roblib_client/transports/
tcp.rs

1use super::{Subscribable, Transport};
2use anyhow::Result;
3use roblib::{
4    cmd::{self, has_return, Command},
5    event::Event,
6};
7use serde::Deserialize;
8use std::{
9    collections::HashMap,
10    io::{Cursor, Read, Write},
11    sync::Arc,
12};
13
14type D<'a> = bincode::Deserializer<
15    bincode::de::read::IoReader<&'a mut Cursor<&'a [u8]>>,
16    bincode::DefaultOptions,
17>;
18type Handler = Box<dyn Send + Sync + (for<'a> FnMut(D<'a>) -> Result<()>)>;
19
20struct TcpInner {
21    handlers: std::sync::Mutex<HashMap<u32, (Handler, bool)>>,
22    events: std::sync::Mutex<HashMap<roblib::event::ConcreteType, u32>>,
23    running: std::sync::RwLock<bool>,
24}
25pub struct Tcp {
26    inner: Arc<TcpInner>,
27
28    socket: std::net::TcpStream,
29    id: std::sync::Mutex<u32>,
30}
31
32impl Tcp {
33    const HEADER: usize = std::mem::size_of::<u32>();
34
35    pub fn connect(robot: impl std::net::ToSocketAddrs) -> anyhow::Result<Self> {
36        let socket = std::net::TcpStream::connect(robot)?;
37
38        let inner = Arc::new(TcpInner {
39            handlers: HashMap::new().into(),
40            events: HashMap::new().into(),
41            running: true.into(),
42        });
43
44        let inner_clone = inner.clone();
45        let socket_clone = socket.try_clone()?;
46        std::thread::spawn(|| Self::listen(inner_clone, socket_clone));
47
48        Ok(Self {
49            inner,
50            id: super::ID_START.into(),
51            socket,
52        })
53    }
54
55    fn listen(inner: Arc<TcpInner>, mut socket: std::net::TcpStream) -> Result<()> {
56        let bin = bincode::options();
57        let mut buf = vec![0; 512];
58        loop {
59            let running = inner.running.read().unwrap();
60            if !*running {
61                return Ok(());
62            }
63            drop(running);
64
65            socket.read_exact(&mut buf[..Self::HEADER])?;
66            let len = u32::from_be_bytes(buf[..Self::HEADER].try_into()?) as usize;
67            let end = Self::HEADER + len;
68            // log::debug!("Receiving {len} bytes");
69            if len > buf.len() {
70                buf.resize(len, 0);
71                log::debug!("Connection buffer resized to {len}");
72            }
73            socket.read_exact(&mut buf[Self::HEADER..end])?;
74
75            let mut c = Cursor::new(&buf[Self::HEADER..end]);
76            let id: u32 = bincode::Options::deserialize_from(bin, &mut c)?;
77
78            let Some(mut handler) = inner.handlers.lock().unwrap().remove(&id) else {
79                return Err(anyhow::Error::msg("received response for unknown id"));
80            };
81
82            handler.0(bincode::Deserializer::with_reader(&mut c, bin))?;
83
84            if handler.1 {
85                inner.handlers.lock().unwrap().insert(id, handler);
86            }
87        }
88    }
89
90    fn cmd_id<C>(&self, cmd: C, id: u32) -> Result<C::Return>
91    where
92        C: Command,
93    {
94        let concrete: cmd::Concrete = cmd.into();
95        let buf = bincode::Options::serialize(bincode::options(), &(id, concrete))?;
96        (&self.socket).write_all(&(buf.len() as u32).to_be_bytes())?;
97        (&self.socket).write_all(&buf)?;
98
99        Ok(if has_return::<C>() {
100            let (tx, rx) = std::sync::mpsc::sync_channel(1);
101
102            let a: Handler = Box::new(move |mut des: D| {
103                let r = C::Return::deserialize(&mut des)?;
104                tx.send(r).unwrap();
105                Ok::<(), anyhow::Error>(())
106            });
107            self.inner.handlers.lock().unwrap().insert(id, (a, false));
108
109            rx.recv()?
110        } else {
111            unsafe { std::mem::zeroed() }
112        })
113    }
114}
115
116impl Transport for Tcp {
117    fn cmd<C>(&self, cmd: C) -> anyhow::Result<C::Return>
118    where
119        C: Command,
120    {
121        let mut id_handle = self.id.lock().unwrap();
122        let id = *id_handle;
123        *id_handle = id + 1;
124        drop(id_handle);
125        self.cmd_id(cmd, id)
126    }
127}
128
129impl Subscribable for Tcp {
130    fn subscribe<E, F>(&self, ev: E, mut handler: F) -> Result<()>
131    where
132        E: Event,
133        F: (FnMut(E::Item) -> Result<()>) + Send + Sync + 'static,
134    {
135        let mut id_handle = self.id.lock().unwrap();
136        let id = *id_handle;
137        *id_handle = id + 1;
138        drop(id_handle);
139
140        let ev = ev.into();
141
142        self.inner.handlers.lock().unwrap().insert(
143            id,
144            (
145                Box::new(move |mut des| handler(E::Item::deserialize(&mut des)?)),
146                true,
147            ),
148        );
149        self.inner.events.lock().unwrap().insert(ev.clone(), id);
150
151        self.cmd_id(cmd::Subscribe(ev), id)?;
152
153        Ok(())
154    }
155
156    fn unsubscribe<E: roblib::event::Event>(&self, ev: E) -> Result<()> {
157        let ev = ev.into();
158        let cmd = cmd::Unsubscribe(ev.clone());
159
160        let mut lock = self.inner.events.lock().unwrap();
161        match lock.entry(ev) {
162            std::collections::hash_map::Entry::Occupied(v) => {
163                let id = v.remove();
164                self.cmd_id(cmd, id)?;
165                self.inner.handlers.lock().unwrap().remove(&id);
166            }
167            std::collections::hash_map::Entry::Vacant(_) => anyhow::bail!("Subscription not found"),
168        }
169
170        Ok(())
171    }
172}
173
174#[cfg(feature = "async")]
175pub use tcp_async::*;
176#[cfg(feature = "async")]
177pub mod tcp_async {
178    use std::{collections::HashMap, io::Cursor, time::Duration};
179
180    use crate::transports::{SubscribableAsync, TransportAsync};
181    use anyhow::Result;
182    use async_trait::async_trait;
183    use roblib::{
184        cmd::{self, has_return, Command},
185        event::{self, Event},
186    };
187    use serde::{Deserialize, Serialize};
188    use tokio::{
189        io::{AsyncReadExt, AsyncWriteExt, Interest},
190        net::{TcpStream, ToSocketAddrs},
191        sync::{broadcast, mpsc, oneshot},
192        task::JoinHandle,
193    };
194
195    type D = bincode::Deserializer<
196        bincode::de::read::IoReader<Cursor<Vec<u8>>>,
197        bincode::DefaultOptions,
198    >;
199
200    enum Action {
201        ServerMessage(usize),
202        Cmd(cmd::Concrete, Option<oneshot::Sender<D>>),
203        Sub(event::ConcreteType, Option<mpsc::UnboundedSender<D>>),
204    }
205
206    struct Worker {
207        stream: TcpStream,
208        cmd_rx: mpsc::UnboundedReceiver<(cmd::Concrete, Option<oneshot::Sender<D>>)>,
209        sub_rx: mpsc::UnboundedReceiver<(event::ConcreteType, Option<mpsc::UnboundedSender<D>>)>,
210    }
211    impl Worker {
212        pub fn new(
213            stream: TcpStream,
214        ) -> (
215            Self,
216            mpsc::UnboundedSender<(cmd::Concrete, Option<oneshot::Sender<D>>)>,
217            mpsc::UnboundedSender<(event::ConcreteType, Option<mpsc::UnboundedSender<D>>)>,
218        ) {
219            let (cmd_tx, cmd_rx) = mpsc::unbounded_channel();
220            let (sub_tx, sub_rx) = mpsc::unbounded_channel();
221            let s = Self {
222                stream,
223                cmd_rx,
224                sub_rx,
225            };
226            (s, cmd_tx, sub_tx)
227        }
228        pub async fn worker(mut self) -> Result<()> {
229            const HEADER: usize = std::mem::size_of::<u32>();
230
231            let mut next_id = super::super::ID_START;
232            let bin = bincode::options();
233            let mut buf = vec![0; 512];
234            let mut len = 0; // no. of bytes read for the current command we're attempting to parse
235            let mut maybe_cmd_len = None;
236            let mut cmds: HashMap<u32, oneshot::Sender<D>> = HashMap::new();
237            let mut subs: HashMap<u32, mpsc::UnboundedSender<D>> = HashMap::new();
238            let mut sub_ids: HashMap<event::ConcreteType, u32> = HashMap::new();
239            loop {
240                let action = tokio::select! {
241                    Ok(n) = self.stream.read(&mut buf[len..( HEADER + maybe_cmd_len.unwrap_or(0) )]) => Action::ServerMessage(n),
242                    Some(cmd) = self.cmd_rx.recv() => Action::Cmd(cmd.0, cmd.1),
243                    Some(sub) = self.sub_rx.recv() => Action::Sub(sub.0, sub.1),
244                    // _ = tokio::time::sleep(Duration::from_secs(5)) => {
245                    //     self.check_disconnect().await;
246                    //     continue;
247                    // }
248                };
249
250                match action {
251                    // adapted from server/src/transports/tcp.rs
252                    Action::ServerMessage(n) => {
253                        if n == 0 {
254                            log::debug!("tcp: received 0 sized msg, investigating disconnect");
255                            // give the socket some time to fully realize disconnect
256                            tokio::time::sleep(Duration::from_millis(100)).await;
257                            if self.check_disconnect().await {
258                                anyhow::bail!("Server disconnected!");
259                            }
260                        }
261
262                        len += n;
263                        if len < HEADER {
264                            continue;
265                        }
266                        let cmd_len = match maybe_cmd_len {
267                            Some(n) => n,
268                            None => {
269                                let cmd = u32::from_be_bytes((&buf[..HEADER]).try_into().unwrap())
270                                    as usize;
271                                // buf.resize(HEADER + cmd, 0);
272                                maybe_cmd_len = Some(cmd);
273                                // log::debug!("header received, cmdlen: {cmd}");
274                                cmd
275                            }
276                        };
277                        if len < HEADER + cmd_len {
278                            continue;
279                        }
280
281                        let mut c = Cursor::new(buf[HEADER..len].to_vec()); // clone :(
282                        let id: u32 = bincode::Options::deserialize_from(bin, &mut c)?;
283                        if let Some(tx) = subs.get(&id) {
284                            tx.send(bincode::Deserializer::with_reader(c, bin))?;
285                        } else if let Some(tx) = cmds.remove(&id) {
286                            if tx.send(bincode::Deserializer::with_reader(c, bin)).is_err() {
287                                log::error!("cmd receiver dropped");
288                            }
289                        } else {
290                            log::error!("server sent invalid id");
291                        }
292
293                        len = 0;
294                        maybe_cmd_len = None;
295                    }
296                    Action::Cmd(cmd, maybe_tx) => {
297                        let id = next_id;
298                        next_id += 1;
299                        if let Some(tx) = maybe_tx {
300                            cmds.insert(id, tx);
301                        }
302                        self.send((id, cmd)).await?;
303                    }
304                    Action::Sub(ev, Some(tx)) => {
305                        let id = next_id;
306                        next_id += 1;
307                        subs.insert(id, tx);
308                        let cmd: cmd::Concrete = cmd::Subscribe(ev).into();
309                        self.send((id, cmd)).await?;
310                    }
311                    // None: unsubscribe
312                    Action::Sub(ev, None) => {
313                        let Some(id) = sub_ids.remove(&ev) else {
314                            log::error!("unsubscribe failed: {ev:?} subscription not found");
315                            continue;
316                        };
317                        subs.remove(&id);
318                        let cmd: cmd::Concrete = cmd::Unsubscribe(ev).into();
319                        self.send((id, cmd)).await?;
320                    }
321                }
322            }
323        }
324        async fn check_disconnect(&mut self) -> bool {
325            let r = self
326                .stream
327                .ready(Interest::READABLE | Interest::WRITABLE)
328                .await;
329            if r.as_ref()
330                .map_or(true, |r| r.is_read_closed() || r.is_write_closed())
331            {
332                log::error!("Server disconnected!");
333                log::debug!("{r:#?}");
334                return true;
335            }
336            return false;
337        }
338        async fn send(&mut self, data: impl Serialize) -> Result<()> {
339            let buf = bincode::Options::serialize(bincode::options(), &data)?;
340            log::debug!("{buf:?}");
341            self.stream
342                .write_all(&(buf.len() as u32).to_be_bytes())
343                .await?;
344            self.stream.write_all(&buf).await?;
345            Ok(())
346        }
347    }
348
349    pub struct TcpAsync {
350        _handle: Option<JoinHandle<Result<()>>>,
351        cmd_tx: mpsc::UnboundedSender<(cmd::Concrete, Option<oneshot::Sender<D>>)>,
352        sub_tx: mpsc::UnboundedSender<(event::ConcreteType, Option<mpsc::UnboundedSender<D>>)>,
353    }
354
355    impl TcpAsync {
356        pub async fn connect(addr: impl ToSocketAddrs) -> Result<Self> {
357            let stream = TcpStream::connect(addr).await?;
358            let (worker, cmd_tx, sub_tx) = Worker::new(stream);
359            let handle = Some(tokio::spawn(async {
360                let r = worker.worker().await;
361                log::debug!("worker dropped??");
362                r
363            }));
364
365            Ok(Self {
366                _handle: handle,
367                cmd_tx,
368                sub_tx,
369            })
370        }
371    }
372
373    #[async_trait]
374    impl TransportAsync for TcpAsync {
375        async fn cmd<C>(&self, cmd: C) -> Result<C::Return>
376        where
377            C: Command,
378        {
379            let concr: cmd::Concrete = cmd.into();
380            if has_return::<C>() {
381                let (tx, rx) = oneshot::channel();
382                self.cmd_tx.send((concr, Some(tx)))?;
383                let mut de = rx.await?;
384                Ok(C::Return::deserialize(&mut de)?)
385            } else {
386                self.cmd_tx.send((concr, None))?;
387                unsafe { std::mem::zeroed() }
388            }
389        }
390    }
391
392    #[async_trait]
393    impl SubscribableAsync for TcpAsync {
394        async fn subscribe<E: Event>(&self, ev: E) -> Result<broadcast::Receiver<E::Item>> {
395            let (worker_tx, mut worker_rx) = mpsc::unbounded_channel();
396            self.sub_tx.send((ev.into(), Some(worker_tx)))?;
397
398            let (client_tx, client_rx) = broadcast::channel(128);
399            tokio::spawn(async move {
400                while let Some(mut de) = worker_rx.recv().await {
401                    let item = E::Item::deserialize(&mut de)?;
402                    if client_tx.send(item).is_err() {
403                        log::error!("no receiver for active subscription");
404                    };
405                }
406                anyhow::Ok(())
407            });
408            Ok(client_rx)
409        }
410
411        async fn unsubscribe<E>(&self, ev: E) -> Result<()>
412        where
413            E: Event,
414        {
415            Ok(self.sub_tx.send((ev.into(), None))?)
416        }
417    }
418}