1use tokio::sync::{mpsc, oneshot};
2
3use crate::{
4 util,
5 Error,
6 Message,
7 PeerHandle,
8 ReceivedMessage,
9 SentRequestHandle,
10};
11use crate::request_tracker::RequestTracker;
12use crate::util::{select, Either};
13
14pub enum Command<Body> {
16 SendRequest(SendRequest<Body>),
17 SendRawMessage(SendRawMessage<Body>),
18 ProcessReceivedMessage(ProcessReceivedMessage<Body>),
19 Stop,
20 UnregisterReadHandle,
21 RegisterWriteHandle,
22 UnregisterWriteHandle,
23}
24
25pub struct Peer<Transport: crate::transport::Transport> {
31 transport: Transport,
33
34 request_tracker: RequestTracker<Transport::Body>,
36
37 command_tx: mpsc::UnboundedSender<Command<Transport::Body>>,
43
44 command_rx: mpsc::UnboundedReceiver<Command<Transport::Body>>,
48
49 incoming_tx: mpsc::UnboundedSender<Result<ReceivedMessage<Transport::Body>, Error>>,
51
52 write_handles: usize,
57}
58
59impl<Transport: crate::transport::Transport> Peer<Transport> {
60 pub fn new(transport: Transport) -> (Self, PeerHandle<Transport::Body>) {
73 let (incoming_tx, incoming_rx) = mpsc::unbounded_channel();
74 let (command_tx, command_rx) = mpsc::unbounded_channel();
75 let request_tracker = RequestTracker::new(command_tx.clone());
76
77 let peer = Self {
78 transport,
79 request_tracker,
80 command_tx: command_tx.clone(),
81 command_rx,
82 incoming_tx,
83 write_handles: 1,
84 };
85
86 let handle = PeerHandle::new(incoming_rx, command_tx);
87
88 (peer, handle)
89 }
90
91 pub fn spawn(transport: Transport) -> PeerHandle<Transport::Body> {
101 let (peer, handle) = Self::new(transport);
102 tokio::spawn(peer.run());
103 handle
104 }
105
106 pub async fn connect<'a, Address>(address: Address, config: Transport::Config) -> std::io::Result<(PeerHandle<Transport::Body>, Transport::Info)>
116 where
117 Address: 'a,
118 Transport: util::Connect<'a, Address>,
119 {
120 let transport = Transport::connect(address, config).await?;
121 let info = transport.info()?;
122 Ok((Self::spawn(transport), info))
123 }
124
125 pub async fn run(mut self) {
127 let Self {
128 transport,
129 request_tracker,
130 command_tx,
131 command_rx,
132 incoming_tx,
133 write_handles,
134 } = &mut self;
135
136 let (read_half, write_half) = transport.split();
137
138 let mut read_loop = ReadLoop {
139 read_half,
140 command_tx: command_tx.clone(),
141 };
142
143 let mut command_loop = CommandLoop {
144 write_half,
145 request_tracker,
146 command_rx,
147 incoming_tx,
148 read_handle_dropped: &mut false,
149 write_handles,
150 };
151
152 let read_loop = read_loop.run();
153 let command_loop = command_loop.run();
154
155 tokio::pin!(read_loop);
157 tokio::pin!(command_loop);
158
159 match select(read_loop, command_loop).await {
160 Either::Left(((), command_loop)) => {
161 command_tx
163 .send(Command::Stop)
164 .map_err(drop)
165 .expect("command loop did not stop yet but command channel is closed");
166 command_loop.await;
167 },
168 Either::Right((_read_loop, ())) => {
169 },
173 }
174 }
175
176 pub fn transport(&self) -> &Transport {
178 &self.transport
179 }
180
181 pub fn transport_mut(&mut self) -> &mut Transport {
183 &mut self.transport
184 }
185}
186
187struct ReadLoop<R>
189where
190 R: crate::transport::TransportReadHalf,
191{
192 read_half: R,
194
195 command_tx: mpsc::UnboundedSender<Command<R::Body>>,
197}
198
199impl<R> ReadLoop<R>
200where
201 R: crate::transport::TransportReadHalf,
202{
203 async fn run(&mut self) {
205 loop {
206 let message = self.read_half.read_msg().await;
208 let stop = matches!(&message, Err(e) if e.is_fatal());
209 let message = message.map_err(|e| e.into_inner());
210
211 if self.command_tx.send(crate::peer::ProcessReceivedMessage { message }.into()).is_err() {
214 break;
215 }
216
217 if stop {
218 break;
219 }
220 }
221 }
222}
223
224struct CommandLoop<'a, W>
226where
227 W: crate::transport::TransportWriteHalf,
228{
229 write_half: W,
231
232 request_tracker: &'a mut RequestTracker<W::Body>,
234
235 command_rx: &'a mut mpsc::UnboundedReceiver<Command<W::Body>>,
237
238 incoming_tx: &'a mut mpsc::UnboundedSender<Result<ReceivedMessage<W::Body>, Error>>,
240
241 read_handle_dropped: &'a mut bool,
243
244 write_handles: &'a mut usize,
246}
247
248impl<W> CommandLoop<'_, W>
249where
250 W: crate::transport::TransportWriteHalf,
251{
252 async fn run(&mut self) {
254 loop {
255 if *self.read_handle_dropped && *self.write_handles == 0 {
257 break;
258 }
259
260 let command = self
262 .command_rx
263 .recv()
264 .await
265 .expect("all command channels closed, but we keep one open ourselves");
266
267 let flow = match command {
269 Command::SendRequest(command) => self.send_request(command).await,
270 Command::SendRawMessage(command) => self.send_raw_message(command).await,
271 Command::ProcessReceivedMessage(command) => self.process_incoming_message(command).await,
272 Command::Stop => LoopFlow::Stop,
273 Command::UnregisterReadHandle => {
274 *self.read_handle_dropped = true;
275 LoopFlow::Continue
276 },
277 Command::RegisterWriteHandle => {
278 *self.write_handles += 1;
279 LoopFlow::Continue
280 },
281 Command::UnregisterWriteHandle => {
282 *self.write_handles -= 1;
283 LoopFlow::Continue
284 },
285 };
286
287 match flow {
289 LoopFlow::Stop => break,
290 LoopFlow::Continue => continue,
291 }
292 }
293 }
294
295 async fn send_request(&mut self, command: crate::peer::SendRequest<W::Body>) -> LoopFlow {
297 let request = match self.request_tracker.allocate_sent_request(command.service_id) {
298 Ok(x) => x,
299 Err(e) => {
300 let _: Result<_, _> = command.result_tx.send(Err(e));
301 return LoopFlow::Continue;
302 },
303 };
304
305 let request_id = request.request_id();
306
307 let message = Message::request(request.request_id(), request.service_id(), command.body);
308 if let Err((e, flow)) = self.write_message(&message).await {
309 let _: Result<_, _> = command.result_tx.send(Err(e));
310 let _: Result<_, _> = self.request_tracker.remove_sent_request(request_id);
311 return flow;
312 }
313
314 if command.result_tx.send(Ok(request)).is_err() {
317 let _: Result<_, _> = self.request_tracker.remove_sent_request(request_id);
318 }
319
320 LoopFlow::Continue
321 }
322
323 async fn send_raw_message(&mut self, command: crate::peer::SendRawMessage<W::Body>) -> LoopFlow {
325 if command.message.header.message_type.is_response() {
327 let _: Result<_, _> = self.request_tracker.remove_received_request(command.message.header.request_id);
328 }
329
330 if let Err((e, flow)) = self.write_message(&command.message).await {
338 let _: Result<_, _> = command.result_tx.send(Err(e));
339 return flow;
340 }
341
342 let _: Result<_, _> = command.result_tx.send(Ok(()));
343 LoopFlow::Continue
344 }
345
346 async fn process_incoming_message(&mut self, command: crate::peer::ProcessReceivedMessage<W::Body>) -> LoopFlow {
348 let message = match command.message {
350 Ok(x) => x,
351 Err(e) => {
352 let _: Result<_, _> = self.send_incoming(Err(e)).await;
353 return LoopFlow::Continue;
354 },
355 };
356
357 let incoming = match self.request_tracker.process_incoming_message(message).await {
359 Ok(None) => return LoopFlow::Continue,
360 Ok(Some(x)) => x,
361 Err(e) => {
362 let _: Result<_, _> = self.send_incoming(Err(e)).await;
363 return LoopFlow::Continue;
364 },
365 };
366
367 match self.incoming_tx.send(Ok(incoming)) {
369 Ok(()) => LoopFlow::Continue,
370
371 Err(mpsc::error::SendError(msg)) => match msg.unwrap() {
374 ReceivedMessage::Request(request, _body) => {
376 let error_msg = format!("unexpected request for service {}", request.service_id());
377 let response = Message::error_response(request.request_id(), &error_msg);
378 if self.write_message(&response).await.is_err() {
379 LoopFlow::Stop
382 } else {
383 LoopFlow::Continue
384 }
385 },
386 ReceivedMessage::Stream(_) => LoopFlow::Continue,
387 },
388 }
389 }
390
391 async fn send_incoming(&mut self, incoming: Result<ReceivedMessage<W::Body>, Error>) -> Result<(), ()> {
393 if self.incoming_tx.send(incoming).is_err() {
394 *self.read_handle_dropped = true;
395 Err(())
396 } else {
397 Ok(())
398 }
399 }
400
401 async fn write_message(&mut self, message: &Message<W::Body>) -> Result<(), (Error, LoopFlow)> {
402 match self.write_half.write_msg(&message.header, &message.body).await {
403 Ok(()) => Ok(()),
404 Err(e) => {
405 let flow = if e.is_fatal() {
406 LoopFlow::Stop
407 } else {
408 LoopFlow::Continue
409 };
410 Err((e.into_inner(), flow))
411 },
412 }
413 }
414}
415
416#[derive(Debug, Copy, Clone, Eq, PartialEq)]
420enum LoopFlow {
421 Continue,
423
424 Stop,
426}
427
428pub struct SendRequest<Body> {
430 pub service_id: i32,
432
433 pub body: Body,
435
436 pub result_tx: oneshot::Sender<Result<SentRequestHandle<Body>, Error>>,
438}
439
440pub struct SendRawMessage<Body> {
442 pub message: Message<Body>,
444
445 pub result_tx: oneshot::Sender<Result<(), Error>>,
447}
448
449pub struct ProcessReceivedMessage<Body> {
451 pub message: Result<Message<Body>, Error>,
453}
454
455impl<Body> std::fmt::Debug for Command<Body> {
456 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
457 let mut debug = f.debug_struct("Command");
458 match self {
459 Self::SendRequest(x) => debug.field("SendRequest", x),
460 Self::SendRawMessage(x) => debug.field("SendRawMessage", x),
461 Self::ProcessReceivedMessage(x) => debug.field("ProcessReceivedMessage", x),
462 Self::Stop => debug.field("Stop", &()),
463 Self::UnregisterReadHandle => debug.field("UnregisterReadHandle", &()),
464 Self::RegisterWriteHandle => debug.field("RegisterWriteHandle", &()),
465 Self::UnregisterWriteHandle => debug.field("UnregisterWriteHandle", &()),
466
467 }.finish()
468 }
469}
470
471impl<Body> std::fmt::Debug for SendRequest<Body> {
472 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
473 f.debug_struct("SendRequest").field("service_id", &self.service_id).finish()
474 }
475}
476
477impl<Body> std::fmt::Debug for SendRawMessage<Body> {
478 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
479 f.debug_struct("SendRawMessage").field("message", &self.message).finish()
480 }
481}
482
483impl<Body> std::fmt::Debug for ProcessReceivedMessage<Body> {
484 fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
485 f.debug_struct("ProcessReceivedMessage").field("message", &self.message).finish()
486 }
487}
488
489impl<Body> From<SendRequest<Body>> for Command<Body> {
490 fn from(other: SendRequest<Body>) -> Self {
491 Self::SendRequest(other)
492 }
493}
494
495impl<Body> From<SendRawMessage<Body>> for Command<Body> {
496 fn from(other: SendRawMessage<Body>) -> Self {
497 Self::SendRawMessage(other)
498 }
499}
500
501impl<Body> From<ProcessReceivedMessage<Body>> for Command<Body> {
502 fn from(other: ProcessReceivedMessage<Body>) -> Self {
503 Self::ProcessReceivedMessage(other)
504 }
505}
506
507#[cfg(test)]
508mod test {
509 use super::*;
510 use assert2::assert;
511 use assert2::let_assert;
512
513 use crate::MessageHeader;
514 use crate::transport::StreamTransport;
515 use tokio::net::UnixStream;
516
517 #[tokio::test]
518 async fn test_peer() {
519 let_assert!(Ok((peer_a, peer_b)) = UnixStream::pair());
520
521 let (peer_a, handle_a) = Peer::new(StreamTransport::new(peer_a, Default::default()));
522 let (peer_b, mut handle_b) = Peer::new(StreamTransport::new(peer_b, Default::default()));
523
524 let task_a = tokio::spawn(peer_a.run());
525 let task_b = tokio::spawn(peer_b.run());
526
527 let_assert!(Ok(mut sent_request) = handle_a.send_request(1, &[2][..]).await);
529 let request_id = sent_request.request_id();
530
531 let_assert!(Ok(ReceivedMessage::Request(mut received_request, _body)) = handle_b.recv_message().await);
533
534 let_assert!(Ok(()) = sent_request.send_update(3, &[4][..]).await);
536 let_assert!(Some(update) = received_request.recv_update().await);
537 assert!(update.header == MessageHeader::requester_update(request_id, 3));
538 assert!(update.body.as_ref() == &[4]);
539
540 let_assert!(Ok(()) = received_request.send_update(5, &[6][..]).await);
542 let_assert!(Some(update) = sent_request.recv_update().await);
543 assert!(update.header == MessageHeader::responder_update(request_id, 5));
544 assert!(update.body.as_ref() == &[6]);
545
546 let_assert!(Ok(()) = received_request.send_response(7, &[8][..]).await);
548 let_assert!(Ok(response) = sent_request.recv_response().await);
549 assert!(response.header == MessageHeader::response(request_id, 7));
550 assert!(response.body.as_ref() == &[8]);
551
552 drop(handle_a);
553 drop(handle_b);
554 drop(sent_request);
555
556 assert!(let Ok(()) = task_a.await);
557 assert!(let Ok(()) = task_b.await);
558 }
559
560 #[tokio::test]
561 async fn peeked_response_is_not_gone() {
562 let_assert!(Ok((peer_a, peer_b)) = UnixStream::pair());
563 let handle_a = Peer::spawn(StreamTransport::new(peer_a, Default::default()));
564 let mut handle_b = Peer::spawn(StreamTransport::new(peer_b, Default::default()));
565
566 let_assert!(Ok(mut sent_request) = handle_a.send_request(1, &[2][..]).await);
568 let request_id = sent_request.request_id();
569
570 let_assert!(Ok(ReceivedMessage::Request(received_request, _body)) = handle_b.recv_message().await);
572
573 let_assert!(Ok(()) = received_request.send_update(5, &b"Hello world!"[..]).await);
575 let_assert!(Ok(()) = received_request.send_update(6, &b"Hello world!"[..]).await);
576 let_assert!(Ok(()) = received_request.send_response(7, &b"Goodbye!"[..]).await);
577
578 assert!(let Some(_) = sent_request.recv_update().await);
581 assert!(let Some(_) = sent_request.recv_update().await);
582 assert!(let None = sent_request.recv_update().await);
583
584 let_assert!(Ok(response) = sent_request.recv_response().await);
586 assert!(let Err(_) = sent_request.recv_response().await);
587
588 assert!(response.header == MessageHeader::response(request_id, 7));
589 assert!(response.body.as_ref() == b"Goodbye!");
590 }
591
592 #[tokio::test]
593 async fn peeked_update_is_not_gone() {
594 let_assert!(Ok((peer_a, peer_b)) = UnixStream::pair());
595 let handle_a = Peer::spawn(StreamTransport::new(peer_a, Default::default()));
596 let mut handle_b = Peer::spawn(StreamTransport::new(peer_b, Default::default()));
597
598 let_assert!(Ok(mut sent_request) = handle_a.send_request(1, &[2][..]).await);
600 let request_id = sent_request.request_id();
601
602 let_assert!(Ok(ReceivedMessage::Request(received_request, _body)) = handle_b.recv_message().await);
604
605 let_assert!(Ok(()) = received_request.send_update(5, &b"Hello world!"[..]).await);
607 let_assert!(Ok(()) = received_request.send_response(6, &b"Goodbye!"[..]).await);
608
609 assert!(let Err(_) = sent_request.recv_response().await);
611
612 let_assert!(Some(update) = sent_request.recv_update().await);
614 assert!(update.header == MessageHeader::responder_update(request_id, 5));
615 assert!(update.body.as_ref() == b"Hello world!");
616 assert!(let None = sent_request.recv_update().await);
617
618 let_assert!(Ok(response) = sent_request.recv_response().await);
620 assert!(response.header == MessageHeader::response(request_id, 6));
621 assert!(response.body.as_ref() == b"Goodbye!");
622 }
623}