lite_sync/request_response/
many_to_one.rs

1/// Many-to-one bidirectional request-response channel
2/// 
3/// Optimized for multiple request senders (side A) communicating with a single
4/// response handler (side B). Uses lock-free queue for concurrent request submission.
5/// 
6/// 多对一双向请求-响应通道
7/// 
8/// 为多个请求发送方(A方)与单个响应处理方(B方)通信而优化。
9/// 使用无锁队列实现并发请求提交。
10use std::sync::Arc;
11use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
12use std::future::Future;
13use std::pin::Pin;
14use std::task::{Context, Poll};
15use crossbeam_queue::SegQueue;
16
17use crate::oneshot::generic::Sender as OneshotSender;
18use super::common::ChannelError;
19
20/// Internal request wrapper containing request data and response channel
21/// 
22/// 内部请求包装器,包含请求数据和响应通道
23struct RequestWrapper<Req, Resp> {
24    /// The actual request data
25    /// 
26    /// 实际的请求数据
27    request: Req,
28    
29    /// Oneshot sender to return the response
30    /// 
31    /// 用于返回响应的 oneshot sender
32    response_tx: OneshotSender<Resp>,
33}
34
35/// Shared internal state for many-to-one channel
36/// 
37/// 多对一通道的共享内部状态
38struct Inner<Req, Resp> {
39    /// Lock-free queue for pending requests
40    /// 
41    /// 待处理请求的无锁队列
42    queue: SegQueue<RequestWrapper<Req, Resp>>,
43    
44    /// Whether side B (receiver) is closed
45    /// 
46    /// B 方(接收方)是否已关闭
47    b_closed: AtomicBool,
48    
49    /// Number of active SideA instances
50    /// 
51    /// 活跃的 SideA 实例数量
52    sender_count: AtomicUsize,
53    
54    /// Waker for side B waiting for requests
55    /// 
56    /// B 方等待请求的 waker
57    b_waker: crate::atomic_waker::AtomicWaker,
58}
59
60impl<Req, Resp> Inner<Req, Resp> {
61    /// Create new shared state
62    /// 
63    /// 创建新的共享状态
64    #[inline]
65    fn new() -> Self {
66        Self {
67            queue: SegQueue::new(),
68            b_closed: AtomicBool::new(false),
69            sender_count: AtomicUsize::new(1), // Start with 1 sender
70            b_waker: crate::atomic_waker::AtomicWaker::new(),
71        }
72    }
73    
74    /// Check if side B is closed
75    /// 
76    /// 检查 B 方是否已关闭
77    #[inline]
78    fn is_b_closed(&self) -> bool {
79        self.b_closed.load(Ordering::Acquire)
80    }
81}
82
83/// Side A endpoint (request sender, response receiver) - can be cloned
84/// 
85/// A 方的 channel 端点(请求发送方,响应接收方)- 可以克隆
86pub struct SideA<Req, Resp> {
87    inner: Arc<Inner<Req, Resp>>,
88}
89
90impl<Req, Resp> Clone for SideA<Req, Resp> {
91    fn clone(&self) -> Self {
92        // Increment sender count with Relaxed (reads will use Acquire)
93        self.inner.sender_count.fetch_add(1, Ordering::Relaxed);
94        Self {
95            inner: self.inner.clone(),
96        }
97    }
98}
99
100// Drop implementation for SideA to decrement sender count
101impl<Req, Resp> Drop for SideA<Req, Resp> {
102    fn drop(&mut self) {
103        // Decrement sender count with Release ordering to ensure visibility
104        if self.inner.sender_count.fetch_sub(1, Ordering::Release) == 1 {
105            // This was the last sender, wake up side B
106            self.inner.b_waker.wake();
107        }
108    }
109}
110
111/// Side B endpoint (request receiver, response sender) - single instance
112/// 
113/// B 方的 channel 端点(请求接收方,响应发送方)- 单实例
114pub struct SideB<Req, Resp> {
115    inner: Arc<Inner<Req, Resp>>,
116}
117
118/// Create a new many-to-one request-response channel
119/// 
120/// Returns (SideA, SideB) tuple. SideA can be cloned to create multiple senders.
121/// 
122/// 创建一个新的多对一请求-响应 channel
123/// 
124/// 返回 (SideA, SideB) 元组。SideA 可以克隆以创建多个发送方。
125/// 
126/// # Example
127/// 
128/// ```
129/// use lite_sync::request_response::many_to_one::channel;
130/// 
131/// # tokio_test::block_on(async {
132/// let (side_a, side_b) = channel::<String, i32>();
133/// 
134/// // Clone side_a for multiple senders
135/// let side_a2 = side_a.clone();
136/// 
137/// // Side B handles requests
138/// tokio::spawn(async move {
139///     while let Ok(guard) = side_b.recv_request().await {
140///         let response = guard.request().len() as i32;
141///         guard.reply(response);
142///     }
143/// });
144/// 
145/// // Multiple senders can send concurrently
146/// let response1 = side_a.request("Hello".to_string()).await;
147/// let response2 = side_a2.request("World".to_string()).await;
148/// 
149/// assert_eq!(response1, Ok(5));
150/// assert_eq!(response2, Ok(5));
151/// # });
152/// ```
153#[inline]
154pub fn channel<Req, Resp>() -> (SideA<Req, Resp>, SideB<Req, Resp>) {
155    let inner = Arc::new(Inner::new());
156    
157    let side_a = SideA {
158        inner: inner.clone(),
159    };
160    
161    let side_b = SideB {
162        inner,
163    };
164    
165    (side_a, side_b)
166}
167
168impl<Req, Resp> SideA<Req, Resp> {
169    /// Send a request and wait for response
170    /// 
171    /// This method will:
172    /// 1. Push request to the queue
173    /// 2. Wait for side B to process and respond
174    /// 3. Return the response
175    /// 
176    /// 发送请求并等待响应
177    /// 
178    /// 这个方法会:
179    /// 1. 将请求推入队列
180    /// 2. 等待 B 方处理并响应
181    /// 3. 返回响应
182    /// 
183    /// # Returns
184    /// 
185    /// - `Ok(response)`: Received response from side B
186    /// - `Err(ChannelError::Closed)`: Side B has been closed
187    /// 
188    /// # Example
189    /// 
190    /// ```
191    /// # use lite_sync::request_response::many_to_one::channel;
192    /// # tokio_test::block_on(async {
193    /// let (side_a, side_b) = channel::<String, i32>();
194    /// 
195    /// tokio::spawn(async move {
196    ///     while let Ok(guard) = side_b.recv_request().await {
197    ///         let len = guard.request().len() as i32;
198    ///         guard.reply(len);
199    ///     }
200    /// });
201    /// 
202    /// let response = side_a.request("Hello".to_string()).await;
203    /// assert_eq!(response, Ok(5));
204    /// # });
205    /// ```
206    pub async fn request(&self, req: Req) -> Result<Resp, ChannelError> {
207        // Check if B is closed first
208        if self.inner.is_b_closed() {
209            return Err(ChannelError::Closed);
210        }
211        
212        // Create oneshot channel for response
213        let (response_tx, response_rx) = OneshotSender::<Resp>::new();
214        
215        // Push request to queue
216        self.inner.queue.push(RequestWrapper {
217            request: req,
218            response_tx,
219        });
220        
221        // Wake up side B
222        self.inner.b_waker.wake();
223        
224        // Wait for response
225        Ok(response_rx.await)
226    }
227    
228    /// Try to send a request without waiting for response
229    /// 
230    /// Returns a future that will resolve to the response.
231    /// 
232    /// 尝试发送请求但不等待响应
233    /// 
234    /// 返回一个 future,将解析为响应。
235    pub fn try_request(&self, req: Req) -> Result<impl Future<Output = Result<Resp, ChannelError>>, ChannelError> {
236        // Check if B is closed first
237        if self.inner.is_b_closed() {
238            return Err(ChannelError::Closed);
239        }
240        
241        // Create oneshot channel for response
242        let (response_tx, response_rx) = OneshotSender::<Resp>::new();
243        
244        // Push request to queue
245        self.inner.queue.push(RequestWrapper {
246            request: req,
247            response_tx,
248        });
249        
250        // Wake up side B
251        self.inner.b_waker.wake();
252        
253        // Return future that waits for response
254        Ok(async move {
255            Ok(response_rx.await)
256        })
257    }
258}
259
260/// Request guard that enforces B must reply
261/// 
262/// This guard ensures that B must call `reply()` before dropping the guard.
263/// If the guard is dropped without replying, it will panic to prevent A from deadlocking.
264/// 
265/// 强制 B 必须回复的 Guard
266/// 
267/// 这个 guard 确保 B 必须在丢弃 guard 之前调用 `reply()`。
268/// 如果 guard 在没有回复的情况下被丢弃,会 panic 以防止 A 死锁。
269pub struct RequestGuard<Req, Resp>
270where
271    Req: Send, Resp: Send,
272{
273    req: Option<Req>,
274    response_tx: Option<OneshotSender<Resp>>,
275}
276
277impl<Req, Resp> std::fmt::Debug for RequestGuard<Req, Resp>
278where
279    Req: Send + std::fmt::Debug,
280    Resp: Send,
281{
282    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
283        f.debug_struct("RequestGuard")
284            .field("req", &self.req)
285            .finish_non_exhaustive()
286    }
287}
288
289// PartialEq for testing purposes
290impl<Req, Resp> PartialEq for RequestGuard<Req, Resp>
291where
292    Req: Send + PartialEq,
293    Resp: Send,
294{
295    fn eq(&self, other: &Self) -> bool {
296        self.req == other.req
297    }
298}
299
300impl<Req, Resp> RequestGuard<Req, Resp>
301where
302    Req: Send, Resp: Send,
303{
304    /// Get a reference to the request
305    /// 
306    /// 获取请求内容的引用
307    #[inline]
308    pub fn request(&self) -> &Req {
309        self.req.as_ref().expect("RequestGuard logic error: request already consumed")
310    }
311    
312    /// Consume the guard and send reply
313    /// 
314    /// This method will send the response back to the requester.
315    /// 
316    /// 消耗 Guard 并发送回复
317    /// 
318    /// 这个方法会将响应发送回请求方。
319    #[inline]
320    pub fn reply(mut self, resp: Resp) {
321        if let Some(response_tx) = self.response_tx.take() {
322            let _ = response_tx.send(resp);
323        }
324        // Mark as replied by taking the request
325        self.req = None;
326    }
327}
328
329/// Drop guard: If B drops the guard without calling `reply`, we panic.
330/// This enforces the "must reply" protocol.
331/// 
332/// Drop 守卫:如果 B 不调用 `reply` 就丢弃了 Guard,我们会 panic。
333/// 这强制执行了 "必须回复" 的协议。
334impl<Req, Resp> Drop for RequestGuard<Req, Resp>
335where
336    Req: Send, Resp: Send,
337{
338    fn drop(&mut self) {
339        if self.req.is_some() {
340            // B dropped the guard without replying
341            // This is a protocol error that would cause A to deadlock
342            // We must panic to prevent this
343            panic!("RequestGuard dropped without replying! This would cause the requester to deadlock. You must call reply() before dropping the guard.");
344        }
345    }
346}
347
348impl<Req, Resp> SideB<Req, Resp> {
349    /// Wait for and receive next request, returning a guard that must be replied to
350    /// 
351    /// The returned `RequestGuard` enforces that you must call `reply()` on it.
352    /// If you drop the guard without calling `reply()`, it will panic.
353    /// 
354    /// 等待并接收下一个请求,返回一个必须回复的 guard
355    /// 
356    /// 返回的 `RequestGuard` 强制你必须调用 `reply()`。
357    /// 如果你在没有调用 `reply()` 的情况下丢弃 guard,会 panic。
358    /// 
359    /// # Returns
360    /// 
361    /// - `Ok(RequestGuard)`: Received request from a side A
362    /// - `Err(ChannelError::Closed)`: All side A instances have been closed
363    /// 
364    /// # Example
365    /// 
366    /// ```
367    /// # use lite_sync::request_response::many_to_one::channel;
368    /// # tokio_test::block_on(async {
369    /// let (side_a, side_b) = channel::<String, i32>();
370    /// 
371    /// tokio::spawn(async move {
372    ///     while let Ok(guard) = side_b.recv_request().await {
373    ///         let len = guard.request().len() as i32;
374    ///         guard.reply(len);
375    ///     }
376    /// });
377    /// 
378    /// let response = side_a.request("Hello".to_string()).await;
379    /// assert_eq!(response, Ok(5));
380    /// # });
381    /// ```
382    pub async fn recv_request(&self) -> Result<RequestGuard<Req, Resp>, ChannelError>
383    where
384        Req: Send,
385        Resp: Send,
386    {
387        RecvRequest {
388            inner: &self.inner,
389            registered: false,
390        }.await
391    }
392    
393    /// Convenient method to handle request and send response
394    /// 
395    /// This method will:
396    /// 1. Wait for and receive request
397    /// 2. Call the handler function
398    /// 3. Send the response via the guard
399    /// 
400    /// 处理请求并发送响应的便捷方法
401    /// 
402    /// 这个方法会:
403    /// 1. 等待并接收请求
404    /// 2. 调用处理函数
405    /// 3. 通过 guard 发送响应
406    /// 
407    /// # Example
408    /// 
409    /// ```
410    /// # use lite_sync::request_response::many_to_one::channel;
411    /// # tokio_test::block_on(async {
412    /// let (side_a, side_b) = channel::<String, i32>();
413    /// 
414    /// tokio::spawn(async move {
415    ///     while side_b.handle_request(|req| req.len() as i32).await.is_ok() {
416    ///         // Continue handling
417    ///     }
418    /// });
419    /// 
420    /// let response = side_a.request("Hello".to_string()).await;
421    /// assert_eq!(response, Ok(5));
422    /// # });
423    /// ```
424    pub async fn handle_request<F>(&self, handler: F) -> Result<(), ChannelError>
425    where
426        Req: Send,
427        Resp: Send,
428        F: FnOnce(&Req) -> Resp,
429    {
430        let guard = self.recv_request().await?;
431        let resp = handler(guard.request());
432        guard.reply(resp);
433        Ok(())
434    }
435    
436    /// Convenient async method to handle request and send response
437    /// 
438    /// Similar to `handle_request`, but supports async handler functions.
439    /// Note: The handler takes ownership of the request to avoid lifetime issues.
440    /// 
441    /// 处理请求并发送响应的异步便捷方法
442    /// 
443    /// 与 `handle_request` 类似,但支持异步处理函数。
444    /// 注意:处理函数会获取请求的所有权以避免生命周期问题。
445    /// 
446    /// # Example
447    /// 
448    /// ```
449    /// # use lite_sync::request_response::many_to_one::channel;
450    /// # tokio_test::block_on(async {
451    /// let (side_a, side_b) = channel::<String, String>();
452    /// 
453    /// tokio::spawn(async move {
454    ///     while side_b.handle_request_async(|req| async move {
455    ///         // Async processing - req is owned
456    ///         req.to_uppercase()
457    ///     }).await.is_ok() {
458    ///         // Continue handling
459    ///     }
460    /// });
461    /// 
462    /// let response = side_a.request("hello".to_string()).await;
463    /// assert_eq!(response, Ok("HELLO".to_string()));
464    /// # });
465    /// ```
466    pub async fn handle_request_async<F, Fut>(&self, handler: F) -> Result<(), ChannelError>
467    where
468        Req: Send,
469        Resp: Send,
470        F: FnOnce(Req) -> Fut,
471        Fut: Future<Output = Resp>,
472    {
473        let mut guard = self.recv_request().await?;
474        let req = guard.req.take().expect("RequestGuard logic error: request already consumed");
475        let resp = handler(req).await;
476        
477        // Manually send the reply since we've consumed the request
478        if let Some(response_tx) = guard.response_tx.take() {
479            let _ = response_tx.send(resp);
480        }
481        // Mark as replied
482        guard.req = None;
483        
484        Ok(())
485    }
486}
487
488// Drop implementation to clean up
489impl<Req, Resp> Drop for SideB<Req, Resp> {
490    fn drop(&mut self) {
491        // Side B closed, notify any waiting senders
492        self.inner.b_closed.store(true, Ordering::Release);
493        
494        // Drain queue and drop all pending response channels
495        // The generic oneshot will handle cleanup automatically
496        while let Some(_wrapper) = self.inner.queue.pop() {
497            // Just drop the wrapper, oneshot cleanup is automatic
498        }
499    }
500}
501
502/// Future: Side B receives request
503struct RecvRequest<'a, Req, Resp> {
504    inner: &'a Inner<Req, Resp>,
505    registered: bool,
506}
507
508// RecvRequest is Unpin because it only holds references and a bool
509impl<Req, Resp> Unpin for RecvRequest<'_, Req, Resp> {}
510
511impl<Req, Resp> Future for RecvRequest<'_, Req, Resp>
512where
513    Req: Send,
514    Resp: Send,
515{
516    type Output = Result<RequestGuard<Req, Resp>, ChannelError>;
517    
518    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
519        // Try to pop from queue
520        if let Some(wrapper) = self.inner.queue.pop() {
521            return Poll::Ready(Ok(RequestGuard {
522                req: Some(wrapper.request),
523                response_tx: Some(wrapper.response_tx),
524            }));
525        }
526        
527        // Check if there are any senders left
528        if self.inner.sender_count.load(Ordering::Acquire) == 0 {
529            return Poll::Ready(Err(ChannelError::Closed));
530        }
531        
532        // Register waker if not already registered
533        if !self.registered {
534            self.inner.b_waker.register(cx.waker());
535            self.registered = true;
536        }
537        
538        // Always check queue and sender_count again before returning Pending
539        // This is critical to avoid deadlock when senders drop after waker is registered
540        if let Some(wrapper) = self.inner.queue.pop() {
541            return Poll::Ready(Ok(RequestGuard {
542                req: Some(wrapper.request),
543                response_tx: Some(wrapper.response_tx),
544            }));
545        }
546        
547        // Final check if there are any senders
548        if self.inner.sender_count.load(Ordering::Acquire) == 0 {
549            return Poll::Ready(Err(ChannelError::Closed));
550        }
551        
552        Poll::Pending
553    }
554}
555
556#[cfg(test)]
557mod tests {
558    use super::*;
559    use tokio::time::{sleep, Duration};
560
561    #[tokio::test]
562    async fn test_basic_many_to_one() {
563        let (side_a, side_b) = channel::<String, i32>();
564        
565        tokio::spawn(async move {
566            while let Ok(guard) = side_b.recv_request().await {
567                let response = guard.request().len() as i32;
568                guard.reply(response);
569            }
570        });
571        
572        let response = side_a.request("Hello".to_string()).await;
573        assert_eq!(response, Ok(5));
574    }
575
576    #[tokio::test]
577    async fn test_multiple_senders() {
578        let (side_a, side_b) = channel::<i32, i32>();
579        let side_a2 = side_a.clone();
580        let side_a3 = side_a.clone();
581        
582        tokio::spawn(async move {
583            while let Ok(guard) = side_b.recv_request().await {
584                let result = *guard.request() * 2;
585                guard.reply(result);
586            }
587        });
588        
589        let handle1 = tokio::spawn(async move {
590            let mut sum = 0;
591            for i in 0..10 {
592                let resp = side_a.request(i).await.unwrap();
593                sum += resp;
594            }
595            sum
596        });
597        
598        let handle2 = tokio::spawn(async move {
599            let mut sum = 0;
600            for i in 10..20 {
601                let resp = side_a2.request(i).await.unwrap();
602                sum += resp;
603            }
604            sum
605        });
606        
607        let handle3 = tokio::spawn(async move {
608            let mut sum = 0;
609            for i in 20..30 {
610                let resp = side_a3.request(i).await.unwrap();
611                sum += resp;
612            }
613            sum
614        });
615        
616        let sum1 = handle1.await.unwrap();
617        let sum2 = handle2.await.unwrap();
618        let sum3 = handle3.await.unwrap();
619        
620        // Each range should give: sum(i*2) = 2 * sum(i)
621        assert_eq!(sum1, 2 * (0..10).sum::<i32>());
622        assert_eq!(sum2, 2 * (10..20).sum::<i32>());
623        assert_eq!(sum3, 2 * (20..30).sum::<i32>());
624    }
625
626    #[tokio::test]
627    async fn test_side_b_closes() {
628        let (side_a, side_b) = channel::<i32, i32>();
629        
630        // Side A closes immediately
631        drop(side_a);
632        
633        // Side B should receive Err
634        let request = side_b.recv_request().await;
635        assert!(request.is_err());
636    }
637
638    #[tokio::test]
639    async fn test_all_side_a_close() {
640        let (side_a, side_b) = channel::<i32, i32>();
641        let side_a2 = side_a.clone();
642        
643        // All side A instances close
644        drop(side_a);
645        drop(side_a2);
646        
647        // Side B should receive Err
648        let request = side_b.recv_request().await;
649        assert!(request.is_err());
650    }
651
652    #[tokio::test]
653    async fn test_handle_request() {
654        let (side_a, side_b) = channel::<i32, i32>();
655        
656        tokio::spawn(async move {
657            while side_b.handle_request(|req| req * 3).await.is_ok() {
658                // Continue handling
659            }
660        });
661        
662        for i in 0..5 {
663            let response = side_a.request(i).await.unwrap();
664            assert_eq!(response, i * 3);
665        }
666    }
667
668    #[tokio::test]
669    async fn test_handle_request_async() {
670        let (side_a, side_b) = channel::<String, usize>();
671        
672        tokio::spawn(async move {
673            while side_b.handle_request_async(|req| async move {
674                sleep(Duration::from_millis(10)).await;
675                req.len()
676            }).await.is_ok() {
677                // Continue handling
678            }
679        });
680        
681        let test_strings = vec!["Hello", "World", "Rust"];
682        for s in test_strings {
683            let response = side_a.request(s.to_string()).await.unwrap();
684            assert_eq!(response, s.len());
685        }
686    }
687
688    #[tokio::test]
689    async fn test_concurrent_requests() {
690        let (side_a, side_b) = channel::<String, String>();
691        
692        tokio::spawn(async move {
693            while side_b.handle_request_async(|req| async move {
694                sleep(Duration::from_millis(5)).await;
695                req.to_uppercase()
696            }).await.is_ok() {
697                // Continue
698            }
699        });
700        
701        // Send multiple requests concurrently
702        let mut handles = vec![];
703        for i in 0..10 {
704            let side_a_clone = side_a.clone();
705            let handle = tokio::spawn(async move {
706                let msg = format!("message{}", i);
707                let resp = side_a_clone.request(msg.clone()).await.unwrap();
708                assert_eq!(resp, msg.to_uppercase());
709            });
710            handles.push(handle);
711        }
712        
713        for handle in handles {
714            handle.await.unwrap();
715        }
716    }
717
718    #[tokio::test]
719    async fn test_request_guard_must_reply() {
720        let (side_a, side_b) = channel::<i32, i32>();
721        
722        let handle = tokio::spawn(async move {
723            let _guard = side_b.recv_request().await.unwrap();
724            // Intentionally not calling reply() - this should panic
725        });
726        
727        // Send a request
728        tokio::spawn(async move {
729            let _ = side_a.request(42).await;
730        });
731        
732        // Wait for the spawned task and verify it panicked
733        let result = handle.await;
734        assert!(result.is_err(), "Task should have panicked");
735        
736        // Verify the panic message contains our expected text
737        if let Err(e) = result {
738            if let Ok(panic_payload) = e.try_into_panic() {
739                if let Some(s) = panic_payload.downcast_ref::<String>() {
740                    assert!(s.contains("RequestGuard dropped without replying"), 
741                        "Panic message should mention RequestGuard: {}", s);
742                } else if let Some(s) = panic_payload.downcast_ref::<&str>() {
743                    assert!(s.contains("RequestGuard dropped without replying"), 
744                        "Panic message should mention RequestGuard: {}", s);
745                } else {
746                    panic!("Unexpected panic type");
747                }
748            }
749        }
750    }
751}