1use std::future::Future;
30use std::sync::Arc;
31
32use rapace::session::Session;
33use rapace_core::{
34 CancelReason, ControlPayload, ErrorCode, Frame, FrameFlags, MsgDescHot, NO_DEADLINE, RpcError,
35 RpcSession, Transport, control_method,
36};
37
38pub mod bidirectional;
39pub mod helper_binary;
40
41#[derive(Debug)]
43pub enum TestError {
44 Setup(String),
46 Rpc(rapace_core::RpcError),
48 Transport(rapace_core::TransportError),
50 Assertion(String),
52}
53
54impl std::fmt::Display for TestError {
55 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
56 match self {
57 TestError::Setup(msg) => write!(f, "setup error: {}", msg),
58 TestError::Rpc(e) => write!(f, "RPC error: {}", e),
59 TestError::Transport(e) => write!(f, "transport error: {}", e),
60 TestError::Assertion(msg) => write!(f, "assertion failed: {}", msg),
61 }
62 }
63}
64
65impl std::error::Error for TestError {}
66
67impl From<rapace_core::RpcError> for TestError {
68 fn from(e: rapace_core::RpcError) -> Self {
69 TestError::Rpc(e)
70 }
71}
72
73impl From<rapace_core::TransportError> for TestError {
74 fn from(e: rapace_core::TransportError) -> Self {
75 TestError::Transport(e)
76 }
77}
78
79pub trait TransportFactory: Send + Sync + 'static {
84 type Transport: Transport + Send + Sync + 'static;
86
87 fn connect_pair()
92 -> impl Future<Output = Result<(Self::Transport, Self::Transport), TestError>> + Send;
93}
94
95#[allow(async_fn_in_trait)]
101#[rapace::service]
102pub trait Adder {
103 async fn add(&self, a: i32, b: i32) -> i32;
105}
106
107pub struct AdderImpl;
109
110impl Adder for AdderImpl {
111 async fn add(&self, a: i32, b: i32) -> i32 {
112 a + b
113 }
114}
115
116#[allow(async_fn_in_trait)]
124#[rapace::service]
125pub trait RangeService {
126 async fn range(&self, n: u32) -> rapace_core::Streaming<u32>;
128}
129
130pub struct RangeServiceImpl;
132
133impl RangeService for RangeServiceImpl {
134 async fn range(&self, n: u32) -> rapace_core::Streaming<u32> {
135 let (tx, rx) = tokio::sync::mpsc::channel(16);
136 tokio::spawn(async move {
137 for i in 0..n {
138 if tx.send(Ok(i)).await.is_err() {
139 break;
140 }
141 }
142 });
143 Box::pin(tokio_stream::wrappers::ReceiverStream::new(rx))
144 }
145}
146
147pub async fn run_unary_happy_path<F: TransportFactory>() {
155 let result = run_unary_happy_path_inner::<F>().await;
156 if let Err(e) = result {
157 panic!("run_unary_happy_path failed: {}", e);
158 }
159}
160
161async fn run_unary_happy_path_inner<F: TransportFactory>() -> Result<(), TestError> {
162 let (client_transport, server_transport) = F::connect_pair().await?;
163 let client_transport = Arc::new(client_transport);
164 let server_transport = Arc::new(server_transport);
165
166 let server = AdderServer::new(AdderImpl);
167
168 let server_handle = tokio::spawn({
170 let server_transport = server_transport.clone();
171 async move {
172 let request = server_transport.recv_frame().await?;
173 let mut response = server
174 .dispatch(request.desc.method_id, request.payload)
175 .await
176 .map_err(TestError::Rpc)?;
177 response.desc.channel_id = request.desc.channel_id;
179 server_transport.send_frame(&response).await?;
180 Ok::<_, TestError>(())
181 }
182 });
183
184 let client_session = Arc::new(RpcSession::new(client_transport));
186 let client_session_runner = client_session.clone();
187 let _session_handle = tokio::spawn(async move { client_session_runner.run().await });
188
189 let client = AdderClient::new(client_session);
191 let result = client.add(2, 3).await?;
192
193 if result != 5 {
194 return Err(TestError::Assertion(format!(
195 "expected add(2, 3) = 5, got {}",
196 result
197 )));
198 }
199
200 server_handle
202 .await
203 .map_err(|e| TestError::Setup(format!("server task panicked: {}", e)))?
204 .map_err(|e| TestError::Setup(format!("server error: {}", e)))?;
205
206 Ok(())
207}
208
209pub async fn run_unary_multiple_calls<F: TransportFactory>() {
213 let result = run_unary_multiple_calls_inner::<F>().await;
214 if let Err(e) = result {
215 panic!("run_unary_multiple_calls failed: {}", e);
216 }
217}
218
219async fn run_unary_multiple_calls_inner<F: TransportFactory>() -> Result<(), TestError> {
220 let (client_transport, server_transport) = F::connect_pair().await?;
221 let client_transport = Arc::new(client_transport);
222 let server_transport = Arc::new(server_transport);
223
224 let server = AdderServer::new(AdderImpl);
225
226 let server_handle = tokio::spawn({
228 let server_transport = server_transport.clone();
229 async move {
230 for _ in 0..3 {
231 let request = server_transport.recv_frame().await?;
232 let mut response = server
233 .dispatch(request.desc.method_id, request.payload)
234 .await
235 .map_err(TestError::Rpc)?;
236 response.desc.channel_id = request.desc.channel_id;
238 server_transport.send_frame(&response).await?;
239 }
240 Ok::<_, TestError>(())
241 }
242 });
243
244 let client_session = Arc::new(RpcSession::new(client_transport));
246 let client_session_runner = client_session.clone();
247 let _session_handle = tokio::spawn(async move { client_session_runner.run().await });
248
249 let client = AdderClient::new(client_session);
250
251 let test_cases = [(1, 2, 3), (10, 20, 30), (-5, 5, 0)];
253
254 for (a, b, expected) in test_cases {
255 let result = client.add(a, b).await?;
256 if result != expected {
257 return Err(TestError::Assertion(format!(
258 "expected add({}, {}) = {}, got {}",
259 a, b, expected, result
260 )));
261 }
262 }
263
264 server_handle
265 .await
266 .map_err(|e| TestError::Setup(format!("server task panicked: {}", e)))?
267 .map_err(|e| TestError::Setup(format!("server error: {}", e)))?;
268
269 Ok(())
270}
271
272pub async fn run_error_response<F: TransportFactory>() {
281 let result = run_error_response_inner::<F>().await;
282 if let Err(e) = result {
283 panic!("run_error_response failed: {}", e);
284 }
285}
286
287async fn run_error_response_inner<F: TransportFactory>() -> Result<(), TestError> {
288 let (client_transport, server_transport) = F::connect_pair().await?;
289 let client_transport = Arc::new(client_transport);
290 let server_transport = Arc::new(server_transport);
291
292 let server_handle = tokio::spawn({
294 let server_transport = server_transport.clone();
295 async move {
296 let request = server_transport.recv_frame().await?;
297
298 let mut desc = MsgDescHot::new();
300 desc.msg_id = request.desc.msg_id;
301 desc.channel_id = request.desc.channel_id;
302 desc.method_id = request.desc.method_id;
303 desc.flags = FrameFlags::ERROR | FrameFlags::EOS;
304
305 let error_code = ErrorCode::InvalidArgument as u32;
307 let message = "test error message";
308 let mut payload = Vec::new();
309 payload.extend_from_slice(&error_code.to_le_bytes());
310 payload.extend_from_slice(&(message.len() as u32).to_le_bytes());
311 payload.extend_from_slice(message.as_bytes());
312
313 let frame = Frame::with_payload(desc, payload);
314 server_transport.send_frame(&frame).await?;
315
316 Ok::<_, TestError>(())
317 }
318 });
319
320 let client_session = Arc::new(RpcSession::new(client_transport));
322 let client_session_runner = client_session.clone();
323 let _session_handle = tokio::spawn(async move { client_session_runner.run().await });
324
325 let client = AdderClient::new(client_session);
327 let result = client.add(1, 2).await;
328
329 match result {
330 Err(RpcError::Status { code, message }) => {
331 if code != ErrorCode::InvalidArgument {
332 return Err(TestError::Assertion(format!(
333 "expected InvalidArgument, got {:?}",
334 code
335 )));
336 }
337 if message != "test error message" {
338 return Err(TestError::Assertion(format!(
339 "expected 'test error message', got '{}'",
340 message
341 )));
342 }
343 }
344 Ok(v) => {
345 return Err(TestError::Assertion(format!(
346 "expected error, got success: {}",
347 v
348 )));
349 }
350 Err(e) => {
351 return Err(TestError::Assertion(format!(
352 "expected Status error, got {:?}",
353 e
354 )));
355 }
356 }
357
358 server_handle
359 .await
360 .map_err(|e| TestError::Setup(format!("server task panicked: {}", e)))?
361 .map_err(|e| TestError::Setup(format!("server error: {}", e)))?;
362
363 Ok(())
364}
365
366pub async fn run_ping_pong<F: TransportFactory>() {
374 let result = run_ping_pong_inner::<F>().await;
375 if let Err(e) = result {
376 panic!("run_ping_pong failed: {}", e);
377 }
378}
379
380async fn run_ping_pong_inner<F: TransportFactory>() -> Result<(), TestError> {
381 let (client_transport, server_transport) = F::connect_pair().await?;
382 let client_transport = Arc::new(client_transport);
383 let server_transport = Arc::new(server_transport);
384
385 let server_handle = tokio::spawn({
387 let server_transport = server_transport.clone();
388 async move {
389 let request = server_transport.recv_frame().await?;
390
391 if request.desc.channel_id != 0 {
393 return Err(TestError::Assertion("expected control channel".into()));
394 }
395 if request.desc.method_id != control_method::PING {
396 return Err(TestError::Assertion("expected PING method_id".into()));
397 }
398 if !request.desc.flags.contains(FrameFlags::CONTROL) {
399 return Err(TestError::Assertion("expected CONTROL flag".into()));
400 }
401
402 let ping_payload: [u8; 8] = request
404 .payload
405 .try_into()
406 .map_err(|_| TestError::Assertion("ping payload should be 8 bytes".into()))?;
407
408 let mut desc = MsgDescHot::new();
409 desc.msg_id = request.desc.msg_id;
410 desc.channel_id = 0; desc.method_id = control_method::PONG;
412 desc.flags = FrameFlags::CONTROL | FrameFlags::EOS;
413
414 let frame = Frame::with_inline_payload(desc, &ping_payload)
415 .expect("pong payload should fit inline");
416 server_transport.send_frame(&frame).await?;
417
418 Ok::<_, TestError>(())
419 }
420 });
421
422 let ping_data: [u8; 8] = [0xDE, 0xAD, 0xBE, 0xEF, 0xCA, 0xFE, 0xBA, 0xBE];
424
425 let mut desc = MsgDescHot::new();
426 desc.msg_id = 1;
427 desc.channel_id = 0; desc.method_id = control_method::PING;
429 desc.flags = FrameFlags::CONTROL | FrameFlags::EOS;
430
431 let frame =
432 Frame::with_inline_payload(desc, &ping_data).expect("ping payload should fit inline");
433 client_transport.send_frame(&frame).await?;
434
435 let pong = client_transport.recv_frame().await?;
437
438 if pong.desc.channel_id != 0 {
439 return Err(TestError::Assertion("expected control channel".into()));
440 }
441 if pong.desc.method_id != control_method::PONG {
442 return Err(TestError::Assertion("expected PONG method_id".into()));
443 }
444 if pong.payload != ping_data {
445 return Err(TestError::Assertion(format!(
446 "PONG payload mismatch: expected {:?}, got {:?}",
447 ping_data, pong.payload
448 )));
449 }
450
451 server_handle
452 .await
453 .map_err(|e| TestError::Setup(format!("server task panicked: {}", e)))?
454 .map_err(|e| TestError::Setup(format!("server error: {}", e)))?;
455
456 Ok(())
457}
458
459fn now_ns() -> u64 {
465 use std::time::Instant;
466 static START: std::sync::OnceLock<Instant> = std::sync::OnceLock::new();
468 let start = START.get_or_init(Instant::now);
469 start.elapsed().as_nanos() as u64
470}
471
472pub async fn run_deadline_success<F: TransportFactory>() {
474 let result = run_deadline_success_inner::<F>().await;
475 if let Err(e) = result {
476 panic!("run_deadline_success failed: {}", e);
477 }
478}
479
480async fn run_deadline_success_inner<F: TransportFactory>() -> Result<(), TestError> {
481 let (client_transport, server_transport) = F::connect_pair().await?;
482 let client_transport = Arc::new(client_transport);
483 let server_transport = Arc::new(server_transport);
484
485 let server = AdderServer::new(AdderImpl);
486
487 let server_handle = tokio::spawn({
489 let server_transport = server_transport.clone();
490 async move {
491 let request = server_transport.recv_frame().await?;
492
493 if request.desc.deadline_ns != NO_DEADLINE {
495 let now = now_ns();
496 if now > request.desc.deadline_ns {
497 let mut desc = MsgDescHot::new();
499 desc.msg_id = request.desc.msg_id;
500 desc.channel_id = request.desc.channel_id;
501 desc.flags = FrameFlags::ERROR | FrameFlags::EOS;
502
503 let error_code = ErrorCode::DeadlineExceeded as u32;
504 let message = "deadline exceeded";
505 let mut payload = Vec::new();
506 payload.extend_from_slice(&error_code.to_le_bytes());
507 payload.extend_from_slice(&(message.len() as u32).to_le_bytes());
508 payload.extend_from_slice(message.as_bytes());
509
510 let frame = Frame::with_payload(desc, payload);
511 server_transport.send_frame(&frame).await?;
512 return Ok(());
513 }
514 }
515
516 let mut response = server
518 .dispatch(request.desc.method_id, request.payload)
519 .await
520 .map_err(TestError::Rpc)?;
521 response.desc.channel_id = request.desc.channel_id;
523 server_transport.send_frame(&response).await?;
524 Ok::<_, TestError>(())
525 }
526 });
527
528 let client_session = Arc::new(RpcSession::new(client_transport));
530 let client_session_runner = client_session.clone();
531 let _session_handle = tokio::spawn(async move { client_session_runner.run().await });
532
533 let deadline = now_ns() + 10_000_000_000; let client = AdderClient::new(client_session);
540 let result = client.add(2, 3).await?;
541
542 if result != 5 {
543 return Err(TestError::Assertion(format!("expected 5, got {}", result)));
544 }
545
546 let _ = deadline; server_handle
549 .await
550 .map_err(|e| TestError::Setup(format!("server task panicked: {}", e)))?
551 .map_err(|e| TestError::Setup(format!("server error: {}", e)))?;
552
553 Ok(())
554}
555
556pub async fn run_deadline_exceeded<F: TransportFactory>() {
558 let result = run_deadline_exceeded_inner::<F>().await;
559 if let Err(e) = result {
560 panic!("run_deadline_exceeded failed: {}", e);
561 }
562}
563
564async fn run_deadline_exceeded_inner<F: TransportFactory>() -> Result<(), TestError> {
565 let (client_transport, server_transport) = F::connect_pair().await?;
566 let client_transport = Arc::new(client_transport);
567 let server_transport = Arc::new(server_transport);
568
569 let baseline = now_ns();
572 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
574 let expired_deadline = baseline;
576
577 let server_handle = tokio::spawn({
579 let server_transport = server_transport.clone();
580 async move {
581 let request = server_transport.recv_frame().await?;
582
583 if request.desc.deadline_ns != NO_DEADLINE {
585 let now = now_ns();
586 if now > request.desc.deadline_ns {
587 let mut desc = MsgDescHot::new();
589 desc.msg_id = request.desc.msg_id;
590 desc.channel_id = request.desc.channel_id;
591 desc.flags = FrameFlags::ERROR | FrameFlags::EOS;
592
593 let error_code = ErrorCode::DeadlineExceeded as u32;
594 let message = "deadline exceeded";
595 let mut payload = Vec::new();
596 payload.extend_from_slice(&error_code.to_le_bytes());
597 payload.extend_from_slice(&(message.len() as u32).to_le_bytes());
598 payload.extend_from_slice(message.as_bytes());
599
600 let frame = Frame::with_payload(desc, payload);
601 server_transport.send_frame(&frame).await?;
602 return Ok(());
603 }
604 }
605
606 Err(TestError::Assertion(
608 "server should have rejected expired deadline".into(),
609 ))
610 }
611 });
612
613 let request_payload = facet_postcard::to_vec(&(1i32, 2i32)).unwrap();
614
615 let mut desc = MsgDescHot::new();
616 desc.msg_id = 1;
617 desc.channel_id = 1;
618 desc.method_id = 1; desc.flags = FrameFlags::DATA | FrameFlags::EOS;
620 desc.deadline_ns = expired_deadline;
621
622 let frame = if request_payload.len() <= rapace_core::INLINE_PAYLOAD_SIZE {
623 Frame::with_inline_payload(desc, &request_payload).expect("should fit inline")
624 } else {
625 Frame::with_payload(desc, request_payload)
626 };
627
628 client_transport.send_frame(&frame).await?;
629
630 let response = client_transport.recv_frame().await?;
632
633 if !response.desc.flags.contains(FrameFlags::ERROR) {
634 return Err(TestError::Assertion(
635 "expected ERROR flag on response".into(),
636 ));
637 }
638
639 if response.payload.len() < 8 {
641 return Err(TestError::Assertion("error payload too short".into()));
642 }
643
644 let error_code = u32::from_le_bytes(response.payload[0..4].try_into().unwrap());
645 let code = ErrorCode::from_u32(error_code);
646
647 if code != Some(ErrorCode::DeadlineExceeded) {
648 return Err(TestError::Assertion(format!(
649 "expected DeadlineExceeded, got {:?}",
650 code
651 )));
652 }
653
654 server_handle
655 .await
656 .map_err(|e| TestError::Setup(format!("server task panicked: {}", e)))?
657 .map_err(|e| TestError::Setup(format!("server error: {}", e)))?;
658
659 Ok(())
660}
661
662pub async fn run_cancellation<F: TransportFactory>() {
671 let result = run_cancellation_inner::<F>().await;
672 if let Err(e) = result {
673 panic!("run_cancellation failed: {}", e);
674 }
675}
676
677async fn run_cancellation_inner<F: TransportFactory>() -> Result<(), TestError> {
678 let (client_transport, server_transport) = F::connect_pair().await?;
679 let client_transport = Arc::new(client_transport);
680 let server_transport = Arc::new(server_transport);
681
682 let channel_to_cancel: u32 = 42;
683
684 let server_handle = tokio::spawn({
686 let server_transport = server_transport.clone();
687 async move {
688 let request = server_transport.recv_frame().await?;
690 if request.desc.channel_id != channel_to_cancel {
691 return Err(TestError::Assertion(format!(
692 "expected channel {}, got {}",
693 channel_to_cancel, request.desc.channel_id
694 )));
695 }
696
697 let cancel = server_transport.recv_frame().await?;
699 if cancel.desc.channel_id != 0 {
700 return Err(TestError::Assertion(
701 "cancel should be on control channel".into(),
702 ));
703 }
704 if cancel.desc.method_id != control_method::CANCEL_CHANNEL {
705 return Err(TestError::Assertion(format!(
706 "expected CANCEL_CHANNEL method_id, got {}",
707 cancel.desc.method_id
708 )));
709 }
710 if !cancel.desc.flags.contains(FrameFlags::CONTROL) {
711 return Err(TestError::Assertion("expected CONTROL flag".into()));
712 }
713
714 let cancel_payload: ControlPayload = facet_postcard::from_slice(cancel.payload)
716 .map_err(|e| {
717 TestError::Assertion(format!("failed to decode CancelChannel: {:?}", e))
718 })?;
719
720 match cancel_payload {
721 ControlPayload::CancelChannel { channel_id, reason } => {
722 if channel_id != channel_to_cancel {
723 return Err(TestError::Assertion(format!(
724 "expected cancel for channel {}, got {}",
725 channel_to_cancel, channel_id
726 )));
727 }
728 if reason != CancelReason::ClientCancel {
729 return Err(TestError::Assertion(format!(
730 "expected ClientCancel reason, got {:?}",
731 reason
732 )));
733 }
734 }
735 _ => {
736 return Err(TestError::Assertion(format!(
737 "expected CancelChannel, got {:?}",
738 cancel_payload
739 )));
740 }
741 }
742
743 Ok::<_, TestError>(())
744 }
745 });
746
747 let request_payload = facet_postcard::to_vec(&(1i32, 2i32)).unwrap();
749
750 let mut desc = MsgDescHot::new();
751 desc.msg_id = 1;
752 desc.channel_id = channel_to_cancel;
753 desc.method_id = 1;
754 desc.flags = FrameFlags::DATA;
755
756 let frame = Frame::with_inline_payload(desc, &request_payload).expect("should fit inline");
757 client_transport.send_frame(&frame).await?;
758
759 let cancel_payload = ControlPayload::CancelChannel {
761 channel_id: channel_to_cancel,
762 reason: CancelReason::ClientCancel,
763 };
764 let cancel_bytes = facet_postcard::to_vec(&cancel_payload).unwrap();
765
766 let mut cancel_desc = MsgDescHot::new();
767 cancel_desc.msg_id = 2;
768 cancel_desc.channel_id = 0; cancel_desc.method_id = control_method::CANCEL_CHANNEL;
770 cancel_desc.flags = FrameFlags::CONTROL | FrameFlags::EOS;
771
772 let cancel_frame =
773 Frame::with_inline_payload(cancel_desc, &cancel_bytes).expect("should fit inline");
774 client_transport.send_frame(&cancel_frame).await?;
775
776 server_handle
777 .await
778 .map_err(|e| TestError::Setup(format!("server task panicked: {}", e)))?
779 .map_err(|e| TestError::Setup(format!("server error: {}", e)))?;
780
781 Ok(())
782}
783
784pub async fn run_credit_grant<F: TransportFactory>() {
792 let result = run_credit_grant_inner::<F>().await;
793 if let Err(e) = result {
794 panic!("run_credit_grant failed: {}", e);
795 }
796}
797
798async fn run_credit_grant_inner<F: TransportFactory>() -> Result<(), TestError> {
799 let (client_transport, server_transport) = F::connect_pair().await?;
800 let client_transport = Arc::new(client_transport);
801 let server_transport = Arc::new(server_transport);
802
803 let channel_id: u32 = 1;
804 let credit_amount: u32 = 65536;
805
806 let server_handle = tokio::spawn({
808 let server_transport = server_transport.clone();
809 async move {
810 let grant_payload = ControlPayload::GrantCredits {
812 channel_id,
813 bytes: credit_amount,
814 };
815 let grant_bytes = facet_postcard::to_vec(&grant_payload).unwrap();
816
817 let mut desc = MsgDescHot::new();
818 desc.msg_id = 1;
819 desc.channel_id = 0; desc.method_id = control_method::GRANT_CREDITS;
821 desc.flags = FrameFlags::CONTROL | FrameFlags::CREDITS | FrameFlags::EOS;
822 desc.credit_grant = credit_amount; let frame = Frame::with_inline_payload(desc, &grant_bytes).expect("should fit inline");
825 server_transport.send_frame(&frame).await?;
826
827 Ok::<_, TestError>(())
828 }
829 });
830
831 let grant = client_transport.recv_frame().await?;
833
834 if grant.desc.channel_id != 0 {
835 return Err(TestError::Assertion(
836 "credit grant should be on control channel".into(),
837 ));
838 }
839 if grant.desc.method_id != control_method::GRANT_CREDITS {
840 return Err(TestError::Assertion(format!(
841 "expected GRANT_CREDITS method_id, got {}",
842 grant.desc.method_id
843 )));
844 }
845 if !grant.desc.flags.contains(FrameFlags::CREDITS) {
846 return Err(TestError::Assertion("expected CREDITS flag".into()));
847 }
848 if grant.desc.credit_grant != credit_amount {
849 return Err(TestError::Assertion(format!(
850 "expected credit_grant {}, got {}",
851 credit_amount, grant.desc.credit_grant
852 )));
853 }
854
855 let grant_payload: ControlPayload = facet_postcard::from_slice(grant.payload)
857 .map_err(|e| TestError::Assertion(format!("failed to decode GrantCredits: {:?}", e)))?;
858
859 match grant_payload {
860 ControlPayload::GrantCredits {
861 channel_id: ch,
862 bytes,
863 } => {
864 if ch != channel_id {
865 return Err(TestError::Assertion(format!(
866 "expected channel {}, got {}",
867 channel_id, ch
868 )));
869 }
870 if bytes != credit_amount {
871 return Err(TestError::Assertion(format!(
872 "expected {} bytes, got {}",
873 credit_amount, bytes
874 )));
875 }
876 }
877 _ => {
878 return Err(TestError::Assertion(format!(
879 "expected GrantCredits, got {:?}",
880 grant_payload
881 )));
882 }
883 }
884
885 server_handle
886 .await
887 .map_err(|e| TestError::Setup(format!("server task panicked: {}", e)))?
888 .map_err(|e| TestError::Setup(format!("server error: {}", e)))?;
889
890 Ok(())
891}
892
893pub async fn run_session_credit_exhaustion<F: TransportFactory>() {
902 let result = run_session_credit_exhaustion_inner::<F>().await;
903 if let Err(e) = result {
904 panic!("run_session_credit_exhaustion failed: {}", e);
905 }
906}
907
908async fn run_session_credit_exhaustion_inner<F: TransportFactory>() -> Result<(), TestError> {
909 use rapace::session::DEFAULT_INITIAL_CREDITS;
910
911 let (client_transport, _server_transport) = F::connect_pair().await?;
912 let client_transport = Arc::new(client_transport);
913
914 let session = Session::new(client_transport);
916
917 let large_payload = vec![0u8; DEFAULT_INITIAL_CREDITS as usize + 1];
920
921 let mut desc = MsgDescHot::new();
922 desc.msg_id = 1;
923 desc.channel_id = 1; desc.method_id = 1;
925 desc.flags = FrameFlags::DATA | FrameFlags::EOS;
926 desc.payload_len = large_payload.len() as u32;
927
928 let frame = Frame::with_payload(desc, large_payload);
929
930 let result = session.send_frame(&frame).await;
932
933 match result {
934 Err(RpcError::Status {
935 code: ErrorCode::ResourceExhausted,
936 ..
937 }) => {
938 Ok(())
940 }
941 Ok(()) => Err(TestError::Assertion(
942 "expected ResourceExhausted error, got success".into(),
943 )),
944 Err(e) => Err(TestError::Assertion(format!(
945 "expected ResourceExhausted, got {:?}",
946 e
947 ))),
948 }
949}
950
951pub async fn run_session_cancelled_channel_drop<F: TransportFactory>() {
953 let result = run_session_cancelled_channel_drop_inner::<F>().await;
954 if let Err(e) = result {
955 panic!("run_session_cancelled_channel_drop failed: {}", e);
956 }
957}
958
959async fn run_session_cancelled_channel_drop_inner<F: TransportFactory>() -> Result<(), TestError> {
960 let (client_transport, server_transport) = F::connect_pair().await?;
961 let client_transport = Arc::new(client_transport);
962 let server_transport = Arc::new(server_transport);
963
964 let session = Session::new(client_transport);
965 let channel_id = 42u32;
966
967 session.cancel_channel(channel_id);
969
970 if !session.is_cancelled(channel_id) {
972 return Err(TestError::Assertion("channel should be cancelled".into()));
973 }
974
975 let mut desc = MsgDescHot::new();
977 desc.msg_id = 1;
978 desc.channel_id = channel_id;
979 desc.method_id = 1;
980 desc.flags = FrameFlags::DATA | FrameFlags::EOS;
981
982 let frame = Frame::with_inline_payload(desc, b"test").expect("should fit");
983
984 session.send_frame(&frame).await?;
986
987 let mut desc2 = MsgDescHot::new();
990 desc2.msg_id = 2;
991 desc2.channel_id = 99; desc2.method_id = 1;
993 desc2.flags = FrameFlags::DATA | FrameFlags::EOS;
994
995 let frame2 = Frame::with_inline_payload(desc2, b"marker").expect("should fit");
996 session.transport().send_frame(&frame2).await?;
997
998 let received = server_transport.recv_frame().await?;
1000 if received.desc.channel_id != 99 {
1001 return Err(TestError::Assertion(format!(
1002 "expected channel 99, got {}",
1003 received.desc.channel_id
1004 )));
1005 }
1006 if received.payload != b"marker" {
1007 return Err(TestError::Assertion("expected marker payload".into()));
1008 }
1009
1010 Ok(())
1011}
1012
1013pub async fn run_session_cancel_control_frame<F: TransportFactory>() {
1015 let result = run_session_cancel_control_frame_inner::<F>().await;
1016 if let Err(e) = result {
1017 panic!("run_session_cancel_control_frame failed: {}", e);
1018 }
1019}
1020
1021async fn run_session_cancel_control_frame_inner<F: TransportFactory>() -> Result<(), TestError> {
1022 let (client_transport, server_transport) = F::connect_pair().await?;
1023 let client_transport = Arc::new(client_transport);
1024 let server_transport = Arc::new(server_transport);
1025
1026 let session = Session::new(server_transport);
1027 let channel_to_cancel = 42u32;
1028
1029 let cancel_payload = ControlPayload::CancelChannel {
1031 channel_id: channel_to_cancel,
1032 reason: CancelReason::ClientCancel,
1033 };
1034 let cancel_bytes = facet_postcard::to_vec(&cancel_payload).unwrap();
1035
1036 let mut cancel_desc = MsgDescHot::new();
1037 cancel_desc.msg_id = 1;
1038 cancel_desc.channel_id = 0; cancel_desc.method_id = control_method::CANCEL_CHANNEL;
1040 cancel_desc.flags = FrameFlags::CONTROL | FrameFlags::EOS;
1041
1042 let cancel_frame = Frame::with_inline_payload(cancel_desc, &cancel_bytes).expect("should fit");
1043 client_transport.send_frame(&cancel_frame).await?;
1044
1045 let mut data_desc = MsgDescHot::new();
1047 data_desc.msg_id = 2;
1048 data_desc.channel_id = channel_to_cancel;
1049 data_desc.method_id = 1;
1050 data_desc.flags = FrameFlags::DATA | FrameFlags::EOS;
1051
1052 let data_frame = Frame::with_inline_payload(data_desc, b"dropped").expect("should fit");
1053 client_transport.send_frame(&data_frame).await?;
1054
1055 let mut marker_desc = MsgDescHot::new();
1057 marker_desc.msg_id = 3;
1058 marker_desc.channel_id = 99;
1059 marker_desc.method_id = 1;
1060 marker_desc.flags = FrameFlags::DATA | FrameFlags::EOS;
1061
1062 let marker_frame = Frame::with_inline_payload(marker_desc, b"marker").expect("should fit");
1063 client_transport.send_frame(&marker_frame).await?;
1064
1065 let frame1 = session.recv_frame().await?;
1067 if frame1.desc.channel_id != 0 {
1068 return Err(TestError::Assertion(
1069 "first frame should be control frame".into(),
1070 ));
1071 }
1072
1073 if !session.is_cancelled(channel_to_cancel) {
1075 return Err(TestError::Assertion(
1076 "channel should be cancelled after control frame".into(),
1077 ));
1078 }
1079
1080 let frame2 = session.recv_frame().await?;
1082 if frame2.desc.channel_id != 99 {
1083 return Err(TestError::Assertion(format!(
1084 "expected channel 99 (marker), got {}",
1085 frame2.desc.channel_id
1086 )));
1087 }
1088 if frame2.payload != b"marker" {
1089 return Err(TestError::Assertion("expected marker payload".into()));
1090 }
1091
1092 Ok(())
1093}
1094
1095pub async fn run_session_grant_credits_control_frame<F: TransportFactory>() {
1097 let result = run_session_grant_credits_control_frame_inner::<F>().await;
1098 if let Err(e) = result {
1099 panic!("run_session_grant_credits_control_frame failed: {}", e);
1100 }
1101}
1102
1103async fn run_session_grant_credits_control_frame_inner<F: TransportFactory>()
1104-> Result<(), TestError> {
1105 use rapace::session::DEFAULT_INITIAL_CREDITS;
1106
1107 let (client_transport, server_transport) = F::connect_pair().await?;
1108 let client_transport = Arc::new(client_transport);
1109 let server_transport = Arc::new(server_transport);
1110
1111 let session = Session::new(client_transport);
1112 let channel_id = 1u32;
1113
1114 let initial = session.get_credits(channel_id);
1116 if initial != DEFAULT_INITIAL_CREDITS {
1117 return Err(TestError::Assertion(format!(
1118 "expected initial credits {}, got {}",
1119 DEFAULT_INITIAL_CREDITS, initial
1120 )));
1121 }
1122
1123 let grant_payload = ControlPayload::GrantCredits {
1125 channel_id,
1126 bytes: 10000,
1127 };
1128 let grant_bytes = facet_postcard::to_vec(&grant_payload).unwrap();
1129
1130 let mut grant_desc = MsgDescHot::new();
1131 grant_desc.msg_id = 1;
1132 grant_desc.channel_id = 0;
1133 grant_desc.method_id = control_method::GRANT_CREDITS;
1134 grant_desc.flags = FrameFlags::CONTROL | FrameFlags::CREDITS | FrameFlags::EOS;
1135 grant_desc.credit_grant = 10000;
1136
1137 let grant_frame = Frame::with_inline_payload(grant_desc, &grant_bytes).expect("should fit");
1138 server_transport.send_frame(&grant_frame).await?;
1139
1140 let frame = session.recv_frame().await?;
1142 if frame.desc.channel_id != 0 {
1143 return Err(TestError::Assertion("expected control frame".into()));
1144 }
1145
1146 let updated = session.get_credits(channel_id);
1148 let expected = DEFAULT_INITIAL_CREDITS + 10000;
1149 if updated != expected {
1150 return Err(TestError::Assertion(format!(
1151 "expected credits {}, got {}",
1152 expected, updated
1153 )));
1154 }
1155
1156 Ok(())
1157}
1158
1159pub async fn run_session_deadline_check<F: TransportFactory>() {
1161 let result = run_session_deadline_check_inner::<F>().await;
1162 if let Err(e) = result {
1163 panic!("run_session_deadline_check failed: {}", e);
1164 }
1165}
1166
1167async fn run_session_deadline_check_inner<F: TransportFactory>() -> Result<(), TestError> {
1168 let (client_transport, _server_transport) = F::connect_pair().await?;
1169 let client_transport = Arc::new(client_transport);
1170
1171 let session = Session::new(client_transport);
1172
1173 let mut desc1 = MsgDescHot::new();
1175 desc1.deadline_ns = NO_DEADLINE;
1176
1177 if session.is_deadline_exceeded(&desc1) {
1178 return Err(TestError::Assertion(
1179 "NO_DEADLINE should not be exceeded".into(),
1180 ));
1181 }
1182
1183 let mut desc2 = MsgDescHot::new();
1185 desc2.deadline_ns = now_ns() + 10_000_000_000; if session.is_deadline_exceeded(&desc2) {
1188 return Err(TestError::Assertion(
1189 "future deadline should not be exceeded".into(),
1190 ));
1191 }
1192
1193 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
1196 let mut desc3 = MsgDescHot::new();
1197 desc3.deadline_ns = 1; if !session.is_deadline_exceeded(&desc3) {
1200 return Err(TestError::Assertion(
1201 "past deadline should be exceeded".into(),
1202 ));
1203 }
1204
1205 Ok(())
1206}
1207
1208pub async fn run_server_streaming_happy_path<F: TransportFactory>() {
1221 let result = run_server_streaming_happy_path_inner::<F>().await;
1222 if let Err(e) = result {
1223 panic!("run_server_streaming_happy_path failed: {}", e);
1224 }
1225}
1226
1227async fn run_server_streaming_happy_path_inner<F: TransportFactory>() -> Result<(), TestError> {
1228 let (client_transport, server_transport) = F::connect_pair().await?;
1229 let client_transport = Arc::new(client_transport);
1230 let server_transport = Arc::new(server_transport);
1231
1232 let client_session = Session::new(client_transport.clone());
1233 let server_session = Session::new(server_transport.clone());
1234
1235 let channel_id = 1u32;
1236 let item_count = 5;
1237
1238 let server_handle = tokio::spawn({
1240 let server_session = server_session;
1241 async move {
1242 let request = server_session.recv_frame().await?;
1244 if request.desc.channel_id != channel_id {
1245 return Err(TestError::Assertion(format!(
1246 "expected channel {}, got {}",
1247 channel_id, request.desc.channel_id
1248 )));
1249 }
1250
1251 if !request.desc.flags.contains(FrameFlags::EOS) {
1253 return Err(TestError::Assertion("request should have EOS".into()));
1254 }
1255
1256 let count: i32 = facet_postcard::from_slice(request.payload)
1258 .map_err(|e| TestError::Assertion(format!("decode request: {:?}", e)))?;
1259
1260 for i in 0..count {
1262 let mut desc = MsgDescHot::new();
1263 desc.msg_id = (i + 1) as u64;
1264 desc.channel_id = channel_id;
1265 desc.method_id = request.desc.method_id;
1266 desc.flags = FrameFlags::DATA; let item_bytes = facet_postcard::to_vec(&i).unwrap();
1269 let frame =
1270 Frame::with_inline_payload(desc, &item_bytes).expect("item should fit inline");
1271 server_session.send_frame(&frame).await?;
1272 }
1273
1274 let mut eos_desc = MsgDescHot::new();
1276 eos_desc.msg_id = (count + 1) as u64;
1277 eos_desc.channel_id = channel_id;
1278 eos_desc.method_id = request.desc.method_id;
1279 eos_desc.flags = FrameFlags::DATA | FrameFlags::EOS;
1280
1281 let eos_frame =
1282 Frame::with_inline_payload(eos_desc, &[]).expect("empty frame should fit inline");
1283 server_session.send_frame(&eos_frame).await?;
1284
1285 Ok::<_, TestError>(())
1286 }
1287 });
1288
1289 let request_bytes = facet_postcard::to_vec(&item_count).unwrap();
1291
1292 let mut desc = MsgDescHot::new();
1293 desc.msg_id = 1;
1294 desc.channel_id = channel_id;
1295 desc.method_id = 1;
1296 desc.flags = FrameFlags::DATA | FrameFlags::EOS; let frame = Frame::with_inline_payload(desc, &request_bytes).expect("should fit inline");
1299 client_session.send_frame(&frame).await?;
1300
1301 let state = client_session.get_lifecycle(channel_id);
1303 if state != rapace::session::ChannelLifecycle::HalfClosedLocal {
1304 return Err(TestError::Assertion(format!(
1305 "expected HalfClosedLocal after client EOS, got {:?}",
1306 state
1307 )));
1308 }
1309
1310 let mut received = Vec::new();
1312 loop {
1313 let frame = client_session.recv_frame().await?;
1314 if frame.desc.channel_id != channel_id {
1315 return Err(TestError::Assertion(format!(
1316 "expected channel {}, got {}",
1317 channel_id, frame.desc.channel_id
1318 )));
1319 }
1320
1321 if frame.desc.flags.contains(FrameFlags::EOS) {
1323 break;
1324 }
1325
1326 let item: i32 = facet_postcard::from_slice(frame.payload)
1328 .map_err(|e| TestError::Assertion(format!("decode item: {:?}", e)))?;
1329 received.push(item);
1330 }
1331
1332 let expected: Vec<i32> = (0..item_count).collect();
1334 if received != expected {
1335 return Err(TestError::Assertion(format!(
1336 "expected {:?}, got {:?}",
1337 expected, received
1338 )));
1339 }
1340
1341 let final_state = client_session.get_lifecycle(channel_id);
1343 if final_state != rapace::session::ChannelLifecycle::Closed {
1344 return Err(TestError::Assertion(format!(
1345 "expected Closed after both EOS, got {:?}",
1346 final_state
1347 )));
1348 }
1349
1350 server_handle
1351 .await
1352 .map_err(|e| TestError::Setup(format!("server task panicked: {}", e)))?
1353 .map_err(|e| TestError::Setup(format!("server error: {}", e)))?;
1354
1355 Ok(())
1356}
1357
1358pub async fn run_client_streaming_happy_path<F: TransportFactory>() {
1365 let result = run_client_streaming_happy_path_inner::<F>().await;
1366 if let Err(e) = result {
1367 panic!("run_client_streaming_happy_path failed: {}", e);
1368 }
1369}
1370
1371async fn run_client_streaming_happy_path_inner<F: TransportFactory>() -> Result<(), TestError> {
1372 let (client_transport, server_transport) = F::connect_pair().await?;
1373 let client_transport = Arc::new(client_transport);
1374 let server_transport = Arc::new(server_transport);
1375
1376 let client_session = Session::new(client_transport.clone());
1377 let server_session = Session::new(server_transport.clone());
1378
1379 let channel_id = 1u32;
1380 let items_to_send: Vec<i32> = vec![10, 20, 30, 40, 50];
1381
1382 let server_handle = tokio::spawn({
1384 let server_session = server_session;
1385 let expected_items = items_to_send.clone();
1386 async move {
1387 let mut sum = 0i32;
1388 let mut count = 0;
1389
1390 loop {
1391 let frame = server_session.recv_frame().await?;
1392 if frame.desc.channel_id != channel_id {
1393 return Err(TestError::Assertion(format!(
1394 "expected channel {}, got {}",
1395 channel_id, frame.desc.channel_id
1396 )));
1397 }
1398
1399 if !frame.payload.is_empty() {
1401 let item: i32 = facet_postcard::from_slice(frame.payload)
1402 .map_err(|e| TestError::Assertion(format!("decode item: {:?}", e)))?;
1403 sum += item;
1404 count += 1;
1405 }
1406
1407 if frame.desc.flags.contains(FrameFlags::EOS) {
1409 break;
1410 }
1411 }
1412
1413 if count != expected_items.len() {
1415 return Err(TestError::Assertion(format!(
1416 "expected {} items, got {}",
1417 expected_items.len(),
1418 count
1419 )));
1420 }
1421
1422 let mut desc = MsgDescHot::new();
1424 desc.msg_id = 1;
1425 desc.channel_id = channel_id;
1426 desc.method_id = 1;
1427 desc.flags = FrameFlags::DATA | FrameFlags::EOS;
1428
1429 let response_bytes = facet_postcard::to_vec(&sum).unwrap();
1430 let frame =
1431 Frame::with_inline_payload(desc, &response_bytes).expect("should fit inline");
1432 server_session.send_frame(&frame).await?;
1433
1434 Ok::<_, TestError>(())
1435 }
1436 });
1437
1438 for (i, &item) in items_to_send.iter().enumerate() {
1440 let is_last = i == items_to_send.len() - 1;
1441
1442 let mut desc = MsgDescHot::new();
1443 desc.msg_id = (i + 1) as u64;
1444 desc.channel_id = channel_id;
1445 desc.method_id = 1;
1446 desc.flags = if is_last {
1447 FrameFlags::DATA | FrameFlags::EOS
1448 } else {
1449 FrameFlags::DATA
1450 };
1451
1452 let item_bytes = facet_postcard::to_vec(&item).unwrap();
1453 let frame = Frame::with_inline_payload(desc, &item_bytes).expect("should fit inline");
1454 client_session.send_frame(&frame).await?;
1455 }
1456
1457 let response = client_session.recv_frame().await?;
1459 if response.desc.channel_id != channel_id {
1460 return Err(TestError::Assertion(format!(
1461 "expected channel {}, got {}",
1462 channel_id, response.desc.channel_id
1463 )));
1464 }
1465 if !response.desc.flags.contains(FrameFlags::EOS) {
1466 return Err(TestError::Assertion("response should have EOS".into()));
1467 }
1468
1469 let sum: i32 = facet_postcard::from_slice(response.payload)
1470 .map_err(|e| TestError::Assertion(format!("decode response: {:?}", e)))?;
1471
1472 let expected_sum: i32 = items_to_send.iter().sum();
1473 if sum != expected_sum {
1474 return Err(TestError::Assertion(format!(
1475 "expected sum {}, got {}",
1476 expected_sum, sum
1477 )));
1478 }
1479
1480 let final_state = client_session.get_lifecycle(channel_id);
1482 if final_state != rapace::session::ChannelLifecycle::Closed {
1483 return Err(TestError::Assertion(format!(
1484 "expected Closed, got {:?}",
1485 final_state
1486 )));
1487 }
1488
1489 server_handle
1490 .await
1491 .map_err(|e| TestError::Setup(format!("server task panicked: {}", e)))?
1492 .map_err(|e| TestError::Setup(format!("server error: {}", e)))?;
1493
1494 Ok(())
1495}
1496
1497pub async fn run_bidirectional_streaming<F: TransportFactory>() {
1504 let result = run_bidirectional_streaming_inner::<F>().await;
1505 if let Err(e) = result {
1506 panic!("run_bidirectional_streaming failed: {}", e);
1507 }
1508}
1509
1510async fn run_bidirectional_streaming_inner<F: TransportFactory>() -> Result<(), TestError> {
1511 let (client_transport, server_transport) = F::connect_pair().await?;
1512 let client_transport = Arc::new(client_transport);
1513 let server_transport = Arc::new(server_transport);
1514
1515 let client_session = Session::new(client_transport.clone());
1516 let server_session = Session::new(server_transport.clone());
1517
1518 let channel_id = 1u32;
1519
1520 let server_handle = tokio::spawn({
1522 let server_session = server_session;
1523 async move {
1524 let mut received = Vec::new();
1525
1526 for (i, item) in [100i32, 200, 300].iter().enumerate() {
1528 let is_last = i == 2;
1529 let mut desc = MsgDescHot::new();
1530 desc.msg_id = (i + 1) as u64;
1531 desc.channel_id = channel_id;
1532 desc.method_id = 1;
1533 desc.flags = if is_last {
1534 FrameFlags::DATA | FrameFlags::EOS
1535 } else {
1536 FrameFlags::DATA
1537 };
1538
1539 let item_bytes = facet_postcard::to_vec(item).unwrap();
1540 let frame =
1541 Frame::with_inline_payload(desc, &item_bytes).expect("should fit inline");
1542 server_session.send_frame(&frame).await?;
1543 }
1544
1545 loop {
1547 let frame = server_session.recv_frame().await?;
1548 if frame.desc.channel_id != channel_id {
1549 continue; }
1551
1552 if !frame.payload.is_empty() {
1553 let item: i32 = facet_postcard::from_slice(frame.payload)
1554 .map_err(|e| TestError::Assertion(format!("decode: {:?}", e)))?;
1555 received.push(item);
1556 }
1557
1558 if frame.desc.flags.contains(FrameFlags::EOS) {
1559 break;
1560 }
1561 }
1562
1563 let expected = vec![1, 2, 3, 4, 5];
1565 if received != expected {
1566 return Err(TestError::Assertion(format!(
1567 "server expected {:?}, got {:?}",
1568 expected, received
1569 )));
1570 }
1571
1572 Ok::<_, TestError>(())
1573 }
1574 });
1575
1576 let mut client_received = Vec::new();
1578
1579 for (i, item) in [1i32, 2, 3, 4, 5].iter().enumerate() {
1581 let is_last = i == 4;
1582 let mut desc = MsgDescHot::new();
1583 desc.msg_id = (i + 100) as u64;
1584 desc.channel_id = channel_id;
1585 desc.method_id = 1;
1586 desc.flags = if is_last {
1587 FrameFlags::DATA | FrameFlags::EOS
1588 } else {
1589 FrameFlags::DATA
1590 };
1591
1592 let item_bytes = facet_postcard::to_vec(item).unwrap();
1593 let frame = Frame::with_inline_payload(desc, &item_bytes).expect("should fit inline");
1594 client_session.send_frame(&frame).await?;
1595 }
1596
1597 loop {
1599 let frame = client_session.recv_frame().await?;
1600 if frame.desc.channel_id != channel_id {
1601 continue;
1602 }
1603
1604 if !frame.payload.is_empty() {
1605 let item: i32 = facet_postcard::from_slice(frame.payload)
1606 .map_err(|e| TestError::Assertion(format!("decode: {:?}", e)))?;
1607 client_received.push(item);
1608 }
1609
1610 if frame.desc.flags.contains(FrameFlags::EOS) {
1611 break;
1612 }
1613 }
1614
1615 let expected = vec![100, 200, 300];
1617 if client_received != expected {
1618 return Err(TestError::Assertion(format!(
1619 "client expected {:?}, got {:?}",
1620 expected, client_received
1621 )));
1622 }
1623
1624 let final_state = client_session.get_lifecycle(channel_id);
1626 if final_state != rapace::session::ChannelLifecycle::Closed {
1627 return Err(TestError::Assertion(format!(
1628 "expected Closed, got {:?}",
1629 final_state
1630 )));
1631 }
1632
1633 server_handle
1634 .await
1635 .map_err(|e| TestError::Setup(format!("server task panicked: {}", e)))?
1636 .map_err(|e| TestError::Setup(format!("server error: {}", e)))?;
1637
1638 Ok(())
1639}
1640
1641pub async fn run_streaming_cancellation<F: TransportFactory>() {
1647 let result = run_streaming_cancellation_inner::<F>().await;
1648 if let Err(e) = result {
1649 panic!("run_streaming_cancellation failed: {}", e);
1650 }
1651}
1652
1653pub async fn run_macro_server_streaming<F: TransportFactory>() {
1664 let result = run_macro_server_streaming_inner::<F>().await;
1665 if let Err(e) = result {
1666 panic!("run_macro_server_streaming failed: {}", e);
1667 }
1668}
1669
1670async fn run_macro_server_streaming_inner<F: TransportFactory>() -> Result<(), TestError> {
1671 use futures::StreamExt;
1672
1673 let (client_transport, server_transport) = F::connect_pair().await?;
1674 let client_transport = Arc::new(client_transport);
1675 let server_transport = Arc::new(server_transport);
1676
1677 let server = RangeServiceServer::new(RangeServiceImpl);
1678
1679 let server_handle = tokio::spawn({
1681 let server_transport = server_transport.clone();
1682 async move {
1683 let request = server_transport.recv_frame().await?;
1685
1686 server
1688 .dispatch_streaming(
1689 request.desc.method_id,
1690 request.desc.channel_id,
1691 request.payload,
1692 server_transport.as_ref(),
1693 )
1694 .await
1695 .map_err(TestError::Rpc)?;
1696
1697 Ok::<_, TestError>(())
1698 }
1699 });
1700
1701 let client_session = Arc::new(RpcSession::new(client_transport));
1703 let client_session_runner = client_session.clone();
1704 let _session_handle = tokio::spawn(async move { client_session_runner.run().await });
1705
1706 let client = RangeServiceClient::new(client_session);
1708 let mut stream = client.range(5).await?;
1709
1710 let mut items = Vec::new();
1712 while let Some(result) = stream.next().await {
1713 let item = result?;
1714 items.push(item);
1715 }
1716
1717 let expected: Vec<u32> = (0..5).collect();
1719 if items != expected {
1720 return Err(TestError::Assertion(format!(
1721 "expected {:?}, got {:?}",
1722 expected, items
1723 )));
1724 }
1725
1726 server_handle
1728 .await
1729 .map_err(|e| TestError::Setup(format!("server task panicked: {}", e)))?
1730 .map_err(|e| TestError::Setup(format!("server error: {}", e)))?;
1731
1732 Ok(())
1733}
1734
1735async fn run_streaming_cancellation_inner<F: TransportFactory>() -> Result<(), TestError> {
1736 let (client_transport, server_transport) = F::connect_pair().await?;
1737 let client_transport = Arc::new(client_transport);
1738 let server_transport = Arc::new(server_transport);
1739
1740 let client_session = Session::new(client_transport.clone());
1741 let server_session = Session::new(server_transport.clone());
1742
1743 let channel_id = 1u32;
1744
1745 let server_handle = tokio::spawn({
1747 let server_session = server_session;
1748 async move {
1749 for i in 0..2 {
1751 let mut desc = MsgDescHot::new();
1752 desc.msg_id = (i + 1) as u64;
1753 desc.channel_id = channel_id;
1754 desc.method_id = 1;
1755 desc.flags = FrameFlags::DATA;
1756
1757 let item_bytes = facet_postcard::to_vec(&i).unwrap();
1758 let frame =
1759 Frame::with_inline_payload(desc, &item_bytes).expect("should fit inline");
1760 server_session.send_frame(&frame).await?;
1761 }
1762
1763 let cancel_payload = ControlPayload::CancelChannel {
1765 channel_id,
1766 reason: CancelReason::ClientCancel,
1767 };
1768 let cancel_bytes = facet_postcard::to_vec(&cancel_payload).unwrap();
1769
1770 let mut cancel_desc = MsgDescHot::new();
1771 cancel_desc.msg_id = 100;
1772 cancel_desc.channel_id = 0;
1773 cancel_desc.method_id = control_method::CANCEL_CHANNEL;
1774 cancel_desc.flags = FrameFlags::CONTROL | FrameFlags::EOS;
1775
1776 let cancel_frame =
1777 Frame::with_inline_payload(cancel_desc, &cancel_bytes).expect("should fit inline");
1778 server_session.transport().send_frame(&cancel_frame).await?;
1779
1780 let mut marker_desc = MsgDescHot::new();
1782 marker_desc.msg_id = 200;
1783 marker_desc.channel_id = 99;
1784 marker_desc.method_id = 1;
1785 marker_desc.flags = FrameFlags::DATA | FrameFlags::EOS;
1786
1787 let marker_frame =
1788 Frame::with_inline_payload(marker_desc, b"done").expect("should fit inline");
1789 server_session.transport().send_frame(&marker_frame).await?;
1790
1791 Ok::<_, TestError>(())
1792 }
1793 });
1794
1795 let mut received = Vec::new();
1797
1798 loop {
1799 let frame = client_session.recv_frame().await?;
1800
1801 if frame.desc.channel_id == 99 {
1802 break;
1804 }
1805
1806 if frame.desc.channel_id == 0 {
1807 continue;
1809 }
1810
1811 if frame.desc.channel_id == channel_id && !frame.payload.is_empty() {
1812 let item: i32 = facet_postcard::from_slice(frame.payload)
1813 .map_err(|e| TestError::Assertion(format!("decode: {:?}", e)))?;
1814 received.push(item);
1815 }
1816 }
1817
1818 if received.len() > 2 {
1821 return Err(TestError::Assertion(format!(
1822 "expected at most 2 items (before cancel), got {:?}",
1823 received
1824 )));
1825 }
1826
1827 if !client_session.is_cancelled(channel_id) {
1829 return Err(TestError::Assertion("channel should be cancelled".into()));
1830 }
1831
1832 server_handle
1833 .await
1834 .map_err(|e| TestError::Setup(format!("server task panicked: {}", e)))?
1835 .map_err(|e| TestError::Setup(format!("server error: {}", e)))?;
1836
1837 Ok(())
1838}
1839
1840#[allow(async_fn_in_trait)]
1850#[rapace::service]
1851pub trait LargeBlobService {
1852 async fn echo(&self, data: Vec<u8>) -> Vec<u8>;
1855
1856 async fn xor_transform(&self, data: Vec<u8>, pattern: u8) -> Vec<u8>;
1859
1860 async fn checksum(&self, data: Vec<u8>) -> (u32, u32);
1863}
1864
1865pub struct LargeBlobServiceImpl;
1867
1868impl LargeBlobService for LargeBlobServiceImpl {
1869 async fn echo(&self, data: Vec<u8>) -> Vec<u8> {
1870 data
1871 }
1872
1873 async fn xor_transform(&self, data: Vec<u8>, pattern: u8) -> Vec<u8> {
1874 data.into_iter().map(|b| b ^ pattern).collect()
1875 }
1876
1877 async fn checksum(&self, data: Vec<u8>) -> (u32, u32) {
1878 let len = data.len() as u32;
1879 let sum = data.iter().fold(0u32, |acc, &b| acc.wrapping_add(b as u32));
1880 (len, sum)
1881 }
1882}
1883
1884pub async fn run_large_blob_echo<F: TransportFactory>() {
1893 let result = run_large_blob_echo_inner::<F>().await;
1894 if let Err(e) = result {
1895 panic!("run_large_blob_echo failed: {}", e);
1896 }
1897}
1898
1899async fn run_large_blob_echo_inner<F: TransportFactory>() -> Result<(), TestError> {
1900 let (client_transport, server_transport) = F::connect_pair().await?;
1901 let client_transport = Arc::new(client_transport);
1902 let server_transport = Arc::new(server_transport);
1903
1904 let server = LargeBlobServiceServer::new(LargeBlobServiceImpl);
1905
1906 let server_handle = tokio::spawn({
1908 let server_transport = server_transport.clone();
1909 async move {
1910 let request = server_transport.recv_frame().await?;
1911 let mut response = server
1912 .dispatch(request.desc.method_id, request.payload)
1913 .await
1914 .map_err(TestError::Rpc)?;
1915 response.desc.channel_id = request.desc.channel_id;
1917 server_transport.send_frame(&response).await?;
1918 Ok::<_, TestError>(())
1919 }
1920 });
1921
1922 let client_session = Arc::new(RpcSession::new(client_transport));
1924 let client_session_runner = client_session.clone();
1925 let _session_handle = tokio::spawn(async move { client_session_runner.run().await });
1926
1927 let blob: Vec<u8> = (0..1024).map(|i| (i % 256) as u8).collect();
1929
1930 let client = LargeBlobServiceClient::new(client_session);
1931 let result = client.echo(blob.clone()).await?;
1932
1933 if result != blob {
1934 return Err(TestError::Assertion(format!(
1935 "echo mismatch: expected {} bytes, got {} bytes",
1936 blob.len(),
1937 result.len()
1938 )));
1939 }
1940
1941 server_handle
1942 .await
1943 .map_err(|e| TestError::Setup(format!("server task panicked: {}", e)))?
1944 .map_err(|e| TestError::Setup(format!("server error: {}", e)))?;
1945
1946 Ok(())
1947}
1948
1949pub async fn run_large_blob_transform<F: TransportFactory>() {
1953 let result = run_large_blob_transform_inner::<F>().await;
1954 if let Err(e) = result {
1955 panic!("run_large_blob_transform failed: {}", e);
1956 }
1957}
1958
1959async fn run_large_blob_transform_inner<F: TransportFactory>() -> Result<(), TestError> {
1960 let (client_transport, server_transport) = F::connect_pair().await?;
1961 let client_transport = Arc::new(client_transport);
1962 let server_transport = Arc::new(server_transport);
1963
1964 let server = LargeBlobServiceServer::new(LargeBlobServiceImpl);
1965
1966 let server_handle = tokio::spawn({
1968 let server_transport = server_transport.clone();
1969 async move {
1970 let request = server_transport.recv_frame().await?;
1971 let mut response = server
1972 .dispatch(request.desc.method_id, request.payload)
1973 .await
1974 .map_err(TestError::Rpc)?;
1975 response.desc.channel_id = request.desc.channel_id;
1977 server_transport.send_frame(&response).await?;
1978 Ok::<_, TestError>(())
1979 }
1980 });
1981
1982 let client_session = Arc::new(RpcSession::new(client_transport));
1984 let client_session_runner = client_session.clone();
1985 let _session_handle = tokio::spawn(async move { client_session_runner.run().await });
1986
1987 let blob: Vec<u8> = (0..2048).map(|i| (i % 256) as u8).collect();
1989 let pattern: u8 = 0xAA;
1990 let expected: Vec<u8> = blob.iter().map(|&b| b ^ pattern).collect();
1991
1992 let client = LargeBlobServiceClient::new(client_session);
1993 let result = client.xor_transform(blob, pattern).await?;
1994
1995 if result != expected {
1996 return Err(TestError::Assertion(format!(
1997 "transform mismatch at byte 0: expected {:02x}, got {:02x}",
1998 expected[0], result[0]
1999 )));
2000 }
2001
2002 server_handle
2003 .await
2004 .map_err(|e| TestError::Setup(format!("server task panicked: {}", e)))?
2005 .map_err(|e| TestError::Setup(format!("server error: {}", e)))?;
2006
2007 Ok(())
2008}
2009
2010pub async fn run_large_blob_checksum<F: TransportFactory>() {
2012 let result = run_large_blob_checksum_inner::<F>().await;
2013 if let Err(e) = result {
2014 panic!("run_large_blob_checksum failed: {}", e);
2015 }
2016}
2017
2018async fn run_large_blob_checksum_inner<F: TransportFactory>() -> Result<(), TestError> {
2019 let (client_transport, server_transport) = F::connect_pair().await?;
2020 let client_transport = Arc::new(client_transport);
2021 let server_transport = Arc::new(server_transport);
2022
2023 let server = LargeBlobServiceServer::new(LargeBlobServiceImpl);
2024
2025 let client_session = Arc::new(RpcSession::new(client_transport));
2027 let client_session_runner = client_session.clone();
2028 let _session_handle = tokio::spawn(async move { client_session_runner.run().await });
2029
2030 let sizes = [100, 1000, 4000]; for size in sizes {
2034 let server_handle = tokio::spawn({
2036 let server_transport = server_transport.clone();
2037 let server = LargeBlobServiceServer::new(LargeBlobServiceImpl);
2038 async move {
2039 let request = server_transport.recv_frame().await?;
2040 let mut response = server
2041 .dispatch(request.desc.method_id, request.payload)
2042 .await
2043 .map_err(TestError::Rpc)?;
2044 response.desc.channel_id = request.desc.channel_id;
2046 server_transport.send_frame(&response).await?;
2047 Ok::<_, TestError>(())
2048 }
2049 });
2050
2051 let blob: Vec<u8> = (0..size).map(|i| ((i * 7) % 256) as u8).collect();
2053 let expected_len = blob.len() as u32;
2054 let expected_sum = blob.iter().fold(0u32, |acc, &b| acc.wrapping_add(b as u32));
2055
2056 let client = LargeBlobServiceClient::new(client_session.clone());
2057 let (len, sum) = client.checksum(blob).await?;
2058
2059 if len != expected_len {
2060 return Err(TestError::Assertion(format!(
2061 "size {}: length mismatch: expected {}, got {}",
2062 size, expected_len, len
2063 )));
2064 }
2065 if sum != expected_sum {
2066 return Err(TestError::Assertion(format!(
2067 "size {}: sum mismatch: expected {}, got {}",
2068 size, expected_sum, sum
2069 )));
2070 }
2071
2072 server_handle
2073 .await
2074 .map_err(|e| TestError::Setup(format!("server task panicked: {}", e)))?
2075 .map_err(|e| TestError::Setup(format!("server error: {}", e)))?;
2076 }
2077
2078 let _ = server; Ok(())
2081}
2082
2083#[cfg(test)]
2094mod registry_tests {
2095 use rapace_registry::ServiceRegistry;
2096
2097 use super::*;
2098
2099 #[test]
2100 fn test_adder_registration() {
2101 let mut registry = ServiceRegistry::new();
2102 adder_register(&mut registry);
2103
2104 let service = registry
2106 .service("Adder")
2107 .expect("Adder service should exist");
2108 assert_eq!(service.name, "Adder");
2109
2110 let add_method = service.method("add").expect("add method should exist");
2112 assert_eq!(add_method.name, "add");
2113 assert_eq!(add_method.full_name, "Adder.add");
2114 assert!(!add_method.is_streaming);
2115
2116 assert!(
2118 !add_method.request_shape.type_identifier.is_empty(),
2119 "request shape should have a type identifier"
2120 );
2121 assert!(
2122 !add_method.response_shape.type_identifier.is_empty(),
2123 "response shape should have a type identifier"
2124 );
2125
2126 let by_id = registry.method_by_id(add_method.id);
2128 assert!(by_id.is_some());
2129 assert_eq!(by_id.unwrap().name, "add");
2130 }
2131
2132 #[test]
2133 fn test_range_service_registration() {
2134 let mut registry = ServiceRegistry::new();
2135 range_service_register(&mut registry);
2136
2137 let service = registry
2139 .service("RangeService")
2140 .expect("RangeService should exist");
2141 assert_eq!(service.name, "RangeService");
2142
2143 let range_method = service.method("range").expect("range method should exist");
2145 assert_eq!(range_method.name, "range");
2146 assert_eq!(range_method.full_name, "RangeService.range");
2147 assert!(
2148 range_method.is_streaming,
2149 "range should be a streaming method"
2150 );
2151 }
2152
2153 #[test]
2154 fn test_multiple_services_registration() {
2155 let mut registry = ServiceRegistry::new();
2156
2157 adder_register(&mut registry);
2159 range_service_register(&mut registry);
2160
2161 assert_eq!(registry.service_count(), 2);
2163 assert_eq!(registry.method_count(), 2);
2164
2165 let add_id = registry.resolve_method_id("Adder", "add").unwrap();
2167 let range_id = registry.resolve_method_id("RangeService", "range").unwrap();
2168 assert_ne!(add_id, range_id);
2169
2170 assert!(add_id.0 >= 1);
2172 assert!(range_id.0 >= 1);
2173 }
2174
2175 #[test]
2176 fn test_lookup_by_name() {
2177 let mut registry = ServiceRegistry::new();
2178 adder_register(&mut registry);
2179 range_service_register(&mut registry);
2180
2181 assert!(registry.lookup_method("Adder", "add").is_some());
2183 assert!(registry.lookup_method("RangeService", "range").is_some());
2184
2185 assert!(registry.lookup_method("Adder", "subtract").is_none());
2187 assert!(registry.lookup_method("NonExistent", "add").is_none());
2188 }
2189
2190 #[test]
2191 fn test_registry_client_method_ids() {
2192 use rapace_transport_mem::InProcTransport;
2194 use std::sync::Arc;
2195
2196 let mut registry = ServiceRegistry::new();
2197
2198 adder_register(&mut registry);
2200 range_service_register(&mut registry);
2201
2202 let add_id = registry.resolve_method_id("Adder", "add").unwrap();
2204 let range_id = registry.resolve_method_id("RangeService", "range").unwrap();
2205
2206 assert_eq!(add_id.0, 1, "First method should have ID 1");
2207 assert_eq!(range_id.0, 2, "Second method should have ID 2");
2208
2209 let (client_transport, _server_transport) = InProcTransport::pair();
2211 let client_session = RpcSession::new(Arc::new(client_transport));
2212 let client = AdderRegistryClient::new(Arc::new(client_session), ®istry);
2213
2214 assert_eq!(client.add_method_id, 1);
2216 }
2217
2218 #[test]
2219 fn test_doc_capture() {
2220 let mut registry = ServiceRegistry::new();
2222 adder_register(&mut registry);
2223
2224 let service = registry
2225 .service("Adder")
2226 .expect("Adder service should exist");
2227
2228 assert!(
2230 service.doc.contains("Simple arithmetic service"),
2231 "Service doc should contain trait doc comment, got: {:?}",
2232 service.doc
2233 );
2234
2235 let add_method = service.method("add").expect("add method should exist");
2237 assert!(
2238 add_method.doc.contains("Add two numbers"),
2239 "Method doc should contain method doc comment, got: {:?}",
2240 add_method.doc
2241 );
2242 }
2243
2244 #[test]
2245 fn test_streaming_method_doc_capture() {
2246 let mut registry = ServiceRegistry::new();
2248 range_service_register(&mut registry);
2249
2250 let service = registry
2251 .service("RangeService")
2252 .expect("RangeService should exist");
2253
2254 assert!(
2256 service.doc.contains("Service with server-streaming RPC"),
2257 "Service doc should contain trait doc comment, got: {:?}",
2258 service.doc
2259 );
2260
2261 let range_method = service.method("range").expect("range method should exist");
2263 assert!(
2264 range_method.doc.contains("Stream numbers"),
2265 "Method doc should contain method doc comment, got: {:?}",
2266 range_method.doc
2267 );
2268 }
2269}