1use crate::call::Command;
2use anyhow::{Result, anyhow};
3use async_trait::async_trait;
4use reqwest::Client;
5use serde::{Deserialize, Serialize};
6use serde_json::json;
7use std::sync::Arc;
8use tracing::{info, warn};
9use voice_engine::ReferOption;
10use voice_engine::event::SessionEvent;
11
12use super::LlmConfig;
13use super::dialogue::DialogueHandler;
14
15#[derive(Serialize, Deserialize, Clone, Debug)]
16pub struct ChatMessage {
17 pub role: String,
18 pub content: String,
19}
20
21const MAX_RAG_ATTEMPTS: usize = 3;
22
23#[async_trait]
24pub trait LlmProvider: Send + Sync {
25 async fn call(&self, config: &LlmConfig, history: &[ChatMessage]) -> Result<String>;
26}
27
28struct DefaultLlmProvider {
29 client: Client,
30}
31
32impl DefaultLlmProvider {
33 fn new() -> Self {
34 Self {
35 client: Client::new(),
36 }
37 }
38}
39
40#[async_trait]
41impl LlmProvider for DefaultLlmProvider {
42 async fn call(&self, config: &LlmConfig, history: &[ChatMessage]) -> Result<String> {
43 let mut url = config
44 .base_url
45 .clone()
46 .unwrap_or_else(|| "https://api.openai.com/v1/chat/completions".to_string());
47 let model = config
48 .model
49 .clone()
50 .unwrap_or_else(|| "gpt-3.5-turbo".to_string());
51 let api_key = config.api_key.clone().unwrap_or_default();
52
53 if !url.ends_with("/chat/completions") {
54 url = format!("{}/chat/completions", url.trim_end_matches('/'));
55 }
56
57 let body = json!({
58 "model": model,
59 "messages": history,
60 });
61
62 let res = self
63 .client
64 .post(&url)
65 .header("Authorization", format!("Bearer {}", api_key))
66 .json(&body)
67 .send()
68 .await?;
69
70 if !res.status().is_success() {
71 return Err(anyhow!("LLM request failed: {}", res.status()));
72 }
73
74 let json: serde_json::Value = res.json().await?;
75 let content = json["choices"][0]["message"]["content"]
76 .as_str()
77 .ok_or_else(|| anyhow!("Invalid LLM response"))?
78 .to_string();
79
80 Ok(content)
81 }
82}
83
84#[async_trait]
85pub trait RagRetriever: Send + Sync {
86 async fn retrieve(&self, query: &str) -> Result<String>;
87}
88
89struct NoopRagRetriever;
90
91#[async_trait]
92impl RagRetriever for NoopRagRetriever {
93 async fn retrieve(&self, _query: &str) -> Result<String> {
94 Ok(String::new())
95 }
96}
97
98#[derive(Debug, Deserialize)]
99#[serde(rename_all = "camelCase")]
100struct StructuredResponse {
101 text: Option<String>,
102 wait_input_timeout: Option<u32>,
103 tools: Option<Vec<ToolInvocation>>,
104}
105
106#[derive(Debug, Deserialize)]
107#[serde(tag = "name", rename_all = "lowercase")]
108enum ToolInvocation {
109 #[serde(rename_all = "camelCase")]
110 Hangup {
111 reason: Option<String>,
112 initiator: Option<String>,
113 },
114 #[serde(rename_all = "camelCase")]
115 Refer {
116 caller: String,
117 callee: String,
118 options: Option<ReferOption>,
119 },
120 #[serde(rename_all = "camelCase")]
121 Rag {
122 query: String,
123 source: Option<String>,
124 },
125}
126
127pub struct LlmHandler {
128 config: LlmConfig,
129 history: Vec<ChatMessage>,
130 provider: Box<dyn LlmProvider>,
131 rag_retriever: Arc<dyn RagRetriever>,
132 is_speaking: bool,
133}
134
135impl LlmHandler {
136 pub fn new(config: LlmConfig) -> Self {
137 Self::with_provider(
138 config,
139 Box::new(DefaultLlmProvider::new()),
140 Arc::new(NoopRagRetriever),
141 )
142 }
143
144 pub fn with_provider(
145 config: LlmConfig,
146 provider: Box<dyn LlmProvider>,
147 rag_retriever: Arc<dyn RagRetriever>,
148 ) -> Self {
149 let mut history = Vec::new();
150 if let Some(prompt) = &config.prompt {
151 history.push(ChatMessage {
152 role: "system".to_string(),
153 content: prompt.clone(),
154 });
155 }
156
157 Self {
158 config,
159 history,
160 provider,
161 rag_retriever,
162 is_speaking: false,
163 }
164 }
165
166 async fn call_llm(&self) -> Result<String> {
167 self.provider.call(&self.config, &self.history).await
168 }
169
170 fn create_tts_command(&self, text: String, wait_input_timeout: Option<u32>) -> Command {
171 let timeout = wait_input_timeout.unwrap_or(10000);
172 Command::Tts {
173 text,
174 speaker: None,
175 play_id: None,
176 auto_hangup: None,
177 streaming: None,
178 end_of_stream: None,
179 option: None,
180 wait_input_timeout: Some(timeout),
181 base64: None,
182 }
183 }
184
185 async fn generate_response(&mut self) -> Result<Vec<Command>> {
186 let initial = self.call_llm().await?;
187 self.interpret_response(initial).await
188 }
189
190 async fn interpret_response(&mut self, initial: String) -> Result<Vec<Command>> {
191 let mut tool_commands = Vec::new();
192 let mut wait_input_timeout = None;
193 let mut attempts = 0;
194 let final_text: Option<String>;
195 let mut raw = initial;
196
197 loop {
198 attempts += 1;
199 let mut rerun_for_rag = false;
200
201 if let Some(structured) = parse_structured_response(&raw) {
202 if wait_input_timeout.is_none() {
203 wait_input_timeout = structured.wait_input_timeout;
204 }
205
206 if let Some(tools) = structured.tools {
207 for tool in tools {
208 match tool {
209 ToolInvocation::Hangup { reason, initiator } => {
210 tool_commands.push(Command::Hangup { reason, initiator });
211 }
212 ToolInvocation::Refer {
213 caller,
214 callee,
215 options,
216 } => {
217 tool_commands.push(Command::Refer {
218 caller,
219 callee,
220 options,
221 });
222 }
223 ToolInvocation::Rag { query, source } => {
224 let rag_result = self.rag_retriever.retrieve(&query).await?;
225 let summary = if let Some(source) = source {
226 format!("[{}] {}", source, rag_result)
227 } else {
228 rag_result
229 };
230 self.history.push(ChatMessage {
231 role: "system".to_string(),
232 content: format!("RAG result for {}: {}", query, summary),
233 });
234 rerun_for_rag = true;
235 }
236 }
237 }
238 }
239
240 if rerun_for_rag {
241 if attempts >= MAX_RAG_ATTEMPTS {
242 warn!("Reached RAG iteration limit, using last response");
243 final_text = structured.text.or_else(|| Some(raw.clone()));
244 break;
245 }
246 raw = self.call_llm().await?;
247 continue;
248 }
249
250 final_text = Some(structured.text.unwrap_or_else(|| raw.clone()));
251 break;
252 }
253
254 final_text = Some(raw.clone());
255 break;
256 }
257
258 let mut commands = Vec::new();
259 if let Some(text) = final_text {
260 if !text.trim().is_empty() {
261 self.history.push(ChatMessage {
262 role: "assistant".to_string(),
263 content: text.clone(),
264 });
265 self.is_speaking = true;
266 commands.push(self.create_tts_command(text, wait_input_timeout));
267 }
268 }
269
270 commands.extend(tool_commands);
271
272 Ok(commands)
273 }
274}
275
276fn parse_structured_response(raw: &str) -> Option<StructuredResponse> {
277 let payload = extract_json_block(raw)?;
278 serde_json::from_str(payload).ok()
279}
280
281fn extract_json_block(raw: &str) -> Option<&str> {
282 let trimmed = raw.trim();
283 if trimmed.starts_with('`') {
284 if let Some(end) = trimmed.rfind("```") {
285 if end <= 3 {
286 return None;
287 }
288 let mut inner = &trimmed[3..end];
289 inner = inner.trim();
290 if inner.to_lowercase().starts_with("json") {
291 if let Some(newline) = inner.find('\n') {
292 inner = inner[newline + 1..].trim();
293 } else if inner.len() > 4 {
294 inner = inner[4..].trim();
295 } else {
296 inner = inner.trim();
297 }
298 }
299 return Some(inner);
300 }
301 } else if trimmed.starts_with('{') || trimmed.starts_with('[') {
302 return Some(trimmed);
303 }
304 None
305}
306
307#[async_trait]
308impl DialogueHandler for LlmHandler {
309 async fn on_start(&mut self) -> Result<Vec<Command>> {
310 if let Some(greeting) = &self.config.greeting {
311 self.is_speaking = true;
312 return Ok(vec![self.create_tts_command(greeting.clone(), None)]);
313 }
314
315 self.generate_response().await
316 }
317
318 async fn on_event(&mut self, event: &SessionEvent) -> Result<Vec<Command>> {
319 match event {
320 SessionEvent::AsrFinal { text, .. } => {
321 if text.trim().is_empty() {
322 return Ok(vec![]);
323 }
324
325 self.history.push(ChatMessage {
326 role: "user".to_string(),
327 content: text.clone(),
328 });
329
330 self.generate_response().await
331 }
332
333 SessionEvent::AsrDelta { .. } | SessionEvent::Speaking { .. } => {
334 if self.is_speaking {
335 info!("Interruption detected, stopping playback");
336 self.is_speaking = false;
337 return Ok(vec![Command::Interrupt {
338 graceful: Some(true),
339 }]);
340 }
341 Ok(vec![])
342 }
343
344 SessionEvent::Silence { .. } => {
345 info!("Silence timeout detected, triggering follow-up");
346 self.generate_response().await
347 }
348
349 SessionEvent::TrackEnd { .. } => {
350 self.is_speaking = false;
351 Ok(vec![])
352 }
353
354 _ => Ok(vec![]),
355 }
356 }
357}
358
359#[cfg(test)]
360mod tests {
361 use super::*;
362 use anyhow::{Result, anyhow};
363 use async_trait::async_trait;
364 use std::collections::VecDeque;
365 use std::sync::Mutex;
366 use voice_engine::event::SessionEvent;
367
368 struct TestProvider {
369 responses: Mutex<VecDeque<String>>,
370 }
371
372 impl TestProvider {
373 fn new(responses: Vec<String>) -> Self {
374 Self {
375 responses: Mutex::new(VecDeque::from(responses)),
376 }
377 }
378 }
379
380 #[async_trait]
381 impl LlmProvider for TestProvider {
382 async fn call(&self, _config: &LlmConfig, _history: &[ChatMessage]) -> Result<String> {
383 let mut guard = self.responses.lock().unwrap();
384 guard
385 .pop_front()
386 .ok_or_else(|| anyhow!("Test provider ran out of responses"))
387 }
388 }
389
390 struct RecordingRag {
391 queries: Mutex<Vec<String>>,
392 }
393
394 impl RecordingRag {
395 fn new() -> Self {
396 Self {
397 queries: Mutex::new(Vec::new()),
398 }
399 }
400
401 fn recorded_queries(&self) -> Vec<String> {
402 self.queries.lock().unwrap().clone()
403 }
404 }
405
406 #[async_trait]
407 impl RagRetriever for RecordingRag {
408 async fn retrieve(&self, query: &str) -> Result<String> {
409 self.queries.lock().unwrap().push(query.to_string());
410 Ok(format!("retrieved {}", query))
411 }
412 }
413
414 #[tokio::test]
415 async fn handler_applies_tool_instructions() -> Result<()> {
416 let response = r#"{
417 "text": "Goodbye",
418 "waitInputTimeout": 15000,
419 "tools": [
420 {"name": "hangup", "reason": "done", "initiator": "agent"},
421 {"name": "refer", "caller": "sip:bot", "callee": "sip:lead"}
422 ]
423 }"#;
424
425 let provider = Box::new(TestProvider::new(vec![response.to_string()]));
426 let mut handler =
427 LlmHandler::with_provider(LlmConfig::default(), provider, Arc::new(NoopRagRetriever));
428
429 let event = SessionEvent::AsrFinal {
430 track_id: "track-1".to_string(),
431 timestamp: 0,
432 index: 0,
433 start_time: None,
434 end_time: None,
435 text: "hello".to_string(),
436 };
437
438 let commands = handler.on_event(&event).await?;
439 assert!(matches!(
440 commands.get(0),
441 Some(Command::Tts {
442 text,
443 wait_input_timeout: Some(15000),
444 ..
445 }) if text == "Goodbye"
446 ));
447 assert!(commands.iter().any(|cmd| matches!(
448 cmd,
449 Command::Hangup {
450 reason: Some(reason),
451 initiator: Some(origin),
452 } if reason == "done" && origin == "agent"
453 )));
454 assert!(commands.iter().any(|cmd| matches!(
455 cmd,
456 Command::Refer {
457 caller,
458 callee,
459 ..
460 } if caller == "sip:bot" && callee == "sip:lead"
461 )));
462
463 Ok(())
464 }
465
466 #[tokio::test]
467 async fn handler_requeries_after_rag() -> Result<()> {
468 let rag_instruction = r#"{"tools": [{"name": "rag", "query": "policy"}]}"#;
469 let provider = Box::new(TestProvider::new(vec![
470 rag_instruction.to_string(),
471 "Final answer".to_string(),
472 ]));
473 let rag = Arc::new(RecordingRag::new());
474 let mut handler = LlmHandler::with_provider(LlmConfig::default(), provider, rag.clone());
475
476 let event = SessionEvent::AsrFinal {
477 track_id: "track-2".to_string(),
478 timestamp: 0,
479 index: 0,
480 start_time: None,
481 end_time: None,
482 text: "reep".to_string(),
483 };
484
485 let commands = handler.on_event(&event).await?;
486 assert!(matches!(
487 commands.get(0),
488 Some(Command::Tts {
489 text,
490 wait_input_timeout: Some(timeout),
491 ..
492 }) if text == "Final answer" && *timeout == 10000
493 ));
494 assert_eq!(rag.recorded_queries(), vec!["policy".to_string()]);
495
496 Ok(())
497 }
498}