active_call/playbook/
handler.rs

1use crate::ReferOption;
2use crate::call::Command;
3use crate::event::SessionEvent;
4use anyhow::{Result, anyhow};
5use async_trait::async_trait;
6use futures::{Stream, StreamExt};
7use once_cell::sync::Lazy;
8use regex::Regex;
9use reqwest::Client;
10use serde::{Deserialize, Serialize};
11use serde_json::json;
12use std::collections::HashMap;
13use std::pin::Pin;
14use std::sync::Arc;
15use tracing::{info, warn};
16
17static RE_HANGUP: Lazy<Regex> = Lazy::new(|| Regex::new(r"<hangup\s*/>").unwrap());
18static RE_REFER: Lazy<Regex> = Lazy::new(|| Regex::new(r#"<refer\s+to="([^"]+)"\s*/>"#).unwrap());
19static RE_PLAY: Lazy<Regex> = Lazy::new(|| Regex::new(r#"<play\s+file="([^"]+)"\s*/>"#).unwrap());
20static RE_GOTO: Lazy<Regex> = Lazy::new(|| Regex::new(r#"<goto\s+scene="([^"]+)"\s*/>"#).unwrap());
21static RE_SENTENCE: Lazy<Regex> = Lazy::new(|| Regex::new(r"(?m)[.!?。!?\n]\s*").unwrap());
22static FILLERS: Lazy<std::collections::HashSet<String>> = Lazy::new(|| {
23    let mut s = std::collections::HashSet::new();
24    let default_fillers = ["嗯", "啊", "哦", "那个", "那个...", "uh", "um", "ah"];
25
26    if let Ok(content) = std::fs::read_to_string("config/fillers.txt") {
27        for line in content.lines() {
28            let trimmed = line.trim().to_lowercase();
29            if !trimmed.is_empty() {
30                s.insert(trimmed);
31            }
32        }
33    }
34
35    if s.is_empty() {
36        for f in default_fillers {
37            s.insert(f.to_string());
38        }
39    }
40    s
41});
42
43use super::ChatMessage;
44use super::InterruptionStrategy;
45use super::LlmConfig;
46use super::dialogue::DialogueHandler;
47
48const MAX_RAG_ATTEMPTS: usize = 3;
49
50#[derive(Debug, Clone)]
51pub enum LlmStreamEvent {
52    Content(String),
53    Reasoning(String),
54}
55
56#[async_trait]
57pub trait LlmProvider: Send + Sync {
58    async fn call(&self, config: &LlmConfig, history: &[ChatMessage]) -> Result<String>;
59    async fn call_stream(
60        &self,
61        config: &LlmConfig,
62        history: &[ChatMessage],
63    ) -> Result<Pin<Box<dyn Stream<Item = Result<LlmStreamEvent>> + Send>>>;
64}
65
66pub struct RealtimeResponse {
67    pub audio_delta: Option<Vec<u8>>,
68    pub text_delta: Option<String>,
69    pub function_call: Option<ToolInvocation>,
70    pub speech_started: bool,
71}
72
73#[async_trait]
74pub trait RealtimeProvider: Send + Sync {
75    async fn connect(&self, config: &LlmConfig) -> Result<()>;
76    async fn send_audio(&self, audio: &[i16]) -> Result<()>;
77    async fn subscribe(
78        &self,
79    ) -> Result<Pin<Box<dyn Stream<Item = Result<RealtimeResponse>> + Send>>>;
80}
81
82struct DefaultLlmProvider {
83    client: Client,
84}
85
86impl DefaultLlmProvider {
87    fn new() -> Self {
88        Self {
89            client: Client::new(),
90        }
91    }
92}
93
94#[async_trait]
95impl LlmProvider for DefaultLlmProvider {
96    async fn call(&self, config: &LlmConfig, history: &[ChatMessage]) -> Result<String> {
97        let mut url = config
98            .base_url
99            .clone()
100            .unwrap_or_else(|| "https://api.openai.com/v1".to_string());
101        let model = config
102            .model
103            .clone()
104            .unwrap_or_else(|| "gpt-4-turbo".to_string());
105        let api_key = config.api_key.clone().unwrap_or_default();
106
107        if !url.ends_with("/chat/completions") {
108            url = format!("{}/chat/completions", url.trim_end_matches('/'));
109        }
110
111        let body = json!({
112            "model": model,
113            "messages": history,
114        });
115
116        let res = self
117            .client
118            .post(&url)
119            .header("Authorization", format!("Bearer {}", api_key))
120            .json(&body)
121            .send()
122            .await?;
123
124        if !res.status().is_success() {
125            return Err(anyhow!("LLM request failed: {}", res.status()));
126        }
127
128        let json: serde_json::Value = res.json().await?;
129        let content = json["choices"][0]["message"]["content"]
130            .as_str()
131            .ok_or_else(|| anyhow!("Invalid LLM response"))?
132            .to_string();
133
134        Ok(content)
135    }
136
137    async fn call_stream(
138        &self,
139        config: &LlmConfig,
140        history: &[ChatMessage],
141    ) -> Result<Pin<Box<dyn Stream<Item = Result<LlmStreamEvent>> + Send>>> {
142        let mut url = config
143            .base_url
144            .clone()
145            .unwrap_or_else(|| "https://api.openai.com/v1".to_string());
146        let model = config
147            .model
148            .clone()
149            .unwrap_or_else(|| "gpt-4-turbo".to_string());
150        let api_key = config.api_key.clone().unwrap_or_default();
151
152        if !url.ends_with("/chat/completions") {
153            url = format!("{}/chat/completions", url.trim_end_matches('/'));
154        }
155
156        let body = json!({
157            "model": model,
158            "messages": history,
159            "stream": true,
160        });
161
162        let res = self
163            .client
164            .post(&url)
165            .header("Authorization", format!("Bearer {}", api_key))
166            .json(&body)
167            .send()
168            .await?;
169
170        if !res.status().is_success() {
171            return Err(anyhow!("LLM request failed: {}", res.status()));
172        }
173
174        let stream = res.bytes_stream();
175        let s = async_stream::stream! {
176            let mut buffer = String::new();
177            for await chunk in stream {
178                match chunk {
179                    Ok(bytes) => {
180                        let text = String::from_utf8_lossy(&bytes);
181                        buffer.push_str(&text);
182
183                        while let Some(line_end) = buffer.find('\n') {
184                            let line = buffer[..line_end].trim();
185                            if line.starts_with("data:") {
186                                let data = &line[5..].trim();
187                                if *data == "[DONE]" {
188                                    break;
189                                }
190                                if let Ok(json) = serde_json::from_str::<serde_json::Value>(data) {
191                                    if let Some(delta) = json["choices"][0].get("delta") {
192                                         if let Some(thinking) = delta.get("reasoning_content").and_then(|v| v.as_str()) {
193                                             yield Ok(LlmStreamEvent::Reasoning(thinking.to_string()));
194                                         }
195                                         if let Some(content) = delta.get("content").and_then(|v| v.as_str()) {
196                                             yield Ok(LlmStreamEvent::Content(content.to_string()));
197                                         }
198                                    }
199                                }
200                            }
201                            buffer.drain(..=line_end);
202                        }
203                    }
204                    Err(e) => yield Err(anyhow!(e)),
205                }
206            }
207        };
208
209        Ok(Box::pin(s))
210    }
211}
212
213#[async_trait]
214pub trait RagRetriever: Send + Sync {
215    async fn retrieve(&self, query: &str) -> Result<String>;
216}
217
218struct NoopRagRetriever;
219
220#[async_trait]
221impl RagRetriever for NoopRagRetriever {
222    async fn retrieve(&self, _query: &str) -> Result<String> {
223        Ok(String::new())
224    }
225}
226
227#[derive(Debug, Deserialize)]
228#[serde(rename_all = "camelCase")]
229struct StructuredResponse {
230    text: Option<String>,
231    wait_input_timeout: Option<u32>,
232    tools: Option<Vec<ToolInvocation>>,
233}
234
235#[derive(Debug, Deserialize, Serialize, Clone)]
236#[serde(tag = "name", rename_all = "lowercase")]
237pub enum ToolInvocation {
238    #[serde(rename_all = "camelCase")]
239    Hangup {
240        reason: Option<String>,
241        initiator: Option<String>,
242    },
243    #[serde(rename_all = "camelCase")]
244    Refer {
245        caller: String,
246        callee: String,
247        options: Option<ReferOption>,
248    },
249    #[serde(rename_all = "camelCase")]
250    Rag {
251        query: String,
252        source: Option<String>,
253    },
254    #[serde(rename_all = "camelCase")]
255    Accept { options: Option<crate::CallOption> },
256    #[serde(rename_all = "camelCase")]
257    Reject {
258        reason: Option<String>,
259        code: Option<u32>,
260    },
261}
262
263pub struct LlmHandler {
264    config: LlmConfig,
265    interruption_config: super::InterruptionConfig,
266    global_follow_up_config: Option<super::FollowUpConfig>,
267    dtmf_config: Option<HashMap<String, super::DtmfAction>>,
268    history: Vec<ChatMessage>,
269    provider: Arc<dyn LlmProvider>,
270    rag_retriever: Arc<dyn RagRetriever>,
271    is_speaking: bool,
272    is_hanging_up: bool,
273    consecutive_follow_ups: u32,
274    last_interaction_at: std::time::Instant,
275    event_sender: Option<crate::event::EventSender>,
276    last_asr_final_at: Option<std::time::Instant>,
277    last_tts_start_at: Option<std::time::Instant>,
278    call: Option<crate::call::ActiveCallRef>,
279    scenes: HashMap<String, super::Scene>,
280    current_scene_id: Option<String>,
281}
282
283impl LlmHandler {
284    pub fn new(
285        config: LlmConfig,
286        interruption: super::InterruptionConfig,
287        global_follow_up_config: Option<super::FollowUpConfig>,
288        scenes: HashMap<String, super::Scene>,
289        dtmf: Option<HashMap<String, super::DtmfAction>>,
290        initial_scene_id: Option<String>,
291    ) -> Self {
292        Self::with_provider(
293            config,
294            Arc::new(DefaultLlmProvider::new()),
295            Arc::new(NoopRagRetriever),
296            interruption,
297            global_follow_up_config,
298            scenes,
299            dtmf,
300            initial_scene_id,
301        )
302    }
303
304    pub fn with_provider(
305        config: LlmConfig,
306        provider: Arc<dyn LlmProvider>,
307        rag_retriever: Arc<dyn RagRetriever>,
308        interruption: super::InterruptionConfig,
309        global_follow_up_config: Option<super::FollowUpConfig>,
310        scenes: HashMap<String, super::Scene>,
311        dtmf: Option<HashMap<String, super::DtmfAction>>,
312        initial_scene_id: Option<String>,
313    ) -> Self {
314        let mut history = Vec::new();
315        let prompt = config.prompt.clone().unwrap_or_default();
316        let system_prompt = Self::build_system_prompt(&prompt);
317
318        history.push(ChatMessage {
319            role: "system".to_string(),
320            content: system_prompt,
321        });
322
323        Self {
324            config,
325            interruption_config: interruption,
326            global_follow_up_config,
327            dtmf_config: dtmf,
328            history,
329            provider,
330            rag_retriever,
331            is_speaking: false,
332            is_hanging_up: false,
333            consecutive_follow_ups: 0,
334            last_interaction_at: std::time::Instant::now(),
335            event_sender: None,
336            last_asr_final_at: None,
337            last_tts_start_at: None,
338            call: None,
339            scenes,
340            current_scene_id: initial_scene_id,
341        }
342    }
343
344    fn build_system_prompt(prompt: &str) -> String {
345        format!(
346            "{}\n\n\
347            Tool usage instructions:\n\
348            - To hang up the call, output: <hangup/>\n\
349            - To transfer the call, output: <refer to=\"sip:xxxx\"/>\n\
350            - To play an audio file, output: <play file=\"path/to/file.wav\"/>\n\
351            - To switch to another scene, output: <goto scene=\"scene_id\"/>\n\
352            Please use these XML tags for action triggers. They are optimized for streaming playback. \
353            Output your response in short sentences. Each sentence will be played as soon as it is finished.",
354            prompt
355        )
356    }
357
358    fn get_dtmf_action(&self, digit: &str) -> Option<super::DtmfAction> {
359        // 1. Check current scene
360        if let Some(scene_id) = &self.current_scene_id {
361            if let Some(scene) = self.scenes.get(scene_id) {
362                if let Some(dtmf) = &scene.dtmf {
363                    if let Some(action) = dtmf.get(digit) {
364                        return Some(action.clone());
365                    }
366                }
367            }
368        }
369
370        // 2. Check global
371        if let Some(dtmf) = &self.dtmf_config {
372            if let Some(action) = dtmf.get(digit) {
373                return Some(action.clone());
374            }
375        }
376
377        None
378    }
379
380    async fn handle_dtmf_action(&mut self, action: super::DtmfAction) -> Result<Vec<Command>> {
381        match action {
382            super::DtmfAction::Goto { scene } => {
383                info!("DTMF action: switch to scene {}", scene);
384                self.switch_to_scene(&scene, true).await
385            }
386            super::DtmfAction::Transfer { target } => {
387                info!("DTMF action: transfer to {}", target);
388                Ok(vec![Command::Refer {
389                    caller: String::new(),
390                    callee: target,
391                    options: None,
392                }])
393            }
394            super::DtmfAction::Hangup => {
395                info!("DTMF action: hangup");
396                Ok(vec![Command::Hangup {
397                    reason: Some("DTMF Hangup".to_string()),
398                    initiator: Some("ai".to_string()),
399                }])
400            }
401        }
402    }
403
404    async fn switch_to_scene(
405        &mut self,
406        scene_id: &str,
407        trigger_response: bool,
408    ) -> Result<Vec<Command>> {
409        if let Some(scene) = self.scenes.get(scene_id).cloned() {
410            info!("Switching to scene: {}", scene_id);
411            self.current_scene_id = Some(scene_id.to_string());
412            // Update system prompt in history
413            let system_prompt = Self::build_system_prompt(&scene.prompt);
414            if let Some(first_msg) = self.history.get_mut(0) {
415                if first_msg.role == "system" {
416                    first_msg.content = system_prompt;
417                }
418            }
419
420            let mut commands = Vec::new();
421            if let Some(url) = &scene.play {
422                commands.push(Command::Play {
423                    url: url.clone(),
424                    play_id: None,
425                    auto_hangup: None,
426                    wait_input_timeout: None,
427                });
428            }
429
430            if trigger_response {
431                let response_cmds = self.generate_response().await?;
432                commands.extend(response_cmds);
433            }
434            Ok(commands)
435        } else {
436            warn!("Scene not found: {}", scene_id);
437            Ok(vec![])
438        }
439    }
440
441    pub fn get_history_ref(&self) -> &[ChatMessage] {
442        &self.history
443    }
444
445    pub fn get_current_scene_id(&self) -> Option<String> {
446        self.current_scene_id.clone()
447    }
448
449    pub fn set_call(&mut self, call: crate::call::ActiveCallRef) {
450        self.call = Some(call);
451    }
452
453    pub fn set_event_sender(&mut self, sender: crate::event::EventSender) {
454        self.event_sender = Some(sender.clone());
455        // Send initial greeting if any
456        if let Some(greeting) = &self.config.greeting {
457            let _ = sender.send(crate::event::SessionEvent::AddHistory {
458                sender: Some("system".to_string()),
459                timestamp: crate::media::get_timestamp(),
460                speaker: "assistant".to_string(),
461                text: greeting.clone(),
462            });
463        }
464    }
465
466    fn send_debug_event(&self, key: &str, data: serde_json::Value) {
467        if let Some(sender) = &self.event_sender {
468            let timestamp = crate::media::get_timestamp();
469            // If this is an LLM response, sync it with history
470            if key == "llm_response" {
471                if let Some(text) = data.get("response").and_then(|v| v.as_str()) {
472                    let _ = sender.send(crate::event::SessionEvent::AddHistory {
473                        sender: Some("llm".to_string()),
474                        timestamp,
475                        speaker: "assistant".to_string(),
476                        text: text.to_string(),
477                    });
478                }
479            }
480
481            let event = crate::event::SessionEvent::Metrics {
482                timestamp,
483                key: key.to_string(),
484                duration: 0,
485                data,
486            };
487            let _ = sender.send(event);
488        }
489    }
490
491    async fn call_llm(&self) -> Result<String> {
492        self.provider.call(&self.config, &self.history).await
493    }
494
495    fn create_tts_command(
496        &self,
497        text: String,
498        wait_input_timeout: Option<u32>,
499        auto_hangup: Option<bool>,
500    ) -> Command {
501        let timeout = wait_input_timeout.unwrap_or(10000);
502        let play_id = uuid::Uuid::new_v4().to_string();
503
504        if let Some(sender) = &self.event_sender {
505            let _ = sender.send(crate::event::SessionEvent::Metrics {
506                timestamp: crate::media::get_timestamp(),
507                key: "tts_play_id_map".to_string(),
508                duration: 0,
509                data: serde_json::json!({
510                    "playId": play_id,
511                    "text": text,
512                }),
513            });
514        }
515
516        Command::Tts {
517            text,
518            speaker: None,
519            play_id: Some(play_id),
520            auto_hangup,
521            streaming: None,
522            end_of_stream: Some(true),
523            option: None,
524            wait_input_timeout: Some(timeout),
525            base64: None,
526        }
527    }
528
529    async fn generate_response(&mut self) -> Result<Vec<Command>> {
530        let start_time = crate::media::get_timestamp();
531        let play_id = uuid::Uuid::new_v4().to_string();
532
533        // Send debug event - LLM call started
534        self.send_debug_event(
535            "llm_call_start",
536            json!({
537                "history_length": self.history.len(),
538                "playId": play_id,
539            }),
540        );
541
542        let mut stream = self
543            .provider
544            .call_stream(&self.config, &self.history)
545            .await?;
546
547        let mut full_content = String::new();
548        let mut full_reasoning = String::new();
549        let mut buffer = String::new();
550        let mut commands = Vec::new();
551        let mut is_json_mode = false;
552        let mut checked_json_mode = false;
553        let mut first_token_time = None;
554
555        while let Some(chunk_result) = stream.next().await {
556            let event = match chunk_result {
557                Ok(c) => c,
558                Err(e) => {
559                    warn!("LLM stream error: {}", e);
560                    break;
561                }
562            };
563
564            match event {
565                LlmStreamEvent::Reasoning(text) => {
566                    full_reasoning.push_str(&text);
567                }
568                LlmStreamEvent::Content(chunk) => {
569                    if first_token_time.is_none() && !chunk.trim().is_empty() {
570                        first_token_time = Some(crate::media::get_timestamp());
571                    }
572
573                    full_content.push_str(&chunk);
574                    buffer.push_str(&chunk);
575
576                    if !checked_json_mode {
577                        let trimmed = full_content.trim();
578                        if !trimmed.is_empty() {
579                            if trimmed.starts_with('{') || trimmed.starts_with('`') {
580                                is_json_mode = true;
581                            }
582                            checked_json_mode = true;
583                        }
584                    }
585
586                    if checked_json_mode && !is_json_mode {
587                        let extracted =
588                            self.extract_streaming_commands(&mut buffer, &play_id, false);
589                        for cmd in extracted {
590                            if let Some(call) = &self.call {
591                                let _ = call.enqueue_command(cmd).await;
592                            } else {
593                                commands.push(cmd);
594                            }
595                        }
596                    }
597                }
598            }
599        }
600
601        // Send debug event - LLM response received
602        let end_time = crate::media::get_timestamp();
603        self.send_debug_event(
604            "llm_response",
605            json!({
606                "response": full_content,
607                "reasoning": full_reasoning,
608                "is_json_mode": is_json_mode,
609                "duration": end_time - start_time,
610                "ttfb": first_token_time.map(|t| t - start_time).unwrap_or(0),
611                "playId": play_id,
612            }),
613        );
614
615        if is_json_mode {
616            self.interpret_response(full_content).await
617        } else {
618            let extracted = self.extract_streaming_commands(&mut buffer, &play_id, true);
619            for cmd in extracted {
620                if let Some(call) = &self.call {
621                    let _ = call.enqueue_command(cmd).await;
622                } else {
623                    commands.push(cmd);
624                }
625            }
626            if !full_content.trim().is_empty() {
627                self.history.push(ChatMessage {
628                    role: "assistant".to_string(),
629                    content: full_content,
630                });
631                self.is_speaking = true;
632                self.last_tts_start_at = Some(std::time::Instant::now());
633            }
634            Ok(commands)
635        }
636    }
637
638    fn extract_streaming_commands(
639        &mut self,
640        buffer: &mut String,
641        play_id: &str,
642        is_final: bool,
643    ) -> Vec<Command> {
644        let mut commands = Vec::new();
645
646        loop {
647            let hangup_pos = RE_HANGUP.find(buffer);
648            let refer_pos = RE_REFER.captures(buffer);
649            let play_pos = RE_PLAY.captures(buffer);
650            let goto_pos = RE_GOTO.captures(buffer);
651            let sentence_pos = RE_SENTENCE.find(buffer);
652
653            // Find the first occurrence
654            let mut positions = Vec::new();
655            if let Some(m) = hangup_pos {
656                positions.push((m.start(), 0));
657            }
658            if let Some(caps) = &refer_pos {
659                positions.push((caps.get(0).unwrap().start(), 1));
660            }
661            if let Some(caps) = &play_pos {
662                positions.push((caps.get(0).unwrap().start(), 3));
663            }
664            if let Some(caps) = &goto_pos {
665                positions.push((caps.get(0).unwrap().start(), 4));
666            }
667            if let Some(m) = sentence_pos {
668                positions.push((m.start(), 2));
669            }
670
671            positions.sort_by_key(|p| p.0);
672
673            if let Some((pos, kind)) = positions.first() {
674                let pos = *pos;
675                match kind {
676                    0 => {
677                        // Hangup
678                        let prefix = buffer[..pos].to_string();
679                        if !prefix.trim().is_empty() {
680                            let mut cmd = self.create_tts_command_with_id(
681                                prefix,
682                                play_id.to_string(),
683                                Some(true),
684                            );
685                            if let Command::Tts { end_of_stream, .. } = &mut cmd {
686                                *end_of_stream = Some(true);
687                            }
688                            // Mark as hanging up to prevent interruption
689                            self.is_hanging_up = true;
690                            commands.push(cmd);
691                        } else {
692                            // Send an empty TTS command with auto_hangup=true to close the stream and trigger hangup
693                            let mut cmd = self.create_tts_command_with_id(
694                                "".to_string(),
695                                play_id.to_string(),
696                                Some(true),
697                            );
698                            if let Command::Tts { end_of_stream, .. } = &mut cmd {
699                                *end_of_stream = Some(true);
700                            }
701                            self.is_hanging_up = true;
702                            commands.push(cmd);
703                        }
704                        buffer.drain(..RE_HANGUP.find(buffer).unwrap().end());
705                        // Stop after hangup
706                        return commands;
707                    }
708                    1 => {
709                        // Refer
710                        let caps = RE_REFER.captures(buffer).unwrap();
711                        let mat = caps.get(0).unwrap();
712                        let callee = caps.get(1).unwrap().as_str().to_string();
713
714                        let prefix = buffer[..pos].to_string();
715                        if !prefix.trim().is_empty() {
716                            commands.push(self.create_tts_command_with_id(
717                                prefix,
718                                play_id.to_string(),
719                                None,
720                            ));
721                        }
722                        commands.push(Command::Refer {
723                            caller: String::new(),
724                            callee,
725                            options: None,
726                        });
727                        buffer.drain(..mat.end());
728                    }
729                    3 => {
730                        // Play audio
731                        let caps = RE_PLAY.captures(buffer).unwrap();
732                        let mat = caps.get(0).unwrap();
733                        let url = caps.get(1).unwrap().as_str().to_string();
734
735                        let prefix = buffer[..pos].to_string();
736                        if !prefix.trim().is_empty() {
737                            commands.push(self.create_tts_command_with_id(
738                                prefix,
739                                play_id.to_string(),
740                                None,
741                            ));
742                        }
743                        commands.push(Command::Play {
744                            url,
745                            play_id: None,
746                            auto_hangup: None,
747                            wait_input_timeout: None,
748                        });
749                        buffer.drain(..mat.end());
750                    }
751                    4 => {
752                        // Goto Scene
753                        let caps = RE_GOTO.captures(buffer).unwrap();
754                        let mat = caps.get(0).unwrap();
755                        let scene_id = caps.get(1).unwrap().as_str().to_string();
756
757                        let prefix = buffer[..pos].to_string();
758                        if !prefix.trim().is_empty() {
759                            commands.push(self.create_tts_command_with_id(
760                                prefix,
761                                play_id.to_string(),
762                                None,
763                            ));
764                        }
765
766                        info!("Switching to scene (from stream): {}", scene_id);
767                        if let Some(scene) = self.scenes.get(&scene_id) {
768                            self.current_scene_id = Some(scene_id);
769                            // Update system prompt in history
770                            let system_prompt = Self::build_system_prompt(&scene.prompt);
771                            if let Some(first_msg) = self.history.get_mut(0) {
772                                if first_msg.role == "system" {
773                                    first_msg.content = system_prompt;
774                                }
775                            }
776                        } else {
777                            warn!("Scene not found: {}", scene_id);
778                        }
779
780                        buffer.drain(..mat.end());
781                    }
782                    2 => {
783                        // Sentence
784                        let mat = sentence_pos.unwrap();
785                        let sentence = buffer[..mat.end()].to_string();
786                        if !sentence.trim().is_empty() {
787                            commands.push(self.create_tts_command_with_id(
788                                sentence,
789                                play_id.to_string(),
790                                None,
791                            ));
792                        }
793                        buffer.drain(..mat.end());
794                    }
795                    _ => unreachable!(),
796                }
797            } else {
798                break;
799            }
800        }
801
802        if is_final {
803            let remaining = buffer.trim().to_string();
804            if !remaining.is_empty() {
805                commands.push(self.create_tts_command_with_id(
806                    remaining,
807                    play_id.to_string(),
808                    None,
809                ));
810            }
811            buffer.clear();
812
813            if let Some(last) = commands.last_mut() {
814                if let Command::Tts { end_of_stream, .. } = last {
815                    *end_of_stream = Some(true);
816                }
817            } else if !self.is_hanging_up {
818                commands.push(Command::Tts {
819                    text: "".to_string(),
820                    speaker: None,
821                    play_id: Some(play_id.to_string()),
822                    auto_hangup: None,
823                    streaming: Some(true),
824                    end_of_stream: Some(true),
825                    option: None,
826                    wait_input_timeout: None,
827                    base64: None,
828                });
829            }
830        }
831
832        commands
833    }
834
835    fn create_tts_command_with_id(
836        &self,
837        text: String,
838        play_id: String,
839        auto_hangup: Option<bool>,
840    ) -> Command {
841        Command::Tts {
842            text,
843            speaker: None,
844            play_id: Some(play_id),
845            auto_hangup,
846            streaming: Some(true),
847            end_of_stream: None,
848            option: None,
849            wait_input_timeout: Some(10000),
850            base64: None,
851        }
852    }
853
854    async fn interpret_response(&mut self, initial: String) -> Result<Vec<Command>> {
855        let mut tool_commands = Vec::new();
856        let mut wait_input_timeout = None;
857        let mut attempts = 0;
858        let final_text: Option<String>;
859        let mut raw = initial;
860
861        loop {
862            attempts += 1;
863            let mut rerun_for_rag = false;
864
865            if let Some(structured) = parse_structured_response(&raw) {
866                if wait_input_timeout.is_none() {
867                    wait_input_timeout = structured.wait_input_timeout;
868                }
869
870                if let Some(tools) = structured.tools {
871                    for tool in tools {
872                        match tool {
873                            ToolInvocation::Hangup {
874                                ref reason,
875                                ref initiator,
876                            } => {
877                                // Send debug event
878                                self.send_debug_event(
879                                    "tool_invocation",
880                                    json!({
881                                        "tool": "Hangup",
882                                        "params": {
883                                            "reason": reason,
884                                            "initiator": initiator,
885                                        }
886                                    }),
887                                );
888                                tool_commands.push(Command::Hangup {
889                                    reason: reason.clone(),
890                                    initiator: initiator.clone(),
891                                });
892                            }
893                            ToolInvocation::Refer {
894                                ref caller,
895                                ref callee,
896                                ref options,
897                            } => {
898                                // Send debug event
899                                self.send_debug_event(
900                                    "tool_invocation",
901                                    json!({
902                                        "tool": "Refer",
903                                        "params": {
904                                            "caller": caller,
905                                            "callee": callee,
906                                        }
907                                    }),
908                                );
909                                tool_commands.push(Command::Refer {
910                                    caller: caller.clone(),
911                                    callee: callee.clone(),
912                                    options: options.clone(),
913                                });
914                            }
915                            ToolInvocation::Rag {
916                                ref query,
917                                ref source,
918                            } => {
919                                // Send debug event - RAG query started
920                                self.send_debug_event(
921                                    "tool_invocation",
922                                    json!({
923                                        "tool": "Rag",
924                                        "params": {
925                                            "query": query,
926                                            "source": source,
927                                        }
928                                    }),
929                                );
930
931                                let rag_result = self.rag_retriever.retrieve(&query).await?;
932
933                                // Send debug event - RAG result
934                                self.send_debug_event(
935                                    "rag_result",
936                                    json!({
937                                        "query": query,
938                                        "result": rag_result,
939                                    }),
940                                );
941
942                                let summary = if let Some(source) = source {
943                                    format!("[{}] {}", source, rag_result)
944                                } else {
945                                    rag_result
946                                };
947                                self.history.push(ChatMessage {
948                                    role: "system".to_string(),
949                                    content: format!("RAG result for {}: {}", query, summary),
950                                });
951                                rerun_for_rag = true;
952                            }
953                            ToolInvocation::Accept { ref options } => {
954                                self.send_debug_event(
955                                    "tool_invocation",
956                                    json!({
957                                        "tool": "Accept",
958                                    }),
959                                );
960                                tool_commands.push(Command::Accept {
961                                    option: options.clone().unwrap_or_default(),
962                                });
963                            }
964                            ToolInvocation::Reject { ref reason, code } => {
965                                self.send_debug_event(
966                                    "tool_invocation",
967                                    json!({
968                                        "tool": "Reject",
969                                        "params": {
970                                            "reason": reason,
971                                            "code": code,
972                                        }
973                                    }),
974                                );
975                                tool_commands.push(Command::Reject {
976                                    reason: reason
977                                        .clone()
978                                        .unwrap_or_else(|| "Rejected by agent".to_string()),
979                                    code,
980                                });
981                            }
982                        }
983                    }
984                }
985
986                if rerun_for_rag {
987                    if attempts >= MAX_RAG_ATTEMPTS {
988                        warn!("Reached RAG iteration limit, using last response");
989                        final_text = structured.text.or_else(|| Some(raw.clone()));
990                        break;
991                    }
992                    raw = self.call_llm().await?;
993                    continue;
994                }
995
996                final_text = structured.text;
997                break;
998            }
999
1000            final_text = Some(raw.clone());
1001            break;
1002        }
1003
1004        let mut commands = Vec::new();
1005
1006        // Check if any tool invocation is a hangup
1007        let mut has_hangup = false;
1008        for tool in &tool_commands {
1009            if matches!(tool, Command::Hangup { .. }) {
1010                has_hangup = true;
1011                break;
1012            }
1013        }
1014
1015        if let Some(text) = final_text {
1016            if !text.trim().is_empty() {
1017                self.history.push(ChatMessage {
1018                    role: "assistant".to_string(),
1019                    content: text.clone(),
1020                });
1021                self.last_tts_start_at = Some(std::time::Instant::now());
1022                self.is_speaking = true;
1023
1024                // If we have text and a hangup command, use auto_hangup on the TTS command
1025                // and remove the separate hangup command.
1026                if has_hangup {
1027                    commands.push(self.create_tts_command(text, wait_input_timeout, Some(true)));
1028                    tool_commands.retain(|c| !matches!(c, Command::Hangup { .. }));
1029                    // Mark as hanging up to prevent interruption
1030                    self.is_hanging_up = true;
1031                } else {
1032                    commands.push(self.create_tts_command(text, wait_input_timeout, None));
1033                }
1034            }
1035        }
1036
1037        commands.extend(tool_commands);
1038
1039        Ok(commands)
1040    }
1041}
1042
1043fn parse_structured_response(raw: &str) -> Option<StructuredResponse> {
1044    let payload = extract_json_block(raw)?;
1045    serde_json::from_str(payload).ok()
1046}
1047
1048fn is_likely_filler(text: &str) -> bool {
1049    let trimmed = text.trim().to_lowercase();
1050    FILLERS.contains(&trimmed)
1051}
1052
1053fn extract_json_block(raw: &str) -> Option<&str> {
1054    let trimmed = raw.trim();
1055    if trimmed.starts_with('`') {
1056        if let Some(end) = trimmed.rfind("```") {
1057            if end <= 3 {
1058                return None;
1059            }
1060            let mut inner = &trimmed[3..end];
1061            inner = inner.trim();
1062            if inner.to_lowercase().starts_with("json") {
1063                if let Some(newline) = inner.find('\n') {
1064                    inner = inner[newline + 1..].trim();
1065                } else if inner.len() > 4 {
1066                    inner = inner[4..].trim();
1067                } else {
1068                    inner = inner.trim();
1069                }
1070            }
1071            return Some(inner);
1072        }
1073    } else if trimmed.starts_with('{') || trimmed.starts_with('[') {
1074        return Some(trimmed);
1075    }
1076    None
1077}
1078
1079#[async_trait]
1080impl DialogueHandler for LlmHandler {
1081    async fn on_start(&mut self) -> Result<Vec<Command>> {
1082        self.last_tts_start_at = Some(std::time::Instant::now());
1083
1084        let mut commands = Vec::new();
1085
1086        // Check if current scene has an audio file to play
1087        if let Some(scene_id) = &self.current_scene_id {
1088            if let Some(scene) = self.scenes.get(scene_id) {
1089                if let Some(audio_file) = &scene.play {
1090                    commands.push(Command::Play {
1091                        url: audio_file.clone(),
1092                        play_id: None,
1093                        auto_hangup: None,
1094                        wait_input_timeout: None,
1095                    });
1096                }
1097            }
1098        }
1099
1100        if let Some(greeting) = &self.config.greeting {
1101            self.is_speaking = true;
1102            commands.push(self.create_tts_command(greeting.clone(), None, None));
1103            return Ok(commands);
1104        }
1105
1106        let response_commands = self.generate_response().await?;
1107        commands.extend(response_commands);
1108        Ok(commands)
1109    }
1110
1111    async fn on_event(&mut self, event: &SessionEvent) -> Result<Vec<Command>> {
1112        match event {
1113            SessionEvent::Dtmf { digit, .. } => {
1114                info!("DTMF received: {}", digit);
1115                let action = self.get_dtmf_action(digit);
1116                if let Some(action) = action {
1117                    return self.handle_dtmf_action(action).await;
1118                }
1119                Ok(vec![])
1120            }
1121
1122            SessionEvent::AsrFinal { text, .. } => {
1123                if text.trim().is_empty() {
1124                    return Ok(vec![]);
1125                }
1126
1127                self.last_asr_final_at = Some(std::time::Instant::now());
1128                self.last_interaction_at = std::time::Instant::now();
1129                self.is_speaking = false;
1130                self.consecutive_follow_ups = 0;
1131
1132                self.history.push(ChatMessage {
1133                    role: "user".to_string(),
1134                    content: text.clone(),
1135                });
1136
1137                self.generate_response().await
1138            }
1139
1140            SessionEvent::AsrDelta { is_filler, .. } | SessionEvent::Speaking { is_filler, .. } => {
1141                let strategy = self.interruption_config.strategy;
1142                let should_check = match (strategy, event) {
1143                    (InterruptionStrategy::None, _) => false,
1144                    (InterruptionStrategy::Vad, SessionEvent::Speaking { .. }) => true,
1145                    (InterruptionStrategy::Asr, SessionEvent::AsrDelta { .. }) => true,
1146                    (InterruptionStrategy::Both, _) => true,
1147                    _ => false,
1148                };
1149
1150                // Do not allow interruption if we are in the process of hanging up (e.g. saying goodbye)
1151                if self.is_speaking && !self.is_hanging_up && should_check {
1152                    // 1. Protection Period Check
1153                    if let Some(last_start) = self.last_tts_start_at {
1154                        let ignore_ms = self.interruption_config.ignore_first_ms.unwrap_or(800);
1155                        if last_start.elapsed().as_millis() < ignore_ms as u128 {
1156                            return Ok(vec![]);
1157                        }
1158                    }
1159
1160                    // 2. Filler Word Check
1161                    if self.interruption_config.filler_word_filter.unwrap_or(false) {
1162                        if let Some(true) = is_filler {
1163                            return Ok(vec![]);
1164                        }
1165                        // Secondary text-based filler check for ASR Delta
1166                        if let SessionEvent::AsrDelta { text, .. } = event {
1167                            if is_likely_filler(text) {
1168                                return Ok(vec![]);
1169                            }
1170                        }
1171                    }
1172
1173                    // 3. Stale event check (already had one in original code)
1174                    if let Some(last_final) = self.last_asr_final_at {
1175                        if last_final.elapsed().as_millis() < 500 {
1176                            return Ok(vec![]);
1177                        }
1178                    }
1179
1180                    info!("Smart interruption detected, stopping playback");
1181                    self.is_speaking = false;
1182                    return Ok(vec![Command::Interrupt {
1183                        graceful: Some(true),
1184                        fade_out_ms: self.interruption_config.volume_fade_ms,
1185                    }]);
1186                }
1187                Ok(vec![])
1188            }
1189
1190            SessionEvent::Eou { completed, .. } => {
1191                if *completed && self.is_speaking == false {
1192                    info!("EOU detected, triggering early response");
1193                    return self.generate_response().await;
1194                }
1195                Ok(vec![])
1196            }
1197
1198            SessionEvent::Silence { .. } => {
1199                let follow_up_config = if let Some(scene_id) = &self.current_scene_id {
1200                    self.scenes
1201                        .get(scene_id)
1202                        .and_then(|s| s.follow_up)
1203                        .or(self.global_follow_up_config)
1204                } else {
1205                    self.global_follow_up_config
1206                };
1207
1208                if let Some(config) = follow_up_config {
1209                    if !self.is_speaking
1210                        && self.last_interaction_at.elapsed().as_millis() as u64 >= config.timeout
1211                    {
1212                        if self.consecutive_follow_ups >= config.max_count {
1213                            info!("Max follow-up count reached, hanging up");
1214                            return Ok(vec![Command::Hangup {
1215                                reason: Some("Max follow-up reached".to_string()),
1216                                initiator: Some("system".to_string()),
1217                            }]);
1218                        }
1219
1220                        info!(
1221                            "Silence timeout detected ({}ms), triggering follow-up ({}/{})",
1222                            self.last_interaction_at.elapsed().as_millis(),
1223                            self.consecutive_follow_ups + 1,
1224                            config.max_count
1225                        );
1226                        self.consecutive_follow_ups += 1;
1227                        self.last_interaction_at = std::time::Instant::now();
1228                        return self.generate_response().await;
1229                    }
1230                }
1231                Ok(vec![])
1232            }
1233
1234            SessionEvent::TrackStart { .. } => {
1235                self.is_speaking = true;
1236                Ok(vec![])
1237            }
1238
1239            SessionEvent::TrackEnd { .. } => {
1240                self.is_speaking = false;
1241                self.is_hanging_up = false;
1242                self.last_interaction_at = std::time::Instant::now();
1243                Ok(vec![])
1244            }
1245
1246            SessionEvent::FunctionCall {
1247                name, arguments, ..
1248            } => {
1249                info!(
1250                    "Function call from Realtime: {} with args {}",
1251                    name, arguments
1252                );
1253                let args: serde_json::Value = serde_json::from_str(arguments).unwrap_or_default();
1254                match name.as_str() {
1255                    "hangup_call" => Ok(vec![Command::Hangup {
1256                        reason: args["reason"].as_str().map(|s| s.to_string()),
1257                        initiator: Some("ai".to_string()),
1258                    }]),
1259                    "transfer_call" | "refer_call" => {
1260                        if let Some(callee) = args["callee"]
1261                            .as_str()
1262                            .or_else(|| args["callee_uri"].as_str())
1263                        {
1264                            Ok(vec![Command::Refer {
1265                                caller: String::new(),
1266                                callee: callee.to_string(),
1267                                options: None,
1268                            }])
1269                        } else {
1270                            warn!("No callee provided for transfer_call");
1271                            Ok(vec![])
1272                        }
1273                    }
1274                    "goto_scene" => {
1275                        if let Some(scene) = args["scene"].as_str() {
1276                            self.switch_to_scene(scene, false).await
1277                        } else {
1278                            Ok(vec![])
1279                        }
1280                    }
1281                    _ => {
1282                        warn!("Unhandled function call: {}", name);
1283                        Ok(vec![])
1284                    }
1285                }
1286            }
1287
1288            _ => Ok(vec![]),
1289        }
1290    }
1291
1292    async fn get_history(&self) -> Vec<ChatMessage> {
1293        self.history.clone()
1294    }
1295
1296    async fn summarize(&mut self, prompt: &str) -> Result<String> {
1297        info!("Generating summary with prompt: {}", prompt);
1298        let mut summary_history = self.history.clone();
1299        summary_history.push(ChatMessage {
1300            role: "user".to_string(),
1301            content: prompt.to_string(),
1302        });
1303
1304        self.provider.call(&self.config, &summary_history).await
1305    }
1306}
1307
1308#[cfg(test)]
1309mod tests {
1310    use super::*;
1311    use crate::event::SessionEvent;
1312    use anyhow::{Result, anyhow};
1313    use async_trait::async_trait;
1314    use std::collections::VecDeque;
1315    use std::sync::Mutex;
1316
1317    struct TestProvider {
1318        responses: Mutex<VecDeque<String>>,
1319    }
1320
1321    impl TestProvider {
1322        fn new(responses: Vec<String>) -> Self {
1323            Self {
1324                responses: Mutex::new(VecDeque::from(responses)),
1325            }
1326        }
1327    }
1328
1329    #[async_trait]
1330    impl LlmProvider for TestProvider {
1331        async fn call(&self, _config: &LlmConfig, _history: &[ChatMessage]) -> Result<String> {
1332            let mut guard = self.responses.lock().unwrap();
1333            guard
1334                .pop_front()
1335                .ok_or_else(|| anyhow!("Test provider ran out of responses"))
1336        }
1337
1338        async fn call_stream(
1339            &self,
1340            _config: &LlmConfig,
1341            _history: &[ChatMessage],
1342        ) -> Result<Pin<Box<dyn Stream<Item = Result<LlmStreamEvent>> + Send>>> {
1343            let response = self.call(_config, _history).await?;
1344            let s = async_stream::stream! {
1345                yield Ok(LlmStreamEvent::Content(response));
1346            };
1347            Ok(Box::pin(s))
1348        }
1349    }
1350
1351    struct RecordingRag {
1352        queries: Mutex<Vec<String>>,
1353    }
1354
1355    impl RecordingRag {
1356        fn new() -> Self {
1357            Self {
1358                queries: Mutex::new(Vec::new()),
1359            }
1360        }
1361
1362        fn recorded_queries(&self) -> Vec<String> {
1363            self.queries.lock().unwrap().clone()
1364        }
1365    }
1366
1367    #[async_trait]
1368    impl RagRetriever for RecordingRag {
1369        async fn retrieve(&self, query: &str) -> Result<String> {
1370            self.queries.lock().unwrap().push(query.to_string());
1371            Ok(format!("retrieved {}", query))
1372        }
1373    }
1374
1375    #[tokio::test]
1376    async fn handler_applies_tool_instructions() -> Result<()> {
1377        let response = r#"{
1378            "text": "Goodbye",
1379            "waitInputTimeout": 15000,
1380            "tools": [
1381                {"name": "hangup", "reason": "done", "initiator": "agent"},
1382                {"name": "refer", "caller": "sip:bot", "callee": "sip:lead"}
1383            ]
1384        }"#;
1385
1386        let provider = Arc::new(TestProvider::new(vec![response.to_string()]));
1387        let mut handler = LlmHandler::with_provider(
1388            LlmConfig::default(),
1389            provider,
1390            Arc::new(NoopRagRetriever),
1391            crate::playbook::InterruptionConfig::default(),
1392            None,
1393            HashMap::new(),
1394            None,
1395            None,
1396        );
1397
1398        let event = SessionEvent::AsrFinal {
1399            track_id: "track-1".to_string(),
1400            timestamp: 0,
1401            index: 0,
1402            start_time: None,
1403            end_time: None,
1404            text: "hello".to_string(),
1405            is_filler: None,
1406            confidence: None,
1407        };
1408
1409        let commands = handler.on_event(&event).await?;
1410        assert!(matches!(
1411            commands.get(0),
1412            Some(Command::Tts {
1413                text,
1414                wait_input_timeout: Some(15000),
1415                auto_hangup: Some(true),
1416                ..
1417            }) if text == "Goodbye"
1418        ));
1419        assert!(commands.iter().any(|cmd| matches!(
1420            cmd,
1421            Command::Refer {
1422                caller,
1423                callee,
1424                ..
1425            } if caller == "sip:bot" && callee == "sip:lead"
1426        )));
1427
1428        Ok(())
1429    }
1430
1431    #[tokio::test]
1432    async fn handler_requeries_after_rag() -> Result<()> {
1433        let rag_instruction = r#"{"tools": [{"name": "rag", "query": "policy"}]}"#;
1434        let provider = Arc::new(TestProvider::new(vec![
1435            rag_instruction.to_string(),
1436            "Final answer".to_string(),
1437        ]));
1438        let rag = Arc::new(RecordingRag::new());
1439        let mut handler = LlmHandler::with_provider(
1440            LlmConfig::default(),
1441            provider,
1442            rag.clone(),
1443            crate::playbook::InterruptionConfig::default(),
1444            None,
1445            HashMap::new(),
1446            None,
1447            None,
1448        );
1449
1450        let event = SessionEvent::AsrFinal {
1451            track_id: "track-2".to_string(),
1452            timestamp: 0,
1453            index: 0,
1454            start_time: None,
1455            end_time: None,
1456            text: "reep".to_string(),
1457            is_filler: None,
1458            confidence: None,
1459        };
1460
1461        let commands = handler.on_event(&event).await?;
1462        assert!(matches!(
1463            commands.get(0),
1464            Some(Command::Tts {
1465                text,
1466                wait_input_timeout: Some(timeout),
1467                ..
1468            }) if text == "Final answer" && *timeout == 10000
1469        ));
1470        assert_eq!(rag.recorded_queries(), vec!["policy".to_string()]);
1471
1472        Ok(())
1473    }
1474
1475    #[tokio::test]
1476    async fn test_full_dialogue_flow() -> Result<()> {
1477        let responses = vec![
1478            "Hello! How can I help you today?".to_string(),
1479            r#"{"text": "I can help with that. Anything else?", "waitInputTimeout": 5000}"#
1480                .to_string(),
1481            r#"{"text": "Goodbye!", "tools": [{"name": "hangup", "reason": "completed"}]}"#
1482                .to_string(),
1483        ];
1484
1485        let provider = Arc::new(TestProvider::new(responses));
1486        let config = LlmConfig {
1487            greeting: Some("Welcome to the voice assistant.".to_string()),
1488            ..Default::default()
1489        };
1490
1491        let mut handler = LlmHandler::with_provider(
1492            config,
1493            provider,
1494            Arc::new(NoopRagRetriever),
1495            crate::playbook::InterruptionConfig::default(),
1496            None,
1497            HashMap::new(),
1498            None,
1499            None,
1500        );
1501
1502        // 1. Start the dialogue
1503        let commands = handler.on_start().await?;
1504        assert_eq!(commands.len(), 1);
1505        if let Command::Tts { text, .. } = &commands[0] {
1506            assert_eq!(text, "Welcome to the voice assistant.");
1507        } else {
1508            panic!("Expected Tts command");
1509        }
1510
1511        // 2. User says something
1512        let event = SessionEvent::AsrFinal {
1513            track_id: "test".to_string(),
1514            timestamp: 0,
1515            index: 0,
1516            start_time: None,
1517            end_time: None,
1518            text: "I need help".to_string(),
1519            is_filler: None,
1520            confidence: None,
1521        };
1522        let commands = handler.on_event(&event).await?;
1523        // "Hello! How can I help you today?" -> split into two + EOS
1524        assert_eq!(commands.len(), 3);
1525        if let Command::Tts { text, .. } = &commands[0] {
1526            assert!(text.contains("Hello"));
1527        } else {
1528            panic!("Expected Tts command");
1529        }
1530
1531        // 3. User says something else
1532        let event = SessionEvent::AsrFinal {
1533            track_id: "test".to_string(),
1534            timestamp: 0,
1535            index: 1,
1536            start_time: None,
1537            end_time: None,
1538            text: "Tell me a joke".to_string(),
1539            is_filler: None,
1540            confidence: None,
1541        };
1542        let commands = handler.on_event(&event).await?;
1543        assert_eq!(commands.len(), 1);
1544        if let Command::Tts {
1545            text,
1546            wait_input_timeout,
1547            ..
1548        } = &commands[0]
1549        {
1550            assert_eq!(text, "I can help with that. Anything else?");
1551            assert_eq!(*wait_input_timeout, Some(5000));
1552        } else {
1553            panic!("Expected Tts command");
1554        }
1555
1556        // 4. User says goodbye
1557        let event = SessionEvent::AsrFinal {
1558            track_id: "test".to_string(),
1559            timestamp: 0,
1560            index: 2,
1561            start_time: None,
1562            end_time: None,
1563            text: "That's all, thanks".to_string(),
1564            is_filler: None,
1565            confidence: None,
1566        };
1567        let commands = handler.on_event(&event).await?;
1568        // Should have Tts with auto_hangup
1569        assert_eq!(commands.len(), 1);
1570
1571        let has_tts_hangup = commands.iter().any(|c| {
1572            matches!(
1573                c,
1574                Command::Tts {
1575                    text,
1576                    auto_hangup: Some(true),
1577                    ..
1578                } if text == "Goodbye!"
1579            )
1580        });
1581
1582        assert!(has_tts_hangup);
1583
1584        Ok(())
1585    }
1586
1587    #[tokio::test]
1588    async fn test_xml_tools_and_sentence_splitting() -> Result<()> {
1589        let responses = vec!["Hello! <refer to=\"sip:123\"/> How are you? <hangup/>".to_string()];
1590        let provider = Arc::new(TestProvider::new(responses));
1591        let mut handler = LlmHandler::with_provider(
1592            LlmConfig::default(),
1593            provider,
1594            Arc::new(NoopRagRetriever),
1595            crate::playbook::InterruptionConfig::default(),
1596            None,
1597            HashMap::new(),
1598            None,
1599            None,
1600        );
1601
1602        let event = SessionEvent::AsrFinal {
1603            track_id: "test".to_string(),
1604            timestamp: 0,
1605            index: 0,
1606            start_time: None,
1607            end_time: None,
1608            text: "hi".to_string(),
1609            is_filler: None,
1610            confidence: None,
1611        };
1612
1613        let commands = handler.on_event(&event).await?;
1614
1615        // Expected commands:
1616        // 1. TTS "Hello! "
1617        // 2. Refer "sip:123"
1618        // 3. TTS " How are you? "
1619        // 4. Hangup
1620        assert_eq!(commands.len(), 4);
1621
1622        if let Command::Tts {
1623            text,
1624            play_id: pid1,
1625            ..
1626        } = &commands[0]
1627        {
1628            assert!(text.contains("Hello"));
1629            assert!(pid1.is_some());
1630
1631            if let Command::Refer { callee, .. } = &commands[1] {
1632                assert_eq!(callee, "sip:123");
1633            } else {
1634                panic!("Expected Refer");
1635            }
1636
1637            if let Command::Tts {
1638                text,
1639                play_id: pid2,
1640                ..
1641            } = &commands[2]
1642            {
1643                assert!(text.contains("How are you"));
1644                assert_eq!(*pid1, *pid2); // Same play_id
1645            } else {
1646                panic!("Expected Tts");
1647            }
1648
1649            if let Command::Tts {
1650                auto_hangup: Some(true),
1651                ..
1652            } = &commands[3]
1653            {
1654                // Ok
1655            } else {
1656                panic!("Expected Tts with auto_hangup");
1657            }
1658        } else {
1659            panic!("Expected Tts");
1660        }
1661
1662        Ok(())
1663    }
1664
1665    #[tokio::test]
1666    async fn test_interruption_logic() -> Result<()> {
1667        let provider = Arc::new(TestProvider::new(vec!["Some long response".to_string()]));
1668        let mut handler = LlmHandler::with_provider(
1669            LlmConfig::default(),
1670            provider,
1671            Arc::new(NoopRagRetriever),
1672            crate::playbook::InterruptionConfig::default(),
1673            None,
1674            HashMap::new(),
1675            None,
1676            None,
1677        );
1678
1679        // 1. Trigger a response
1680        let event = SessionEvent::AsrFinal {
1681            track_id: "test".to_string(),
1682            timestamp: 0,
1683            index: 0,
1684            start_time: None,
1685            end_time: None,
1686            text: "hello".to_string(),
1687            is_filler: None,
1688            confidence: None,
1689        };
1690        handler.on_event(&event).await?;
1691        assert!(handler.is_speaking);
1692
1693        // Sleep to bypass the 800ms interruption guard
1694        tokio::time::sleep(std::time::Duration::from_millis(850)).await;
1695
1696        // 2. Simulate user starting to speak (AsrDelta)
1697        let event = SessionEvent::AsrDelta {
1698            track_id: "test".to_string(),
1699            timestamp: 0,
1700            index: 0,
1701            start_time: None,
1702            end_time: None,
1703            text: "I...".to_string(),
1704            is_filler: None,
1705            confidence: None,
1706        };
1707        let commands = handler.on_event(&event).await?;
1708        assert_eq!(commands.len(), 1);
1709        assert!(matches!(commands[0], Command::Interrupt { .. }));
1710        assert!(!handler.is_speaking);
1711
1712        Ok(())
1713    }
1714
1715    #[tokio::test]
1716    async fn test_rag_iteration_limit() -> Result<()> {
1717        // Provider that always returns a RAG tool call
1718        let rag_instruction = r#"{"tools": [{"name": "rag", "query": "endless"}]}"#;
1719        let provider = Arc::new(TestProvider::new(vec![
1720            rag_instruction.to_string(),
1721            rag_instruction.to_string(),
1722            rag_instruction.to_string(),
1723            rag_instruction.to_string(),
1724            "Should not reach here".to_string(),
1725        ]));
1726
1727        let mut handler = LlmHandler::with_provider(
1728            LlmConfig::default(),
1729            provider,
1730            Arc::new(RecordingRag::new()),
1731            crate::playbook::InterruptionConfig::default(),
1732            None,
1733            HashMap::new(),
1734            None,
1735            None,
1736        );
1737
1738        let event = SessionEvent::AsrFinal {
1739            track_id: "test".to_string(),
1740            timestamp: 0,
1741            index: 0,
1742            start_time: None,
1743            end_time: None,
1744            text: "loop".to_string(),
1745            is_filler: None,
1746            confidence: None,
1747        };
1748
1749        let commands = handler.on_event(&event).await?;
1750        // After 3 attempts (MAX_RAG_ATTEMPTS), it should stop and return the last raw response
1751        assert_eq!(commands.len(), 1);
1752        if let Command::Tts { text, .. } = &commands[0] {
1753            assert_eq!(text, rag_instruction);
1754        }
1755
1756        Ok(())
1757    }
1758
1759    #[tokio::test]
1760    async fn test_follow_up_logic() -> Result<()> {
1761        use std::time::Duration;
1762
1763        // 1. Setup handler with follow-up config
1764        let follow_up_config = super::super::FollowUpConfig {
1765            timeout: 100, // 100ms for testing
1766            max_count: 2,
1767        };
1768
1769        // Provider that returns responses for follow-ups
1770        let provider = Arc::new(TestProvider::new(vec![
1771            "Follow up 1".to_string(),
1772            "Follow up 2".to_string(),
1773            "Response to user".to_string(),
1774        ]));
1775
1776        let mut handler = LlmHandler::with_provider(
1777            LlmConfig::default(),
1778            provider,
1779            Arc::new(NoopRagRetriever),
1780            crate::playbook::InterruptionConfig::default(),
1781            Some(follow_up_config),
1782            HashMap::new(),
1783            None,
1784            None,
1785        );
1786
1787        // 2. Simulate initial interaction end
1788        handler.last_interaction_at = std::time::Instant::now();
1789        handler.is_speaking = false;
1790
1791        // 3. Silence < timeout
1792        let event = SessionEvent::Silence {
1793            track_id: "t1".to_string(),
1794            timestamp: 0,
1795            start_time: 0,
1796            duration: 50,
1797            samples: None,
1798        };
1799        let commands = handler.on_event(&event).await?;
1800        assert!(commands.is_empty(), "Should not trigger if < timeout");
1801
1802        // 4. Silence >= timeout
1803        tokio::time::sleep(Duration::from_millis(110)).await;
1804        // We need to act as if time passed. logic uses last_interaction_at.elapsed()
1805        // Wait, handler.last_interaction_at was set to now(). Sleep ensures elapsed() > 100ms.
1806
1807        let event = SessionEvent::Silence {
1808            track_id: "t1".to_string(),
1809            timestamp: 0,
1810            start_time: 0,
1811            duration: 100,
1812            samples: None,
1813        };
1814        let commands = handler.on_event(&event).await?;
1815        assert_eq!(commands.len(), 1, "Should trigger follow-up 1");
1816        if let Command::Tts { text, .. } = &commands[0] {
1817            assert_eq!(text, "Follow up 1");
1818        }
1819        assert_eq!(handler.consecutive_follow_ups, 1);
1820
1821        // 4b. Simulate bot finishing speaking Follow up 1
1822        // generate_response sets is_speaking = true. We need to clear it.
1823        let event = SessionEvent::TrackEnd {
1824            track_id: "t1".to_string(),
1825            timestamp: 0,
1826            play_id: None,
1827            duration: 100,
1828            ssrc: 0,
1829        };
1830        handler.on_event(&event).await?;
1831        assert!(
1832            !handler.is_speaking,
1833            "Bot should not be speaking after TrackEnd"
1834        );
1835
1836        // 5. Simulate bot speaking (TrackStart/End updates tracking) -- Actually verifying 2nd timeout
1837        // Reset interaction time to now (done inside handler on Silence event trigger?
1838        // Yes: self.last_interaction_at = std::time::Instant::now(); in Silence block)
1839        // But we want to verify the loop.
1840
1841        // Wait again for timeout
1842        tokio::time::sleep(Duration::from_millis(110)).await;
1843        let event = SessionEvent::Silence {
1844            track_id: "t1".to_string(),
1845            timestamp: 0,
1846            start_time: 0,
1847            duration: 100,
1848            samples: None,
1849        };
1850        let commands = handler.on_event(&event).await?;
1851        assert_eq!(commands.len(), 1, "Should trigger follow-up 2");
1852        if let Command::Tts { text, .. } = &commands[0] {
1853            assert_eq!(text, "Follow up 2");
1854        }
1855        assert_eq!(handler.consecutive_follow_ups, 2);
1856
1857        // 5b. Simulate bot finishing speaking Follow up 2
1858        let event = SessionEvent::TrackEnd {
1859            track_id: "t1".to_string(),
1860            timestamp: 0,
1861            play_id: None,
1862            duration: 100,
1863            ssrc: 0,
1864        };
1865        handler.on_event(&event).await?;
1866
1867        // 6. Max count reached
1868        tokio::time::sleep(Duration::from_millis(110)).await;
1869        let event = SessionEvent::Silence {
1870            track_id: "t1".to_string(),
1871            timestamp: 0,
1872            start_time: 0,
1873            duration: 100,
1874            samples: None,
1875        };
1876        let commands = handler.on_event(&event).await?;
1877        assert_eq!(commands.len(), 1, "Should hangup after max count");
1878        assert!(matches!(commands[0], Command::Hangup { .. }));
1879
1880        // 7. Reset on user speech
1881        handler.consecutive_follow_ups = 2; // Artificially set high
1882        let event = SessionEvent::AsrFinal {
1883            track_id: "t1".to_string(),
1884            timestamp: 0,
1885            index: 0,
1886            start_time: None,
1887            end_time: None,
1888            text: "User speaks".to_string(),
1889            is_filler: None,
1890            confidence: None,
1891        };
1892        // This will trigger generate_response (consuming "Response to user" from provider)
1893        let _ = handler.on_event(&event).await?;
1894        assert_eq!(
1895            handler.consecutive_follow_ups, 0,
1896            "Should reset count on AsrFinal"
1897        );
1898
1899        Ok(())
1900    }
1901
1902    #[tokio::test]
1903    async fn test_interruption_protection_period() -> Result<()> {
1904        let provider = Arc::new(TestProvider::new(vec!["Some long response".to_string()]));
1905        let mut config = crate::playbook::InterruptionConfig::default();
1906        config.ignore_first_ms = Some(800);
1907
1908        let mut handler = LlmHandler::with_provider(
1909            LlmConfig::default(),
1910            provider,
1911            Arc::new(NoopRagRetriever),
1912            config,
1913            None,
1914            HashMap::new(),
1915            None,
1916            None,
1917        );
1918
1919        // 1. Trigger a response
1920        let event = SessionEvent::AsrFinal {
1921            track_id: "test".to_string(),
1922            timestamp: 0,
1923            index: 0,
1924            start_time: None,
1925            end_time: None,
1926            text: "hello".to_string(),
1927            is_filler: None,
1928            confidence: None,
1929        };
1930        handler.on_event(&event).await?;
1931        assert!(handler.is_speaking);
1932
1933        // 2. Simulate user starting to speak immediately (AsrDelta)
1934        let event = SessionEvent::AsrDelta {
1935            track_id: "test".to_string(),
1936            timestamp: 0,
1937            index: 0,
1938            start_time: None,
1939            end_time: None,
1940            text: "I...".to_string(),
1941            is_filler: None,
1942            confidence: None,
1943        };
1944        let commands = handler.on_event(&event).await?;
1945        // Should be ignored due to protection period
1946        assert_eq!(commands.len(), 0);
1947        assert!(handler.is_speaking);
1948
1949        Ok(())
1950    }
1951
1952    #[tokio::test]
1953    async fn test_interruption_filler_word() -> Result<()> {
1954        let provider = Arc::new(TestProvider::new(vec!["Some long response".to_string()]));
1955        let mut config = crate::playbook::InterruptionConfig::default();
1956        config.filler_word_filter = Some(true);
1957        config.ignore_first_ms = Some(0); // Disable protection period for this test
1958
1959        let mut handler = LlmHandler::with_provider(
1960            LlmConfig::default(),
1961            provider,
1962            Arc::new(NoopRagRetriever),
1963            config,
1964            None,
1965            HashMap::new(),
1966            None,
1967            None,
1968        );
1969
1970        // 1. Trigger a response
1971        let event = SessionEvent::AsrFinal {
1972            track_id: "test".to_string(),
1973            timestamp: 0,
1974            index: 0,
1975            start_time: None,
1976            end_time: None,
1977            text: "hello".to_string(),
1978            is_filler: None,
1979            confidence: None,
1980        };
1981        handler.on_event(&event).await?;
1982        assert!(handler.is_speaking);
1983
1984        // Sleep to bypass the 500ms stale event guard
1985        tokio::time::sleep(std::time::Duration::from_millis(600)).await;
1986
1987        // 2. Simulate user saying "uh" (filler)
1988        let event = SessionEvent::AsrDelta {
1989            track_id: "test".to_string(),
1990            timestamp: 0,
1991            index: 0,
1992            start_time: None,
1993            end_time: None,
1994            text: "uh".to_string(),
1995            is_filler: Some(true),
1996            confidence: None,
1997        };
1998        let commands = handler.on_event(&event).await?;
1999        // Should be ignored
2000        assert_eq!(commands.len(), 0);
2001        assert!(handler.is_speaking);
2002
2003        // 3. Simulate user saying "Wait" (not filler)
2004        let event = SessionEvent::AsrDelta {
2005            track_id: "test".to_string(),
2006            timestamp: 0,
2007            index: 0,
2008            start_time: None,
2009            end_time: None,
2010            text: "Wait".to_string(),
2011            is_filler: Some(false),
2012            confidence: None,
2013        };
2014        let commands = handler.on_event(&event).await?;
2015        // Should trigger interruption
2016        assert_eq!(commands.len(), 1);
2017        assert!(matches!(commands[0], Command::Interrupt { .. }));
2018
2019        Ok(())
2020    }
2021
2022    #[tokio::test]
2023    async fn test_eou_early_response() -> Result<()> {
2024        let provider = Arc::new(TestProvider::new(vec![
2025            "End of Utterance response".to_string(),
2026        ]));
2027        let mut handler = LlmHandler::with_provider(
2028            LlmConfig::default(),
2029            provider,
2030            Arc::new(NoopRagRetriever),
2031            crate::playbook::InterruptionConfig::default(),
2032            None,
2033            HashMap::new(),
2034            None,
2035            None,
2036        );
2037
2038        // 1. Receive EOU
2039        let event = SessionEvent::Eou {
2040            track_id: "test".to_string(),
2041            timestamp: 0,
2042            completed: true,
2043        };
2044        let commands = handler.on_event(&event).await?;
2045        assert_eq!(commands.len(), 1);
2046        if let Command::Tts { text, .. } = &commands[0] {
2047            assert_eq!(text, "End of Utterance response");
2048        } else {
2049            panic!("Expected Tts");
2050        }
2051
2052        Ok(())
2053    }
2054
2055    #[tokio::test]
2056    async fn test_summary_and_history() -> Result<()> {
2057        let provider = Arc::new(TestProvider::new(vec!["Test summary".to_string()]));
2058        let mut handler = LlmHandler::with_provider(
2059            LlmConfig::default(),
2060            provider,
2061            Arc::new(NoopRagRetriever),
2062            crate::playbook::InterruptionConfig::default(),
2063            None,
2064            HashMap::new(),
2065            None,
2066            None,
2067        );
2068
2069        // Add some history
2070        handler.history.push(ChatMessage {
2071            role: "user".to_string(),
2072            content: "Hello".to_string(),
2073        });
2074        handler.history.push(ChatMessage {
2075            role: "assistant".to_string(),
2076            content: "Hi there".to_string(),
2077        });
2078
2079        // Test get_history
2080        let history = handler.get_history().await;
2081        assert_eq!(history.len(), 3); // system + user + assistant
2082
2083        // Test summarize
2084        let summary = handler.summarize("Summarize this").await?;
2085        assert_eq!(summary, "Test summary");
2086
2087        Ok(())
2088    }
2089}