postcard_rpc/
test_utils.rs

1//! Test utilities for doctests and integration tests
2
3use core::{fmt::Display, future::Future};
4
5use crate::header::{VarHeader, VarKey, VarSeq, VarSeqKind};
6use crate::host_client::util::Stopper;
7use crate::{
8    host_client::{HostClient, RpcFrame, WireRx, WireSpawn, WireTx},
9    Endpoint, Topic,
10};
11use postcard_schema::Schema;
12use serde::{de::DeserializeOwned, Serialize};
13use tokio::{
14    select,
15    sync::mpsc::{channel, Receiver, Sender},
16};
17
18/// Rx Helper type
19pub struct LocalRx {
20    fake_error: Stopper,
21    from_server: Receiver<Vec<u8>>,
22}
23/// Tx Helper type
24pub struct LocalTx {
25    fake_error: Stopper,
26    to_server: Sender<Vec<u8>>,
27}
28/// Spawn helper type
29pub struct LocalSpawn;
30/// Server type
31pub struct LocalFakeServer {
32    fake_error: Stopper,
33    /// from client to server
34    pub from_client: Receiver<Vec<u8>>,
35    /// from server to client
36    pub to_client: Sender<Vec<u8>>,
37}
38
39impl LocalFakeServer {
40    /// receive a frame
41    pub async fn recv_from_client(&mut self) -> Result<RpcFrame, LocalError> {
42        let msg = self.from_client.recv().await.ok_or(LocalError::TxClosed)?;
43        let Some((hdr, body)) = VarHeader::take_from_slice(&msg) else {
44            return Err(LocalError::BadFrame);
45        };
46        Ok(RpcFrame {
47            header: hdr,
48            body: body.to_vec(),
49        })
50    }
51
52    /// Reply
53    pub async fn reply<E: Endpoint>(
54        &mut self,
55        seq_no: u32,
56        data: &E::Response,
57    ) -> Result<(), LocalError>
58    where
59        E::Response: Serialize,
60    {
61        let frame = RpcFrame {
62            header: VarHeader {
63                key: VarKey::Key8(E::RESP_KEY),
64                seq_no: VarSeq::Seq4(seq_no),
65            },
66            body: postcard::to_stdvec(data).unwrap(),
67        };
68        self.to_client
69            .send(frame.to_bytes())
70            .await
71            .map_err(|_| LocalError::RxClosed)
72    }
73
74    /// Publish
75    pub async fn publish<T: Topic>(
76        &mut self,
77        seq_no: u32,
78        data: &T::Message,
79    ) -> Result<(), LocalError>
80    where
81        T::Message: Serialize,
82    {
83        let frame = RpcFrame {
84            header: VarHeader {
85                key: VarKey::Key8(T::TOPIC_KEY),
86                seq_no: VarSeq::Seq4(seq_no),
87            },
88            body: postcard::to_stdvec(data).unwrap(),
89        };
90        self.to_client
91            .send(frame.to_bytes())
92            .await
93            .map_err(|_| LocalError::RxClosed)
94    }
95
96    /// oops
97    pub fn cause_fatal_error(&self) {
98        self.fake_error.stop();
99    }
100}
101
102/// Local error type
103#[derive(Debug, PartialEq)]
104pub enum LocalError {
105    /// RxClosed
106    RxClosed,
107    /// TxClosed
108    TxClosed,
109    /// BadFrame
110    BadFrame,
111    /// FatalError
112    FatalError,
113}
114
115impl Display for LocalError {
116    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
117        <Self as core::fmt::Debug>::fmt(self, f)
118    }
119}
120
121impl std::error::Error for LocalError {}
122
123impl WireRx for LocalRx {
124    type Error = LocalError;
125
126    #[allow(clippy::manual_async_fn)]
127    fn receive(&mut self) -> impl Future<Output = Result<Vec<u8>, Self::Error>> + Send {
128        async {
129            // This is not usually necessary - HostClient machinery takes care of listening
130            // to the stopper, but we have an EXTRA one to simulate I/O failure
131            let recv_fut = self.from_server.recv();
132            let error_fut = self.fake_error.wait_stopped();
133
134            // Before we await, do a quick check to see if an error occured, this way
135            // recv can't accidentally win the select
136            if self.fake_error.is_stopped() {
137                return Err(LocalError::FatalError);
138            }
139
140            select! {
141                recv = recv_fut => recv.ok_or(LocalError::RxClosed),
142                _err = error_fut => Err(LocalError::FatalError),
143            }
144        }
145    }
146}
147
148impl WireTx for LocalTx {
149    type Error = LocalError;
150
151    #[allow(clippy::manual_async_fn)]
152    fn send(&mut self, data: Vec<u8>) -> impl Future<Output = Result<(), Self::Error>> + Send {
153        async {
154            // This is not usually necessary - HostClient machinery takes care of listening
155            // to the stopper, but we have an EXTRA one to simulate I/O failure
156            let send_fut = self.to_server.send(data);
157            let error_fut = self.fake_error.wait_stopped();
158
159            // Before we await, do a quick check to see if an error occured, this way
160            // send can't accidentally win the select
161            if self.fake_error.is_stopped() {
162                return Err(LocalError::FatalError);
163            }
164
165            select! {
166                send = send_fut => send.map_err(|_| LocalError::TxClosed),
167                _err = error_fut => Err(LocalError::FatalError),
168            }
169        }
170    }
171}
172
173impl WireSpawn for LocalSpawn {
174    fn spawn(&mut self, fut: impl Future<Output = ()> + Send + 'static) {
175        tokio::task::spawn(fut);
176    }
177}
178
179/// This function creates a directly-linked Server and Client.
180///
181/// This is useful for testing and demonstrating server/client behavior,
182/// without actually requiring an external device.
183pub fn local_setup<E>(bound: usize, err_uri_path: &str) -> (LocalFakeServer, HostClient<E>)
184where
185    E: Schema + DeserializeOwned,
186{
187    let (c2s_tx, c2s_rx) = channel(bound);
188    let (s2c_tx, s2c_rx) = channel(bound);
189
190    // NOTE: the normal HostClient machinery has it's own Stopper used for signalling
191    // errors, this is an EXTRA stopper we use to simulate the error occurring, like
192    // if our USB device disconnected or the serial port was closed
193    let fake_error = Stopper::new();
194
195    let client = HostClient::<E>::new_with_wire(
196        LocalTx {
197            to_server: c2s_tx,
198            fake_error: fake_error.clone(),
199        },
200        LocalRx {
201            from_server: s2c_rx,
202            fake_error: fake_error.clone(),
203        },
204        LocalSpawn,
205        VarSeqKind::Seq2,
206        err_uri_path,
207        bound,
208    );
209
210    let lfs = LocalFakeServer {
211        from_client: c2s_rx,
212        to_client: s2c_tx,
213        fake_error: fake_error.clone(),
214    };
215
216    (lfs, client)
217}