lite_sync/request_response/many_to_one.rs
1/// Many-to-one bidirectional request-response channel
2///
3/// Optimized for multiple request senders (side A) communicating with a single
4/// response handler (side B). Uses lock-free queue for concurrent request submission.
5///
6/// 多对一双向请求-响应通道
7///
8/// 为多个请求发送方(A方)与单个响应处理方(B方)通信而优化。
9/// 使用无锁队列实现并发请求提交。
10use std::sync::Arc;
11use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
12use std::future::Future;
13use std::pin::Pin;
14use std::task::{Context, Poll};
15use crossbeam_queue::SegQueue;
16
17use crate::oneshot::generic::Sender as OneshotSender;
18use super::common::ChannelError;
19
20/// Internal request wrapper containing request data and response channel
21///
22/// 内部请求包装器,包含请求数据和响应通道
23struct RequestWrapper<Req, Resp> {
24 /// The actual request data
25 ///
26 /// 实际的请求数据
27 request: Req,
28
29 /// Oneshot sender to return the response
30 ///
31 /// 用于返回响应的 oneshot sender
32 response_tx: OneshotSender<Resp>,
33}
34
35/// Shared internal state for many-to-one channel
36///
37/// 多对一通道的共享内部状态
38struct Inner<Req, Resp> {
39 /// Lock-free queue for pending requests
40 ///
41 /// 待处理请求的无锁队列
42 queue: SegQueue<RequestWrapper<Req, Resp>>,
43
44 /// Whether side B (receiver) is closed
45 ///
46 /// B 方(接收方)是否已关闭
47 b_closed: AtomicBool,
48
49 /// Number of active SideA instances
50 ///
51 /// 活跃的 SideA 实例数量
52 sender_count: AtomicUsize,
53
54 /// Waker for side B waiting for requests
55 ///
56 /// B 方等待请求的 waker
57 b_waker: crate::atomic_waker::AtomicWaker,
58}
59
60impl<Req, Resp> Inner<Req, Resp> {
61 /// Create new shared state
62 ///
63 /// 创建新的共享状态
64 #[inline]
65 fn new() -> Self {
66 Self {
67 queue: SegQueue::new(),
68 b_closed: AtomicBool::new(false),
69 sender_count: AtomicUsize::new(1), // Start with 1 sender
70 b_waker: crate::atomic_waker::AtomicWaker::new(),
71 }
72 }
73
74 /// Check if side B is closed
75 ///
76 /// 检查 B 方是否已关闭
77 #[inline]
78 fn is_b_closed(&self) -> bool {
79 self.b_closed.load(Ordering::Acquire)
80 }
81}
82
83/// Side A endpoint (request sender, response receiver) - can be cloned
84///
85/// A 方的 channel 端点(请求发送方,响应接收方)- 可以克隆
86pub struct SideA<Req, Resp> {
87 inner: Arc<Inner<Req, Resp>>,
88}
89
90impl<Req, Resp> Clone for SideA<Req, Resp> {
91 fn clone(&self) -> Self {
92 // Increment sender count with Relaxed (reads will use Acquire)
93 self.inner.sender_count.fetch_add(1, Ordering::Relaxed);
94 Self {
95 inner: self.inner.clone(),
96 }
97 }
98}
99
100// Drop implementation for SideA to decrement sender count
101impl<Req, Resp> Drop for SideA<Req, Resp> {
102 fn drop(&mut self) {
103 // Decrement sender count with Release ordering to ensure visibility
104 if self.inner.sender_count.fetch_sub(1, Ordering::Release) == 1 {
105 // This was the last sender, wake up side B
106 self.inner.b_waker.wake();
107 }
108 }
109}
110
111/// Side B endpoint (request receiver, response sender) - single instance
112///
113/// B 方的 channel 端点(请求接收方,响应发送方)- 单实例
114pub struct SideB<Req, Resp> {
115 inner: Arc<Inner<Req, Resp>>,
116}
117
118/// Create a new many-to-one request-response channel
119///
120/// Returns (SideA, SideB) tuple. SideA can be cloned to create multiple senders.
121///
122/// 创建一个新的多对一请求-响应 channel
123///
124/// 返回 (SideA, SideB) 元组。SideA 可以克隆以创建多个发送方。
125///
126/// # Example
127///
128/// ```
129/// use lite_sync::request_response::many_to_one::channel;
130///
131/// # tokio_test::block_on(async {
132/// let (side_a, side_b) = channel::<String, i32>();
133///
134/// // Clone side_a for multiple senders
135/// let side_a2 = side_a.clone();
136///
137/// // Side B handles requests
138/// tokio::spawn(async move {
139/// while let Ok(guard) = side_b.recv_request().await {
140/// let response = guard.request().len() as i32;
141/// guard.reply(response);
142/// }
143/// });
144///
145/// // Multiple senders can send concurrently
146/// let response1 = side_a.request("Hello".to_string()).await;
147/// let response2 = side_a2.request("World".to_string()).await;
148///
149/// assert_eq!(response1, Ok(5));
150/// assert_eq!(response2, Ok(5));
151/// # });
152/// ```
153#[inline]
154pub fn channel<Req, Resp>() -> (SideA<Req, Resp>, SideB<Req, Resp>) {
155 let inner = Arc::new(Inner::new());
156
157 let side_a = SideA {
158 inner: inner.clone(),
159 };
160
161 let side_b = SideB {
162 inner,
163 };
164
165 (side_a, side_b)
166}
167
168impl<Req, Resp> SideA<Req, Resp> {
169 /// Send a request and wait for response
170 ///
171 /// This method will:
172 /// 1. Push request to the queue
173 /// 2. Wait for side B to process and respond
174 /// 3. Return the response
175 ///
176 /// 发送请求并等待响应
177 ///
178 /// 这个方法会:
179 /// 1. 将请求推入队列
180 /// 2. 等待 B 方处理并响应
181 /// 3. 返回响应
182 ///
183 /// # Returns
184 ///
185 /// - `Ok(response)`: Received response from side B
186 /// - `Err(ChannelError::Closed)`: Side B has been closed
187 ///
188 /// # Example
189 ///
190 /// ```
191 /// # use lite_sync::request_response::many_to_one::channel;
192 /// # tokio_test::block_on(async {
193 /// let (side_a, side_b) = channel::<String, i32>();
194 ///
195 /// tokio::spawn(async move {
196 /// while let Ok(guard) = side_b.recv_request().await {
197 /// let len = guard.request().len() as i32;
198 /// guard.reply(len);
199 /// }
200 /// });
201 ///
202 /// let response = side_a.request("Hello".to_string()).await;
203 /// assert_eq!(response, Ok(5));
204 /// # });
205 /// ```
206 pub async fn request(&self, req: Req) -> Result<Resp, ChannelError> {
207 // Check if B is closed first
208 if self.inner.is_b_closed() {
209 return Err(ChannelError::Closed);
210 }
211
212 // Create oneshot channel for response
213 let (response_tx, response_rx) = OneshotSender::<Resp>::new();
214
215 // Push request to queue
216 self.inner.queue.push(RequestWrapper {
217 request: req,
218 response_tx,
219 });
220
221 // Wake up side B
222 self.inner.b_waker.wake();
223
224 // Wait for response
225 Ok(response_rx.await)
226 }
227
228 /// Try to send a request without waiting for response
229 ///
230 /// Returns a future that will resolve to the response.
231 ///
232 /// 尝试发送请求但不等待响应
233 ///
234 /// 返回一个 future,将解析为响应。
235 pub fn try_request(&self, req: Req) -> Result<impl Future<Output = Result<Resp, ChannelError>>, ChannelError> {
236 // Check if B is closed first
237 if self.inner.is_b_closed() {
238 return Err(ChannelError::Closed);
239 }
240
241 // Create oneshot channel for response
242 let (response_tx, response_rx) = OneshotSender::<Resp>::new();
243
244 // Push request to queue
245 self.inner.queue.push(RequestWrapper {
246 request: req,
247 response_tx,
248 });
249
250 // Wake up side B
251 self.inner.b_waker.wake();
252
253 // Return future that waits for response
254 Ok(async move {
255 Ok(response_rx.await)
256 })
257 }
258}
259
260/// Request guard that enforces B must reply
261///
262/// This guard ensures that B must call `reply()` before dropping the guard.
263/// If the guard is dropped without replying, it will panic to prevent A from deadlocking.
264///
265/// 强制 B 必须回复的 Guard
266///
267/// 这个 guard 确保 B 必须在丢弃 guard 之前调用 `reply()`。
268/// 如果 guard 在没有回复的情况下被丢弃,会 panic 以防止 A 死锁。
269pub struct RequestGuard<Req, Resp>
270where
271 Req: Send, Resp: Send,
272{
273 req: Option<Req>,
274 response_tx: Option<OneshotSender<Resp>>,
275}
276
277impl<Req, Resp> std::fmt::Debug for RequestGuard<Req, Resp>
278where
279 Req: Send + std::fmt::Debug,
280 Resp: Send,
281{
282 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
283 f.debug_struct("RequestGuard")
284 .field("req", &self.req)
285 .finish_non_exhaustive()
286 }
287}
288
289// PartialEq for testing purposes
290impl<Req, Resp> PartialEq for RequestGuard<Req, Resp>
291where
292 Req: Send + PartialEq,
293 Resp: Send,
294{
295 fn eq(&self, other: &Self) -> bool {
296 self.req == other.req
297 }
298}
299
300impl<Req, Resp> RequestGuard<Req, Resp>
301where
302 Req: Send, Resp: Send,
303{
304 /// Get a reference to the request
305 ///
306 /// 获取请求内容的引用
307 #[inline]
308 pub fn request(&self) -> &Req {
309 self.req.as_ref().expect("RequestGuard logic error: request already consumed")
310 }
311
312 /// Consume the guard and send reply
313 ///
314 /// This method will send the response back to the requester.
315 ///
316 /// 消耗 Guard 并发送回复
317 ///
318 /// 这个方法会将响应发送回请求方。
319 #[inline]
320 pub fn reply(mut self, resp: Resp) {
321 if let Some(response_tx) = self.response_tx.take() {
322 let _ = response_tx.send(resp);
323 }
324 // Mark as replied by taking the request
325 self.req = None;
326 }
327}
328
329/// Drop guard: If B drops the guard without calling `reply`, we panic.
330/// This enforces the "must reply" protocol.
331///
332/// Drop 守卫:如果 B 不调用 `reply` 就丢弃了 Guard,我们会 panic。
333/// 这强制执行了 "必须回复" 的协议。
334impl<Req, Resp> Drop for RequestGuard<Req, Resp>
335where
336 Req: Send, Resp: Send,
337{
338 fn drop(&mut self) {
339 if self.req.is_some() {
340 // B dropped the guard without replying
341 // This is a protocol error that would cause A to deadlock
342 // We must panic to prevent this
343 panic!("RequestGuard dropped without replying! This would cause the requester to deadlock. You must call reply() before dropping the guard.");
344 }
345 }
346}
347
348impl<Req, Resp> SideB<Req, Resp> {
349 /// Wait for and receive next request, returning a guard that must be replied to
350 ///
351 /// The returned `RequestGuard` enforces that you must call `reply()` on it.
352 /// If you drop the guard without calling `reply()`, it will panic.
353 ///
354 /// 等待并接收下一个请求,返回一个必须回复的 guard
355 ///
356 /// 返回的 `RequestGuard` 强制你必须调用 `reply()`。
357 /// 如果你在没有调用 `reply()` 的情况下丢弃 guard,会 panic。
358 ///
359 /// # Returns
360 ///
361 /// - `Ok(RequestGuard)`: Received request from a side A
362 /// - `Err(ChannelError::Closed)`: All side A instances have been closed
363 ///
364 /// # Example
365 ///
366 /// ```
367 /// # use lite_sync::request_response::many_to_one::channel;
368 /// # tokio_test::block_on(async {
369 /// let (side_a, side_b) = channel::<String, i32>();
370 ///
371 /// tokio::spawn(async move {
372 /// while let Ok(guard) = side_b.recv_request().await {
373 /// let len = guard.request().len() as i32;
374 /// guard.reply(len);
375 /// }
376 /// });
377 ///
378 /// let response = side_a.request("Hello".to_string()).await;
379 /// assert_eq!(response, Ok(5));
380 /// # });
381 /// ```
382 pub async fn recv_request(&self) -> Result<RequestGuard<Req, Resp>, ChannelError>
383 where
384 Req: Send,
385 Resp: Send,
386 {
387 RecvRequest {
388 inner: &self.inner,
389 registered: false,
390 }.await
391 }
392
393 /// Convenient method to handle request and send response
394 ///
395 /// This method will:
396 /// 1. Wait for and receive request
397 /// 2. Call the handler function
398 /// 3. Send the response via the guard
399 ///
400 /// 处理请求并发送响应的便捷方法
401 ///
402 /// 这个方法会:
403 /// 1. 等待并接收请求
404 /// 2. 调用处理函数
405 /// 3. 通过 guard 发送响应
406 ///
407 /// # Example
408 ///
409 /// ```
410 /// # use lite_sync::request_response::many_to_one::channel;
411 /// # tokio_test::block_on(async {
412 /// let (side_a, side_b) = channel::<String, i32>();
413 ///
414 /// tokio::spawn(async move {
415 /// while side_b.handle_request(|req| req.len() as i32).await.is_ok() {
416 /// // Continue handling
417 /// }
418 /// });
419 ///
420 /// let response = side_a.request("Hello".to_string()).await;
421 /// assert_eq!(response, Ok(5));
422 /// # });
423 /// ```
424 pub async fn handle_request<F>(&self, handler: F) -> Result<(), ChannelError>
425 where
426 Req: Send,
427 Resp: Send,
428 F: FnOnce(&Req) -> Resp,
429 {
430 let guard = self.recv_request().await?;
431 let resp = handler(guard.request());
432 guard.reply(resp);
433 Ok(())
434 }
435
436 /// Convenient async method to handle request and send response
437 ///
438 /// Similar to `handle_request`, but supports async handler functions.
439 /// Note: The handler takes ownership of the request to avoid lifetime issues.
440 ///
441 /// 处理请求并发送响应的异步便捷方法
442 ///
443 /// 与 `handle_request` 类似,但支持异步处理函数。
444 /// 注意:处理函数会获取请求的所有权以避免生命周期问题。
445 ///
446 /// # Example
447 ///
448 /// ```
449 /// # use lite_sync::request_response::many_to_one::channel;
450 /// # tokio_test::block_on(async {
451 /// let (side_a, side_b) = channel::<String, String>();
452 ///
453 /// tokio::spawn(async move {
454 /// while side_b.handle_request_async(|req| async move {
455 /// // Async processing - req is owned
456 /// req.to_uppercase()
457 /// }).await.is_ok() {
458 /// // Continue handling
459 /// }
460 /// });
461 ///
462 /// let response = side_a.request("hello".to_string()).await;
463 /// assert_eq!(response, Ok("HELLO".to_string()));
464 /// # });
465 /// ```
466 pub async fn handle_request_async<F, Fut>(&self, handler: F) -> Result<(), ChannelError>
467 where
468 Req: Send,
469 Resp: Send,
470 F: FnOnce(Req) -> Fut,
471 Fut: Future<Output = Resp>,
472 {
473 let mut guard = self.recv_request().await?;
474 let req = guard.req.take().expect("RequestGuard logic error: request already consumed");
475 let resp = handler(req).await;
476
477 // Manually send the reply since we've consumed the request
478 if let Some(response_tx) = guard.response_tx.take() {
479 let _ = response_tx.send(resp);
480 }
481 // Mark as replied
482 guard.req = None;
483
484 Ok(())
485 }
486}
487
488// Drop implementation to clean up
489impl<Req, Resp> Drop for SideB<Req, Resp> {
490 fn drop(&mut self) {
491 // Side B closed, notify any waiting senders
492 self.inner.b_closed.store(true, Ordering::Release);
493
494 // Drain queue and drop all pending response channels
495 // The generic oneshot will handle cleanup automatically
496 while let Some(_wrapper) = self.inner.queue.pop() {
497 // Just drop the wrapper, oneshot cleanup is automatic
498 }
499 }
500}
501
502/// Future: Side B receives request
503struct RecvRequest<'a, Req, Resp> {
504 inner: &'a Inner<Req, Resp>,
505 registered: bool,
506}
507
508// RecvRequest is Unpin because it only holds references and a bool
509impl<Req, Resp> Unpin for RecvRequest<'_, Req, Resp> {}
510
511impl<Req, Resp> Future for RecvRequest<'_, Req, Resp>
512where
513 Req: Send,
514 Resp: Send,
515{
516 type Output = Result<RequestGuard<Req, Resp>, ChannelError>;
517
518 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
519 // Try to pop from queue
520 if let Some(wrapper) = self.inner.queue.pop() {
521 return Poll::Ready(Ok(RequestGuard {
522 req: Some(wrapper.request),
523 response_tx: Some(wrapper.response_tx),
524 }));
525 }
526
527 // Check if there are any senders left
528 if self.inner.sender_count.load(Ordering::Acquire) == 0 {
529 return Poll::Ready(Err(ChannelError::Closed));
530 }
531
532 // Register waker if not already registered
533 if !self.registered {
534 self.inner.b_waker.register(cx.waker());
535 self.registered = true;
536 }
537
538 // Always check queue and sender_count again before returning Pending
539 // This is critical to avoid deadlock when senders drop after waker is registered
540 if let Some(wrapper) = self.inner.queue.pop() {
541 return Poll::Ready(Ok(RequestGuard {
542 req: Some(wrapper.request),
543 response_tx: Some(wrapper.response_tx),
544 }));
545 }
546
547 // Final check if there are any senders
548 if self.inner.sender_count.load(Ordering::Acquire) == 0 {
549 return Poll::Ready(Err(ChannelError::Closed));
550 }
551
552 Poll::Pending
553 }
554}
555
556#[cfg(test)]
557mod tests {
558 use super::*;
559 use tokio::time::{sleep, Duration};
560
561 #[tokio::test]
562 async fn test_basic_many_to_one() {
563 let (side_a, side_b) = channel::<String, i32>();
564
565 tokio::spawn(async move {
566 while let Ok(guard) = side_b.recv_request().await {
567 let response = guard.request().len() as i32;
568 guard.reply(response);
569 }
570 });
571
572 let response = side_a.request("Hello".to_string()).await;
573 assert_eq!(response, Ok(5));
574 }
575
576 #[tokio::test]
577 async fn test_multiple_senders() {
578 let (side_a, side_b) = channel::<i32, i32>();
579 let side_a2 = side_a.clone();
580 let side_a3 = side_a.clone();
581
582 tokio::spawn(async move {
583 while let Ok(guard) = side_b.recv_request().await {
584 let result = *guard.request() * 2;
585 guard.reply(result);
586 }
587 });
588
589 let handle1 = tokio::spawn(async move {
590 let mut sum = 0;
591 for i in 0..10 {
592 let resp = side_a.request(i).await.unwrap();
593 sum += resp;
594 }
595 sum
596 });
597
598 let handle2 = tokio::spawn(async move {
599 let mut sum = 0;
600 for i in 10..20 {
601 let resp = side_a2.request(i).await.unwrap();
602 sum += resp;
603 }
604 sum
605 });
606
607 let handle3 = tokio::spawn(async move {
608 let mut sum = 0;
609 for i in 20..30 {
610 let resp = side_a3.request(i).await.unwrap();
611 sum += resp;
612 }
613 sum
614 });
615
616 let sum1 = handle1.await.unwrap();
617 let sum2 = handle2.await.unwrap();
618 let sum3 = handle3.await.unwrap();
619
620 // Each range should give: sum(i*2) = 2 * sum(i)
621 assert_eq!(sum1, 2 * (0..10).sum::<i32>());
622 assert_eq!(sum2, 2 * (10..20).sum::<i32>());
623 assert_eq!(sum3, 2 * (20..30).sum::<i32>());
624 }
625
626 #[tokio::test]
627 async fn test_side_b_closes() {
628 let (side_a, side_b) = channel::<i32, i32>();
629
630 // Side A closes immediately
631 drop(side_a);
632
633 // Side B should receive Err
634 let request = side_b.recv_request().await;
635 assert!(request.is_err());
636 }
637
638 #[tokio::test]
639 async fn test_all_side_a_close() {
640 let (side_a, side_b) = channel::<i32, i32>();
641 let side_a2 = side_a.clone();
642
643 // All side A instances close
644 drop(side_a);
645 drop(side_a2);
646
647 // Side B should receive Err
648 let request = side_b.recv_request().await;
649 assert!(request.is_err());
650 }
651
652 #[tokio::test]
653 async fn test_handle_request() {
654 let (side_a, side_b) = channel::<i32, i32>();
655
656 tokio::spawn(async move {
657 while side_b.handle_request(|req| req * 3).await.is_ok() {
658 // Continue handling
659 }
660 });
661
662 for i in 0..5 {
663 let response = side_a.request(i).await.unwrap();
664 assert_eq!(response, i * 3);
665 }
666 }
667
668 #[tokio::test]
669 async fn test_handle_request_async() {
670 let (side_a, side_b) = channel::<String, usize>();
671
672 tokio::spawn(async move {
673 while side_b.handle_request_async(|req| async move {
674 sleep(Duration::from_millis(10)).await;
675 req.len()
676 }).await.is_ok() {
677 // Continue handling
678 }
679 });
680
681 let test_strings = vec!["Hello", "World", "Rust"];
682 for s in test_strings {
683 let response = side_a.request(s.to_string()).await.unwrap();
684 assert_eq!(response, s.len());
685 }
686 }
687
688 #[tokio::test]
689 async fn test_concurrent_requests() {
690 let (side_a, side_b) = channel::<String, String>();
691
692 tokio::spawn(async move {
693 while side_b.handle_request_async(|req| async move {
694 sleep(Duration::from_millis(5)).await;
695 req.to_uppercase()
696 }).await.is_ok() {
697 // Continue
698 }
699 });
700
701 // Send multiple requests concurrently
702 let mut handles = vec![];
703 for i in 0..10 {
704 let side_a_clone = side_a.clone();
705 let handle = tokio::spawn(async move {
706 let msg = format!("message{}", i);
707 let resp = side_a_clone.request(msg.clone()).await.unwrap();
708 assert_eq!(resp, msg.to_uppercase());
709 });
710 handles.push(handle);
711 }
712
713 for handle in handles {
714 handle.await.unwrap();
715 }
716 }
717
718 #[tokio::test]
719 async fn test_request_guard_must_reply() {
720 let (side_a, side_b) = channel::<i32, i32>();
721
722 let handle = tokio::spawn(async move {
723 let _guard = side_b.recv_request().await.unwrap();
724 // Intentionally not calling reply() - this should panic
725 });
726
727 // Send a request
728 tokio::spawn(async move {
729 let _ = side_a.request(42).await;
730 });
731
732 // Wait for the spawned task and verify it panicked
733 let result = handle.await;
734 assert!(result.is_err(), "Task should have panicked");
735
736 // Verify the panic message contains our expected text
737 if let Err(e) = result {
738 if let Ok(panic_payload) = e.try_into_panic() {
739 if let Some(s) = panic_payload.downcast_ref::<String>() {
740 assert!(s.contains("RequestGuard dropped without replying"),
741 "Panic message should mention RequestGuard: {}", s);
742 } else if let Some(s) = panic_payload.downcast_ref::<&str>() {
743 assert!(s.contains("RequestGuard dropped without replying"),
744 "Panic message should mention RequestGuard: {}", s);
745 } else {
746 panic!("Unexpected panic type");
747 }
748 }
749 }
750 }
751}