1use crate::call::Command;
2use anyhow::{Result, anyhow};
3use async_trait::async_trait;
4use reqwest::Client;
5use serde::{Deserialize, Serialize};
6use serde_json::json;
7use std::sync::Arc;
8use tracing::{info, warn};
9use crate::ReferOption;
10use crate::event::SessionEvent;
11
12use super::LlmConfig;
13use super::dialogue::DialogueHandler;
14
15#[derive(Serialize, Deserialize, Clone, Debug)]
16pub struct ChatMessage {
17 pub role: String,
18 pub content: String,
19}
20
21const MAX_RAG_ATTEMPTS: usize = 3;
22
23#[async_trait]
24pub trait LlmProvider: Send + Sync {
25 async fn call(&self, config: &LlmConfig, history: &[ChatMessage]) -> Result<String>;
26}
27
28struct DefaultLlmProvider {
29 client: Client,
30}
31
32impl DefaultLlmProvider {
33 fn new() -> Self {
34 Self {
35 client: Client::new(),
36 }
37 }
38}
39
40#[async_trait]
41impl LlmProvider for DefaultLlmProvider {
42 async fn call(&self, config: &LlmConfig, history: &[ChatMessage]) -> Result<String> {
43 let mut url = config
44 .base_url
45 .clone()
46 .unwrap_or_else(|| "https://api.openai.com/v1/chat/completions".to_string());
47 let model = config
48 .model
49 .clone()
50 .unwrap_or_else(|| "gpt-3.5-turbo".to_string());
51 let api_key = config.api_key.clone().unwrap_or_default();
52
53 if !url.ends_with("/chat/completions") {
54 url = format!("{}/chat/completions", url.trim_end_matches('/'));
55 }
56
57 let body = json!({
58 "model": model,
59 "messages": history,
60 });
61
62 let res = self
63 .client
64 .post(&url)
65 .header("Authorization", format!("Bearer {}", api_key))
66 .json(&body)
67 .send()
68 .await?;
69
70 if !res.status().is_success() {
71 return Err(anyhow!("LLM request failed: {}", res.status()));
72 }
73
74 let json: serde_json::Value = res.json().await?;
75 let content = json["choices"][0]["message"]["content"]
76 .as_str()
77 .ok_or_else(|| anyhow!("Invalid LLM response"))?
78 .to_string();
79
80 Ok(content)
81 }
82}
83
84#[async_trait]
85pub trait RagRetriever: Send + Sync {
86 async fn retrieve(&self, query: &str) -> Result<String>;
87}
88
89struct NoopRagRetriever;
90
91#[async_trait]
92impl RagRetriever for NoopRagRetriever {
93 async fn retrieve(&self, _query: &str) -> Result<String> {
94 Ok(String::new())
95 }
96}
97
98#[derive(Debug, Deserialize)]
99#[serde(rename_all = "camelCase")]
100struct StructuredResponse {
101 text: Option<String>,
102 wait_input_timeout: Option<u32>,
103 tools: Option<Vec<ToolInvocation>>,
104}
105
106#[derive(Debug, Deserialize)]
107#[serde(tag = "name", rename_all = "lowercase")]
108enum ToolInvocation {
109 #[serde(rename_all = "camelCase")]
110 Hangup {
111 reason: Option<String>,
112 initiator: Option<String>,
113 },
114 #[serde(rename_all = "camelCase")]
115 Refer {
116 caller: String,
117 callee: String,
118 options: Option<ReferOption>,
119 },
120 #[serde(rename_all = "camelCase")]
121 Rag {
122 query: String,
123 source: Option<String>,
124 },
125}
126
127pub struct LlmHandler {
128 config: LlmConfig,
129 history: Vec<ChatMessage>,
130 provider: Box<dyn LlmProvider>,
131 rag_retriever: Arc<dyn RagRetriever>,
132 is_speaking: bool,
133 event_sender: Option<crate::event::EventSender>,
134}
135
136impl LlmHandler {
137 pub fn new(config: LlmConfig) -> Self {
138 Self::with_provider(
139 config,
140 Box::new(DefaultLlmProvider::new()),
141 Arc::new(NoopRagRetriever),
142 )
143 }
144
145 pub fn with_provider(
146 config: LlmConfig,
147 provider: Box<dyn LlmProvider>,
148 rag_retriever: Arc<dyn RagRetriever>,
149 ) -> Self {
150 let mut history = Vec::new();
151 if let Some(prompt) = &config.prompt {
152 history.push(ChatMessage {
153 role: "system".to_string(),
154 content: prompt.clone(),
155 });
156 }
157
158 Self {
159 config,
160 history,
161 provider,
162 rag_retriever,
163 is_speaking: false,
164 event_sender: None,
165 }
166 }
167
168 pub fn set_event_sender(&mut self, sender: crate::event::EventSender) {
169 self.event_sender = Some(sender);
170 }
171
172 fn send_debug_event(&self, key: &str, data: serde_json::Value) {
173 if let Some(sender) = &self.event_sender {
174 let event = crate::event::SessionEvent::Metrics {
175 timestamp: crate::media::get_timestamp(),
176 key: key.to_string(),
177 duration: 0,
178 data,
179 };
180 let _ = sender.send(event);
181 }
182 }
183
184 async fn call_llm(&self) -> Result<String> {
185 self.provider.call(&self.config, &self.history).await
186 }
187
188 fn create_tts_command(&self, text: String, wait_input_timeout: Option<u32>) -> Command {
189 let timeout = wait_input_timeout.unwrap_or(10000);
190 Command::Tts {
191 text,
192 speaker: None,
193 play_id: None,
194 auto_hangup: None,
195 streaming: None,
196 end_of_stream: None,
197 option: None,
198 wait_input_timeout: Some(timeout),
199 base64: None,
200 }
201 }
202
203 async fn generate_response(&mut self) -> Result<Vec<Command>> {
204 self.send_debug_event("llm_call_start", json!({
206 "history_length": self.history.len(),
207 }));
208
209 let initial = self.call_llm().await?;
210
211 self.send_debug_event("llm_response", json!({
213 "response": initial,
214 }));
215
216 self.interpret_response(initial).await
217 }
218
219 async fn interpret_response(&mut self, initial: String) -> Result<Vec<Command>> {
220 let mut tool_commands = Vec::new();
221 let mut wait_input_timeout = None;
222 let mut attempts = 0;
223 let final_text: Option<String>;
224 let mut raw = initial;
225
226 loop {
227 attempts += 1;
228 let mut rerun_for_rag = false;
229
230 if let Some(structured) = parse_structured_response(&raw) {
231 if wait_input_timeout.is_none() {
232 wait_input_timeout = structured.wait_input_timeout;
233 }
234
235 if let Some(tools) = structured.tools {
236 for tool in tools {
237 match tool {
238 ToolInvocation::Hangup { ref reason, ref initiator } => {
239 self.send_debug_event("tool_invocation", json!({
241 "tool": "Hangup",
242 "params": {
243 "reason": reason,
244 "initiator": initiator,
245 }
246 }));
247 tool_commands.push(Command::Hangup {
248 reason: reason.clone(),
249 initiator: initiator.clone()
250 });
251 }
252 ToolInvocation::Refer {
253 ref caller,
254 ref callee,
255 ref options,
256 } => {
257 self.send_debug_event("tool_invocation", json!({
259 "tool": "Refer",
260 "params": {
261 "caller": caller,
262 "callee": callee,
263 }
264 }));
265 tool_commands.push(Command::Refer {
266 caller: caller.clone(),
267 callee: callee.clone(),
268 options: options.clone(),
269 });
270 }
271 ToolInvocation::Rag { ref query, ref source } => {
272 self.send_debug_event("tool_invocation", json!({
274 "tool": "Rag",
275 "params": {
276 "query": query,
277 "source": source,
278 }
279 }));
280
281 let rag_result = self.rag_retriever.retrieve(&query).await?;
282
283 self.send_debug_event("rag_result", json!({
285 "query": query,
286 "result": rag_result,
287 }));
288
289 let summary = if let Some(source) = source {
290 format!("[{}] {}", source, rag_result)
291 } else {
292 rag_result
293 };
294 self.history.push(ChatMessage {
295 role: "system".to_string(),
296 content: format!("RAG result for {}: {}", query, summary),
297 });
298 rerun_for_rag = true;
299 }
300 }
301 }
302 }
303
304 if rerun_for_rag {
305 if attempts >= MAX_RAG_ATTEMPTS {
306 warn!("Reached RAG iteration limit, using last response");
307 final_text = structured.text.or_else(|| Some(raw.clone()));
308 break;
309 }
310 raw = self.call_llm().await?;
311 continue;
312 }
313
314 final_text = Some(structured.text.unwrap_or_else(|| raw.clone()));
315 break;
316 }
317
318 final_text = Some(raw.clone());
319 break;
320 }
321
322 let mut commands = Vec::new();
323 if let Some(text) = final_text {
324 if !text.trim().is_empty() {
325 self.history.push(ChatMessage {
326 role: "assistant".to_string(),
327 content: text.clone(),
328 });
329 self.is_speaking = true;
330 commands.push(self.create_tts_command(text, wait_input_timeout));
331 }
332 }
333
334 commands.extend(tool_commands);
335
336 Ok(commands)
337 }
338}
339
340fn parse_structured_response(raw: &str) -> Option<StructuredResponse> {
341 let payload = extract_json_block(raw)?;
342 serde_json::from_str(payload).ok()
343}
344
345fn extract_json_block(raw: &str) -> Option<&str> {
346 let trimmed = raw.trim();
347 if trimmed.starts_with('`') {
348 if let Some(end) = trimmed.rfind("```") {
349 if end <= 3 {
350 return None;
351 }
352 let mut inner = &trimmed[3..end];
353 inner = inner.trim();
354 if inner.to_lowercase().starts_with("json") {
355 if let Some(newline) = inner.find('\n') {
356 inner = inner[newline + 1..].trim();
357 } else if inner.len() > 4 {
358 inner = inner[4..].trim();
359 } else {
360 inner = inner.trim();
361 }
362 }
363 return Some(inner);
364 }
365 } else if trimmed.starts_with('{') || trimmed.starts_with('[') {
366 return Some(trimmed);
367 }
368 None
369}
370
371#[async_trait]
372impl DialogueHandler for LlmHandler {
373 async fn on_start(&mut self) -> Result<Vec<Command>> {
374 if let Some(greeting) = &self.config.greeting {
375 self.is_speaking = true;
376 return Ok(vec![self.create_tts_command(greeting.clone(), None)]);
377 }
378
379 self.generate_response().await
380 }
381
382 async fn on_event(&mut self, event: &SessionEvent) -> Result<Vec<Command>> {
383 match event {
384 SessionEvent::AsrFinal { text, .. } => {
385 if text.trim().is_empty() {
386 return Ok(vec![]);
387 }
388
389 self.history.push(ChatMessage {
390 role: "user".to_string(),
391 content: text.clone(),
392 });
393
394 self.generate_response().await
395 }
396
397 SessionEvent::AsrDelta { .. } | SessionEvent::Speaking { .. } => {
398 if self.is_speaking {
399 info!("Interruption detected, stopping playback");
400 self.is_speaking = false;
401 return Ok(vec![Command::Interrupt {
402 graceful: Some(true),
403 }]);
404 }
405 Ok(vec![])
406 }
407
408 SessionEvent::Silence { .. } => {
409 info!("Silence timeout detected, triggering follow-up");
410 self.generate_response().await
411 }
412
413 SessionEvent::TrackEnd { .. } => {
414 self.is_speaking = false;
415 Ok(vec![])
416 }
417
418 _ => Ok(vec![]),
419 }
420 }
421}
422
423#[cfg(test)]
424mod tests {
425 use super::*;
426 use anyhow::{Result, anyhow};
427 use async_trait::async_trait;
428 use std::collections::VecDeque;
429 use std::sync::Mutex;
430 use crate::event::SessionEvent;
431
432 struct TestProvider {
433 responses: Mutex<VecDeque<String>>,
434 }
435
436 impl TestProvider {
437 fn new(responses: Vec<String>) -> Self {
438 Self {
439 responses: Mutex::new(VecDeque::from(responses)),
440 }
441 }
442 }
443
444 #[async_trait]
445 impl LlmProvider for TestProvider {
446 async fn call(&self, _config: &LlmConfig, _history: &[ChatMessage]) -> Result<String> {
447 let mut guard = self.responses.lock().unwrap();
448 guard
449 .pop_front()
450 .ok_or_else(|| anyhow!("Test provider ran out of responses"))
451 }
452 }
453
454 struct RecordingRag {
455 queries: Mutex<Vec<String>>,
456 }
457
458 impl RecordingRag {
459 fn new() -> Self {
460 Self {
461 queries: Mutex::new(Vec::new()),
462 }
463 }
464
465 fn recorded_queries(&self) -> Vec<String> {
466 self.queries.lock().unwrap().clone()
467 }
468 }
469
470 #[async_trait]
471 impl RagRetriever for RecordingRag {
472 async fn retrieve(&self, query: &str) -> Result<String> {
473 self.queries.lock().unwrap().push(query.to_string());
474 Ok(format!("retrieved {}", query))
475 }
476 }
477
478 #[tokio::test]
479 async fn handler_applies_tool_instructions() -> Result<()> {
480 let response = r#"{
481 "text": "Goodbye",
482 "waitInputTimeout": 15000,
483 "tools": [
484 {"name": "hangup", "reason": "done", "initiator": "agent"},
485 {"name": "refer", "caller": "sip:bot", "callee": "sip:lead"}
486 ]
487 }"#;
488
489 let provider = Box::new(TestProvider::new(vec![response.to_string()]));
490 let mut handler =
491 LlmHandler::with_provider(LlmConfig::default(), provider, Arc::new(NoopRagRetriever));
492
493 let event = SessionEvent::AsrFinal {
494 track_id: "track-1".to_string(),
495 timestamp: 0,
496 index: 0,
497 start_time: None,
498 end_time: None,
499 text: "hello".to_string(),
500 };
501
502 let commands = handler.on_event(&event).await?;
503 assert!(matches!(
504 commands.get(0),
505 Some(Command::Tts {
506 text,
507 wait_input_timeout: Some(15000),
508 ..
509 }) if text == "Goodbye"
510 ));
511 assert!(commands.iter().any(|cmd| matches!(
512 cmd,
513 Command::Hangup {
514 reason: Some(reason),
515 initiator: Some(origin),
516 } if reason == "done" && origin == "agent"
517 )));
518 assert!(commands.iter().any(|cmd| matches!(
519 cmd,
520 Command::Refer {
521 caller,
522 callee,
523 ..
524 } if caller == "sip:bot" && callee == "sip:lead"
525 )));
526
527 Ok(())
528 }
529
530 #[tokio::test]
531 async fn handler_requeries_after_rag() -> Result<()> {
532 let rag_instruction = r#"{"tools": [{"name": "rag", "query": "policy"}]}"#;
533 let provider = Box::new(TestProvider::new(vec![
534 rag_instruction.to_string(),
535 "Final answer".to_string(),
536 ]));
537 let rag = Arc::new(RecordingRag::new());
538 let mut handler = LlmHandler::with_provider(LlmConfig::default(), provider, rag.clone());
539
540 let event = SessionEvent::AsrFinal {
541 track_id: "track-2".to_string(),
542 timestamp: 0,
543 index: 0,
544 start_time: None,
545 end_time: None,
546 text: "reep".to_string(),
547 };
548
549 let commands = handler.on_event(&event).await?;
550 assert!(matches!(
551 commands.get(0),
552 Some(Command::Tts {
553 text,
554 wait_input_timeout: Some(timeout),
555 ..
556 }) if text == "Final answer" && *timeout == 10000
557 ));
558 assert_eq!(rag.recorded_queries(), vec!["policy".to_string()]);
559
560 Ok(())
561 }
562
563 #[tokio::test]
564 async fn test_full_dialogue_flow() -> Result<()> {
565 let responses = vec![
566 "Hello! How can I help you today?".to_string(),
567 r#"{"text": "I can help with that. Anything else?", "waitInputTimeout": 5000}"#
568 .to_string(),
569 r#"{"text": "Goodbye!", "tools": [{"name": "hangup", "reason": "completed"}]}"#
570 .to_string(),
571 ];
572
573 let provider = Box::new(TestProvider::new(responses));
574 let config = LlmConfig {
575 greeting: Some("Welcome to the voice assistant.".to_string()),
576 ..Default::default()
577 };
578
579 let mut handler = LlmHandler::with_provider(config, provider, Arc::new(NoopRagRetriever));
580
581 let commands = handler.on_start().await?;
583 assert_eq!(commands.len(), 1);
584 if let Command::Tts { text, .. } = &commands[0] {
585 assert_eq!(text, "Welcome to the voice assistant.");
586 } else {
587 panic!("Expected Tts command");
588 }
589
590 let event = SessionEvent::AsrFinal {
592 track_id: "test".to_string(),
593 timestamp: 0,
594 index: 0,
595 start_time: None,
596 end_time: None,
597 text: "I need help".to_string(),
598 };
599 let commands = handler.on_event(&event).await?;
600 assert_eq!(commands.len(), 1);
601 if let Command::Tts { text, .. } = &commands[0] {
602 assert_eq!(text, "Hello! How can I help you today?");
603 } else {
604 panic!("Expected Tts command");
605 }
606
607 let event = SessionEvent::AsrFinal {
609 track_id: "test".to_string(),
610 timestamp: 0,
611 index: 1,
612 start_time: None,
613 end_time: None,
614 text: "Tell me a joke".to_string(),
615 };
616 let commands = handler.on_event(&event).await?;
617 assert_eq!(commands.len(), 1);
618 if let Command::Tts {
619 text,
620 wait_input_timeout,
621 ..
622 } = &commands[0]
623 {
624 assert_eq!(text, "I can help with that. Anything else?");
625 assert_eq!(*wait_input_timeout, Some(5000));
626 } else {
627 panic!("Expected Tts command");
628 }
629
630 let event = SessionEvent::AsrFinal {
632 track_id: "test".to_string(),
633 timestamp: 0,
634 index: 2,
635 start_time: None,
636 end_time: None,
637 text: "That's all, thanks".to_string(),
638 };
639 let commands = handler.on_event(&event).await?;
640 assert_eq!(commands.len(), 2);
642
643 let has_tts = commands
644 .iter()
645 .any(|c| matches!(c, Command::Tts { text, .. } if text == "Goodbye!"));
646 let has_hangup = commands.iter().any(|c| matches!(c, Command::Hangup { .. }));
647
648 assert!(has_tts);
649 assert!(has_hangup);
650
651 Ok(())
652 }
653
654 #[tokio::test]
655 async fn test_interruption_logic() -> Result<()> {
656 let provider = Box::new(TestProvider::new(vec!["Some long response".to_string()]));
657 let mut handler =
658 LlmHandler::with_provider(LlmConfig::default(), provider, Arc::new(NoopRagRetriever));
659
660 let event = SessionEvent::AsrFinal {
662 track_id: "test".to_string(),
663 timestamp: 0,
664 index: 0,
665 start_time: None,
666 end_time: None,
667 text: "hello".to_string(),
668 };
669 handler.on_event(&event).await?;
670 assert!(handler.is_speaking);
671
672 let event = SessionEvent::AsrDelta {
674 track_id: "test".to_string(),
675 timestamp: 0,
676 index: 0,
677 start_time: None,
678 end_time: None,
679 text: "I...".to_string(),
680 };
681 let commands = handler.on_event(&event).await?;
682 assert_eq!(commands.len(), 1);
683 assert!(matches!(commands[0], Command::Interrupt { .. }));
684 assert!(!handler.is_speaking);
685
686 Ok(())
687 }
688
689 #[tokio::test]
690 async fn test_rag_iteration_limit() -> Result<()> {
691 let rag_instruction = r#"{"tools": [{"name": "rag", "query": "endless"}]}"#;
693 let provider = Box::new(TestProvider::new(vec![
694 rag_instruction.to_string(),
695 rag_instruction.to_string(),
696 rag_instruction.to_string(),
697 rag_instruction.to_string(),
698 "Should not reach here".to_string(),
699 ]));
700
701 let mut handler = LlmHandler::with_provider(
702 LlmConfig::default(),
703 provider,
704 Arc::new(RecordingRag::new()),
705 );
706
707 let event = SessionEvent::AsrFinal {
708 track_id: "test".to_string(),
709 timestamp: 0,
710 index: 0,
711 start_time: None,
712 end_time: None,
713 text: "loop".to_string(),
714 };
715
716 let commands = handler.on_event(&event).await?;
717 assert_eq!(commands.len(), 1);
719 if let Command::Tts { text, .. } = &commands[0] {
720 assert_eq!(text, rag_instruction);
721 }
722
723 Ok(())
724 }
725}