postcard_rpc/
test_utils.rs1use core::{fmt::Display, future::Future};
4
5use crate::header::{VarHeader, VarKey, VarSeq, VarSeqKind};
6use crate::host_client::util::Stopper;
7use crate::{
8 host_client::{HostClient, RpcFrame, WireRx, WireSpawn, WireTx},
9 Endpoint, Topic,
10};
11use postcard_schema::Schema;
12use serde::{de::DeserializeOwned, Serialize};
13use tokio::{
14 select,
15 sync::mpsc::{channel, Receiver, Sender},
16};
17
18pub struct LocalRx {
20 fake_error: Stopper,
21 from_server: Receiver<Vec<u8>>,
22}
23pub struct LocalTx {
25 fake_error: Stopper,
26 to_server: Sender<Vec<u8>>,
27}
28pub struct LocalSpawn;
30pub struct LocalFakeServer {
32 fake_error: Stopper,
33 pub from_client: Receiver<Vec<u8>>,
35 pub to_client: Sender<Vec<u8>>,
37}
38
39impl LocalFakeServer {
40 pub async fn recv_from_client(&mut self) -> Result<RpcFrame, LocalError> {
42 let msg = self.from_client.recv().await.ok_or(LocalError::TxClosed)?;
43 let Some((hdr, body)) = VarHeader::take_from_slice(&msg) else {
44 return Err(LocalError::BadFrame);
45 };
46 Ok(RpcFrame {
47 header: hdr,
48 body: body.to_vec(),
49 })
50 }
51
52 pub async fn reply<E: Endpoint>(
54 &mut self,
55 seq_no: u32,
56 data: &E::Response,
57 ) -> Result<(), LocalError>
58 where
59 E::Response: Serialize,
60 {
61 let frame = RpcFrame {
62 header: VarHeader {
63 key: VarKey::Key8(E::RESP_KEY),
64 seq_no: VarSeq::Seq4(seq_no),
65 },
66 body: postcard::to_stdvec(data).unwrap(),
67 };
68 self.to_client
69 .send(frame.to_bytes())
70 .await
71 .map_err(|_| LocalError::RxClosed)
72 }
73
74 pub async fn publish<T: Topic>(
76 &mut self,
77 seq_no: u32,
78 data: &T::Message,
79 ) -> Result<(), LocalError>
80 where
81 T::Message: Serialize,
82 {
83 let frame = RpcFrame {
84 header: VarHeader {
85 key: VarKey::Key8(T::TOPIC_KEY),
86 seq_no: VarSeq::Seq4(seq_no),
87 },
88 body: postcard::to_stdvec(data).unwrap(),
89 };
90 self.to_client
91 .send(frame.to_bytes())
92 .await
93 .map_err(|_| LocalError::RxClosed)
94 }
95
96 pub fn cause_fatal_error(&self) {
98 self.fake_error.stop();
99 }
100}
101
102#[derive(Debug, PartialEq)]
104pub enum LocalError {
105 RxClosed,
107 TxClosed,
109 BadFrame,
111 FatalError,
113}
114
115impl Display for LocalError {
116 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
117 <Self as core::fmt::Debug>::fmt(self, f)
118 }
119}
120
121impl std::error::Error for LocalError {}
122
123impl WireRx for LocalRx {
124 type Error = LocalError;
125
126 #[allow(clippy::manual_async_fn)]
127 fn receive(&mut self) -> impl Future<Output = Result<Vec<u8>, Self::Error>> + Send {
128 async {
129 let recv_fut = self.from_server.recv();
132 let error_fut = self.fake_error.wait_stopped();
133
134 if self.fake_error.is_stopped() {
137 return Err(LocalError::FatalError);
138 }
139
140 select! {
141 recv = recv_fut => recv.ok_or(LocalError::RxClosed),
142 _err = error_fut => Err(LocalError::FatalError),
143 }
144 }
145 }
146}
147
148impl WireTx for LocalTx {
149 type Error = LocalError;
150
151 #[allow(clippy::manual_async_fn)]
152 fn send(&mut self, data: Vec<u8>) -> impl Future<Output = Result<(), Self::Error>> + Send {
153 async {
154 let send_fut = self.to_server.send(data);
157 let error_fut = self.fake_error.wait_stopped();
158
159 if self.fake_error.is_stopped() {
162 return Err(LocalError::FatalError);
163 }
164
165 select! {
166 send = send_fut => send.map_err(|_| LocalError::TxClosed),
167 _err = error_fut => Err(LocalError::FatalError),
168 }
169 }
170 }
171}
172
173impl WireSpawn for LocalSpawn {
174 fn spawn(&mut self, fut: impl Future<Output = ()> + Send + 'static) {
175 tokio::task::spawn(fut);
176 }
177}
178
179pub fn local_setup<E>(bound: usize, err_uri_path: &str) -> (LocalFakeServer, HostClient<E>)
184where
185 E: Schema + DeserializeOwned,
186{
187 let (c2s_tx, c2s_rx) = channel(bound);
188 let (s2c_tx, s2c_rx) = channel(bound);
189
190 let fake_error = Stopper::new();
194
195 let client = HostClient::<E>::new_with_wire(
196 LocalTx {
197 to_server: c2s_tx,
198 fake_error: fake_error.clone(),
199 },
200 LocalRx {
201 from_server: s2c_rx,
202 fake_error: fake_error.clone(),
203 },
204 LocalSpawn,
205 VarSeqKind::Seq2,
206 err_uri_path,
207 bound,
208 );
209
210 let lfs = LocalFakeServer {
211 from_client: c2s_rx,
212 to_client: s2c_tx,
213 fake_error: fake_error.clone(),
214 };
215
216 (lfs, client)
217}