active_call/playbook/
handler.rs

1use crate::call::Command;
2use anyhow::{Result, anyhow};
3use async_trait::async_trait;
4use reqwest::Client;
5use serde::{Deserialize, Serialize};
6use serde_json::json;
7use std::sync::Arc;
8use tracing::{info, warn};
9use crate::ReferOption;
10use crate::event::SessionEvent;
11
12use super::LlmConfig;
13use super::dialogue::DialogueHandler;
14
15#[derive(Serialize, Deserialize, Clone, Debug)]
16pub struct ChatMessage {
17    pub role: String,
18    pub content: String,
19}
20
21const MAX_RAG_ATTEMPTS: usize = 3;
22
23#[async_trait]
24pub trait LlmProvider: Send + Sync {
25    async fn call(&self, config: &LlmConfig, history: &[ChatMessage]) -> Result<String>;
26}
27
28struct DefaultLlmProvider {
29    client: Client,
30}
31
32impl DefaultLlmProvider {
33    fn new() -> Self {
34        Self {
35            client: Client::new(),
36        }
37    }
38}
39
40#[async_trait]
41impl LlmProvider for DefaultLlmProvider {
42    async fn call(&self, config: &LlmConfig, history: &[ChatMessage]) -> Result<String> {
43        let mut url = config
44            .base_url
45            .clone()
46            .unwrap_or_else(|| "https://api.openai.com/v1/chat/completions".to_string());
47        let model = config
48            .model
49            .clone()
50            .unwrap_or_else(|| "gpt-3.5-turbo".to_string());
51        let api_key = config.api_key.clone().unwrap_or_default();
52
53        if !url.ends_with("/chat/completions") {
54            url = format!("{}/chat/completions", url.trim_end_matches('/'));
55        }
56
57        let body = json!({
58            "model": model,
59            "messages": history,
60        });
61
62        let res = self
63            .client
64            .post(&url)
65            .header("Authorization", format!("Bearer {}", api_key))
66            .json(&body)
67            .send()
68            .await?;
69
70        if !res.status().is_success() {
71            return Err(anyhow!("LLM request failed: {}", res.status()));
72        }
73
74        let json: serde_json::Value = res.json().await?;
75        let content = json["choices"][0]["message"]["content"]
76            .as_str()
77            .ok_or_else(|| anyhow!("Invalid LLM response"))?
78            .to_string();
79
80        Ok(content)
81    }
82}
83
84#[async_trait]
85pub trait RagRetriever: Send + Sync {
86    async fn retrieve(&self, query: &str) -> Result<String>;
87}
88
89struct NoopRagRetriever;
90
91#[async_trait]
92impl RagRetriever for NoopRagRetriever {
93    async fn retrieve(&self, _query: &str) -> Result<String> {
94        Ok(String::new())
95    }
96}
97
98#[derive(Debug, Deserialize)]
99#[serde(rename_all = "camelCase")]
100struct StructuredResponse {
101    text: Option<String>,
102    wait_input_timeout: Option<u32>,
103    tools: Option<Vec<ToolInvocation>>,
104}
105
106#[derive(Debug, Deserialize)]
107#[serde(tag = "name", rename_all = "lowercase")]
108enum ToolInvocation {
109    #[serde(rename_all = "camelCase")]
110    Hangup {
111        reason: Option<String>,
112        initiator: Option<String>,
113    },
114    #[serde(rename_all = "camelCase")]
115    Refer {
116        caller: String,
117        callee: String,
118        options: Option<ReferOption>,
119    },
120    #[serde(rename_all = "camelCase")]
121    Rag {
122        query: String,
123        source: Option<String>,
124    },
125}
126
127pub struct LlmHandler {
128    config: LlmConfig,
129    history: Vec<ChatMessage>,
130    provider: Box<dyn LlmProvider>,
131    rag_retriever: Arc<dyn RagRetriever>,
132    is_speaking: bool,
133    event_sender: Option<crate::event::EventSender>,
134}
135
136impl LlmHandler {
137    pub fn new(config: LlmConfig) -> Self {
138        Self::with_provider(
139            config,
140            Box::new(DefaultLlmProvider::new()),
141            Arc::new(NoopRagRetriever),
142        )
143    }
144
145    pub fn with_provider(
146        config: LlmConfig,
147        provider: Box<dyn LlmProvider>,
148        rag_retriever: Arc<dyn RagRetriever>,
149    ) -> Self {
150        let mut history = Vec::new();
151        if let Some(prompt) = &config.prompt {
152            history.push(ChatMessage {
153                role: "system".to_string(),
154                content: prompt.clone(),
155            });
156        }
157
158        Self {
159            config,
160            history,
161            provider,
162            rag_retriever,
163            is_speaking: false,
164            event_sender: None,
165        }
166    }
167
168    pub fn set_event_sender(&mut self, sender: crate::event::EventSender) {
169        self.event_sender = Some(sender);
170    }
171
172    fn send_debug_event(&self, key: &str, data: serde_json::Value) {
173        if let Some(sender) = &self.event_sender {
174            let event = crate::event::SessionEvent::Metrics {
175                timestamp: crate::media::get_timestamp(),
176                key: key.to_string(),
177                duration: 0,
178                data,
179            };
180            let _ = sender.send(event);
181        }
182    }
183
184    async fn call_llm(&self) -> Result<String> {
185        self.provider.call(&self.config, &self.history).await
186    }
187
188    fn create_tts_command(&self, text: String, wait_input_timeout: Option<u32>) -> Command {
189        let timeout = wait_input_timeout.unwrap_or(10000);
190        Command::Tts {
191            text,
192            speaker: None,
193            play_id: None,
194            auto_hangup: None,
195            streaming: None,
196            end_of_stream: None,
197            option: None,
198            wait_input_timeout: Some(timeout),
199            base64: None,
200        }
201    }
202
203    async fn generate_response(&mut self) -> Result<Vec<Command>> {
204        // Send debug event - LLM call started
205        self.send_debug_event("llm_call_start", json!({
206            "history_length": self.history.len(),
207        }));
208
209        let initial = self.call_llm().await?;
210
211        // Send debug event - LLM response received
212        self.send_debug_event("llm_response", json!({
213            "response": initial,
214        }));
215
216        self.interpret_response(initial).await
217    }
218
219    async fn interpret_response(&mut self, initial: String) -> Result<Vec<Command>> {
220        let mut tool_commands = Vec::new();
221        let mut wait_input_timeout = None;
222        let mut attempts = 0;
223        let final_text: Option<String>;
224        let mut raw = initial;
225
226        loop {
227            attempts += 1;
228            let mut rerun_for_rag = false;
229
230            if let Some(structured) = parse_structured_response(&raw) {
231                if wait_input_timeout.is_none() {
232                    wait_input_timeout = structured.wait_input_timeout;
233                }
234
235                if let Some(tools) = structured.tools {
236                    for tool in tools {
237                        match tool {
238                            ToolInvocation::Hangup { ref reason, ref initiator } => {
239                                // Send debug event
240                                self.send_debug_event("tool_invocation", json!({
241                                    "tool": "Hangup",
242                                    "params": {
243                                        "reason": reason,
244                                        "initiator": initiator,
245                                    }
246                                }));
247                                tool_commands.push(Command::Hangup {
248                                    reason: reason.clone(),
249                                    initiator: initiator.clone()
250                                });
251                            }
252                            ToolInvocation::Refer {
253                                ref caller,
254                                ref callee,
255                                ref options,
256                            } => {
257                                // Send debug event
258                                self.send_debug_event("tool_invocation", json!({
259                                    "tool": "Refer",
260                                    "params": {
261                                        "caller": caller,
262                                        "callee": callee,
263                                    }
264                                }));
265                                tool_commands.push(Command::Refer {
266                                    caller: caller.clone(),
267                                    callee: callee.clone(),
268                                    options: options.clone(),
269                                });
270                            }
271                            ToolInvocation::Rag { ref query, ref source } => {
272                                // Send debug event - RAG query started
273                                self.send_debug_event("tool_invocation", json!({
274                                    "tool": "Rag",
275                                    "params": {
276                                        "query": query,
277                                        "source": source,
278                                    }
279                                }));
280
281                                let rag_result = self.rag_retriever.retrieve(&query).await?;
282
283                                // Send debug event - RAG result
284                                self.send_debug_event("rag_result", json!({
285                                    "query": query,
286                                    "result": rag_result,
287                                }));
288
289                                let summary = if let Some(source) = source {
290                                    format!("[{}] {}", source, rag_result)
291                                } else {
292                                    rag_result
293                                };
294                                self.history.push(ChatMessage {
295                                    role: "system".to_string(),
296                                    content: format!("RAG result for {}: {}", query, summary),
297                                });
298                                rerun_for_rag = true;
299                            }
300                        }
301                    }
302                }
303
304                if rerun_for_rag {
305                    if attempts >= MAX_RAG_ATTEMPTS {
306                        warn!("Reached RAG iteration limit, using last response");
307                        final_text = structured.text.or_else(|| Some(raw.clone()));
308                        break;
309                    }
310                    raw = self.call_llm().await?;
311                    continue;
312                }
313
314                final_text = Some(structured.text.unwrap_or_else(|| raw.clone()));
315                break;
316            }
317
318            final_text = Some(raw.clone());
319            break;
320        }
321
322        let mut commands = Vec::new();
323        if let Some(text) = final_text {
324            if !text.trim().is_empty() {
325                self.history.push(ChatMessage {
326                    role: "assistant".to_string(),
327                    content: text.clone(),
328                });
329                self.is_speaking = true;
330                commands.push(self.create_tts_command(text, wait_input_timeout));
331            }
332        }
333
334        commands.extend(tool_commands);
335
336        Ok(commands)
337    }
338}
339
340fn parse_structured_response(raw: &str) -> Option<StructuredResponse> {
341    let payload = extract_json_block(raw)?;
342    serde_json::from_str(payload).ok()
343}
344
345fn extract_json_block(raw: &str) -> Option<&str> {
346    let trimmed = raw.trim();
347    if trimmed.starts_with('`') {
348        if let Some(end) = trimmed.rfind("```") {
349            if end <= 3 {
350                return None;
351            }
352            let mut inner = &trimmed[3..end];
353            inner = inner.trim();
354            if inner.to_lowercase().starts_with("json") {
355                if let Some(newline) = inner.find('\n') {
356                    inner = inner[newline + 1..].trim();
357                } else if inner.len() > 4 {
358                    inner = inner[4..].trim();
359                } else {
360                    inner = inner.trim();
361                }
362            }
363            return Some(inner);
364        }
365    } else if trimmed.starts_with('{') || trimmed.starts_with('[') {
366        return Some(trimmed);
367    }
368    None
369}
370
371#[async_trait]
372impl DialogueHandler for LlmHandler {
373    async fn on_start(&mut self) -> Result<Vec<Command>> {
374        if let Some(greeting) = &self.config.greeting {
375            self.is_speaking = true;
376            return Ok(vec![self.create_tts_command(greeting.clone(), None)]);
377        }
378
379        self.generate_response().await
380    }
381
382    async fn on_event(&mut self, event: &SessionEvent) -> Result<Vec<Command>> {
383        match event {
384            SessionEvent::AsrFinal { text, .. } => {
385                if text.trim().is_empty() {
386                    return Ok(vec![]);
387                }
388
389                self.history.push(ChatMessage {
390                    role: "user".to_string(),
391                    content: text.clone(),
392                });
393
394                self.generate_response().await
395            }
396
397            SessionEvent::AsrDelta { .. } | SessionEvent::Speaking { .. } => {
398                if self.is_speaking {
399                    info!("Interruption detected, stopping playback");
400                    self.is_speaking = false;
401                    return Ok(vec![Command::Interrupt {
402                        graceful: Some(true),
403                    }]);
404                }
405                Ok(vec![])
406            }
407
408            SessionEvent::Silence { .. } => {
409                info!("Silence timeout detected, triggering follow-up");
410                self.generate_response().await
411            }
412
413            SessionEvent::TrackEnd { .. } => {
414                self.is_speaking = false;
415                Ok(vec![])
416            }
417
418            _ => Ok(vec![]),
419        }
420    }
421}
422
423#[cfg(test)]
424mod tests {
425    use super::*;
426    use anyhow::{Result, anyhow};
427    use async_trait::async_trait;
428    use std::collections::VecDeque;
429    use std::sync::Mutex;
430    use crate::event::SessionEvent;
431
432    struct TestProvider {
433        responses: Mutex<VecDeque<String>>,
434    }
435
436    impl TestProvider {
437        fn new(responses: Vec<String>) -> Self {
438            Self {
439                responses: Mutex::new(VecDeque::from(responses)),
440            }
441        }
442    }
443
444    #[async_trait]
445    impl LlmProvider for TestProvider {
446        async fn call(&self, _config: &LlmConfig, _history: &[ChatMessage]) -> Result<String> {
447            let mut guard = self.responses.lock().unwrap();
448            guard
449                .pop_front()
450                .ok_or_else(|| anyhow!("Test provider ran out of responses"))
451        }
452    }
453
454    struct RecordingRag {
455        queries: Mutex<Vec<String>>,
456    }
457
458    impl RecordingRag {
459        fn new() -> Self {
460            Self {
461                queries: Mutex::new(Vec::new()),
462            }
463        }
464
465        fn recorded_queries(&self) -> Vec<String> {
466            self.queries.lock().unwrap().clone()
467        }
468    }
469
470    #[async_trait]
471    impl RagRetriever for RecordingRag {
472        async fn retrieve(&self, query: &str) -> Result<String> {
473            self.queries.lock().unwrap().push(query.to_string());
474            Ok(format!("retrieved {}", query))
475        }
476    }
477
478    #[tokio::test]
479    async fn handler_applies_tool_instructions() -> Result<()> {
480        let response = r#"{
481            "text": "Goodbye",
482            "waitInputTimeout": 15000,
483            "tools": [
484                {"name": "hangup", "reason": "done", "initiator": "agent"},
485                {"name": "refer", "caller": "sip:bot", "callee": "sip:lead"}
486            ]
487        }"#;
488
489        let provider = Box::new(TestProvider::new(vec![response.to_string()]));
490        let mut handler =
491            LlmHandler::with_provider(LlmConfig::default(), provider, Arc::new(NoopRagRetriever));
492
493        let event = SessionEvent::AsrFinal {
494            track_id: "track-1".to_string(),
495            timestamp: 0,
496            index: 0,
497            start_time: None,
498            end_time: None,
499            text: "hello".to_string(),
500        };
501
502        let commands = handler.on_event(&event).await?;
503        assert!(matches!(
504            commands.get(0),
505            Some(Command::Tts {
506                text,
507                wait_input_timeout: Some(15000),
508                ..
509            }) if text == "Goodbye"
510        ));
511        assert!(commands.iter().any(|cmd| matches!(
512            cmd,
513            Command::Hangup {
514                reason: Some(reason),
515                initiator: Some(origin),
516            } if reason == "done" && origin == "agent"
517        )));
518        assert!(commands.iter().any(|cmd| matches!(
519            cmd,
520            Command::Refer {
521                caller,
522                callee,
523                ..
524            } if caller == "sip:bot" && callee == "sip:lead"
525        )));
526
527        Ok(())
528    }
529
530    #[tokio::test]
531    async fn handler_requeries_after_rag() -> Result<()> {
532        let rag_instruction = r#"{"tools": [{"name": "rag", "query": "policy"}]}"#;
533        let provider = Box::new(TestProvider::new(vec![
534            rag_instruction.to_string(),
535            "Final answer".to_string(),
536        ]));
537        let rag = Arc::new(RecordingRag::new());
538        let mut handler = LlmHandler::with_provider(LlmConfig::default(), provider, rag.clone());
539
540        let event = SessionEvent::AsrFinal {
541            track_id: "track-2".to_string(),
542            timestamp: 0,
543            index: 0,
544            start_time: None,
545            end_time: None,
546            text: "reep".to_string(),
547        };
548
549        let commands = handler.on_event(&event).await?;
550        assert!(matches!(
551            commands.get(0),
552            Some(Command::Tts {
553                text,
554                wait_input_timeout: Some(timeout),
555                ..
556            }) if text == "Final answer" && *timeout == 10000
557        ));
558        assert_eq!(rag.recorded_queries(), vec!["policy".to_string()]);
559
560        Ok(())
561    }
562
563    #[tokio::test]
564    async fn test_full_dialogue_flow() -> Result<()> {
565        let responses = vec![
566            "Hello! How can I help you today?".to_string(),
567            r#"{"text": "I can help with that. Anything else?", "waitInputTimeout": 5000}"#
568                .to_string(),
569            r#"{"text": "Goodbye!", "tools": [{"name": "hangup", "reason": "completed"}]}"#
570                .to_string(),
571        ];
572
573        let provider = Box::new(TestProvider::new(responses));
574        let config = LlmConfig {
575            greeting: Some("Welcome to the voice assistant.".to_string()),
576            ..Default::default()
577        };
578
579        let mut handler = LlmHandler::with_provider(config, provider, Arc::new(NoopRagRetriever));
580
581        // 1. Start the dialogue
582        let commands = handler.on_start().await?;
583        assert_eq!(commands.len(), 1);
584        if let Command::Tts { text, .. } = &commands[0] {
585            assert_eq!(text, "Welcome to the voice assistant.");
586        } else {
587            panic!("Expected Tts command");
588        }
589
590        // 2. User says something
591        let event = SessionEvent::AsrFinal {
592            track_id: "test".to_string(),
593            timestamp: 0,
594            index: 0,
595            start_time: None,
596            end_time: None,
597            text: "I need help".to_string(),
598        };
599        let commands = handler.on_event(&event).await?;
600        assert_eq!(commands.len(), 1);
601        if let Command::Tts { text, .. } = &commands[0] {
602            assert_eq!(text, "Hello! How can I help you today?");
603        } else {
604            panic!("Expected Tts command");
605        }
606
607        // 3. User says something else
608        let event = SessionEvent::AsrFinal {
609            track_id: "test".to_string(),
610            timestamp: 0,
611            index: 1,
612            start_time: None,
613            end_time: None,
614            text: "Tell me a joke".to_string(),
615        };
616        let commands = handler.on_event(&event).await?;
617        assert_eq!(commands.len(), 1);
618        if let Command::Tts {
619            text,
620            wait_input_timeout,
621            ..
622        } = &commands[0]
623        {
624            assert_eq!(text, "I can help with that. Anything else?");
625            assert_eq!(*wait_input_timeout, Some(5000));
626        } else {
627            panic!("Expected Tts command");
628        }
629
630        // 4. User says goodbye
631        let event = SessionEvent::AsrFinal {
632            track_id: "test".to_string(),
633            timestamp: 0,
634            index: 2,
635            start_time: None,
636            end_time: None,
637            text: "That's all, thanks".to_string(),
638        };
639        let commands = handler.on_event(&event).await?;
640        // Should have Tts and Hangup
641        assert_eq!(commands.len(), 2);
642
643        let has_tts = commands
644            .iter()
645            .any(|c| matches!(c, Command::Tts { text, .. } if text == "Goodbye!"));
646        let has_hangup = commands.iter().any(|c| matches!(c, Command::Hangup { .. }));
647
648        assert!(has_tts);
649        assert!(has_hangup);
650
651        Ok(())
652    }
653
654    #[tokio::test]
655    async fn test_interruption_logic() -> Result<()> {
656        let provider = Box::new(TestProvider::new(vec!["Some long response".to_string()]));
657        let mut handler =
658            LlmHandler::with_provider(LlmConfig::default(), provider, Arc::new(NoopRagRetriever));
659
660        // 1. Trigger a response
661        let event = SessionEvent::AsrFinal {
662            track_id: "test".to_string(),
663            timestamp: 0,
664            index: 0,
665            start_time: None,
666            end_time: None,
667            text: "hello".to_string(),
668        };
669        handler.on_event(&event).await?;
670        assert!(handler.is_speaking);
671
672        // 2. Simulate user starting to speak (AsrDelta)
673        let event = SessionEvent::AsrDelta {
674            track_id: "test".to_string(),
675            timestamp: 0,
676            index: 0,
677            start_time: None,
678            end_time: None,
679            text: "I...".to_string(),
680        };
681        let commands = handler.on_event(&event).await?;
682        assert_eq!(commands.len(), 1);
683        assert!(matches!(commands[0], Command::Interrupt { .. }));
684        assert!(!handler.is_speaking);
685
686        Ok(())
687    }
688
689    #[tokio::test]
690    async fn test_rag_iteration_limit() -> Result<()> {
691        // Provider that always returns a RAG tool call
692        let rag_instruction = r#"{"tools": [{"name": "rag", "query": "endless"}]}"#;
693        let provider = Box::new(TestProvider::new(vec![
694            rag_instruction.to_string(),
695            rag_instruction.to_string(),
696            rag_instruction.to_string(),
697            rag_instruction.to_string(),
698            "Should not reach here".to_string(),
699        ]));
700
701        let mut handler = LlmHandler::with_provider(
702            LlmConfig::default(),
703            provider,
704            Arc::new(RecordingRag::new()),
705        );
706
707        let event = SessionEvent::AsrFinal {
708            track_id: "test".to_string(),
709            timestamp: 0,
710            index: 0,
711            start_time: None,
712            end_time: None,
713            text: "loop".to_string(),
714        };
715
716        let commands = handler.on_event(&event).await?;
717        // After 3 attempts (MAX_RAG_ATTEMPTS), it should stop and return the last raw response
718        assert_eq!(commands.len(), 1);
719        if let Command::Tts { text, .. } = &commands[0] {
720            assert_eq!(text, rag_instruction);
721        }
722
723        Ok(())
724    }
725}