1use rs_genai::prelude::{Content, FunctionResponse};
14use rs_genai::session::{SessionError, SessionEvent, SessionHandle, SessionWriter};
15use std::sync::Arc;
16use tokio::sync::broadcast;
17
18use crate::error::AgentError;
19use crate::state::State;
20
21#[derive(Debug, Clone)]
24pub enum InputEvent {
25 Audio(Vec<u8>),
27 Text(String),
29 ActivityStart,
31 ActivityEnd,
33}
34
35#[derive(Clone)]
40pub struct AgentSession {
41 writer: Arc<dyn SessionWriter>,
43 event_tx: broadcast::Sender<SessionEvent>,
45 input_broadcast: broadcast::Sender<InputEvent>,
48 state: State,
50}
51
52impl AgentSession {
53 pub fn new(session: SessionHandle) -> Self {
55 let (input_broadcast, _) = broadcast::channel(256);
56 let event_tx = session.event_sender().clone();
57 Self {
58 writer: Arc::new(session),
59 event_tx,
60 input_broadcast,
61 state: State::new(),
62 }
63 }
64
65 pub fn from_writer(
67 writer: Arc<dyn SessionWriter>,
68 event_tx: broadcast::Sender<SessionEvent>,
69 ) -> Self {
70 let (input_broadcast, _) = broadcast::channel(256);
71 Self {
72 writer,
73 event_tx,
74 input_broadcast,
75 state: State::new(),
76 }
77 }
78
79 pub async fn send_audio(&self, data: Vec<u8>) -> Result<(), AgentError> {
81 if self.input_broadcast.receiver_count() > 0 {
83 let _ = self.input_broadcast.send(InputEvent::Audio(data.clone()));
84 }
85 self.writer
87 .send_audio(data)
88 .await
89 .map_err(AgentError::Session)
90 }
91
92 pub async fn send_text(&self, text: impl Into<String>) -> Result<(), AgentError> {
94 let t = text.into();
95 if self.input_broadcast.receiver_count() > 0 {
96 let _ = self.input_broadcast.send(InputEvent::Text(t.clone()));
97 }
98 self.writer.send_text(t).await.map_err(AgentError::Session)
99 }
100
101 pub async fn send_tool_response(
103 &self,
104 responses: Vec<FunctionResponse>,
105 ) -> Result<(), AgentError> {
106 self.writer
107 .send_tool_response(responses)
108 .await
109 .map_err(AgentError::Session)
110 }
111
112 pub async fn send_client_content(
114 &self,
115 turns: Vec<Content>,
116 turn_complete: bool,
117 ) -> Result<(), AgentError> {
118 self.writer
119 .send_client_content(turns, turn_complete)
120 .await
121 .map_err(AgentError::Session)
122 }
123
124 pub async fn send_video(&self, jpeg_data: Vec<u8>) -> Result<(), AgentError> {
126 self.writer
127 .send_video(jpeg_data)
128 .await
129 .map_err(AgentError::Session)
130 }
131
132 pub async fn update_instruction(
134 &self,
135 instruction: impl Into<String>,
136 ) -> Result<(), AgentError> {
137 self.writer
138 .update_instruction(instruction.into())
139 .await
140 .map_err(AgentError::Session)
141 }
142
143 pub async fn signal_activity_start(&self) -> Result<(), AgentError> {
145 if self.input_broadcast.receiver_count() > 0 {
146 let _ = self.input_broadcast.send(InputEvent::ActivityStart);
147 }
148 self.writer
149 .signal_activity_start()
150 .await
151 .map_err(AgentError::Session)
152 }
153
154 pub async fn signal_activity_end(&self) -> Result<(), AgentError> {
156 if self.input_broadcast.receiver_count() > 0 {
157 let _ = self.input_broadcast.send(InputEvent::ActivityEnd);
158 }
159 self.writer
160 .signal_activity_end()
161 .await
162 .map_err(AgentError::Session)
163 }
164
165 pub async fn disconnect(&self) -> Result<(), AgentError> {
167 self.writer.disconnect().await.map_err(AgentError::Session)
168 }
169
170 pub fn subscribe_input(&self) -> broadcast::Receiver<InputEvent> {
172 self.input_broadcast.subscribe()
173 }
174
175 pub fn subscribe_events(&self) -> broadcast::Receiver<SessionEvent> {
177 self.event_tx.subscribe()
178 }
179
180 pub fn writer(&self) -> &dyn SessionWriter {
182 &*self.writer
183 }
184
185 pub fn state(&self) -> &State {
187 &self.state
188 }
189
190 pub fn input_subscriber_count(&self) -> usize {
192 self.input_broadcast.receiver_count()
193 }
194}
195
196pub struct NoOpSessionWriter;
199
200#[async_trait::async_trait]
201impl SessionWriter for NoOpSessionWriter {
202 async fn send_audio(&self, _data: Vec<u8>) -> Result<(), SessionError> {
203 Ok(())
204 }
205 async fn send_text(&self, _text: String) -> Result<(), SessionError> {
206 Ok(())
207 }
208 async fn send_tool_response(
209 &self,
210 _responses: Vec<FunctionResponse>,
211 ) -> Result<(), SessionError> {
212 Ok(())
213 }
214 async fn send_client_content(
215 &self,
216 _turns: Vec<Content>,
217 _turn_complete: bool,
218 ) -> Result<(), SessionError> {
219 Ok(())
220 }
221 async fn send_video(&self, _jpeg_data: Vec<u8>) -> Result<(), SessionError> {
222 Ok(())
223 }
224 async fn update_instruction(&self, _instruction: String) -> Result<(), SessionError> {
225 Ok(())
226 }
227 async fn signal_activity_start(&self) -> Result<(), SessionError> {
228 Ok(())
229 }
230 async fn signal_activity_end(&self) -> Result<(), SessionError> {
231 Ok(())
232 }
233 async fn disconnect(&self) -> Result<(), SessionError> {
234 Ok(())
235 }
236}
237
238#[cfg(test)]
239mod tests {
240 use super::*;
241 use rs_genai::session::{SessionHandle, SessionPhase, SessionState};
242 use std::sync::Arc;
243 use tokio::sync::{broadcast, mpsc, watch};
244
245 fn mock_session_handle() -> SessionHandle {
246 let (cmd_tx, _cmd_rx) = mpsc::channel(16);
247 let (evt_tx, _) = broadcast::channel(16);
248 let (phase_tx, phase_rx) = watch::channel(SessionPhase::Disconnected);
249 let state = Arc::new(SessionState::new(phase_tx));
250 SessionHandle::new(cmd_tx, evt_tx, state, phase_rx)
251 }
252
253 #[tokio::test]
254 async fn send_audio_without_subscribers_no_broadcast() {
255 let handle = mock_session_handle();
256 let session = AgentSession::new(handle);
257 assert_eq!(session.input_subscriber_count(), 0);
258 }
259
260 #[tokio::test]
261 async fn send_audio_with_subscriber_broadcasts() {
262 let handle = mock_session_handle();
263 let session = AgentSession::new(handle);
264 let mut input_rx = session.subscribe_input();
265 assert_eq!(session.input_subscriber_count(), 1);
266
267 let data = vec![1, 2, 3, 4];
270 let _ = session.send_audio(data.clone()).await;
271
272 match input_rx.try_recv() {
273 Ok(InputEvent::Audio(received)) => assert_eq!(received, data),
274 other => panic!("expected Audio, got {:?}", other),
275 }
276 }
277
278 #[test]
279 fn agent_session_is_clone() {
280 let handle = mock_session_handle();
281 let session = AgentSession::new(handle);
282 let _clone = session.clone();
283 }
284
285 #[test]
286 fn state_accessible() {
287 let handle = mock_session_handle();
288 let session = AgentSession::new(handle);
289 session.state().set("key", "value");
290 assert_eq!(
291 session.state().get::<String>("key"),
292 Some("value".to_string())
293 );
294 }
295
296 #[tokio::test]
297 async fn text_broadcast() {
298 let handle = mock_session_handle();
299 let session = AgentSession::new(handle);
300 let mut input_rx = session.subscribe_input();
301
302 let _ = session.send_text("hello").await;
303
304 match input_rx.try_recv() {
305 Ok(InputEvent::Text(t)) => assert_eq!(t, "hello"),
306 other => panic!("expected Text, got {:?}", other),
307 }
308 }
309
310 #[tokio::test]
311 async fn activity_signals_broadcast() {
312 let handle = mock_session_handle();
313 let session = AgentSession::new(handle);
314 let mut input_rx = session.subscribe_input();
315
316 let _ = session.signal_activity_start().await;
317 let _ = session.signal_activity_end().await;
318
319 assert!(matches!(input_rx.try_recv(), Ok(InputEvent::ActivityStart)));
320 assert!(matches!(input_rx.try_recv(), Ok(InputEvent::ActivityEnd)));
321 }
322
323 #[tokio::test]
324 async fn from_writer_with_mock() {
325 let handle = mock_session_handle();
327 let event_tx = handle.event_sender().clone();
328 let writer: Arc<dyn SessionWriter> = Arc::new(handle);
329 let session = AgentSession::from_writer(writer, event_tx);
330 assert_eq!(session.input_subscriber_count(), 0);
331 }
332}