roblib_client/transports/
tcp.rs1use super::{Subscribable, Transport};
2use anyhow::Result;
3use roblib::{
4 cmd::{self, has_return, Command},
5 event::Event,
6};
7use serde::Deserialize;
8use std::{
9 collections::HashMap,
10 io::{Cursor, Read, Write},
11 sync::Arc,
12};
13
14type D<'a> = bincode::Deserializer<
15 bincode::de::read::IoReader<&'a mut Cursor<&'a [u8]>>,
16 bincode::DefaultOptions,
17>;
18type Handler = Box<dyn Send + Sync + (for<'a> FnMut(D<'a>) -> Result<()>)>;
19
20struct TcpInner {
21 handlers: std::sync::Mutex<HashMap<u32, (Handler, bool)>>,
22 events: std::sync::Mutex<HashMap<roblib::event::ConcreteType, u32>>,
23 running: std::sync::RwLock<bool>,
24}
25pub struct Tcp {
26 inner: Arc<TcpInner>,
27
28 socket: std::net::TcpStream,
29 id: std::sync::Mutex<u32>,
30}
31
32impl Tcp {
33 const HEADER: usize = std::mem::size_of::<u32>();
34
35 pub fn connect(robot: impl std::net::ToSocketAddrs) -> anyhow::Result<Self> {
36 let socket = std::net::TcpStream::connect(robot)?;
37
38 let inner = Arc::new(TcpInner {
39 handlers: HashMap::new().into(),
40 events: HashMap::new().into(),
41 running: true.into(),
42 });
43
44 let inner_clone = inner.clone();
45 let socket_clone = socket.try_clone()?;
46 std::thread::spawn(|| Self::listen(inner_clone, socket_clone));
47
48 Ok(Self {
49 inner,
50 id: super::ID_START.into(),
51 socket,
52 })
53 }
54
55 fn listen(inner: Arc<TcpInner>, mut socket: std::net::TcpStream) -> Result<()> {
56 let bin = bincode::options();
57 let mut buf = vec![0; 512];
58 loop {
59 let running = inner.running.read().unwrap();
60 if !*running {
61 return Ok(());
62 }
63 drop(running);
64
65 socket.read_exact(&mut buf[..Self::HEADER])?;
66 let len = u32::from_be_bytes(buf[..Self::HEADER].try_into()?) as usize;
67 let end = Self::HEADER + len;
68 if len > buf.len() {
70 buf.resize(len, 0);
71 log::debug!("Connection buffer resized to {len}");
72 }
73 socket.read_exact(&mut buf[Self::HEADER..end])?;
74
75 let mut c = Cursor::new(&buf[Self::HEADER..end]);
76 let id: u32 = bincode::Options::deserialize_from(bin, &mut c)?;
77
78 let Some(mut handler) = inner.handlers.lock().unwrap().remove(&id) else {
79 return Err(anyhow::Error::msg("received response for unknown id"));
80 };
81
82 handler.0(bincode::Deserializer::with_reader(&mut c, bin))?;
83
84 if handler.1 {
85 inner.handlers.lock().unwrap().insert(id, handler);
86 }
87 }
88 }
89
90 fn cmd_id<C>(&self, cmd: C, id: u32) -> Result<C::Return>
91 where
92 C: Command,
93 {
94 let concrete: cmd::Concrete = cmd.into();
95 let buf = bincode::Options::serialize(bincode::options(), &(id, concrete))?;
96 (&self.socket).write_all(&(buf.len() as u32).to_be_bytes())?;
97 (&self.socket).write_all(&buf)?;
98
99 Ok(if has_return::<C>() {
100 let (tx, rx) = std::sync::mpsc::sync_channel(1);
101
102 let a: Handler = Box::new(move |mut des: D| {
103 let r = C::Return::deserialize(&mut des)?;
104 tx.send(r).unwrap();
105 Ok::<(), anyhow::Error>(())
106 });
107 self.inner.handlers.lock().unwrap().insert(id, (a, false));
108
109 rx.recv()?
110 } else {
111 unsafe { std::mem::zeroed() }
112 })
113 }
114}
115
116impl Transport for Tcp {
117 fn cmd<C>(&self, cmd: C) -> anyhow::Result<C::Return>
118 where
119 C: Command,
120 {
121 let mut id_handle = self.id.lock().unwrap();
122 let id = *id_handle;
123 *id_handle = id + 1;
124 drop(id_handle);
125 self.cmd_id(cmd, id)
126 }
127}
128
129impl Subscribable for Tcp {
130 fn subscribe<E, F>(&self, ev: E, mut handler: F) -> Result<()>
131 where
132 E: Event,
133 F: (FnMut(E::Item) -> Result<()>) + Send + Sync + 'static,
134 {
135 let mut id_handle = self.id.lock().unwrap();
136 let id = *id_handle;
137 *id_handle = id + 1;
138 drop(id_handle);
139
140 let ev = ev.into();
141
142 self.inner.handlers.lock().unwrap().insert(
143 id,
144 (
145 Box::new(move |mut des| handler(E::Item::deserialize(&mut des)?)),
146 true,
147 ),
148 );
149 self.inner.events.lock().unwrap().insert(ev.clone(), id);
150
151 self.cmd_id(cmd::Subscribe(ev), id)?;
152
153 Ok(())
154 }
155
156 fn unsubscribe<E: roblib::event::Event>(&self, ev: E) -> Result<()> {
157 let ev = ev.into();
158 let cmd = cmd::Unsubscribe(ev.clone());
159
160 let mut lock = self.inner.events.lock().unwrap();
161 match lock.entry(ev) {
162 std::collections::hash_map::Entry::Occupied(v) => {
163 let id = v.remove();
164 self.cmd_id(cmd, id)?;
165 self.inner.handlers.lock().unwrap().remove(&id);
166 }
167 std::collections::hash_map::Entry::Vacant(_) => anyhow::bail!("Subscription not found"),
168 }
169
170 Ok(())
171 }
172}
173
174#[cfg(feature = "async")]
175pub use tcp_async::*;
176#[cfg(feature = "async")]
177pub mod tcp_async {
178 use std::{collections::HashMap, io::Cursor, time::Duration};
179
180 use crate::transports::{SubscribableAsync, TransportAsync};
181 use anyhow::Result;
182 use async_trait::async_trait;
183 use roblib::{
184 cmd::{self, has_return, Command},
185 event::{self, Event},
186 };
187 use serde::{Deserialize, Serialize};
188 use tokio::{
189 io::{AsyncReadExt, AsyncWriteExt, Interest},
190 net::{TcpStream, ToSocketAddrs},
191 sync::{broadcast, mpsc, oneshot},
192 task::JoinHandle,
193 };
194
195 type D = bincode::Deserializer<
196 bincode::de::read::IoReader<Cursor<Vec<u8>>>,
197 bincode::DefaultOptions,
198 >;
199
200 enum Action {
201 ServerMessage(usize),
202 Cmd(cmd::Concrete, Option<oneshot::Sender<D>>),
203 Sub(event::ConcreteType, Option<mpsc::UnboundedSender<D>>),
204 }
205
206 struct Worker {
207 stream: TcpStream,
208 cmd_rx: mpsc::UnboundedReceiver<(cmd::Concrete, Option<oneshot::Sender<D>>)>,
209 sub_rx: mpsc::UnboundedReceiver<(event::ConcreteType, Option<mpsc::UnboundedSender<D>>)>,
210 }
211 impl Worker {
212 pub fn new(
213 stream: TcpStream,
214 ) -> (
215 Self,
216 mpsc::UnboundedSender<(cmd::Concrete, Option<oneshot::Sender<D>>)>,
217 mpsc::UnboundedSender<(event::ConcreteType, Option<mpsc::UnboundedSender<D>>)>,
218 ) {
219 let (cmd_tx, cmd_rx) = mpsc::unbounded_channel();
220 let (sub_tx, sub_rx) = mpsc::unbounded_channel();
221 let s = Self {
222 stream,
223 cmd_rx,
224 sub_rx,
225 };
226 (s, cmd_tx, sub_tx)
227 }
228 pub async fn worker(mut self) -> Result<()> {
229 const HEADER: usize = std::mem::size_of::<u32>();
230
231 let mut next_id = super::super::ID_START;
232 let bin = bincode::options();
233 let mut buf = vec![0; 512];
234 let mut len = 0; let mut maybe_cmd_len = None;
236 let mut cmds: HashMap<u32, oneshot::Sender<D>> = HashMap::new();
237 let mut subs: HashMap<u32, mpsc::UnboundedSender<D>> = HashMap::new();
238 let mut sub_ids: HashMap<event::ConcreteType, u32> = HashMap::new();
239 loop {
240 let action = tokio::select! {
241 Ok(n) = self.stream.read(&mut buf[len..( HEADER + maybe_cmd_len.unwrap_or(0) )]) => Action::ServerMessage(n),
242 Some(cmd) = self.cmd_rx.recv() => Action::Cmd(cmd.0, cmd.1),
243 Some(sub) = self.sub_rx.recv() => Action::Sub(sub.0, sub.1),
244 };
249
250 match action {
251 Action::ServerMessage(n) => {
253 if n == 0 {
254 log::debug!("tcp: received 0 sized msg, investigating disconnect");
255 tokio::time::sleep(Duration::from_millis(100)).await;
257 if self.check_disconnect().await {
258 anyhow::bail!("Server disconnected!");
259 }
260 }
261
262 len += n;
263 if len < HEADER {
264 continue;
265 }
266 let cmd_len = match maybe_cmd_len {
267 Some(n) => n,
268 None => {
269 let cmd = u32::from_be_bytes((&buf[..HEADER]).try_into().unwrap())
270 as usize;
271 maybe_cmd_len = Some(cmd);
273 cmd
275 }
276 };
277 if len < HEADER + cmd_len {
278 continue;
279 }
280
281 let mut c = Cursor::new(buf[HEADER..len].to_vec()); let id: u32 = bincode::Options::deserialize_from(bin, &mut c)?;
283 if let Some(tx) = subs.get(&id) {
284 tx.send(bincode::Deserializer::with_reader(c, bin))?;
285 } else if let Some(tx) = cmds.remove(&id) {
286 if tx.send(bincode::Deserializer::with_reader(c, bin)).is_err() {
287 log::error!("cmd receiver dropped");
288 }
289 } else {
290 log::error!("server sent invalid id");
291 }
292
293 len = 0;
294 maybe_cmd_len = None;
295 }
296 Action::Cmd(cmd, maybe_tx) => {
297 let id = next_id;
298 next_id += 1;
299 if let Some(tx) = maybe_tx {
300 cmds.insert(id, tx);
301 }
302 self.send((id, cmd)).await?;
303 }
304 Action::Sub(ev, Some(tx)) => {
305 let id = next_id;
306 next_id += 1;
307 subs.insert(id, tx);
308 let cmd: cmd::Concrete = cmd::Subscribe(ev).into();
309 self.send((id, cmd)).await?;
310 }
311 Action::Sub(ev, None) => {
313 let Some(id) = sub_ids.remove(&ev) else {
314 log::error!("unsubscribe failed: {ev:?} subscription not found");
315 continue;
316 };
317 subs.remove(&id);
318 let cmd: cmd::Concrete = cmd::Unsubscribe(ev).into();
319 self.send((id, cmd)).await?;
320 }
321 }
322 }
323 }
324 async fn check_disconnect(&mut self) -> bool {
325 let r = self
326 .stream
327 .ready(Interest::READABLE | Interest::WRITABLE)
328 .await;
329 if r.as_ref()
330 .map_or(true, |r| r.is_read_closed() || r.is_write_closed())
331 {
332 log::error!("Server disconnected!");
333 log::debug!("{r:#?}");
334 return true;
335 }
336 return false;
337 }
338 async fn send(&mut self, data: impl Serialize) -> Result<()> {
339 let buf = bincode::Options::serialize(bincode::options(), &data)?;
340 log::debug!("{buf:?}");
341 self.stream
342 .write_all(&(buf.len() as u32).to_be_bytes())
343 .await?;
344 self.stream.write_all(&buf).await?;
345 Ok(())
346 }
347 }
348
349 pub struct TcpAsync {
350 _handle: Option<JoinHandle<Result<()>>>,
351 cmd_tx: mpsc::UnboundedSender<(cmd::Concrete, Option<oneshot::Sender<D>>)>,
352 sub_tx: mpsc::UnboundedSender<(event::ConcreteType, Option<mpsc::UnboundedSender<D>>)>,
353 }
354
355 impl TcpAsync {
356 pub async fn connect(addr: impl ToSocketAddrs) -> Result<Self> {
357 let stream = TcpStream::connect(addr).await?;
358 let (worker, cmd_tx, sub_tx) = Worker::new(stream);
359 let handle = Some(tokio::spawn(async {
360 let r = worker.worker().await;
361 log::debug!("worker dropped??");
362 r
363 }));
364
365 Ok(Self {
366 _handle: handle,
367 cmd_tx,
368 sub_tx,
369 })
370 }
371 }
372
373 #[async_trait]
374 impl TransportAsync for TcpAsync {
375 async fn cmd<C>(&self, cmd: C) -> Result<C::Return>
376 where
377 C: Command,
378 {
379 let concr: cmd::Concrete = cmd.into();
380 if has_return::<C>() {
381 let (tx, rx) = oneshot::channel();
382 self.cmd_tx.send((concr, Some(tx)))?;
383 let mut de = rx.await?;
384 Ok(C::Return::deserialize(&mut de)?)
385 } else {
386 self.cmd_tx.send((concr, None))?;
387 unsafe { std::mem::zeroed() }
388 }
389 }
390 }
391
392 #[async_trait]
393 impl SubscribableAsync for TcpAsync {
394 async fn subscribe<E: Event>(&self, ev: E) -> Result<broadcast::Receiver<E::Item>> {
395 let (worker_tx, mut worker_rx) = mpsc::unbounded_channel();
396 self.sub_tx.send((ev.into(), Some(worker_tx)))?;
397
398 let (client_tx, client_rx) = broadcast::channel(128);
399 tokio::spawn(async move {
400 while let Some(mut de) = worker_rx.recv().await {
401 let item = E::Item::deserialize(&mut de)?;
402 if client_tx.send(item).is_err() {
403 log::error!("no receiver for active subscription");
404 };
405 }
406 anyhow::Ok(())
407 });
408 Ok(client_rx)
409 }
410
411 async fn unsubscribe<E>(&self, ev: E) -> Result<()>
412 where
413 E: Event,
414 {
415 Ok(self.sub_tx.send((ev.into(), None))?)
416 }
417 }
418}