postcard_rpc/host_client/
mod.rs

1//! A postcard-rpc host client
2//!
3//! This library is meant to be used with the `Dispatch` type and the
4//! postcard-rpc wire protocol.
5
6use core::time::Duration;
7use std::{
8    collections::HashSet,
9    future::Future,
10    marker::PhantomData,
11    sync::{
12        atomic::{AtomicU32, Ordering},
13        Arc, RwLock,
14    },
15};
16use thiserror::Error;
17
18use maitake_sync::{
19    wait_map::{WaitError, WakeOutcome},
20    WaitMap,
21};
22use postcard_schema::{schema::owned::OwnedNamedType, Schema};
23use serde::{de::DeserializeOwned, Deserialize, Serialize};
24use tokio::{
25    select,
26    sync::{broadcast, mpsc, Mutex},
27};
28use util::Subscriptions;
29
30use crate::{
31    header::{VarHeader, VarKey, VarKeyKind, VarSeq, VarSeqKind},
32    standard_icd::{GetAllSchemaDataTopic, GetAllSchemasEndpoint, OwnedSchemaData},
33    Endpoint, Key, Topic, TopicDirection,
34};
35
36use self::util::Stopper;
37pub use crate::host_client::util::HostClientConfig;
38
39#[cfg(all(feature = "raw-nusb", not(target_family = "wasm")))]
40mod raw_nusb;
41
42#[cfg(all(feature = "cobs-serial", not(target_family = "wasm")))]
43mod serial;
44
45#[cfg(all(feature = "webusb", target_family = "wasm"))]
46pub mod webusb;
47
48pub(crate) mod util;
49
50#[cfg(feature = "test-utils")]
51pub mod test_channels;
52
53/// Host Error Kind
54#[derive(Debug, PartialEq, Error)]
55pub enum HostErr<WireErr> {
56    /// An error of the user-specified wire error type
57    #[error("a wire error occurred")]
58    Wire(WireErr),
59    /// We got a response that didn't match the expected value or the
60    /// user specified wire error type
61    ///
62    /// This is also (misused) to report when duplicate sequence numbers
63    /// in-flight at the same time are detected.
64    #[error("the response received didn't match the expected value or the wire error type")]
65    BadResponse,
66    /// Deserialization of the message failed
67    #[error("message deserialization failed")]
68    Postcard(#[from] postcard::Error),
69    /// The interface has been closed, and no further messages are possible
70    #[error("the interface has been closed, and no further messages are possible")]
71    Closed,
72}
73
74impl<T> From<WaitError> for HostErr<T> {
75    fn from(_: WaitError) -> Self {
76        Self::Closed
77    }
78}
79
80/// Wire Transmit Interface
81///
82/// Responsible for taking a serialized frame (including header and payload),
83/// performing any further encoding if necessary, and transmitting to the device.
84///
85/// Should complete once the message is fully sent (e.g. not just enqueued)
86/// if possible.
87///
88/// All errors are treated as fatal - resolvable or ignorable errors should not
89/// be returned to the caller.
90#[cfg(target_family = "wasm")]
91pub trait WireTx: 'static {
92    /// Transmit error type
93    type Error: std::error::Error;
94    /// Send a single frame
95    fn send(&mut self, data: Vec<u8>) -> impl Future<Output = Result<(), Self::Error>>;
96}
97
98/// Wire Receive Interface
99///
100/// Responsible for accumulating a serialized frame (including header and payload),
101/// performing any further decoding if necessary, and returning to the caller.
102///
103/// All errors are treated as fatal - resolvable or ignorable errors should not
104/// be returned to the caller.
105#[cfg(target_family = "wasm")]
106pub trait WireRx: 'static {
107    /// Receive error type
108    type Error: std::error::Error; // or std?
109    /// Receive a single frame
110    fn receive(&mut self) -> impl Future<Output = Result<Vec<u8>, Self::Error>>;
111}
112
113/// Wire Spawn Interface
114///
115/// Should be suitable for spawning a task in the host executor.
116#[cfg(target_family = "wasm")]
117pub trait WireSpawn: 'static {
118    /// Spawn a task
119    fn spawn(&mut self, fut: impl Future<Output = ()> + 'static);
120}
121
122/// Wire Transmit Interface
123///
124/// Responsible for taking a serialized frame (including header and payload),
125/// performing any further encoding if necessary, and transmitting to the device.
126///
127/// Should complete once the message is fully sent (e.g. not just enqueued)
128/// if possible.
129///
130/// All errors are treated as fatal - resolvable or ignorable errors should not
131/// be returned to the caller.
132#[cfg(not(target_family = "wasm"))]
133pub trait WireTx: Send + 'static {
134    /// Transmit error type
135    type Error: std::error::Error;
136    /// Send a single frame
137    fn send(&mut self, data: Vec<u8>) -> impl Future<Output = Result<(), Self::Error>> + Send;
138}
139
140/// Wire Receive Interface
141///
142/// Responsible for accumulating a serialized frame (including header and payload),
143/// performing any further decoding if necessary, and returning to the caller.
144///
145/// All errors are treated as fatal - resolvable or ignorable errors should not
146/// be returned to the caller.
147#[cfg(not(target_family = "wasm"))]
148pub trait WireRx: Send + 'static {
149    /// Receive error type
150    type Error: std::error::Error;
151    /// Receive a single frame
152    fn receive(&mut self) -> impl Future<Output = Result<Vec<u8>, Self::Error>> + Send;
153}
154
155/// Wire Spawn Interface
156///
157/// Should be suitable for spawning a task in the host executor.
158#[cfg(not(target_family = "wasm"))]
159pub trait WireSpawn: 'static {
160    /// Spawn a task
161    fn spawn(&mut self, fut: impl Future<Output = ()> + Send + 'static);
162}
163
164/// The [HostClient] is the primary PC-side interface.
165///
166/// It is generic over a single type, `WireErr`, which can be used by the
167/// embedded system when a request was not understood, or some other error
168/// has occurred.
169///
170/// [HostClient]s can be cloned, and used across multiple tasks/threads.
171///
172/// There are currently two ways to create one, based on the transport used:
173///
174/// 1. With raw USB Bulk transfers: [`HostClient::new_raw_nusb()`] (**recommended**)
175/// 2. With cobs CDC-ACM transfers: [`HostClient::new_serial_cobs()`]
176pub struct HostClient<WireErr> {
177    ctx: Arc<HostContext>,
178    out: mpsc::Sender<RpcFrame>,
179    subscriptions: Arc<Mutex<Subscriptions>>,
180    err_key: Key,
181    stopper: Stopper,
182    seq_kind: VarSeqKind,
183    _pd: PhantomData<fn() -> WireErr>,
184}
185
186impl<W> core::fmt::Debug for HostClient<W> {
187    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
188        f.debug_struct("HostClient").finish_non_exhaustive()
189    }
190}
191
192/// # Constructor Methods
193impl<WireErr> HostClient<WireErr>
194where
195    WireErr: DeserializeOwned + Schema,
196{
197    /// Private method for creating internal context
198    pub(crate) fn new_manual_priv(config: &HostClientConfig) -> (Self, WireContext) {
199        let (tx_pc, rx_pc) = tokio::sync::mpsc::channel(config.outgoing_depth);
200
201        let ctx = Arc::new(HostContext {
202            kkind: RwLock::new(VarKeyKind::Key8),
203            map: WaitMap::new(),
204            seq: AtomicU32::new(0),
205            subscription_timeout: config.subscriber_timeout_if_full,
206        });
207
208        let err_key = Key::for_path::<WireErr>(config.err_uri_path);
209
210        let me = HostClient {
211            ctx: ctx.clone(),
212            out: tx_pc,
213            err_key,
214            _pd: PhantomData,
215            subscriptions: Arc::new(Mutex::new(Subscriptions::default())),
216            stopper: Stopper::new(),
217            seq_kind: config.seq_kind,
218        };
219
220        let wire = WireContext {
221            outgoing: rx_pc,
222            incoming: ctx,
223        };
224
225        (me, wire)
226    }
227}
228
229/// Errors related to retrieving the schema
230#[derive(Debug, Error)]
231pub enum SchemaError<WireErr> {
232    /// Some kind of communication error occurred
233    #[error("A communication error occurred")]
234    Comms(#[from] HostErr<WireErr>),
235    /// An error occurred internally. Please open an issue.
236    #[error("An error occurred internally. Please open an issue.")]
237    TaskError,
238    /// Invalid report data was received, including endpoints or
239    /// tasks that referred to unknown types. Please open an issue
240    #[error("Invalid report data was received. Please open an issue.")]
241    InvalidReportData,
242    /// Data was lost while transmitting. If a retry does not solve
243    /// this, please open an issue.
244    #[error(
245        "Data was lost while transmitting. If a retry does not solve this, please open an issue."
246    )]
247    LostData,
248}
249
250impl<WireErr> From<UnableToFindType> for SchemaError<WireErr> {
251    fn from(_: UnableToFindType) -> Self {
252        Self::InvalidReportData
253    }
254}
255
256/// # Interface Methods
257impl<WireErr> HostClient<WireErr>
258where
259    WireErr: DeserializeOwned + Schema,
260{
261    /// Obtain a [`SchemaReport`] describing the connected device
262    pub async fn get_schema_report(&self) -> Result<SchemaReport, SchemaError<WireErr>> {
263        let Ok(mut sub) = self.subscribe_multi::<GetAllSchemaDataTopic>(64).await else {
264            return Err(SchemaError::Comms(HostErr::Closed));
265        };
266
267        let collect_task = tokio::task::spawn({
268            async move {
269                let mut got = vec![];
270                while let Ok(Ok(val)) =
271                    tokio::time::timeout(Duration::from_millis(500), sub.recv()).await
272                {
273                    got.push(val);
274                }
275                got
276            }
277        });
278        let trigger_task = self.send_resp::<GetAllSchemasEndpoint>(&()).await;
279        let data = collect_task.await;
280        let (resp, data) = match (trigger_task, data) {
281            (Ok(a), Ok(b)) => (a, b),
282            (Ok(_), Err(_)) => return Err(SchemaError::TaskError),
283            (Err(e), Ok(_)) => return Err(SchemaError::Comms(e)),
284            (Err(e1), Err(_e2)) => return Err(SchemaError::Comms(e1)),
285        };
286        let mut rpt = SchemaReport::default();
287        let mut e_and_t = vec![];
288
289        for d in data {
290            match d {
291                OwnedSchemaData::Type(d) => {
292                    rpt.add_type(d);
293                }
294                e @ OwnedSchemaData::Endpoint { .. } => e_and_t.push(e),
295                t @ OwnedSchemaData::Topic { .. } => e_and_t.push(t),
296            }
297        }
298
299        for e in e_and_t {
300            match e {
301                OwnedSchemaData::Type(_) => unreachable!(),
302                OwnedSchemaData::Endpoint {
303                    path,
304                    request_key,
305                    response_key,
306                } => {
307                    rpt.add_endpoint(path, request_key, response_key)?;
308                }
309                OwnedSchemaData::Topic {
310                    path,
311                    key,
312                    direction,
313                } => match direction {
314                    TopicDirection::ToServer => rpt.add_topic_in(path, key)?,
315                    TopicDirection::ToClient => rpt.add_topic_out(path, key)?,
316                },
317            }
318        }
319
320        let mut data_matches = true;
321        data_matches &= resp.endpoints_sent as usize == rpt.endpoints.len();
322        data_matches &= resp.topics_in_sent as usize == rpt.topics_in.len();
323        data_matches &= resp.topics_out_sent as usize == rpt.topics_out.len();
324        data_matches &= resp.errors == 0;
325
326        if data_matches {
327            // TODO: filter primitive types out?
328            Ok(rpt)
329        } else {
330            Err(SchemaError::LostData)
331        }
332    }
333
334    /// Send a message of type [Endpoint::Request][Endpoint] to `path`, and await
335    /// a response of type [Endpoint::Response][Endpoint] (or WireErr) to `path`.
336    ///
337    /// This function will wait potentially forever. Consider using with a timeout.
338    pub async fn send_resp<E: Endpoint>(
339        &self,
340        t: &E::Request,
341    ) -> Result<E::Response, HostErr<WireErr>>
342    where
343        E::Request: Serialize + Schema,
344        E::Response: DeserializeOwned + Schema,
345    {
346        let seq_no = self.ctx.seq.fetch_add(1, Ordering::Relaxed);
347
348        let msg = postcard::to_stdvec(&t).expect("Allocations should not ever fail");
349        let frame = RpcFrame {
350            // NOTE: send_resp_raw automatically shrinks down key and sequence
351            // kinds to the appropriate amount
352            header: VarHeader {
353                key: VarKey::Key8(E::REQ_KEY),
354                seq_no: VarSeq::Seq4(seq_no),
355            },
356            body: msg,
357        };
358        let frame = self.send_resp_raw(frame, E::RESP_KEY).await?;
359        let r = postcard::from_bytes::<E::Response>(&frame.body)?;
360        Ok(r)
361    }
362
363    /// Perform an endpoint request/response,but without handling the
364    /// Ser/De automatically
365    pub async fn send_resp_raw(
366        &self,
367        mut rqst: RpcFrame,
368        resp_key: Key,
369    ) -> Result<RpcFrame, HostErr<WireErr>> {
370        let cancel_fut = self.stopper.wait_stopped();
371        let kkind: VarKeyKind = *self.ctx.kkind.read().unwrap();
372        rqst.header.key.shrink_to(kkind);
373        let mut resp_key = VarKey::Key8(resp_key);
374        let mut err_key = VarKey::Key8(self.err_key);
375        resp_key.shrink_to(kkind);
376        err_key.shrink_to(kkind);
377
378        // Prepare to receive the reply, BEFORE we send the request.
379        // This uses the `enqueue` feature of WaitMap, which makes sure that
380        // our receiver is ready to "catch" before we even send the request.
381        let ok_resp = self.ctx.map.wait(VarHeader {
382            seq_no: rqst.header.seq_no,
383            key: resp_key,
384        });
385        let err_resp = self.ctx.map.wait(VarHeader {
386            seq_no: rqst.header.seq_no,
387            key: err_key,
388        });
389        let mut ok_resp = std::pin::pin!(ok_resp);
390        let mut err_resp = std::pin::pin!(err_resp);
391        let setup_fut: Result<(), WaitError> = async {
392            ok_resp.as_mut().subscribe().await?;
393            err_resp.as_mut().subscribe().await?;
394            Ok(())
395        }
396        .await;
397
398        // If registering for the response failed, return an error
399        if let Err(e) = setup_fut {
400            return Err(match e {
401                WaitError::Closed => HostErr::Closed,
402                WaitError::Duplicate => {
403                    tracing::error!("Attempted to register a duplicate wait for a reply. This can happen if sequence numbers are reused.");
404                    // TODO: This is the wrong kind of error, but we don't want to report closed.
405                    // Fix this in the next breaking change of postcard-rpc, or make HostErr non-exhaustive
406                    HostErr::BadResponse
407                }
408
409                // These should never happen: NeverAdded and AlreadyConsumed
410                _ => {
411                    tracing::error!("Internal error setting up reply: {e:?}, closing");
412                    self.close();
413                    HostErr::Closed
414                }
415            });
416        };
417
418        self.out.send(rqst).await.map_err(|_| HostErr::Closed)?;
419
420        select! {
421            _c = cancel_fut => Err(HostErr::Closed),
422            o = ok_resp => {
423                let (hdr, resp) = o?;
424                if hdr.key.kind() != kkind {
425                    *self.ctx.kkind.write().unwrap() = hdr.key.kind();
426                }
427                Ok(RpcFrame { header: hdr, body: resp })
428            },
429            e = err_resp => {
430                let (hdr, resp) = e?;
431                if hdr.key.kind() != kkind {
432                    *self.ctx.kkind.write().unwrap() = hdr.key.kind();
433                }
434                let r = postcard::from_bytes::<WireErr>(&resp)?;
435                Err(HostErr::Wire(r))
436            },
437        }
438    }
439
440    /// Publish a [Topic] [Message][Topic::Message].
441    ///
442    /// There is no feedback if the server received our message. If the I/O worker is
443    /// closed, an error is returned.
444    pub async fn publish<T: Topic>(&self, seq_no: VarSeq, msg: &T::Message) -> Result<(), IoClosed>
445    where
446        T::Message: Serialize,
447    {
448        let smsg = postcard::to_stdvec(msg).expect("alloc should never fail");
449        let frame = RpcFrame {
450            header: VarHeader {
451                key: VarKey::Key8(T::TOPIC_KEY),
452                seq_no,
453            },
454            body: smsg,
455        };
456        self.publish_raw(frame).await
457    }
458
459    /// Publish the given raw frame
460    pub async fn publish_raw(&self, mut frame: RpcFrame) -> Result<(), IoClosed> {
461        let kkind: VarKeyKind = *self.ctx.kkind.read().unwrap();
462        frame.header.key.shrink_to(kkind);
463
464        let cancel_fut = self.stopper.wait_stopped();
465        let operate_fut = self.out.send(frame);
466
467        select! {
468            _ = cancel_fut => Err(IoClosed),
469            res = operate_fut => res.map_err(|_| IoClosed),
470        }
471    }
472
473    ///////////////////////////////////////////////////////////////////////////
474    // Subscribe Multi
475    ///////////////////////////////////////////////////////////////////////////
476
477    /// Begin listening to a [Topic], receiving a [Subscription] that will give a
478    /// stream of [Message][Topic::Message]s. Unlike `subscribe`, multiple subscribers
479    /// to the same stream are allowed, and behave as a broadcast channel.
480    ///
481    /// Returns an Error if the I/O worker is closed.
482    pub async fn subscribe_multi<T: Topic>(
483        &self,
484        depth: usize,
485    ) -> Result<MultiSubscription<T::Message>, IoClosed>
486    where
487        T::Message: DeserializeOwned,
488    {
489        let cancel_fut = self.stopper.wait_stopped();
490        let operate_fut = self.subscribe_multi_inner::<T>(depth);
491        select! {
492            _ = cancel_fut => Err(IoClosed),
493            res = operate_fut => res,
494        }
495    }
496
497    /// Inner function version of [Self::subscribe_multi]
498    async fn subscribe_multi_inner<T: Topic>(
499        &self,
500        depth: usize,
501    ) -> Result<MultiSubscription<T::Message>, IoClosed>
502    where
503        T::Message: DeserializeOwned,
504    {
505        let rx = {
506            let mut guard = self.subscriptions.lock().await;
507            if guard.stopped {
508                return Err(IoClosed);
509            }
510            if let Some(entry) = guard
511                .broadcast_list
512                .iter_mut()
513                .find(|(k, _)| *k == T::TOPIC_KEY)
514            {
515                entry.1.subscribe()
516            } else {
517                let (tx, rx) = broadcast::channel(depth);
518                guard.broadcast_list.push((T::TOPIC_KEY, tx));
519                rx
520            }
521        };
522        Ok(MultiSubscription {
523            rx,
524            _pd: PhantomData,
525        })
526    }
527
528    /// Subscribe to the given [`Key`], without automatically handling deserialization
529    pub async fn subscribe_multi_raw(
530        &self,
531        key: Key,
532        depth: usize,
533    ) -> Result<RawMultiSubscription, IoClosed> {
534        let cancel_fut = self.stopper.wait_stopped();
535        let operate_fut = self.subscribe_multi_inner_raw(key, depth);
536        select! {
537            _ = cancel_fut => Err(IoClosed),
538            res = operate_fut => res,
539        }
540    }
541
542    /// Inner function version of [Self::subscribe]
543    async fn subscribe_multi_inner_raw(
544        &self,
545        key: Key,
546        depth: usize,
547    ) -> Result<RawMultiSubscription, IoClosed> {
548        let rx = {
549            let mut guard = self.subscriptions.lock().await;
550            if guard.stopped {
551                return Err(IoClosed);
552            }
553            if let Some(entry) = guard.broadcast_list.iter_mut().find(|(k, _)| *k == key) {
554                entry.1.subscribe()
555            } else {
556                let (tx, rx) = broadcast::channel(depth);
557                guard.broadcast_list.push((key, tx));
558                rx
559            }
560        };
561        Ok(RawMultiSubscription { rx })
562    }
563
564    ///////////////////////////////////////////////////////////////////////////
565    // Subscribe (Legacy)
566    ///////////////////////////////////////////////////////////////////////////
567
568    /// Begin listening to a [Topic], receiving a [Subscription] that will give a
569    /// stream of [Message][Topic::Message]s.
570    ///
571    /// If you subscribe to the same topic multiple times, the previous subscription
572    /// will be closed (there can be only one). This does not apply to subscriptions
573    /// created with `subscribe_multi`. This also WILL close subscriptions opened by
574    /// [`subscribe_exclusive`](Self::subscribe_exclusive).
575    ///
576    /// Returns an Error if the I/O worker is closed.
577    #[deprecated = "In future versions, `subscribe` will be removed. Use `subscribe_multi` or `subscribe_exclusive` instead."]
578    pub async fn subscribe<T: Topic>(
579        &self,
580        depth: usize,
581    ) -> Result<Subscription<T::Message>, IoClosed>
582    where
583        T::Message: DeserializeOwned,
584    {
585        let cancel_fut = self.stopper.wait_stopped();
586        let operate_fut = self.subscribe_inner::<T>(depth);
587        select! {
588            _ = cancel_fut => Err(IoClosed),
589            res = operate_fut => res,
590        }
591    }
592
593    /// Inner function version of [Self::subscribe]
594    async fn subscribe_inner<T: Topic>(
595        &self,
596        depth: usize,
597    ) -> Result<Subscription<T::Message>, IoClosed>
598    where
599        T::Message: DeserializeOwned,
600    {
601        let (tx, rx) = tokio::sync::mpsc::channel(depth);
602        {
603            let mut guard = self.subscriptions.lock().await;
604            if guard.stopped {
605                return Err(IoClosed);
606            }
607            if let Some(entry) = guard
608                .exclusive_list
609                .iter_mut()
610                .find(|(k, _)| *k == T::TOPIC_KEY)
611            {
612                if !entry.1.is_closed() {
613                    tracing::warn!("replacing subscription for topic path '{}'", T::PATH);
614                }
615                entry.1 = tx;
616            } else {
617                guard.exclusive_list.push((T::TOPIC_KEY, tx));
618            }
619        }
620        Ok(Subscription {
621            rx,
622            _pd: PhantomData,
623        })
624    }
625
626    /// Subscribe to the given [`Key`], without automatically handling deserialization.
627    ///
628    /// If you subscribe to the same topic multiple times, the previous subscription
629    /// will be closed (there can be only one). This does not apply to subscriptions
630    /// created with `subscribe_multi`.
631    ///
632    /// Returns an Error if the I/O worker is closed.
633    #[deprecated = "In future versions, `subscribe_raw` will be removed. Use `subscribe_multi_raw` or `subscribe_exclusive_raw` instead."]
634    pub async fn subscribe_raw(&self, key: Key, depth: usize) -> Result<RawSubscription, IoClosed> {
635        let cancel_fut = self.stopper.wait_stopped();
636        let operate_fut = self.subscribe_inner_raw(key, depth);
637        select! {
638            _ = cancel_fut => Err(IoClosed),
639            res = operate_fut => res,
640        }
641    }
642
643    /// Inner function version of [Self::subscribe_raw]
644    async fn subscribe_inner_raw(
645        &self,
646        key: Key,
647        depth: usize,
648    ) -> Result<RawSubscription, IoClosed> {
649        let (tx, rx) = tokio::sync::mpsc::channel(depth);
650        {
651            let mut guard = self.subscriptions.lock().await;
652            if guard.stopped {
653                return Err(IoClosed);
654            }
655            if let Some(entry) = guard.exclusive_list.iter_mut().find(|(k, _)| *k == key) {
656                if !entry.1.is_closed() {
657                    tracing::warn!("replacing subscription for raw topic key '{:?}'", key);
658                }
659                entry.1 = tx;
660            } else {
661                guard.exclusive_list.push((key, tx));
662            }
663        }
664        Ok(RawSubscription { rx })
665    }
666
667    ///////////////////////////////////////////////////////////////////////////
668    // Subscribe Exclusive
669    ///////////////////////////////////////////////////////////////////////////
670
671    /// Begin listening to a [Topic], receiving a [Subscription] that will give a
672    /// stream of [Message][Topic::Message]s.
673    ///
674    /// If you try to subscribe to the same topic multiple times, this function returns a
675    /// [`SubscribeError::AlreadySubscribed`] (there can be only one).
676    /// This does not apply to subscriptions created with `subscribe_multi`.
677    ///
678    /// Returns an Error if the I/O worker is closed.
679    pub async fn subscribe_exclusive<T: Topic>(
680        &self,
681        depth: usize,
682    ) -> Result<Subscription<T::Message>, SubscribeError>
683    where
684        T::Message: DeserializeOwned,
685    {
686        let cancel_fut = self.stopper.wait_stopped();
687        let operate_fut = self.subscribe_inner_exclusive::<T>(depth);
688        select! {
689            _ = cancel_fut => Err(SubscribeError::IoClosed),
690            res = operate_fut => res,
691        }
692    }
693
694    /// Inner function version of [Self::subscribe_exclusive]
695    async fn subscribe_inner_exclusive<T: Topic>(
696        &self,
697        depth: usize,
698    ) -> Result<Subscription<T::Message>, SubscribeError>
699    where
700        T::Message: DeserializeOwned,
701    {
702        let (tx, rx) = tokio::sync::mpsc::channel(depth);
703        {
704            let mut guard = self.subscriptions.lock().await;
705            if guard.stopped {
706                return Err(SubscribeError::IoClosed);
707            }
708            if let Some(entry) = guard
709                .exclusive_list
710                .iter_mut()
711                .find(|(k, _)| *k == T::TOPIC_KEY)
712            {
713                if !entry.1.is_closed() {
714                    return Err(SubscribeError::AlreadySubscribed);
715                }
716                entry.1 = tx;
717            } else {
718                guard.exclusive_list.push((T::TOPIC_KEY, tx));
719            }
720        }
721        Ok(Subscription {
722            rx,
723            _pd: PhantomData,
724        })
725    }
726
727    /// Subscribe to the given [`Key`], without automatically handling deserialization.
728    ///
729    /// If you try to subscribe to the same topic multiple times, this function returns a
730    /// [`SubscribeError::AlreadySubscribed`] (there can be only one).
731    /// This does not apply to subscriptions created with `subscribe_multi`.
732    ///
733    /// Returns an Error if the I/O worker is closed.
734    pub async fn subscribe_exclusive_raw(
735        &self,
736        key: Key,
737        depth: usize,
738    ) -> Result<RawSubscription, SubscribeError> {
739        let cancel_fut = self.stopper.wait_stopped();
740        let operate_fut = self.subscribe_inner_exclusive_raw(key, depth);
741        select! {
742            _ = cancel_fut => Err(SubscribeError::IoClosed),
743            res = operate_fut => res,
744        }
745    }
746
747    /// Inner function version of [Self::subscribe_exclusive_raw]
748    async fn subscribe_inner_exclusive_raw(
749        &self,
750        key: Key,
751        depth: usize,
752    ) -> Result<RawSubscription, SubscribeError> {
753        let (tx, rx) = tokio::sync::mpsc::channel(depth);
754        {
755            let mut guard = self.subscriptions.lock().await;
756            if guard.stopped {
757                return Err(SubscribeError::IoClosed);
758            }
759            if let Some(entry) = guard.exclusive_list.iter_mut().find(|(k, _)| *k == key) {
760                if !entry.1.is_closed() {
761                    return Err(SubscribeError::AlreadySubscribed);
762                }
763                entry.1 = tx;
764            } else {
765                guard.exclusive_list.push((key, tx));
766            }
767        }
768        Ok(RawSubscription { rx })
769    }
770
771    /// Permanently close the connection to the client
772    ///
773    /// All other HostClients sharing the connection (e.g. created by cloning
774    /// a single HostClient) will also stop, and no further communication will
775    /// succeed. The in-flight messages will not be flushed.
776    ///
777    /// This will also signal any I/O worker tasks to halt immediately as well.
778    pub fn close(&self) {
779        self.stopper.stop()
780    }
781
782    /// Has this host client been closed?
783    pub fn is_closed(&self) -> bool {
784        self.stopper.is_stopped()
785    }
786
787    /// Wait for the host client to be closed
788    pub async fn wait_closed(&self) {
789        self.stopper.wait_stopped().await;
790    }
791}
792
793/// Like Subscription, but receives Raw frames that are not
794/// automatically deserialized
795pub struct RawSubscription {
796    rx: mpsc::Receiver<RpcFrame>,
797}
798
799impl RawSubscription {
800    /// Await a message for the given subscription.
801    ///
802    /// Returns [None]` if the subscription was closed
803    pub async fn recv(&mut self) -> Option<RpcFrame> {
804        self.rx.recv().await
805    }
806}
807
808/// A structure that represents a subscription to the given topic
809pub struct Subscription<M> {
810    rx: mpsc::Receiver<RpcFrame>,
811    _pd: PhantomData<M>,
812}
813
814impl<M> Subscription<M>
815where
816    M: DeserializeOwned,
817{
818    /// Await a message for the given subscription.
819    ///
820    /// Returns [None]` if the subscription was closed
821    pub async fn recv(&mut self) -> Option<M> {
822        loop {
823            let frame = self.rx.recv().await?;
824            if let Ok(m) = postcard::from_bytes(&frame.body) {
825                return Some(m);
826            }
827        }
828    }
829}
830
831/// Like MultiSubscription, but receives Raw frames that are not
832/// automatically deserialized
833pub struct RawMultiSubscription {
834    rx: broadcast::Receiver<RpcFrame>,
835}
836
837impl RawMultiSubscription {
838    /// Await a message for the given subscription.
839    ///
840    /// Returns [None]` if the subscription was closed
841    pub async fn recv(&mut self) -> Result<RpcFrame, MultiSubRxError> {
842        match self.rx.recv().await {
843            Ok(f) => Ok(f),
844            Err(broadcast::error::RecvError::Closed) => Err(MultiSubRxError::IoClosed),
845            Err(broadcast::error::RecvError::Lagged(n)) => Err(MultiSubRxError::Lagged(n)),
846        }
847    }
848}
849
850/// A structure that represents a subscription to the given topic
851pub struct MultiSubscription<M> {
852    rx: broadcast::Receiver<RpcFrame>,
853    _pd: PhantomData<M>,
854}
855
856/// Recv
857#[derive(Debug, PartialEq, Error)]
858pub enum MultiSubRxError {
859    /// The receiver was closed
860    #[error("Receiver closed")]
861    IoClosed,
862    /// Lagged behind, this many messages were lost
863    #[error("Lagged behind, lost {0} messages")]
864    Lagged(u64),
865}
866
867impl<M> MultiSubscription<M>
868where
869    M: DeserializeOwned,
870{
871    /// Await a message for the given subscription.
872    ///
873    /// Returns [None]` if the subscription was closed
874    pub async fn recv(&mut self) -> Result<M, MultiSubRxError> {
875        loop {
876            let frame = match self.rx.recv().await {
877                Ok(f) => f,
878                Err(broadcast::error::RecvError::Closed) => return Err(MultiSubRxError::IoClosed),
879                Err(broadcast::error::RecvError::Lagged(n)) => {
880                    return Err(MultiSubRxError::Lagged(n))
881                }
882            };
883            if let Ok(m) = postcard::from_bytes(&frame.body) {
884                return Ok(m);
885            }
886        }
887    }
888}
889
890// Manual Clone impl because WireErr may not impl Clone
891impl<WireErr> Clone for HostClient<WireErr> {
892    fn clone(&self) -> Self {
893        Self {
894            ctx: self.ctx.clone(),
895            out: self.out.clone(),
896            err_key: self.err_key,
897            _pd: PhantomData,
898            subscriptions: self.subscriptions.clone(),
899            stopper: self.stopper.clone(),
900            seq_kind: self.seq_kind,
901        }
902    }
903}
904
905/// Items necessary for implementing a custom I/O Task
906pub struct WireContext {
907    /// This is a stream of frames that should be placed on the
908    /// wire towards the server.
909    pub outgoing: mpsc::Receiver<RpcFrame>,
910    /// This shared information contains the WaitMap used for replying to
911    /// open requests.
912    pub incoming: Arc<HostContext>,
913}
914
915/// A single postcard-rpc frame
916#[derive(Clone)]
917pub struct RpcFrame {
918    /// The wire header
919    pub header: VarHeader,
920    /// The serialized message payload
921    pub body: Vec<u8>,
922}
923
924impl RpcFrame {
925    /// Serialize the `RpcFrame` into a Vec of bytes
926    pub fn to_bytes(&self) -> Vec<u8> {
927        let mut out = self.header.write_to_vec();
928        out.extend_from_slice(&self.body);
929        out
930    }
931}
932
933/// Shared context between [HostClient] and the I/O worker task
934pub struct HostContext {
935    kkind: RwLock<VarKeyKind>,
936    map: WaitMap<VarHeader, (VarHeader, Vec<u8>)>,
937    seq: AtomicU32,
938    subscription_timeout: Duration,
939}
940
941impl core::fmt::Debug for HostContext {
942    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
943        f.debug_struct("HostContext").finish_non_exhaustive()
944    }
945}
946
947/// The I/O worker has closed.
948#[derive(Debug, Error)]
949#[error("The I/O worker has closed")]
950pub struct IoClosed;
951
952/// The I/O worker has closed.
953#[derive(Debug, Error)]
954pub enum SubscribeError {
955    /// The subscription was already active
956    #[error("The subscription was already active")]
957    AlreadySubscribed,
958    /// The I/O worker has closed.
959    #[error("The I/O worker has closed")]
960    IoClosed,
961}
962
963/// Error for [HostContext::process].
964#[derive(Debug, PartialEq, Error)]
965pub enum ProcessError {
966    /// All [HostClient]s have been dropped, no further requests
967    /// will be made and no responses will be processed.
968    #[error("All clients have been dropped")]
969    Closed,
970}
971
972impl HostContext {
973    /// Like `HostContext::process` but tells you if we processed the message or
974    /// nobody wanted it
975    pub fn process_did_wake(&self, frame: RpcFrame) -> Result<bool, ProcessError> {
976        match self.map.wake(&frame.header, (frame.header, frame.body)) {
977            WakeOutcome::Woke => Ok(true),
978            WakeOutcome::NoMatch(_) => Ok(false),
979            WakeOutcome::Closed(_) => Err(ProcessError::Closed),
980        }
981    }
982
983    /// Process the message, returns Ok if the message was taken or dropped.
984    ///
985    /// Returns an Err if the map was closed.
986    pub fn process(&self, frame: RpcFrame) -> Result<(), ProcessError> {
987        if let WakeOutcome::Closed(_) = self.map.wake(&frame.header, (frame.header, frame.body)) {
988            Err(ProcessError::Closed)
989        } else {
990            Ok(())
991        }
992    }
993}
994
995/// A report describing the schema spoken by the connected device
996#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Schema)]
997pub struct SchemaReport {
998    /// All custom types spoken by the device (on any endpoint or topic),
999    /// as well as all primitive types. In the future, primitive types may
1000    /// be removed.
1001    pub types: HashSet<OwnedNamedType>,
1002    /// All incoming (client to server) topics reported by the device
1003    pub topics_in: Vec<TopicReport>,
1004    /// All outgoing (server to client) topics reported by the device
1005    pub topics_out: Vec<TopicReport>,
1006    /// All endpoints reported by the device
1007    pub endpoints: Vec<EndpointReport>,
1008}
1009
1010impl Default for SchemaReport {
1011    fn default() -> Self {
1012        let mut me = Self {
1013            types: Default::default(),
1014            topics_in: Default::default(),
1015            topics_out: Default::default(),
1016            endpoints: Default::default(),
1017        };
1018
1019        // We need to pre-populate all of the types we consider primitives:
1020        // DataModelType::Bool
1021        me.add_type(OwnedNamedType::from(<bool as Schema>::SCHEMA));
1022        // DataModelType::I8
1023        me.add_type(OwnedNamedType::from(<i8 as Schema>::SCHEMA));
1024        // DataModelType::U8
1025        me.add_type(OwnedNamedType::from(<u8 as Schema>::SCHEMA));
1026        // DataModelType::I16
1027        me.add_type(OwnedNamedType::from(<i16 as Schema>::SCHEMA));
1028        // DataModelType::I32
1029        me.add_type(OwnedNamedType::from(<i32 as Schema>::SCHEMA));
1030        // DataModelType::I64
1031        me.add_type(OwnedNamedType::from(<i64 as Schema>::SCHEMA));
1032        // DataModelType::I128
1033        me.add_type(OwnedNamedType::from(<i128 as Schema>::SCHEMA));
1034        // DataModelType::U16
1035        me.add_type(OwnedNamedType::from(<u16 as Schema>::SCHEMA));
1036        // DataModelType::U32
1037        me.add_type(OwnedNamedType::from(<u32 as Schema>::SCHEMA));
1038        // DataModelType::U64
1039        me.add_type(OwnedNamedType::from(<u64 as Schema>::SCHEMA));
1040        // DataModelType::U128
1041        me.add_type(OwnedNamedType::from(<u128 as Schema>::SCHEMA));
1042        // // DataModelType::Usize
1043        // me.add_type(OwnedNamedType::from(<usize as Schema>::SCHEMA));
1044        // // DataModelType::Isize
1045        // me.add_type(OwnedNamedType::from(<isize as Schema>::SCHEMA));
1046        // DataModelType::F32
1047        me.add_type(OwnedNamedType::from(<f32 as Schema>::SCHEMA));
1048        // DataModelType::F64
1049        me.add_type(OwnedNamedType::from(<f64 as Schema>::SCHEMA));
1050        // DataModelType::Char
1051        me.add_type(OwnedNamedType::from(<char as Schema>::SCHEMA));
1052        // DataModelType::String
1053        me.add_type(OwnedNamedType::from(<String as Schema>::SCHEMA));
1054        // DataModelType::ByteArray
1055        me.add_type(OwnedNamedType::from(<Vec<u8> as Schema>::SCHEMA));
1056        // DataModelType::Unit
1057        me.add_type(OwnedNamedType::from(<() as Schema>::SCHEMA));
1058        // DataModelType::Schema
1059        me.add_type(OwnedNamedType::from(<OwnedNamedType as Schema>::SCHEMA));
1060
1061        me
1062    }
1063}
1064
1065/// A description of a single Topic
1066#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Schema)]
1067pub struct TopicReport {
1068    /// The human readable path of the topic
1069    pub path: String,
1070    /// The Key of the topic (which hashes the path and type)
1071    pub key: Key,
1072    /// The schema of the type of the message
1073    pub ty: OwnedNamedType,
1074}
1075
1076/// A description of a single Endpoint
1077#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Schema)]
1078pub struct EndpointReport {
1079    /// The human readable path of the endpoint
1080    pub path: String,
1081    /// The Key of the request (which hashes the path and type)
1082    pub req_key: Key,
1083    /// The schema of the request type
1084    pub req_ty: OwnedNamedType,
1085    /// The Key of the response (which hashes the path and type)
1086    pub resp_key: Key,
1087    /// The schema of the response type
1088    pub resp_ty: OwnedNamedType,
1089}
1090
1091/// An error that denotes we were unable to resolve the type used by a given key
1092#[derive(Debug)]
1093pub struct UnableToFindType;
1094
1095impl SchemaReport {
1096    /// Insert a new type
1097    pub fn add_type(&mut self, t: OwnedNamedType) {
1098        self.types.insert(t);
1099    }
1100
1101    /// Insert a new incoming (client to server) topic
1102    ///
1103    /// Returns an error if we are unable to find the type used for this topic
1104    pub fn add_topic_in(&mut self, path: String, key: Key) -> Result<(), UnableToFindType> {
1105        // We need to figure out which type goes with this topic
1106        for ty in self.types.iter() {
1107            let calc_key = Key::for_owned_schema_path(&path, ty);
1108            if calc_key == key {
1109                self.topics_in.push(TopicReport {
1110                    path,
1111                    key,
1112                    ty: ty.clone(),
1113                });
1114                return Ok(());
1115            }
1116        }
1117        Err(UnableToFindType)
1118    }
1119
1120    /// Insert a new outgoing (server to client) topic
1121    ///
1122    /// Returns an error if we are unable to find the type used for this topic
1123    pub fn add_topic_out(&mut self, path: String, key: Key) -> Result<(), UnableToFindType> {
1124        // We need to figure out which type goes with this topic
1125        for ty in self.types.iter() {
1126            let calc_key = Key::for_owned_schema_path(&path, ty);
1127            if calc_key == key {
1128                self.topics_out.push(TopicReport {
1129                    path,
1130                    key,
1131                    ty: ty.clone(),
1132                });
1133                return Ok(());
1134            }
1135        }
1136        Err(UnableToFindType)
1137    }
1138
1139    /// Insert a new endpoint
1140    ///
1141    /// Returns an error if we are unable to find the type used for the request/response
1142    pub fn add_endpoint(
1143        &mut self,
1144        path: String,
1145        req_key: Key,
1146        resp_key: Key,
1147    ) -> Result<(), UnableToFindType> {
1148        // We need to figure out which types go with this endpoint
1149        let mut req_ty = None;
1150        for ty in self.types.iter() {
1151            let calc_key = Key::for_owned_schema_path(&path, ty);
1152            if calc_key == req_key {
1153                req_ty = Some(ty.clone());
1154                break;
1155            }
1156        }
1157        let Some(req_ty) = req_ty else {
1158            return Err(UnableToFindType);
1159        };
1160
1161        let mut resp_ty = None;
1162        for ty in self.types.iter() {
1163            let calc_key = Key::for_owned_schema_path(&path, ty);
1164            if calc_key == resp_key {
1165                resp_ty = Some(ty.clone());
1166                break;
1167            }
1168        }
1169        let Some(resp_ty) = resp_ty else {
1170            return Err(UnableToFindType);
1171        };
1172
1173        self.endpoints.push(EndpointReport {
1174            path,
1175            req_key,
1176            req_ty,
1177            resp_key,
1178            resp_ty,
1179        });
1180        Ok(())
1181    }
1182}