1use core::time::Duration;
7use std::{
8 collections::HashSet,
9 future::Future,
10 marker::PhantomData,
11 sync::{
12 atomic::{AtomicU32, Ordering},
13 Arc, RwLock,
14 },
15};
16use thiserror::Error;
17
18use maitake_sync::{
19 wait_map::{WaitError, WakeOutcome},
20 WaitMap,
21};
22use postcard_schema::{schema::owned::OwnedNamedType, Schema};
23use serde::{de::DeserializeOwned, Deserialize, Serialize};
24use tokio::{
25 select,
26 sync::{broadcast, mpsc, Mutex},
27};
28use util::Subscriptions;
29
30use crate::{
31 header::{VarHeader, VarKey, VarKeyKind, VarSeq, VarSeqKind},
32 standard_icd::{GetAllSchemaDataTopic, GetAllSchemasEndpoint, OwnedSchemaData},
33 Endpoint, Key, Topic, TopicDirection,
34};
35
36use self::util::Stopper;
37pub use crate::host_client::util::HostClientConfig;
38
39#[cfg(all(feature = "raw-nusb", not(target_family = "wasm")))]
40mod raw_nusb;
41
42#[cfg(all(feature = "cobs-serial", not(target_family = "wasm")))]
43mod serial;
44
45#[cfg(all(feature = "webusb", target_family = "wasm"))]
46pub mod webusb;
47
48pub(crate) mod util;
49
50#[cfg(feature = "test-utils")]
51pub mod test_channels;
52
53#[derive(Debug, PartialEq, Error)]
55pub enum HostErr<WireErr> {
56 #[error("a wire error occurred")]
58 Wire(WireErr),
59 #[error("the response received didn't match the expected value or the wire error type")]
65 BadResponse,
66 #[error("message deserialization failed")]
68 Postcard(#[from] postcard::Error),
69 #[error("the interface has been closed, and no further messages are possible")]
71 Closed,
72}
73
74impl<T> From<WaitError> for HostErr<T> {
75 fn from(_: WaitError) -> Self {
76 Self::Closed
77 }
78}
79
80#[cfg(target_family = "wasm")]
91pub trait WireTx: 'static {
92 type Error: std::error::Error;
94 fn send(&mut self, data: Vec<u8>) -> impl Future<Output = Result<(), Self::Error>>;
96}
97
98#[cfg(target_family = "wasm")]
106pub trait WireRx: 'static {
107 type Error: std::error::Error; fn receive(&mut self) -> impl Future<Output = Result<Vec<u8>, Self::Error>>;
111}
112
113#[cfg(target_family = "wasm")]
117pub trait WireSpawn: 'static {
118 fn spawn(&mut self, fut: impl Future<Output = ()> + 'static);
120}
121
122#[cfg(not(target_family = "wasm"))]
133pub trait WireTx: Send + 'static {
134 type Error: std::error::Error;
136 fn send(&mut self, data: Vec<u8>) -> impl Future<Output = Result<(), Self::Error>> + Send;
138}
139
140#[cfg(not(target_family = "wasm"))]
148pub trait WireRx: Send + 'static {
149 type Error: std::error::Error;
151 fn receive(&mut self) -> impl Future<Output = Result<Vec<u8>, Self::Error>> + Send;
153}
154
155#[cfg(not(target_family = "wasm"))]
159pub trait WireSpawn: 'static {
160 fn spawn(&mut self, fut: impl Future<Output = ()> + Send + 'static);
162}
163
164pub struct HostClient<WireErr> {
177 ctx: Arc<HostContext>,
178 out: mpsc::Sender<RpcFrame>,
179 subscriptions: Arc<Mutex<Subscriptions>>,
180 err_key: Key,
181 stopper: Stopper,
182 seq_kind: VarSeqKind,
183 _pd: PhantomData<fn() -> WireErr>,
184}
185
186impl<W> core::fmt::Debug for HostClient<W> {
187 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
188 f.debug_struct("HostClient").finish_non_exhaustive()
189 }
190}
191
192impl<WireErr> HostClient<WireErr>
194where
195 WireErr: DeserializeOwned + Schema,
196{
197 pub(crate) fn new_manual_priv(config: &HostClientConfig) -> (Self, WireContext) {
199 let (tx_pc, rx_pc) = tokio::sync::mpsc::channel(config.outgoing_depth);
200
201 let ctx = Arc::new(HostContext {
202 kkind: RwLock::new(VarKeyKind::Key8),
203 map: WaitMap::new(),
204 seq: AtomicU32::new(0),
205 subscription_timeout: config.subscriber_timeout_if_full,
206 });
207
208 let err_key = Key::for_path::<WireErr>(config.err_uri_path);
209
210 let me = HostClient {
211 ctx: ctx.clone(),
212 out: tx_pc,
213 err_key,
214 _pd: PhantomData,
215 subscriptions: Arc::new(Mutex::new(Subscriptions::default())),
216 stopper: Stopper::new(),
217 seq_kind: config.seq_kind,
218 };
219
220 let wire = WireContext {
221 outgoing: rx_pc,
222 incoming: ctx,
223 };
224
225 (me, wire)
226 }
227}
228
229#[derive(Debug, Error)]
231pub enum SchemaError<WireErr> {
232 #[error("A communication error occurred")]
234 Comms(#[from] HostErr<WireErr>),
235 #[error("An error occurred internally. Please open an issue.")]
237 TaskError,
238 #[error("Invalid report data was received. Please open an issue.")]
241 InvalidReportData,
242 #[error(
245 "Data was lost while transmitting. If a retry does not solve this, please open an issue."
246 )]
247 LostData,
248}
249
250impl<WireErr> From<UnableToFindType> for SchemaError<WireErr> {
251 fn from(_: UnableToFindType) -> Self {
252 Self::InvalidReportData
253 }
254}
255
256impl<WireErr> HostClient<WireErr>
258where
259 WireErr: DeserializeOwned + Schema,
260{
261 pub async fn get_schema_report(&self) -> Result<SchemaReport, SchemaError<WireErr>> {
263 let Ok(mut sub) = self.subscribe_multi::<GetAllSchemaDataTopic>(64).await else {
264 return Err(SchemaError::Comms(HostErr::Closed));
265 };
266
267 let collect_task = tokio::task::spawn({
268 async move {
269 let mut got = vec![];
270 while let Ok(Ok(val)) =
271 tokio::time::timeout(Duration::from_millis(500), sub.recv()).await
272 {
273 got.push(val);
274 }
275 got
276 }
277 });
278 let trigger_task = self.send_resp::<GetAllSchemasEndpoint>(&()).await;
279 let data = collect_task.await;
280 let (resp, data) = match (trigger_task, data) {
281 (Ok(a), Ok(b)) => (a, b),
282 (Ok(_), Err(_)) => return Err(SchemaError::TaskError),
283 (Err(e), Ok(_)) => return Err(SchemaError::Comms(e)),
284 (Err(e1), Err(_e2)) => return Err(SchemaError::Comms(e1)),
285 };
286 let mut rpt = SchemaReport::default();
287 let mut e_and_t = vec![];
288
289 for d in data {
290 match d {
291 OwnedSchemaData::Type(d) => {
292 rpt.add_type(d);
293 }
294 e @ OwnedSchemaData::Endpoint { .. } => e_and_t.push(e),
295 t @ OwnedSchemaData::Topic { .. } => e_and_t.push(t),
296 }
297 }
298
299 for e in e_and_t {
300 match e {
301 OwnedSchemaData::Type(_) => unreachable!(),
302 OwnedSchemaData::Endpoint {
303 path,
304 request_key,
305 response_key,
306 } => {
307 rpt.add_endpoint(path, request_key, response_key)?;
308 }
309 OwnedSchemaData::Topic {
310 path,
311 key,
312 direction,
313 } => match direction {
314 TopicDirection::ToServer => rpt.add_topic_in(path, key)?,
315 TopicDirection::ToClient => rpt.add_topic_out(path, key)?,
316 },
317 }
318 }
319
320 let mut data_matches = true;
321 data_matches &= resp.endpoints_sent as usize == rpt.endpoints.len();
322 data_matches &= resp.topics_in_sent as usize == rpt.topics_in.len();
323 data_matches &= resp.topics_out_sent as usize == rpt.topics_out.len();
324 data_matches &= resp.errors == 0;
325
326 if data_matches {
327 Ok(rpt)
329 } else {
330 Err(SchemaError::LostData)
331 }
332 }
333
334 pub async fn send_resp<E: Endpoint>(
339 &self,
340 t: &E::Request,
341 ) -> Result<E::Response, HostErr<WireErr>>
342 where
343 E::Request: Serialize + Schema,
344 E::Response: DeserializeOwned + Schema,
345 {
346 let seq_no = self.ctx.seq.fetch_add(1, Ordering::Relaxed);
347
348 let msg = postcard::to_stdvec(&t).expect("Allocations should not ever fail");
349 let frame = RpcFrame {
350 header: VarHeader {
353 key: VarKey::Key8(E::REQ_KEY),
354 seq_no: VarSeq::Seq4(seq_no),
355 },
356 body: msg,
357 };
358 let frame = self.send_resp_raw(frame, E::RESP_KEY).await?;
359 let r = postcard::from_bytes::<E::Response>(&frame.body)?;
360 Ok(r)
361 }
362
363 pub async fn send_resp_raw(
366 &self,
367 mut rqst: RpcFrame,
368 resp_key: Key,
369 ) -> Result<RpcFrame, HostErr<WireErr>> {
370 let cancel_fut = self.stopper.wait_stopped();
371 let kkind: VarKeyKind = *self.ctx.kkind.read().unwrap();
372 rqst.header.key.shrink_to(kkind);
373 let mut resp_key = VarKey::Key8(resp_key);
374 let mut err_key = VarKey::Key8(self.err_key);
375 resp_key.shrink_to(kkind);
376 err_key.shrink_to(kkind);
377
378 let ok_resp = self.ctx.map.wait(VarHeader {
382 seq_no: rqst.header.seq_no,
383 key: resp_key,
384 });
385 let err_resp = self.ctx.map.wait(VarHeader {
386 seq_no: rqst.header.seq_no,
387 key: err_key,
388 });
389 let mut ok_resp = std::pin::pin!(ok_resp);
390 let mut err_resp = std::pin::pin!(err_resp);
391 let setup_fut: Result<(), WaitError> = async {
392 ok_resp.as_mut().subscribe().await?;
393 err_resp.as_mut().subscribe().await?;
394 Ok(())
395 }
396 .await;
397
398 if let Err(e) = setup_fut {
400 return Err(match e {
401 WaitError::Closed => HostErr::Closed,
402 WaitError::Duplicate => {
403 tracing::error!("Attempted to register a duplicate wait for a reply. This can happen if sequence numbers are reused.");
404 HostErr::BadResponse
407 }
408
409 _ => {
411 tracing::error!("Internal error setting up reply: {e:?}, closing");
412 self.close();
413 HostErr::Closed
414 }
415 });
416 };
417
418 self.out.send(rqst).await.map_err(|_| HostErr::Closed)?;
419
420 select! {
421 _c = cancel_fut => Err(HostErr::Closed),
422 o = ok_resp => {
423 let (hdr, resp) = o?;
424 if hdr.key.kind() != kkind {
425 *self.ctx.kkind.write().unwrap() = hdr.key.kind();
426 }
427 Ok(RpcFrame { header: hdr, body: resp })
428 },
429 e = err_resp => {
430 let (hdr, resp) = e?;
431 if hdr.key.kind() != kkind {
432 *self.ctx.kkind.write().unwrap() = hdr.key.kind();
433 }
434 let r = postcard::from_bytes::<WireErr>(&resp)?;
435 Err(HostErr::Wire(r))
436 },
437 }
438 }
439
440 pub async fn publish<T: Topic>(&self, seq_no: VarSeq, msg: &T::Message) -> Result<(), IoClosed>
445 where
446 T::Message: Serialize,
447 {
448 let smsg = postcard::to_stdvec(msg).expect("alloc should never fail");
449 let frame = RpcFrame {
450 header: VarHeader {
451 key: VarKey::Key8(T::TOPIC_KEY),
452 seq_no,
453 },
454 body: smsg,
455 };
456 self.publish_raw(frame).await
457 }
458
459 pub async fn publish_raw(&self, mut frame: RpcFrame) -> Result<(), IoClosed> {
461 let kkind: VarKeyKind = *self.ctx.kkind.read().unwrap();
462 frame.header.key.shrink_to(kkind);
463
464 let cancel_fut = self.stopper.wait_stopped();
465 let operate_fut = self.out.send(frame);
466
467 select! {
468 _ = cancel_fut => Err(IoClosed),
469 res = operate_fut => res.map_err(|_| IoClosed),
470 }
471 }
472
473 pub async fn subscribe_multi<T: Topic>(
483 &self,
484 depth: usize,
485 ) -> Result<MultiSubscription<T::Message>, IoClosed>
486 where
487 T::Message: DeserializeOwned,
488 {
489 let cancel_fut = self.stopper.wait_stopped();
490 let operate_fut = self.subscribe_multi_inner::<T>(depth);
491 select! {
492 _ = cancel_fut => Err(IoClosed),
493 res = operate_fut => res,
494 }
495 }
496
497 async fn subscribe_multi_inner<T: Topic>(
499 &self,
500 depth: usize,
501 ) -> Result<MultiSubscription<T::Message>, IoClosed>
502 where
503 T::Message: DeserializeOwned,
504 {
505 let rx = {
506 let mut guard = self.subscriptions.lock().await;
507 if guard.stopped {
508 return Err(IoClosed);
509 }
510 if let Some(entry) = guard
511 .broadcast_list
512 .iter_mut()
513 .find(|(k, _)| *k == T::TOPIC_KEY)
514 {
515 entry.1.subscribe()
516 } else {
517 let (tx, rx) = broadcast::channel(depth);
518 guard.broadcast_list.push((T::TOPIC_KEY, tx));
519 rx
520 }
521 };
522 Ok(MultiSubscription {
523 rx,
524 _pd: PhantomData,
525 })
526 }
527
528 pub async fn subscribe_multi_raw(
530 &self,
531 key: Key,
532 depth: usize,
533 ) -> Result<RawMultiSubscription, IoClosed> {
534 let cancel_fut = self.stopper.wait_stopped();
535 let operate_fut = self.subscribe_multi_inner_raw(key, depth);
536 select! {
537 _ = cancel_fut => Err(IoClosed),
538 res = operate_fut => res,
539 }
540 }
541
542 async fn subscribe_multi_inner_raw(
544 &self,
545 key: Key,
546 depth: usize,
547 ) -> Result<RawMultiSubscription, IoClosed> {
548 let rx = {
549 let mut guard = self.subscriptions.lock().await;
550 if guard.stopped {
551 return Err(IoClosed);
552 }
553 if let Some(entry) = guard.broadcast_list.iter_mut().find(|(k, _)| *k == key) {
554 entry.1.subscribe()
555 } else {
556 let (tx, rx) = broadcast::channel(depth);
557 guard.broadcast_list.push((key, tx));
558 rx
559 }
560 };
561 Ok(RawMultiSubscription { rx })
562 }
563
564 #[deprecated = "In future versions, `subscribe` will be removed. Use `subscribe_multi` or `subscribe_exclusive` instead."]
578 pub async fn subscribe<T: Topic>(
579 &self,
580 depth: usize,
581 ) -> Result<Subscription<T::Message>, IoClosed>
582 where
583 T::Message: DeserializeOwned,
584 {
585 let cancel_fut = self.stopper.wait_stopped();
586 let operate_fut = self.subscribe_inner::<T>(depth);
587 select! {
588 _ = cancel_fut => Err(IoClosed),
589 res = operate_fut => res,
590 }
591 }
592
593 async fn subscribe_inner<T: Topic>(
595 &self,
596 depth: usize,
597 ) -> Result<Subscription<T::Message>, IoClosed>
598 where
599 T::Message: DeserializeOwned,
600 {
601 let (tx, rx) = tokio::sync::mpsc::channel(depth);
602 {
603 let mut guard = self.subscriptions.lock().await;
604 if guard.stopped {
605 return Err(IoClosed);
606 }
607 if let Some(entry) = guard
608 .exclusive_list
609 .iter_mut()
610 .find(|(k, _)| *k == T::TOPIC_KEY)
611 {
612 if !entry.1.is_closed() {
613 tracing::warn!("replacing subscription for topic path '{}'", T::PATH);
614 }
615 entry.1 = tx;
616 } else {
617 guard.exclusive_list.push((T::TOPIC_KEY, tx));
618 }
619 }
620 Ok(Subscription {
621 rx,
622 _pd: PhantomData,
623 })
624 }
625
626 #[deprecated = "In future versions, `subscribe_raw` will be removed. Use `subscribe_multi_raw` or `subscribe_exclusive_raw` instead."]
634 pub async fn subscribe_raw(&self, key: Key, depth: usize) -> Result<RawSubscription, IoClosed> {
635 let cancel_fut = self.stopper.wait_stopped();
636 let operate_fut = self.subscribe_inner_raw(key, depth);
637 select! {
638 _ = cancel_fut => Err(IoClosed),
639 res = operate_fut => res,
640 }
641 }
642
643 async fn subscribe_inner_raw(
645 &self,
646 key: Key,
647 depth: usize,
648 ) -> Result<RawSubscription, IoClosed> {
649 let (tx, rx) = tokio::sync::mpsc::channel(depth);
650 {
651 let mut guard = self.subscriptions.lock().await;
652 if guard.stopped {
653 return Err(IoClosed);
654 }
655 if let Some(entry) = guard.exclusive_list.iter_mut().find(|(k, _)| *k == key) {
656 if !entry.1.is_closed() {
657 tracing::warn!("replacing subscription for raw topic key '{:?}'", key);
658 }
659 entry.1 = tx;
660 } else {
661 guard.exclusive_list.push((key, tx));
662 }
663 }
664 Ok(RawSubscription { rx })
665 }
666
667 pub async fn subscribe_exclusive<T: Topic>(
680 &self,
681 depth: usize,
682 ) -> Result<Subscription<T::Message>, SubscribeError>
683 where
684 T::Message: DeserializeOwned,
685 {
686 let cancel_fut = self.stopper.wait_stopped();
687 let operate_fut = self.subscribe_inner_exclusive::<T>(depth);
688 select! {
689 _ = cancel_fut => Err(SubscribeError::IoClosed),
690 res = operate_fut => res,
691 }
692 }
693
694 async fn subscribe_inner_exclusive<T: Topic>(
696 &self,
697 depth: usize,
698 ) -> Result<Subscription<T::Message>, SubscribeError>
699 where
700 T::Message: DeserializeOwned,
701 {
702 let (tx, rx) = tokio::sync::mpsc::channel(depth);
703 {
704 let mut guard = self.subscriptions.lock().await;
705 if guard.stopped {
706 return Err(SubscribeError::IoClosed);
707 }
708 if let Some(entry) = guard
709 .exclusive_list
710 .iter_mut()
711 .find(|(k, _)| *k == T::TOPIC_KEY)
712 {
713 if !entry.1.is_closed() {
714 return Err(SubscribeError::AlreadySubscribed);
715 }
716 entry.1 = tx;
717 } else {
718 guard.exclusive_list.push((T::TOPIC_KEY, tx));
719 }
720 }
721 Ok(Subscription {
722 rx,
723 _pd: PhantomData,
724 })
725 }
726
727 pub async fn subscribe_exclusive_raw(
735 &self,
736 key: Key,
737 depth: usize,
738 ) -> Result<RawSubscription, SubscribeError> {
739 let cancel_fut = self.stopper.wait_stopped();
740 let operate_fut = self.subscribe_inner_exclusive_raw(key, depth);
741 select! {
742 _ = cancel_fut => Err(SubscribeError::IoClosed),
743 res = operate_fut => res,
744 }
745 }
746
747 async fn subscribe_inner_exclusive_raw(
749 &self,
750 key: Key,
751 depth: usize,
752 ) -> Result<RawSubscription, SubscribeError> {
753 let (tx, rx) = tokio::sync::mpsc::channel(depth);
754 {
755 let mut guard = self.subscriptions.lock().await;
756 if guard.stopped {
757 return Err(SubscribeError::IoClosed);
758 }
759 if let Some(entry) = guard.exclusive_list.iter_mut().find(|(k, _)| *k == key) {
760 if !entry.1.is_closed() {
761 return Err(SubscribeError::AlreadySubscribed);
762 }
763 entry.1 = tx;
764 } else {
765 guard.exclusive_list.push((key, tx));
766 }
767 }
768 Ok(RawSubscription { rx })
769 }
770
771 pub fn close(&self) {
779 self.stopper.stop()
780 }
781
782 pub fn is_closed(&self) -> bool {
784 self.stopper.is_stopped()
785 }
786
787 pub async fn wait_closed(&self) {
789 self.stopper.wait_stopped().await;
790 }
791}
792
793pub struct RawSubscription {
796 rx: mpsc::Receiver<RpcFrame>,
797}
798
799impl RawSubscription {
800 pub async fn recv(&mut self) -> Option<RpcFrame> {
804 self.rx.recv().await
805 }
806}
807
808pub struct Subscription<M> {
810 rx: mpsc::Receiver<RpcFrame>,
811 _pd: PhantomData<M>,
812}
813
814impl<M> Subscription<M>
815where
816 M: DeserializeOwned,
817{
818 pub async fn recv(&mut self) -> Option<M> {
822 loop {
823 let frame = self.rx.recv().await?;
824 if let Ok(m) = postcard::from_bytes(&frame.body) {
825 return Some(m);
826 }
827 }
828 }
829}
830
831pub struct RawMultiSubscription {
834 rx: broadcast::Receiver<RpcFrame>,
835}
836
837impl RawMultiSubscription {
838 pub async fn recv(&mut self) -> Result<RpcFrame, MultiSubRxError> {
842 match self.rx.recv().await {
843 Ok(f) => Ok(f),
844 Err(broadcast::error::RecvError::Closed) => Err(MultiSubRxError::IoClosed),
845 Err(broadcast::error::RecvError::Lagged(n)) => Err(MultiSubRxError::Lagged(n)),
846 }
847 }
848}
849
850pub struct MultiSubscription<M> {
852 rx: broadcast::Receiver<RpcFrame>,
853 _pd: PhantomData<M>,
854}
855
856#[derive(Debug, PartialEq, Error)]
858pub enum MultiSubRxError {
859 #[error("Receiver closed")]
861 IoClosed,
862 #[error("Lagged behind, lost {0} messages")]
864 Lagged(u64),
865}
866
867impl<M> MultiSubscription<M>
868where
869 M: DeserializeOwned,
870{
871 pub async fn recv(&mut self) -> Result<M, MultiSubRxError> {
875 loop {
876 let frame = match self.rx.recv().await {
877 Ok(f) => f,
878 Err(broadcast::error::RecvError::Closed) => return Err(MultiSubRxError::IoClosed),
879 Err(broadcast::error::RecvError::Lagged(n)) => {
880 return Err(MultiSubRxError::Lagged(n))
881 }
882 };
883 if let Ok(m) = postcard::from_bytes(&frame.body) {
884 return Ok(m);
885 }
886 }
887 }
888}
889
890impl<WireErr> Clone for HostClient<WireErr> {
892 fn clone(&self) -> Self {
893 Self {
894 ctx: self.ctx.clone(),
895 out: self.out.clone(),
896 err_key: self.err_key,
897 _pd: PhantomData,
898 subscriptions: self.subscriptions.clone(),
899 stopper: self.stopper.clone(),
900 seq_kind: self.seq_kind,
901 }
902 }
903}
904
905pub struct WireContext {
907 pub outgoing: mpsc::Receiver<RpcFrame>,
910 pub incoming: Arc<HostContext>,
913}
914
915#[derive(Clone)]
917pub struct RpcFrame {
918 pub header: VarHeader,
920 pub body: Vec<u8>,
922}
923
924impl RpcFrame {
925 pub fn to_bytes(&self) -> Vec<u8> {
927 let mut out = self.header.write_to_vec();
928 out.extend_from_slice(&self.body);
929 out
930 }
931}
932
933pub struct HostContext {
935 kkind: RwLock<VarKeyKind>,
936 map: WaitMap<VarHeader, (VarHeader, Vec<u8>)>,
937 seq: AtomicU32,
938 subscription_timeout: Duration,
939}
940
941impl core::fmt::Debug for HostContext {
942 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
943 f.debug_struct("HostContext").finish_non_exhaustive()
944 }
945}
946
947#[derive(Debug, Error)]
949#[error("The I/O worker has closed")]
950pub struct IoClosed;
951
952#[derive(Debug, Error)]
954pub enum SubscribeError {
955 #[error("The subscription was already active")]
957 AlreadySubscribed,
958 #[error("The I/O worker has closed")]
960 IoClosed,
961}
962
963#[derive(Debug, PartialEq, Error)]
965pub enum ProcessError {
966 #[error("All clients have been dropped")]
969 Closed,
970}
971
972impl HostContext {
973 pub fn process_did_wake(&self, frame: RpcFrame) -> Result<bool, ProcessError> {
976 match self.map.wake(&frame.header, (frame.header, frame.body)) {
977 WakeOutcome::Woke => Ok(true),
978 WakeOutcome::NoMatch(_) => Ok(false),
979 WakeOutcome::Closed(_) => Err(ProcessError::Closed),
980 }
981 }
982
983 pub fn process(&self, frame: RpcFrame) -> Result<(), ProcessError> {
987 if let WakeOutcome::Closed(_) = self.map.wake(&frame.header, (frame.header, frame.body)) {
988 Err(ProcessError::Closed)
989 } else {
990 Ok(())
991 }
992 }
993}
994
995#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Schema)]
997pub struct SchemaReport {
998 pub types: HashSet<OwnedNamedType>,
1002 pub topics_in: Vec<TopicReport>,
1004 pub topics_out: Vec<TopicReport>,
1006 pub endpoints: Vec<EndpointReport>,
1008}
1009
1010impl Default for SchemaReport {
1011 fn default() -> Self {
1012 let mut me = Self {
1013 types: Default::default(),
1014 topics_in: Default::default(),
1015 topics_out: Default::default(),
1016 endpoints: Default::default(),
1017 };
1018
1019 me.add_type(OwnedNamedType::from(<bool as Schema>::SCHEMA));
1022 me.add_type(OwnedNamedType::from(<i8 as Schema>::SCHEMA));
1024 me.add_type(OwnedNamedType::from(<u8 as Schema>::SCHEMA));
1026 me.add_type(OwnedNamedType::from(<i16 as Schema>::SCHEMA));
1028 me.add_type(OwnedNamedType::from(<i32 as Schema>::SCHEMA));
1030 me.add_type(OwnedNamedType::from(<i64 as Schema>::SCHEMA));
1032 me.add_type(OwnedNamedType::from(<i128 as Schema>::SCHEMA));
1034 me.add_type(OwnedNamedType::from(<u16 as Schema>::SCHEMA));
1036 me.add_type(OwnedNamedType::from(<u32 as Schema>::SCHEMA));
1038 me.add_type(OwnedNamedType::from(<u64 as Schema>::SCHEMA));
1040 me.add_type(OwnedNamedType::from(<u128 as Schema>::SCHEMA));
1042 me.add_type(OwnedNamedType::from(<f32 as Schema>::SCHEMA));
1048 me.add_type(OwnedNamedType::from(<f64 as Schema>::SCHEMA));
1050 me.add_type(OwnedNamedType::from(<char as Schema>::SCHEMA));
1052 me.add_type(OwnedNamedType::from(<String as Schema>::SCHEMA));
1054 me.add_type(OwnedNamedType::from(<Vec<u8> as Schema>::SCHEMA));
1056 me.add_type(OwnedNamedType::from(<() as Schema>::SCHEMA));
1058 me.add_type(OwnedNamedType::from(<OwnedNamedType as Schema>::SCHEMA));
1060
1061 me
1062 }
1063}
1064
1065#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Schema)]
1067pub struct TopicReport {
1068 pub path: String,
1070 pub key: Key,
1072 pub ty: OwnedNamedType,
1074}
1075
1076#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Schema)]
1078pub struct EndpointReport {
1079 pub path: String,
1081 pub req_key: Key,
1083 pub req_ty: OwnedNamedType,
1085 pub resp_key: Key,
1087 pub resp_ty: OwnedNamedType,
1089}
1090
1091#[derive(Debug)]
1093pub struct UnableToFindType;
1094
1095impl SchemaReport {
1096 pub fn add_type(&mut self, t: OwnedNamedType) {
1098 self.types.insert(t);
1099 }
1100
1101 pub fn add_topic_in(&mut self, path: String, key: Key) -> Result<(), UnableToFindType> {
1105 for ty in self.types.iter() {
1107 let calc_key = Key::for_owned_schema_path(&path, ty);
1108 if calc_key == key {
1109 self.topics_in.push(TopicReport {
1110 path,
1111 key,
1112 ty: ty.clone(),
1113 });
1114 return Ok(());
1115 }
1116 }
1117 Err(UnableToFindType)
1118 }
1119
1120 pub fn add_topic_out(&mut self, path: String, key: Key) -> Result<(), UnableToFindType> {
1124 for ty in self.types.iter() {
1126 let calc_key = Key::for_owned_schema_path(&path, ty);
1127 if calc_key == key {
1128 self.topics_out.push(TopicReport {
1129 path,
1130 key,
1131 ty: ty.clone(),
1132 });
1133 return Ok(());
1134 }
1135 }
1136 Err(UnableToFindType)
1137 }
1138
1139 pub fn add_endpoint(
1143 &mut self,
1144 path: String,
1145 req_key: Key,
1146 resp_key: Key,
1147 ) -> Result<(), UnableToFindType> {
1148 let mut req_ty = None;
1150 for ty in self.types.iter() {
1151 let calc_key = Key::for_owned_schema_path(&path, ty);
1152 if calc_key == req_key {
1153 req_ty = Some(ty.clone());
1154 break;
1155 }
1156 }
1157 let Some(req_ty) = req_ty else {
1158 return Err(UnableToFindType);
1159 };
1160
1161 let mut resp_ty = None;
1162 for ty in self.types.iter() {
1163 let calc_key = Key::for_owned_schema_path(&path, ty);
1164 if calc_key == resp_key {
1165 resp_ty = Some(ty.clone());
1166 break;
1167 }
1168 }
1169 let Some(resp_ty) = resp_ty else {
1170 return Err(UnableToFindType);
1171 };
1172
1173 self.endpoints.push(EndpointReport {
1174 path,
1175 req_key,
1176 req_ty,
1177 resp_key,
1178 resp_ty,
1179 });
1180 Ok(())
1181 }
1182}