use rs_genai::prelude::{Content, FunctionResponse};
use rs_genai::session::{SessionError, SessionEvent, SessionHandle, SessionWriter};
use std::sync::Arc;
use tokio::sync::broadcast;
use crate::error::AgentError;
use crate::state::State;
#[derive(Debug, Clone)]
pub enum InputEvent {
Audio(Vec<u8>),
Text(String),
ActivityStart,
ActivityEnd,
}
#[derive(Clone)]
pub struct AgentSession {
writer: Arc<dyn SessionWriter>,
event_tx: broadcast::Sender<SessionEvent>,
input_broadcast: broadcast::Sender<InputEvent>,
state: State,
}
impl AgentSession {
pub fn new(session: SessionHandle) -> Self {
let (input_broadcast, _) = broadcast::channel(256);
let event_tx = session.event_sender().clone();
Self {
writer: Arc::new(session),
event_tx,
input_broadcast,
state: State::new(),
}
}
pub fn from_writer(
writer: Arc<dyn SessionWriter>,
event_tx: broadcast::Sender<SessionEvent>,
) -> Self {
let (input_broadcast, _) = broadcast::channel(256);
Self {
writer,
event_tx,
input_broadcast,
state: State::new(),
}
}
pub async fn send_audio(&self, data: Vec<u8>) -> Result<(), AgentError> {
if self.input_broadcast.receiver_count() > 0 {
let _ = self.input_broadcast.send(InputEvent::Audio(data.clone()));
}
self.writer
.send_audio(data)
.await
.map_err(AgentError::Session)
}
pub async fn send_text(&self, text: impl Into<String>) -> Result<(), AgentError> {
let t = text.into();
if self.input_broadcast.receiver_count() > 0 {
let _ = self.input_broadcast.send(InputEvent::Text(t.clone()));
}
self.writer.send_text(t).await.map_err(AgentError::Session)
}
pub async fn send_tool_response(
&self,
responses: Vec<FunctionResponse>,
) -> Result<(), AgentError> {
self.writer
.send_tool_response(responses)
.await
.map_err(AgentError::Session)
}
pub async fn send_client_content(
&self,
turns: Vec<Content>,
turn_complete: bool,
) -> Result<(), AgentError> {
self.writer
.send_client_content(turns, turn_complete)
.await
.map_err(AgentError::Session)
}
pub async fn send_video(&self, jpeg_data: Vec<u8>) -> Result<(), AgentError> {
self.writer
.send_video(jpeg_data)
.await
.map_err(AgentError::Session)
}
pub async fn update_instruction(
&self,
instruction: impl Into<String>,
) -> Result<(), AgentError> {
self.writer
.update_instruction(instruction.into())
.await
.map_err(AgentError::Session)
}
pub async fn signal_activity_start(&self) -> Result<(), AgentError> {
if self.input_broadcast.receiver_count() > 0 {
let _ = self.input_broadcast.send(InputEvent::ActivityStart);
}
self.writer
.signal_activity_start()
.await
.map_err(AgentError::Session)
}
pub async fn signal_activity_end(&self) -> Result<(), AgentError> {
if self.input_broadcast.receiver_count() > 0 {
let _ = self.input_broadcast.send(InputEvent::ActivityEnd);
}
self.writer
.signal_activity_end()
.await
.map_err(AgentError::Session)
}
pub async fn disconnect(&self) -> Result<(), AgentError> {
self.writer.disconnect().await.map_err(AgentError::Session)
}
pub fn subscribe_input(&self) -> broadcast::Receiver<InputEvent> {
self.input_broadcast.subscribe()
}
pub fn subscribe_events(&self) -> broadcast::Receiver<SessionEvent> {
self.event_tx.subscribe()
}
pub fn writer(&self) -> &dyn SessionWriter {
&*self.writer
}
pub fn state(&self) -> &State {
&self.state
}
pub fn input_subscriber_count(&self) -> usize {
self.input_broadcast.receiver_count()
}
}
pub struct NoOpSessionWriter;
#[async_trait::async_trait]
impl SessionWriter for NoOpSessionWriter {
async fn send_audio(&self, _data: Vec<u8>) -> Result<(), SessionError> {
Ok(())
}
async fn send_text(&self, _text: String) -> Result<(), SessionError> {
Ok(())
}
async fn send_tool_response(
&self,
_responses: Vec<FunctionResponse>,
) -> Result<(), SessionError> {
Ok(())
}
async fn send_client_content(
&self,
_turns: Vec<Content>,
_turn_complete: bool,
) -> Result<(), SessionError> {
Ok(())
}
async fn send_video(&self, _jpeg_data: Vec<u8>) -> Result<(), SessionError> {
Ok(())
}
async fn update_instruction(&self, _instruction: String) -> Result<(), SessionError> {
Ok(())
}
async fn signal_activity_start(&self) -> Result<(), SessionError> {
Ok(())
}
async fn signal_activity_end(&self) -> Result<(), SessionError> {
Ok(())
}
async fn disconnect(&self) -> Result<(), SessionError> {
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use rs_genai::session::{SessionHandle, SessionPhase, SessionState};
use std::sync::Arc;
use tokio::sync::{broadcast, mpsc, watch};
fn mock_session_handle() -> SessionHandle {
let (cmd_tx, _cmd_rx) = mpsc::channel(16);
let (evt_tx, _) = broadcast::channel(16);
let (phase_tx, phase_rx) = watch::channel(SessionPhase::Disconnected);
let state = Arc::new(SessionState::new(phase_tx));
SessionHandle::new(cmd_tx, evt_tx, state, phase_rx)
}
#[tokio::test]
async fn send_audio_without_subscribers_no_broadcast() {
let handle = mock_session_handle();
let session = AgentSession::new(handle);
assert_eq!(session.input_subscriber_count(), 0);
}
#[tokio::test]
async fn send_audio_with_subscriber_broadcasts() {
let handle = mock_session_handle();
let session = AgentSession::new(handle);
let mut input_rx = session.subscribe_input();
assert_eq!(session.input_subscriber_count(), 1);
let data = vec![1, 2, 3, 4];
let _ = session.send_audio(data.clone()).await;
match input_rx.try_recv() {
Ok(InputEvent::Audio(received)) => assert_eq!(received, data),
other => panic!("expected Audio, got {:?}", other),
}
}
#[test]
fn agent_session_is_clone() {
let handle = mock_session_handle();
let session = AgentSession::new(handle);
let _clone = session.clone();
}
#[test]
fn state_accessible() {
let handle = mock_session_handle();
let session = AgentSession::new(handle);
session.state().set("key", "value");
assert_eq!(
session.state().get::<String>("key"),
Some("value".to_string())
);
}
#[tokio::test]
async fn text_broadcast() {
let handle = mock_session_handle();
let session = AgentSession::new(handle);
let mut input_rx = session.subscribe_input();
let _ = session.send_text("hello").await;
match input_rx.try_recv() {
Ok(InputEvent::Text(t)) => assert_eq!(t, "hello"),
other => panic!("expected Text, got {:?}", other),
}
}
#[tokio::test]
async fn activity_signals_broadcast() {
let handle = mock_session_handle();
let session = AgentSession::new(handle);
let mut input_rx = session.subscribe_input();
let _ = session.signal_activity_start().await;
let _ = session.signal_activity_end().await;
assert!(matches!(input_rx.try_recv(), Ok(InputEvent::ActivityStart)));
assert!(matches!(input_rx.try_recv(), Ok(InputEvent::ActivityEnd)));
}
#[tokio::test]
async fn from_writer_with_mock() {
let handle = mock_session_handle();
let event_tx = handle.event_sender().clone();
let writer: Arc<dyn SessionWriter> = Arc::new(handle);
let session = AgentSession::from_writer(writer, event_tx);
assert_eq!(session.input_subscriber_count(), 0);
}
}