Skip to main content

agent_core_runtime/agent/interface/
source.rs

1//! Input Source - Provides input to the engine
2//!
3//! The [`InputSource`] trait defines how consumers provide input to the engine.
4
5use std::future::Future;
6use std::pin::Pin;
7
8use tokio::sync::mpsc;
9
10use crate::controller::ControllerInputPayload;
11
12/// Provides input from a consumer to the agent engine.
13///
14/// Implementations handle receiving user messages, commands, and other
15/// input from the consumer and delivering them to the engine.
16///
17/// # Lifecycle
18///
19/// The engine calls `recv()` in a loop. When the consumer closes
20/// (user quits, connection dropped), `recv()` should return `None`
21/// to signal shutdown.
22///
23/// # Example
24///
25/// ```ignore
26/// use agent_core_runtime::agent::interface::InputSource;
27/// use agent_core_runtime::controller::ControllerInputPayload;
28/// use std::pin::Pin;
29/// use std::future::Future;
30///
31/// struct MyCustomSource { /* ... */ }
32///
33/// impl InputSource for MyCustomSource {
34///     fn recv(&mut self) -> Pin<Box<dyn Future<Output = Option<ControllerInputPayload>> + Send + '_>> {
35///         Box::pin(async move {
36///             // Receive from your transport
37///             None
38///         })
39///     }
40/// }
41/// ```
42pub trait InputSource: Send + 'static {
43    /// Receive the next input from the consumer.
44    ///
45    /// Returns `None` when the consumer is closed and no more input
46    /// will arrive. The engine will shut down when this returns `None`.
47    fn recv(&mut self) -> Pin<Box<dyn Future<Output = Option<ControllerInputPayload>> + Send + '_>>;
48}
49
50/// Input source backed by an async channel.
51///
52/// This is the default source used internally. The consumer sends
53/// input through a channel that this source reads from.
54pub struct ChannelInputSource {
55    rx: mpsc::Receiver<ControllerInputPayload>,
56}
57
58impl ChannelInputSource {
59    /// Create a new channel-backed input source.
60    pub fn new(rx: mpsc::Receiver<ControllerInputPayload>) -> Self {
61        Self { rx }
62    }
63
64    /// Create a channel pair for input.
65    ///
66    /// Returns `(sender, source)` where sender is used by the consumer
67    /// to send input and source is passed to the engine.
68    pub fn channel(buffer: usize) -> (mpsc::Sender<ControllerInputPayload>, Self) {
69        let (tx, rx) = mpsc::channel(buffer);
70        (tx, Self::new(rx))
71    }
72}
73
74impl InputSource for ChannelInputSource {
75    fn recv(&mut self) -> Pin<Box<dyn Future<Output = Option<ControllerInputPayload>> + Send + '_>> {
76        Box::pin(async move { self.rx.recv().await })
77    }
78}
79
80#[cfg(test)]
81mod tests {
82    use super::*;
83    use crate::controller::TurnId;
84
85    #[tokio::test]
86    async fn test_channel_input_source_recv() {
87        let (tx, mut source) = ChannelInputSource::channel(10);
88
89        let payload = ControllerInputPayload::data(1, "hello", TurnId::new_user_turn(1));
90        tx.send(payload).await.unwrap();
91
92        let received = source.recv().await.unwrap();
93        assert_eq!(received.session_id, 1);
94        assert_eq!(received.content, "hello");
95    }
96
97    #[tokio::test]
98    async fn test_channel_input_source_closed() {
99        let (tx, mut source) = ChannelInputSource::channel(10);
100
101        // Drop sender to close channel
102        drop(tx);
103
104        // Should return None
105        let received = source.recv().await;
106        assert!(received.is_none());
107    }
108
109    #[tokio::test]
110    async fn test_channel_input_source_multiple() {
111        let (tx, mut source) = ChannelInputSource::channel(10);
112
113        // Send multiple messages
114        for i in 0..3 {
115            let payload = ControllerInputPayload::data(1, format!("msg {}", i), TurnId::new_user_turn(i as i64));
116            tx.send(payload).await.unwrap();
117        }
118
119        // Receive all
120        for i in 0..3 {
121            let received = source.recv().await.unwrap();
122            assert_eq!(received.content, format!("msg {}", i));
123        }
124    }
125}