fizyr_rpc/
peer.rs

1use tokio::sync::{mpsc, oneshot};
2
3use crate::{
4	util,
5	Error,
6	Message,
7	PeerHandle,
8	ReceivedMessage,
9	SentRequestHandle,
10};
11use crate::request_tracker::RequestTracker;
12use crate::util::{select, Either};
13
14/// Message for the internal peer command loop.
15pub enum Command<Body> {
16	SendRequest(SendRequest<Body>),
17	SendRawMessage(SendRawMessage<Body>),
18	ProcessReceivedMessage(ProcessReceivedMessage<Body>),
19	Stop,
20	UnregisterReadHandle,
21	RegisterWriteHandle,
22	UnregisterWriteHandle,
23}
24
25/// Peer read/write loop.
26///
27/// This struct is used to run the read/write loop of the peer.
28/// To send or receive requests and stream messages,
29/// you need to use the [`PeerHandle`] instead.
30pub struct Peer<Transport: crate::transport::Transport> {
31	/// The transport to use for sending/receiving messages.
32	transport: Transport,
33
34	/// The request tracker to track open requests.
35	request_tracker: RequestTracker<Transport::Body>,
36
37	/// Sending end of the command channel, so we can send commands to ourselves.
38	///
39	/// This is used to have the read loop inject things into the command loop.
40	/// That way, the read loop doesn't need a mutable reference to the request tracker,
41	/// which simplifies the implementation.
42	command_tx: mpsc::UnboundedSender<Command<Transport::Body>>,
43
44	/// Receiving end of the command channel.
45	///
46	/// Used to make the command loop do the things we want.
47	command_rx: mpsc::UnboundedReceiver<Command<Transport::Body>>,
48
49	/// Sending end of the channel for incoming requests and stream messages.
50	incoming_tx: mpsc::UnboundedSender<Result<ReceivedMessage<Transport::Body>, Error>>,
51
52	/// The number of [`PeerWriteHandle`][crate::PeerWriteHandle] objects for this peer.
53	///
54	/// When it hits zero, and the [`PeerReadHandle`][crate::PeerReadHandle] is dropped,
55	/// the internal loops are stopped.
56	write_handles: usize,
57}
58
59impl<Transport: crate::transport::Transport> Peer<Transport> {
60	/// Create a new peer and a handle to it.
61	///
62	/// The [`Peer`] itself is used to run the read/write loop.
63	/// The returned [`PeerHandle`] is used to send and receive requests and stream messages.
64	///
65	/// If [`Self::run()`] is not called (or aborted),
66	/// none of the functions of the [`PeerHandle`] will work.
67	/// They will just wait forever.
68	///
69	/// You can also use [`Self::spawn()`] to run the read/write loop in a newly spawned task,
70	/// and only get a [`PeerHandle`].
71	/// You should only use [`Self::spawn()`] if you do not need full control over the execution of the read/write loop.
72	pub fn new(transport: Transport) -> (Self, PeerHandle<Transport::Body>) {
73		let (incoming_tx, incoming_rx) = mpsc::unbounded_channel();
74		let (command_tx, command_rx) = mpsc::unbounded_channel();
75		let request_tracker = RequestTracker::new(command_tx.clone());
76
77		let peer = Self {
78			transport,
79			request_tracker,
80			command_tx: command_tx.clone(),
81			command_rx,
82			incoming_tx,
83			write_handles: 1,
84		};
85
86		let handle = PeerHandle::new(incoming_rx, command_tx);
87
88		(peer, handle)
89	}
90
91	/// Spawn a peer in a new task, and get a handle to the peer.
92	///
93	/// The spawned task will immediately be detached.
94	/// It can not be joined.
95	///
96	/// The returned [`PeerHandle`] can be used to send and receive requests and stream messages.
97	///
98	/// If you need more control of the execution of the peer read/write loop,
99	/// you should use [`Self::new()`] instead.
100	pub fn spawn(transport: Transport) -> PeerHandle<Transport::Body> {
101		let (peer, handle) = Self::new(transport);
102		tokio::spawn(peer.run());
103		handle
104	}
105
106	/// Connect to a remote server.
107	///
108	/// Similar to [`Self::spawn()`], this spawns a background task for the peer.
109	///
110	/// The returned [`PeerHandle`] can be used to send and receive requests and stream messages.
111	///
112	/// The type of address accepted depends on the transport.
113	/// For internet transports such as TCP, the address must implement [`tokio::net::ToSocketAddrs`].
114	/// For unix transports, the address must implement [`AsRef<std::path::Path>`].
115	pub async fn connect<'a, Address>(address: Address, config: Transport::Config) -> std::io::Result<(PeerHandle<Transport::Body>, Transport::Info)>
116	where
117		Address: 'a,
118		Transport: util::Connect<'a, Address>,
119	{
120		let transport = Transport::connect(address, config).await?;
121		let info = transport.info()?;
122		Ok((Self::spawn(transport), info))
123	}
124
125	/// Run the read/write loop.
126	pub async fn run(mut self) {
127		let Self {
128			transport,
129			request_tracker,
130			command_tx,
131			command_rx,
132			incoming_tx,
133			write_handles,
134		} = &mut self;
135
136		let (read_half, write_half) = transport.split();
137
138		let mut read_loop = ReadLoop {
139			read_half,
140			command_tx: command_tx.clone(),
141		};
142
143		let mut command_loop = CommandLoop {
144			write_half,
145			request_tracker,
146			command_rx,
147			incoming_tx,
148			read_handle_dropped: &mut false,
149			write_handles,
150		};
151
152		let read_loop = read_loop.run();
153		let command_loop = command_loop.run();
154
155		// Futures must be pinned in order to poll them.
156		tokio::pin!(read_loop);
157		tokio::pin!(command_loop);
158
159		match select(read_loop, command_loop).await {
160			Either::Left(((), command_loop)) => {
161				// If the read loop stopped we should still flush all queued incoming messages, then stop.
162				command_tx
163					.send(Command::Stop)
164					.map_err(drop)
165					.expect("command loop did not stop yet but command channel is closed");
166				command_loop.await;
167			},
168			Either::Right((_read_loop, ())) => {
169				// If the command loop stopped, the read loop is pointless.
170				// Nobody will ever observe any effects of the read loop without the command loop.
171				// The read loop is dropped here.
172			},
173		}
174	}
175
176	/// Get direct access to the underlying transport.
177	pub fn transport(&self) -> &Transport {
178		&self.transport
179	}
180
181	/// Get direct mutable access to the underlying transport.
182	pub fn transport_mut(&mut self) -> &mut Transport {
183		&mut self.transport
184	}
185}
186
187/// Implementation of the read loop of a peer.
188struct ReadLoop<R>
189where
190	R: crate::transport::TransportReadHalf,
191{
192	/// The read half of the message transport.
193	read_half: R,
194
195	/// The channel used to inject things into the peer read/write loop.
196	command_tx: mpsc::UnboundedSender<Command<R::Body>>,
197}
198
199impl<R> ReadLoop<R>
200where
201	R: crate::transport::TransportReadHalf,
202{
203	/// Run the read loop.
204	async fn run(&mut self) {
205		loop {
206			// Read a message, and stop the read loop on errors.
207			let message = self.read_half.read_msg().await;
208			let stop = matches!(&message, Err(e) if e.is_fatal());
209			let message = message.map_err(|e| e.into_inner());
210
211			// But first send the error to the command loop so it can be delivered to the peer.
212			// If that fails the command loop already closed, so just stop the read loop.
213			if self.command_tx.send(crate::peer::ProcessReceivedMessage { message }.into()).is_err() {
214				break;
215			}
216
217			if stop {
218				break;
219			}
220		}
221	}
222}
223
224/// Implementation of the command loop of a peer.
225struct CommandLoop<'a, W>
226where
227	W: crate::transport::TransportWriteHalf,
228{
229	/// The write half of the message transport.
230	write_half: W,
231
232	/// The request tracker.
233	request_tracker: &'a mut RequestTracker<W::Body>,
234
235	/// The channel for incoming commands.
236	command_rx: &'a mut mpsc::UnboundedReceiver<Command<W::Body>>,
237
238	/// The channel for sending incoming messages to the [`PeerHandle`].
239	incoming_tx: &'a mut mpsc::UnboundedSender<Result<ReceivedMessage<W::Body>, Error>>,
240
241	/// Flag to indicate if the peer read handle has already been stopped.
242	read_handle_dropped: &'a mut bool,
243
244	/// Number of open write handles.
245	write_handles: &'a mut usize,
246}
247
248impl<W> CommandLoop<'_, W>
249where
250	W: crate::transport::TransportWriteHalf,
251{
252	/// Run the command loop.
253	async fn run(&mut self) {
254		loop {
255			// Stop the command loop if both halves of the PeerHandle are dropped.
256			if *self.read_handle_dropped && *self.write_handles == 0 {
257				break;
258			}
259
260			// Get the next command from the channel.
261			let command = self
262				.command_rx
263				.recv()
264				.await
265				.expect("all command channels closed, but we keep one open ourselves");
266
267			// Process the command.
268			let flow = match command {
269				Command::SendRequest(command) => self.send_request(command).await,
270				Command::SendRawMessage(command) => self.send_raw_message(command).await,
271				Command::ProcessReceivedMessage(command) => self.process_incoming_message(command).await,
272				Command::Stop => LoopFlow::Stop,
273				Command::UnregisterReadHandle => {
274					*self.read_handle_dropped = true;
275					LoopFlow::Continue
276				},
277				Command::RegisterWriteHandle => {
278					*self.write_handles += 1;
279					LoopFlow::Continue
280				},
281				Command::UnregisterWriteHandle => {
282					*self.write_handles -= 1;
283					LoopFlow::Continue
284				},
285			};
286
287			// Stop the loop if the command dictates it.
288			match flow {
289				LoopFlow::Stop => break,
290				LoopFlow::Continue => continue,
291			}
292		}
293	}
294
295	/// Process a SendRequest command.
296	async fn send_request(&mut self, command: crate::peer::SendRequest<W::Body>) -> LoopFlow {
297		let request = match self.request_tracker.allocate_sent_request(command.service_id) {
298			Ok(x) => x,
299			Err(e) => {
300				let _: Result<_, _> = command.result_tx.send(Err(e));
301				return LoopFlow::Continue;
302			},
303		};
304
305		let request_id = request.request_id();
306
307		let message = Message::request(request.request_id(), request.service_id(), command.body);
308		if let Err((e, flow)) = self.write_message(&message).await {
309			let _: Result<_, _> = command.result_tx.send(Err(e));
310			let _: Result<_, _> = self.request_tracker.remove_sent_request(request_id);
311			return flow;
312		}
313
314		// If sending fails, the result_rx was dropped.
315		// Then remove the request from the tracker.
316		if command.result_tx.send(Ok(request)).is_err() {
317			let _: Result<_, _> = self.request_tracker.remove_sent_request(request_id);
318		}
319
320		LoopFlow::Continue
321	}
322
323	/// Process a SendRawMessage command.
324	async fn send_raw_message(&mut self, command: crate::peer::SendRawMessage<W::Body>) -> LoopFlow {
325		// Remove tracked received requests when we send a response.
326		if command.message.header.message_type.is_response() {
327			let _: Result<_, _> = self.request_tracker.remove_received_request(command.message.header.request_id);
328		}
329
330		// TODO: replace SendRawMessage with specific command for different message types.
331		// Then we can use that to remove the appropriate request from the tracker if result_tx is dropped.
332		// Or just parse the message header to determine which request to remove.
333		//
334		// Actually, should we remove the request if result_tx is dropped?
335		// Needs more thought.
336
337		if let Err((e, flow)) = self.write_message(&command.message).await {
338			let _: Result<_, _> = command.result_tx.send(Err(e));
339			return flow;
340		}
341
342		let _: Result<_, _> = command.result_tx.send(Ok(()));
343		LoopFlow::Continue
344	}
345
346	/// Process an incoming message.
347	async fn process_incoming_message(&mut self, command: crate::peer::ProcessReceivedMessage<W::Body>) -> LoopFlow {
348		// Forward errors to the peer read handle.
349		let message = match command.message {
350			Ok(x) => x,
351			Err(e) => {
352				let _: Result<_, _> = self.send_incoming(Err(e)).await;
353				return LoopFlow::Continue;
354			},
355		};
356
357		// Forward errors from the request tracker too.
358		let incoming = match self.request_tracker.process_incoming_message(message).await {
359			Ok(None) => return LoopFlow::Continue,
360			Ok(Some(x)) => x,
361			Err(e) => {
362				let _: Result<_, _> = self.send_incoming(Err(e)).await;
363				return LoopFlow::Continue;
364			},
365		};
366
367		// Deliver the message to the peer read handle.
368		match self.incoming_tx.send(Ok(incoming)) {
369			Ok(()) => LoopFlow::Continue,
370
371			// The read handle was dropped.
372			// `msg` must be Ok(), because we checked it before.
373			Err(mpsc::error::SendError(msg)) => match msg.unwrap() {
374				// Respond to requests with an error.
375				ReceivedMessage::Request(request, _body) => {
376					let error_msg = format!("unexpected request for service {}", request.service_id());
377					let response = Message::error_response(request.request_id(), &error_msg);
378					if self.write_message(&response).await.is_err() {
379						// If we can't send the error to the remote peer, just close the connection.
380						// Even if the transport doesn't say that the write error is fatal.
381						LoopFlow::Stop
382					} else {
383						LoopFlow::Continue
384					}
385				},
386				ReceivedMessage::Stream(_) => LoopFlow::Continue,
387			},
388		}
389	}
390
391	/// Send an incoming message to the PeerHandle.
392	async fn send_incoming(&mut self, incoming: Result<ReceivedMessage<W::Body>, Error>) -> Result<(), ()> {
393		if self.incoming_tx.send(incoming).is_err() {
394			*self.read_handle_dropped = true;
395			Err(())
396		} else {
397			Ok(())
398		}
399	}
400
401	async fn write_message(&mut self, message: &Message<W::Body>) -> Result<(), (Error, LoopFlow)> {
402		match self.write_half.write_msg(&message.header, &message.body).await {
403			Ok(()) => Ok(()),
404			Err(e) => {
405				let flow = if e.is_fatal() {
406					LoopFlow::Stop
407				} else {
408					LoopFlow::Continue
409				};
410				Err((e.into_inner(), flow))
411			},
412		}
413	}
414}
415
416/// Loop control flow command.
417///
418/// Allows other methods to make decisions on loop control flow.
419#[derive(Debug, Copy, Clone, Eq, PartialEq)]
420enum LoopFlow {
421	/// Keep the loop running.
422	Continue,
423
424	/// Stop the loop.
425	Stop,
426}
427
428/// Command to send a request to the remote peer.
429pub struct SendRequest<Body> {
430	/// The service ID for the request.
431	pub service_id: i32,
432
433	/// The body for the request.
434	pub body: Body,
435
436	/// One-shot channel to transmit back the created [`SentRequestHandle`] object, or an error.
437	pub result_tx: oneshot::Sender<Result<SentRequestHandle<Body>, Error>>,
438}
439
440/// Command to send a raw message to the remote peer.
441pub struct SendRawMessage<Body> {
442	/// The message to send.
443	pub message: Message<Body>,
444
445	/// One-shot channel to receive the result of sending the message.
446	pub result_tx: oneshot::Sender<Result<(), Error>>,
447}
448
449/// Command to process an incoming message from the remote peer.
450pub struct ProcessReceivedMessage<Body> {
451	/// The message from the remote peer, or an error.
452	pub message: Result<Message<Body>, Error>,
453}
454
455impl<Body> std::fmt::Debug for Command<Body> {
456	fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
457		let mut debug = f.debug_struct("Command");
458		match self {
459			Self::SendRequest(x) => debug.field("SendRequest", x),
460			Self::SendRawMessage(x) => debug.field("SendRawMessage", x),
461			Self::ProcessReceivedMessage(x) => debug.field("ProcessReceivedMessage", x),
462			Self::Stop => debug.field("Stop", &()),
463			Self::UnregisterReadHandle => debug.field("UnregisterReadHandle", &()),
464			Self::RegisterWriteHandle => debug.field("RegisterWriteHandle", &()),
465			Self::UnregisterWriteHandle => debug.field("UnregisterWriteHandle", &()),
466
467		}.finish()
468	}
469}
470
471impl<Body> std::fmt::Debug for SendRequest<Body> {
472	fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
473		f.debug_struct("SendRequest").field("service_id", &self.service_id).finish()
474	}
475}
476
477impl<Body> std::fmt::Debug for SendRawMessage<Body> {
478	fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
479		f.debug_struct("SendRawMessage").field("message", &self.message).finish()
480	}
481}
482
483impl<Body> std::fmt::Debug for ProcessReceivedMessage<Body> {
484	fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
485		f.debug_struct("ProcessReceivedMessage").field("message", &self.message).finish()
486	}
487}
488
489impl<Body> From<SendRequest<Body>> for Command<Body> {
490	fn from(other: SendRequest<Body>) -> Self {
491		Self::SendRequest(other)
492	}
493}
494
495impl<Body> From<SendRawMessage<Body>> for Command<Body> {
496	fn from(other: SendRawMessage<Body>) -> Self {
497		Self::SendRawMessage(other)
498	}
499}
500
501impl<Body> From<ProcessReceivedMessage<Body>> for Command<Body> {
502	fn from(other: ProcessReceivedMessage<Body>) -> Self {
503		Self::ProcessReceivedMessage(other)
504	}
505}
506
507#[cfg(test)]
508mod test {
509	use super::*;
510	use assert2::assert;
511	use assert2::let_assert;
512
513	use crate::MessageHeader;
514	use crate::transport::StreamTransport;
515	use tokio::net::UnixStream;
516
517	#[tokio::test]
518	async fn test_peer() {
519		let_assert!(Ok((peer_a, peer_b)) = UnixStream::pair());
520
521		let (peer_a, handle_a) = Peer::new(StreamTransport::new(peer_a, Default::default()));
522		let (peer_b, mut handle_b) = Peer::new(StreamTransport::new(peer_b, Default::default()));
523
524		let task_a = tokio::spawn(peer_a.run());
525		let task_b = tokio::spawn(peer_b.run());
526
527		// Send a request from A.
528		let_assert!(Ok(mut sent_request) = handle_a.send_request(1, &[2][..]).await);
529		let request_id = sent_request.request_id();
530
531		// Receive the request on B.
532		let_assert!(Ok(ReceivedMessage::Request(mut received_request, _body)) = handle_b.recv_message().await);
533
534		// Send an update from A and receive it on B.
535		let_assert!(Ok(()) = sent_request.send_update(3, &[4][..]).await);
536		let_assert!(Some(update) = received_request.recv_update().await);
537		assert!(update.header == MessageHeader::requester_update(request_id, 3));
538		assert!(update.body.as_ref() == &[4]);
539
540		// Send an update from B and receive it on A.
541		let_assert!(Ok(()) = received_request.send_update(5, &[6][..]).await);
542		let_assert!(Some(update) = sent_request.recv_update().await);
543		assert!(update.header == MessageHeader::responder_update(request_id, 5));
544		assert!(update.body.as_ref() == &[6]);
545
546		// Send the response from B and receive it on A.
547		let_assert!(Ok(()) = received_request.send_response(7, &[8][..]).await);
548		let_assert!(Ok(response) = sent_request.recv_response().await);
549		assert!(response.header == MessageHeader::response(request_id, 7));
550		assert!(response.body.as_ref() == &[8]);
551
552		drop(handle_a);
553		drop(handle_b);
554		drop(sent_request);
555
556		assert!(let Ok(()) = task_a.await);
557		assert!(let Ok(()) = task_b.await);
558	}
559
560	#[tokio::test]
561	async fn peeked_response_is_not_gone() {
562		let_assert!(Ok((peer_a, peer_b)) = UnixStream::pair());
563		let handle_a = Peer::spawn(StreamTransport::new(peer_a, Default::default()));
564		let mut handle_b = Peer::spawn(StreamTransport::new(peer_b, Default::default()));
565
566		// Send a request from A.
567		let_assert!(Ok(mut sent_request) = handle_a.send_request(1, &[2][..]).await);
568		let request_id = sent_request.request_id();
569
570		// Receive the request on B.
571		let_assert!(Ok(ReceivedMessage::Request(received_request, _body)) = handle_b.recv_message().await);
572
573		// Send two updates and a response from B to A.
574		let_assert!(Ok(()) = received_request.send_update(5, &b"Hello world!"[..]).await);
575		let_assert!(Ok(()) = received_request.send_update(6, &b"Hello world!"[..]).await);
576		let_assert!(Ok(()) = received_request.send_response(7, &b"Goodbye!"[..]).await);
577
578		// Try to receive three responses.
579		// This should stuff the response in the internal peek buffer.
580		assert!(let Some(_) = sent_request.recv_update().await);
581		assert!(let Some(_) = sent_request.recv_update().await);
582		assert!(let None = sent_request.recv_update().await);
583
584		// Now receive the response, which should be returned intact from the peek buffer exactly once.
585		let_assert!(Ok(response) = sent_request.recv_response().await);
586		assert!(let Err(_) = sent_request.recv_response().await);
587
588		assert!(response.header == MessageHeader::response(request_id, 7));
589		assert!(response.body.as_ref() == b"Goodbye!");
590	}
591
592	#[tokio::test]
593	async fn peeked_update_is_not_gone() {
594		let_assert!(Ok((peer_a, peer_b)) = UnixStream::pair());
595		let handle_a = Peer::spawn(StreamTransport::new(peer_a, Default::default()));
596		let mut handle_b = Peer::spawn(StreamTransport::new(peer_b, Default::default()));
597
598		// Send a request from A.
599		let_assert!(Ok(mut sent_request) = handle_a.send_request(1, &[2][..]).await);
600		let request_id = sent_request.request_id();
601
602		// Receive the request on B.
603		let_assert!(Ok(ReceivedMessage::Request(received_request, _body)) = handle_b.recv_message().await);
604
605		// Send one update and a response from B to A.
606		let_assert!(Ok(()) = received_request.send_update(5, &b"Hello world!"[..]).await);
607		let_assert!(Ok(()) = received_request.send_response(6, &b"Goodbye!"[..]).await);
608
609		// Trying to read a response should stuff the update in the internal peek buffer.
610		assert!(let Err(_) = sent_request.recv_response().await);
611
612		// Now we should receive the update intact from the peek buffer exactly once.
613		let_assert!(Some(update) = sent_request.recv_update().await);
614		assert!(update.header == MessageHeader::responder_update(request_id, 5));
615		assert!(update.body.as_ref() == b"Hello world!");
616		assert!(let None = sent_request.recv_update().await);
617
618		// Now receive the response.
619		let_assert!(Ok(response) = sent_request.recv_response().await);
620		assert!(response.header == MessageHeader::response(request_id, 6));
621		assert!(response.body.as_ref() == b"Goodbye!");
622	}
623}