1use crate::ReferOption;
2use crate::call::Command;
3use crate::event::SessionEvent;
4use anyhow::{Result, anyhow};
5use async_trait::async_trait;
6use futures::{Stream, StreamExt};
7use once_cell::sync::Lazy;
8use regex::Regex;
9use reqwest::Client;
10use serde::{Deserialize, Serialize};
11use serde_json::json;
12use std::collections::HashMap;
13use std::pin::Pin;
14use std::sync::Arc;
15use tracing::{info, warn};
16
17static RE_HANGUP: Lazy<Regex> = Lazy::new(|| Regex::new(r"<hangup\s*/>").unwrap());
18static RE_REFER: Lazy<Regex> = Lazy::new(|| Regex::new(r#"<refer\s+to="([^"]+)"\s*/>"#).unwrap());
19static RE_PLAY: Lazy<Regex> = Lazy::new(|| Regex::new(r#"<play\s+file="([^"]+)"\s*/>"#).unwrap());
20static RE_GOTO: Lazy<Regex> = Lazy::new(|| Regex::new(r#"<goto\s+scene="([^"]+)"\s*/>"#).unwrap());
21static RE_SENTENCE: Lazy<Regex> = Lazy::new(|| Regex::new(r"(?m)[.!?。!?\n]\s*").unwrap());
22static FILLERS: Lazy<std::collections::HashSet<String>> = Lazy::new(|| {
23 let mut s = std::collections::HashSet::new();
24 let default_fillers = ["嗯", "啊", "哦", "那个", "那个...", "uh", "um", "ah"];
25
26 if let Ok(content) = std::fs::read_to_string("config/fillers.txt") {
27 for line in content.lines() {
28 let trimmed = line.trim().to_lowercase();
29 if !trimmed.is_empty() {
30 s.insert(trimmed);
31 }
32 }
33 }
34
35 if s.is_empty() {
36 for f in default_fillers {
37 s.insert(f.to_string());
38 }
39 }
40 s
41});
42
43use super::ChatMessage;
44use super::InterruptionStrategy;
45use super::LlmConfig;
46use super::dialogue::DialogueHandler;
47
48const MAX_RAG_ATTEMPTS: usize = 3;
49
50#[derive(Debug, Clone)]
51pub enum LlmStreamEvent {
52 Content(String),
53 Reasoning(String),
54}
55
56#[async_trait]
57pub trait LlmProvider: Send + Sync {
58 async fn call(&self, config: &LlmConfig, history: &[ChatMessage]) -> Result<String>;
59 async fn call_stream(
60 &self,
61 config: &LlmConfig,
62 history: &[ChatMessage],
63 ) -> Result<Pin<Box<dyn Stream<Item = Result<LlmStreamEvent>> + Send>>>;
64}
65
66pub struct RealtimeResponse {
67 pub audio_delta: Option<Vec<u8>>,
68 pub text_delta: Option<String>,
69 pub function_call: Option<ToolInvocation>,
70 pub speech_started: bool,
71}
72
73#[async_trait]
74pub trait RealtimeProvider: Send + Sync {
75 async fn connect(&self, config: &LlmConfig) -> Result<()>;
76 async fn send_audio(&self, audio: &[i16]) -> Result<()>;
77 async fn subscribe(
78 &self,
79 ) -> Result<Pin<Box<dyn Stream<Item = Result<RealtimeResponse>> + Send>>>;
80}
81
82struct DefaultLlmProvider {
83 client: Client,
84}
85
86impl DefaultLlmProvider {
87 fn new() -> Self {
88 Self {
89 client: Client::new(),
90 }
91 }
92}
93
94#[async_trait]
95impl LlmProvider for DefaultLlmProvider {
96 async fn call(&self, config: &LlmConfig, history: &[ChatMessage]) -> Result<String> {
97 let mut url = config
98 .base_url
99 .clone()
100 .unwrap_or_else(|| "https://api.openai.com/v1".to_string());
101 let model = config
102 .model
103 .clone()
104 .unwrap_or_else(|| "gpt-4-turbo".to_string());
105 let api_key = config.api_key.clone().unwrap_or_default();
106
107 if !url.ends_with("/chat/completions") {
108 url = format!("{}/chat/completions", url.trim_end_matches('/'));
109 }
110
111 let body = json!({
112 "model": model,
113 "messages": history,
114 });
115
116 let res = self
117 .client
118 .post(&url)
119 .header("Authorization", format!("Bearer {}", api_key))
120 .json(&body)
121 .send()
122 .await?;
123
124 if !res.status().is_success() {
125 return Err(anyhow!("LLM request failed: {}", res.status()));
126 }
127
128 let json: serde_json::Value = res.json().await?;
129 let content = json["choices"][0]["message"]["content"]
130 .as_str()
131 .ok_or_else(|| anyhow!("Invalid LLM response"))?
132 .to_string();
133
134 Ok(content)
135 }
136
137 async fn call_stream(
138 &self,
139 config: &LlmConfig,
140 history: &[ChatMessage],
141 ) -> Result<Pin<Box<dyn Stream<Item = Result<LlmStreamEvent>> + Send>>> {
142 let mut url = config
143 .base_url
144 .clone()
145 .unwrap_or_else(|| "https://api.openai.com/v1".to_string());
146 let model = config
147 .model
148 .clone()
149 .unwrap_or_else(|| "gpt-4-turbo".to_string());
150 let api_key = config.api_key.clone().unwrap_or_default();
151
152 if !url.ends_with("/chat/completions") {
153 url = format!("{}/chat/completions", url.trim_end_matches('/'));
154 }
155
156 let body = json!({
157 "model": model,
158 "messages": history,
159 "stream": true,
160 });
161
162 let res = self
163 .client
164 .post(&url)
165 .header("Authorization", format!("Bearer {}", api_key))
166 .json(&body)
167 .send()
168 .await?;
169
170 if !res.status().is_success() {
171 return Err(anyhow!("LLM request failed: {}", res.status()));
172 }
173
174 let stream = res.bytes_stream();
175 let s = async_stream::stream! {
176 let mut buffer = String::new();
177 for await chunk in stream {
178 match chunk {
179 Ok(bytes) => {
180 let text = String::from_utf8_lossy(&bytes);
181 buffer.push_str(&text);
182
183 while let Some(line_end) = buffer.find('\n') {
184 let line = buffer[..line_end].trim();
185 if line.starts_with("data:") {
186 let data = &line[5..].trim();
187 if *data == "[DONE]" {
188 break;
189 }
190 if let Ok(json) = serde_json::from_str::<serde_json::Value>(data) {
191 if let Some(delta) = json["choices"][0].get("delta") {
192 if let Some(thinking) = delta.get("reasoning_content").and_then(|v| v.as_str()) {
193 yield Ok(LlmStreamEvent::Reasoning(thinking.to_string()));
194 }
195 if let Some(content) = delta.get("content").and_then(|v| v.as_str()) {
196 yield Ok(LlmStreamEvent::Content(content.to_string()));
197 }
198 }
199 }
200 }
201 buffer.drain(..=line_end);
202 }
203 }
204 Err(e) => yield Err(anyhow!(e)),
205 }
206 }
207 };
208
209 Ok(Box::pin(s))
210 }
211}
212
213#[async_trait]
214pub trait RagRetriever: Send + Sync {
215 async fn retrieve(&self, query: &str) -> Result<String>;
216}
217
218struct NoopRagRetriever;
219
220#[async_trait]
221impl RagRetriever for NoopRagRetriever {
222 async fn retrieve(&self, _query: &str) -> Result<String> {
223 Ok(String::new())
224 }
225}
226
227#[derive(Debug, Deserialize)]
228#[serde(rename_all = "camelCase")]
229struct StructuredResponse {
230 text: Option<String>,
231 wait_input_timeout: Option<u32>,
232 tools: Option<Vec<ToolInvocation>>,
233}
234
235#[derive(Debug, Deserialize, Serialize, Clone)]
236#[serde(tag = "name", rename_all = "lowercase")]
237pub enum ToolInvocation {
238 #[serde(rename_all = "camelCase")]
239 Hangup {
240 reason: Option<String>,
241 initiator: Option<String>,
242 },
243 #[serde(rename_all = "camelCase")]
244 Refer {
245 caller: String,
246 callee: String,
247 options: Option<ReferOption>,
248 },
249 #[serde(rename_all = "camelCase")]
250 Rag {
251 query: String,
252 source: Option<String>,
253 },
254 #[serde(rename_all = "camelCase")]
255 Accept { options: Option<crate::CallOption> },
256 #[serde(rename_all = "camelCase")]
257 Reject {
258 reason: Option<String>,
259 code: Option<u32>,
260 },
261 #[serde(rename_all = "camelCase")]
262 Http {
263 url: String,
264 method: Option<String>,
265 body: Option<serde_json::Value>,
266 headers: Option<HashMap<String, String>>,
267 },
268}
269
270pub struct LlmHandler {
271 config: LlmConfig,
272 interruption_config: super::InterruptionConfig,
273 global_follow_up_config: Option<super::FollowUpConfig>,
274 dtmf_config: Option<HashMap<String, super::DtmfAction>>,
275 history: Vec<ChatMessage>,
276 provider: Arc<dyn LlmProvider>,
277 rag_retriever: Arc<dyn RagRetriever>,
278 is_speaking: bool,
279 is_hanging_up: bool,
280 consecutive_follow_ups: u32,
281 last_interaction_at: std::time::Instant,
282 event_sender: Option<crate::event::EventSender>,
283 last_asr_final_at: Option<std::time::Instant>,
284 last_tts_start_at: Option<std::time::Instant>,
285 call: Option<crate::call::ActiveCallRef>,
286 scenes: HashMap<String, super::Scene>,
287 current_scene_id: Option<String>,
288 client: Client,
289}
290
291impl LlmHandler {
292 pub fn new(
293 config: LlmConfig,
294 interruption: super::InterruptionConfig,
295 global_follow_up_config: Option<super::FollowUpConfig>,
296 scenes: HashMap<String, super::Scene>,
297 dtmf: Option<HashMap<String, super::DtmfAction>>,
298 initial_scene_id: Option<String>,
299 ) -> Self {
300 Self::with_provider(
301 config,
302 Arc::new(DefaultLlmProvider::new()),
303 Arc::new(NoopRagRetriever),
304 interruption,
305 global_follow_up_config,
306 scenes,
307 dtmf,
308 initial_scene_id,
309 )
310 }
311
312 pub fn with_provider(
313 config: LlmConfig,
314 provider: Arc<dyn LlmProvider>,
315 rag_retriever: Arc<dyn RagRetriever>,
316 interruption: super::InterruptionConfig,
317 global_follow_up_config: Option<super::FollowUpConfig>,
318 scenes: HashMap<String, super::Scene>,
319 dtmf: Option<HashMap<String, super::DtmfAction>>,
320 initial_scene_id: Option<String>,
321 ) -> Self {
322 let mut history = Vec::new();
323 let system_prompt = Self::build_system_prompt(&config, None);
324
325 history.push(ChatMessage {
326 role: "system".to_string(),
327 content: system_prompt,
328 });
329
330 Self {
331 config,
332 interruption_config: interruption,
333 global_follow_up_config,
334 dtmf_config: dtmf,
335 history,
336 provider,
337 rag_retriever,
338 is_speaking: false,
339 is_hanging_up: false,
340 consecutive_follow_ups: 0,
341 last_interaction_at: std::time::Instant::now(),
342 event_sender: None,
343 last_asr_final_at: None,
344 last_tts_start_at: None,
345 call: None,
346 scenes,
347 current_scene_id: initial_scene_id,
348 client: Client::new(),
349 }
350 }
351
352 fn build_system_prompt(config: &LlmConfig, scene_prompt: Option<&str>) -> String {
353 let base_prompt =
354 scene_prompt.unwrap_or_else(|| config.prompt.as_deref().unwrap_or_default());
355 let mut features_prompt = String::new();
356
357 if let Some(features) = &config.features {
358 let lang = config.language.as_deref().unwrap_or("zh");
359 for feature in features {
360 match Self::load_feature_snippet(feature, lang) {
361 Ok(snippet) => {
362 features_prompt.push_str(&format!("\n- {}", snippet));
363 }
364 Err(e) => {
365 warn!("Failed to load feature snippet {}: {}", feature, e);
366 }
367 }
368 }
369 }
370
371 let features_section = if features_prompt.is_empty() {
372 String::new()
373 } else {
374 format!("\n\n### Enhanced Capabilities:{}\n", features_prompt)
375 };
376
377 format!(
378 "{}{}\n\n\
379 Tool usage instructions:\n\
380 - To hang up the call, output: <hangup/>\n\
381 - To transfer the call, output: <refer to=\"sip:xxxx\"/>\n\
382 - To play an audio file, output: <play file=\"path/to/file.wav\"/>\n\
383 - To switch to another scene, output: <goto scene=\"scene_id\"/>\n\
384 - To call an external HTTP API, output JSON:\n\
385 ```json\n\
386 {{ \"tools\": [{{ \"name\": \"http\", \"url\": \"...\", \"method\": \"POST\", \"body\": {{ ... }} }}] }}\n\
387 ```\n\
388 Please use XML tags for simple actions and JSON blocks for tool calls. \
389 Output your response in short sentences. Each sentence will be played as soon as it is finished.",
390 base_prompt, features_section
391 )
392 }
393
394 fn load_feature_snippet(feature: &str, lang: &str) -> Result<String> {
395 let path = format!("features/{}.{}.md", feature, lang);
396 let content = std::fs::read_to_string(path)?;
397 Ok(content.trim().to_string())
398 }
399
400 fn get_dtmf_action(&self, digit: &str) -> Option<super::DtmfAction> {
401 if let Some(scene_id) = &self.current_scene_id {
403 if let Some(scene) = self.scenes.get(scene_id) {
404 if let Some(dtmf) = &scene.dtmf {
405 if let Some(action) = dtmf.get(digit) {
406 return Some(action.clone());
407 }
408 }
409 }
410 }
411
412 if let Some(dtmf) = &self.dtmf_config {
414 if let Some(action) = dtmf.get(digit) {
415 return Some(action.clone());
416 }
417 }
418
419 None
420 }
421
422 async fn handle_dtmf_action(&mut self, action: super::DtmfAction) -> Result<Vec<Command>> {
423 match action {
424 super::DtmfAction::Goto { scene } => {
425 info!("DTMF action: switch to scene {}", scene);
426 self.switch_to_scene(&scene, true).await
427 }
428 super::DtmfAction::Transfer { target } => {
429 info!("DTMF action: transfer to {}", target);
430 Ok(vec![Command::Refer {
431 caller: String::new(),
432 callee: target,
433 options: None,
434 }])
435 }
436 super::DtmfAction::Hangup => {
437 info!("DTMF action: hangup");
438 Ok(vec![Command::Hangup {
439 reason: Some("DTMF Hangup".to_string()),
440 initiator: Some("ai".to_string()),
441 }])
442 }
443 }
444 }
445
446 async fn switch_to_scene(
447 &mut self,
448 scene_id: &str,
449 trigger_response: bool,
450 ) -> Result<Vec<Command>> {
451 if let Some(scene) = self.scenes.get(scene_id).cloned() {
452 info!("Switching to scene: {}", scene_id);
453 self.current_scene_id = Some(scene_id.to_string());
454 let system_prompt = Self::build_system_prompt(&self.config, Some(&scene.prompt));
456 if let Some(first_msg) = self.history.get_mut(0) {
457 if first_msg.role == "system" {
458 first_msg.content = system_prompt;
459 }
460 }
461
462 let mut commands = Vec::new();
463 if let Some(url) = &scene.play {
464 commands.push(Command::Play {
465 url: url.clone(),
466 play_id: None,
467 auto_hangup: None,
468 wait_input_timeout: None,
469 });
470 }
471
472 if trigger_response {
473 let response_cmds = self.generate_response().await?;
474 commands.extend(response_cmds);
475 }
476 Ok(commands)
477 } else {
478 warn!("Scene not found: {}", scene_id);
479 Ok(vec![])
480 }
481 }
482
483 pub fn get_history_ref(&self) -> &[ChatMessage] {
484 &self.history
485 }
486
487 pub fn get_current_scene_id(&self) -> Option<String> {
488 self.current_scene_id.clone()
489 }
490
491 pub fn set_call(&mut self, call: crate::call::ActiveCallRef) {
492 self.call = Some(call);
493 }
494
495 pub fn set_event_sender(&mut self, sender: crate::event::EventSender) {
496 self.event_sender = Some(sender.clone());
497 if let Some(greeting) = &self.config.greeting {
499 let _ = sender.send(crate::event::SessionEvent::AddHistory {
500 sender: Some("system".to_string()),
501 timestamp: crate::media::get_timestamp(),
502 speaker: "assistant".to_string(),
503 text: greeting.clone(),
504 });
505 }
506 }
507
508 fn send_debug_event(&self, key: &str, data: serde_json::Value) {
509 if let Some(sender) = &self.event_sender {
510 let timestamp = crate::media::get_timestamp();
511 if key == "llm_response" {
513 if let Some(text) = data.get("response").and_then(|v| v.as_str()) {
514 let _ = sender.send(crate::event::SessionEvent::AddHistory {
515 sender: Some("llm".to_string()),
516 timestamp,
517 speaker: "assistant".to_string(),
518 text: text.to_string(),
519 });
520 }
521 }
522
523 let event = crate::event::SessionEvent::Metrics {
524 timestamp,
525 key: key.to_string(),
526 duration: 0,
527 data,
528 };
529 let _ = sender.send(event);
530 }
531 }
532
533 async fn call_llm(&self) -> Result<String> {
534 self.provider.call(&self.config, &self.history).await
535 }
536
537 fn create_tts_command(
538 &self,
539 text: String,
540 wait_input_timeout: Option<u32>,
541 auto_hangup: Option<bool>,
542 ) -> Command {
543 let timeout = wait_input_timeout.unwrap_or(10000);
544 let play_id = uuid::Uuid::new_v4().to_string();
545
546 if let Some(sender) = &self.event_sender {
547 let _ = sender.send(crate::event::SessionEvent::Metrics {
548 timestamp: crate::media::get_timestamp(),
549 key: "tts_play_id_map".to_string(),
550 duration: 0,
551 data: serde_json::json!({
552 "playId": play_id,
553 "text": text,
554 }),
555 });
556 }
557
558 Command::Tts {
559 text,
560 speaker: None,
561 play_id: Some(play_id),
562 auto_hangup,
563 streaming: None,
564 end_of_stream: Some(true),
565 option: None,
566 wait_input_timeout: Some(timeout),
567 base64: None,
568 }
569 }
570
571 async fn generate_response(&mut self) -> Result<Vec<Command>> {
572 let start_time = crate::media::get_timestamp();
573 let play_id = uuid::Uuid::new_v4().to_string();
574
575 self.send_debug_event(
577 "llm_call_start",
578 json!({
579 "history_length": self.history.len(),
580 "playId": play_id,
581 }),
582 );
583
584 let mut stream = self
585 .provider
586 .call_stream(&self.config, &self.history)
587 .await?;
588
589 let mut full_content = String::new();
590 let mut full_reasoning = String::new();
591 let mut buffer = String::new();
592 let mut commands = Vec::new();
593 let mut is_json_mode = false;
594 let mut checked_json_mode = false;
595 let mut first_token_time = None;
596
597 while let Some(chunk_result) = stream.next().await {
598 let event = match chunk_result {
599 Ok(c) => c,
600 Err(e) => {
601 warn!("LLM stream error: {}", e);
602 break;
603 }
604 };
605
606 match event {
607 LlmStreamEvent::Reasoning(text) => {
608 full_reasoning.push_str(&text);
609 }
610 LlmStreamEvent::Content(chunk) => {
611 if first_token_time.is_none() && !chunk.trim().is_empty() {
612 first_token_time = Some(crate::media::get_timestamp());
613 }
614
615 full_content.push_str(&chunk);
616 buffer.push_str(&chunk);
617
618 if !checked_json_mode {
619 let trimmed = full_content.trim();
620 if !trimmed.is_empty() {
621 if trimmed.starts_with('{') || trimmed.starts_with('`') {
622 is_json_mode = true;
623 }
624 checked_json_mode = true;
625 }
626 }
627
628 if checked_json_mode && !is_json_mode {
629 let extracted =
630 self.extract_streaming_commands(&mut buffer, &play_id, false);
631 for cmd in extracted {
632 if let Some(call) = &self.call {
633 let _ = call.enqueue_command(cmd).await;
634 } else {
635 commands.push(cmd);
636 }
637 }
638 }
639 }
640 }
641 }
642
643 let end_time = crate::media::get_timestamp();
645 self.send_debug_event(
646 "llm_response",
647 json!({
648 "response": full_content,
649 "reasoning": full_reasoning,
650 "is_json_mode": is_json_mode,
651 "duration": end_time - start_time,
652 "ttfb": first_token_time.map(|t| t - start_time).unwrap_or(0),
653 "playId": play_id,
654 }),
655 );
656
657 if is_json_mode {
658 self.interpret_response(full_content).await
659 } else {
660 let extracted = self.extract_streaming_commands(&mut buffer, &play_id, true);
661 for cmd in extracted {
662 if let Some(call) = &self.call {
663 let _ = call.enqueue_command(cmd).await;
664 } else {
665 commands.push(cmd);
666 }
667 }
668 if !full_content.trim().is_empty() {
669 self.history.push(ChatMessage {
670 role: "assistant".to_string(),
671 content: full_content,
672 });
673 self.is_speaking = true;
674 self.last_tts_start_at = Some(std::time::Instant::now());
675 }
676 Ok(commands)
677 }
678 }
679
680 fn extract_streaming_commands(
681 &mut self,
682 buffer: &mut String,
683 play_id: &str,
684 is_final: bool,
685 ) -> Vec<Command> {
686 let mut commands = Vec::new();
687
688 loop {
689 let hangup_pos = RE_HANGUP.find(buffer);
690 let refer_pos = RE_REFER.captures(buffer);
691 let play_pos = RE_PLAY.captures(buffer);
692 let goto_pos = RE_GOTO.captures(buffer);
693 let sentence_pos = RE_SENTENCE.find(buffer);
694
695 let mut positions = Vec::new();
697 if let Some(m) = hangup_pos {
698 positions.push((m.start(), 0));
699 }
700 if let Some(caps) = &refer_pos {
701 positions.push((caps.get(0).unwrap().start(), 1));
702 }
703 if let Some(caps) = &play_pos {
704 positions.push((caps.get(0).unwrap().start(), 3));
705 }
706 if let Some(caps) = &goto_pos {
707 positions.push((caps.get(0).unwrap().start(), 4));
708 }
709 if let Some(m) = sentence_pos {
710 positions.push((m.start(), 2));
711 }
712
713 positions.sort_by_key(|p| p.0);
714
715 if let Some((pos, kind)) = positions.first() {
716 let pos = *pos;
717 match kind {
718 0 => {
719 let prefix = buffer[..pos].to_string();
721 if !prefix.trim().is_empty() {
722 let mut cmd = self.create_tts_command_with_id(
723 prefix,
724 play_id.to_string(),
725 Some(true),
726 );
727 if let Command::Tts { end_of_stream, .. } = &mut cmd {
728 *end_of_stream = Some(true);
729 }
730 self.is_hanging_up = true;
732 commands.push(cmd);
733 } else {
734 let mut cmd = self.create_tts_command_with_id(
736 "".to_string(),
737 play_id.to_string(),
738 Some(true),
739 );
740 if let Command::Tts { end_of_stream, .. } = &mut cmd {
741 *end_of_stream = Some(true);
742 }
743 self.is_hanging_up = true;
744 commands.push(cmd);
745 }
746 buffer.drain(..RE_HANGUP.find(buffer).unwrap().end());
747 return commands;
749 }
750 1 => {
751 let caps = RE_REFER.captures(buffer).unwrap();
753 let mat = caps.get(0).unwrap();
754 let callee = caps.get(1).unwrap().as_str().to_string();
755
756 let prefix = buffer[..pos].to_string();
757 if !prefix.trim().is_empty() {
758 commands.push(self.create_tts_command_with_id(
759 prefix,
760 play_id.to_string(),
761 None,
762 ));
763 }
764 commands.push(Command::Refer {
765 caller: String::new(),
766 callee,
767 options: None,
768 });
769 buffer.drain(..mat.end());
770 }
771 3 => {
772 let caps = RE_PLAY.captures(buffer).unwrap();
774 let mat = caps.get(0).unwrap();
775 let url = caps.get(1).unwrap().as_str().to_string();
776
777 let prefix = buffer[..pos].to_string();
778 if !prefix.trim().is_empty() {
779 commands.push(self.create_tts_command_with_id(
780 prefix,
781 play_id.to_string(),
782 None,
783 ));
784 }
785 commands.push(Command::Play {
786 url,
787 play_id: None,
788 auto_hangup: None,
789 wait_input_timeout: None,
790 });
791 buffer.drain(..mat.end());
792 }
793 4 => {
794 let caps = RE_GOTO.captures(buffer).unwrap();
796 let mat = caps.get(0).unwrap();
797 let scene_id = caps.get(1).unwrap().as_str().to_string();
798
799 let prefix = buffer[..pos].to_string();
800 if !prefix.trim().is_empty() {
801 commands.push(self.create_tts_command_with_id(
802 prefix,
803 play_id.to_string(),
804 None,
805 ));
806 }
807
808 info!("Switching to scene (from stream): {}", scene_id);
809 if let Some(scene) = self.scenes.get(&scene_id) {
810 self.current_scene_id = Some(scene_id);
811 let system_prompt =
813 Self::build_system_prompt(&self.config, Some(&scene.prompt));
814 if let Some(first_msg) = self.history.get_mut(0) {
815 if first_msg.role == "system" {
816 first_msg.content = system_prompt;
817 }
818 }
819 } else {
820 warn!("Scene not found: {}", scene_id);
821 }
822
823 buffer.drain(..mat.end());
824 }
825 2 => {
826 let mat = sentence_pos.unwrap();
828 let sentence = buffer[..mat.end()].to_string();
829 if !sentence.trim().is_empty() {
830 commands.push(self.create_tts_command_with_id(
831 sentence,
832 play_id.to_string(),
833 None,
834 ));
835 }
836 buffer.drain(..mat.end());
837 }
838 _ => unreachable!(),
839 }
840 } else {
841 break;
842 }
843 }
844
845 if is_final {
846 let remaining = buffer.trim().to_string();
847 if !remaining.is_empty() {
848 commands.push(self.create_tts_command_with_id(
849 remaining,
850 play_id.to_string(),
851 None,
852 ));
853 }
854 buffer.clear();
855
856 if let Some(last) = commands.last_mut() {
857 if let Command::Tts { end_of_stream, .. } = last {
858 *end_of_stream = Some(true);
859 }
860 } else if !self.is_hanging_up {
861 commands.push(Command::Tts {
862 text: "".to_string(),
863 speaker: None,
864 play_id: Some(play_id.to_string()),
865 auto_hangup: None,
866 streaming: Some(true),
867 end_of_stream: Some(true),
868 option: None,
869 wait_input_timeout: None,
870 base64: None,
871 });
872 }
873 }
874
875 commands
876 }
877
878 fn create_tts_command_with_id(
879 &self,
880 text: String,
881 play_id: String,
882 auto_hangup: Option<bool>,
883 ) -> Command {
884 Command::Tts {
885 text,
886 speaker: None,
887 play_id: Some(play_id),
888 auto_hangup,
889 streaming: Some(true),
890 end_of_stream: None,
891 option: None,
892 wait_input_timeout: Some(10000),
893 base64: None,
894 }
895 }
896
897 async fn interpret_response(&mut self, initial: String) -> Result<Vec<Command>> {
898 let mut tool_commands = Vec::new();
899 let mut wait_input_timeout = None;
900 let mut attempts = 0;
901 let final_text: Option<String>;
902 let mut raw = initial;
903
904 loop {
905 attempts += 1;
906 let mut rerun_for_rag = false;
907
908 if let Some(structured) = parse_structured_response(&raw) {
909 if wait_input_timeout.is_none() {
910 wait_input_timeout = structured.wait_input_timeout;
911 }
912
913 if let Some(tools) = structured.tools {
914 for tool in tools {
915 match tool {
916 ToolInvocation::Hangup {
917 ref reason,
918 ref initiator,
919 } => {
920 self.send_debug_event(
922 "tool_invocation",
923 json!({
924 "tool": "Hangup",
925 "params": {
926 "reason": reason,
927 "initiator": initiator,
928 }
929 }),
930 );
931 tool_commands.push(Command::Hangup {
932 reason: reason.clone(),
933 initiator: initiator.clone(),
934 });
935 }
936 ToolInvocation::Refer {
937 ref caller,
938 ref callee,
939 ref options,
940 } => {
941 self.send_debug_event(
943 "tool_invocation",
944 json!({
945 "tool": "Refer",
946 "params": {
947 "caller": caller,
948 "callee": callee,
949 }
950 }),
951 );
952 tool_commands.push(Command::Refer {
953 caller: caller.clone(),
954 callee: callee.clone(),
955 options: options.clone(),
956 });
957 }
958 ToolInvocation::Rag {
959 ref query,
960 ref source,
961 } => {
962 self.send_debug_event(
964 "tool_invocation",
965 json!({
966 "tool": "Rag",
967 "params": {
968 "query": query,
969 "source": source,
970 }
971 }),
972 );
973
974 let rag_result = self.rag_retriever.retrieve(&query).await?;
975
976 self.send_debug_event(
978 "rag_result",
979 json!({
980 "query": query,
981 "result": rag_result,
982 }),
983 );
984
985 let summary = if let Some(source) = source {
986 format!("[{}] {}", source, rag_result)
987 } else {
988 rag_result
989 };
990 self.history.push(ChatMessage {
991 role: "system".to_string(),
992 content: format!("RAG result for {}: {}", query, summary),
993 });
994 rerun_for_rag = true;
995 }
996 ToolInvocation::Accept { ref options } => {
997 self.send_debug_event(
998 "tool_invocation",
999 json!({
1000 "tool": "Accept",
1001 }),
1002 );
1003 tool_commands.push(Command::Accept {
1004 option: options.clone().unwrap_or_default(),
1005 });
1006 }
1007 ToolInvocation::Reject { ref reason, code } => {
1008 self.send_debug_event(
1009 "tool_invocation",
1010 json!({
1011 "tool": "Reject",
1012 "params": {
1013 "reason": reason,
1014 "code": code,
1015 }
1016 }),
1017 );
1018 tool_commands.push(Command::Reject {
1019 reason: reason
1020 .clone()
1021 .unwrap_or_else(|| "Rejected by agent".to_string()),
1022 code,
1023 });
1024 }
1025 ToolInvocation::Http {
1026 ref url,
1027 ref method,
1028 ref body,
1029 ref headers,
1030 } => {
1031 let method_str = method.as_deref().unwrap_or("GET").to_uppercase();
1032 let method = reqwest::Method::from_bytes(method_str.as_bytes())
1033 .unwrap_or(reqwest::Method::GET);
1034
1035 self.send_debug_event(
1037 "tool_invocation",
1038 json!({
1039 "tool": "Http",
1040 "params": {
1041 "url": url,
1042 "method": method_str,
1043 }
1044 }),
1045 );
1046
1047 let mut req = self.client.request(method, url);
1048 if let Some(body) = body {
1049 req = req.json(body);
1050 }
1051 if let Some(headers) = headers {
1052 for (k, v) in headers {
1053 req = req.header(k, v);
1054 }
1055 }
1056
1057 match req.send().await {
1058 Ok(res) => {
1059 let status = res.status();
1060 let text = res.text().await.unwrap_or_default();
1061 self.history.push(ChatMessage {
1062 role: "system".to_string(),
1063 content: format!(
1064 "HTTP tool response ({}): {}",
1065 status, text
1066 ),
1067 });
1068 }
1069 Err(e) => {
1070 warn!("HTTP tool failed: {}", e);
1071 self.history.push(ChatMessage {
1072 role: "system".to_string(),
1073 content: format!("HTTP tool failed: {}", e),
1074 });
1075 }
1076 }
1077 rerun_for_rag = true;
1078 }
1079 }
1080 }
1081 }
1082
1083 if rerun_for_rag {
1084 if attempts >= MAX_RAG_ATTEMPTS {
1085 warn!("Reached RAG iteration limit, using last response");
1086 final_text = structured.text.or_else(|| Some(raw.clone()));
1087 break;
1088 }
1089 raw = self.call_llm().await?;
1090 continue;
1091 }
1092
1093 final_text = structured.text;
1094 break;
1095 }
1096
1097 final_text = Some(raw.clone());
1098 break;
1099 }
1100
1101 let mut commands = Vec::new();
1102
1103 let mut has_hangup = false;
1105 for tool in &tool_commands {
1106 if matches!(tool, Command::Hangup { .. }) {
1107 has_hangup = true;
1108 break;
1109 }
1110 }
1111
1112 if let Some(text) = final_text {
1113 if !text.trim().is_empty() {
1114 self.history.push(ChatMessage {
1115 role: "assistant".to_string(),
1116 content: text.clone(),
1117 });
1118 self.last_tts_start_at = Some(std::time::Instant::now());
1119 self.is_speaking = true;
1120
1121 if has_hangup {
1124 commands.push(self.create_tts_command(text, wait_input_timeout, Some(true)));
1125 tool_commands.retain(|c| !matches!(c, Command::Hangup { .. }));
1126 self.is_hanging_up = true;
1128 } else {
1129 commands.push(self.create_tts_command(text, wait_input_timeout, None));
1130 }
1131 }
1132 }
1133
1134 commands.extend(tool_commands);
1135
1136 Ok(commands)
1137 }
1138}
1139
1140fn parse_structured_response(raw: &str) -> Option<StructuredResponse> {
1141 let payload = extract_json_block(raw)?;
1142 serde_json::from_str(payload).ok()
1143}
1144
1145fn is_likely_filler(text: &str) -> bool {
1146 let trimmed = text.trim().to_lowercase();
1147 FILLERS.contains(&trimmed)
1148}
1149
1150fn extract_json_block(raw: &str) -> Option<&str> {
1151 let trimmed = raw.trim();
1152 if trimmed.starts_with('`') {
1153 if let Some(end) = trimmed.rfind("```") {
1154 if end <= 3 {
1155 return None;
1156 }
1157 let mut inner = &trimmed[3..end];
1158 inner = inner.trim();
1159 if inner.to_lowercase().starts_with("json") {
1160 if let Some(newline) = inner.find('\n') {
1161 inner = inner[newline + 1..].trim();
1162 } else if inner.len() > 4 {
1163 inner = inner[4..].trim();
1164 } else {
1165 inner = inner.trim();
1166 }
1167 }
1168 return Some(inner);
1169 }
1170 } else if trimmed.starts_with('{') || trimmed.starts_with('[') {
1171 return Some(trimmed);
1172 }
1173 None
1174}
1175
1176#[async_trait]
1177impl DialogueHandler for LlmHandler {
1178 async fn on_start(&mut self) -> Result<Vec<Command>> {
1179 self.last_tts_start_at = Some(std::time::Instant::now());
1180
1181 let mut commands = Vec::new();
1182
1183 if let Some(scene_id) = &self.current_scene_id {
1185 if let Some(scene) = self.scenes.get(scene_id) {
1186 if let Some(audio_file) = &scene.play {
1187 commands.push(Command::Play {
1188 url: audio_file.clone(),
1189 play_id: None,
1190 auto_hangup: None,
1191 wait_input_timeout: None,
1192 });
1193 }
1194 }
1195 }
1196
1197 if let Some(greeting) = &self.config.greeting {
1198 self.is_speaking = true;
1199 commands.push(self.create_tts_command(greeting.clone(), None, None));
1200 return Ok(commands);
1201 }
1202
1203 let response_commands = self.generate_response().await?;
1204 commands.extend(response_commands);
1205 Ok(commands)
1206 }
1207
1208 async fn on_event(&mut self, event: &SessionEvent) -> Result<Vec<Command>> {
1209 match event {
1210 SessionEvent::Dtmf { digit, .. } => {
1211 info!("DTMF received: {}", digit);
1212 let action = self.get_dtmf_action(digit);
1213 if let Some(action) = action {
1214 return self.handle_dtmf_action(action).await;
1215 }
1216 Ok(vec![])
1217 }
1218
1219 SessionEvent::AsrFinal { text, .. } => {
1220 if text.trim().is_empty() {
1221 return Ok(vec![]);
1222 }
1223
1224 self.last_asr_final_at = Some(std::time::Instant::now());
1225 self.last_interaction_at = std::time::Instant::now();
1226 self.is_speaking = false;
1227 self.consecutive_follow_ups = 0;
1228
1229 self.history.push(ChatMessage {
1230 role: "user".to_string(),
1231 content: text.clone(),
1232 });
1233
1234 self.generate_response().await
1235 }
1236
1237 SessionEvent::AsrDelta { is_filler, .. } | SessionEvent::Speaking { is_filler, .. } => {
1238 let strategy = self.interruption_config.strategy;
1239 let should_check = match (strategy, event) {
1240 (InterruptionStrategy::None, _) => false,
1241 (InterruptionStrategy::Vad, SessionEvent::Speaking { .. }) => true,
1242 (InterruptionStrategy::Asr, SessionEvent::AsrDelta { .. }) => true,
1243 (InterruptionStrategy::Both, _) => true,
1244 _ => false,
1245 };
1246
1247 if self.is_speaking && !self.is_hanging_up && should_check {
1249 if let Some(last_start) = self.last_tts_start_at {
1251 let ignore_ms = self.interruption_config.ignore_first_ms.unwrap_or(800);
1252 if last_start.elapsed().as_millis() < ignore_ms as u128 {
1253 return Ok(vec![]);
1254 }
1255 }
1256
1257 if self.interruption_config.filler_word_filter.unwrap_or(false) {
1259 if let Some(true) = is_filler {
1260 return Ok(vec![]);
1261 }
1262 if let SessionEvent::AsrDelta { text, .. } = event {
1264 if is_likely_filler(text) {
1265 return Ok(vec![]);
1266 }
1267 }
1268 }
1269
1270 if let Some(last_final) = self.last_asr_final_at {
1272 if last_final.elapsed().as_millis() < 500 {
1273 return Ok(vec![]);
1274 }
1275 }
1276
1277 info!("Smart interruption detected, stopping playback");
1278 self.is_speaking = false;
1279 return Ok(vec![Command::Interrupt {
1280 graceful: Some(true),
1281 fade_out_ms: self.interruption_config.volume_fade_ms,
1282 }]);
1283 }
1284 Ok(vec![])
1285 }
1286
1287 SessionEvent::Eou { completed, .. } => {
1288 if *completed && self.is_speaking == false {
1289 info!("EOU detected, triggering early response");
1290 return self.generate_response().await;
1291 }
1292 Ok(vec![])
1293 }
1294
1295 SessionEvent::Silence { .. } => {
1296 let follow_up_config = if let Some(scene_id) = &self.current_scene_id {
1297 self.scenes
1298 .get(scene_id)
1299 .and_then(|s| s.follow_up)
1300 .or(self.global_follow_up_config)
1301 } else {
1302 self.global_follow_up_config
1303 };
1304
1305 if let Some(config) = follow_up_config {
1306 if !self.is_speaking
1307 && self.last_interaction_at.elapsed().as_millis() as u64 >= config.timeout
1308 {
1309 if self.consecutive_follow_ups >= config.max_count {
1310 info!("Max follow-up count reached, hanging up");
1311 return Ok(vec![Command::Hangup {
1312 reason: Some("Max follow-up reached".to_string()),
1313 initiator: Some("system".to_string()),
1314 }]);
1315 }
1316
1317 info!(
1318 "Silence timeout detected ({}ms), triggering follow-up ({}/{})",
1319 self.last_interaction_at.elapsed().as_millis(),
1320 self.consecutive_follow_ups + 1,
1321 config.max_count
1322 );
1323 self.consecutive_follow_ups += 1;
1324 self.last_interaction_at = std::time::Instant::now();
1325 return self.generate_response().await;
1326 }
1327 }
1328 Ok(vec![])
1329 }
1330
1331 SessionEvent::TrackStart { .. } => {
1332 self.is_speaking = true;
1333 Ok(vec![])
1334 }
1335
1336 SessionEvent::TrackEnd { .. } => {
1337 self.is_speaking = false;
1338 self.is_hanging_up = false;
1339 self.last_interaction_at = std::time::Instant::now();
1340 Ok(vec![])
1341 }
1342
1343 SessionEvent::FunctionCall {
1344 name, arguments, ..
1345 } => {
1346 info!(
1347 "Function call from Realtime: {} with args {}",
1348 name, arguments
1349 );
1350 let args: serde_json::Value = serde_json::from_str(arguments).unwrap_or_default();
1351 match name.as_str() {
1352 "hangup_call" => Ok(vec![Command::Hangup {
1353 reason: args["reason"].as_str().map(|s| s.to_string()),
1354 initiator: Some("ai".to_string()),
1355 }]),
1356 "transfer_call" | "refer_call" => {
1357 if let Some(callee) = args["callee"]
1358 .as_str()
1359 .or_else(|| args["callee_uri"].as_str())
1360 {
1361 Ok(vec![Command::Refer {
1362 caller: String::new(),
1363 callee: callee.to_string(),
1364 options: None,
1365 }])
1366 } else {
1367 warn!("No callee provided for transfer_call");
1368 Ok(vec![])
1369 }
1370 }
1371 "goto_scene" => {
1372 if let Some(scene) = args["scene"].as_str() {
1373 self.switch_to_scene(scene, false).await
1374 } else {
1375 Ok(vec![])
1376 }
1377 }
1378 _ => {
1379 warn!("Unhandled function call: {}", name);
1380 Ok(vec![])
1381 }
1382 }
1383 }
1384
1385 _ => Ok(vec![]),
1386 }
1387 }
1388
1389 async fn get_history(&self) -> Vec<ChatMessage> {
1390 self.history.clone()
1391 }
1392
1393 async fn summarize(&mut self, prompt: &str) -> Result<String> {
1394 info!("Generating summary with prompt: {}", prompt);
1395 let mut summary_history = self.history.clone();
1396 summary_history.push(ChatMessage {
1397 role: "user".to_string(),
1398 content: prompt.to_string(),
1399 });
1400
1401 self.provider.call(&self.config, &summary_history).await
1402 }
1403}
1404
1405#[cfg(test)]
1406mod tests {
1407 use super::*;
1408 use crate::event::SessionEvent;
1409 use anyhow::{Result, anyhow};
1410 use async_trait::async_trait;
1411 use std::collections::VecDeque;
1412 use std::sync::Mutex;
1413
1414 struct TestProvider {
1415 responses: Mutex<VecDeque<String>>,
1416 }
1417
1418 impl TestProvider {
1419 fn new(responses: Vec<String>) -> Self {
1420 Self {
1421 responses: Mutex::new(VecDeque::from(responses)),
1422 }
1423 }
1424 }
1425
1426 #[async_trait]
1427 impl LlmProvider for TestProvider {
1428 async fn call(&self, _config: &LlmConfig, _history: &[ChatMessage]) -> Result<String> {
1429 let mut guard = self.responses.lock().unwrap();
1430 guard
1431 .pop_front()
1432 .ok_or_else(|| anyhow!("Test provider ran out of responses"))
1433 }
1434
1435 async fn call_stream(
1436 &self,
1437 _config: &LlmConfig,
1438 _history: &[ChatMessage],
1439 ) -> Result<Pin<Box<dyn Stream<Item = Result<LlmStreamEvent>> + Send>>> {
1440 let response = self.call(_config, _history).await?;
1441 let s = async_stream::stream! {
1442 yield Ok(LlmStreamEvent::Content(response));
1443 };
1444 Ok(Box::pin(s))
1445 }
1446 }
1447
1448 struct RecordingRag {
1449 queries: Mutex<Vec<String>>,
1450 }
1451
1452 impl RecordingRag {
1453 fn new() -> Self {
1454 Self {
1455 queries: Mutex::new(Vec::new()),
1456 }
1457 }
1458
1459 fn recorded_queries(&self) -> Vec<String> {
1460 self.queries.lock().unwrap().clone()
1461 }
1462 }
1463
1464 #[async_trait]
1465 impl RagRetriever for RecordingRag {
1466 async fn retrieve(&self, query: &str) -> Result<String> {
1467 self.queries.lock().unwrap().push(query.to_string());
1468 Ok(format!("retrieved {}", query))
1469 }
1470 }
1471
1472 #[test]
1473 fn test_build_system_prompt_with_features() {
1474 let config = LlmConfig {
1475 prompt: Some("Base prompt".to_string()),
1476 language: Some("zh".to_string()),
1477 features: Some(vec!["intent_clarification".to_string()]),
1478 ..Default::default()
1479 };
1480
1481 let prompt = LlmHandler::build_system_prompt(&config, None);
1482 assert!(prompt.contains("Base prompt"));
1483 assert!(prompt.contains("### Enhanced Capabilities:"));
1484 assert!(prompt.contains("如果用户意图模糊"));
1486 assert!(prompt.contains("<hangup/>"));
1487 }
1488
1489 #[test]
1490 fn test_build_system_prompt_missing_feature() {
1491 let config = LlmConfig {
1492 prompt: Some("Base prompt".to_string()),
1493 language: Some("zh".to_string()),
1494 features: Some(vec!["non_existent_feature".to_string()]),
1495 ..Default::default()
1496 };
1497
1498 let prompt = LlmHandler::build_system_prompt(&config, None);
1500 assert!(prompt.contains("Base prompt"));
1501 assert!(!prompt.contains("Enhanced Capabilities"));
1502 }
1503
1504 #[test]
1505 fn test_build_system_prompt_en() {
1506 let config = LlmConfig {
1507 prompt: Some("Base prompt".to_string()),
1508 language: Some("en".to_string()),
1509 features: Some(vec!["intent_clarification".to_string()]),
1510 ..Default::default()
1511 };
1512
1513 let prompt = LlmHandler::build_system_prompt(&config, None);
1514 assert!(prompt.contains("If the user's intent is unclear"));
1515 }
1516
1517 #[tokio::test]
1518 async fn handler_applies_tool_instructions() -> Result<()> {
1519 let response = r#"{
1520 "text": "Goodbye",
1521 "waitInputTimeout": 15000,
1522 "tools": [
1523 {"name": "hangup", "reason": "done", "initiator": "agent"},
1524 {"name": "refer", "caller": "sip:bot", "callee": "sip:lead"}
1525 ]
1526 }"#;
1527
1528 let provider = Arc::new(TestProvider::new(vec![response.to_string()]));
1529 let mut handler = LlmHandler::with_provider(
1530 LlmConfig::default(),
1531 provider,
1532 Arc::new(NoopRagRetriever),
1533 crate::playbook::InterruptionConfig::default(),
1534 None,
1535 HashMap::new(),
1536 None,
1537 None,
1538 );
1539
1540 let event = SessionEvent::AsrFinal {
1541 track_id: "track-1".to_string(),
1542 timestamp: 0,
1543 index: 0,
1544 start_time: None,
1545 end_time: None,
1546 text: "hello".to_string(),
1547 is_filler: None,
1548 confidence: None,
1549 };
1550
1551 let commands = handler.on_event(&event).await?;
1552 assert!(matches!(
1553 commands.get(0),
1554 Some(Command::Tts {
1555 text,
1556 wait_input_timeout: Some(15000),
1557 auto_hangup: Some(true),
1558 ..
1559 }) if text == "Goodbye"
1560 ));
1561 assert!(commands.iter().any(|cmd| matches!(
1562 cmd,
1563 Command::Refer {
1564 caller,
1565 callee,
1566 ..
1567 } if caller == "sip:bot" && callee == "sip:lead"
1568 )));
1569
1570 Ok(())
1571 }
1572
1573 #[tokio::test]
1574 async fn handler_requeries_after_rag() -> Result<()> {
1575 let rag_instruction = r#"{"tools": [{"name": "rag", "query": "policy"}]}"#;
1576 let provider = Arc::new(TestProvider::new(vec![
1577 rag_instruction.to_string(),
1578 "Final answer".to_string(),
1579 ]));
1580 let rag = Arc::new(RecordingRag::new());
1581 let mut handler = LlmHandler::with_provider(
1582 LlmConfig::default(),
1583 provider,
1584 rag.clone(),
1585 crate::playbook::InterruptionConfig::default(),
1586 None,
1587 HashMap::new(),
1588 None,
1589 None,
1590 );
1591
1592 let event = SessionEvent::AsrFinal {
1593 track_id: "track-2".to_string(),
1594 timestamp: 0,
1595 index: 0,
1596 start_time: None,
1597 end_time: None,
1598 text: "reep".to_string(),
1599 is_filler: None,
1600 confidence: None,
1601 };
1602
1603 let commands = handler.on_event(&event).await?;
1604 assert!(matches!(
1605 commands.get(0),
1606 Some(Command::Tts {
1607 text,
1608 wait_input_timeout: Some(timeout),
1609 ..
1610 }) if text == "Final answer" && *timeout == 10000
1611 ));
1612 assert_eq!(rag.recorded_queries(), vec!["policy".to_string()]);
1613
1614 Ok(())
1615 }
1616
1617 #[tokio::test]
1618 async fn test_full_dialogue_flow() -> Result<()> {
1619 let responses = vec![
1620 "Hello! How can I help you today?".to_string(),
1621 r#"{"text": "I can help with that. Anything else?", "waitInputTimeout": 5000}"#
1622 .to_string(),
1623 r#"{"text": "Goodbye!", "tools": [{"name": "hangup", "reason": "completed"}]}"#
1624 .to_string(),
1625 ];
1626
1627 let provider = Arc::new(TestProvider::new(responses));
1628 let config = LlmConfig {
1629 greeting: Some("Welcome to the voice assistant.".to_string()),
1630 ..Default::default()
1631 };
1632
1633 let mut handler = LlmHandler::with_provider(
1634 config,
1635 provider,
1636 Arc::new(NoopRagRetriever),
1637 crate::playbook::InterruptionConfig::default(),
1638 None,
1639 HashMap::new(),
1640 None,
1641 None,
1642 );
1643
1644 let commands = handler.on_start().await?;
1646 assert_eq!(commands.len(), 1);
1647 if let Command::Tts { text, .. } = &commands[0] {
1648 assert_eq!(text, "Welcome to the voice assistant.");
1649 } else {
1650 panic!("Expected Tts command");
1651 }
1652
1653 let event = SessionEvent::AsrFinal {
1655 track_id: "test".to_string(),
1656 timestamp: 0,
1657 index: 0,
1658 start_time: None,
1659 end_time: None,
1660 text: "I need help".to_string(),
1661 is_filler: None,
1662 confidence: None,
1663 };
1664 let commands = handler.on_event(&event).await?;
1665 assert_eq!(commands.len(), 3);
1667 if let Command::Tts { text, .. } = &commands[0] {
1668 assert!(text.contains("Hello"));
1669 } else {
1670 panic!("Expected Tts command");
1671 }
1672
1673 let event = SessionEvent::AsrFinal {
1675 track_id: "test".to_string(),
1676 timestamp: 0,
1677 index: 1,
1678 start_time: None,
1679 end_time: None,
1680 text: "Tell me a joke".to_string(),
1681 is_filler: None,
1682 confidence: None,
1683 };
1684 let commands = handler.on_event(&event).await?;
1685 assert_eq!(commands.len(), 1);
1686 if let Command::Tts {
1687 text,
1688 wait_input_timeout,
1689 ..
1690 } = &commands[0]
1691 {
1692 assert_eq!(text, "I can help with that. Anything else?");
1693 assert_eq!(*wait_input_timeout, Some(5000));
1694 } else {
1695 panic!("Expected Tts command");
1696 }
1697
1698 let event = SessionEvent::AsrFinal {
1700 track_id: "test".to_string(),
1701 timestamp: 0,
1702 index: 2,
1703 start_time: None,
1704 end_time: None,
1705 text: "That's all, thanks".to_string(),
1706 is_filler: None,
1707 confidence: None,
1708 };
1709 let commands = handler.on_event(&event).await?;
1710 assert_eq!(commands.len(), 1);
1712
1713 let has_tts_hangup = commands.iter().any(|c| {
1714 matches!(
1715 c,
1716 Command::Tts {
1717 text,
1718 auto_hangup: Some(true),
1719 ..
1720 } if text == "Goodbye!"
1721 )
1722 });
1723
1724 assert!(has_tts_hangup);
1725
1726 Ok(())
1727 }
1728
1729 #[tokio::test]
1730 async fn test_xml_tools_and_sentence_splitting() -> Result<()> {
1731 let responses = vec!["Hello! <refer to=\"sip:123\"/> How are you? <hangup/>".to_string()];
1732 let provider = Arc::new(TestProvider::new(responses));
1733 let mut handler = LlmHandler::with_provider(
1734 LlmConfig::default(),
1735 provider,
1736 Arc::new(NoopRagRetriever),
1737 crate::playbook::InterruptionConfig::default(),
1738 None,
1739 HashMap::new(),
1740 None,
1741 None,
1742 );
1743
1744 let event = SessionEvent::AsrFinal {
1745 track_id: "test".to_string(),
1746 timestamp: 0,
1747 index: 0,
1748 start_time: None,
1749 end_time: None,
1750 text: "hi".to_string(),
1751 is_filler: None,
1752 confidence: None,
1753 };
1754
1755 let commands = handler.on_event(&event).await?;
1756
1757 assert_eq!(commands.len(), 4);
1763
1764 if let Command::Tts {
1765 text,
1766 play_id: pid1,
1767 ..
1768 } = &commands[0]
1769 {
1770 assert!(text.contains("Hello"));
1771 assert!(pid1.is_some());
1772
1773 if let Command::Refer { callee, .. } = &commands[1] {
1774 assert_eq!(callee, "sip:123");
1775 } else {
1776 panic!("Expected Refer");
1777 }
1778
1779 if let Command::Tts {
1780 text,
1781 play_id: pid2,
1782 ..
1783 } = &commands[2]
1784 {
1785 assert!(text.contains("How are you"));
1786 assert_eq!(*pid1, *pid2); } else {
1788 panic!("Expected Tts");
1789 }
1790
1791 if let Command::Tts {
1792 auto_hangup: Some(true),
1793 ..
1794 } = &commands[3]
1795 {
1796 } else {
1798 panic!("Expected Tts with auto_hangup");
1799 }
1800 } else {
1801 panic!("Expected Tts");
1802 }
1803
1804 Ok(())
1805 }
1806
1807 #[tokio::test]
1808 async fn test_interruption_logic() -> Result<()> {
1809 let provider = Arc::new(TestProvider::new(vec!["Some long response".to_string()]));
1810 let mut handler = LlmHandler::with_provider(
1811 LlmConfig::default(),
1812 provider,
1813 Arc::new(NoopRagRetriever),
1814 crate::playbook::InterruptionConfig::default(),
1815 None,
1816 HashMap::new(),
1817 None,
1818 None,
1819 );
1820
1821 let event = SessionEvent::AsrFinal {
1823 track_id: "test".to_string(),
1824 timestamp: 0,
1825 index: 0,
1826 start_time: None,
1827 end_time: None,
1828 text: "hello".to_string(),
1829 is_filler: None,
1830 confidence: None,
1831 };
1832 handler.on_event(&event).await?;
1833 assert!(handler.is_speaking);
1834
1835 tokio::time::sleep(std::time::Duration::from_millis(850)).await;
1837
1838 let event = SessionEvent::AsrDelta {
1840 track_id: "test".to_string(),
1841 timestamp: 0,
1842 index: 0,
1843 start_time: None,
1844 end_time: None,
1845 text: "I...".to_string(),
1846 is_filler: None,
1847 confidence: None,
1848 };
1849 let commands = handler.on_event(&event).await?;
1850 assert_eq!(commands.len(), 1);
1851 assert!(matches!(commands[0], Command::Interrupt { .. }));
1852 assert!(!handler.is_speaking);
1853
1854 Ok(())
1855 }
1856
1857 #[tokio::test]
1858 async fn test_rag_iteration_limit() -> Result<()> {
1859 let rag_instruction = r#"{"tools": [{"name": "rag", "query": "endless"}]}"#;
1861 let provider = Arc::new(TestProvider::new(vec![
1862 rag_instruction.to_string(),
1863 rag_instruction.to_string(),
1864 rag_instruction.to_string(),
1865 rag_instruction.to_string(),
1866 "Should not reach here".to_string(),
1867 ]));
1868
1869 let mut handler = LlmHandler::with_provider(
1870 LlmConfig::default(),
1871 provider,
1872 Arc::new(RecordingRag::new()),
1873 crate::playbook::InterruptionConfig::default(),
1874 None,
1875 HashMap::new(),
1876 None,
1877 None,
1878 );
1879
1880 let event = SessionEvent::AsrFinal {
1881 track_id: "test".to_string(),
1882 timestamp: 0,
1883 index: 0,
1884 start_time: None,
1885 end_time: None,
1886 text: "loop".to_string(),
1887 is_filler: None,
1888 confidence: None,
1889 };
1890
1891 let commands = handler.on_event(&event).await?;
1892 assert_eq!(commands.len(), 1);
1894 if let Command::Tts { text, .. } = &commands[0] {
1895 assert_eq!(text, rag_instruction);
1896 }
1897
1898 Ok(())
1899 }
1900
1901 #[tokio::test]
1902 async fn test_follow_up_logic() -> Result<()> {
1903 use std::time::Duration;
1904
1905 let follow_up_config = super::super::FollowUpConfig {
1907 timeout: 100, max_count: 2,
1909 };
1910
1911 let provider = Arc::new(TestProvider::new(vec![
1913 "Follow up 1".to_string(),
1914 "Follow up 2".to_string(),
1915 "Response to user".to_string(),
1916 ]));
1917
1918 let mut handler = LlmHandler::with_provider(
1919 LlmConfig::default(),
1920 provider,
1921 Arc::new(NoopRagRetriever),
1922 crate::playbook::InterruptionConfig::default(),
1923 Some(follow_up_config),
1924 HashMap::new(),
1925 None,
1926 None,
1927 );
1928
1929 handler.last_interaction_at = std::time::Instant::now();
1931 handler.is_speaking = false;
1932
1933 let event = SessionEvent::Silence {
1935 track_id: "t1".to_string(),
1936 timestamp: 0,
1937 start_time: 0,
1938 duration: 50,
1939 samples: None,
1940 };
1941 let commands = handler.on_event(&event).await?;
1942 assert!(commands.is_empty(), "Should not trigger if < timeout");
1943
1944 tokio::time::sleep(Duration::from_millis(110)).await;
1946 let event = SessionEvent::Silence {
1950 track_id: "t1".to_string(),
1951 timestamp: 0,
1952 start_time: 0,
1953 duration: 100,
1954 samples: None,
1955 };
1956 let commands = handler.on_event(&event).await?;
1957 assert_eq!(commands.len(), 1, "Should trigger follow-up 1");
1958 if let Command::Tts { text, .. } = &commands[0] {
1959 assert_eq!(text, "Follow up 1");
1960 }
1961 assert_eq!(handler.consecutive_follow_ups, 1);
1962
1963 let event = SessionEvent::TrackEnd {
1966 track_id: "t1".to_string(),
1967 timestamp: 0,
1968 play_id: None,
1969 duration: 100,
1970 ssrc: 0,
1971 };
1972 handler.on_event(&event).await?;
1973 assert!(
1974 !handler.is_speaking,
1975 "Bot should not be speaking after TrackEnd"
1976 );
1977
1978 tokio::time::sleep(Duration::from_millis(110)).await;
1985 let event = SessionEvent::Silence {
1986 track_id: "t1".to_string(),
1987 timestamp: 0,
1988 start_time: 0,
1989 duration: 100,
1990 samples: None,
1991 };
1992 let commands = handler.on_event(&event).await?;
1993 assert_eq!(commands.len(), 1, "Should trigger follow-up 2");
1994 if let Command::Tts { text, .. } = &commands[0] {
1995 assert_eq!(text, "Follow up 2");
1996 }
1997 assert_eq!(handler.consecutive_follow_ups, 2);
1998
1999 let event = SessionEvent::TrackEnd {
2001 track_id: "t1".to_string(),
2002 timestamp: 0,
2003 play_id: None,
2004 duration: 100,
2005 ssrc: 0,
2006 };
2007 handler.on_event(&event).await?;
2008
2009 tokio::time::sleep(Duration::from_millis(110)).await;
2011 let event = SessionEvent::Silence {
2012 track_id: "t1".to_string(),
2013 timestamp: 0,
2014 start_time: 0,
2015 duration: 100,
2016 samples: None,
2017 };
2018 let commands = handler.on_event(&event).await?;
2019 assert_eq!(commands.len(), 1, "Should hangup after max count");
2020 assert!(matches!(commands[0], Command::Hangup { .. }));
2021
2022 handler.consecutive_follow_ups = 2; let event = SessionEvent::AsrFinal {
2025 track_id: "t1".to_string(),
2026 timestamp: 0,
2027 index: 0,
2028 start_time: None,
2029 end_time: None,
2030 text: "User speaks".to_string(),
2031 is_filler: None,
2032 confidence: None,
2033 };
2034 let _ = handler.on_event(&event).await?;
2036 assert_eq!(
2037 handler.consecutive_follow_ups, 0,
2038 "Should reset count on AsrFinal"
2039 );
2040
2041 Ok(())
2042 }
2043
2044 #[tokio::test]
2045 async fn test_interruption_protection_period() -> Result<()> {
2046 let provider = Arc::new(TestProvider::new(vec!["Some long response".to_string()]));
2047 let mut config = crate::playbook::InterruptionConfig::default();
2048 config.ignore_first_ms = Some(800);
2049
2050 let mut handler = LlmHandler::with_provider(
2051 LlmConfig::default(),
2052 provider,
2053 Arc::new(NoopRagRetriever),
2054 config,
2055 None,
2056 HashMap::new(),
2057 None,
2058 None,
2059 );
2060
2061 let event = SessionEvent::AsrFinal {
2063 track_id: "test".to_string(),
2064 timestamp: 0,
2065 index: 0,
2066 start_time: None,
2067 end_time: None,
2068 text: "hello".to_string(),
2069 is_filler: None,
2070 confidence: None,
2071 };
2072 handler.on_event(&event).await?;
2073 assert!(handler.is_speaking);
2074
2075 let event = SessionEvent::AsrDelta {
2077 track_id: "test".to_string(),
2078 timestamp: 0,
2079 index: 0,
2080 start_time: None,
2081 end_time: None,
2082 text: "I...".to_string(),
2083 is_filler: None,
2084 confidence: None,
2085 };
2086 let commands = handler.on_event(&event).await?;
2087 assert_eq!(commands.len(), 0);
2089 assert!(handler.is_speaking);
2090
2091 Ok(())
2092 }
2093
2094 #[tokio::test]
2095 async fn test_interruption_filler_word() -> Result<()> {
2096 let provider = Arc::new(TestProvider::new(vec!["Some long response".to_string()]));
2097 let mut config = crate::playbook::InterruptionConfig::default();
2098 config.filler_word_filter = Some(true);
2099 config.ignore_first_ms = Some(0); let mut handler = LlmHandler::with_provider(
2102 LlmConfig::default(),
2103 provider,
2104 Arc::new(NoopRagRetriever),
2105 config,
2106 None,
2107 HashMap::new(),
2108 None,
2109 None,
2110 );
2111
2112 let event = SessionEvent::AsrFinal {
2114 track_id: "test".to_string(),
2115 timestamp: 0,
2116 index: 0,
2117 start_time: None,
2118 end_time: None,
2119 text: "hello".to_string(),
2120 is_filler: None,
2121 confidence: None,
2122 };
2123 handler.on_event(&event).await?;
2124 assert!(handler.is_speaking);
2125
2126 tokio::time::sleep(std::time::Duration::from_millis(600)).await;
2128
2129 let event = SessionEvent::AsrDelta {
2131 track_id: "test".to_string(),
2132 timestamp: 0,
2133 index: 0,
2134 start_time: None,
2135 end_time: None,
2136 text: "uh".to_string(),
2137 is_filler: Some(true),
2138 confidence: None,
2139 };
2140 let commands = handler.on_event(&event).await?;
2141 assert_eq!(commands.len(), 0);
2143 assert!(handler.is_speaking);
2144
2145 let event = SessionEvent::AsrDelta {
2147 track_id: "test".to_string(),
2148 timestamp: 0,
2149 index: 0,
2150 start_time: None,
2151 end_time: None,
2152 text: "Wait".to_string(),
2153 is_filler: Some(false),
2154 confidence: None,
2155 };
2156 let commands = handler.on_event(&event).await?;
2157 assert_eq!(commands.len(), 1);
2159 assert!(matches!(commands[0], Command::Interrupt { .. }));
2160
2161 Ok(())
2162 }
2163
2164 #[tokio::test]
2165 async fn test_eou_early_response() -> Result<()> {
2166 let provider = Arc::new(TestProvider::new(vec![
2167 "End of Utterance response".to_string(),
2168 ]));
2169 let mut handler = LlmHandler::with_provider(
2170 LlmConfig::default(),
2171 provider,
2172 Arc::new(NoopRagRetriever),
2173 crate::playbook::InterruptionConfig::default(),
2174 None,
2175 HashMap::new(),
2176 None,
2177 None,
2178 );
2179
2180 let event = SessionEvent::Eou {
2182 track_id: "test".to_string(),
2183 timestamp: 0,
2184 completed: true,
2185 };
2186 let commands = handler.on_event(&event).await?;
2187 assert_eq!(commands.len(), 1);
2188 if let Command::Tts { text, .. } = &commands[0] {
2189 assert_eq!(text, "End of Utterance response");
2190 } else {
2191 panic!("Expected Tts");
2192 }
2193
2194 Ok(())
2195 }
2196
2197 #[tokio::test]
2198 async fn test_summary_and_history() -> Result<()> {
2199 let provider = Arc::new(TestProvider::new(vec!["Test summary".to_string()]));
2200 let mut handler = LlmHandler::with_provider(
2201 LlmConfig::default(),
2202 provider,
2203 Arc::new(NoopRagRetriever),
2204 crate::playbook::InterruptionConfig::default(),
2205 None,
2206 HashMap::new(),
2207 None,
2208 None,
2209 );
2210
2211 handler.history.push(ChatMessage {
2213 role: "user".to_string(),
2214 content: "Hello".to_string(),
2215 });
2216 handler.history.push(ChatMessage {
2217 role: "assistant".to_string(),
2218 content: "Hi there".to_string(),
2219 });
2220
2221 let history = handler.get_history().await;
2223 assert_eq!(history.len(), 3); let summary = handler.summarize("Summarize this").await?;
2227 assert_eq!(summary, "Test summary");
2228
2229 Ok(())
2230 }
2231}