agent_air_runtime/agent/interface/source.rs
1//! Input Source - Provides input to the engine
2//!
3//! The [`InputSource`] trait defines how consumers provide input to the engine.
4
5use std::future::Future;
6use std::pin::Pin;
7
8use tokio::sync::mpsc;
9
10use crate::controller::ControllerInputPayload;
11
12/// Provides input from a consumer to the agent engine.
13///
14/// Implementations handle receiving user messages, commands, and other
15/// input from the consumer and delivering them to the engine.
16///
17/// # Lifecycle
18///
19/// The engine calls `recv()` in a loop. When the consumer closes
20/// (user quits, connection dropped), `recv()` should return `None`
21/// to signal shutdown.
22///
23/// # Example
24///
25/// ```ignore
26/// use agent_air_runtime::agent::interface::InputSource;
27/// use agent_air_runtime::controller::ControllerInputPayload;
28/// use std::pin::Pin;
29/// use std::future::Future;
30///
31/// struct MyCustomSource { /* ... */ }
32///
33/// impl InputSource for MyCustomSource {
34/// fn recv(&mut self) -> Pin<Box<dyn Future<Output = Option<ControllerInputPayload>> + Send + '_>> {
35/// Box::pin(async move {
36/// // Receive from your transport
37/// None
38/// })
39/// }
40/// }
41/// ```
42pub trait InputSource: Send + 'static {
43 /// Receive the next input from the consumer.
44 ///
45 /// Returns `None` when the consumer is closed and no more input
46 /// will arrive. The engine will shut down when this returns `None`.
47 fn recv(&mut self)
48 -> Pin<Box<dyn Future<Output = Option<ControllerInputPayload>> + Send + '_>>;
49}
50
51/// Input source backed by an async channel.
52///
53/// This is the default source used internally. The consumer sends
54/// input through a channel that this source reads from.
55pub struct ChannelInputSource {
56 rx: mpsc::Receiver<ControllerInputPayload>,
57}
58
59impl ChannelInputSource {
60 /// Create a new channel-backed input source.
61 pub fn new(rx: mpsc::Receiver<ControllerInputPayload>) -> Self {
62 Self { rx }
63 }
64
65 /// Create a channel pair for input.
66 ///
67 /// Returns `(sender, source)` where sender is used by the consumer
68 /// to send input and source is passed to the engine.
69 pub fn channel(buffer: usize) -> (mpsc::Sender<ControllerInputPayload>, Self) {
70 let (tx, rx) = mpsc::channel(buffer);
71 (tx, Self::new(rx))
72 }
73}
74
75impl InputSource for ChannelInputSource {
76 fn recv(
77 &mut self,
78 ) -> Pin<Box<dyn Future<Output = Option<ControllerInputPayload>> + Send + '_>> {
79 Box::pin(async move { self.rx.recv().await })
80 }
81}
82
83#[cfg(test)]
84mod tests {
85 use super::*;
86 use crate::controller::TurnId;
87
88 #[tokio::test]
89 async fn test_channel_input_source_recv() {
90 let (tx, mut source) = ChannelInputSource::channel(10);
91
92 let payload = ControllerInputPayload::data(1, "hello", TurnId::new_user_turn(1));
93 tx.send(payload).await.unwrap();
94
95 let received = source.recv().await.unwrap();
96 assert_eq!(received.session_id, 1);
97 assert_eq!(received.content, "hello");
98 }
99
100 #[tokio::test]
101 async fn test_channel_input_source_closed() {
102 let (tx, mut source) = ChannelInputSource::channel(10);
103
104 // Drop sender to close channel
105 drop(tx);
106
107 // Should return None
108 let received = source.recv().await;
109 assert!(received.is_none());
110 }
111
112 #[tokio::test]
113 async fn test_channel_input_source_multiple() {
114 let (tx, mut source) = ChannelInputSource::channel(10);
115
116 // Send multiple messages
117 for i in 0..3 {
118 let payload = ControllerInputPayload::data(
119 1,
120 format!("msg {}", i),
121 TurnId::new_user_turn(i as i64),
122 );
123 tx.send(payload).await.unwrap();
124 }
125
126 // Receive all
127 for i in 0..3 {
128 let received = source.recv().await.unwrap();
129 assert_eq!(received.content, format!("msg {}", i));
130 }
131 }
132}