rapace_testkit/
bidirectional.rs

1//! Bidirectional RPC test harness.
2//!
3//! This module provides a shared test harness for bidirectional RPC patterns
4//! where both peers can call each other (like the template engine example with
5//! host callbacks).
6//!
7//! # Usage
8//!
9//! ```ignore
10//! use rapace_testkit::bidirectional::{run_bidirectional_scenario, BidirectionalScenario};
11//!
12//! struct MyFactory;
13//! impl TransportFactory for MyFactory { ... }
14//!
15//! #[tokio::test]
16//! async fn test_bidirectional() {
17//!     run_bidirectional_scenario::<MyFactory>(BidirectionalScenario::NestedCallback).await;
18//! }
19//! ```
20
21use std::sync::Arc;
22
23use rapace_core::{ErrorCode, Frame, FrameFlags, MsgDescHot, RpcError, Transport};
24
25use crate::RpcSession;
26use crate::{TestError, TransportFactory};
27
28/// Scenarios for bidirectional RPC testing.
29#[derive(Debug, Clone, Copy, PartialEq, Eq)]
30pub enum BidirectionalScenario {
31    /// Simple echo: A calls B, B echoes back.
32    SimpleEcho,
33
34    /// A calls B, B calls A during processing (nested callback).
35    NestedCallback,
36
37    /// Multiple nested calls: A calls B, B calls A multiple times.
38    MultipleNestedCallbacks,
39}
40
41/// Run a bidirectional RPC scenario.
42pub async fn run_bidirectional_scenario<F: TransportFactory>(scenario: BidirectionalScenario) {
43    let result = match scenario {
44        BidirectionalScenario::SimpleEcho => run_simple_echo::<F>().await,
45        BidirectionalScenario::NestedCallback => run_nested_callback::<F>().await,
46        BidirectionalScenario::MultipleNestedCallbacks => run_multiple_nested::<F>().await,
47    };
48
49    if let Err(e) = result {
50        panic!("bidirectional scenario {:?} failed: {}", scenario, e);
51    }
52}
53
54// ============================================================================
55// Scenario: Simple Echo
56// ============================================================================
57
58async fn run_simple_echo<F: TransportFactory>() -> Result<(), TestError> {
59    let (transport_a, transport_b) = F::connect_pair().await?;
60    let transport_a = Arc::new(transport_a);
61    let transport_b = Arc::new(transport_b);
62
63    // Session A (uses odd channel IDs)
64    let session_a = Arc::new(RpcSession::with_channel_start(transport_a.clone(), 1));
65
66    // Session B (uses even channel IDs) - simple echo dispatcher
67    let session_b = Arc::new(RpcSession::with_channel_start(transport_b.clone(), 2));
68    session_b.set_dispatcher(|_channel_id, _method_id, payload| async move {
69        // Echo: respond with the same payload
70        let mut desc = MsgDescHot::new();
71        desc.flags = FrameFlags::DATA | FrameFlags::EOS;
72        Ok(Frame::with_payload(desc, payload))
73    });
74
75    // Spawn demux loops
76    let session_a_clone = session_a.clone();
77    let handle_a = tokio::spawn(async move { session_a_clone.run().await });
78
79    let session_b_clone = session_b.clone();
80    let handle_b = tokio::spawn(async move { session_b_clone.run().await });
81
82    // A calls B
83    let channel_id = session_a.next_channel_id();
84    let response = session_a
85        .call(channel_id, 1, b"hello".to_vec())
86        .await
87        .map_err(TestError::Rpc)?;
88
89    if response.payload != b"hello" {
90        return Err(TestError::Assertion(format!(
91            "expected echo 'hello', got {:?}",
92            response.payload
93        )));
94    }
95
96    // Cleanup
97    let _ = transport_a.close().await;
98    let _ = transport_b.close().await;
99    handle_a.abort();
100    handle_b.abort();
101
102    Ok(())
103}
104
105// ============================================================================
106// Scenario: Nested Callback
107// ============================================================================
108
109async fn run_nested_callback<F: TransportFactory>() -> Result<(), TestError> {
110    let (transport_a, transport_b) = F::connect_pair().await?;
111    let transport_a = Arc::new(transport_a);
112    let transport_b = Arc::new(transport_b);
113
114    // Session A (uses odd channel IDs)
115    // A provides a "get_prefix" service: returns "PREFIX:"
116    let session_a = Arc::new(RpcSession::with_channel_start(transport_a.clone(), 1));
117    session_a.set_dispatcher(|_channel_id, method_id, _payload| async move {
118        // method 1 = get_prefix
119        if method_id == 1 {
120            let prefix = b"PREFIX:";
121            let mut desc = MsgDescHot::new();
122            desc.flags = FrameFlags::DATA | FrameFlags::EOS;
123            Ok(Frame::with_payload(desc, prefix.to_vec()))
124        } else {
125            Err(RpcError::Status {
126                code: ErrorCode::Unimplemented,
127                message: "unknown method".into(),
128            })
129        }
130    });
131
132    // Session B (uses even channel IDs)
133    // B provides a "format" service: calls A's get_prefix, then appends the input
134    let session_b = Arc::new(RpcSession::with_channel_start(transport_b.clone(), 2));
135    let session_b_for_dispatcher = session_b.clone();
136    session_b.set_dispatcher(move |_channel_id, method_id, payload| {
137        let session = session_b_for_dispatcher.clone();
138        async move {
139            // method 1 = format
140            if method_id == 1 {
141                // Call A's get_prefix
142                let cb_channel = session.next_channel_id();
143                let cb_response =
144                    session
145                        .call(cb_channel, 1, vec![])
146                        .await
147                        .map_err(|e| RpcError::Status {
148                            code: ErrorCode::Internal,
149                            message: format!("callback failed: {:?}", e),
150                        })?;
151
152                // Combine prefix + input
153                let mut result = cb_response.payload;
154                result.extend(payload);
155
156                let mut desc = MsgDescHot::new();
157                desc.flags = FrameFlags::DATA | FrameFlags::EOS;
158                Ok(Frame::with_payload(desc, result))
159            } else {
160                Err(RpcError::Status {
161                    code: ErrorCode::Unimplemented,
162                    message: "unknown method".into(),
163                })
164            }
165        }
166    });
167
168    // Spawn demux loops
169    let session_a_clone = session_a.clone();
170    let handle_a = tokio::spawn(async move { session_a_clone.run().await });
171
172    let session_b_clone = session_b.clone();
173    let handle_b = tokio::spawn(async move { session_b_clone.run().await });
174
175    // A calls B's format service
176    let channel_id = session_a.next_channel_id();
177    let response = session_a
178        .call(channel_id, 1, b"test".to_vec())
179        .await
180        .map_err(TestError::Rpc)?;
181
182    if response.payload != b"PREFIX:test" {
183        return Err(TestError::Assertion(format!(
184            "expected 'PREFIX:test', got {:?}",
185            String::from_utf8_lossy(&response.payload)
186        )));
187    }
188
189    // Cleanup
190    let _ = transport_a.close().await;
191    let _ = transport_b.close().await;
192    handle_a.abort();
193    handle_b.abort();
194
195    Ok(())
196}
197
198// ============================================================================
199// Scenario: Multiple Nested Callbacks
200// ============================================================================
201
202async fn run_multiple_nested<F: TransportFactory>() -> Result<(), TestError> {
203    let (transport_a, transport_b) = F::connect_pair().await?;
204    let transport_a = Arc::new(transport_a);
205    let transport_b = Arc::new(transport_b);
206
207    // Session A (uses odd channel IDs)
208    // A provides a "get_value" service: returns "value_N" where N is from the request
209    let session_a = Arc::new(RpcSession::with_channel_start(transport_a.clone(), 1));
210    session_a.set_dispatcher(|_channel_id, method_id, payload| async move {
211        // method 1 = get_value
212        if method_id == 1 {
213            // payload is the key, return "value_" + key
214            let mut result = b"value_".to_vec();
215            result.extend(payload);
216            let mut desc = MsgDescHot::new();
217            desc.flags = FrameFlags::DATA | FrameFlags::EOS;
218            Ok(Frame::with_payload(desc, result))
219        } else {
220            Err(RpcError::Status {
221                code: ErrorCode::Unimplemented,
222                message: "unknown method".into(),
223            })
224        }
225    });
226
227    // Session B (uses even channel IDs)
228    // B provides a "combine" service: calls A's get_value 3 times and combines results
229    let session_b = Arc::new(RpcSession::with_channel_start(transport_b.clone(), 2));
230    let session_b_for_dispatcher = session_b.clone();
231    session_b.set_dispatcher(move |_channel_id, method_id, _payload| {
232        let session = session_b_for_dispatcher.clone();
233        async move {
234            // method 1 = combine
235            if method_id == 1 {
236                let mut result = Vec::new();
237
238                // Call A three times
239                for key in [b"a".as_slice(), b"b", b"c"] {
240                    let cb_channel = session.next_channel_id();
241                    let cb_response =
242                        session
243                            .call(cb_channel, 1, key.to_vec())
244                            .await
245                            .map_err(|e| RpcError::Status {
246                                code: ErrorCode::Internal,
247                                message: format!("callback failed: {:?}", e),
248                            })?;
249                    result.extend(&cb_response.payload);
250                    result.push(b',');
251                }
252
253                // Remove trailing comma
254                if !result.is_empty() {
255                    result.pop();
256                }
257
258                let mut desc = MsgDescHot::new();
259                desc.flags = FrameFlags::DATA | FrameFlags::EOS;
260                Ok(Frame::with_payload(desc, result))
261            } else {
262                Err(RpcError::Status {
263                    code: ErrorCode::Unimplemented,
264                    message: "unknown method".into(),
265                })
266            }
267        }
268    });
269
270    // Spawn demux loops
271    let session_a_clone = session_a.clone();
272    let handle_a = tokio::spawn(async move { session_a_clone.run().await });
273
274    let session_b_clone = session_b.clone();
275    let handle_b = tokio::spawn(async move { session_b_clone.run().await });
276
277    // A calls B's combine service
278    let channel_id = session_a.next_channel_id();
279    let response = session_a
280        .call(channel_id, 1, vec![])
281        .await
282        .map_err(TestError::Rpc)?;
283
284    let expected = b"value_a,value_b,value_c";
285    if response.payload != expected {
286        return Err(TestError::Assertion(format!(
287            "expected '{}', got '{}'",
288            String::from_utf8_lossy(expected),
289            String::from_utf8_lossy(&response.payload)
290        )));
291    }
292
293    // Cleanup
294    let _ = transport_a.close().await;
295    let _ = transport_b.close().await;
296    handle_a.abort();
297    handle_b.abort();
298
299    Ok(())
300}
301
302#[cfg(test)]
303mod tests {
304    use super::*;
305    use rapace_transport_mem::InProcTransport;
306
307    struct InProcFactory;
308
309    impl TransportFactory for InProcFactory {
310        type Transport = InProcTransport;
311
312        async fn connect_pair() -> Result<(Self::Transport, Self::Transport), TestError> {
313            Ok(InProcTransport::pair())
314        }
315    }
316
317    #[tokio::test]
318    async fn test_simple_echo_inproc() {
319        run_bidirectional_scenario::<InProcFactory>(BidirectionalScenario::SimpleEcho).await;
320    }
321
322    #[tokio::test]
323    async fn test_nested_callback_inproc() {
324        run_bidirectional_scenario::<InProcFactory>(BidirectionalScenario::NestedCallback).await;
325    }
326
327    #[tokio::test]
328    async fn test_multiple_nested_inproc() {
329        run_bidirectional_scenario::<InProcFactory>(BidirectionalScenario::MultipleNestedCallbacks)
330            .await;
331    }
332}