1use crate::ReferOption;
2use crate::call::Command;
3use crate::event::SessionEvent;
4use anyhow::{Result, anyhow};
5use async_trait::async_trait;
6use futures::{Stream, StreamExt};
7use once_cell::sync::Lazy;
8use regex::Regex;
9use reqwest::Client;
10use serde::{Deserialize, Serialize};
11use serde_json::json;
12use std::collections::HashMap;
13use std::pin::Pin;
14use std::sync::Arc;
15use tracing::{info, warn};
16
17static RE_HANGUP: Lazy<Regex> = Lazy::new(|| Regex::new(r"<hangup\s*/>").unwrap());
18static RE_REFER: Lazy<Regex> = Lazy::new(|| Regex::new(r#"<refer\s+to="([^"]+)"\s*/>"#).unwrap());
19static RE_PLAY: Lazy<Regex> = Lazy::new(|| Regex::new(r#"<play\s+file="([^"]+)"\s*/>"#).unwrap());
20static RE_GOTO: Lazy<Regex> = Lazy::new(|| Regex::new(r#"<goto\s+scene="([^"]+)"\s*/>"#).unwrap());
21static RE_SENTENCE: Lazy<Regex> = Lazy::new(|| Regex::new(r"(?m)[.!?。!?\n]\s*").unwrap());
22static FILLERS: Lazy<std::collections::HashSet<String>> = Lazy::new(|| {
23 let mut s = std::collections::HashSet::new();
24 let default_fillers = ["嗯", "啊", "哦", "那个", "那个...", "uh", "um", "ah"];
25
26 if let Ok(content) = std::fs::read_to_string("config/fillers.txt") {
27 for line in content.lines() {
28 let trimmed = line.trim().to_lowercase();
29 if !trimmed.is_empty() {
30 s.insert(trimmed);
31 }
32 }
33 }
34
35 if s.is_empty() {
36 for f in default_fillers {
37 s.insert(f.to_string());
38 }
39 }
40 s
41});
42
43use super::ChatMessage;
44use super::InterruptionStrategy;
45use super::LlmConfig;
46use super::dialogue::DialogueHandler;
47
48const MAX_RAG_ATTEMPTS: usize = 3;
49
50#[derive(Debug, Clone)]
51pub enum LlmStreamEvent {
52 Content(String),
53 Reasoning(String),
54}
55
56#[async_trait]
57pub trait LlmProvider: Send + Sync {
58 async fn call(&self, config: &LlmConfig, history: &[ChatMessage]) -> Result<String>;
59 async fn call_stream(
60 &self,
61 config: &LlmConfig,
62 history: &[ChatMessage],
63 ) -> Result<Pin<Box<dyn Stream<Item = Result<LlmStreamEvent>> + Send>>>;
64}
65
66pub struct RealtimeResponse {
67 pub audio_delta: Option<Vec<u8>>,
68 pub text_delta: Option<String>,
69 pub function_call: Option<ToolInvocation>,
70 pub speech_started: bool,
71}
72
73#[async_trait]
74pub trait RealtimeProvider: Send + Sync {
75 async fn connect(&self, config: &LlmConfig) -> Result<()>;
76 async fn send_audio(&self, audio: &[i16]) -> Result<()>;
77 async fn subscribe(
78 &self,
79 ) -> Result<Pin<Box<dyn Stream<Item = Result<RealtimeResponse>> + Send>>>;
80}
81
82struct DefaultLlmProvider {
83 client: Client,
84}
85
86impl DefaultLlmProvider {
87 fn new() -> Self {
88 Self {
89 client: Client::new(),
90 }
91 }
92}
93
94#[async_trait]
95impl LlmProvider for DefaultLlmProvider {
96 async fn call(&self, config: &LlmConfig, history: &[ChatMessage]) -> Result<String> {
97 let mut url = config
98 .base_url
99 .clone()
100 .unwrap_or_else(|| "https://api.openai.com/v1".to_string());
101 let model = config
102 .model
103 .clone()
104 .unwrap_or_else(|| "gpt-4-turbo".to_string());
105 let api_key = config.api_key.clone().unwrap_or_default();
106
107 if !url.ends_with("/chat/completions") {
108 url = format!("{}/chat/completions", url.trim_end_matches('/'));
109 }
110
111 let body = json!({
112 "model": model,
113 "messages": history,
114 });
115
116 let res = self
117 .client
118 .post(&url)
119 .header("Authorization", format!("Bearer {}", api_key))
120 .json(&body)
121 .send()
122 .await?;
123
124 if !res.status().is_success() {
125 return Err(anyhow!("LLM request failed: {}", res.status()));
126 }
127
128 let json: serde_json::Value = res.json().await?;
129 let content = json["choices"][0]["message"]["content"]
130 .as_str()
131 .ok_or_else(|| anyhow!("Invalid LLM response"))?
132 .to_string();
133
134 Ok(content)
135 }
136
137 async fn call_stream(
138 &self,
139 config: &LlmConfig,
140 history: &[ChatMessage],
141 ) -> Result<Pin<Box<dyn Stream<Item = Result<LlmStreamEvent>> + Send>>> {
142 let mut url = config
143 .base_url
144 .clone()
145 .unwrap_or_else(|| "https://api.openai.com/v1".to_string());
146 let model = config
147 .model
148 .clone()
149 .unwrap_or_else(|| "gpt-4-turbo".to_string());
150 let api_key = config.api_key.clone().unwrap_or_default();
151
152 if !url.ends_with("/chat/completions") {
153 url = format!("{}/chat/completions", url.trim_end_matches('/'));
154 }
155
156 let body = json!({
157 "model": model,
158 "messages": history,
159 "stream": true,
160 });
161
162 let res = self
163 .client
164 .post(&url)
165 .header("Authorization", format!("Bearer {}", api_key))
166 .json(&body)
167 .send()
168 .await?;
169
170 if !res.status().is_success() {
171 return Err(anyhow!("LLM request failed: {}", res.status()));
172 }
173
174 let stream = res.bytes_stream();
175 let s = async_stream::stream! {
176 let mut buffer = String::new();
177 for await chunk in stream {
178 match chunk {
179 Ok(bytes) => {
180 let text = String::from_utf8_lossy(&bytes);
181 buffer.push_str(&text);
182
183 while let Some(line_end) = buffer.find('\n') {
184 let line = buffer[..line_end].trim();
185 if line.starts_with("data:") {
186 let data = &line[5..].trim();
187 if *data == "[DONE]" {
188 break;
189 }
190 if let Ok(json) = serde_json::from_str::<serde_json::Value>(data) {
191 if let Some(delta) = json["choices"][0].get("delta") {
192 if let Some(thinking) = delta.get("reasoning_content").and_then(|v| v.as_str()) {
193 yield Ok(LlmStreamEvent::Reasoning(thinking.to_string()));
194 }
195 if let Some(content) = delta.get("content").and_then(|v| v.as_str()) {
196 yield Ok(LlmStreamEvent::Content(content.to_string()));
197 }
198 }
199 }
200 }
201 buffer.drain(..=line_end);
202 }
203 }
204 Err(e) => yield Err(anyhow!(e)),
205 }
206 }
207 };
208
209 Ok(Box::pin(s))
210 }
211}
212
213#[async_trait]
214pub trait RagRetriever: Send + Sync {
215 async fn retrieve(&self, query: &str) -> Result<String>;
216}
217
218struct NoopRagRetriever;
219
220#[async_trait]
221impl RagRetriever for NoopRagRetriever {
222 async fn retrieve(&self, _query: &str) -> Result<String> {
223 Ok(String::new())
224 }
225}
226
227#[derive(Debug, Deserialize)]
228#[serde(rename_all = "camelCase")]
229struct StructuredResponse {
230 text: Option<String>,
231 wait_input_timeout: Option<u32>,
232 tools: Option<Vec<ToolInvocation>>,
233}
234
235#[derive(Debug, Deserialize, Serialize, Clone)]
236#[serde(tag = "name", rename_all = "lowercase")]
237pub enum ToolInvocation {
238 #[serde(rename_all = "camelCase")]
239 Hangup {
240 reason: Option<String>,
241 initiator: Option<String>,
242 },
243 #[serde(rename_all = "camelCase")]
244 Refer {
245 caller: String,
246 callee: String,
247 options: Option<ReferOption>,
248 },
249 #[serde(rename_all = "camelCase")]
250 Rag {
251 query: String,
252 source: Option<String>,
253 },
254 #[serde(rename_all = "camelCase")]
255 Accept { options: Option<crate::CallOption> },
256 #[serde(rename_all = "camelCase")]
257 Reject {
258 reason: Option<String>,
259 code: Option<u32>,
260 },
261}
262
263pub struct LlmHandler {
264 config: LlmConfig,
265 interruption_config: super::InterruptionConfig,
266 global_follow_up_config: Option<super::FollowUpConfig>,
267 dtmf_config: Option<HashMap<String, super::DtmfAction>>,
268 history: Vec<ChatMessage>,
269 provider: Arc<dyn LlmProvider>,
270 rag_retriever: Arc<dyn RagRetriever>,
271 is_speaking: bool,
272 is_hanging_up: bool,
273 consecutive_follow_ups: u32,
274 last_interaction_at: std::time::Instant,
275 event_sender: Option<crate::event::EventSender>,
276 last_asr_final_at: Option<std::time::Instant>,
277 last_tts_start_at: Option<std::time::Instant>,
278 call: Option<crate::call::ActiveCallRef>,
279 scenes: HashMap<String, super::Scene>,
280 current_scene_id: Option<String>,
281}
282
283impl LlmHandler {
284 pub fn new(
285 config: LlmConfig,
286 interruption: super::InterruptionConfig,
287 global_follow_up_config: Option<super::FollowUpConfig>,
288 scenes: HashMap<String, super::Scene>,
289 dtmf: Option<HashMap<String, super::DtmfAction>>,
290 initial_scene_id: Option<String>,
291 ) -> Self {
292 Self::with_provider(
293 config,
294 Arc::new(DefaultLlmProvider::new()),
295 Arc::new(NoopRagRetriever),
296 interruption,
297 global_follow_up_config,
298 scenes,
299 dtmf,
300 initial_scene_id,
301 )
302 }
303
304 pub fn with_provider(
305 config: LlmConfig,
306 provider: Arc<dyn LlmProvider>,
307 rag_retriever: Arc<dyn RagRetriever>,
308 interruption: super::InterruptionConfig,
309 global_follow_up_config: Option<super::FollowUpConfig>,
310 scenes: HashMap<String, super::Scene>,
311 dtmf: Option<HashMap<String, super::DtmfAction>>,
312 initial_scene_id: Option<String>,
313 ) -> Self {
314 let mut history = Vec::new();
315 let prompt = config.prompt.clone().unwrap_or_default();
316 let system_prompt = Self::build_system_prompt(&prompt);
317
318 history.push(ChatMessage {
319 role: "system".to_string(),
320 content: system_prompt,
321 });
322
323 Self {
324 config,
325 interruption_config: interruption,
326 global_follow_up_config,
327 dtmf_config: dtmf,
328 history,
329 provider,
330 rag_retriever,
331 is_speaking: false,
332 is_hanging_up: false,
333 consecutive_follow_ups: 0,
334 last_interaction_at: std::time::Instant::now(),
335 event_sender: None,
336 last_asr_final_at: None,
337 last_tts_start_at: None,
338 call: None,
339 scenes,
340 current_scene_id: initial_scene_id,
341 }
342 }
343
344 fn build_system_prompt(prompt: &str) -> String {
345 format!(
346 "{}\n\n\
347 Tool usage instructions:\n\
348 - To hang up the call, output: <hangup/>\n\
349 - To transfer the call, output: <refer to=\"sip:xxxx\"/>\n\
350 - To play an audio file, output: <play file=\"path/to/file.wav\"/>\n\
351 - To switch to another scene, output: <goto scene=\"scene_id\"/>\n\
352 Please use these XML tags for action triggers. They are optimized for streaming playback. \
353 Output your response in short sentences. Each sentence will be played as soon as it is finished.",
354 prompt
355 )
356 }
357
358 fn get_dtmf_action(&self, digit: &str) -> Option<super::DtmfAction> {
359 if let Some(scene_id) = &self.current_scene_id {
361 if let Some(scene) = self.scenes.get(scene_id) {
362 if let Some(dtmf) = &scene.dtmf {
363 if let Some(action) = dtmf.get(digit) {
364 return Some(action.clone());
365 }
366 }
367 }
368 }
369
370 if let Some(dtmf) = &self.dtmf_config {
372 if let Some(action) = dtmf.get(digit) {
373 return Some(action.clone());
374 }
375 }
376
377 None
378 }
379
380 async fn handle_dtmf_action(&mut self, action: super::DtmfAction) -> Result<Vec<Command>> {
381 match action {
382 super::DtmfAction::Goto { scene } => {
383 info!("DTMF action: switch to scene {}", scene);
384 self.switch_to_scene(&scene, true).await
385 }
386 super::DtmfAction::Transfer { target } => {
387 info!("DTMF action: transfer to {}", target);
388 Ok(vec![Command::Refer {
389 caller: String::new(),
390 callee: target,
391 options: None,
392 }])
393 }
394 super::DtmfAction::Hangup => {
395 info!("DTMF action: hangup");
396 Ok(vec![Command::Hangup {
397 reason: Some("DTMF Hangup".to_string()),
398 initiator: Some("ai".to_string()),
399 }])
400 }
401 }
402 }
403
404 async fn switch_to_scene(
405 &mut self,
406 scene_id: &str,
407 trigger_response: bool,
408 ) -> Result<Vec<Command>> {
409 if let Some(scene) = self.scenes.get(scene_id).cloned() {
410 info!("Switching to scene: {}", scene_id);
411 self.current_scene_id = Some(scene_id.to_string());
412 let system_prompt = Self::build_system_prompt(&scene.prompt);
414 if let Some(first_msg) = self.history.get_mut(0) {
415 if first_msg.role == "system" {
416 first_msg.content = system_prompt;
417 }
418 }
419
420 let mut commands = Vec::new();
421 if let Some(url) = &scene.play {
422 commands.push(Command::Play {
423 url: url.clone(),
424 play_id: None,
425 auto_hangup: None,
426 wait_input_timeout: None,
427 });
428 }
429
430 if trigger_response {
431 let response_cmds = self.generate_response().await?;
432 commands.extend(response_cmds);
433 }
434 Ok(commands)
435 } else {
436 warn!("Scene not found: {}", scene_id);
437 Ok(vec![])
438 }
439 }
440
441 pub fn get_history_ref(&self) -> &[ChatMessage] {
442 &self.history
443 }
444
445 pub fn get_current_scene_id(&self) -> Option<String> {
446 self.current_scene_id.clone()
447 }
448
449 pub fn set_call(&mut self, call: crate::call::ActiveCallRef) {
450 self.call = Some(call);
451 }
452
453 pub fn set_event_sender(&mut self, sender: crate::event::EventSender) {
454 self.event_sender = Some(sender.clone());
455 if let Some(greeting) = &self.config.greeting {
457 let _ = sender.send(crate::event::SessionEvent::AddHistory {
458 sender: Some("system".to_string()),
459 timestamp: crate::media::get_timestamp(),
460 speaker: "assistant".to_string(),
461 text: greeting.clone(),
462 });
463 }
464 }
465
466 fn send_debug_event(&self, key: &str, data: serde_json::Value) {
467 if let Some(sender) = &self.event_sender {
468 let timestamp = crate::media::get_timestamp();
469 if key == "llm_response" {
471 if let Some(text) = data.get("response").and_then(|v| v.as_str()) {
472 let _ = sender.send(crate::event::SessionEvent::AddHistory {
473 sender: Some("llm".to_string()),
474 timestamp,
475 speaker: "assistant".to_string(),
476 text: text.to_string(),
477 });
478 }
479 }
480
481 let event = crate::event::SessionEvent::Metrics {
482 timestamp,
483 key: key.to_string(),
484 duration: 0,
485 data,
486 };
487 let _ = sender.send(event);
488 }
489 }
490
491 async fn call_llm(&self) -> Result<String> {
492 self.provider.call(&self.config, &self.history).await
493 }
494
495 fn create_tts_command(
496 &self,
497 text: String,
498 wait_input_timeout: Option<u32>,
499 auto_hangup: Option<bool>,
500 ) -> Command {
501 let timeout = wait_input_timeout.unwrap_or(10000);
502 let play_id = uuid::Uuid::new_v4().to_string();
503
504 if let Some(sender) = &self.event_sender {
505 let _ = sender.send(crate::event::SessionEvent::Metrics {
506 timestamp: crate::media::get_timestamp(),
507 key: "tts_play_id_map".to_string(),
508 duration: 0,
509 data: serde_json::json!({
510 "playId": play_id,
511 "text": text,
512 }),
513 });
514 }
515
516 Command::Tts {
517 text,
518 speaker: None,
519 play_id: Some(play_id),
520 auto_hangup,
521 streaming: None,
522 end_of_stream: Some(true),
523 option: None,
524 wait_input_timeout: Some(timeout),
525 base64: None,
526 }
527 }
528
529 async fn generate_response(&mut self) -> Result<Vec<Command>> {
530 let start_time = crate::media::get_timestamp();
531 let play_id = uuid::Uuid::new_v4().to_string();
532
533 self.send_debug_event(
535 "llm_call_start",
536 json!({
537 "history_length": self.history.len(),
538 "playId": play_id,
539 }),
540 );
541
542 let mut stream = self
543 .provider
544 .call_stream(&self.config, &self.history)
545 .await?;
546
547 let mut full_content = String::new();
548 let mut full_reasoning = String::new();
549 let mut buffer = String::new();
550 let mut commands = Vec::new();
551 let mut is_json_mode = false;
552 let mut checked_json_mode = false;
553 let mut first_token_time = None;
554
555 while let Some(chunk_result) = stream.next().await {
556 let event = match chunk_result {
557 Ok(c) => c,
558 Err(e) => {
559 warn!("LLM stream error: {}", e);
560 break;
561 }
562 };
563
564 match event {
565 LlmStreamEvent::Reasoning(text) => {
566 full_reasoning.push_str(&text);
567 }
568 LlmStreamEvent::Content(chunk) => {
569 if first_token_time.is_none() && !chunk.trim().is_empty() {
570 first_token_time = Some(crate::media::get_timestamp());
571 }
572
573 full_content.push_str(&chunk);
574 buffer.push_str(&chunk);
575
576 if !checked_json_mode {
577 let trimmed = full_content.trim();
578 if !trimmed.is_empty() {
579 if trimmed.starts_with('{') || trimmed.starts_with('`') {
580 is_json_mode = true;
581 }
582 checked_json_mode = true;
583 }
584 }
585
586 if checked_json_mode && !is_json_mode {
587 let extracted =
588 self.extract_streaming_commands(&mut buffer, &play_id, false);
589 for cmd in extracted {
590 if let Some(call) = &self.call {
591 let _ = call.enqueue_command(cmd).await;
592 } else {
593 commands.push(cmd);
594 }
595 }
596 }
597 }
598 }
599 }
600
601 let end_time = crate::media::get_timestamp();
603 self.send_debug_event(
604 "llm_response",
605 json!({
606 "response": full_content,
607 "reasoning": full_reasoning,
608 "is_json_mode": is_json_mode,
609 "duration": end_time - start_time,
610 "ttfb": first_token_time.map(|t| t - start_time).unwrap_or(0),
611 "playId": play_id,
612 }),
613 );
614
615 if is_json_mode {
616 self.interpret_response(full_content).await
617 } else {
618 let extracted = self.extract_streaming_commands(&mut buffer, &play_id, true);
619 for cmd in extracted {
620 if let Some(call) = &self.call {
621 let _ = call.enqueue_command(cmd).await;
622 } else {
623 commands.push(cmd);
624 }
625 }
626 if !full_content.trim().is_empty() {
627 self.history.push(ChatMessage {
628 role: "assistant".to_string(),
629 content: full_content,
630 });
631 self.is_speaking = true;
632 self.last_tts_start_at = Some(std::time::Instant::now());
633 }
634 Ok(commands)
635 }
636 }
637
638 fn extract_streaming_commands(
639 &mut self,
640 buffer: &mut String,
641 play_id: &str,
642 is_final: bool,
643 ) -> Vec<Command> {
644 let mut commands = Vec::new();
645
646 loop {
647 let hangup_pos = RE_HANGUP.find(buffer);
648 let refer_pos = RE_REFER.captures(buffer);
649 let play_pos = RE_PLAY.captures(buffer);
650 let goto_pos = RE_GOTO.captures(buffer);
651 let sentence_pos = RE_SENTENCE.find(buffer);
652
653 let mut positions = Vec::new();
655 if let Some(m) = hangup_pos {
656 positions.push((m.start(), 0));
657 }
658 if let Some(caps) = &refer_pos {
659 positions.push((caps.get(0).unwrap().start(), 1));
660 }
661 if let Some(caps) = &play_pos {
662 positions.push((caps.get(0).unwrap().start(), 3));
663 }
664 if let Some(caps) = &goto_pos {
665 positions.push((caps.get(0).unwrap().start(), 4));
666 }
667 if let Some(m) = sentence_pos {
668 positions.push((m.start(), 2));
669 }
670
671 positions.sort_by_key(|p| p.0);
672
673 if let Some((pos, kind)) = positions.first() {
674 let pos = *pos;
675 match kind {
676 0 => {
677 let prefix = buffer[..pos].to_string();
679 if !prefix.trim().is_empty() {
680 let mut cmd = self.create_tts_command_with_id(
681 prefix,
682 play_id.to_string(),
683 Some(true),
684 );
685 if let Command::Tts { end_of_stream, .. } = &mut cmd {
686 *end_of_stream = Some(true);
687 }
688 self.is_hanging_up = true;
690 commands.push(cmd);
691 } else {
692 let mut cmd = self.create_tts_command_with_id(
694 "".to_string(),
695 play_id.to_string(),
696 Some(true),
697 );
698 if let Command::Tts { end_of_stream, .. } = &mut cmd {
699 *end_of_stream = Some(true);
700 }
701 self.is_hanging_up = true;
702 commands.push(cmd);
703 }
704 buffer.drain(..RE_HANGUP.find(buffer).unwrap().end());
705 return commands;
707 }
708 1 => {
709 let caps = RE_REFER.captures(buffer).unwrap();
711 let mat = caps.get(0).unwrap();
712 let callee = caps.get(1).unwrap().as_str().to_string();
713
714 let prefix = buffer[..pos].to_string();
715 if !prefix.trim().is_empty() {
716 commands.push(self.create_tts_command_with_id(
717 prefix,
718 play_id.to_string(),
719 None,
720 ));
721 }
722 commands.push(Command::Refer {
723 caller: String::new(),
724 callee,
725 options: None,
726 });
727 buffer.drain(..mat.end());
728 }
729 3 => {
730 let caps = RE_PLAY.captures(buffer).unwrap();
732 let mat = caps.get(0).unwrap();
733 let url = caps.get(1).unwrap().as_str().to_string();
734
735 let prefix = buffer[..pos].to_string();
736 if !prefix.trim().is_empty() {
737 commands.push(self.create_tts_command_with_id(
738 prefix,
739 play_id.to_string(),
740 None,
741 ));
742 }
743 commands.push(Command::Play {
744 url,
745 play_id: None,
746 auto_hangup: None,
747 wait_input_timeout: None,
748 });
749 buffer.drain(..mat.end());
750 }
751 4 => {
752 let caps = RE_GOTO.captures(buffer).unwrap();
754 let mat = caps.get(0).unwrap();
755 let scene_id = caps.get(1).unwrap().as_str().to_string();
756
757 let prefix = buffer[..pos].to_string();
758 if !prefix.trim().is_empty() {
759 commands.push(self.create_tts_command_with_id(
760 prefix,
761 play_id.to_string(),
762 None,
763 ));
764 }
765
766 info!("Switching to scene (from stream): {}", scene_id);
767 if let Some(scene) = self.scenes.get(&scene_id) {
768 self.current_scene_id = Some(scene_id);
769 let system_prompt = Self::build_system_prompt(&scene.prompt);
771 if let Some(first_msg) = self.history.get_mut(0) {
772 if first_msg.role == "system" {
773 first_msg.content = system_prompt;
774 }
775 }
776 } else {
777 warn!("Scene not found: {}", scene_id);
778 }
779
780 buffer.drain(..mat.end());
781 }
782 2 => {
783 let mat = sentence_pos.unwrap();
785 let sentence = buffer[..mat.end()].to_string();
786 if !sentence.trim().is_empty() {
787 commands.push(self.create_tts_command_with_id(
788 sentence,
789 play_id.to_string(),
790 None,
791 ));
792 }
793 buffer.drain(..mat.end());
794 }
795 _ => unreachable!(),
796 }
797 } else {
798 break;
799 }
800 }
801
802 if is_final {
803 let remaining = buffer.trim().to_string();
804 if !remaining.is_empty() {
805 commands.push(self.create_tts_command_with_id(
806 remaining,
807 play_id.to_string(),
808 None,
809 ));
810 }
811 buffer.clear();
812
813 if let Some(last) = commands.last_mut() {
814 if let Command::Tts { end_of_stream, .. } = last {
815 *end_of_stream = Some(true);
816 }
817 } else if !self.is_hanging_up {
818 commands.push(Command::Tts {
819 text: "".to_string(),
820 speaker: None,
821 play_id: Some(play_id.to_string()),
822 auto_hangup: None,
823 streaming: Some(true),
824 end_of_stream: Some(true),
825 option: None,
826 wait_input_timeout: None,
827 base64: None,
828 });
829 }
830 }
831
832 commands
833 }
834
835 fn create_tts_command_with_id(
836 &self,
837 text: String,
838 play_id: String,
839 auto_hangup: Option<bool>,
840 ) -> Command {
841 Command::Tts {
842 text,
843 speaker: None,
844 play_id: Some(play_id),
845 auto_hangup,
846 streaming: Some(true),
847 end_of_stream: None,
848 option: None,
849 wait_input_timeout: Some(10000),
850 base64: None,
851 }
852 }
853
854 async fn interpret_response(&mut self, initial: String) -> Result<Vec<Command>> {
855 let mut tool_commands = Vec::new();
856 let mut wait_input_timeout = None;
857 let mut attempts = 0;
858 let final_text: Option<String>;
859 let mut raw = initial;
860
861 loop {
862 attempts += 1;
863 let mut rerun_for_rag = false;
864
865 if let Some(structured) = parse_structured_response(&raw) {
866 if wait_input_timeout.is_none() {
867 wait_input_timeout = structured.wait_input_timeout;
868 }
869
870 if let Some(tools) = structured.tools {
871 for tool in tools {
872 match tool {
873 ToolInvocation::Hangup {
874 ref reason,
875 ref initiator,
876 } => {
877 self.send_debug_event(
879 "tool_invocation",
880 json!({
881 "tool": "Hangup",
882 "params": {
883 "reason": reason,
884 "initiator": initiator,
885 }
886 }),
887 );
888 tool_commands.push(Command::Hangup {
889 reason: reason.clone(),
890 initiator: initiator.clone(),
891 });
892 }
893 ToolInvocation::Refer {
894 ref caller,
895 ref callee,
896 ref options,
897 } => {
898 self.send_debug_event(
900 "tool_invocation",
901 json!({
902 "tool": "Refer",
903 "params": {
904 "caller": caller,
905 "callee": callee,
906 }
907 }),
908 );
909 tool_commands.push(Command::Refer {
910 caller: caller.clone(),
911 callee: callee.clone(),
912 options: options.clone(),
913 });
914 }
915 ToolInvocation::Rag {
916 ref query,
917 ref source,
918 } => {
919 self.send_debug_event(
921 "tool_invocation",
922 json!({
923 "tool": "Rag",
924 "params": {
925 "query": query,
926 "source": source,
927 }
928 }),
929 );
930
931 let rag_result = self.rag_retriever.retrieve(&query).await?;
932
933 self.send_debug_event(
935 "rag_result",
936 json!({
937 "query": query,
938 "result": rag_result,
939 }),
940 );
941
942 let summary = if let Some(source) = source {
943 format!("[{}] {}", source, rag_result)
944 } else {
945 rag_result
946 };
947 self.history.push(ChatMessage {
948 role: "system".to_string(),
949 content: format!("RAG result for {}: {}", query, summary),
950 });
951 rerun_for_rag = true;
952 }
953 ToolInvocation::Accept { ref options } => {
954 self.send_debug_event(
955 "tool_invocation",
956 json!({
957 "tool": "Accept",
958 }),
959 );
960 tool_commands.push(Command::Accept {
961 option: options.clone().unwrap_or_default(),
962 });
963 }
964 ToolInvocation::Reject { ref reason, code } => {
965 self.send_debug_event(
966 "tool_invocation",
967 json!({
968 "tool": "Reject",
969 "params": {
970 "reason": reason,
971 "code": code,
972 }
973 }),
974 );
975 tool_commands.push(Command::Reject {
976 reason: reason
977 .clone()
978 .unwrap_or_else(|| "Rejected by agent".to_string()),
979 code,
980 });
981 }
982 }
983 }
984 }
985
986 if rerun_for_rag {
987 if attempts >= MAX_RAG_ATTEMPTS {
988 warn!("Reached RAG iteration limit, using last response");
989 final_text = structured.text.or_else(|| Some(raw.clone()));
990 break;
991 }
992 raw = self.call_llm().await?;
993 continue;
994 }
995
996 final_text = structured.text;
997 break;
998 }
999
1000 final_text = Some(raw.clone());
1001 break;
1002 }
1003
1004 let mut commands = Vec::new();
1005
1006 let mut has_hangup = false;
1008 for tool in &tool_commands {
1009 if matches!(tool, Command::Hangup { .. }) {
1010 has_hangup = true;
1011 break;
1012 }
1013 }
1014
1015 if let Some(text) = final_text {
1016 if !text.trim().is_empty() {
1017 self.history.push(ChatMessage {
1018 role: "assistant".to_string(),
1019 content: text.clone(),
1020 });
1021 self.last_tts_start_at = Some(std::time::Instant::now());
1022 self.is_speaking = true;
1023
1024 if has_hangup {
1027 commands.push(self.create_tts_command(text, wait_input_timeout, Some(true)));
1028 tool_commands.retain(|c| !matches!(c, Command::Hangup { .. }));
1029 self.is_hanging_up = true;
1031 } else {
1032 commands.push(self.create_tts_command(text, wait_input_timeout, None));
1033 }
1034 }
1035 }
1036
1037 commands.extend(tool_commands);
1038
1039 Ok(commands)
1040 }
1041}
1042
1043fn parse_structured_response(raw: &str) -> Option<StructuredResponse> {
1044 let payload = extract_json_block(raw)?;
1045 serde_json::from_str(payload).ok()
1046}
1047
1048fn is_likely_filler(text: &str) -> bool {
1049 let trimmed = text.trim().to_lowercase();
1050 FILLERS.contains(&trimmed)
1051}
1052
1053fn extract_json_block(raw: &str) -> Option<&str> {
1054 let trimmed = raw.trim();
1055 if trimmed.starts_with('`') {
1056 if let Some(end) = trimmed.rfind("```") {
1057 if end <= 3 {
1058 return None;
1059 }
1060 let mut inner = &trimmed[3..end];
1061 inner = inner.trim();
1062 if inner.to_lowercase().starts_with("json") {
1063 if let Some(newline) = inner.find('\n') {
1064 inner = inner[newline + 1..].trim();
1065 } else if inner.len() > 4 {
1066 inner = inner[4..].trim();
1067 } else {
1068 inner = inner.trim();
1069 }
1070 }
1071 return Some(inner);
1072 }
1073 } else if trimmed.starts_with('{') || trimmed.starts_with('[') {
1074 return Some(trimmed);
1075 }
1076 None
1077}
1078
1079#[async_trait]
1080impl DialogueHandler for LlmHandler {
1081 async fn on_start(&mut self) -> Result<Vec<Command>> {
1082 self.last_tts_start_at = Some(std::time::Instant::now());
1083
1084 let mut commands = Vec::new();
1085
1086 if let Some(scene_id) = &self.current_scene_id {
1088 if let Some(scene) = self.scenes.get(scene_id) {
1089 if let Some(audio_file) = &scene.play {
1090 commands.push(Command::Play {
1091 url: audio_file.clone(),
1092 play_id: None,
1093 auto_hangup: None,
1094 wait_input_timeout: None,
1095 });
1096 }
1097 }
1098 }
1099
1100 if let Some(greeting) = &self.config.greeting {
1101 self.is_speaking = true;
1102 commands.push(self.create_tts_command(greeting.clone(), None, None));
1103 return Ok(commands);
1104 }
1105
1106 let response_commands = self.generate_response().await?;
1107 commands.extend(response_commands);
1108 Ok(commands)
1109 }
1110
1111 async fn on_event(&mut self, event: &SessionEvent) -> Result<Vec<Command>> {
1112 match event {
1113 SessionEvent::Dtmf { digit, .. } => {
1114 info!("DTMF received: {}", digit);
1115 let action = self.get_dtmf_action(digit);
1116 if let Some(action) = action {
1117 return self.handle_dtmf_action(action).await;
1118 }
1119 Ok(vec![])
1120 }
1121
1122 SessionEvent::AsrFinal { text, .. } => {
1123 if text.trim().is_empty() {
1124 return Ok(vec![]);
1125 }
1126
1127 self.last_asr_final_at = Some(std::time::Instant::now());
1128 self.last_interaction_at = std::time::Instant::now();
1129 self.is_speaking = false;
1130 self.consecutive_follow_ups = 0;
1131
1132 self.history.push(ChatMessage {
1133 role: "user".to_string(),
1134 content: text.clone(),
1135 });
1136
1137 self.generate_response().await
1138 }
1139
1140 SessionEvent::AsrDelta { is_filler, .. } | SessionEvent::Speaking { is_filler, .. } => {
1141 let strategy = self.interruption_config.strategy;
1142 let should_check = match (strategy, event) {
1143 (InterruptionStrategy::None, _) => false,
1144 (InterruptionStrategy::Vad, SessionEvent::Speaking { .. }) => true,
1145 (InterruptionStrategy::Asr, SessionEvent::AsrDelta { .. }) => true,
1146 (InterruptionStrategy::Both, _) => true,
1147 _ => false,
1148 };
1149
1150 if self.is_speaking && !self.is_hanging_up && should_check {
1152 if let Some(last_start) = self.last_tts_start_at {
1154 let ignore_ms = self.interruption_config.ignore_first_ms.unwrap_or(800);
1155 if last_start.elapsed().as_millis() < ignore_ms as u128 {
1156 return Ok(vec![]);
1157 }
1158 }
1159
1160 if self.interruption_config.filler_word_filter.unwrap_or(false) {
1162 if let Some(true) = is_filler {
1163 return Ok(vec![]);
1164 }
1165 if let SessionEvent::AsrDelta { text, .. } = event {
1167 if is_likely_filler(text) {
1168 return Ok(vec![]);
1169 }
1170 }
1171 }
1172
1173 if let Some(last_final) = self.last_asr_final_at {
1175 if last_final.elapsed().as_millis() < 500 {
1176 return Ok(vec![]);
1177 }
1178 }
1179
1180 info!("Smart interruption detected, stopping playback");
1181 self.is_speaking = false;
1182 return Ok(vec![Command::Interrupt {
1183 graceful: Some(true),
1184 fade_out_ms: self.interruption_config.volume_fade_ms,
1185 }]);
1186 }
1187 Ok(vec![])
1188 }
1189
1190 SessionEvent::Eou { completed, .. } => {
1191 if *completed && self.is_speaking == false {
1192 info!("EOU detected, triggering early response");
1193 return self.generate_response().await;
1194 }
1195 Ok(vec![])
1196 }
1197
1198 SessionEvent::Silence { .. } => {
1199 let follow_up_config = if let Some(scene_id) = &self.current_scene_id {
1200 self.scenes
1201 .get(scene_id)
1202 .and_then(|s| s.follow_up)
1203 .or(self.global_follow_up_config)
1204 } else {
1205 self.global_follow_up_config
1206 };
1207
1208 if let Some(config) = follow_up_config {
1209 if !self.is_speaking
1210 && self.last_interaction_at.elapsed().as_millis() as u64 >= config.timeout
1211 {
1212 if self.consecutive_follow_ups >= config.max_count {
1213 info!("Max follow-up count reached, hanging up");
1214 return Ok(vec![Command::Hangup {
1215 reason: Some("Max follow-up reached".to_string()),
1216 initiator: Some("system".to_string()),
1217 }]);
1218 }
1219
1220 info!(
1221 "Silence timeout detected ({}ms), triggering follow-up ({}/{})",
1222 self.last_interaction_at.elapsed().as_millis(),
1223 self.consecutive_follow_ups + 1,
1224 config.max_count
1225 );
1226 self.consecutive_follow_ups += 1;
1227 self.last_interaction_at = std::time::Instant::now();
1228 return self.generate_response().await;
1229 }
1230 }
1231 Ok(vec![])
1232 }
1233
1234 SessionEvent::TrackStart { .. } => {
1235 self.is_speaking = true;
1236 Ok(vec![])
1237 }
1238
1239 SessionEvent::TrackEnd { .. } => {
1240 self.is_speaking = false;
1241 self.is_hanging_up = false;
1242 self.last_interaction_at = std::time::Instant::now();
1243 Ok(vec![])
1244 }
1245
1246 SessionEvent::FunctionCall {
1247 name, arguments, ..
1248 } => {
1249 info!(
1250 "Function call from Realtime: {} with args {}",
1251 name, arguments
1252 );
1253 let args: serde_json::Value = serde_json::from_str(arguments).unwrap_or_default();
1254 match name.as_str() {
1255 "hangup_call" => Ok(vec![Command::Hangup {
1256 reason: args["reason"].as_str().map(|s| s.to_string()),
1257 initiator: Some("ai".to_string()),
1258 }]),
1259 "transfer_call" | "refer_call" => {
1260 if let Some(callee) = args["callee"]
1261 .as_str()
1262 .or_else(|| args["callee_uri"].as_str())
1263 {
1264 Ok(vec![Command::Refer {
1265 caller: String::new(),
1266 callee: callee.to_string(),
1267 options: None,
1268 }])
1269 } else {
1270 warn!("No callee provided for transfer_call");
1271 Ok(vec![])
1272 }
1273 }
1274 "goto_scene" => {
1275 if let Some(scene) = args["scene"].as_str() {
1276 self.switch_to_scene(scene, false).await
1277 } else {
1278 Ok(vec![])
1279 }
1280 }
1281 _ => {
1282 warn!("Unhandled function call: {}", name);
1283 Ok(vec![])
1284 }
1285 }
1286 }
1287
1288 _ => Ok(vec![]),
1289 }
1290 }
1291
1292 async fn get_history(&self) -> Vec<ChatMessage> {
1293 self.history.clone()
1294 }
1295
1296 async fn summarize(&mut self, prompt: &str) -> Result<String> {
1297 info!("Generating summary with prompt: {}", prompt);
1298 let mut summary_history = self.history.clone();
1299 summary_history.push(ChatMessage {
1300 role: "user".to_string(),
1301 content: prompt.to_string(),
1302 });
1303
1304 self.provider.call(&self.config, &summary_history).await
1305 }
1306}
1307
1308#[cfg(test)]
1309mod tests {
1310 use super::*;
1311 use crate::event::SessionEvent;
1312 use anyhow::{Result, anyhow};
1313 use async_trait::async_trait;
1314 use std::collections::VecDeque;
1315 use std::sync::Mutex;
1316
1317 struct TestProvider {
1318 responses: Mutex<VecDeque<String>>,
1319 }
1320
1321 impl TestProvider {
1322 fn new(responses: Vec<String>) -> Self {
1323 Self {
1324 responses: Mutex::new(VecDeque::from(responses)),
1325 }
1326 }
1327 }
1328
1329 #[async_trait]
1330 impl LlmProvider for TestProvider {
1331 async fn call(&self, _config: &LlmConfig, _history: &[ChatMessage]) -> Result<String> {
1332 let mut guard = self.responses.lock().unwrap();
1333 guard
1334 .pop_front()
1335 .ok_or_else(|| anyhow!("Test provider ran out of responses"))
1336 }
1337
1338 async fn call_stream(
1339 &self,
1340 _config: &LlmConfig,
1341 _history: &[ChatMessage],
1342 ) -> Result<Pin<Box<dyn Stream<Item = Result<LlmStreamEvent>> + Send>>> {
1343 let response = self.call(_config, _history).await?;
1344 let s = async_stream::stream! {
1345 yield Ok(LlmStreamEvent::Content(response));
1346 };
1347 Ok(Box::pin(s))
1348 }
1349 }
1350
1351 struct RecordingRag {
1352 queries: Mutex<Vec<String>>,
1353 }
1354
1355 impl RecordingRag {
1356 fn new() -> Self {
1357 Self {
1358 queries: Mutex::new(Vec::new()),
1359 }
1360 }
1361
1362 fn recorded_queries(&self) -> Vec<String> {
1363 self.queries.lock().unwrap().clone()
1364 }
1365 }
1366
1367 #[async_trait]
1368 impl RagRetriever for RecordingRag {
1369 async fn retrieve(&self, query: &str) -> Result<String> {
1370 self.queries.lock().unwrap().push(query.to_string());
1371 Ok(format!("retrieved {}", query))
1372 }
1373 }
1374
1375 #[tokio::test]
1376 async fn handler_applies_tool_instructions() -> Result<()> {
1377 let response = r#"{
1378 "text": "Goodbye",
1379 "waitInputTimeout": 15000,
1380 "tools": [
1381 {"name": "hangup", "reason": "done", "initiator": "agent"},
1382 {"name": "refer", "caller": "sip:bot", "callee": "sip:lead"}
1383 ]
1384 }"#;
1385
1386 let provider = Arc::new(TestProvider::new(vec![response.to_string()]));
1387 let mut handler = LlmHandler::with_provider(
1388 LlmConfig::default(),
1389 provider,
1390 Arc::new(NoopRagRetriever),
1391 crate::playbook::InterruptionConfig::default(),
1392 None,
1393 HashMap::new(),
1394 None,
1395 None,
1396 );
1397
1398 let event = SessionEvent::AsrFinal {
1399 track_id: "track-1".to_string(),
1400 timestamp: 0,
1401 index: 0,
1402 start_time: None,
1403 end_time: None,
1404 text: "hello".to_string(),
1405 is_filler: None,
1406 confidence: None,
1407 };
1408
1409 let commands = handler.on_event(&event).await?;
1410 assert!(matches!(
1411 commands.get(0),
1412 Some(Command::Tts {
1413 text,
1414 wait_input_timeout: Some(15000),
1415 auto_hangup: Some(true),
1416 ..
1417 }) if text == "Goodbye"
1418 ));
1419 assert!(commands.iter().any(|cmd| matches!(
1420 cmd,
1421 Command::Refer {
1422 caller,
1423 callee,
1424 ..
1425 } if caller == "sip:bot" && callee == "sip:lead"
1426 )));
1427
1428 Ok(())
1429 }
1430
1431 #[tokio::test]
1432 async fn handler_requeries_after_rag() -> Result<()> {
1433 let rag_instruction = r#"{"tools": [{"name": "rag", "query": "policy"}]}"#;
1434 let provider = Arc::new(TestProvider::new(vec![
1435 rag_instruction.to_string(),
1436 "Final answer".to_string(),
1437 ]));
1438 let rag = Arc::new(RecordingRag::new());
1439 let mut handler = LlmHandler::with_provider(
1440 LlmConfig::default(),
1441 provider,
1442 rag.clone(),
1443 crate::playbook::InterruptionConfig::default(),
1444 None,
1445 HashMap::new(),
1446 None,
1447 None,
1448 );
1449
1450 let event = SessionEvent::AsrFinal {
1451 track_id: "track-2".to_string(),
1452 timestamp: 0,
1453 index: 0,
1454 start_time: None,
1455 end_time: None,
1456 text: "reep".to_string(),
1457 is_filler: None,
1458 confidence: None,
1459 };
1460
1461 let commands = handler.on_event(&event).await?;
1462 assert!(matches!(
1463 commands.get(0),
1464 Some(Command::Tts {
1465 text,
1466 wait_input_timeout: Some(timeout),
1467 ..
1468 }) if text == "Final answer" && *timeout == 10000
1469 ));
1470 assert_eq!(rag.recorded_queries(), vec!["policy".to_string()]);
1471
1472 Ok(())
1473 }
1474
1475 #[tokio::test]
1476 async fn test_full_dialogue_flow() -> Result<()> {
1477 let responses = vec![
1478 "Hello! How can I help you today?".to_string(),
1479 r#"{"text": "I can help with that. Anything else?", "waitInputTimeout": 5000}"#
1480 .to_string(),
1481 r#"{"text": "Goodbye!", "tools": [{"name": "hangup", "reason": "completed"}]}"#
1482 .to_string(),
1483 ];
1484
1485 let provider = Arc::new(TestProvider::new(responses));
1486 let config = LlmConfig {
1487 greeting: Some("Welcome to the voice assistant.".to_string()),
1488 ..Default::default()
1489 };
1490
1491 let mut handler = LlmHandler::with_provider(
1492 config,
1493 provider,
1494 Arc::new(NoopRagRetriever),
1495 crate::playbook::InterruptionConfig::default(),
1496 None,
1497 HashMap::new(),
1498 None,
1499 None,
1500 );
1501
1502 let commands = handler.on_start().await?;
1504 assert_eq!(commands.len(), 1);
1505 if let Command::Tts { text, .. } = &commands[0] {
1506 assert_eq!(text, "Welcome to the voice assistant.");
1507 } else {
1508 panic!("Expected Tts command");
1509 }
1510
1511 let event = SessionEvent::AsrFinal {
1513 track_id: "test".to_string(),
1514 timestamp: 0,
1515 index: 0,
1516 start_time: None,
1517 end_time: None,
1518 text: "I need help".to_string(),
1519 is_filler: None,
1520 confidence: None,
1521 };
1522 let commands = handler.on_event(&event).await?;
1523 assert_eq!(commands.len(), 3);
1525 if let Command::Tts { text, .. } = &commands[0] {
1526 assert!(text.contains("Hello"));
1527 } else {
1528 panic!("Expected Tts command");
1529 }
1530
1531 let event = SessionEvent::AsrFinal {
1533 track_id: "test".to_string(),
1534 timestamp: 0,
1535 index: 1,
1536 start_time: None,
1537 end_time: None,
1538 text: "Tell me a joke".to_string(),
1539 is_filler: None,
1540 confidence: None,
1541 };
1542 let commands = handler.on_event(&event).await?;
1543 assert_eq!(commands.len(), 1);
1544 if let Command::Tts {
1545 text,
1546 wait_input_timeout,
1547 ..
1548 } = &commands[0]
1549 {
1550 assert_eq!(text, "I can help with that. Anything else?");
1551 assert_eq!(*wait_input_timeout, Some(5000));
1552 } else {
1553 panic!("Expected Tts command");
1554 }
1555
1556 let event = SessionEvent::AsrFinal {
1558 track_id: "test".to_string(),
1559 timestamp: 0,
1560 index: 2,
1561 start_time: None,
1562 end_time: None,
1563 text: "That's all, thanks".to_string(),
1564 is_filler: None,
1565 confidence: None,
1566 };
1567 let commands = handler.on_event(&event).await?;
1568 assert_eq!(commands.len(), 1);
1570
1571 let has_tts_hangup = commands.iter().any(|c| {
1572 matches!(
1573 c,
1574 Command::Tts {
1575 text,
1576 auto_hangup: Some(true),
1577 ..
1578 } if text == "Goodbye!"
1579 )
1580 });
1581
1582 assert!(has_tts_hangup);
1583
1584 Ok(())
1585 }
1586
1587 #[tokio::test]
1588 async fn test_xml_tools_and_sentence_splitting() -> Result<()> {
1589 let responses = vec!["Hello! <refer to=\"sip:123\"/> How are you? <hangup/>".to_string()];
1590 let provider = Arc::new(TestProvider::new(responses));
1591 let mut handler = LlmHandler::with_provider(
1592 LlmConfig::default(),
1593 provider,
1594 Arc::new(NoopRagRetriever),
1595 crate::playbook::InterruptionConfig::default(),
1596 None,
1597 HashMap::new(),
1598 None,
1599 None,
1600 );
1601
1602 let event = SessionEvent::AsrFinal {
1603 track_id: "test".to_string(),
1604 timestamp: 0,
1605 index: 0,
1606 start_time: None,
1607 end_time: None,
1608 text: "hi".to_string(),
1609 is_filler: None,
1610 confidence: None,
1611 };
1612
1613 let commands = handler.on_event(&event).await?;
1614
1615 assert_eq!(commands.len(), 4);
1621
1622 if let Command::Tts {
1623 text,
1624 play_id: pid1,
1625 ..
1626 } = &commands[0]
1627 {
1628 assert!(text.contains("Hello"));
1629 assert!(pid1.is_some());
1630
1631 if let Command::Refer { callee, .. } = &commands[1] {
1632 assert_eq!(callee, "sip:123");
1633 } else {
1634 panic!("Expected Refer");
1635 }
1636
1637 if let Command::Tts {
1638 text,
1639 play_id: pid2,
1640 ..
1641 } = &commands[2]
1642 {
1643 assert!(text.contains("How are you"));
1644 assert_eq!(*pid1, *pid2); } else {
1646 panic!("Expected Tts");
1647 }
1648
1649 if let Command::Tts {
1650 auto_hangup: Some(true),
1651 ..
1652 } = &commands[3]
1653 {
1654 } else {
1656 panic!("Expected Tts with auto_hangup");
1657 }
1658 } else {
1659 panic!("Expected Tts");
1660 }
1661
1662 Ok(())
1663 }
1664
1665 #[tokio::test]
1666 async fn test_interruption_logic() -> Result<()> {
1667 let provider = Arc::new(TestProvider::new(vec!["Some long response".to_string()]));
1668 let mut handler = LlmHandler::with_provider(
1669 LlmConfig::default(),
1670 provider,
1671 Arc::new(NoopRagRetriever),
1672 crate::playbook::InterruptionConfig::default(),
1673 None,
1674 HashMap::new(),
1675 None,
1676 None,
1677 );
1678
1679 let event = SessionEvent::AsrFinal {
1681 track_id: "test".to_string(),
1682 timestamp: 0,
1683 index: 0,
1684 start_time: None,
1685 end_time: None,
1686 text: "hello".to_string(),
1687 is_filler: None,
1688 confidence: None,
1689 };
1690 handler.on_event(&event).await?;
1691 assert!(handler.is_speaking);
1692
1693 tokio::time::sleep(std::time::Duration::from_millis(850)).await;
1695
1696 let event = SessionEvent::AsrDelta {
1698 track_id: "test".to_string(),
1699 timestamp: 0,
1700 index: 0,
1701 start_time: None,
1702 end_time: None,
1703 text: "I...".to_string(),
1704 is_filler: None,
1705 confidence: None,
1706 };
1707 let commands = handler.on_event(&event).await?;
1708 assert_eq!(commands.len(), 1);
1709 assert!(matches!(commands[0], Command::Interrupt { .. }));
1710 assert!(!handler.is_speaking);
1711
1712 Ok(())
1713 }
1714
1715 #[tokio::test]
1716 async fn test_rag_iteration_limit() -> Result<()> {
1717 let rag_instruction = r#"{"tools": [{"name": "rag", "query": "endless"}]}"#;
1719 let provider = Arc::new(TestProvider::new(vec![
1720 rag_instruction.to_string(),
1721 rag_instruction.to_string(),
1722 rag_instruction.to_string(),
1723 rag_instruction.to_string(),
1724 "Should not reach here".to_string(),
1725 ]));
1726
1727 let mut handler = LlmHandler::with_provider(
1728 LlmConfig::default(),
1729 provider,
1730 Arc::new(RecordingRag::new()),
1731 crate::playbook::InterruptionConfig::default(),
1732 None,
1733 HashMap::new(),
1734 None,
1735 None,
1736 );
1737
1738 let event = SessionEvent::AsrFinal {
1739 track_id: "test".to_string(),
1740 timestamp: 0,
1741 index: 0,
1742 start_time: None,
1743 end_time: None,
1744 text: "loop".to_string(),
1745 is_filler: None,
1746 confidence: None,
1747 };
1748
1749 let commands = handler.on_event(&event).await?;
1750 assert_eq!(commands.len(), 1);
1752 if let Command::Tts { text, .. } = &commands[0] {
1753 assert_eq!(text, rag_instruction);
1754 }
1755
1756 Ok(())
1757 }
1758
1759 #[tokio::test]
1760 async fn test_follow_up_logic() -> Result<()> {
1761 use std::time::Duration;
1762
1763 let follow_up_config = super::super::FollowUpConfig {
1765 timeout: 100, max_count: 2,
1767 };
1768
1769 let provider = Arc::new(TestProvider::new(vec![
1771 "Follow up 1".to_string(),
1772 "Follow up 2".to_string(),
1773 "Response to user".to_string(),
1774 ]));
1775
1776 let mut handler = LlmHandler::with_provider(
1777 LlmConfig::default(),
1778 provider,
1779 Arc::new(NoopRagRetriever),
1780 crate::playbook::InterruptionConfig::default(),
1781 Some(follow_up_config),
1782 HashMap::new(),
1783 None,
1784 None,
1785 );
1786
1787 handler.last_interaction_at = std::time::Instant::now();
1789 handler.is_speaking = false;
1790
1791 let event = SessionEvent::Silence {
1793 track_id: "t1".to_string(),
1794 timestamp: 0,
1795 start_time: 0,
1796 duration: 50,
1797 samples: None,
1798 };
1799 let commands = handler.on_event(&event).await?;
1800 assert!(commands.is_empty(), "Should not trigger if < timeout");
1801
1802 tokio::time::sleep(Duration::from_millis(110)).await;
1804 let event = SessionEvent::Silence {
1808 track_id: "t1".to_string(),
1809 timestamp: 0,
1810 start_time: 0,
1811 duration: 100,
1812 samples: None,
1813 };
1814 let commands = handler.on_event(&event).await?;
1815 assert_eq!(commands.len(), 1, "Should trigger follow-up 1");
1816 if let Command::Tts { text, .. } = &commands[0] {
1817 assert_eq!(text, "Follow up 1");
1818 }
1819 assert_eq!(handler.consecutive_follow_ups, 1);
1820
1821 let event = SessionEvent::TrackEnd {
1824 track_id: "t1".to_string(),
1825 timestamp: 0,
1826 play_id: None,
1827 duration: 100,
1828 ssrc: 0,
1829 };
1830 handler.on_event(&event).await?;
1831 assert!(
1832 !handler.is_speaking,
1833 "Bot should not be speaking after TrackEnd"
1834 );
1835
1836 tokio::time::sleep(Duration::from_millis(110)).await;
1843 let event = SessionEvent::Silence {
1844 track_id: "t1".to_string(),
1845 timestamp: 0,
1846 start_time: 0,
1847 duration: 100,
1848 samples: None,
1849 };
1850 let commands = handler.on_event(&event).await?;
1851 assert_eq!(commands.len(), 1, "Should trigger follow-up 2");
1852 if let Command::Tts { text, .. } = &commands[0] {
1853 assert_eq!(text, "Follow up 2");
1854 }
1855 assert_eq!(handler.consecutive_follow_ups, 2);
1856
1857 let event = SessionEvent::TrackEnd {
1859 track_id: "t1".to_string(),
1860 timestamp: 0,
1861 play_id: None,
1862 duration: 100,
1863 ssrc: 0,
1864 };
1865 handler.on_event(&event).await?;
1866
1867 tokio::time::sleep(Duration::from_millis(110)).await;
1869 let event = SessionEvent::Silence {
1870 track_id: "t1".to_string(),
1871 timestamp: 0,
1872 start_time: 0,
1873 duration: 100,
1874 samples: None,
1875 };
1876 let commands = handler.on_event(&event).await?;
1877 assert_eq!(commands.len(), 1, "Should hangup after max count");
1878 assert!(matches!(commands[0], Command::Hangup { .. }));
1879
1880 handler.consecutive_follow_ups = 2; let event = SessionEvent::AsrFinal {
1883 track_id: "t1".to_string(),
1884 timestamp: 0,
1885 index: 0,
1886 start_time: None,
1887 end_time: None,
1888 text: "User speaks".to_string(),
1889 is_filler: None,
1890 confidence: None,
1891 };
1892 let _ = handler.on_event(&event).await?;
1894 assert_eq!(
1895 handler.consecutive_follow_ups, 0,
1896 "Should reset count on AsrFinal"
1897 );
1898
1899 Ok(())
1900 }
1901
1902 #[tokio::test]
1903 async fn test_interruption_protection_period() -> Result<()> {
1904 let provider = Arc::new(TestProvider::new(vec!["Some long response".to_string()]));
1905 let mut config = crate::playbook::InterruptionConfig::default();
1906 config.ignore_first_ms = Some(800);
1907
1908 let mut handler = LlmHandler::with_provider(
1909 LlmConfig::default(),
1910 provider,
1911 Arc::new(NoopRagRetriever),
1912 config,
1913 None,
1914 HashMap::new(),
1915 None,
1916 None,
1917 );
1918
1919 let event = SessionEvent::AsrFinal {
1921 track_id: "test".to_string(),
1922 timestamp: 0,
1923 index: 0,
1924 start_time: None,
1925 end_time: None,
1926 text: "hello".to_string(),
1927 is_filler: None,
1928 confidence: None,
1929 };
1930 handler.on_event(&event).await?;
1931 assert!(handler.is_speaking);
1932
1933 let event = SessionEvent::AsrDelta {
1935 track_id: "test".to_string(),
1936 timestamp: 0,
1937 index: 0,
1938 start_time: None,
1939 end_time: None,
1940 text: "I...".to_string(),
1941 is_filler: None,
1942 confidence: None,
1943 };
1944 let commands = handler.on_event(&event).await?;
1945 assert_eq!(commands.len(), 0);
1947 assert!(handler.is_speaking);
1948
1949 Ok(())
1950 }
1951
1952 #[tokio::test]
1953 async fn test_interruption_filler_word() -> Result<()> {
1954 let provider = Arc::new(TestProvider::new(vec!["Some long response".to_string()]));
1955 let mut config = crate::playbook::InterruptionConfig::default();
1956 config.filler_word_filter = Some(true);
1957 config.ignore_first_ms = Some(0); let mut handler = LlmHandler::with_provider(
1960 LlmConfig::default(),
1961 provider,
1962 Arc::new(NoopRagRetriever),
1963 config,
1964 None,
1965 HashMap::new(),
1966 None,
1967 None,
1968 );
1969
1970 let event = SessionEvent::AsrFinal {
1972 track_id: "test".to_string(),
1973 timestamp: 0,
1974 index: 0,
1975 start_time: None,
1976 end_time: None,
1977 text: "hello".to_string(),
1978 is_filler: None,
1979 confidence: None,
1980 };
1981 handler.on_event(&event).await?;
1982 assert!(handler.is_speaking);
1983
1984 tokio::time::sleep(std::time::Duration::from_millis(600)).await;
1986
1987 let event = SessionEvent::AsrDelta {
1989 track_id: "test".to_string(),
1990 timestamp: 0,
1991 index: 0,
1992 start_time: None,
1993 end_time: None,
1994 text: "uh".to_string(),
1995 is_filler: Some(true),
1996 confidence: None,
1997 };
1998 let commands = handler.on_event(&event).await?;
1999 assert_eq!(commands.len(), 0);
2001 assert!(handler.is_speaking);
2002
2003 let event = SessionEvent::AsrDelta {
2005 track_id: "test".to_string(),
2006 timestamp: 0,
2007 index: 0,
2008 start_time: None,
2009 end_time: None,
2010 text: "Wait".to_string(),
2011 is_filler: Some(false),
2012 confidence: None,
2013 };
2014 let commands = handler.on_event(&event).await?;
2015 assert_eq!(commands.len(), 1);
2017 assert!(matches!(commands[0], Command::Interrupt { .. }));
2018
2019 Ok(())
2020 }
2021
2022 #[tokio::test]
2023 async fn test_eou_early_response() -> Result<()> {
2024 let provider = Arc::new(TestProvider::new(vec![
2025 "End of Utterance response".to_string(),
2026 ]));
2027 let mut handler = LlmHandler::with_provider(
2028 LlmConfig::default(),
2029 provider,
2030 Arc::new(NoopRagRetriever),
2031 crate::playbook::InterruptionConfig::default(),
2032 None,
2033 HashMap::new(),
2034 None,
2035 None,
2036 );
2037
2038 let event = SessionEvent::Eou {
2040 track_id: "test".to_string(),
2041 timestamp: 0,
2042 completed: true,
2043 };
2044 let commands = handler.on_event(&event).await?;
2045 assert_eq!(commands.len(), 1);
2046 if let Command::Tts { text, .. } = &commands[0] {
2047 assert_eq!(text, "End of Utterance response");
2048 } else {
2049 panic!("Expected Tts");
2050 }
2051
2052 Ok(())
2053 }
2054
2055 #[tokio::test]
2056 async fn test_summary_and_history() -> Result<()> {
2057 let provider = Arc::new(TestProvider::new(vec!["Test summary".to_string()]));
2058 let mut handler = LlmHandler::with_provider(
2059 LlmConfig::default(),
2060 provider,
2061 Arc::new(NoopRagRetriever),
2062 crate::playbook::InterruptionConfig::default(),
2063 None,
2064 HashMap::new(),
2065 None,
2066 None,
2067 );
2068
2069 handler.history.push(ChatMessage {
2071 role: "user".to_string(),
2072 content: "Hello".to_string(),
2073 });
2074 handler.history.push(ChatMessage {
2075 role: "assistant".to_string(),
2076 content: "Hi there".to_string(),
2077 });
2078
2079 let history = handler.get_history().await;
2081 assert_eq!(history.len(), 3); let summary = handler.summarize("Summarize this").await?;
2085 assert_eq!(summary, "Test summary");
2086
2087 Ok(())
2088 }
2089}