1use std::{
2 collections::HashMap,
3 ops::Deref,
4 sync::{
5 atomic::{self, AtomicBool, AtomicU64, AtomicUsize},
6 Arc, OnceLock,
7 },
8};
9
10use asteroid_mq_model::codec::BINCODE_CONFIG;
11use openraft::{error::Unreachable, raft::ClientWriteResponse, Raft, RaftNetworkFactory};
12use serde::{Deserialize, Serialize};
13use tokio::{
14 io::{AsyncReadExt, AsyncWriteExt},
15 net::{TcpListener, TcpStream},
16 sync::oneshot,
17};
18use tokio_util::sync::CancellationToken;
19use tracing::{instrument, Instrument};
20
21use crate::{
22 prelude::NodeId,
23 protocol::node::raft::{network::TcpNetwork, TypeConfig},
24};
25
26use super::{
27 network::{Packet, Payload, Request, Response},
28 proposal::Proposal,
29 raft_node::TcpNode,
30 MaybeLoadingRaft,
31};
32#[derive(Clone, Debug)]
33pub struct TcpNetworkService {
34 pub info: RaftNodeInfo,
35 pub raft: MaybeLoadingRaft,
36 pub service_api: Arc<OnceLock<tokio::sync::mpsc::UnboundedSender<TcpNetworkServiceRequest>>>,
37 pub ct: CancellationToken,
38}
39
40const DEFAULT_BUFFER_SIZE: usize = 1024 * 1024 * 16;
47#[derive(Debug)]
48
49pub struct GetConnection {
50 peer_id: NodeId,
51 responder: oneshot::Sender<Option<Arc<RaftTcpConnection>>>,
52}
53#[derive(Debug)]
54pub struct EnsureConnection {
55 peer_id: NodeId,
56 peer_addr: String,
57 responder: oneshot::Sender<Arc<RaftTcpConnection>>,
58}
59#[derive(Debug)]
60pub enum TcpNetworkServiceRequest {
61 GetConnection(GetConnection),
62 EnsureConnection(EnsureConnection),
63}
64
65impl TcpNetworkService {
66 pub async fn get_connection(&self, peer_id: NodeId) -> Option<Arc<RaftTcpConnection>> {
67 let sender = self.service_api.get()?;
68 let (responder, receiver) = oneshot::channel();
69 let get_connection = GetConnection { peer_id, responder };
70 let _ = sender
71 .send(TcpNetworkServiceRequest::GetConnection(get_connection))
72 .inspect_err(|_| {
73 tracing::error!("service not running");
74 });
75 receiver.await.ok().flatten()
76 }
77 #[instrument(skip_all, fields(local=%self.info.id, peer=%peer_id))]
78 pub async fn ensure_connection(
79 &self,
80 peer_id: NodeId,
81 peer_addr: String,
82 ) -> std::io::Result<Arc<RaftTcpConnection>> {
83 let Some(sender) = self.service_api.get() else {
84 return Err(std::io::Error::new(
85 std::io::ErrorKind::NotConnected,
86 "service not running",
87 ));
88 };
89 let (responder, receiver) = oneshot::channel();
90 let ensure_connection = EnsureConnection {
91 peer_id,
92 peer_addr,
93 responder,
94 };
95
96 sender
97 .send(TcpNetworkServiceRequest::EnsureConnection(
98 ensure_connection,
99 ))
100 .map_err(|_| {
101 std::io::Error::new(std::io::ErrorKind::NotConnected, "service not running")
102 })?;
103 let connection = receiver.await.map_err(|_| {
104 std::io::Error::new(std::io::ErrorKind::NotConnected, "service not running")
105 });
106 tracing::trace!(?connection, "response received");
107 connection
108 }
109 pub fn run_service(&self) {
110 {
111 let tcp_service = self.clone();
112 let info = self.info.clone();
113 let create_task = move || {
114 let ct = tcp_service.ct.clone();
115 let (ensure_connection_tx, mut ensure_connection_rx) =
116 tokio::sync::mpsc::unbounded_channel();
117 tokio::spawn(
118 async move {
119 tracing::info!(?info, "tcp service started");
120 let this_id = info.id;
121 let inner_task = async move {
122 let tcp_listener = TcpListener::bind(info.node.addr).await?;
123 let mut connection_map: HashMap<NodeId, Arc<RaftTcpConnection>> = HashMap::new();
124 let mut ensure_waiting_queue:HashMap<NodeId, Vec<oneshot::Sender<Arc<RaftTcpConnection>>>> = HashMap::new();
125 enum SelectEvent {
127 Accepted(TcpStream),
128 Request(TcpNetworkServiceRequest),
129 }
130 loop {
131 let event: SelectEvent = tokio::select! {
132 _ = ct.cancelled() => {
133 return Ok(());
134 }
135 accepted = tcp_listener.accept() => {
136 let Ok((stream, _)) = accepted else {
137 continue;
138 };
139 tracing::info!(local=%this_id, "tcp connection accepted");
140 SelectEvent::Accepted(stream)
141 }
142 ensure_connection_req = ensure_connection_rx.recv() => {
143 if let Some(ensure_connection_req) = ensure_connection_req {
144 SelectEvent::Request(ensure_connection_req)
145 } else {
146 return Ok(());
147 }
148 }
149 };
150 match event {
151 SelectEvent::Accepted(stream) => {
152 if let Ok(connection) =
153 RaftTcpConnection::from_tokio_tcp_stream(
154 stream,
155 tcp_service.clone(),
156 )
157 .await.inspect_err(|e| {
158 tracing::error!(%e, "tcp connection error");
159 })
160 {
161 let peer_id = connection.peer_id();
162 tracing::info!(local=%this_id, peer=%peer_id, "tcp connection established");
163 if let Some(connection) = connection_map.get(&peer_id) {
164 if connection.is_alive() {
165 if let Some(waiting) = ensure_waiting_queue.remove(&peer_id) {
166 for responder in waiting {
167 let _ = responder.send(connection.clone());
168 }
169 }
170 tracing::trace!(local=%this_id, peer=%peer_id, "connection exists");
171 continue;
172 }
173 }
174 let connection = Arc::new(connection);
175 connection_map.insert(peer_id, connection.clone());
176 if let Some(waiting) = ensure_waiting_queue.remove(&peer_id) {
177 for responder in waiting {
178 let _ = responder.send(connection.clone());
179 }
180 }
181 tracing::info!(local=%this_id, peer=%peer_id, "connection stored");
182 }
183 }
184 SelectEvent::Request(request) => {
185 match request {
186 TcpNetworkServiceRequest::GetConnection(get_connection) => {
187 let GetConnection {
188 peer_id,
189 responder,
190 } = get_connection;
191 let connection = connection_map.get(&peer_id).cloned();
192 let _ = responder.send(connection);
193 },
194 TcpNetworkServiceRequest::EnsureConnection(ensure_connection) => {
195 static REQ_ID: AtomicUsize = AtomicUsize::new(0);
196 let req_id = REQ_ID.fetch_add(1, atomic::Ordering::Relaxed);
197
198 let EnsureConnection {
199 peer_id,
200 peer_addr,
201 responder,
202 } = ensure_connection;
203 if let Some(connection) = connection_map.get(&peer_id) {
204 if connection.is_alive() {
205 tracing::trace!(req_id, local=%this_id, peer=%peer_id, "connection exists");
206 let _ = responder.send(connection.clone());
207 continue;
208 }
209 }
210 match peer_id.cmp(&info.id) {
212 std::cmp::Ordering::Less => {
213 tracing::info!(local=%this_id, peer=%peer_id, "waiting for connection({req_id})");
215 ensure_waiting_queue.entry(peer_id).or_default().push(responder);
216 },
217 std::cmp::Ordering::Equal => {
218 panic!("self connection is not allowed");
220 },
221 std::cmp::Ordering::Greater => {
222 ensure_waiting_queue.entry(peer_id).or_default().push(responder);
223 let create = async {
224 tracing::info!(req_id, local=%this_id, peer=%peer_id, %peer_addr, "tcp connecting");
225 let stream = TcpStream::connect(&peer_addr).await?;
226 tracing::info!(req_id, local=%this_id, peer=%peer_id, %peer_addr, "tcp connected");
227 let connection =
228 RaftTcpConnection::from_tokio_tcp_stream(
229 stream,
230 tcp_service.clone(),
231 )
232 .await?;
233 tracing::info!(req_id, local=%this_id, peer=%peer_id, %peer_addr, "connected established");
234 <Result<Arc<RaftTcpConnection>, std::io::Error>>::Ok(
235 Arc::new(connection),
236 )
237 };
238 let result = create.await
239 .inspect_err(|e| {
240 tracing::error!(req_id, local=%this_id, peer=%peer_id, %peer_addr, %e, "tcp connection error");
241 });
242
243 if let Ok(connection) = result {
244 connection_map.insert(peer_id, connection.clone());
245 if let Some(waiting) = ensure_waiting_queue.remove(&peer_id) {
246 for responder in waiting {
247 let _ = responder.send(connection.clone());
248 }
249 }
250 } else {
251 ensure_waiting_queue.remove(&peer_id);
253 }
254 },
255 }
256 },
257 }
258 }
259 }
260 }
261 #[allow(unreachable_code)]
262 std::io::Result::Ok(())
263 };
264 if let Err(e) = inner_task.await {
265 tracing::error!(?e, "tcp service error");
266 };
267 }
268 .instrument(tracing::span!(
269 tracing::Level::INFO,
270 "tcp_network_service",
271 )),
272 );
273 ensure_connection_tx
274 };
275 self.service_api.get_or_init(create_task);
276 }
277 }
278}
279
280#[derive(Clone, Default, Debug)]
281pub struct RaftTcpConnectionMap {
282 map: Arc<tokio::sync::RwLock<HashMap<NodeId, Arc<RaftTcpConnection>>>>,
283}
284
285impl Deref for RaftTcpConnectionMap {
286 type Target = Arc<tokio::sync::RwLock<HashMap<NodeId, Arc<RaftTcpConnection>>>>;
287
288 fn deref(&self) -> &Self::Target {
289 &self.map
290 }
291}
292#[derive(Debug)]
293pub struct RaftTcpConnection {
294 peer: RaftNodeInfo,
295 packet_tx: tokio::sync::mpsc::Sender<Packet>,
296 wait_poll: Arc<tokio::sync::Mutex<HashMap<u64, oneshot::Sender<Response>>>>,
297 local_seq: Arc<AtomicU64>,
298 alive: Arc<AtomicBool>,
299 ct: CancellationToken,
300}
301
302impl Drop for RaftTcpConnection {
303 fn drop(&mut self) {
304 let peer = self.peer.id;
305 tracing::info!(%peer, "connection dropped");
306 self.ct.cancel();
307 self.alive.store(false, atomic::Ordering::Relaxed);
308 }
309}
310#[derive(Clone, Debug, Serialize, Deserialize)]
311pub struct RaftNodeInfo {
312 pub id: NodeId,
313 pub node: TcpNode,
314}
315impl RaftTcpConnection {
316 pub fn is_alive(&self) -> bool {
317 self.alive.load(atomic::Ordering::Relaxed)
318 }
319 pub fn peer_id(&self) -> NodeId {
320 self.peer.id
321 }
322 pub fn peer_node(&self) -> &TcpNode {
323 &self.peer.node
324 }
325 fn next_seq(&self) -> u64 {
326 self.local_seq.fetch_add(1, atomic::Ordering::Relaxed)
327 }
328 pub(crate) async fn propose(
329 &self,
330 proposal: Proposal,
331 ) -> crate::Result<ClientWriteResponse<TypeConfig>> {
332 let req = Request::Proposal(proposal);
333 let resp = self
334 .send_request(req)
335 .await
336 .map_err(crate::Error::contextual_custom(
337 "sending proposal to remote",
338 ))?;
339 let resp = resp.await.map_err(crate::Error::contextual_custom(
340 "waiting for proposal response",
341 ))?;
342 let Response::Proposal(resp) = resp else {
343 return Err(crate::Error::unknown("unexpected response"));
344 };
345 let resp = resp.map_err(crate::Error::contextual("remote proposal"))?;
346 Ok(resp)
347 }
348 pub(super) async fn send_request(
349 &self,
350 req: Request,
351 ) -> Result<oneshot::Receiver<Response>, Unreachable> {
352 tracing::trace!(?req, "send request");
353 let payload = Payload::Request(req);
354 let seq_id = self.next_seq();
355 let packet = Packet { seq_id, payload };
356 let (sender, receiver) = tokio::sync::oneshot::channel();
357
358 self.wait_poll.lock().await.insert(seq_id, sender);
359 self.packet_tx
360 .send(packet)
361 .await
362 .inspect_err(|_| {
363 let pool = self.wait_poll.clone();
364 tokio::spawn(async move {
365 pool.lock().await.remove(&seq_id);
366 });
367 })
368 .map_err(|e| Unreachable::new(&e))?;
369 Ok(receiver)
370 }
371 pub async fn from_tokio_tcp_stream(
372 mut stream: TcpStream,
373 service: TcpNetworkService,
374 ) -> std::io::Result<Self> {
375 let connection_ct = service.ct.child_token();
376 let info = service.info.clone();
377 let pending_raft = service.raft.clone();
378 let local_id = info.id;
379 let packet = bincode::serde::encode_to_vec(&info, BINCODE_CONFIG)
380 .map_err(|_| std::io::ErrorKind::InvalidData)?;
381 stream.write_u32(packet.len() as u32).await?;
382 stream.write_all(&packet).await?;
383 let hello_size = stream.read_u32().await?;
384 let mut hello_data = vec![0; hello_size as usize];
385 stream.read_exact(&mut hello_data).await?;
386 let peer: RaftNodeInfo = bincode::serde::decode_from_slice(&hello_data, BINCODE_CONFIG)
387 .map_err(|_| std::io::ErrorKind::InvalidData)?
388 .0;
389 let peer_id = peer.id;
390 tracing::info!(peer=%peer_id, local=%local_id, "hello received");
391 let (mut read, mut write) = stream.into_split();
392 let wait_pool = Arc::new(tokio::sync::Mutex::new(HashMap::<
393 u64,
394 oneshot::Sender<Response>,
395 >::new()));
396 let wait_poll_clone = wait_pool.clone();
397 let (packet_tx, mut packet_rx) = tokio::sync::mpsc::channel::<Packet>(512);
398 let write_task_ct = connection_ct.child_token();
399 let _write_task = tokio::spawn(
400 async move {
401 let write_loop = async {
402 loop {
403 let packet = tokio::select! {
404 _ = write_task_ct.cancelled() => {
405 return std::io::Result::<()>::Ok(());
406 }
407 maybe_packet = packet_rx.recv() => {
408 match maybe_packet {
409 None => {
410 return std::io::Result::<()>::Ok(());
411 }
412 Some(packet) => packet,
413 }
414 }
415 };
416 let bytes = bincode::serde::encode_to_vec(&packet.payload, BINCODE_CONFIG)
417 .expect("should be valid for bincode");
418 write.write_u64(packet.seq_id).await?;
419 write.write_u32(bytes.len() as u32).await?;
420 write.write_all(&bytes).await?;
421 write.flush().await?;
422 tracing::trace!(?packet, "flushed");
423 }
424 };
425 match write_loop.await {
426 Ok(_) => {}
427 Err(e) => {
428 tracing::error!(%e, "write loop error");
429 }
430 }
431 }
432 .instrument(tracing::span!(
433 tracing::Level::INFO,
434 "write_loop",
435 ?info,
436 ?peer
437 )),
438 );
439 let alive = Arc::new(AtomicBool::new(true));
440 let read_task_ct = connection_ct.child_token();
441 let _read_task = {
442 let packet_tx = packet_tx.clone();
443 let alive = alive.clone();
444 let inner_task = async move {
445 let mut buffer = Vec::with_capacity(DEFAULT_BUFFER_SIZE);
446 loop {
447 let seq_id = tokio::select! {
448 seq_id = read.read_u64() => {
449 seq_id
450 }
451 _ = read_task_ct.cancelled() => {
452 return Ok(())
453 }
454 };
455 let seq_id = seq_id?;
456 let len = read.read_u32().await? as usize;
457 if len > buffer.capacity() {
458 buffer.reserve(len - buffer.capacity());
459 }
460 buffer.resize(len, 0);
461 let data = &mut buffer[..len];
465 read.read_exact(data).await?;
466 let Ok((payload, _)) =
467 bincode::serde::decode_from_slice::<Payload, _>(data, BINCODE_CONFIG)
468 .inspect_err(|e| {
469 tracing::error!(?e);
470 })
471 else {
472 continue;
473 };
474 tracing::trace!(?seq_id, ?payload, "received");
475 match payload {
476 Payload::Request(req) => {
477 let pending_raft = pending_raft.clone();
478 let packet_tx = packet_tx.clone();
479 tokio::spawn(
480 async move {
481 let raft = pending_raft.get().await;
482 let resp = match req {
483 Request::Vote(vote) => {
484 Response::Vote(raft.vote(vote).await)
485 }
486 Request::AppendEntries(append) => {
487 Response::AppendEntries(raft.append_entries(append).await)
488 },
489 Request::InstallSnapshot(install) => {
490 let offset = install.offset;
491 let installed = asteroid_mq_model::MemUnit(offset as usize);
492 let size = asteroid_mq_model::MemUnit(install.data.len());
493 let done = install.done;
494
495 tracing::info!({offset, %installed, %size, done}, "installing snapshot");
496 Response::InstallSnapshot(
497 raft.install_snapshot(install).await,
498 )
499 }
500 Request::Proposal(proposal) => {
501 Response::Proposal(raft.client_write(proposal).await)
502 }
503 };
504 if let Some(fatal) = resp.catch_fatal() {
505 tracing::error!(?fatal, "⚠⚠⚠ FATAL ⚠⚠⚠");
506 raft.shutdown().await.expect("join error when shutting down raft node");
508 std::process::exit(1);
509 };
510 let payload = Payload::Response(resp);
511 let _ = packet_tx.send(Packet { seq_id, payload }).await;
512 }
513 .instrument(tracing::span!(
514 tracing::Level::INFO,
515 "tcp_request_handler",
516 )),
517 );
518 }
519 Payload::Response(resp) => {
520 let sender = wait_poll_clone.lock().await.remove(&seq_id);
521 if let Some(sender) = sender {
522 let _result = sender.send(resp);
523 } else {
524 tracing::warn!(?seq_id, "responder not found");
525 }
526 }
527 }
528 }
529 };
530 tokio::spawn(
531 async move {
532 let result: std::io::Result<()> = inner_task.await;
533 if let Err(e) = result {
534 tracing::error!(%e, "read task error");
535 }
536 alive.store(false, atomic::Ordering::Relaxed);
537 }
538 .instrument(tracing::span!(
539 tracing::Level::INFO,
540 "tcp_read_loop",
541 local=%local_id,
542 peer=%peer_id
543 )),
544 )
545 };
546 Ok(Self {
547 packet_tx,
548 wait_poll: wait_pool,
549 peer,
550 local_seq: Arc::new(AtomicU64::new(0)),
551 alive,
552 ct: connection_ct,
553 })
554 }
555}
556
557impl TcpNetworkService {
558 pub fn new(info: RaftNodeInfo, raft: MaybeLoadingRaft, ct: CancellationToken) -> Self {
559 Self {
560 info,
561 raft,
562 service_api: Arc::new(OnceLock::new()),
563 ct,
564 }
565 }
566 pub fn set_raft(&self, raft: Raft<TypeConfig>) {
567 self.raft.set(raft);
568 }
569}
570
571impl RaftNetworkFactory<TypeConfig> for TcpNetworkService {
572 type Network = TcpNetwork;
573 async fn new_client(
574 &mut self,
575 target: <TypeConfig as openraft::RaftTypeConfig>::NodeId,
576 node: &<TypeConfig as openraft::RaftTypeConfig>::Node,
577 ) -> Self::Network {
578 TcpNetwork::new(
579 RaftNodeInfo {
580 id: target,
581 node: node.clone(),
582 },
583 self.clone(),
584 )
585 }
586}
587#[cfg(test)]
588#[test]
589fn test_mem() {
590 use crate::protocol::node::{LogStorage, StateMachineStore};
591
592 tracing_subscriber::fmt()
593 .with_max_level(tracing::Level::INFO)
594 .init();
595 pub struct MemStore {}
596 impl openraft::testing::StoreBuilder<TypeConfig, LogStorage, Arc<StateMachineStore>> for MemStore {
597 async fn build(
598 &self,
599 ) -> Result<((), LogStorage, Arc<StateMachineStore>), openraft::StorageError<NodeId>>
600 {
601 Ok((
602 (),
603 LogStorage::default(),
604 Arc::new(unsafe { StateMachineStore::new_uninitialized() }),
605 ))
606 }
607 }
608 openraft::testing::Suite::test_all(MemStore {}).unwrap();
609}