1use std::sync::Arc;
26use std::time::Duration;
27
28use bytes::Bytes;
29use rs_genai::prelude::{FunctionCall, FunctionResponse, SessionPhase, UsageMetadata};
30use rs_genai::session::SessionWriter;
31
32use super::BoxFuture;
33use crate::state::State;
34
35#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
44pub enum CallbackMode {
45 #[default]
50 Blocking,
51 Concurrent,
56}
57
58pub struct EventCallbacks {
64 pub on_audio: Option<Box<dyn Fn(&Bytes) + Send + Sync>>,
67 pub on_text: Option<Box<dyn Fn(&str) + Send + Sync>>,
69 pub on_text_complete: Option<Box<dyn Fn(&str) + Send + Sync>>,
71 pub on_input_transcript: Option<Box<dyn Fn(&str, bool) + Send + Sync>>,
73 pub on_output_transcript: Option<Box<dyn Fn(&str, bool) + Send + Sync>>,
75 pub on_thought: Option<Box<dyn Fn(&str) + Send + Sync>>,
77 pub on_vad_start: Option<Box<dyn Fn() + Send + Sync>>,
79 pub on_vad_end: Option<Box<dyn Fn() + Send + Sync>>,
81 pub on_phase: Option<Box<dyn Fn(SessionPhase) + Send + Sync>>,
83 pub on_usage: Option<Box<dyn Fn(&UsageMetadata) + Send + Sync>>,
85
86 pub on_interrupted: Option<Arc<dyn Fn() -> BoxFuture<()> + Send + Sync>>,
89 pub on_tool_call: Option<
93 Arc<
94 dyn Fn(Vec<FunctionCall>, State) -> BoxFuture<Option<Vec<FunctionResponse>>>
95 + Send
96 + Sync,
97 >,
98 >,
99 pub on_tool_cancelled: Option<Arc<dyn Fn(Vec<String>) -> BoxFuture<()> + Send + Sync>>,
101 pub on_turn_complete: Option<Arc<dyn Fn() -> BoxFuture<()> + Send + Sync>>,
103 pub on_go_away: Option<Arc<dyn Fn(Duration) -> BoxFuture<()> + Send + Sync>>,
105 pub on_connected: Option<Arc<dyn Fn(Arc<dyn SessionWriter>) -> BoxFuture<()> + Send + Sync>>,
109 pub on_disconnected: Option<Arc<dyn Fn(Option<String>) -> BoxFuture<()> + Send + Sync>>,
111 pub on_resumed: Option<Arc<dyn Fn() -> BoxFuture<()> + Send + Sync>>,
113 pub on_error: Option<Arc<dyn Fn(String) -> BoxFuture<()> + Send + Sync>>,
115 pub on_transfer: Option<Arc<dyn Fn(String, String) -> BoxFuture<()> + Send + Sync>>,
117 pub on_extracted: Option<Arc<dyn Fn(String, serde_json::Value) -> BoxFuture<()> + Send + Sync>>,
119 pub on_extraction_error: Option<Arc<dyn Fn(String, String) -> BoxFuture<()> + Send + Sync>>,
124
125 pub on_turn_complete_mode: CallbackMode,
128 pub on_connected_mode: CallbackMode,
130 pub on_disconnected_mode: CallbackMode,
132 pub on_error_mode: CallbackMode,
134 pub on_go_away_mode: CallbackMode,
136 pub on_extracted_mode: CallbackMode,
138 pub on_extraction_error_mode: CallbackMode,
140 pub on_tool_cancelled_mode: CallbackMode,
142 pub on_transfer_mode: CallbackMode,
144 pub on_resumed_mode: CallbackMode,
146
147 pub before_tool_response: Option<
154 Arc<dyn Fn(Vec<FunctionResponse>, State) -> BoxFuture<Vec<FunctionResponse>> + Send + Sync>,
155 >,
156
157 pub on_turn_boundary:
163 Option<Arc<dyn Fn(State, Arc<dyn SessionWriter>) -> BoxFuture<()> + Send + Sync>>,
164
165 pub instruction_template: Option<Arc<dyn Fn(&State) -> Option<String> + Send + Sync>>,
173
174 pub instruction_amendment: Option<Arc<dyn Fn(&State) -> Option<String> + Send + Sync>>,
184}
185
186impl Default for EventCallbacks {
187 fn default() -> Self {
188 Self {
189 on_audio: None,
190 on_text: None,
191 on_text_complete: None,
192 on_input_transcript: None,
193 on_output_transcript: None,
194 on_thought: None,
195 on_vad_start: None,
196 on_vad_end: None,
197 on_phase: None,
198 on_usage: None,
199 on_interrupted: None,
200 on_tool_call: None,
201 on_tool_cancelled: None,
202 on_turn_complete: None,
203 on_go_away: None,
204 on_connected: None,
205 on_disconnected: None,
206 on_resumed: None,
207 on_error: None,
208 on_transfer: None,
209 on_extracted: None,
210 on_extraction_error: None,
211 on_turn_complete_mode: CallbackMode::Blocking,
212 on_connected_mode: CallbackMode::Blocking,
213 on_disconnected_mode: CallbackMode::Blocking,
214 on_error_mode: CallbackMode::Blocking,
215 on_go_away_mode: CallbackMode::Blocking,
216 on_extracted_mode: CallbackMode::Blocking,
217 on_extraction_error_mode: CallbackMode::Blocking,
218 on_tool_cancelled_mode: CallbackMode::Blocking,
219 on_transfer_mode: CallbackMode::Blocking,
220 on_resumed_mode: CallbackMode::Blocking,
221 before_tool_response: None,
222 on_turn_boundary: None,
223 instruction_template: None,
224 instruction_amendment: None,
225 }
226 }
227}
228
229impl std::fmt::Debug for EventCallbacks {
230 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
231 f.debug_struct("EventCallbacks")
232 .field("on_audio", &self.on_audio.is_some())
233 .field("on_text", &self.on_text.is_some())
234 .field("on_text_complete", &self.on_text_complete.is_some())
235 .field("on_input_transcript", &self.on_input_transcript.is_some())
236 .field("on_output_transcript", &self.on_output_transcript.is_some())
237 .field("on_thought", &self.on_thought.is_some())
238 .field("on_vad_start", &self.on_vad_start.is_some())
239 .field("on_vad_end", &self.on_vad_end.is_some())
240 .field("on_phase", &self.on_phase.is_some())
241 .field("on_usage", &self.on_usage.is_some())
242 .field("on_interrupted", &self.on_interrupted.is_some())
243 .field("on_tool_call", &self.on_tool_call.is_some())
244 .field("on_tool_cancelled", &self.on_tool_cancelled.is_some())
245 .field("on_turn_complete", &self.on_turn_complete.is_some())
246 .field("on_go_away", &self.on_go_away.is_some())
247 .field("on_connected", &self.on_connected.is_some())
248 .field("on_disconnected", &self.on_disconnected.is_some())
249 .field("on_resumed", &self.on_resumed.is_some())
250 .field("on_error", &self.on_error.is_some())
251 .field("on_transfer", &self.on_transfer.is_some())
252 .field("on_extracted", &self.on_extracted.is_some())
253 .field("on_extraction_error", &self.on_extraction_error.is_some())
254 .field("on_turn_complete_mode", &self.on_turn_complete_mode)
255 .field("on_connected_mode", &self.on_connected_mode)
256 .field("on_disconnected_mode", &self.on_disconnected_mode)
257 .field("on_error_mode", &self.on_error_mode)
258 .field("on_go_away_mode", &self.on_go_away_mode)
259 .field("on_extracted_mode", &self.on_extracted_mode)
260 .field("on_extraction_error_mode", &self.on_extraction_error_mode)
261 .field("on_tool_cancelled_mode", &self.on_tool_cancelled_mode)
262 .field("on_transfer_mode", &self.on_transfer_mode)
263 .field("on_resumed_mode", &self.on_resumed_mode)
264 .field("before_tool_response", &self.before_tool_response.is_some())
265 .field("on_turn_boundary", &self.on_turn_boundary.is_some())
266 .field("instruction_template", &self.instruction_template.is_some())
267 .field(
268 "instruction_amendment",
269 &self.instruction_amendment.is_some(),
270 )
271 .finish()
272 }
273}
274
275#[cfg(test)]
276mod tests {
277 use super::*;
278
279 #[test]
280 fn default_callbacks_all_none() {
281 let cb = EventCallbacks::default();
282 assert!(cb.on_audio.is_none());
283 assert!(cb.on_text.is_none());
284 assert!(cb.on_interrupted.is_none());
285 assert!(cb.on_tool_call.is_none());
286 }
287
288 #[test]
289 fn sync_callback_callable() {
290 let mut cb = EventCallbacks::default();
291 let called = Arc::new(std::sync::atomic::AtomicBool::new(false));
292 let called_clone = called.clone();
293 cb.on_text = Some(Box::new(move |_text| {
294 called_clone.store(true, std::sync::atomic::Ordering::SeqCst);
295 }));
296 if let Some(f) = &cb.on_text {
297 f("hello");
298 }
299 assert!(called.load(std::sync::atomic::Ordering::SeqCst));
300 }
301
302 #[test]
303 fn callback_mode_defaults_to_blocking() {
304 let cb = EventCallbacks::default();
305 assert_eq!(cb.on_turn_complete_mode, CallbackMode::Blocking);
306 assert_eq!(cb.on_connected_mode, CallbackMode::Blocking);
307 assert_eq!(cb.on_disconnected_mode, CallbackMode::Blocking);
308 assert_eq!(cb.on_error_mode, CallbackMode::Blocking);
309 assert_eq!(cb.on_go_away_mode, CallbackMode::Blocking);
310 assert_eq!(cb.on_extracted_mode, CallbackMode::Blocking);
311 assert_eq!(cb.on_extraction_error_mode, CallbackMode::Blocking);
312 assert_eq!(cb.on_tool_cancelled_mode, CallbackMode::Blocking);
313 assert_eq!(cb.on_transfer_mode, CallbackMode::Blocking);
314 assert_eq!(cb.on_resumed_mode, CallbackMode::Blocking);
315 }
316
317 #[test]
318 fn debug_shows_registered() {
319 let mut cb = EventCallbacks::default();
320 cb.on_audio = Some(Box::new(|_| {}));
321 let debug = format!("{:?}", cb);
322 assert!(debug.contains("on_audio: true"));
323 assert!(debug.contains("on_text: false"));
324 }
325}