Skip to main content

agent_air_runtime/agent/interface/
source.rs

1//! Input Source - Provides input to the engine
2//!
3//! The [`InputSource`] trait defines how consumers provide input to the engine.
4
5use std::future::Future;
6use std::pin::Pin;
7
8use tokio::sync::mpsc;
9
10use crate::controller::ControllerInputPayload;
11
12/// Provides input from a consumer to the agent engine.
13///
14/// Implementations handle receiving user messages, commands, and other
15/// input from the consumer and delivering them to the engine.
16///
17/// # Lifecycle
18///
19/// The engine calls `recv()` in a loop. When the consumer closes
20/// (user quits, connection dropped), `recv()` should return `None`
21/// to signal shutdown.
22///
23/// # Example
24///
25/// ```ignore
26/// use agent_air_runtime::agent::interface::InputSource;
27/// use agent_air_runtime::controller::ControllerInputPayload;
28/// use std::pin::Pin;
29/// use std::future::Future;
30///
31/// struct MyCustomSource { /* ... */ }
32///
33/// impl InputSource for MyCustomSource {
34///     fn recv(&mut self) -> Pin<Box<dyn Future<Output = Option<ControllerInputPayload>> + Send + '_>> {
35///         Box::pin(async move {
36///             // Receive from your transport
37///             None
38///         })
39///     }
40/// }
41/// ```
42pub trait InputSource: Send + 'static {
43    /// Receive the next input from the consumer.
44    ///
45    /// Returns `None` when the consumer is closed and no more input
46    /// will arrive. The engine will shut down when this returns `None`.
47    fn recv(&mut self)
48    -> Pin<Box<dyn Future<Output = Option<ControllerInputPayload>> + Send + '_>>;
49}
50
51/// Input source backed by an async channel.
52///
53/// This is the default source used internally. The consumer sends
54/// input through a channel that this source reads from.
55pub struct ChannelInputSource {
56    rx: mpsc::Receiver<ControllerInputPayload>,
57}
58
59impl ChannelInputSource {
60    /// Create a new channel-backed input source.
61    pub fn new(rx: mpsc::Receiver<ControllerInputPayload>) -> Self {
62        Self { rx }
63    }
64
65    /// Create a channel pair for input.
66    ///
67    /// Returns `(sender, source)` where sender is used by the consumer
68    /// to send input and source is passed to the engine.
69    pub fn channel(buffer: usize) -> (mpsc::Sender<ControllerInputPayload>, Self) {
70        let (tx, rx) = mpsc::channel(buffer);
71        (tx, Self::new(rx))
72    }
73}
74
75impl InputSource for ChannelInputSource {
76    fn recv(
77        &mut self,
78    ) -> Pin<Box<dyn Future<Output = Option<ControllerInputPayload>> + Send + '_>> {
79        Box::pin(async move { self.rx.recv().await })
80    }
81}
82
83#[cfg(test)]
84mod tests {
85    use super::*;
86    use crate::controller::TurnId;
87
88    #[tokio::test]
89    async fn test_channel_input_source_recv() {
90        let (tx, mut source) = ChannelInputSource::channel(10);
91
92        let payload = ControllerInputPayload::data(1, "hello", TurnId::new_user_turn(1));
93        tx.send(payload).await.unwrap();
94
95        let received = source.recv().await.unwrap();
96        assert_eq!(received.session_id, 1);
97        assert_eq!(received.content, "hello");
98    }
99
100    #[tokio::test]
101    async fn test_channel_input_source_closed() {
102        let (tx, mut source) = ChannelInputSource::channel(10);
103
104        // Drop sender to close channel
105        drop(tx);
106
107        // Should return None
108        let received = source.recv().await;
109        assert!(received.is_none());
110    }
111
112    #[tokio::test]
113    async fn test_channel_input_source_multiple() {
114        let (tx, mut source) = ChannelInputSource::channel(10);
115
116        // Send multiple messages
117        for i in 0..3 {
118            let payload = ControllerInputPayload::data(
119                1,
120                format!("msg {}", i),
121                TurnId::new_user_turn(i as i64),
122            );
123            tx.send(payload).await.unwrap();
124        }
125
126        // Receive all
127        for i in 0..3 {
128            let received = source.recv().await.unwrap();
129            assert_eq!(received.content, format!("msg {}", i));
130        }
131    }
132}