active_call/playbook/
handler.rs

1use crate::call::Command;
2use anyhow::{Result, anyhow};
3use async_trait::async_trait;
4use reqwest::Client;
5use serde::{Deserialize, Serialize};
6use serde_json::json;
7use std::sync::Arc;
8use tracing::{info, warn};
9use voice_engine::ReferOption;
10use voice_engine::event::SessionEvent;
11
12use super::LlmConfig;
13use super::dialogue::DialogueHandler;
14
15#[derive(Serialize, Deserialize, Clone, Debug)]
16pub struct ChatMessage {
17    pub role: String,
18    pub content: String,
19}
20
21const MAX_RAG_ATTEMPTS: usize = 3;
22
23#[async_trait]
24pub trait LlmProvider: Send + Sync {
25    async fn call(&self, config: &LlmConfig, history: &[ChatMessage]) -> Result<String>;
26}
27
28struct DefaultLlmProvider {
29    client: Client,
30}
31
32impl DefaultLlmProvider {
33    fn new() -> Self {
34        Self {
35            client: Client::new(),
36        }
37    }
38}
39
40#[async_trait]
41impl LlmProvider for DefaultLlmProvider {
42    async fn call(&self, config: &LlmConfig, history: &[ChatMessage]) -> Result<String> {
43        let mut url = config
44            .base_url
45            .clone()
46            .unwrap_or_else(|| "https://api.openai.com/v1/chat/completions".to_string());
47        let model = config
48            .model
49            .clone()
50            .unwrap_or_else(|| "gpt-3.5-turbo".to_string());
51        let api_key = config.api_key.clone().unwrap_or_default();
52
53        if !url.ends_with("/chat/completions") {
54            url = format!("{}/chat/completions", url.trim_end_matches('/'));
55        }
56
57        let body = json!({
58            "model": model,
59            "messages": history,
60        });
61
62        let res = self
63            .client
64            .post(&url)
65            .header("Authorization", format!("Bearer {}", api_key))
66            .json(&body)
67            .send()
68            .await?;
69
70        if !res.status().is_success() {
71            return Err(anyhow!("LLM request failed: {}", res.status()));
72        }
73
74        let json: serde_json::Value = res.json().await?;
75        let content = json["choices"][0]["message"]["content"]
76            .as_str()
77            .ok_or_else(|| anyhow!("Invalid LLM response"))?
78            .to_string();
79
80        Ok(content)
81    }
82}
83
84#[async_trait]
85pub trait RagRetriever: Send + Sync {
86    async fn retrieve(&self, query: &str) -> Result<String>;
87}
88
89struct NoopRagRetriever;
90
91#[async_trait]
92impl RagRetriever for NoopRagRetriever {
93    async fn retrieve(&self, _query: &str) -> Result<String> {
94        Ok(String::new())
95    }
96}
97
98#[derive(Debug, Deserialize)]
99#[serde(rename_all = "camelCase")]
100struct StructuredResponse {
101    text: Option<String>,
102    wait_input_timeout: Option<u32>,
103    tools: Option<Vec<ToolInvocation>>,
104}
105
106#[derive(Debug, Deserialize)]
107#[serde(tag = "name", rename_all = "lowercase")]
108enum ToolInvocation {
109    #[serde(rename_all = "camelCase")]
110    Hangup {
111        reason: Option<String>,
112        initiator: Option<String>,
113    },
114    #[serde(rename_all = "camelCase")]
115    Refer {
116        caller: String,
117        callee: String,
118        options: Option<ReferOption>,
119    },
120    #[serde(rename_all = "camelCase")]
121    Rag {
122        query: String,
123        source: Option<String>,
124    },
125}
126
127pub struct LlmHandler {
128    config: LlmConfig,
129    history: Vec<ChatMessage>,
130    provider: Box<dyn LlmProvider>,
131    rag_retriever: Arc<dyn RagRetriever>,
132    is_speaking: bool,
133}
134
135impl LlmHandler {
136    pub fn new(config: LlmConfig) -> Self {
137        Self::with_provider(
138            config,
139            Box::new(DefaultLlmProvider::new()),
140            Arc::new(NoopRagRetriever),
141        )
142    }
143
144    pub fn with_provider(
145        config: LlmConfig,
146        provider: Box<dyn LlmProvider>,
147        rag_retriever: Arc<dyn RagRetriever>,
148    ) -> Self {
149        let mut history = Vec::new();
150        if let Some(prompt) = &config.prompt {
151            history.push(ChatMessage {
152                role: "system".to_string(),
153                content: prompt.clone(),
154            });
155        }
156
157        Self {
158            config,
159            history,
160            provider,
161            rag_retriever,
162            is_speaking: false,
163        }
164    }
165
166    async fn call_llm(&self) -> Result<String> {
167        self.provider.call(&self.config, &self.history).await
168    }
169
170    fn create_tts_command(&self, text: String, wait_input_timeout: Option<u32>) -> Command {
171        let timeout = wait_input_timeout.unwrap_or(10000);
172        Command::Tts {
173            text,
174            speaker: None,
175            play_id: None,
176            auto_hangup: None,
177            streaming: None,
178            end_of_stream: None,
179            option: None,
180            wait_input_timeout: Some(timeout),
181            base64: None,
182        }
183    }
184
185    async fn generate_response(&mut self) -> Result<Vec<Command>> {
186        let initial = self.call_llm().await?;
187        self.interpret_response(initial).await
188    }
189
190    async fn interpret_response(&mut self, initial: String) -> Result<Vec<Command>> {
191        let mut tool_commands = Vec::new();
192        let mut wait_input_timeout = None;
193        let mut attempts = 0;
194        let final_text: Option<String>;
195        let mut raw = initial;
196
197        loop {
198            attempts += 1;
199            let mut rerun_for_rag = false;
200
201            if let Some(structured) = parse_structured_response(&raw) {
202                if wait_input_timeout.is_none() {
203                    wait_input_timeout = structured.wait_input_timeout;
204                }
205
206                if let Some(tools) = structured.tools {
207                    for tool in tools {
208                        match tool {
209                            ToolInvocation::Hangup { reason, initiator } => {
210                                tool_commands.push(Command::Hangup { reason, initiator });
211                            }
212                            ToolInvocation::Refer {
213                                caller,
214                                callee,
215                                options,
216                            } => {
217                                tool_commands.push(Command::Refer {
218                                    caller,
219                                    callee,
220                                    options,
221                                });
222                            }
223                            ToolInvocation::Rag { query, source } => {
224                                let rag_result = self.rag_retriever.retrieve(&query).await?;
225                                let summary = if let Some(source) = source {
226                                    format!("[{}] {}", source, rag_result)
227                                } else {
228                                    rag_result
229                                };
230                                self.history.push(ChatMessage {
231                                    role: "system".to_string(),
232                                    content: format!("RAG result for {}: {}", query, summary),
233                                });
234                                rerun_for_rag = true;
235                            }
236                        }
237                    }
238                }
239
240                if rerun_for_rag {
241                    if attempts >= MAX_RAG_ATTEMPTS {
242                        warn!("Reached RAG iteration limit, using last response");
243                        final_text = structured.text.or_else(|| Some(raw.clone()));
244                        break;
245                    }
246                    raw = self.call_llm().await?;
247                    continue;
248                }
249
250                final_text = Some(structured.text.unwrap_or_else(|| raw.clone()));
251                break;
252            }
253
254            final_text = Some(raw.clone());
255            break;
256        }
257
258        let mut commands = Vec::new();
259        if let Some(text) = final_text {
260            if !text.trim().is_empty() {
261                self.history.push(ChatMessage {
262                    role: "assistant".to_string(),
263                    content: text.clone(),
264                });
265                self.is_speaking = true;
266                commands.push(self.create_tts_command(text, wait_input_timeout));
267            }
268        }
269
270        commands.extend(tool_commands);
271
272        Ok(commands)
273    }
274}
275
276fn parse_structured_response(raw: &str) -> Option<StructuredResponse> {
277    let payload = extract_json_block(raw)?;
278    serde_json::from_str(payload).ok()
279}
280
281fn extract_json_block(raw: &str) -> Option<&str> {
282    let trimmed = raw.trim();
283    if trimmed.starts_with('`') {
284        if let Some(end) = trimmed.rfind("```") {
285            if end <= 3 {
286                return None;
287            }
288            let mut inner = &trimmed[3..end];
289            inner = inner.trim();
290            if inner.to_lowercase().starts_with("json") {
291                if let Some(newline) = inner.find('\n') {
292                    inner = inner[newline + 1..].trim();
293                } else if inner.len() > 4 {
294                    inner = inner[4..].trim();
295                } else {
296                    inner = inner.trim();
297                }
298            }
299            return Some(inner);
300        }
301    } else if trimmed.starts_with('{') || trimmed.starts_with('[') {
302        return Some(trimmed);
303    }
304    None
305}
306
307#[async_trait]
308impl DialogueHandler for LlmHandler {
309    async fn on_start(&mut self) -> Result<Vec<Command>> {
310        if let Some(greeting) = &self.config.greeting {
311            self.is_speaking = true;
312            return Ok(vec![self.create_tts_command(greeting.clone(), None)]);
313        }
314
315        self.generate_response().await
316    }
317
318    async fn on_event(&mut self, event: &SessionEvent) -> Result<Vec<Command>> {
319        match event {
320            SessionEvent::AsrFinal { text, .. } => {
321                if text.trim().is_empty() {
322                    return Ok(vec![]);
323                }
324
325                self.history.push(ChatMessage {
326                    role: "user".to_string(),
327                    content: text.clone(),
328                });
329
330                self.generate_response().await
331            }
332
333            SessionEvent::AsrDelta { .. } | SessionEvent::Speaking { .. } => {
334                if self.is_speaking {
335                    info!("Interruption detected, stopping playback");
336                    self.is_speaking = false;
337                    return Ok(vec![Command::Interrupt {
338                        graceful: Some(true),
339                    }]);
340                }
341                Ok(vec![])
342            }
343
344            SessionEvent::Silence { .. } => {
345                info!("Silence timeout detected, triggering follow-up");
346                self.generate_response().await
347            }
348
349            SessionEvent::TrackEnd { .. } => {
350                self.is_speaking = false;
351                Ok(vec![])
352            }
353
354            _ => Ok(vec![]),
355        }
356    }
357}
358
359#[cfg(test)]
360mod tests {
361    use super::*;
362    use anyhow::{Result, anyhow};
363    use async_trait::async_trait;
364    use std::collections::VecDeque;
365    use std::sync::Mutex;
366    use voice_engine::event::SessionEvent;
367
368    struct TestProvider {
369        responses: Mutex<VecDeque<String>>,
370    }
371
372    impl TestProvider {
373        fn new(responses: Vec<String>) -> Self {
374            Self {
375                responses: Mutex::new(VecDeque::from(responses)),
376            }
377        }
378    }
379
380    #[async_trait]
381    impl LlmProvider for TestProvider {
382        async fn call(&self, _config: &LlmConfig, _history: &[ChatMessage]) -> Result<String> {
383            let mut guard = self.responses.lock().unwrap();
384            guard
385                .pop_front()
386                .ok_or_else(|| anyhow!("Test provider ran out of responses"))
387        }
388    }
389
390    struct RecordingRag {
391        queries: Mutex<Vec<String>>,
392    }
393
394    impl RecordingRag {
395        fn new() -> Self {
396            Self {
397                queries: Mutex::new(Vec::new()),
398            }
399        }
400
401        fn recorded_queries(&self) -> Vec<String> {
402            self.queries.lock().unwrap().clone()
403        }
404    }
405
406    #[async_trait]
407    impl RagRetriever for RecordingRag {
408        async fn retrieve(&self, query: &str) -> Result<String> {
409            self.queries.lock().unwrap().push(query.to_string());
410            Ok(format!("retrieved {}", query))
411        }
412    }
413
414    #[tokio::test]
415    async fn handler_applies_tool_instructions() -> Result<()> {
416        let response = r#"{
417            "text": "Goodbye",
418            "waitInputTimeout": 15000,
419            "tools": [
420                {"name": "hangup", "reason": "done", "initiator": "agent"},
421                {"name": "refer", "caller": "sip:bot", "callee": "sip:lead"}
422            ]
423        }"#;
424
425        let provider = Box::new(TestProvider::new(vec![response.to_string()]));
426        let mut handler =
427            LlmHandler::with_provider(LlmConfig::default(), provider, Arc::new(NoopRagRetriever));
428
429        let event = SessionEvent::AsrFinal {
430            track_id: "track-1".to_string(),
431            timestamp: 0,
432            index: 0,
433            start_time: None,
434            end_time: None,
435            text: "hello".to_string(),
436        };
437
438        let commands = handler.on_event(&event).await?;
439        assert!(matches!(
440            commands.get(0),
441            Some(Command::Tts {
442                text,
443                wait_input_timeout: Some(15000),
444                ..
445            }) if text == "Goodbye"
446        ));
447        assert!(commands.iter().any(|cmd| matches!(
448            cmd,
449            Command::Hangup {
450                reason: Some(reason),
451                initiator: Some(origin),
452            } if reason == "done" && origin == "agent"
453        )));
454        assert!(commands.iter().any(|cmd| matches!(
455            cmd,
456            Command::Refer {
457                caller,
458                callee,
459                ..
460            } if caller == "sip:bot" && callee == "sip:lead"
461        )));
462
463        Ok(())
464    }
465
466    #[tokio::test]
467    async fn handler_requeries_after_rag() -> Result<()> {
468        let rag_instruction = r#"{"tools": [{"name": "rag", "query": "policy"}]}"#;
469        let provider = Box::new(TestProvider::new(vec![
470            rag_instruction.to_string(),
471            "Final answer".to_string(),
472        ]));
473        let rag = Arc::new(RecordingRag::new());
474        let mut handler = LlmHandler::with_provider(LlmConfig::default(), provider, rag.clone());
475
476        let event = SessionEvent::AsrFinal {
477            track_id: "track-2".to_string(),
478            timestamp: 0,
479            index: 0,
480            start_time: None,
481            end_time: None,
482            text: "reep".to_string(),
483        };
484
485        let commands = handler.on_event(&event).await?;
486        assert!(matches!(
487            commands.get(0),
488            Some(Command::Tts {
489                text,
490                wait_input_timeout: Some(timeout),
491                ..
492            }) if text == "Final answer" && *timeout == 10000
493        ));
494        assert_eq!(rag.recorded_queries(), vec!["policy".to_string()]);
495
496        Ok(())
497    }
498}