lite_sync/request_response/one_to_one.rs
1/// Lightweight bidirectional request-response channel for one-to-one communication
2///
3/// Optimized for strict request-response pattern where side A sends requests
4/// and side B must respond before A can send the next request. No buffer needed.
5///
6/// 轻量级一对一双向请求-响应通道
7///
8/// 为严格的请求-响应模式优化,A方发送请求,B方必须响应后A方才能发送下一个请求。
9/// 无需缓冲区。
10use std::cell::UnsafeCell;
11use std::mem::MaybeUninit;
12use std::sync::atomic::{AtomicU8, AtomicBool, Ordering};
13use std::future::Future;
14use std::pin::Pin;
15use std::task::{Context, Poll};
16use std::sync::Arc;
17
18use crate::atomic_waker::AtomicWaker;
19use super::common::{ChannelError, state};
20
21// Channel states - re-export from core
22const IDLE: u8 = state::IDLE;
23const WAITING_RESPONSE: u8 = state::WAITING_RESPONSE;
24const RESPONSE_READY: u8 = state::RESPONSE_READY;
25
26/// Internal state for bidirectional request-response channel
27///
28/// Uses atomic operations and UnsafeCell for lock-free bidirectional communication.
29/// State machine ensures strict request-response ordering.
30///
31/// 双向请求-响应 channel 的内部状态
32///
33/// 使用原子操作和 UnsafeCell 实现无锁的双向通信。
34/// 通过状态机确保严格的请求-响应顺序。
35struct Inner<Req, Resp> {
36 /// Channel state
37 state: AtomicU8,
38
39 /// Waker for A waiting for response
40 a_waker: AtomicWaker,
41
42 /// Waker for B waiting for request
43 b_waker: AtomicWaker,
44
45 /// Whether side A is closed
46 a_closed: AtomicBool,
47
48 /// Whether side B is closed
49 b_closed: AtomicBool,
50
51 /// Storage for request data
52 request: UnsafeCell<MaybeUninit<Req>>,
53
54 /// Storage for response data
55 response: UnsafeCell<MaybeUninit<Resp>>,
56}
57
58// SAFETY: Access to UnsafeCell is synchronized via atomic state machine
59unsafe impl<Req: Send, Resp: Send> Send for Inner<Req, Resp> {}
60unsafe impl<Req: Send, Resp: Send> Sync for Inner<Req, Resp> {}
61
62impl<Req, Resp> std::fmt::Debug for Inner<Req, Resp> {
63 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
64 let state = self.current_state();
65 let state_str = match state {
66 IDLE => "Idle",
67 WAITING_RESPONSE => "WaitingResponse",
68 RESPONSE_READY => "ResponseReady",
69 _ => "Unknown",
70 };
71 f.debug_struct("Inner")
72 .field("state", &state_str)
73 .field("a_closed", &self.is_a_closed())
74 .field("b_closed", &self.is_b_closed())
75 .finish()
76 }
77}
78
79impl<Req, Resp> Inner<Req, Resp> {
80 /// Create new channel internal state
81 ///
82 /// 创建新的 channel 内部状态
83 #[inline]
84 fn new() -> Self {
85 Self {
86 state: AtomicU8::new(IDLE),
87 a_waker: AtomicWaker::new(),
88 b_waker: AtomicWaker::new(),
89 a_closed: AtomicBool::new(false),
90 b_closed: AtomicBool::new(false),
91 request: UnsafeCell::new(MaybeUninit::uninit()),
92 response: UnsafeCell::new(MaybeUninit::uninit()),
93 }
94 }
95
96 /// Try to send request (state transition: IDLE -> WAITING_RESPONSE)
97 ///
98 /// 尝试发送请求(状态转换:IDLE -> WAITING_RESPONSE)
99 #[inline]
100 fn try_send_request(&self) -> bool {
101 self.state.compare_exchange(
102 IDLE,
103 WAITING_RESPONSE,
104 Ordering::AcqRel,
105 Ordering::Acquire,
106 ).is_ok()
107 }
108
109 /// Mark response as ready (state transition: WAITING_RESPONSE -> RESPONSE_READY)
110 ///
111 /// 标记响应已就绪(状态转换:WAITING_RESPONSE -> RESPONSE_READY)
112 #[inline]
113 fn mark_response_ready(&self) {
114 self.state.store(RESPONSE_READY, Ordering::Release);
115 }
116
117 /// Complete response reception (state transition: RESPONSE_READY -> IDLE)
118 ///
119 /// 完成响应接收(状态转换:RESPONSE_READY -> IDLE)
120 #[inline]
121 fn complete_response(&self) {
122 self.state.store(IDLE, Ordering::Release);
123 }
124
125 /// Check if side A is closed
126 ///
127 /// 检查 A 方是否已关闭
128 #[inline]
129 fn is_a_closed(&self) -> bool {
130 self.a_closed.load(Ordering::Acquire)
131 }
132
133 /// Check if side B is closed
134 ///
135 /// 检查 B 方是否已关闭
136 #[inline]
137 fn is_b_closed(&self) -> bool {
138 self.b_closed.load(Ordering::Acquire)
139 }
140
141 /// Get current state
142 ///
143 /// 获取当前状态
144 #[inline]
145 fn current_state(&self) -> u8 {
146 self.state.load(Ordering::Acquire)
147 }
148}
149
150/// Side A endpoint (request sender, response receiver)
151///
152/// A 方的 channel 端点(请求发送方,响应接收方)
153pub struct SideA<Req, Resp> {
154 inner: Arc<Inner<Req, Resp>>,
155}
156
157impl<Req, Resp> std::fmt::Debug for SideA<Req, Resp> {
158 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
159 f.debug_struct("SideA")
160 .field("inner", &self.inner)
161 .finish()
162 }
163}
164
165/// Side B endpoint (request receiver, response sender)
166///
167/// B 方的 channel 端点(请求接收方,响应发送方)
168pub struct SideB<Req, Resp> {
169 inner: Arc<Inner<Req, Resp>>,
170}
171
172impl<Req, Resp> std::fmt::Debug for SideB<Req, Resp> {
173 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
174 f.debug_struct("SideB")
175 .field("inner", &self.inner)
176 .finish()
177 }
178}
179
180/// Create a new request-response channel
181///
182/// Returns (SideA, SideB) tuple representing both ends of the channel.
183///
184/// 创建一个新的请求-响应 channel
185///
186/// 返回 (SideA, SideB) 元组,分别代表 channel 的两端。
187///
188/// # Example
189///
190/// ```
191/// use lite_sync::request_response::one_to_one::channel;
192///
193/// # tokio_test::block_on(async {
194/// let (side_a, side_b) = channel::<String, i32>();
195///
196/// // Side B uses convenient handle_request method
197/// tokio::spawn(async move {
198/// while side_b.handle_request(|request| request.len() as i32).await.is_ok() {
199/// // Continue handling requests
200/// }
201/// });
202///
203/// let response = side_a.request("Hello".to_string()).await;
204/// assert_eq!(response, Ok(5));
205/// # });
206/// ```
207///
208/// # Advanced Example: Async Processing
209///
210/// ```
211/// use lite_sync::request_response::one_to_one::channel;
212///
213/// # tokio_test::block_on(async {
214/// let (side_a, side_b) = channel::<String, String>();
215///
216/// tokio::spawn(async move {
217/// while side_b.handle_request_async(|req| async move {
218/// // Async processing logic
219/// req.to_uppercase()
220/// }).await.is_ok() {
221/// // Continue handling
222/// }
223/// });
224///
225/// let result = side_a.request("hello".to_string()).await;
226/// assert_eq!(result, Ok("HELLO".to_string()));
227/// # });
228/// ```
229#[inline]
230pub fn channel<Req, Resp>() -> (SideA<Req, Resp>, SideB<Req, Resp>) {
231 let inner = Arc::new(Inner::new());
232
233 let side_a = SideA {
234 inner: inner.clone(),
235 };
236
237 let side_b = SideB {
238 inner,
239 };
240
241 (side_a, side_b)
242}
243
244impl<Req, Resp> SideA<Req, Resp> {
245 /// Send a request and wait for response
246 ///
247 /// This method will:
248 /// 1. Wait for channel to be idle (if previous request is still being processed)
249 /// 2. Send request to side B
250 /// 3. Wait for side B's response
251 /// 4. Return the response
252 ///
253 /// 发送请求并等待响应
254 ///
255 /// 这个方法会:
256 /// 1. 等待 channel 进入空闲状态(如果之前的请求还在处理中)
257 /// 2. 发送请求到 B 方
258 /// 3. 等待 B 方的响应
259 /// 4. 返回响应
260 ///
261 /// # Returns
262 ///
263 /// - `Ok(response)`: Received response from side B
264 /// - `Err(ChannelError::Closed)`: Side B has been closed
265 ///
266 /// # Example
267 ///
268 /// ```
269 /// # use lite_sync::request_response::one_to_one::channel;
270 /// # tokio_test::block_on(async {
271 /// let (side_a, side_b) = channel::<String, i32>();
272 ///
273 /// tokio::spawn(async move {
274 /// while let Ok(guard) = side_b.recv_request().await {
275 /// let response = guard.request().len() as i32;
276 /// guard.reply(response);
277 /// }
278 /// });
279 ///
280 /// let response = side_a.request("Hello".to_string()).await;
281 /// assert_eq!(response, Ok(5));
282 /// # });
283 /// ```
284 pub async fn request(&self, req: Req) -> Result<Resp, ChannelError> {
285 // Send request
286 self.send_request(req).await?;
287
288 // Wait for response
289 self.recv_response().await
290 }
291
292 /// Send request (without waiting for response)
293 ///
294 /// Will wait if channel is not in idle state.
295 ///
296 /// 发送请求(不等待响应)
297 ///
298 /// 如果 channel 不在空闲状态,会等待直到可以发送。
299 async fn send_request(&self, req: Req) -> Result<(), ChannelError> {
300 SendRequest {
301 inner: &self.inner,
302 request: Some(req),
303 registered: false,
304 }.await
305 }
306
307 /// Wait for and receive response
308 ///
309 /// 等待并接收响应
310 ///
311 /// # Returns
312 ///
313 /// - `Ok(response)`: Received response
314 /// - `Err(ChannelError::Closed)`: Side B closed
315 async fn recv_response(&self) -> Result<Resp, ChannelError> {
316 RecvResponse {
317 inner: &self.inner,
318 registered: false,
319 }.await
320 }
321}
322
323/// Request guard that enforces B must reply
324///
325/// This guard ensures that B must call `reply()` before dropping the guard.
326/// If the guard is dropped without replying, it will panic to prevent A from deadlocking.
327///
328/// 强制 B 必须回复的 Guard
329///
330/// 这个 guard 确保 B 必须在丢弃 guard 之前调用 `reply()`。
331/// 如果 guard 在没有回复的情况下被丢弃,会 panic 以防止 A 死锁。
332pub struct RequestGuard<'a, Req, Resp>
333where
334 Req: Send, Resp: Send,
335{
336 inner: &'a Inner<Req, Resp>,
337 req: Option<Req>,
338 replied: bool,
339}
340
341impl<'a, Req, Resp> std::fmt::Debug for RequestGuard<'a, Req, Resp>
342where
343 Req: Send + std::fmt::Debug,
344 Resp: Send,
345{
346 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
347 f.debug_struct("RequestGuard")
348 .field("req", &self.req)
349 .finish_non_exhaustive()
350 }
351}
352
353// PartialEq for testing purposes - comparing RequestGuards doesn't make sense
354// but we need it for Result<RequestGuard, _> comparisons in tests
355impl<'a, Req, Resp> PartialEq for RequestGuard<'a, Req, Resp>
356where
357 Req: Send + PartialEq,
358 Resp: Send,
359{
360 fn eq(&self, other: &Self) -> bool {
361 // Two guards are equal if they hold the same request value
362 // This is mainly for testing purposes
363 self.req == other.req
364 }
365}
366
367impl<'a, Req, Resp> RequestGuard<'a, Req, Resp>
368where
369 Req: Send, Resp: Send,
370{
371 /// Get a reference to the request
372 ///
373 /// 获取请求内容的引用
374 #[inline]
375 pub fn request(&self) -> &Req {
376 self.req.as_ref().expect("RequestGuard logic error: request already consumed")
377 }
378
379 /// Consume the guard and send reply
380 ///
381 /// This method will:
382 /// 1. Store the response
383 /// 2. Update state to RESPONSE_READY
384 /// 3. Wake up side A
385 ///
386 /// 消耗 Guard 并发送回复
387 ///
388 /// 这个方法会:
389 /// 1. 存储响应
390 /// 2. 更新状态为 RESPONSE_READY
391 /// 3. 唤醒 A 方
392 #[inline]
393 pub fn reply(mut self, resp: Resp) {
394 // Store response
395 // SAFETY: Side B has exclusive access to response storage
396 unsafe {
397 (*self.inner.response.get()).write(resp);
398 }
399
400 // Mark state as RESPONSE_READY
401 self.inner.mark_response_ready();
402
403 // Wake up side A
404 self.inner.a_waker.wake();
405
406 // Mark as replied (prevent Drop panic)
407 self.replied = true;
408 }
409
410}
411
412/// Drop guard: If B drops the guard without calling `reply`, we panic.
413/// This enforces the "must reply" protocol.
414///
415/// Drop 守卫:如果 B 不调用 `reply` 就丢弃了 Guard,我们会 panic。
416/// 这强制执行了 "必须回复" 的协议。
417impl<'a, Req, Resp> Drop for RequestGuard<'a, Req, Resp>
418where
419 Req: Send, Resp: Send,
420{
421 fn drop(&mut self) {
422 if !self.replied {
423 // B dropped the guard without replying
424 // This is a protocol error that would cause A to deadlock
425 // We must panic to prevent this
426 panic!("RequestGuard dropped without replying! This would cause the requester to deadlock. You must call reply() before dropping the guard.");
427 }
428 }
429}
430
431impl<Req, Resp> SideB<Req, Resp> {
432 /// Wait for and receive request, returning a guard that must be replied to
433 ///
434 /// The returned `RequestGuard` enforces that you must call `reply()` on it.
435 /// If you drop the guard without calling `reply()`, it will panic.
436 ///
437 /// 等待并接收请求,返回一个必须回复的 guard
438 ///
439 /// 返回的 `RequestGuard` 强制你必须调用 `reply()`。
440 /// 如果你在没有调用 `reply()` 的情况下丢弃 guard,会 panic。
441 ///
442 /// # Returns
443 ///
444 /// - `Ok(RequestGuard)`: Received request from side A
445 /// - `Err(ChannelError::Closed)`: Side A has been closed
446 ///
447 /// # Example
448 ///
449 /// ```
450 /// # use lite_sync::request_response::one_to_one::channel;
451 /// # tokio_test::block_on(async {
452 /// let (side_a, side_b) = channel::<String, i32>();
453 ///
454 /// tokio::spawn(async move {
455 /// while let Ok(guard) = side_b.recv_request().await {
456 /// let len = guard.request().len() as i32;
457 /// guard.reply(len);
458 /// }
459 /// });
460 ///
461 /// let response = side_a.request("Hello".to_string()).await;
462 /// assert_eq!(response, Ok(5));
463 /// # });
464 /// ```
465 pub async fn recv_request(&self) -> Result<RequestGuard<'_, Req, Resp>, ChannelError>
466 where
467 Req: Send,
468 Resp: Send,
469 {
470 let req = RecvRequest {
471 inner: &self.inner,
472 registered: false,
473 }.await?;
474
475 Ok(RequestGuard {
476 inner: &self.inner,
477 req: Some(req),
478 replied: false,
479 })
480 }
481
482 /// Convenient method to handle request and send response
483 ///
484 /// This method will:
485 /// 1. Wait for and receive request
486 /// 2. Call the handler function
487 /// 3. Send the response via the guard
488 ///
489 /// 处理请求并发送响应的便捷方法
490 ///
491 /// 这个方法会:
492 /// 1. 等待并接收请求
493 /// 2. 调用处理函数
494 /// 3. 通过 guard 发送响应
495 ///
496 /// # Returns
497 ///
498 /// - `Ok(())`: Successfully handled request and sent response
499 /// - `Err(ChannelError::Closed)`: Side A closed
500 ///
501 /// # Example
502 ///
503 /// ```
504 /// # use lite_sync::request_response::one_to_one::channel;
505 /// # tokio_test::block_on(async {
506 /// let (side_a, side_b) = channel::<String, i32>();
507 ///
508 /// tokio::spawn(async move {
509 /// while side_b.handle_request(|req| req.len() as i32).await.is_ok() {
510 /// // Continue handling
511 /// }
512 /// });
513 ///
514 /// let response = side_a.request("Hello".to_string()).await;
515 /// assert_eq!(response, Ok(5));
516 /// # });
517 /// ```
518 pub async fn handle_request<F>(&self, handler: F) -> Result<(), ChannelError>
519 where
520 Req: Send,
521 Resp: Send,
522 F: FnOnce(&Req) -> Resp,
523 {
524 let guard = self.recv_request().await?;
525 let resp = handler(guard.request());
526 guard.reply(resp);
527 Ok(())
528 }
529
530 /// Convenient async method to handle request and send response
531 ///
532 /// Similar to `handle_request`, but supports async handler functions.
533 /// Note: The handler takes ownership of the request to avoid lifetime issues.
534 ///
535 /// 处理请求并发送响应的异步便捷方法
536 ///
537 /// 与 `handle_request` 类似,但支持异步处理函数。
538 /// 注意:处理函数会获取请求的所有权以避免生命周期问题。
539 ///
540 /// # Example
541 ///
542 /// ```
543 /// # use lite_sync::request_response::one_to_one::channel;
544 /// # tokio_test::block_on(async {
545 /// let (side_a, side_b) = channel::<String, String>();
546 ///
547 /// tokio::spawn(async move {
548 /// while side_b.handle_request_async(|req| async move {
549 /// // Async processing - req is owned
550 /// req.to_uppercase()
551 /// }).await.is_ok() {
552 /// // Continue handling
553 /// }
554 /// });
555 ///
556 /// let response = side_a.request("hello".to_string()).await;
557 /// assert_eq!(response, Ok("HELLO".to_string()));
558 /// # });
559 /// ```
560 pub async fn handle_request_async<F, Fut>(&self, handler: F) -> Result<(), ChannelError>
561 where
562 Req: Send,
563 Resp: Send,
564 F: FnOnce(Req) -> Fut,
565 Fut: Future<Output = Resp>,
566 {
567 let mut guard = self.recv_request().await?;
568 let req = guard.req.take().expect("RequestGuard logic error: request already consumed");
569 let resp = handler(req).await;
570
571 // Manually send the reply since we've consumed the request
572 unsafe {
573 (*guard.inner.response.get()).write(resp);
574 }
575 guard.inner.mark_response_ready();
576 guard.inner.a_waker.wake();
577 guard.replied = true;
578
579 Ok(())
580 }
581}
582
583// Drop implementations to clean up wakers
584impl<Req, Resp> Drop for SideA<Req, Resp> {
585 fn drop(&mut self) {
586 // Side A closed, wake up side B that might be waiting
587 self.inner.a_closed.store(true, Ordering::Release);
588 self.inner.b_waker.wake();
589 }
590}
591
592impl<Req, Resp> Drop for SideB<Req, Resp> {
593 fn drop(&mut self) {
594 // Side B closed, wake up side A that might be waiting
595 self.inner.b_closed.store(true, Ordering::Release);
596 self.inner.a_waker.wake();
597 }
598}
599
600/// Future: Side A sends request
601struct SendRequest<'a, Req, Resp> {
602 inner: &'a Inner<Req, Resp>,
603 request: Option<Req>,
604 registered: bool,
605}
606
607// SendRequest is Unpin because we only need to move data, not pin it
608impl<Req, Resp> Unpin for SendRequest<'_, Req, Resp> {}
609
610impl<Req, Resp> Future for SendRequest<'_, Req, Resp> {
611 type Output = Result<(), ChannelError>;
612
613 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
614 let this = self.get_mut();
615
616 // Check if side B is closed
617 if this.inner.is_b_closed() {
618 return Poll::Ready(Err(ChannelError::Closed));
619 }
620
621 // Try to send request
622 if this.inner.try_send_request() {
623 // Successfully sent request
624 // SAFETY: We have exclusive access (guaranteed by state machine)
625 unsafe {
626 (*this.inner.request.get()).write(this.request.take().unwrap());
627 }
628
629 // Wake up side B
630 this.inner.b_waker.wake();
631
632 return Poll::Ready(Ok(()));
633 }
634
635 // Channel busy, register waker and wait
636 if !this.registered {
637 this.inner.a_waker.register(cx.waker());
638 this.registered = true;
639
640 // Check again (avoid race condition)
641 if this.inner.is_b_closed() {
642 return Poll::Ready(Err(ChannelError::Closed));
643 }
644
645 if this.inner.current_state() == IDLE {
646 cx.waker().wake_by_ref();
647 }
648 }
649
650 Poll::Pending
651 }
652}
653
654/// Future: Side A receives response
655struct RecvResponse<'a, Req, Resp> {
656 inner: &'a Inner<Req, Resp>,
657 registered: bool,
658}
659
660// RecvResponse is Unpin because it only holds references and a bool
661impl<Req, Resp> Unpin for RecvResponse<'_, Req, Resp> {}
662
663impl<Req, Resp> Future for RecvResponse<'_, Req, Resp> {
664 type Output = Result<Resp, ChannelError>;
665
666 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
667 // Check if state is RESPONSE_READY first
668 // Try to receive sent response even if B closed
669 if self.inner.current_state() == RESPONSE_READY {
670 // Take the response
671 // SAFETY: Response must exist when state is RESPONSE_READY
672 let response = unsafe {
673 (*self.inner.response.get()).assume_init_read()
674 };
675
676 // Reset to IDLE for next round
677 self.inner.complete_response();
678
679 return Poll::Ready(Ok(response));
680 }
681
682 // Check if B closed (after confirming no response)
683 if self.inner.is_b_closed() {
684 return Poll::Ready(Err(ChannelError::Closed));
685 }
686
687 // Register waker
688 if !self.registered {
689 self.inner.a_waker.register(cx.waker());
690 self.registered = true;
691
692 // Check state again (avoid race condition)
693 if self.inner.current_state() == RESPONSE_READY {
694 let response = unsafe {
695 (*self.inner.response.get()).assume_init_read()
696 };
697 self.inner.complete_response();
698 return Poll::Ready(Ok(response));
699 }
700
701 // Check again if B closed
702 if self.inner.is_b_closed() {
703 return Poll::Ready(Err(ChannelError::Closed));
704 }
705 }
706
707 Poll::Pending
708 }
709}
710
711/// Future: Side B receives request
712struct RecvRequest<'a, Req, Resp> {
713 inner: &'a Inner<Req, Resp>,
714 registered: bool,
715}
716
717// RecvRequest is Unpin because it only holds references and a bool
718impl<Req, Resp> Unpin for RecvRequest<'_, Req, Resp> {}
719
720impl<Req, Resp> Future for RecvRequest<'_, Req, Resp> {
721 type Output = Result<Req, ChannelError>;
722
723 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
724 // Check state first, then access data (fixes race condition)
725 // Try to receive sent request even if A closed
726 if self.inner.current_state() == WAITING_RESPONSE {
727 // Take the request
728 // SAFETY: Request must exist when state is WAITING_RESPONSE
729 let request = unsafe {
730 (*self.inner.request.get()).assume_init_read()
731 };
732
733 return Poll::Ready(Ok(request));
734 }
735
736 // Check if A closed (after confirming no request)
737 if self.inner.is_a_closed() {
738 return Poll::Ready(Err(ChannelError::Closed));
739 }
740
741 // Register waker
742 if !self.registered {
743 self.inner.b_waker.register(cx.waker());
744 self.registered = true;
745
746 // Check again (avoid race condition)
747 if self.inner.current_state() == WAITING_RESPONSE {
748 let request = unsafe {
749 (*self.inner.request.get()).assume_init_read()
750 };
751 return Poll::Ready(Ok(request));
752 }
753
754 // Check again if A closed
755 if self.inner.is_a_closed() {
756 return Poll::Ready(Err(ChannelError::Closed));
757 }
758 }
759
760 Poll::Pending
761 }
762}
763
764
765#[cfg(test)]
766mod tests {
767 use super::*;
768 use tokio::time::{sleep, Duration};
769
770 #[tokio::test]
771 async fn test_basic_request_response() {
772 let (side_a, side_b) = channel::<String, i32>();
773
774 tokio::spawn(async move {
775 let guard = side_b.recv_request().await.unwrap();
776 assert_eq!(guard.request(), "Hello");
777 guard.reply(42);
778 });
779
780 let response = side_a.request("Hello".to_string()).await;
781 assert_eq!(response, Ok(42));
782 }
783
784 #[tokio::test]
785 async fn test_multiple_rounds() {
786 let (side_a, side_b) = channel::<i32, i32>();
787
788 tokio::spawn(async move {
789 for i in 0..5 {
790 let guard = side_b.recv_request().await.unwrap();
791 assert_eq!(*guard.request(), i);
792 guard.reply(i * 2);
793 }
794 });
795
796 for i in 0..5 {
797 let response = side_a.request(i).await;
798 assert_eq!(response, Ok(i * 2));
799 }
800 }
801
802 #[tokio::test]
803 async fn test_delayed_response() {
804 let (side_a, side_b) = channel::<String, String>();
805
806 tokio::spawn(async move {
807 let guard = side_b.recv_request().await.unwrap();
808 sleep(Duration::from_millis(50)).await;
809 let response = guard.request().to_uppercase();
810 guard.reply(response);
811 });
812
813 let response = side_a.request("hello".to_string()).await;
814 assert_eq!(response, Ok("HELLO".to_string()));
815 }
816
817 #[tokio::test]
818 async fn test_side_b_closes() {
819 let (side_a, side_b) = channel::<i32, i32>();
820
821 // Side B closes immediately
822 drop(side_b);
823
824 // Side A should receive Err
825 let response = side_a.request(42).await;
826 assert_eq!(response, Err(ChannelError::Closed));
827 }
828
829 #[tokio::test]
830 async fn test_side_a_closes() {
831 let (side_a, side_b) = channel::<i32, i32>();
832
833 // Side A closes immediately
834 drop(side_a);
835
836 // Side B should receive Err
837 let request = side_b.recv_request().await;
838 assert_eq!(request, Err(ChannelError::Closed));
839 }
840
841 #[tokio::test]
842 async fn test_concurrent_requests() {
843 let (side_a, side_b) = channel::<i32, i32>();
844
845 let handle_b = tokio::spawn(async move {
846 let mut count = 0;
847 loop {
848 if let Ok(guard) = side_b.recv_request().await {
849 count += 1;
850 let response = *guard.request() * 2;
851 guard.reply(response);
852 } else {
853 break;
854 }
855 }
856 count
857 });
858
859 let handle_a = tokio::spawn(async move {
860 for i in 0..10 {
861 let response = side_a.request(i).await.unwrap();
862 assert_eq!(response, i * 2);
863 }
864 drop(side_a);
865 });
866
867 handle_a.await.unwrap();
868 let count = handle_b.await.unwrap();
869 assert_eq!(count, 10);
870 }
871
872 #[tokio::test]
873 async fn test_string_messages() {
874 let (side_a, side_b) = channel::<String, String>();
875
876 tokio::spawn(async move {
877 loop {
878 if let Ok(guard) = side_b.recv_request().await {
879 let response = format!("Echo: {}", guard.request());
880 guard.reply(response);
881 } else {
882 break;
883 }
884 }
885 });
886
887 let messages = vec!["Hello", "World", "Rust"];
888 for msg in messages {
889 let response = side_a.request(msg.to_string()).await.unwrap();
890 assert_eq!(response, format!("Echo: {}", msg));
891 }
892 }
893
894 #[tokio::test]
895 async fn test_handle_request() {
896 let (side_a, side_b) = channel::<i32, i32>();
897
898 tokio::spawn(async move {
899 // Using handle_request convenience method
900 while side_b.handle_request(|req| req * 3).await.is_ok() {
901 // Continue handling
902 }
903 });
904
905 for i in 0..5 {
906 let response = side_a.request(i).await.unwrap();
907 assert_eq!(response, i * 3);
908 }
909 }
910
911 #[tokio::test]
912 async fn test_handle_request_async() {
913 let (side_a, side_b) = channel::<String, usize>();
914
915 tokio::spawn(async move {
916 // Using handle_request_async async convenience method
917 while side_b.handle_request_async(|req| async move {
918 sleep(Duration::from_millis(10)).await;
919 req.len()
920 }).await.is_ok() {
921 // Continue handling
922 }
923 });
924
925 let test_strings = vec!["Hello", "World", "Rust", "Async"];
926 for s in test_strings {
927 let response = side_a.request(s.to_string()).await.unwrap();
928 assert_eq!(response, s.len());
929 }
930 }
931
932 #[tokio::test]
933 async fn test_error_display() {
934 // Test Display implementation for error types
935 assert_eq!(format!("{}", ChannelError::Closed), "channel closed");
936 }
937
938 #[tokio::test]
939 async fn test_multiple_handle_request_rounds() {
940 let (side_a, side_b) = channel::<String, String>();
941
942 let handle = tokio::spawn(async move {
943 let mut count = 0;
944 // Manually handle to maintain state
945 while let Ok(guard) = side_b.recv_request().await {
946 count += 1;
947 let resp = format!("{}:{}", count, guard.request().to_uppercase());
948 guard.reply(resp);
949 }
950 count
951 });
952
953 let response1 = side_a.request("hello".to_string()).await.unwrap();
954 assert_eq!(response1, "1:HELLO");
955
956 let response2 = side_a.request("world".to_string()).await.unwrap();
957 assert_eq!(response2, "2:WORLD");
958
959 let response3 = side_a.request("rust".to_string()).await.unwrap();
960 assert_eq!(response3, "3:RUST");
961
962 drop(side_a);
963 let count = handle.await.unwrap();
964 assert_eq!(count, 3);
965 }
966
967 #[tokio::test]
968 async fn test_request_guard_must_reply() {
969 let (side_a, side_b) = channel::<i32, i32>();
970
971 let handle = tokio::spawn(async move {
972 let _guard = side_b.recv_request().await.unwrap();
973 // Intentionally not calling reply() - this should panic
974 });
975
976 // Send a request
977 tokio::spawn(async move {
978 let _ = side_a.request(42).await;
979 });
980
981 // Wait for the spawned task and verify it panicked
982 let result = handle.await;
983 assert!(result.is_err(), "Task should have panicked");
984
985 // Verify the panic message contains our expected text
986 if let Err(e) = result {
987 if let Ok(panic_payload) = e.try_into_panic() {
988 if let Some(s) = panic_payload.downcast_ref::<String>() {
989 assert!(s.contains("RequestGuard dropped without replying"),
990 "Panic message should mention RequestGuard: {}", s);
991 } else if let Some(s) = panic_payload.downcast_ref::<&str>() {
992 assert!(s.contains("RequestGuard dropped without replying"),
993 "Panic message should mention RequestGuard: {}", s);
994 } else {
995 panic!("Unexpected panic type");
996 }
997 }
998 }
999 }
1000}