bokken_runtime/
ipc_comm.rs1use std::{collections::{VecDeque}, io, sync::{Arc, atomic::{AtomicBool, Ordering}}};
2
3use borsh::{BorshSerialize, BorshDeserialize};
4use tokio::{task, net::{UnixStream, unix}, sync::{Mutex, watch}};
6
7
8enum IPCCommReadState {
9 MsgLength,
10 MsgBody
11}
12enum IPCCommReadResult {
13 Shutdown,
14 Waiting,
15 Message(Vec<u8>)
16}
17
18
19struct IPCCommReadHandler {
20 buffer: Vec<u8>,
21 buffer_index: usize,
22 state: IPCCommReadState,
23 stream: unix::OwnedReadHalf
24}
25impl IPCCommReadHandler {
26 pub fn new(
27 stream: unix::OwnedReadHalf,
28 ) -> Self {
29 Self {
30 buffer: vec![0; 8],
31 buffer_index: 0,
32 state: IPCCommReadState::MsgLength,
33 stream
34 }
35 }
36 async fn read_tick(&mut self) -> Result<IPCCommReadResult, io::Error> {
37 self.stream.readable().await?;
38
39
40 let buf_slice = &mut self.buffer.as_mut_slice()[self.buffer_index..];
41 if buf_slice.len() == 0 {
42 panic!("Zero-length message, this shouldn't happen");
43 }
44 let read_result = match self.stream.try_read(buf_slice) {
45 Ok(0) => {
46 IPCCommReadResult::Shutdown
47 },
48 Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
49 IPCCommReadResult::Waiting
50 },
51 Ok(n) => {
52 self.buffer_index += n;
53 if self.buffer_index == self.buffer.len() {
54 match self.state {
55 IPCCommReadState::MsgLength => {
56 let size = u64::from_le_bytes(
57 self.buffer.as_slice()
58 .try_into()
59 .expect("vector for msg len should have been 8 bytes long")
60 );
61 self.buffer = vec![0; size as usize];
62 self.buffer_index = 0;
63 self.state = IPCCommReadState::MsgBody;
64 IPCCommReadResult::Waiting
65 },
66 IPCCommReadState::MsgBody => {
67 let final_msg = self.buffer.clone();
68 self.buffer = vec![0; 8];
69 self.buffer_index = 0;
70 self.state = IPCCommReadState::MsgLength;
71 IPCCommReadResult::Message(final_msg)
72 }
73 }
74 }else{
75 IPCCommReadResult::Waiting
76 }
77 }
78 Err(e) => {
79 return Err(e.into())
80 }
81 };
82 Ok(read_result)
83 }
84}
85
86
87struct IPCCommWriteHandler {
88 queue: Arc<Mutex<VecDeque<Vec<u8>>>>,
89 stream: unix::OwnedWriteHalf
90}
91impl IPCCommWriteHandler {
92 pub fn new(
93 stream: unix::OwnedWriteHalf,
94 bytes_queue: Arc<Mutex<VecDeque<Vec<u8>>>>,
95 ) -> Self {
96 Self {
97 queue: bytes_queue,
98 stream
99 }
100 }
101 async fn write_tick(&mut self) -> Result<(), io::Error> {
102 self.stream.writable().await?;
103 let mut send_queue = self.queue.lock().await;
104 if let Some(send_data) = send_queue.pop_front() {
105 match self.stream.try_write(send_data.as_slice()) {
106 Ok(n) => {
107 if n < send_data.len() {
108 send_queue.push_front(send_data[{send_data.len() - n}..].into());
110 }
111 },
112 Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => {
113 send_queue.push_front(send_data);
115 },
116 Err(e) => {
117 return Err(e.into())
118 }
119 }
120 }
121 Ok(())
122 }
123}
124
125#[derive(Debug)]
127pub struct IPCComm {
128 write_handle: task::JoinHandle<()>,
129 read_handle: task::JoinHandle<()>,
130 should_stop: Arc<AtomicBool>,
131 send_queue_bytes: Arc<Mutex<VecDeque<Vec<u8>>>>,
132 recv_queue_bytes: Arc<Mutex<VecDeque<Vec<u8>>>>,
133 recv_notif: watch::Receiver<usize>
134}
135
136impl IPCComm {
138 pub fn new(
140 stream: UnixStream,
141 ) -> Self {
142 let recv_queue_bytes_mutex = Arc::new(Mutex::new(VecDeque::new()));
143 let send_queue_bytes_mutex = Arc::new(Mutex::new(VecDeque::new()));
144 let should_stop = Arc::new(AtomicBool::new(false));
145 let (recv_notif_sender, recv_notif) = watch::channel(0usize);
146
147
148 let (read_stream, write_stream) = stream.into_split();
149
150 let mut read_handler = IPCCommReadHandler::new(read_stream);
151 let should_stop_clone = should_stop.clone();
152 let recv_queue_bytes_mutex_clone = recv_queue_bytes_mutex.clone();
153 let read_handle = task::spawn(async move {
154 while !should_stop_clone.load(Ordering::Relaxed) {
155 match read_handler.read_tick().await.unwrap() {
156 IPCCommReadResult::Shutdown => {
157 should_stop_clone.store(true, Ordering::Relaxed);
158 recv_notif_sender.send_modify(|val| {
159 (*val, _) = val.overflowing_add(1)
160 })
161 },
162 IPCCommReadResult::Waiting => {
163 },
165 IPCCommReadResult::Message(msg_bytes) => {
166 let mut recv_queue_bytes = recv_queue_bytes_mutex_clone.lock().await;
167 recv_queue_bytes.push_back(msg_bytes);
168 recv_notif_sender.send_modify(|val| {
169 (*val, _) = val.overflowing_add(1)
170 })
171 },
172 }
173 }
174 });
175
176 let mut write_handler = IPCCommWriteHandler::new(write_stream, send_queue_bytes_mutex.clone());
177 let should_stop_clone = should_stop.clone();
178 let write_handle = task::spawn(async move {
179 while !should_stop_clone.load(Ordering::Relaxed) {
180 write_handler.write_tick().await.unwrap();
181 }
182 });
183
184 Self {
185 write_handle,
186 read_handle,
187 should_stop,
188 send_queue_bytes: send_queue_bytes_mutex,
189 recv_queue_bytes: recv_queue_bytes_mutex,
190 recv_notif
191 }
192 }
193
194 pub async fn new_with_identifier<I: BorshDeserialize>(stream: UnixStream) -> Result<(Self, I), io::Error> {
198 let mut new_self = Self::new(stream);
199 let id = new_self.until_recv_msg().await?.ok_or(io::Error::from(io::ErrorKind::UnexpectedEof))?;
200 Ok((new_self, id))
201 }
202
203 pub async fn send_msg<S: BorshSerialize>(&mut self, msg: S) -> Result<(), io::Error> {
206 let msg_bytes = msg.try_to_vec()?;
207 let mut send_queue_bytes = self.send_queue_bytes.lock().await;
208 send_queue_bytes.push_back((msg_bytes.len() as u64).to_le_bytes().to_vec());
209 send_queue_bytes.push_back(msg_bytes);
210 Ok(())
211 }
212
213 pub fn blocking_send_msg<S: BorshSerialize>(&mut self, msg: S) -> Result<(), io::Error> {
216 let msg_bytes = msg.try_to_vec()?;
217 let mut send_queue_bytes = self.send_queue_bytes.blocking_lock();
218 send_queue_bytes.push_back((msg_bytes.len() as u64).to_le_bytes().to_vec());
219 send_queue_bytes.push_back(msg_bytes);
220 Ok(())
221 }
222
223 pub async fn recv_msg<R: BorshDeserialize>(&mut self) -> Result<Option<R>, io::Error> {
226 let mut recv_queue_bytes = self.recv_queue_bytes.lock().await;
227 match recv_queue_bytes.pop_front() {
228 Some(msg_bytes) => {
229 Ok(Some(R::try_from_slice(&msg_bytes)?))
230 },
231 None => Ok(None),
232 }
233 }
234
235 pub async fn until_recv_msg<R: BorshDeserialize>(&mut self) -> Result<Option<R>, io::Error> {
239 loop {
240 if self.should_stop.load(Ordering::Relaxed) {
241 return Ok(None);
242 }
243 if let Some(msg) = self.recv_msg::<R>().await? {
244 return Ok(Some(msg));
245 }
246 self.recv_notif.changed().await.expect("Receiver shouldn't drop without sending a message first");
247 }
248 }
249
250 pub fn stopped(&self) -> bool {
252 self.should_stop.load(Ordering::Relaxed)
253 }
254
255 pub fn stop(&self) {
257 self.should_stop.store(true, Ordering::Relaxed);
258 }
259
260 pub async fn wait_until_stopped(self) {
262 self.write_handle.await.unwrap();
263 self.read_handle.await.unwrap();
264 }
265
266}