lite_sync/request_response/many_to_one.rs
1/// Many-to-one bidirectional request-response channel
2///
3/// Optimized for multiple request senders (side A) communicating with a single
4/// response handler (side B). Uses lock-free queue for concurrent request submission.
5///
6/// 多对一双向请求-响应通道
7///
8/// 为多个请求发送方(A方)与单个响应处理方(B方)通信而优化。
9/// 使用无锁队列实现并发请求提交。
10use std::sync::Arc;
11use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
12use std::future::Future;
13use std::pin::Pin;
14use std::task::{Context, Poll};
15use crossbeam_queue::SegQueue;
16
17use crate::oneshot::generic::Sender as OneshotSender;
18use super::common::ChannelError;
19
20/// Internal request wrapper containing request data and response channel
21///
22/// 内部请求包装器,包含请求数据和响应通道
23struct RequestWrapper<Req, Resp> {
24 /// The actual request data
25 ///
26 /// 实际的请求数据
27 request: Req,
28
29 /// Oneshot sender to return the response
30 ///
31 /// 用于返回响应的 oneshot sender
32 response_tx: OneshotSender<Resp>,
33}
34
35/// Shared internal state for many-to-one channel
36///
37/// 多对一通道的共享内部状态
38struct Inner<Req, Resp> {
39 /// Lock-free queue for pending requests
40 ///
41 /// 待处理请求的无锁队列
42 queue: SegQueue<RequestWrapper<Req, Resp>>,
43
44 /// Whether side B (receiver) is closed
45 ///
46 /// B 方(接收方)是否已关闭
47 b_closed: AtomicBool,
48
49 /// Number of active SideA instances
50 ///
51 /// 活跃的 SideA 实例数量
52 sender_count: AtomicUsize,
53
54 /// Waker for side B waiting for requests
55 ///
56 /// B 方等待请求的 waker
57 b_waker: crate::atomic_waker::AtomicWaker,
58}
59
60impl<Req, Resp> std::fmt::Debug for Inner<Req, Resp> {
61 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
62 let sender_count = self.sender_count.load(Ordering::Acquire);
63 f.debug_struct("Inner")
64 .field("b_closed", &self.is_b_closed())
65 .field("sender_count", &sender_count)
66 .finish()
67 }
68}
69
70impl<Req, Resp> Inner<Req, Resp> {
71 /// Create new shared state
72 ///
73 /// 创建新的共享状态
74 #[inline]
75 fn new() -> Self {
76 Self {
77 queue: SegQueue::new(),
78 b_closed: AtomicBool::new(false),
79 sender_count: AtomicUsize::new(1), // Start with 1 sender
80 b_waker: crate::atomic_waker::AtomicWaker::new(),
81 }
82 }
83
84 /// Check if side B is closed
85 ///
86 /// 检查 B 方是否已关闭
87 #[inline]
88 fn is_b_closed(&self) -> bool {
89 self.b_closed.load(Ordering::Acquire)
90 }
91}
92
93/// Side A endpoint (request sender, response receiver) - can be cloned
94///
95/// A 方的 channel 端点(请求发送方,响应接收方)- 可以克隆
96pub struct SideA<Req, Resp> {
97 inner: Arc<Inner<Req, Resp>>,
98}
99
100impl<Req, Resp> std::fmt::Debug for SideA<Req, Resp> {
101 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
102 f.debug_struct("SideA")
103 .field("inner", &self.inner)
104 .finish()
105 }
106}
107
108impl<Req, Resp> Clone for SideA<Req, Resp> {
109 fn clone(&self) -> Self {
110 // Increment sender count with Relaxed (reads will use Acquire)
111 self.inner.sender_count.fetch_add(1, Ordering::Relaxed);
112 Self {
113 inner: self.inner.clone(),
114 }
115 }
116}
117
118// Drop implementation for SideA to decrement sender count
119impl<Req, Resp> Drop for SideA<Req, Resp> {
120 fn drop(&mut self) {
121 // Decrement sender count with Release ordering to ensure visibility
122 if self.inner.sender_count.fetch_sub(1, Ordering::Release) == 1 {
123 // This was the last sender, wake up side B
124 self.inner.b_waker.wake();
125 }
126 }
127}
128
129/// Side B endpoint (request receiver, response sender) - single instance
130///
131/// B 方的 channel 端点(请求接收方,响应发送方)- 单实例
132pub struct SideB<Req, Resp> {
133 inner: Arc<Inner<Req, Resp>>,
134}
135
136impl<Req, Resp> std::fmt::Debug for SideB<Req, Resp> {
137 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
138 f.debug_struct("SideB")
139 .field("inner", &self.inner)
140 .finish()
141 }
142}
143
144/// Create a new many-to-one request-response channel
145///
146/// Returns (SideA, SideB) tuple. SideA can be cloned to create multiple senders.
147///
148/// 创建一个新的多对一请求-响应 channel
149///
150/// 返回 (SideA, SideB) 元组。SideA 可以克隆以创建多个发送方。
151///
152/// # Example
153///
154/// ```
155/// use lite_sync::request_response::many_to_one::channel;
156///
157/// # tokio_test::block_on(async {
158/// let (side_a, side_b) = channel::<String, i32>();
159///
160/// // Clone side_a for multiple senders
161/// let side_a2 = side_a.clone();
162///
163/// // Side B handles requests
164/// tokio::spawn(async move {
165/// while let Ok(guard) = side_b.recv_request().await {
166/// let response = guard.request().len() as i32;
167/// guard.reply(response);
168/// }
169/// });
170///
171/// // Multiple senders can send concurrently
172/// let response1 = side_a.request("Hello".to_string()).await;
173/// let response2 = side_a2.request("World".to_string()).await;
174///
175/// assert_eq!(response1, Ok(5));
176/// assert_eq!(response2, Ok(5));
177/// # });
178/// ```
179#[inline]
180pub fn channel<Req, Resp>() -> (SideA<Req, Resp>, SideB<Req, Resp>) {
181 let inner = Arc::new(Inner::new());
182
183 let side_a = SideA {
184 inner: inner.clone(),
185 };
186
187 let side_b = SideB {
188 inner,
189 };
190
191 (side_a, side_b)
192}
193
194impl<Req, Resp> SideA<Req, Resp> {
195 /// Send a request and wait for response
196 ///
197 /// This method will:
198 /// 1. Push request to the queue
199 /// 2. Wait for side B to process and respond
200 /// 3. Return the response
201 ///
202 /// 发送请求并等待响应
203 ///
204 /// 这个方法会:
205 /// 1. 将请求推入队列
206 /// 2. 等待 B 方处理并响应
207 /// 3. 返回响应
208 ///
209 /// # Returns
210 ///
211 /// - `Ok(response)`: Received response from side B
212 /// - `Err(ChannelError::Closed)`: Side B has been closed
213 ///
214 /// # Example
215 ///
216 /// ```
217 /// # use lite_sync::request_response::many_to_one::channel;
218 /// # tokio_test::block_on(async {
219 /// let (side_a, side_b) = channel::<String, i32>();
220 ///
221 /// tokio::spawn(async move {
222 /// while let Ok(guard) = side_b.recv_request().await {
223 /// let len = guard.request().len() as i32;
224 /// guard.reply(len);
225 /// }
226 /// });
227 ///
228 /// let response = side_a.request("Hello".to_string()).await;
229 /// assert_eq!(response, Ok(5));
230 /// # });
231 /// ```
232 pub async fn request(&self, req: Req) -> Result<Resp, ChannelError> {
233 // Check if B is closed first
234 if self.inner.is_b_closed() {
235 return Err(ChannelError::Closed);
236 }
237
238 // Create oneshot channel for response
239 let (response_tx, response_rx) = OneshotSender::<Resp>::new();
240
241 // Push request to queue
242 self.inner.queue.push(RequestWrapper {
243 request: req,
244 response_tx,
245 });
246
247 // Wake up side B
248 self.inner.b_waker.wake();
249
250 // Wait for response
251 Ok(response_rx.await)
252 }
253
254 /// Try to send a request without waiting for response
255 ///
256 /// Returns a future that will resolve to the response.
257 ///
258 /// 尝试发送请求但不等待响应
259 ///
260 /// 返回一个 future,将解析为响应。
261 pub fn try_request(&self, req: Req) -> Result<impl Future<Output = Result<Resp, ChannelError>>, ChannelError> {
262 // Check if B is closed first
263 if self.inner.is_b_closed() {
264 return Err(ChannelError::Closed);
265 }
266
267 // Create oneshot channel for response
268 let (response_tx, response_rx) = OneshotSender::<Resp>::new();
269
270 // Push request to queue
271 self.inner.queue.push(RequestWrapper {
272 request: req,
273 response_tx,
274 });
275
276 // Wake up side B
277 self.inner.b_waker.wake();
278
279 // Return future that waits for response
280 Ok(async move {
281 Ok(response_rx.await)
282 })
283 }
284}
285
286/// Request guard that enforces B must reply
287///
288/// This guard ensures that B must call `reply()` before dropping the guard.
289/// If the guard is dropped without replying, it will panic to prevent A from deadlocking.
290///
291/// 强制 B 必须回复的 Guard
292///
293/// 这个 guard 确保 B 必须在丢弃 guard 之前调用 `reply()`。
294/// 如果 guard 在没有回复的情况下被丢弃,会 panic 以防止 A 死锁。
295pub struct RequestGuard<Req, Resp>
296where
297 Req: Send, Resp: Send,
298{
299 req: Option<Req>,
300 response_tx: Option<OneshotSender<Resp>>,
301}
302
303impl<Req, Resp> std::fmt::Debug for RequestGuard<Req, Resp>
304where
305 Req: Send + std::fmt::Debug,
306 Resp: Send,
307{
308 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
309 f.debug_struct("RequestGuard")
310 .field("req", &self.req)
311 .finish_non_exhaustive()
312 }
313}
314
315// PartialEq for testing purposes
316impl<Req, Resp> PartialEq for RequestGuard<Req, Resp>
317where
318 Req: Send + PartialEq,
319 Resp: Send,
320{
321 fn eq(&self, other: &Self) -> bool {
322 self.req == other.req
323 }
324}
325
326impl<Req, Resp> RequestGuard<Req, Resp>
327where
328 Req: Send, Resp: Send,
329{
330 /// Get a reference to the request
331 ///
332 /// 获取请求内容的引用
333 #[inline]
334 pub fn request(&self) -> &Req {
335 self.req.as_ref().expect("RequestGuard logic error: request already consumed")
336 }
337
338 /// Consume the guard and send reply
339 ///
340 /// This method will send the response back to the requester.
341 ///
342 /// 消耗 Guard 并发送回复
343 ///
344 /// 这个方法会将响应发送回请求方。
345 #[inline]
346 pub fn reply(mut self, resp: Resp) {
347 if let Some(response_tx) = self.response_tx.take() {
348 let _ = response_tx.send(resp);
349 }
350 // Mark as replied by taking the request
351 self.req = None;
352 }
353}
354
355/// Drop guard: If B drops the guard without calling `reply`, we panic.
356/// This enforces the "must reply" protocol.
357///
358/// Drop 守卫:如果 B 不调用 `reply` 就丢弃了 Guard,我们会 panic。
359/// 这强制执行了 "必须回复" 的协议。
360impl<Req, Resp> Drop for RequestGuard<Req, Resp>
361where
362 Req: Send, Resp: Send,
363{
364 fn drop(&mut self) {
365 if self.req.is_some() {
366 // B dropped the guard without replying
367 // This is a protocol error that would cause A to deadlock
368 // We must panic to prevent this
369 panic!("RequestGuard dropped without replying! This would cause the requester to deadlock. You must call reply() before dropping the guard.");
370 }
371 }
372}
373
374impl<Req, Resp> SideB<Req, Resp> {
375 /// Wait for and receive next request, returning a guard that must be replied to
376 ///
377 /// The returned `RequestGuard` enforces that you must call `reply()` on it.
378 /// If you drop the guard without calling `reply()`, it will panic.
379 ///
380 /// 等待并接收下一个请求,返回一个必须回复的 guard
381 ///
382 /// 返回的 `RequestGuard` 强制你必须调用 `reply()`。
383 /// 如果你在没有调用 `reply()` 的情况下丢弃 guard,会 panic。
384 ///
385 /// # Returns
386 ///
387 /// - `Ok(RequestGuard)`: Received request from a side A
388 /// - `Err(ChannelError::Closed)`: All side A instances have been closed
389 ///
390 /// # Example
391 ///
392 /// ```
393 /// # use lite_sync::request_response::many_to_one::channel;
394 /// # tokio_test::block_on(async {
395 /// let (side_a, side_b) = channel::<String, i32>();
396 ///
397 /// tokio::spawn(async move {
398 /// while let Ok(guard) = side_b.recv_request().await {
399 /// let len = guard.request().len() as i32;
400 /// guard.reply(len);
401 /// }
402 /// });
403 ///
404 /// let response = side_a.request("Hello".to_string()).await;
405 /// assert_eq!(response, Ok(5));
406 /// # });
407 /// ```
408 pub async fn recv_request(&self) -> Result<RequestGuard<Req, Resp>, ChannelError>
409 where
410 Req: Send,
411 Resp: Send,
412 {
413 RecvRequest {
414 inner: &self.inner,
415 registered: false,
416 }.await
417 }
418
419 /// Convenient method to handle request and send response
420 ///
421 /// This method will:
422 /// 1. Wait for and receive request
423 /// 2. Call the handler function
424 /// 3. Send the response via the guard
425 ///
426 /// 处理请求并发送响应的便捷方法
427 ///
428 /// 这个方法会:
429 /// 1. 等待并接收请求
430 /// 2. 调用处理函数
431 /// 3. 通过 guard 发送响应
432 ///
433 /// # Example
434 ///
435 /// ```
436 /// # use lite_sync::request_response::many_to_one::channel;
437 /// # tokio_test::block_on(async {
438 /// let (side_a, side_b) = channel::<String, i32>();
439 ///
440 /// tokio::spawn(async move {
441 /// while side_b.handle_request(|req| req.len() as i32).await.is_ok() {
442 /// // Continue handling
443 /// }
444 /// });
445 ///
446 /// let response = side_a.request("Hello".to_string()).await;
447 /// assert_eq!(response, Ok(5));
448 /// # });
449 /// ```
450 pub async fn handle_request<F>(&self, handler: F) -> Result<(), ChannelError>
451 where
452 Req: Send,
453 Resp: Send,
454 F: FnOnce(&Req) -> Resp,
455 {
456 let guard = self.recv_request().await?;
457 let resp = handler(guard.request());
458 guard.reply(resp);
459 Ok(())
460 }
461
462 /// Convenient async method to handle request and send response
463 ///
464 /// Similar to `handle_request`, but supports async handler functions.
465 /// Note: The handler takes ownership of the request to avoid lifetime issues.
466 ///
467 /// 处理请求并发送响应的异步便捷方法
468 ///
469 /// 与 `handle_request` 类似,但支持异步处理函数。
470 /// 注意:处理函数会获取请求的所有权以避免生命周期问题。
471 ///
472 /// # Example
473 ///
474 /// ```
475 /// # use lite_sync::request_response::many_to_one::channel;
476 /// # tokio_test::block_on(async {
477 /// let (side_a, side_b) = channel::<String, String>();
478 ///
479 /// tokio::spawn(async move {
480 /// while side_b.handle_request_async(|req| async move {
481 /// // Async processing - req is owned
482 /// req.to_uppercase()
483 /// }).await.is_ok() {
484 /// // Continue handling
485 /// }
486 /// });
487 ///
488 /// let response = side_a.request("hello".to_string()).await;
489 /// assert_eq!(response, Ok("HELLO".to_string()));
490 /// # });
491 /// ```
492 pub async fn handle_request_async<F, Fut>(&self, handler: F) -> Result<(), ChannelError>
493 where
494 Req: Send,
495 Resp: Send,
496 F: FnOnce(Req) -> Fut,
497 Fut: Future<Output = Resp>,
498 {
499 let mut guard = self.recv_request().await?;
500 let req = guard.req.take().expect("RequestGuard logic error: request already consumed");
501 let resp = handler(req).await;
502
503 // Manually send the reply since we've consumed the request
504 if let Some(response_tx) = guard.response_tx.take() {
505 let _ = response_tx.send(resp);
506 }
507 // Mark as replied
508 guard.req = None;
509
510 Ok(())
511 }
512}
513
514// Drop implementation to clean up
515impl<Req, Resp> Drop for SideB<Req, Resp> {
516 fn drop(&mut self) {
517 // Side B closed, notify any waiting senders
518 self.inner.b_closed.store(true, Ordering::Release);
519
520 // Drain queue and drop all pending response channels
521 // The generic oneshot will handle cleanup automatically
522 while let Some(_wrapper) = self.inner.queue.pop() {
523 // Just drop the wrapper, oneshot cleanup is automatic
524 }
525 }
526}
527
528/// Future: Side B receives request
529struct RecvRequest<'a, Req, Resp> {
530 inner: &'a Inner<Req, Resp>,
531 registered: bool,
532}
533
534// RecvRequest is Unpin because it only holds references and a bool
535impl<Req, Resp> Unpin for RecvRequest<'_, Req, Resp> {}
536
537impl<Req, Resp> Future for RecvRequest<'_, Req, Resp>
538where
539 Req: Send,
540 Resp: Send,
541{
542 type Output = Result<RequestGuard<Req, Resp>, ChannelError>;
543
544 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
545 // Try to pop from queue
546 if let Some(wrapper) = self.inner.queue.pop() {
547 return Poll::Ready(Ok(RequestGuard {
548 req: Some(wrapper.request),
549 response_tx: Some(wrapper.response_tx),
550 }));
551 }
552
553 // Check if there are any senders left
554 if self.inner.sender_count.load(Ordering::Acquire) == 0 {
555 return Poll::Ready(Err(ChannelError::Closed));
556 }
557
558 // Register waker if not already registered
559 if !self.registered {
560 self.inner.b_waker.register(cx.waker());
561 self.registered = true;
562 }
563
564 // Always check queue and sender_count again before returning Pending
565 // This is critical to avoid deadlock when senders drop after waker is registered
566 if let Some(wrapper) = self.inner.queue.pop() {
567 return Poll::Ready(Ok(RequestGuard {
568 req: Some(wrapper.request),
569 response_tx: Some(wrapper.response_tx),
570 }));
571 }
572
573 // Final check if there are any senders
574 if self.inner.sender_count.load(Ordering::Acquire) == 0 {
575 return Poll::Ready(Err(ChannelError::Closed));
576 }
577
578 Poll::Pending
579 }
580}
581
582#[cfg(test)]
583mod tests {
584 use super::*;
585 use tokio::time::{sleep, Duration};
586
587 #[tokio::test]
588 async fn test_basic_many_to_one() {
589 let (side_a, side_b) = channel::<String, i32>();
590
591 tokio::spawn(async move {
592 while let Ok(guard) = side_b.recv_request().await {
593 let response = guard.request().len() as i32;
594 guard.reply(response);
595 }
596 });
597
598 let response = side_a.request("Hello".to_string()).await;
599 assert_eq!(response, Ok(5));
600 }
601
602 #[tokio::test]
603 async fn test_multiple_senders() {
604 let (side_a, side_b) = channel::<i32, i32>();
605 let side_a2 = side_a.clone();
606 let side_a3 = side_a.clone();
607
608 tokio::spawn(async move {
609 while let Ok(guard) = side_b.recv_request().await {
610 let result = *guard.request() * 2;
611 guard.reply(result);
612 }
613 });
614
615 let handle1 = tokio::spawn(async move {
616 let mut sum = 0;
617 for i in 0..10 {
618 let resp = side_a.request(i).await.unwrap();
619 sum += resp;
620 }
621 sum
622 });
623
624 let handle2 = tokio::spawn(async move {
625 let mut sum = 0;
626 for i in 10..20 {
627 let resp = side_a2.request(i).await.unwrap();
628 sum += resp;
629 }
630 sum
631 });
632
633 let handle3 = tokio::spawn(async move {
634 let mut sum = 0;
635 for i in 20..30 {
636 let resp = side_a3.request(i).await.unwrap();
637 sum += resp;
638 }
639 sum
640 });
641
642 let sum1 = handle1.await.unwrap();
643 let sum2 = handle2.await.unwrap();
644 let sum3 = handle3.await.unwrap();
645
646 // Each range should give: sum(i*2) = 2 * sum(i)
647 assert_eq!(sum1, 2 * (0..10).sum::<i32>());
648 assert_eq!(sum2, 2 * (10..20).sum::<i32>());
649 assert_eq!(sum3, 2 * (20..30).sum::<i32>());
650 }
651
652 #[tokio::test]
653 async fn test_side_b_closes() {
654 let (side_a, side_b) = channel::<i32, i32>();
655
656 // Side A closes immediately
657 drop(side_a);
658
659 // Side B should receive Err
660 let request = side_b.recv_request().await;
661 assert!(request.is_err());
662 }
663
664 #[tokio::test]
665 async fn test_all_side_a_close() {
666 let (side_a, side_b) = channel::<i32, i32>();
667 let side_a2 = side_a.clone();
668
669 // All side A instances close
670 drop(side_a);
671 drop(side_a2);
672
673 // Side B should receive Err
674 let request = side_b.recv_request().await;
675 assert!(request.is_err());
676 }
677
678 #[tokio::test]
679 async fn test_handle_request() {
680 let (side_a, side_b) = channel::<i32, i32>();
681
682 tokio::spawn(async move {
683 while side_b.handle_request(|req| req * 3).await.is_ok() {
684 // Continue handling
685 }
686 });
687
688 for i in 0..5 {
689 let response = side_a.request(i).await.unwrap();
690 assert_eq!(response, i * 3);
691 }
692 }
693
694 #[tokio::test]
695 async fn test_handle_request_async() {
696 let (side_a, side_b) = channel::<String, usize>();
697
698 tokio::spawn(async move {
699 while side_b.handle_request_async(|req| async move {
700 sleep(Duration::from_millis(10)).await;
701 req.len()
702 }).await.is_ok() {
703 // Continue handling
704 }
705 });
706
707 let test_strings = vec!["Hello", "World", "Rust"];
708 for s in test_strings {
709 let response = side_a.request(s.to_string()).await.unwrap();
710 assert_eq!(response, s.len());
711 }
712 }
713
714 #[tokio::test]
715 async fn test_concurrent_requests() {
716 let (side_a, side_b) = channel::<String, String>();
717
718 tokio::spawn(async move {
719 while side_b.handle_request_async(|req| async move {
720 sleep(Duration::from_millis(5)).await;
721 req.to_uppercase()
722 }).await.is_ok() {
723 // Continue
724 }
725 });
726
727 // Send multiple requests concurrently
728 let mut handles = vec![];
729 for i in 0..10 {
730 let side_a_clone = side_a.clone();
731 let handle = tokio::spawn(async move {
732 let msg = format!("message{}", i);
733 let resp = side_a_clone.request(msg.clone()).await.unwrap();
734 assert_eq!(resp, msg.to_uppercase());
735 });
736 handles.push(handle);
737 }
738
739 for handle in handles {
740 handle.await.unwrap();
741 }
742 }
743
744 #[tokio::test]
745 async fn test_request_guard_must_reply() {
746 let (side_a, side_b) = channel::<i32, i32>();
747
748 let handle = tokio::spawn(async move {
749 let _guard = side_b.recv_request().await.unwrap();
750 // Intentionally not calling reply() - this should panic
751 });
752
753 // Send a request
754 tokio::spawn(async move {
755 let _ = side_a.request(42).await;
756 });
757
758 // Wait for the spawned task and verify it panicked
759 let result = handle.await;
760 assert!(result.is_err(), "Task should have panicked");
761
762 // Verify the panic message contains our expected text
763 if let Err(e) = result {
764 if let Ok(panic_payload) = e.try_into_panic() {
765 if let Some(s) = panic_payload.downcast_ref::<String>() {
766 assert!(s.contains("RequestGuard dropped without replying"),
767 "Panic message should mention RequestGuard: {}", s);
768 } else if let Some(s) = panic_payload.downcast_ref::<&str>() {
769 assert!(s.contains("RequestGuard dropped without replying"),
770 "Panic message should mention RequestGuard: {}", s);
771 } else {
772 panic!("Unexpected panic type");
773 }
774 }
775 }
776 }
777}