Skip to main content

rs_adk/
agent_session.rs

1//! AgentSession — intercepting wrapper around SessionHandle.
2//!
3//! Replaces ADK Python's LiveRequestQueue. Instead of adding a second queue
4//! on top of SessionHandle's existing mpsc channel, this wraps a SessionWriter
5//! and intercepts sends for: (1) input fan-out to streaming tools,
6//! (2) middleware hooks, (3) state tracking.
7//!
8//! Data flow: App → AgentSession → SessionWriter → WebSocket
9//!                                ↘ broadcast to input-streaming tools
10//!
11//! ONE queue, ONE consumer task, zero-copy on the hot path.
12
13use rs_genai::prelude::{Content, FunctionResponse};
14use rs_genai::session::{SessionError, SessionEvent, SessionHandle, SessionWriter};
15use std::sync::Arc;
16use tokio::sync::broadcast;
17
18use crate::error::AgentError;
19use crate::state::State;
20
21/// Input events broadcast to input-streaming tools.
22/// Distinct from SessionCommand — this is observation-only.
23#[derive(Debug, Clone)]
24pub enum InputEvent {
25    /// Raw PCM16 audio bytes.
26    Audio(Vec<u8>),
27    /// Text content.
28    Text(String),
29    /// User started speaking.
30    ActivityStart,
31    /// User stopped speaking.
32    ActivityEnd,
33}
34
35/// Intercepting wrapper around a SessionWriter.
36///
37/// Adds input fan-out, middleware hooks, and state tracking without
38/// introducing a second queue (avoids double-queuing).
39#[derive(Clone)]
40pub struct AgentSession {
41    /// The underlying wire-level session writer (Layer 0).
42    writer: Arc<dyn SessionWriter>,
43    /// Event subscription source.
44    event_tx: broadcast::Sender<SessionEvent>,
45    /// Fan-out for input-streaming tools.
46    /// Zero-cost when no tools are subscribed (receiver_count == 0).
47    input_broadcast: broadcast::Sender<InputEvent>,
48    /// Conversation state container.
49    state: State,
50}
51
52impl AgentSession {
53    /// Create a new AgentSession wrapping a SessionHandle.
54    pub fn new(session: SessionHandle) -> Self {
55        let (input_broadcast, _) = broadcast::channel(256);
56        let event_tx = session.event_sender().clone();
57        Self {
58            writer: Arc::new(session),
59            event_tx,
60            input_broadcast,
61            state: State::new(),
62        }
63    }
64
65    /// Create from a trait-object writer (enables mock testing and middleware injection).
66    pub fn from_writer(
67        writer: Arc<dyn SessionWriter>,
68        event_tx: broadcast::Sender<SessionEvent>,
69    ) -> Self {
70        let (input_broadcast, _) = broadcast::channel(256);
71        Self {
72            writer,
73            event_tx,
74            input_broadcast,
75            state: State::new(),
76        }
77    }
78
79    /// Send audio data. Fans out to input-streaming tools ONLY if listeners exist.
80    pub async fn send_audio(&self, data: Vec<u8>) -> Result<(), AgentError> {
81        // Fan-out ONLY if input-streaming tools are listening
82        if self.input_broadcast.receiver_count() > 0 {
83            let _ = self.input_broadcast.send(InputEvent::Audio(data.clone()));
84        }
85        // Forward directly to Layer 0 (ONE hop to WebSocket)
86        self.writer
87            .send_audio(data)
88            .await
89            .map_err(AgentError::Session)
90    }
91
92    /// Send a text message.
93    pub async fn send_text(&self, text: impl Into<String>) -> Result<(), AgentError> {
94        let t = text.into();
95        if self.input_broadcast.receiver_count() > 0 {
96            let _ = self.input_broadcast.send(InputEvent::Text(t.clone()));
97        }
98        self.writer.send_text(t).await.map_err(AgentError::Session)
99    }
100
101    /// Send tool responses.
102    pub async fn send_tool_response(
103        &self,
104        responses: Vec<FunctionResponse>,
105    ) -> Result<(), AgentError> {
106        self.writer
107            .send_tool_response(responses)
108            .await
109            .map_err(AgentError::Session)
110    }
111
112    /// Send client content (conversation history or context injection).
113    pub async fn send_client_content(
114        &self,
115        turns: Vec<Content>,
116        turn_complete: bool,
117    ) -> Result<(), AgentError> {
118        self.writer
119            .send_client_content(turns, turn_complete)
120            .await
121            .map_err(AgentError::Session)
122    }
123
124    /// Send video/image data (raw JPEG bytes).
125    pub async fn send_video(&self, jpeg_data: Vec<u8>) -> Result<(), AgentError> {
126        self.writer
127            .send_video(jpeg_data)
128            .await
129            .map_err(AgentError::Session)
130    }
131
132    /// Update the system instruction mid-session.
133    pub async fn update_instruction(
134        &self,
135        instruction: impl Into<String>,
136    ) -> Result<(), AgentError> {
137        self.writer
138            .update_instruction(instruction.into())
139            .await
140            .map_err(AgentError::Session)
141    }
142
143    /// Signal activity start (user started speaking).
144    pub async fn signal_activity_start(&self) -> Result<(), AgentError> {
145        if self.input_broadcast.receiver_count() > 0 {
146            let _ = self.input_broadcast.send(InputEvent::ActivityStart);
147        }
148        self.writer
149            .signal_activity_start()
150            .await
151            .map_err(AgentError::Session)
152    }
153
154    /// Signal activity end (user stopped speaking).
155    pub async fn signal_activity_end(&self) -> Result<(), AgentError> {
156        if self.input_broadcast.receiver_count() > 0 {
157            let _ = self.input_broadcast.send(InputEvent::ActivityEnd);
158        }
159        self.writer
160            .signal_activity_end()
161            .await
162            .map_err(AgentError::Session)
163    }
164
165    /// Gracefully disconnect.
166    pub async fn disconnect(&self) -> Result<(), AgentError> {
167        self.writer.disconnect().await.map_err(AgentError::Session)
168    }
169
170    /// Subscribe to input events (for input-streaming tools).
171    pub fn subscribe_input(&self) -> broadcast::Receiver<InputEvent> {
172        self.input_broadcast.subscribe()
173    }
174
175    /// Subscribe to session events.
176    pub fn subscribe_events(&self) -> broadcast::Receiver<SessionEvent> {
177        self.event_tx.subscribe()
178    }
179
180    /// Access the underlying session writer.
181    pub fn writer(&self) -> &dyn SessionWriter {
182        &*self.writer
183    }
184
185    /// Access conversation state.
186    pub fn state(&self) -> &State {
187        &self.state
188    }
189
190    /// Number of input-streaming subscribers (for diagnostics).
191    pub fn input_subscriber_count(&self) -> usize {
192        self.input_broadcast.receiver_count()
193    }
194}
195
196/// A SessionWriter that discards all writes.
197/// Used for isolated agent execution (AgentTool) where no real WebSocket exists.
198pub struct NoOpSessionWriter;
199
200#[async_trait::async_trait]
201impl SessionWriter for NoOpSessionWriter {
202    async fn send_audio(&self, _data: Vec<u8>) -> Result<(), SessionError> {
203        Ok(())
204    }
205    async fn send_text(&self, _text: String) -> Result<(), SessionError> {
206        Ok(())
207    }
208    async fn send_tool_response(
209        &self,
210        _responses: Vec<FunctionResponse>,
211    ) -> Result<(), SessionError> {
212        Ok(())
213    }
214    async fn send_client_content(
215        &self,
216        _turns: Vec<Content>,
217        _turn_complete: bool,
218    ) -> Result<(), SessionError> {
219        Ok(())
220    }
221    async fn send_video(&self, _jpeg_data: Vec<u8>) -> Result<(), SessionError> {
222        Ok(())
223    }
224    async fn update_instruction(&self, _instruction: String) -> Result<(), SessionError> {
225        Ok(())
226    }
227    async fn signal_activity_start(&self) -> Result<(), SessionError> {
228        Ok(())
229    }
230    async fn signal_activity_end(&self) -> Result<(), SessionError> {
231        Ok(())
232    }
233    async fn disconnect(&self) -> Result<(), SessionError> {
234        Ok(())
235    }
236}
237
238#[cfg(test)]
239mod tests {
240    use super::*;
241    use rs_genai::session::{SessionHandle, SessionPhase, SessionState};
242    use std::sync::Arc;
243    use tokio::sync::{broadcast, mpsc, watch};
244
245    fn mock_session_handle() -> SessionHandle {
246        let (cmd_tx, _cmd_rx) = mpsc::channel(16);
247        let (evt_tx, _) = broadcast::channel(16);
248        let (phase_tx, phase_rx) = watch::channel(SessionPhase::Disconnected);
249        let state = Arc::new(SessionState::new(phase_tx));
250        SessionHandle::new(cmd_tx, evt_tx, state, phase_rx)
251    }
252
253    #[tokio::test]
254    async fn send_audio_without_subscribers_no_broadcast() {
255        let handle = mock_session_handle();
256        let session = AgentSession::new(handle);
257        assert_eq!(session.input_subscriber_count(), 0);
258    }
259
260    #[tokio::test]
261    async fn send_audio_with_subscriber_broadcasts() {
262        let handle = mock_session_handle();
263        let session = AgentSession::new(handle);
264        let mut input_rx = session.subscribe_input();
265        assert_eq!(session.input_subscriber_count(), 1);
266
267        // send_audio will fail at SessionHandle level (no real WS), but
268        // the broadcast should still fire
269        let data = vec![1, 2, 3, 4];
270        let _ = session.send_audio(data.clone()).await;
271
272        match input_rx.try_recv() {
273            Ok(InputEvent::Audio(received)) => assert_eq!(received, data),
274            other => panic!("expected Audio, got {:?}", other),
275        }
276    }
277
278    #[test]
279    fn agent_session_is_clone() {
280        let handle = mock_session_handle();
281        let session = AgentSession::new(handle);
282        let _clone = session.clone();
283    }
284
285    #[test]
286    fn state_accessible() {
287        let handle = mock_session_handle();
288        let session = AgentSession::new(handle);
289        session.state().set("key", "value");
290        assert_eq!(
291            session.state().get::<String>("key"),
292            Some("value".to_string())
293        );
294    }
295
296    #[tokio::test]
297    async fn text_broadcast() {
298        let handle = mock_session_handle();
299        let session = AgentSession::new(handle);
300        let mut input_rx = session.subscribe_input();
301
302        let _ = session.send_text("hello").await;
303
304        match input_rx.try_recv() {
305            Ok(InputEvent::Text(t)) => assert_eq!(t, "hello"),
306            other => panic!("expected Text, got {:?}", other),
307        }
308    }
309
310    #[tokio::test]
311    async fn activity_signals_broadcast() {
312        let handle = mock_session_handle();
313        let session = AgentSession::new(handle);
314        let mut input_rx = session.subscribe_input();
315
316        let _ = session.signal_activity_start().await;
317        let _ = session.signal_activity_end().await;
318
319        assert!(matches!(input_rx.try_recv(), Ok(InputEvent::ActivityStart)));
320        assert!(matches!(input_rx.try_recv(), Ok(InputEvent::ActivityEnd)));
321    }
322
323    #[tokio::test]
324    async fn from_writer_with_mock() {
325        // Create a mock writer using a real SessionHandle (simplest mock available)
326        let handle = mock_session_handle();
327        let event_tx = handle.event_sender().clone();
328        let writer: Arc<dyn SessionWriter> = Arc::new(handle);
329        let session = AgentSession::from_writer(writer, event_tx);
330        assert_eq!(session.input_subscriber_count(), 0);
331    }
332}