bokken_runtime/
ipc_comm.rs

1use std::{collections::{VecDeque}, io, sync::{Arc, atomic::{AtomicBool, Ordering}}};
2
3use borsh::{BorshSerialize, BorshDeserialize};
4// use borsh::{BorshSerialize, BorshDeserialize};
5use tokio::{task, net::{UnixStream, unix}, sync::{Mutex, watch}};
6
7
8enum IPCCommReadState {
9	MsgLength,
10	MsgBody
11}
12enum IPCCommReadResult {
13	Shutdown,
14	Waiting,
15	Message(Vec<u8>)
16}
17
18
19struct IPCCommReadHandler {
20	buffer: Vec<u8>,
21	buffer_index: usize,
22	state: IPCCommReadState,
23	stream: unix::OwnedReadHalf
24}
25impl IPCCommReadHandler {
26	pub fn new(
27		stream: unix::OwnedReadHalf,
28	) -> Self {
29		Self {
30			buffer: vec![0; 8],
31			buffer_index: 0,
32			state: IPCCommReadState::MsgLength,
33			stream
34		}
35	}
36	async fn read_tick(&mut self) -> Result<IPCCommReadResult, io::Error> {
37		self.stream.readable().await?;
38
39		
40		let buf_slice = &mut self.buffer.as_mut_slice()[self.buffer_index..];
41		if buf_slice.len() == 0 {
42			panic!("Zero-length message, this shouldn't happen");
43		}
44		let read_result = match self.stream.try_read(buf_slice) {
45			Ok(0) => {
46				IPCCommReadResult::Shutdown
47			},
48			Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
49				IPCCommReadResult::Waiting
50			},
51			Ok(n) => {
52				self.buffer_index += n;
53				if self.buffer_index == self.buffer.len() {
54					match self.state {
55						IPCCommReadState::MsgLength => {
56							let size = u64::from_le_bytes(
57								self.buffer.as_slice()
58									.try_into()
59									.expect("vector for msg len should have been 8 bytes long")
60							);
61							self.buffer = vec![0; size as usize];
62							self.buffer_index = 0;
63							self.state = IPCCommReadState::MsgBody;
64							IPCCommReadResult::Waiting
65						},
66						IPCCommReadState::MsgBody => {
67							let final_msg = self.buffer.clone();
68							self.buffer = vec![0; 8];
69							self.buffer_index = 0;
70							self.state = IPCCommReadState::MsgLength;
71							IPCCommReadResult::Message(final_msg)
72						}
73					}
74				}else{
75					IPCCommReadResult::Waiting
76				}
77			}
78			Err(e) => {
79				return Err(e.into())
80			}
81		};
82		Ok(read_result)
83	}
84}
85
86
87struct IPCCommWriteHandler {
88	queue: Arc<Mutex<VecDeque<Vec<u8>>>>,
89	stream: unix::OwnedWriteHalf
90}
91impl IPCCommWriteHandler {
92	pub fn new(
93		stream: unix::OwnedWriteHalf,
94		bytes_queue: Arc<Mutex<VecDeque<Vec<u8>>>>,
95	) -> Self {
96		Self {
97			queue: bytes_queue,
98			stream
99		}
100	}
101	async fn write_tick(&mut self) -> Result<(), io::Error> {
102		self.stream.writable().await?;
103		let mut send_queue = self.queue.lock().await;
104		if let Some(send_data) = send_queue.pop_front() {
105			match self.stream.try_write(send_data.as_slice()) {
106				Ok(n) => {
107					if n < send_data.len() {
108						// Not all the bytes have been written, add the remaining ones to the queue
109						send_queue.push_front(send_data[{send_data.len() - n}..].into());
110					}
111				},
112				Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
113					// We can't write it now, add it to the queue
114					send_queue.push_front(send_data);
115				},
116				Err(e) => {
117					return Err(e.into())
118				}
119			}
120		}
121		Ok(())
122	}
123}
124
125// #[derive(Debug, Clone)]
126#[derive(Debug)]
127pub struct IPCComm {
128	write_handle: task::JoinHandle<()>,
129	read_handle: task::JoinHandle<()>,
130	should_stop: Arc<AtomicBool>,
131	send_queue_bytes: Arc<Mutex<VecDeque<Vec<u8>>>>,
132	recv_queue_bytes: Arc<Mutex<VecDeque<Vec<u8>>>>,
133	recv_notif: watch::Receiver<usize>
134}
135
136/// Simple length-prefixed Borsh-encoded messages
137impl IPCComm {
138	/// Consumes a unix stream for length-prefixed Borsh-encoded communication.
139	pub fn new(
140		stream: UnixStream,
141	) -> Self {
142		let recv_queue_bytes_mutex = Arc::new(Mutex::new(VecDeque::new()));
143		let send_queue_bytes_mutex = Arc::new(Mutex::new(VecDeque::new()));
144		let should_stop = Arc::new(AtomicBool::new(false));
145		let (recv_notif_sender, recv_notif) = watch::channel(0usize);
146
147
148		let (read_stream, write_stream) = stream.into_split();
149
150		let mut read_handler = IPCCommReadHandler::new(read_stream);
151		let should_stop_clone = should_stop.clone();
152		let recv_queue_bytes_mutex_clone = recv_queue_bytes_mutex.clone();
153		let read_handle = task::spawn(async move {
154			while !should_stop_clone.load(Ordering::Relaxed) {
155				match read_handler.read_tick().await.unwrap() {
156					IPCCommReadResult::Shutdown => {
157						should_stop_clone.store(true, Ordering::Relaxed);
158						recv_notif_sender.send_modify(|val| {
159							(*val, _) = val.overflowing_add(1)
160						})
161					},
162					IPCCommReadResult::Waiting => {
163						// Nothing else to do!
164					},
165					IPCCommReadResult::Message(msg_bytes) => {
166						let mut recv_queue_bytes = recv_queue_bytes_mutex_clone.lock().await;
167							recv_queue_bytes.push_back(msg_bytes);
168						recv_notif_sender.send_modify(|val| {
169							(*val, _) = val.overflowing_add(1)
170						})
171					},
172				}
173			}
174		});
175
176		let mut write_handler = IPCCommWriteHandler::new(write_stream, send_queue_bytes_mutex.clone());
177		let should_stop_clone = should_stop.clone();
178		let write_handle = task::spawn(async move {
179			while !should_stop_clone.load(Ordering::Relaxed) {
180				write_handler.write_tick().await.unwrap();
181			}
182		});
183		
184		Self {
185			write_handle,
186			read_handle,
187			should_stop,
188			send_queue_bytes: send_queue_bytes_mutex,
189			recv_queue_bytes: recv_queue_bytes_mutex,
190			recv_notif
191		}
192	}
193
194	/// Consumes a unix stream for length-prefixed Borsh-encoded communication.
195	/// 
196	/// Waits until type `I` is received, will error if the initial message couldn't be decoded.
197	pub async fn new_with_identifier<I: BorshDeserialize>(stream: UnixStream) -> Result<(Self, I), io::Error> {
198		let mut new_self = Self::new(stream);
199		let id = new_self.until_recv_msg().await?.ok_or(io::Error::from(io::ErrorKind::UnexpectedEof))?;
200		Ok((new_self, id))
201	}
202
203	/// Adds the provided message to a queue for sending over the underlying connection, but does not wait until
204	/// the message is actually sent
205	pub async fn send_msg<S: BorshSerialize>(&mut self, msg: S) -> Result<(), io::Error> {
206		let msg_bytes = msg.try_to_vec()?;
207		let mut send_queue_bytes = self.send_queue_bytes.lock().await;
208		send_queue_bytes.push_back((msg_bytes.len() as u64).to_le_bytes().to_vec());
209		send_queue_bytes.push_back(msg_bytes);
210		Ok(())
211	}
212
213	/// Adds the provided message to a queue for sending over the underlying connection, but does not block until
214	/// the message is actually sent
215	pub fn blocking_send_msg<S: BorshSerialize>(&mut self, msg: S) -> Result<(), io::Error> {
216		let msg_bytes = msg.try_to_vec()?;
217		let mut send_queue_bytes = self.send_queue_bytes.blocking_lock();
218		send_queue_bytes.push_back((msg_bytes.len() as u64).to_le_bytes().to_vec());
219		send_queue_bytes.push_back(msg_bytes);
220		Ok(())
221	}
222
223	/// Removes and parses a message received messages queue.
224	/// If there are no pending messages, None is returned.
225	pub async fn recv_msg<R: BorshDeserialize>(&mut self) -> Result<Option<R>, io::Error> {
226		let mut recv_queue_bytes = self.recv_queue_bytes.lock().await;
227		match recv_queue_bytes.pop_front() {
228			Some(msg_bytes) => {
229				Ok(Some(R::try_from_slice(&msg_bytes)?))
230			},
231			None => Ok(None),
232		}
233	}
234	
235	/// Removes and parses a message received messages queue.
236	/// If there are no pending messages, this function waits until there is one.
237	/// If the underlying connection is closed before a message could be received, None is returned.
238	pub async fn until_recv_msg<R: BorshDeserialize>(&mut self) -> Result<Option<R>, io::Error> {
239		loop {
240			if self.should_stop.load(Ordering::Relaxed) {
241				return Ok(None);
242			}
243			if let Some(msg) = self.recv_msg::<R>().await? {
244				return Ok(Some(msg));
245			}
246			self.recv_notif.changed().await.expect("Receiver shouldn't drop without sending a message first");
247		}
248	}
249
250	/// Checks if the underlying connection is closed
251	pub fn stopped(&self) -> bool {
252		self.should_stop.load(Ordering::Relaxed)
253	}
254	
255	/// Stops parsing received messages
256	pub fn stop(&self) {
257		self.should_stop.store(true, Ordering::Relaxed);
258	}
259
260	/// Waits until the read/write tasks are stopped
261	pub async fn wait_until_stopped(self) {
262		self.write_handle.await.unwrap();
263		self.read_handle.await.unwrap();
264	}
265
266}