Skip to main content

asteroid_mq/protocol/node/raft/
network_factory.rs

1use std::{
2    collections::HashMap,
3    ops::Deref,
4    sync::{
5        atomic::{self, AtomicBool, AtomicU64, AtomicUsize},
6        Arc, OnceLock,
7    },
8};
9
10use asteroid_mq_model::codec::BINCODE_CONFIG;
11use openraft::{error::Unreachable, raft::ClientWriteResponse, Raft, RaftNetworkFactory};
12use serde::{Deserialize, Serialize};
13use tokio::{
14    io::{AsyncReadExt, AsyncWriteExt},
15    net::{TcpListener, TcpStream},
16    sync::oneshot,
17};
18use tokio_util::sync::CancellationToken;
19use tracing::{instrument, Instrument};
20
21use crate::{
22    prelude::NodeId,
23    protocol::node::raft::{network::TcpNetwork, TypeConfig},
24};
25
26use super::{
27    network::{Packet, Payload, Request, Response},
28    proposal::Proposal,
29    raft_node::TcpNode,
30    MaybeLoadingRaft,
31};
32#[derive(Clone, Debug)]
33pub struct TcpNetworkService {
34    pub info: RaftNodeInfo,
35    pub raft: MaybeLoadingRaft,
36    pub service_api: Arc<OnceLock<tokio::sync::mpsc::UnboundedSender<TcpNetworkServiceRequest>>>,
37    pub ct: CancellationToken,
38}
39
40// TODO: Expose the API rather than the service itself
41// pub struct TcpNetworkServiceApi {
42//     api: tokio::sync::mpsc::UnboundedSender<TcpNetworkServiceRequest>,
43// }
44
45/// 16MB for each connection, this should be enough
46const DEFAULT_BUFFER_SIZE: usize = 1024 * 1024 * 16;
47#[derive(Debug)]
48
49pub struct GetConnection {
50    peer_id: NodeId,
51    responder: oneshot::Sender<Option<Arc<RaftTcpConnection>>>,
52}
53#[derive(Debug)]
54pub struct EnsureConnection {
55    peer_id: NodeId,
56    peer_addr: String,
57    responder: oneshot::Sender<Arc<RaftTcpConnection>>,
58}
59#[derive(Debug)]
60pub enum TcpNetworkServiceRequest {
61    GetConnection(GetConnection),
62    EnsureConnection(EnsureConnection),
63}
64
65impl TcpNetworkService {
66    pub async fn get_connection(&self, peer_id: NodeId) -> Option<Arc<RaftTcpConnection>> {
67        let sender = self.service_api.get()?;
68        let (responder, receiver) = oneshot::channel();
69        let get_connection = GetConnection { peer_id, responder };
70        let _ = sender
71            .send(TcpNetworkServiceRequest::GetConnection(get_connection))
72            .inspect_err(|_| {
73                tracing::error!("service not running");
74            });
75        receiver.await.ok().flatten()
76    }
77    #[instrument(skip_all, fields(local=%self.info.id, peer=%peer_id))]
78    pub async fn ensure_connection(
79        &self,
80        peer_id: NodeId,
81        peer_addr: String,
82    ) -> std::io::Result<Arc<RaftTcpConnection>> {
83        let Some(sender) = self.service_api.get() else {
84            return Err(std::io::Error::new(
85                std::io::ErrorKind::NotConnected,
86                "service not running",
87            ));
88        };
89        let (responder, receiver) = oneshot::channel();
90        let ensure_connection = EnsureConnection {
91            peer_id,
92            peer_addr,
93            responder,
94        };
95
96        sender
97            .send(TcpNetworkServiceRequest::EnsureConnection(
98                ensure_connection,
99            ))
100            .map_err(|_| {
101                std::io::Error::new(std::io::ErrorKind::NotConnected, "service not running")
102            })?;
103        let connection = receiver.await.map_err(|_| {
104            std::io::Error::new(std::io::ErrorKind::NotConnected, "service not running")
105        });
106        tracing::trace!(?connection, "response received");
107        connection
108    }
109    pub fn run_service(&self) {
110        {
111            let tcp_service = self.clone();
112            let info = self.info.clone();
113            let create_task = move || {
114                let ct = tcp_service.ct.clone();
115                let (ensure_connection_tx, mut ensure_connection_rx) =
116                    tokio::sync::mpsc::unbounded_channel();
117                tokio::spawn(
118                    async move {
119                        tracing::info!(?info, "tcp service started");
120                        let this_id = info.id;
121                        let inner_task = async move {
122                            let tcp_listener = TcpListener::bind(info.node.addr).await?;
123                            let mut connection_map: HashMap<NodeId, Arc<RaftTcpConnection>> = HashMap::new();
124                            let mut ensure_waiting_queue:HashMap<NodeId, Vec<oneshot::Sender<Arc<RaftTcpConnection>>>>  = HashMap::new();
125                            // let pending_connection: Arc<Mutex<HashMap::<NodeId, Arc<Notify>>>> = Default::default();
126                            enum SelectEvent {
127                                Accepted(TcpStream),
128                                Request(TcpNetworkServiceRequest),
129                            }
130                            loop {
131                                let event: SelectEvent = tokio::select! {
132                                    _ = ct.cancelled() => {
133                                        return Ok(());
134                                    }
135                                    accepted = tcp_listener.accept() => {
136                                        let Ok((stream, _)) = accepted else {
137                                            continue;
138                                        };
139                                        tracing::info!(local=%this_id, "tcp connection accepted");
140                                        SelectEvent::Accepted(stream)
141                                    }
142                                    ensure_connection_req = ensure_connection_rx.recv() => {
143                                        if let Some(ensure_connection_req) = ensure_connection_req {
144                                            SelectEvent::Request(ensure_connection_req)
145                                        } else {
146                                            return Ok(());
147                                        }
148                                    }
149                                };
150                                match event {
151                                    SelectEvent::Accepted(stream) => {
152                                        if let Ok(connection) =
153                                                RaftTcpConnection::from_tokio_tcp_stream(
154                                                    stream,
155                                                    tcp_service.clone(),
156                                                )
157                                                .await.inspect_err(|e| {
158                                                    tracing::error!(%e, "tcp connection error");
159                                                })
160                                            {
161                                                let peer_id = connection.peer_id();
162                                                tracing::info!(local=%this_id, peer=%peer_id, "tcp connection established");
163                                                if let Some(connection) = connection_map.get(&peer_id) {
164                                                    if connection.is_alive() {
165                                                        if let Some(waiting) = ensure_waiting_queue.remove(&peer_id) {
166                                                            for responder in waiting {
167                                                                let _ = responder.send(connection.clone());
168                                                            }
169                                                        }
170                                                        tracing::trace!(local=%this_id, peer=%peer_id, "connection exists");
171                                                        continue;
172                                                    }
173                                                }
174                                                let connection = Arc::new(connection);
175                                                connection_map.insert(peer_id, connection.clone());
176                                                if let Some(waiting) = ensure_waiting_queue.remove(&peer_id) {
177                                                    for responder in waiting {
178                                                        let _ = responder.send(connection.clone());
179                                                    }
180                                                }
181                                                tracing::info!(local=%this_id, peer=%peer_id, "connection stored");
182                                            }
183                                    }
184                                    SelectEvent::Request(request) => {
185                                        match request {
186                                            TcpNetworkServiceRequest::GetConnection(get_connection) => {
187                                                let GetConnection {
188                                                    peer_id,
189                                                    responder,
190                                                } = get_connection;
191                                                let connection = connection_map.get(&peer_id).cloned();
192                                                let _ = responder.send(connection);
193                                            },
194                                            TcpNetworkServiceRequest::EnsureConnection(ensure_connection) => {
195                                                static REQ_ID: AtomicUsize = AtomicUsize::new(0);
196                                                let req_id = REQ_ID.fetch_add(1, atomic::Ordering::Relaxed);
197
198                                                let EnsureConnection {
199                                                    peer_id,
200                                                    peer_addr,
201                                                    responder,
202                                                } = ensure_connection;
203                                                if let Some(connection) = connection_map.get(&peer_id) {
204                                                    if connection.is_alive() {
205                                                        tracing::trace!(req_id, local=%this_id, peer=%peer_id, "connection exists");
206                                                        let _ = responder.send(connection.clone());
207                                                        continue;
208                                                    }
209                                                }
210                                                // compare id 
211                                                match peer_id.cmp(&info.id) {
212                                                    std::cmp::Ordering::Less => {
213                                                        // just wait for connection
214                                                        tracing::info!(local=%this_id, peer=%peer_id, "waiting for connection({req_id})");
215                                                        ensure_waiting_queue.entry(peer_id).or_default().push(responder);
216                                                    },
217                                                    std::cmp::Ordering::Equal => {
218                                                        // self connection is not allowed
219                                                        panic!("self connection is not allowed");
220                                                    },
221                                                    std::cmp::Ordering::Greater => {
222                                                        ensure_waiting_queue.entry(peer_id).or_default().push(responder);
223                                                        let create = async {
224                                                            tracing::info!(req_id, local=%this_id, peer=%peer_id, %peer_addr, "tcp connecting");
225                                                            let stream = TcpStream::connect(&peer_addr).await?;
226                                                            tracing::info!(req_id, local=%this_id, peer=%peer_id, %peer_addr, "tcp connected");
227                                                            let connection =
228                                                                RaftTcpConnection::from_tokio_tcp_stream(
229                                                                    stream,
230                                                                    tcp_service.clone(),
231                                                                )
232                                                                .await?;
233                                                            tracing::info!(req_id, local=%this_id, peer=%peer_id, %peer_addr, "connected established");
234                                                            <Result<Arc<RaftTcpConnection>, std::io::Error>>::Ok(
235                                                                Arc::new(connection),
236                                                            )
237                                                        };
238                                                        let result = create.await
239                                                        .inspect_err(|e| {
240                                                            tracing::error!(req_id, local=%this_id, peer=%peer_id, %peer_addr, %e, "tcp connection error");
241                                                        });
242
243                                                        if let Ok(connection) = result {
244                                                            connection_map.insert(peer_id, connection.clone());
245                                                            if let Some(waiting) = ensure_waiting_queue.remove(&peer_id) {
246                                                                for responder in waiting {
247                                                                    let _ = responder.send(connection.clone());
248                                                                }
249                                                            }
250                                                        } else {
251                                                            // drop waiting handles
252                                                            ensure_waiting_queue.remove(&peer_id);
253                                                        }
254                                                    },
255                                                }
256                                            },
257                                        }
258                                    }
259                                }
260                            }
261                            #[allow(unreachable_code)]
262                            std::io::Result::Ok(())
263                        };
264                        if let Err(e) = inner_task.await {
265                            tracing::error!(?e, "tcp service error");
266                        };
267                    }
268                    .instrument(tracing::span!(
269                        tracing::Level::INFO,
270                        "tcp_network_service",
271                    )),
272                );
273                ensure_connection_tx
274            };
275            self.service_api.get_or_init(create_task);
276        }
277    }
278}
279
280#[derive(Clone, Default, Debug)]
281pub struct RaftTcpConnectionMap {
282    map: Arc<tokio::sync::RwLock<HashMap<NodeId, Arc<RaftTcpConnection>>>>,
283}
284
285impl Deref for RaftTcpConnectionMap {
286    type Target = Arc<tokio::sync::RwLock<HashMap<NodeId, Arc<RaftTcpConnection>>>>;
287
288    fn deref(&self) -> &Self::Target {
289        &self.map
290    }
291}
292#[derive(Debug)]
293pub struct RaftTcpConnection {
294    peer: RaftNodeInfo,
295    packet_tx: tokio::sync::mpsc::Sender<Packet>,
296    wait_poll: Arc<tokio::sync::Mutex<HashMap<u64, oneshot::Sender<Response>>>>,
297    local_seq: Arc<AtomicU64>,
298    alive: Arc<AtomicBool>,
299    ct: CancellationToken,
300}
301
302impl Drop for RaftTcpConnection {
303    fn drop(&mut self) {
304        let peer = self.peer.id;
305        tracing::info!(%peer, "connection dropped");
306        self.ct.cancel();
307        self.alive.store(false, atomic::Ordering::Relaxed);
308    }
309}
310#[derive(Clone, Debug, Serialize, Deserialize)]
311pub struct RaftNodeInfo {
312    pub id: NodeId,
313    pub node: TcpNode,
314}
315impl RaftTcpConnection {
316    pub fn is_alive(&self) -> bool {
317        self.alive.load(atomic::Ordering::Relaxed)
318    }
319    pub fn peer_id(&self) -> NodeId {
320        self.peer.id
321    }
322    pub fn peer_node(&self) -> &TcpNode {
323        &self.peer.node
324    }
325    fn next_seq(&self) -> u64 {
326        self.local_seq.fetch_add(1, atomic::Ordering::Relaxed)
327    }
328    pub(crate) async fn propose(
329        &self,
330        proposal: Proposal,
331    ) -> crate::Result<ClientWriteResponse<TypeConfig>> {
332        let req = Request::Proposal(proposal);
333        let resp = self
334            .send_request(req)
335            .await
336            .map_err(crate::Error::contextual_custom(
337                "sending proposal to remote",
338            ))?;
339        let resp = resp.await.map_err(crate::Error::contextual_custom(
340            "waiting for proposal response",
341        ))?;
342        let Response::Proposal(resp) = resp else {
343            return Err(crate::Error::unknown("unexpected response"));
344        };
345        let resp = resp.map_err(crate::Error::contextual("remote proposal"))?;
346        Ok(resp)
347    }
348    pub(super) async fn send_request(
349        &self,
350        req: Request,
351    ) -> Result<oneshot::Receiver<Response>, Unreachable> {
352        tracing::trace!(?req, "send request");
353        let payload = Payload::Request(req);
354        let seq_id = self.next_seq();
355        let packet = Packet { seq_id, payload };
356        let (sender, receiver) = tokio::sync::oneshot::channel();
357
358        self.wait_poll.lock().await.insert(seq_id, sender);
359        self.packet_tx
360            .send(packet)
361            .await
362            .inspect_err(|_| {
363                let pool = self.wait_poll.clone();
364                tokio::spawn(async move {
365                    pool.lock().await.remove(&seq_id);
366                });
367            })
368            .map_err(|e| Unreachable::new(&e))?;
369        Ok(receiver)
370    }
371    pub async fn from_tokio_tcp_stream(
372        mut stream: TcpStream,
373        service: TcpNetworkService,
374    ) -> std::io::Result<Self> {
375        let connection_ct = service.ct.child_token();
376        let info = service.info.clone();
377        let pending_raft = service.raft.clone();
378        let local_id = info.id;
379        let packet = bincode::serde::encode_to_vec(&info, BINCODE_CONFIG)
380            .map_err(|_| std::io::ErrorKind::InvalidData)?;
381        stream.write_u32(packet.len() as u32).await?;
382        stream.write_all(&packet).await?;
383        let hello_size = stream.read_u32().await?;
384        let mut hello_data = vec![0; hello_size as usize];
385        stream.read_exact(&mut hello_data).await?;
386        let peer: RaftNodeInfo = bincode::serde::decode_from_slice(&hello_data, BINCODE_CONFIG)
387            .map_err(|_| std::io::ErrorKind::InvalidData)?
388            .0;
389        let peer_id = peer.id;
390        tracing::info!(peer=%peer_id, local=%local_id, "hello received");
391        let (mut read, mut write) = stream.into_split();
392        let wait_pool = Arc::new(tokio::sync::Mutex::new(HashMap::<
393            u64,
394            oneshot::Sender<Response>,
395        >::new()));
396        let wait_poll_clone = wait_pool.clone();
397        let (packet_tx, mut packet_rx) = tokio::sync::mpsc::channel::<Packet>(512);
398        let write_task_ct = connection_ct.child_token();
399        let _write_task = tokio::spawn(
400            async move {
401                let write_loop = async {
402                    loop {
403                        let packet = tokio::select! {
404                            _ = write_task_ct.cancelled() => {
405                                return std::io::Result::<()>::Ok(());
406                            }
407                            maybe_packet = packet_rx.recv() => {
408                                match maybe_packet {
409                                    None => {
410                                        return std::io::Result::<()>::Ok(());
411                                    }
412                                    Some(packet) => packet,
413                                }
414                            }
415                        };
416                        let bytes = bincode::serde::encode_to_vec(&packet.payload, BINCODE_CONFIG)
417                            .expect("should be valid for bincode");
418                        write.write_u64(packet.seq_id).await?;
419                        write.write_u32(bytes.len() as u32).await?;
420                        write.write_all(&bytes).await?;
421                        write.flush().await?;
422                        tracing::trace!(?packet, "flushed");
423                    }
424                };
425                match write_loop.await {
426                    Ok(_) => {}
427                    Err(e) => {
428                        tracing::error!(%e, "write loop error");
429                    }
430                }
431            }
432            .instrument(tracing::span!(
433                tracing::Level::INFO,
434                "write_loop",
435                ?info,
436                ?peer
437            )),
438        );
439        let alive = Arc::new(AtomicBool::new(true));
440        let read_task_ct = connection_ct.child_token();
441        let _read_task = {
442            let packet_tx = packet_tx.clone();
443            let alive = alive.clone();
444            let inner_task = async move {
445                let mut buffer = Vec::with_capacity(DEFAULT_BUFFER_SIZE);
446                loop {
447                    let seq_id = tokio::select! {
448                        seq_id = read.read_u64() => {
449                            seq_id
450                        }
451                        _ = read_task_ct.cancelled() => {
452                            return Ok(())
453                        }
454                    };
455                    let seq_id = seq_id?;
456                    let len = read.read_u32().await? as usize;
457                    if len > buffer.capacity() {
458                        buffer.reserve(len - buffer.capacity());
459                    }
460                    buffer.resize(len, 0);
461                    // unsafe {
462                    //     buffer.set_len(len);
463                    // }
464                    let data = &mut buffer[..len];
465                    read.read_exact(data).await?;
466                    let Ok((payload, _)) =
467                        bincode::serde::decode_from_slice::<Payload, _>(data, BINCODE_CONFIG)
468                            .inspect_err(|e| {
469                                tracing::error!(?e);
470                            })
471                    else {
472                        continue;
473                    };
474                    tracing::trace!(?seq_id, ?payload, "received");
475                    match payload {
476                        Payload::Request(req) => {
477                            let pending_raft = pending_raft.clone();
478                            let packet_tx = packet_tx.clone();
479                            tokio::spawn(
480                                async move {
481                                    let raft = pending_raft.get().await;
482                                    let resp = match req {
483                                        Request::Vote(vote) => {
484                                            Response::Vote(raft.vote(vote).await)
485                                        }
486                                        Request::AppendEntries(append) => {
487                                            Response::AppendEntries(raft.append_entries(append).await)
488                                        },
489                                        Request::InstallSnapshot(install) => {
490                                            let offset = install.offset;
491                                            let installed = asteroid_mq_model::MemUnit(offset as usize);
492                                            let size = asteroid_mq_model::MemUnit(install.data.len());
493                                            let done = install.done;
494
495                                            tracing::info!({offset, %installed, %size, done}, "installing snapshot");
496                                            Response::InstallSnapshot(
497                                                raft.install_snapshot(install).await,
498                                            )
499                                        }
500                                        Request::Proposal(proposal) => {
501                                            Response::Proposal(raft.client_write(proposal).await)
502                                        }
503                                    };
504                                    if let Some(fatal) = resp.catch_fatal() {
505                                        tracing::error!(?fatal, "⚠⚠⚠ FATAL ⚠⚠⚠");
506                                        // it's over
507                                        raft.shutdown().await.expect("join error when shutting down raft node");
508                                        std::process::exit(1);
509                                    };
510                                    let payload = Payload::Response(resp);
511                                    let _ = packet_tx.send(Packet { seq_id, payload }).await;
512                                }
513                                .instrument(tracing::span!(
514                                    tracing::Level::INFO,
515                                    "tcp_request_handler",
516                                )),
517                            );
518                        }
519                        Payload::Response(resp) => {
520                            let sender = wait_poll_clone.lock().await.remove(&seq_id);
521                            if let Some(sender) = sender {
522                                let _result = sender.send(resp);
523                            } else {
524                                tracing::warn!(?seq_id, "responder not found");
525                            }
526                        }
527                    }
528                }
529            };
530            tokio::spawn(
531                async move {
532                    let result: std::io::Result<()> = inner_task.await;
533                    if let Err(e) = result {
534                        tracing::error!(%e, "read task error");
535                    }
536                    alive.store(false, atomic::Ordering::Relaxed);
537                }
538                .instrument(tracing::span!(
539                    tracing::Level::INFO,
540                    "tcp_read_loop",
541                    local=%local_id,
542                    peer=%peer_id
543                )),
544            )
545        };
546        Ok(Self {
547            packet_tx,
548            wait_poll: wait_pool,
549            peer,
550            local_seq: Arc::new(AtomicU64::new(0)),
551            alive,
552            ct: connection_ct,
553        })
554    }
555}
556
557impl TcpNetworkService {
558    pub fn new(info: RaftNodeInfo, raft: MaybeLoadingRaft, ct: CancellationToken) -> Self {
559        Self {
560            info,
561            raft,
562            service_api: Arc::new(OnceLock::new()),
563            ct,
564        }
565    }
566    pub fn set_raft(&self, raft: Raft<TypeConfig>) {
567        self.raft.set(raft);
568    }
569}
570
571impl RaftNetworkFactory<TypeConfig> for TcpNetworkService {
572    type Network = TcpNetwork;
573    async fn new_client(
574        &mut self,
575        target: <TypeConfig as openraft::RaftTypeConfig>::NodeId,
576        node: &<TypeConfig as openraft::RaftTypeConfig>::Node,
577    ) -> Self::Network {
578        TcpNetwork::new(
579            RaftNodeInfo {
580                id: target,
581                node: node.clone(),
582            },
583            self.clone(),
584        )
585    }
586}
587#[cfg(test)]
588#[test]
589fn test_mem() {
590    use crate::protocol::node::{LogStorage, StateMachineStore};
591
592    tracing_subscriber::fmt()
593        .with_max_level(tracing::Level::INFO)
594        .init();
595    pub struct MemStore {}
596    impl openraft::testing::StoreBuilder<TypeConfig, LogStorage, Arc<StateMachineStore>> for MemStore {
597        async fn build(
598            &self,
599        ) -> Result<((), LogStorage, Arc<StateMachineStore>), openraft::StorageError<NodeId>>
600        {
601            Ok((
602                (),
603                LogStorage::default(),
604                Arc::new(unsafe { StateMachineStore::new_uninitialized() }),
605            ))
606        }
607    }
608    openraft::testing::Suite::test_all(MemStore {}).unwrap();
609}