Skip to main content

mofa_foundation/secretary/
connection.rs

1//! 用户连接实现 (Foundation 层)
2//!
3//! Kernel 仅定义 `UserConnection` 抽象,具体连接实现放在 Foundation 层。
4
5use async_trait::async_trait;
6use mofa_kernel::agent::secretary::UserConnection;
7use tokio::sync::mpsc;
8
9// =============================================================================
10// 基于通道的连接实现
11// =============================================================================
12
13/// 基于 mpsc 通道的连接
14///
15/// 用于进程内通信,如测试或单机应用
16pub struct ChannelConnection<I, O> {
17    /// 输入接收器
18    input_rx: tokio::sync::Mutex<mpsc::Receiver<I>>,
19    /// 输出发送器
20    output_tx: mpsc::Sender<O>,
21    /// 连接状态
22    connected: std::sync::atomic::AtomicBool,
23}
24
25impl<I, O> ChannelConnection<I, O>
26where
27    I: Send + 'static,
28    O: Send + 'static,
29{
30    /// 创建新的通道连接
31    pub fn new(input_rx: mpsc::Receiver<I>, output_tx: mpsc::Sender<O>) -> Self {
32        Self {
33            input_rx: tokio::sync::Mutex::new(input_rx),
34            output_tx,
35            connected: std::sync::atomic::AtomicBool::new(true),
36        }
37    }
38
39    /// 创建连接对
40    ///
41    /// 返回 (connection, input_tx, output_rx)
42    /// - `connection`: 给秘书使用的连接
43    /// - `input_tx`: 用户用来发送输入
44    /// - `output_rx`: 用户用来接收输出
45    pub fn new_pair(buffer_size: usize) -> (Self, mpsc::Sender<I>, mpsc::Receiver<O>) {
46        let (input_tx, input_rx) = mpsc::channel(buffer_size);
47        let (output_tx, output_rx) = mpsc::channel(buffer_size);
48
49        let conn = Self::new(input_rx, output_tx);
50        (conn, input_tx, output_rx)
51    }
52}
53
54#[async_trait]
55impl<I, O> UserConnection for ChannelConnection<I, O>
56where
57    I: Send + 'static,
58    O: Send + 'static,
59{
60    type Input = I;
61    type Output = O;
62
63    async fn receive(&self) -> anyhow::Result<Self::Input> {
64        let mut rx = self.input_rx.lock().await;
65        rx.recv()
66            .await
67            .ok_or_else(|| anyhow::anyhow!("Channel closed"))
68    }
69
70    async fn try_receive(&self) -> anyhow::Result<Option<Self::Input>> {
71        let mut rx = self.input_rx.lock().await;
72        match rx.try_recv() {
73            Ok(input) => Ok(Some(input)),
74            Err(mpsc::error::TryRecvError::Empty) => Ok(None),
75            Err(mpsc::error::TryRecvError::Disconnected) => {
76                self.connected
77                    .store(false, std::sync::atomic::Ordering::SeqCst);
78                Err(anyhow::anyhow!("Channel disconnected"))
79            }
80        }
81    }
82
83    async fn send(&self, output: Self::Output) -> anyhow::Result<()> {
84        self.output_tx
85            .send(output)
86            .await
87            .map_err(|_| anyhow::anyhow!("Failed to send output"))
88    }
89
90    fn is_connected(&self) -> bool {
91        self.connected.load(std::sync::atomic::Ordering::SeqCst) && !self.output_tx.is_closed()
92    }
93
94    async fn close(&self) -> anyhow::Result<()> {
95        self.connected
96            .store(false, std::sync::atomic::Ordering::SeqCst);
97        Ok(())
98    }
99}
100
101// =============================================================================
102// 超时包装连接
103// =============================================================================
104
105/// 带超时的连接包装器
106pub struct TimeoutConnection<C> {
107    /// 内部连接
108    inner: C,
109    /// 接收超时(毫秒)
110    receive_timeout_ms: u64,
111    /// 发送超时(毫秒)
112    send_timeout_ms: u64,
113}
114
115impl<C> TimeoutConnection<C> {
116    /// 创建带超时的连接
117    pub fn new(inner: C, receive_timeout_ms: u64, send_timeout_ms: u64) -> Self {
118        Self {
119            inner,
120            receive_timeout_ms,
121            send_timeout_ms,
122        }
123    }
124}
125
126#[async_trait]
127impl<C> UserConnection for TimeoutConnection<C>
128where
129    C: UserConnection,
130{
131    type Input = C::Input;
132    type Output = C::Output;
133
134    async fn receive(&self) -> anyhow::Result<Self::Input> {
135        tokio::time::timeout(
136            tokio::time::Duration::from_millis(self.receive_timeout_ms),
137            self.inner.receive(),
138        )
139        .await
140        .map_err(|_| anyhow::anyhow!("Receive timeout"))?
141    }
142
143    async fn try_receive(&self) -> anyhow::Result<Option<Self::Input>> {
144        self.inner.try_receive().await
145    }
146
147    async fn send(&self, output: Self::Output) -> anyhow::Result<()> {
148        tokio::time::timeout(
149            tokio::time::Duration::from_millis(self.send_timeout_ms),
150            self.inner.send(output),
151        )
152        .await
153        .map_err(|_| anyhow::anyhow!("Send timeout"))?
154    }
155
156    fn is_connected(&self) -> bool {
157        self.inner.is_connected()
158    }
159
160    async fn close(&self) -> anyhow::Result<()> {
161        self.inner.close().await
162    }
163}
164
165// =============================================================================
166// 测试
167// =============================================================================
168
169#[cfg(test)]
170mod tests {
171    use super::*;
172
173    #[tokio::test]
174    async fn test_channel_connection() {
175        let (conn, input_tx, mut output_rx) = ChannelConnection::<String, String>::new_pair(10);
176
177        // 发送输入
178        input_tx.send("Hello".to_string()).await.unwrap();
179
180        // 接收输入
181        let input = conn.receive().await.unwrap();
182        assert_eq!(input, "Hello");
183
184        // 发送输出
185        conn.send("World".to_string()).await.unwrap();
186
187        // 接收输出
188        let output = output_rx.recv().await.unwrap();
189        assert_eq!(output, "World");
190
191        assert!(conn.is_connected());
192    }
193
194    #[tokio::test]
195    async fn test_try_receive() {
196        let (conn, _input_tx, _output_rx) = ChannelConnection::<String, String>::new_pair(10);
197
198        // 没有输入时返回 None
199        let result = conn.try_receive().await.unwrap();
200        assert!(result.is_none());
201    }
202}