use std::sync::Arc;
use std::time::Duration;
use bytes::Bytes;
use rs_genai::prelude::{FunctionCall, FunctionResponse, SessionPhase, UsageMetadata};
use rs_genai::session::SessionWriter;
use super::BoxFuture;
use crate::state::State;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum CallbackMode {
#[default]
Blocking,
Concurrent,
}
pub struct EventCallbacks {
pub on_audio: Option<Box<dyn Fn(&Bytes) + Send + Sync>>,
pub on_text: Option<Box<dyn Fn(&str) + Send + Sync>>,
pub on_text_complete: Option<Box<dyn Fn(&str) + Send + Sync>>,
pub on_input_transcript: Option<Box<dyn Fn(&str, bool) + Send + Sync>>,
pub on_output_transcript: Option<Box<dyn Fn(&str, bool) + Send + Sync>>,
pub on_thought: Option<Box<dyn Fn(&str) + Send + Sync>>,
pub on_vad_start: Option<Box<dyn Fn() + Send + Sync>>,
pub on_vad_end: Option<Box<dyn Fn() + Send + Sync>>,
pub on_phase: Option<Box<dyn Fn(SessionPhase) + Send + Sync>>,
pub on_usage: Option<Box<dyn Fn(&UsageMetadata) + Send + Sync>>,
pub on_interrupted: Option<Arc<dyn Fn() -> BoxFuture<()> + Send + Sync>>,
pub on_tool_call: Option<
Arc<
dyn Fn(Vec<FunctionCall>, State) -> BoxFuture<Option<Vec<FunctionResponse>>>
+ Send
+ Sync,
>,
>,
pub on_tool_cancelled: Option<Arc<dyn Fn(Vec<String>) -> BoxFuture<()> + Send + Sync>>,
pub on_turn_complete: Option<Arc<dyn Fn() -> BoxFuture<()> + Send + Sync>>,
pub on_go_away: Option<Arc<dyn Fn(Duration) -> BoxFuture<()> + Send + Sync>>,
pub on_connected: Option<Arc<dyn Fn(Arc<dyn SessionWriter>) -> BoxFuture<()> + Send + Sync>>,
pub on_disconnected: Option<Arc<dyn Fn(Option<String>) -> BoxFuture<()> + Send + Sync>>,
pub on_resumed: Option<Arc<dyn Fn() -> BoxFuture<()> + Send + Sync>>,
pub on_error: Option<Arc<dyn Fn(String) -> BoxFuture<()> + Send + Sync>>,
pub on_transfer: Option<Arc<dyn Fn(String, String) -> BoxFuture<()> + Send + Sync>>,
pub on_extracted: Option<Arc<dyn Fn(String, serde_json::Value) -> BoxFuture<()> + Send + Sync>>,
pub on_extraction_error: Option<Arc<dyn Fn(String, String) -> BoxFuture<()> + Send + Sync>>,
pub on_turn_complete_mode: CallbackMode,
pub on_connected_mode: CallbackMode,
pub on_disconnected_mode: CallbackMode,
pub on_error_mode: CallbackMode,
pub on_go_away_mode: CallbackMode,
pub on_extracted_mode: CallbackMode,
pub on_extraction_error_mode: CallbackMode,
pub on_tool_cancelled_mode: CallbackMode,
pub on_transfer_mode: CallbackMode,
pub on_resumed_mode: CallbackMode,
pub before_tool_response: Option<
Arc<dyn Fn(Vec<FunctionResponse>, State) -> BoxFuture<Vec<FunctionResponse>> + Send + Sync>,
>,
pub on_turn_boundary:
Option<Arc<dyn Fn(State, Arc<dyn SessionWriter>) -> BoxFuture<()> + Send + Sync>>,
pub instruction_template: Option<Arc<dyn Fn(&State) -> Option<String> + Send + Sync>>,
pub instruction_amendment: Option<Arc<dyn Fn(&State) -> Option<String> + Send + Sync>>,
}
impl Default for EventCallbacks {
fn default() -> Self {
Self {
on_audio: None,
on_text: None,
on_text_complete: None,
on_input_transcript: None,
on_output_transcript: None,
on_thought: None,
on_vad_start: None,
on_vad_end: None,
on_phase: None,
on_usage: None,
on_interrupted: None,
on_tool_call: None,
on_tool_cancelled: None,
on_turn_complete: None,
on_go_away: None,
on_connected: None,
on_disconnected: None,
on_resumed: None,
on_error: None,
on_transfer: None,
on_extracted: None,
on_extraction_error: None,
on_turn_complete_mode: CallbackMode::Blocking,
on_connected_mode: CallbackMode::Blocking,
on_disconnected_mode: CallbackMode::Blocking,
on_error_mode: CallbackMode::Blocking,
on_go_away_mode: CallbackMode::Blocking,
on_extracted_mode: CallbackMode::Blocking,
on_extraction_error_mode: CallbackMode::Blocking,
on_tool_cancelled_mode: CallbackMode::Blocking,
on_transfer_mode: CallbackMode::Blocking,
on_resumed_mode: CallbackMode::Blocking,
before_tool_response: None,
on_turn_boundary: None,
instruction_template: None,
instruction_amendment: None,
}
}
}
impl std::fmt::Debug for EventCallbacks {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("EventCallbacks")
.field("on_audio", &self.on_audio.is_some())
.field("on_text", &self.on_text.is_some())
.field("on_text_complete", &self.on_text_complete.is_some())
.field("on_input_transcript", &self.on_input_transcript.is_some())
.field("on_output_transcript", &self.on_output_transcript.is_some())
.field("on_thought", &self.on_thought.is_some())
.field("on_vad_start", &self.on_vad_start.is_some())
.field("on_vad_end", &self.on_vad_end.is_some())
.field("on_phase", &self.on_phase.is_some())
.field("on_usage", &self.on_usage.is_some())
.field("on_interrupted", &self.on_interrupted.is_some())
.field("on_tool_call", &self.on_tool_call.is_some())
.field("on_tool_cancelled", &self.on_tool_cancelled.is_some())
.field("on_turn_complete", &self.on_turn_complete.is_some())
.field("on_go_away", &self.on_go_away.is_some())
.field("on_connected", &self.on_connected.is_some())
.field("on_disconnected", &self.on_disconnected.is_some())
.field("on_resumed", &self.on_resumed.is_some())
.field("on_error", &self.on_error.is_some())
.field("on_transfer", &self.on_transfer.is_some())
.field("on_extracted", &self.on_extracted.is_some())
.field("on_extraction_error", &self.on_extraction_error.is_some())
.field("on_turn_complete_mode", &self.on_turn_complete_mode)
.field("on_connected_mode", &self.on_connected_mode)
.field("on_disconnected_mode", &self.on_disconnected_mode)
.field("on_error_mode", &self.on_error_mode)
.field("on_go_away_mode", &self.on_go_away_mode)
.field("on_extracted_mode", &self.on_extracted_mode)
.field("on_extraction_error_mode", &self.on_extraction_error_mode)
.field("on_tool_cancelled_mode", &self.on_tool_cancelled_mode)
.field("on_transfer_mode", &self.on_transfer_mode)
.field("on_resumed_mode", &self.on_resumed_mode)
.field("before_tool_response", &self.before_tool_response.is_some())
.field("on_turn_boundary", &self.on_turn_boundary.is_some())
.field("instruction_template", &self.instruction_template.is_some())
.field(
"instruction_amendment",
&self.instruction_amendment.is_some(),
)
.finish()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn default_callbacks_all_none() {
let cb = EventCallbacks::default();
assert!(cb.on_audio.is_none());
assert!(cb.on_text.is_none());
assert!(cb.on_interrupted.is_none());
assert!(cb.on_tool_call.is_none());
}
#[test]
fn sync_callback_callable() {
let mut cb = EventCallbacks::default();
let called = Arc::new(std::sync::atomic::AtomicBool::new(false));
let called_clone = called.clone();
cb.on_text = Some(Box::new(move |_text| {
called_clone.store(true, std::sync::atomic::Ordering::SeqCst);
}));
if let Some(f) = &cb.on_text {
f("hello");
}
assert!(called.load(std::sync::atomic::Ordering::SeqCst));
}
#[test]
fn callback_mode_defaults_to_blocking() {
let cb = EventCallbacks::default();
assert_eq!(cb.on_turn_complete_mode, CallbackMode::Blocking);
assert_eq!(cb.on_connected_mode, CallbackMode::Blocking);
assert_eq!(cb.on_disconnected_mode, CallbackMode::Blocking);
assert_eq!(cb.on_error_mode, CallbackMode::Blocking);
assert_eq!(cb.on_go_away_mode, CallbackMode::Blocking);
assert_eq!(cb.on_extracted_mode, CallbackMode::Blocking);
assert_eq!(cb.on_extraction_error_mode, CallbackMode::Blocking);
assert_eq!(cb.on_tool_cancelled_mode, CallbackMode::Blocking);
assert_eq!(cb.on_transfer_mode, CallbackMode::Blocking);
assert_eq!(cb.on_resumed_mode, CallbackMode::Blocking);
}
#[test]
fn debug_shows_registered() {
let mut cb = EventCallbacks::default();
cb.on_audio = Some(Box::new(|_| {}));
let debug = format!("{:?}", cb);
assert!(debug.contains("on_audio: true"));
assert!(debug.contains("on_text: false"));
}
}