nt_rs/
lib.rs

1use core::panic;
2use std::{
3    collections::HashMap,
4    marker::PhantomData,
5    sync::{
6        atomic::{AtomicI64, AtomicU32},
7        Arc, Mutex,
8    },
9    time::Duration,
10};
11
12use flume::{unbounded, Receiver, RecvError, Sender};
13use futures::{select, Future, FutureExt};
14use http::uri::InvalidUri;
15use payload::Payload;
16use thiserror::Error;
17use time::get_time;
18use types::{
19    BinaryData, BinaryMessage, BinaryMessageError, Properties, SubscriptionOptions, TextMessage,
20};
21
22pub mod backend;
23pub mod payload;
24pub mod time;
25pub mod types;
26
27/// Any type of message that could be sent or recieved from a websocket
28#[derive(Debug)]
29pub enum NtMessage {
30    Text(TextMessage),
31    Binary(BinaryMessage),
32}
33
34struct Topics {
35    topics: HashMap<String, Sender<SubscriberUpdate>>,
36    topic_ids: HashMap<u32, String>,
37}
38
39impl Default for Topics {
40    fn default() -> Self {
41        Self {
42            topics: Default::default(),
43            topic_ids: Default::default(),
44        }
45    }
46}
47
48struct InnerNetworkTableClient {
49    receive: Receiver<Result<NtMessage>>,
50    send: Sender<NtMessage>,
51    topics: Mutex<Topics>,
52    time_offset: AtomicI64,
53
54    subuid: AtomicU32,
55    pubuid: AtomicU32,
56}
57
58enum SubscriberUpdate {
59    Properties(Properties),
60    Data(BinaryData),
61    Type(String),
62}
63
64#[derive(Debug, Error)]
65pub enum Error {
66    #[error("Encountered an http error: {0}")]
67    Http(#[from] http::Error),
68    #[error("Encountered a websocket error: {0}")]
69    Websocket(#[from] tungstenite::Error),
70    #[error("Invalid uri error: {0}")]
71    Uri(#[from] InvalidUri),
72    #[error("Server does not support the nt v4.0 protocol")]
73    UnsupportedServer,
74    #[error("Error while encoding or decoding a binary message: {0}")]
75    BinaryMessage(#[from] BinaryMessageError),
76    #[error("Error while encoding or decoding a text message: {0}")]
77    TextMessage(#[from] serde_json::Error),
78    #[error("Error while sending a message")]
79    Send,
80    #[error("Error while receiving a message: {0}")]
81    Receive(#[from] RecvError),
82    #[error("Other error occured: {0}")]
83    Other(Box<dyn std::error::Error + 'static + Send>),
84    #[error("Encountered an unknown frame")]
85    UnknownFrame,
86    #[error("Encountered an incorrect type")]
87    Type,
88}
89
90type Result<T, E = Error> = std::result::Result<T, E>;
91
92/// A generic timer driver
93pub trait Timer {
94    /// Delay for the specified duration
95    fn time(duration: Duration) -> impl std::future::Future<Output = ()> + Send;
96}
97
98/// A generic backend that a client can use. [backend::TokioBackend] is a good example.
99pub trait Backend {
100    /// A type like a join handle that whatever is using the client might need
101    type Output;
102    type Error: std::error::Error + 'static + Send;
103
104    /// Using the hostname and client name create a backend that sends [NtMessage] or [Error] to
105    /// the client and passes on [NtMessage] to the server
106    fn create(
107        host: &str,
108        name: &str,
109        send: Sender<Result<NtMessage>>,
110        receive: Receiver<NtMessage>,
111    ) -> std::result::Result<Self::Output, Self::Error>;
112}
113
114impl InnerNetworkTableClient {
115    async fn new<B: Backend>(host: &str, name: &str) -> Result<(Self, B::Output)> {
116        let (send_out, receive_out) = unbounded();
117        let (send_in, receive_in) = unbounded();
118
119        let out = match B::create(host, name, send_in, receive_out) {
120            Ok(out) => out,
121            Err(err) => return Err(Error::Other(Box::new(err))),
122        };
123
124        send_out
125            .send(NtMessage::Binary(BinaryMessage {
126                id: -1,
127                timestamp: 0,
128                data: BinaryData::Int(get_time() as i64),
129            }))
130            .map_err(|_| Error::Send)?;
131
132        let NtMessage::Binary(msg) = receive_in.recv_async().await?? else {
133            return Err(Error::Type);
134        };
135
136        if msg.id != -1 {
137            return Err(Error::Type); // TODO: Maybe not the right response
138        }
139
140        let BinaryData::Int(time) = msg.data else {
141            return Err(Error::Type);
142        };
143
144        let server_time = (get_time() as i64 - time) / 2 + msg.timestamp as i64;
145        let offset = server_time - get_time() as i64;
146
147        Ok((
148            Self {
149                send: send_out,
150                receive: receive_in,
151                topics: Mutex::new(Default::default()),
152                time_offset: AtomicI64::new(offset),
153
154                subuid: AtomicU32::new(u32::MIN),
155                pubuid: AtomicU32::new(u32::MIN),
156            },
157            out,
158        ))
159    }
160
161    fn get_server_time(&self) -> u64 {
162        let offset = self.time_offset.load(std::sync::atomic::Ordering::Relaxed);
163        (get_time() as i64 + offset) as u64
164    }
165
166    async fn main_loop<T: Timer>(&self) -> Result<()> {
167        select! {
168            res = self.time_loop::<T>().fuse() => {
169                return res;
170            }
171            res = self.recv_loop().fuse() => {
172                return res;
173            }
174        }
175    }
176
177    async fn time_loop<T: Timer>(&self) -> Result<()> {
178        loop {
179            T::time(Duration::from_secs(2)).await;
180
181            self.start_sync_time()?;
182        }
183    }
184
185    async fn recv_loop(&self) -> Result<()> {
186        loop {
187            let val = self.receive.recv_async().await??;
188
189            match val {
190                NtMessage::Text(msg) => match msg {
191                    TextMessage::Announce {
192                        name,
193                        id,
194                        data_type,
195                        pubuid: _,
196                        properties,
197                    } => {
198                        let mut topics = self.topics.lock().unwrap();
199
200                        let Some(sender) = topics.topics.get(&name) else {
201                            continue;
202                        };
203
204                        if sender.send(SubscriberUpdate::Type(data_type)).is_err()
205                            || sender
206                                .send(SubscriberUpdate::Properties(properties))
207                                .is_err()
208                        {
209                            topics.topics.remove(&name);
210                        } else {
211                            topics.topic_ids.insert(id, name);
212                        }
213                    }
214                    TextMessage::Unannounce { name, id } => {
215                        let mut topics = self.topics.lock().unwrap();
216
217                        topics.topics.remove(&name);
218                        topics.topic_ids.remove(&id);
219                    }
220                    TextMessage::Properties {
221                        name,
222                        ack: _,
223                        update,
224                    } => {
225                        let mut topics = self.topics.lock().unwrap();
226
227                        let topic = topics.topics.get(&name);
228
229                        if let Some(topic) = topic {
230                            if topic.send(SubscriberUpdate::Properties(update)).is_err() {
231                                topics.topics.remove(&name);
232                            }
233                        }
234                    }
235                    _ => unreachable!("A server-bound message was sent to the client"),
236                },
237                NtMessage::Binary(msg) => {
238                    if msg.id == -1 {
239                        let BinaryData::Int(time) = msg.data else {
240                            return Err(Error::Type);
241                        };
242
243                        let server_time = (get_time() as i64 - time) / 2 + msg.timestamp as i64;
244                        let offset = server_time - get_time() as i64;
245
246                        self.time_offset
247                            .fetch_min(offset, std::sync::atomic::Ordering::Relaxed);
248                    } else {
249                        let mut topics = self.topics.lock().unwrap();
250
251                        let Some(name) = topics.topic_ids.get(&(msg.id as u32)) else {
252                            topics.topic_ids.remove(&(msg.id as u32));
253                            continue;
254                        };
255
256                        let is_sender_dropped = topics
257                            .topics
258                            .get(name)
259                            .map(|topic| topic.send(SubscriberUpdate::Data(msg.data)).is_err())
260                            .unwrap_or(false);
261
262                        if is_sender_dropped {
263                            let name = name.to_owned();
264                            topics.topics.remove(&name);
265                            topics.topic_ids.remove(&(msg.id as u32));
266                        }
267                    }
268                }
269            }
270        }
271    }
272
273    fn start_sync_time(&self) -> Result<()> {
274        self.send
275            .send(NtMessage::Binary(BinaryMessage {
276                id: -1,
277                timestamp: 0,
278                data: BinaryData::Int(get_time() as i64),
279            }))
280            .map_err(|_| Error::Send)?;
281
282        Ok(())
283    }
284
285    fn subscribe(&self, topics: Vec<String>, options: SubscriptionOptions) -> Result<u32> {
286        let id = self.new_subuid();
287        self.send
288            .send(NtMessage::Text(TextMessage::Subscribe {
289                topics,
290                subuid: id,
291                options,
292            }))
293            .map_err(|_| Error::Send)?;
294
295        Ok(id)
296    }
297
298    fn unsubscribe(&self, id: u32) -> Result<()> {
299        self.send
300            .send(NtMessage::Text(TextMessage::Unsubscribe { subuid: id }))
301            .map_err(|_| Error::Send)?;
302
303        Ok(())
304    }
305
306    fn publish(&self, name: String, data_type: String, properties: Properties) -> Result<u32> {
307        let id = self.new_pubuid();
308
309        self.send
310            .send(NtMessage::Text(TextMessage::Publish {
311                name,
312                pubuid: id,
313                data_type,
314                properties,
315            }))
316            .map_err(|_| Error::Send)?;
317
318        Ok(id)
319    }
320
321    fn unpublish(&self, id: u32) -> Result<()> {
322        self.send
323            .send(NtMessage::Text(TextMessage::Unpublish { pubuid: id }))
324            .map_err(|_| Error::Send)?;
325
326        Ok(())
327    }
328
329    fn new_subuid(&self) -> u32 {
330        self.subuid
331            .fetch_add(1, std::sync::atomic::Ordering::Relaxed)
332    }
333
334    fn new_pubuid(&self) -> u32 {
335        self.pubuid
336            .fetch_add(1, std::sync::atomic::Ordering::Relaxed)
337    }
338}
339
340// An instance of a network table client. Based on an [Arc] internally, so cheap to copy
341#[derive(Clone)]
342pub struct NetworkTableClient {
343    inner: Arc<InnerNetworkTableClient>,
344}
345
346impl NetworkTableClient {
347    /// Create a new client using the hostname, client name, and a backend type
348    pub async fn new<B: Backend>(host: &str, name: &str) -> Result<(Self, B::Output)> {
349        let (inner, out) = InnerNetworkTableClient::new::<B>(host, name).await?;
350
351        Ok((
352            Self {
353                inner: Arc::new(inner),
354            },
355            out,
356        ))
357    }
358
359    /// This returns a future that should be run on the side, usually, in an async task. This
360    /// future must remain alive for as long as subscriber and time updates are required
361    pub fn main_task<T: Timer>(&self) -> impl Future<Output = Result<()>> + 'static {
362        let inner = self.inner.clone();
363
364        async move { inner.main_loop::<T>().await }
365    }
366
367    /// Create a subscriber for a topic with a certain payload type
368    pub fn subscribe<P: Payload>(&self, name: String) -> Result<Subscriber<P>> {
369        let (sender, receiver) = unbounded();
370
371        self.inner
372            .topics
373            .lock()
374            .unwrap()
375            .topics
376            .insert(name.clone(), sender);
377
378        let id = self
379            .inner
380            .subscribe(vec![name.clone()], Default::default())?;
381
382        Ok(Subscriber {
383            name,
384            properties: None,
385            input: receiver,
386            id,
387            client: self.inner.clone(),
388            phantom: PhantomData,
389        })
390    }
391
392    /// Create a publisher for a topic with a certain payload type
393    pub fn publish<P: Payload>(&self, name: String) -> Result<Publisher<P>> {
394        let id = self
395            .inner
396            .publish(name.clone(), P::name().to_owned(), Default::default())?;
397
398        Ok(Publisher {
399            name,
400            id,
401            client: self.inner.clone(),
402            phantom: PhantomData,
403        })
404    }
405}
406
407pub struct Subscriber<P: Payload> {
408    name: String,
409    properties: Option<Properties>,
410    input: Receiver<SubscriberUpdate>,
411    id: u32,
412    client: Arc<InnerNetworkTableClient>,
413    phantom: PhantomData<P>,
414}
415
416impl<P: Payload> Subscriber<P> {
417    fn consume_updates(&mut self) -> Result<Option<P>> {
418        let mut data = None;
419        for update in self.input.drain() {
420            match update {
421                SubscriberUpdate::Properties(props) => {
422                    if self.properties.is_none() {
423                        self.properties = Some(Default::default());
424                    }
425
426                    self.properties.as_mut().unwrap().update(props);
427                }
428                SubscriberUpdate::Data(bin_data) => {
429                    data = Some(P::parse(bin_data).map_err(|_| Error::Type)?);
430                }
431                SubscriberUpdate::Type(val) => {
432                    if &val != P::name() {
433                        return Err(Error::Type);
434                    }
435                }
436            }
437        }
438
439        Ok(data)
440    }
441
442    /// Wait for a new payload value to become avaliable
443    pub async fn get(&mut self) -> Result<P> {
444        if !self.input.is_empty() {
445            if let Some(val) = self.consume_updates()? {
446                return Ok(val);
447            }
448        }
449
450        loop {
451            let val = self.input.recv_async().await?;
452
453            match val {
454                SubscriberUpdate::Properties(props) => {
455                    if self.properties.is_none() {
456                        self.properties = Some(Default::default());
457                    }
458
459                    self.properties.as_mut().unwrap().update(props);
460                }
461                SubscriberUpdate::Data(val) => {
462                    break P::parse(val).map_err(|_| Error::Type);
463                }
464                SubscriberUpdate::Type(ty) => {
465                    if &ty != P::name() {
466                        return Err(Error::Type);
467                    }
468                }
469            }
470        }
471    }
472
473    pub fn properties(&self) -> Option<&Properties> {
474        self.properties.as_ref()
475    }
476
477    pub fn name(&self) -> &str {
478        &self.name
479    }
480}
481
482impl<P: Payload> Drop for Subscriber<P> {
483    fn drop(&mut self) {
484        let _ = self.client.unsubscribe(self.id);
485    }
486}
487
488pub struct Publisher<P: Payload> {
489    name: String,
490    id: u32,
491    client: Arc<InnerNetworkTableClient>,
492    phantom: PhantomData<P>,
493}
494
495impl<P: Payload> Publisher<P> {
496    pub fn set(&self, value: P) -> Result<()> {
497        self.client
498            .send
499            .send(NtMessage::Binary(BinaryMessage {
500                id: self.id as i64,
501                timestamp: self.client.get_server_time(),
502                data: value.to_val(),
503            }))
504            .map_err(|_| Error::Send)
505    }
506
507    pub fn set_properties(&self, props: Properties) -> Result<()> {
508        self.client
509            .send
510            .send(NtMessage::Text(TextMessage::SetProperties {
511                name: self.name.clone(),
512                update: props,
513            }))
514            .map_err(|_| Error::Send)
515    }
516
517    pub fn name(&self) -> &str {
518        &self.name
519    }
520}
521
522impl<P: Payload> Drop for Publisher<P> {
523    fn drop(&mut self) {
524        let _ = self.client.unpublish(self.id);
525    }
526}