mofa_foundation/secretary/
connection.rs1use async_trait::async_trait;
6use mofa_kernel::agent::secretary::UserConnection;
7use tokio::sync::mpsc;
8
9pub struct ChannelConnection<I, O> {
17 input_rx: tokio::sync::Mutex<mpsc::Receiver<I>>,
19 output_tx: mpsc::Sender<O>,
21 connected: std::sync::atomic::AtomicBool,
23}
24
25impl<I, O> ChannelConnection<I, O>
26where
27 I: Send + 'static,
28 O: Send + 'static,
29{
30 pub fn new(input_rx: mpsc::Receiver<I>, output_tx: mpsc::Sender<O>) -> Self {
32 Self {
33 input_rx: tokio::sync::Mutex::new(input_rx),
34 output_tx,
35 connected: std::sync::atomic::AtomicBool::new(true),
36 }
37 }
38
39 pub fn new_pair(buffer_size: usize) -> (Self, mpsc::Sender<I>, mpsc::Receiver<O>) {
46 let (input_tx, input_rx) = mpsc::channel(buffer_size);
47 let (output_tx, output_rx) = mpsc::channel(buffer_size);
48
49 let conn = Self::new(input_rx, output_tx);
50 (conn, input_tx, output_rx)
51 }
52}
53
54#[async_trait]
55impl<I, O> UserConnection for ChannelConnection<I, O>
56where
57 I: Send + 'static,
58 O: Send + 'static,
59{
60 type Input = I;
61 type Output = O;
62
63 async fn receive(&self) -> anyhow::Result<Self::Input> {
64 let mut rx = self.input_rx.lock().await;
65 rx.recv()
66 .await
67 .ok_or_else(|| anyhow::anyhow!("Channel closed"))
68 }
69
70 async fn try_receive(&self) -> anyhow::Result<Option<Self::Input>> {
71 let mut rx = self.input_rx.lock().await;
72 match rx.try_recv() {
73 Ok(input) => Ok(Some(input)),
74 Err(mpsc::error::TryRecvError::Empty) => Ok(None),
75 Err(mpsc::error::TryRecvError::Disconnected) => {
76 self.connected
77 .store(false, std::sync::atomic::Ordering::SeqCst);
78 Err(anyhow::anyhow!("Channel disconnected"))
79 }
80 }
81 }
82
83 async fn send(&self, output: Self::Output) -> anyhow::Result<()> {
84 self.output_tx
85 .send(output)
86 .await
87 .map_err(|_| anyhow::anyhow!("Failed to send output"))
88 }
89
90 fn is_connected(&self) -> bool {
91 self.connected.load(std::sync::atomic::Ordering::SeqCst) && !self.output_tx.is_closed()
92 }
93
94 async fn close(&self) -> anyhow::Result<()> {
95 self.connected
96 .store(false, std::sync::atomic::Ordering::SeqCst);
97 Ok(())
98 }
99}
100
101pub struct TimeoutConnection<C> {
107 inner: C,
109 receive_timeout_ms: u64,
111 send_timeout_ms: u64,
113}
114
115impl<C> TimeoutConnection<C> {
116 pub fn new(inner: C, receive_timeout_ms: u64, send_timeout_ms: u64) -> Self {
118 Self {
119 inner,
120 receive_timeout_ms,
121 send_timeout_ms,
122 }
123 }
124}
125
126#[async_trait]
127impl<C> UserConnection for TimeoutConnection<C>
128where
129 C: UserConnection,
130{
131 type Input = C::Input;
132 type Output = C::Output;
133
134 async fn receive(&self) -> anyhow::Result<Self::Input> {
135 tokio::time::timeout(
136 tokio::time::Duration::from_millis(self.receive_timeout_ms),
137 self.inner.receive(),
138 )
139 .await
140 .map_err(|_| anyhow::anyhow!("Receive timeout"))?
141 }
142
143 async fn try_receive(&self) -> anyhow::Result<Option<Self::Input>> {
144 self.inner.try_receive().await
145 }
146
147 async fn send(&self, output: Self::Output) -> anyhow::Result<()> {
148 tokio::time::timeout(
149 tokio::time::Duration::from_millis(self.send_timeout_ms),
150 self.inner.send(output),
151 )
152 .await
153 .map_err(|_| anyhow::anyhow!("Send timeout"))?
154 }
155
156 fn is_connected(&self) -> bool {
157 self.inner.is_connected()
158 }
159
160 async fn close(&self) -> anyhow::Result<()> {
161 self.inner.close().await
162 }
163}
164
165#[cfg(test)]
170mod tests {
171 use super::*;
172
173 #[tokio::test]
174 async fn test_channel_connection() {
175 let (conn, input_tx, mut output_rx) = ChannelConnection::<String, String>::new_pair(10);
176
177 input_tx.send("Hello".to_string()).await.unwrap();
179
180 let input = conn.receive().await.unwrap();
182 assert_eq!(input, "Hello");
183
184 conn.send("World".to_string()).await.unwrap();
186
187 let output = output_rx.recv().await.unwrap();
189 assert_eq!(output, "World");
190
191 assert!(conn.is_connected());
192 }
193
194 #[tokio::test]
195 async fn test_try_receive() {
196 let (conn, _input_tx, _output_rx) = ChannelConnection::<String, String>::new_pair(10);
197
198 let result = conn.try_receive().await.unwrap();
200 assert!(result.is_none());
201 }
202}