rapace_testkit/
bidirectional.rs1use std::sync::Arc;
22
23use rapace_core::{ErrorCode, Frame, FrameFlags, MsgDescHot, RpcError, Transport};
24
25use crate::RpcSession;
26use crate::{TestError, TransportFactory};
27
28#[derive(Debug, Clone, Copy, PartialEq, Eq)]
30pub enum BidirectionalScenario {
31 SimpleEcho,
33
34 NestedCallback,
36
37 MultipleNestedCallbacks,
39}
40
41pub async fn run_bidirectional_scenario<F: TransportFactory>(scenario: BidirectionalScenario) {
43 let result = match scenario {
44 BidirectionalScenario::SimpleEcho => run_simple_echo::<F>().await,
45 BidirectionalScenario::NestedCallback => run_nested_callback::<F>().await,
46 BidirectionalScenario::MultipleNestedCallbacks => run_multiple_nested::<F>().await,
47 };
48
49 if let Err(e) = result {
50 panic!("bidirectional scenario {:?} failed: {}", scenario, e);
51 }
52}
53
54async fn run_simple_echo<F: TransportFactory>() -> Result<(), TestError> {
59 let (transport_a, transport_b) = F::connect_pair().await?;
60 let transport_a = Arc::new(transport_a);
61 let transport_b = Arc::new(transport_b);
62
63 let session_a = Arc::new(RpcSession::with_channel_start(transport_a.clone(), 1));
65
66 let session_b = Arc::new(RpcSession::with_channel_start(transport_b.clone(), 2));
68 session_b.set_dispatcher(|_channel_id, _method_id, payload| async move {
69 let mut desc = MsgDescHot::new();
71 desc.flags = FrameFlags::DATA | FrameFlags::EOS;
72 Ok(Frame::with_payload(desc, payload))
73 });
74
75 let session_a_clone = session_a.clone();
77 let handle_a = tokio::spawn(async move { session_a_clone.run().await });
78
79 let session_b_clone = session_b.clone();
80 let handle_b = tokio::spawn(async move { session_b_clone.run().await });
81
82 let channel_id = session_a.next_channel_id();
84 let response = session_a
85 .call(channel_id, 1, b"hello".to_vec())
86 .await
87 .map_err(TestError::Rpc)?;
88
89 if response.payload != b"hello" {
90 return Err(TestError::Assertion(format!(
91 "expected echo 'hello', got {:?}",
92 response.payload
93 )));
94 }
95
96 let _ = transport_a.close().await;
98 let _ = transport_b.close().await;
99 handle_a.abort();
100 handle_b.abort();
101
102 Ok(())
103}
104
105async fn run_nested_callback<F: TransportFactory>() -> Result<(), TestError> {
110 let (transport_a, transport_b) = F::connect_pair().await?;
111 let transport_a = Arc::new(transport_a);
112 let transport_b = Arc::new(transport_b);
113
114 let session_a = Arc::new(RpcSession::with_channel_start(transport_a.clone(), 1));
117 session_a.set_dispatcher(|_channel_id, method_id, _payload| async move {
118 if method_id == 1 {
120 let prefix = b"PREFIX:";
121 let mut desc = MsgDescHot::new();
122 desc.flags = FrameFlags::DATA | FrameFlags::EOS;
123 Ok(Frame::with_payload(desc, prefix.to_vec()))
124 } else {
125 Err(RpcError::Status {
126 code: ErrorCode::Unimplemented,
127 message: "unknown method".into(),
128 })
129 }
130 });
131
132 let session_b = Arc::new(RpcSession::with_channel_start(transport_b.clone(), 2));
135 let session_b_for_dispatcher = session_b.clone();
136 session_b.set_dispatcher(move |_channel_id, method_id, payload| {
137 let session = session_b_for_dispatcher.clone();
138 async move {
139 if method_id == 1 {
141 let cb_channel = session.next_channel_id();
143 let cb_response =
144 session
145 .call(cb_channel, 1, vec![])
146 .await
147 .map_err(|e| RpcError::Status {
148 code: ErrorCode::Internal,
149 message: format!("callback failed: {:?}", e),
150 })?;
151
152 let mut result = cb_response.payload;
154 result.extend(payload);
155
156 let mut desc = MsgDescHot::new();
157 desc.flags = FrameFlags::DATA | FrameFlags::EOS;
158 Ok(Frame::with_payload(desc, result))
159 } else {
160 Err(RpcError::Status {
161 code: ErrorCode::Unimplemented,
162 message: "unknown method".into(),
163 })
164 }
165 }
166 });
167
168 let session_a_clone = session_a.clone();
170 let handle_a = tokio::spawn(async move { session_a_clone.run().await });
171
172 let session_b_clone = session_b.clone();
173 let handle_b = tokio::spawn(async move { session_b_clone.run().await });
174
175 let channel_id = session_a.next_channel_id();
177 let response = session_a
178 .call(channel_id, 1, b"test".to_vec())
179 .await
180 .map_err(TestError::Rpc)?;
181
182 if response.payload != b"PREFIX:test" {
183 return Err(TestError::Assertion(format!(
184 "expected 'PREFIX:test', got {:?}",
185 String::from_utf8_lossy(&response.payload)
186 )));
187 }
188
189 let _ = transport_a.close().await;
191 let _ = transport_b.close().await;
192 handle_a.abort();
193 handle_b.abort();
194
195 Ok(())
196}
197
198async fn run_multiple_nested<F: TransportFactory>() -> Result<(), TestError> {
203 let (transport_a, transport_b) = F::connect_pair().await?;
204 let transport_a = Arc::new(transport_a);
205 let transport_b = Arc::new(transport_b);
206
207 let session_a = Arc::new(RpcSession::with_channel_start(transport_a.clone(), 1));
210 session_a.set_dispatcher(|_channel_id, method_id, payload| async move {
211 if method_id == 1 {
213 let mut result = b"value_".to_vec();
215 result.extend(payload);
216 let mut desc = MsgDescHot::new();
217 desc.flags = FrameFlags::DATA | FrameFlags::EOS;
218 Ok(Frame::with_payload(desc, result))
219 } else {
220 Err(RpcError::Status {
221 code: ErrorCode::Unimplemented,
222 message: "unknown method".into(),
223 })
224 }
225 });
226
227 let session_b = Arc::new(RpcSession::with_channel_start(transport_b.clone(), 2));
230 let session_b_for_dispatcher = session_b.clone();
231 session_b.set_dispatcher(move |_channel_id, method_id, _payload| {
232 let session = session_b_for_dispatcher.clone();
233 async move {
234 if method_id == 1 {
236 let mut result = Vec::new();
237
238 for key in [b"a".as_slice(), b"b", b"c"] {
240 let cb_channel = session.next_channel_id();
241 let cb_response =
242 session
243 .call(cb_channel, 1, key.to_vec())
244 .await
245 .map_err(|e| RpcError::Status {
246 code: ErrorCode::Internal,
247 message: format!("callback failed: {:?}", e),
248 })?;
249 result.extend(&cb_response.payload);
250 result.push(b',');
251 }
252
253 if !result.is_empty() {
255 result.pop();
256 }
257
258 let mut desc = MsgDescHot::new();
259 desc.flags = FrameFlags::DATA | FrameFlags::EOS;
260 Ok(Frame::with_payload(desc, result))
261 } else {
262 Err(RpcError::Status {
263 code: ErrorCode::Unimplemented,
264 message: "unknown method".into(),
265 })
266 }
267 }
268 });
269
270 let session_a_clone = session_a.clone();
272 let handle_a = tokio::spawn(async move { session_a_clone.run().await });
273
274 let session_b_clone = session_b.clone();
275 let handle_b = tokio::spawn(async move { session_b_clone.run().await });
276
277 let channel_id = session_a.next_channel_id();
279 let response = session_a
280 .call(channel_id, 1, vec![])
281 .await
282 .map_err(TestError::Rpc)?;
283
284 let expected = b"value_a,value_b,value_c";
285 if response.payload != expected {
286 return Err(TestError::Assertion(format!(
287 "expected '{}', got '{}'",
288 String::from_utf8_lossy(expected),
289 String::from_utf8_lossy(&response.payload)
290 )));
291 }
292
293 let _ = transport_a.close().await;
295 let _ = transport_b.close().await;
296 handle_a.abort();
297 handle_b.abort();
298
299 Ok(())
300}
301
302#[cfg(test)]
303mod tests {
304 use super::*;
305 use rapace_transport_mem::InProcTransport;
306
307 struct InProcFactory;
308
309 impl TransportFactory for InProcFactory {
310 type Transport = InProcTransport;
311
312 async fn connect_pair() -> Result<(Self::Transport, Self::Transport), TestError> {
313 Ok(InProcTransport::pair())
314 }
315 }
316
317 #[tokio::test]
318 async fn test_simple_echo_inproc() {
319 run_bidirectional_scenario::<InProcFactory>(BidirectionalScenario::SimpleEcho).await;
320 }
321
322 #[tokio::test]
323 async fn test_nested_callback_inproc() {
324 run_bidirectional_scenario::<InProcFactory>(BidirectionalScenario::NestedCallback).await;
325 }
326
327 #[tokio::test]
328 async fn test_multiple_nested_inproc() {
329 run_bidirectional_scenario::<InProcFactory>(BidirectionalScenario::MultipleNestedCallbacks)
330 .await;
331 }
332}