active_call/playbook/handler/
provider.rs

1use anyhow::{Result, anyhow};
2use async_trait::async_trait;
3use futures::Stream;
4use reqwest::Client;
5use serde_json::json;
6use std::pin::Pin;
7
8use super::super::{LlmConfig, ChatMessage};
9use super::types::ToolInvocation;
10
11#[derive(Debug, Clone)]
12pub enum LlmStreamEvent {
13    Content(String),
14    Reasoning(String),
15}
16
17#[async_trait]
18pub trait LlmProvider: Send + Sync {
19    async fn call(&self, config: &LlmConfig, history: &[ChatMessage]) -> Result<String>;
20    async fn call_stream(
21        &self,
22        config: &LlmConfig,
23        history: &[ChatMessage],
24    ) -> Result<Pin<Box<dyn Stream<Item = Result<LlmStreamEvent>> + Send>>>;
25}
26
27pub struct RealtimeResponse {
28    pub audio_delta: Option<Vec<u8>>,
29    pub text_delta: Option<String>,
30    pub function_call: Option<ToolInvocation>,
31    pub speech_started: bool,
32}
33
34#[async_trait]
35pub trait RealtimeProvider: Send + Sync {
36    async fn connect(&self, config: &LlmConfig) -> Result<()>;
37    async fn send_audio(&self, audio: &[i16]) -> Result<()>;
38    async fn subscribe(
39        &self,
40    ) -> Result<Pin<Box<dyn Stream<Item = Result<RealtimeResponse>> + Send>>>;
41}
42
43pub struct DefaultLlmProvider {
44    client: Client,
45}
46
47impl DefaultLlmProvider {
48    pub fn new() -> Self {
49        Self {
50            client: Client::new(),
51        }
52    }
53}
54
55#[async_trait]
56impl LlmProvider for DefaultLlmProvider {
57    async fn call(&self, config: &LlmConfig, history: &[ChatMessage]) -> Result<String> {
58        let mut url = config
59            .base_url
60            .clone()
61            .unwrap_or_else(|| "https://api.openai.com/v1".to_string());
62        let model = config
63            .model
64            .clone()
65            .unwrap_or_else(|| "gpt-4-turbo".to_string());
66        let api_key = config.api_key.clone().unwrap_or_default();
67
68        if !url.ends_with("/chat/completions") {
69            url = format!("{}/chat/completions", url.trim_end_matches('/'));
70        }
71
72        let body = json!({
73            "model": model,
74            "messages": history,
75        });
76
77        let res = self
78            .client
79            .post(&url)
80            .header("Authorization", format!("Bearer {}", api_key))
81            .json(&body)
82            .send()
83            .await?;
84
85        if !res.status().is_success() {
86            return Err(anyhow!("LLM request failed: {}", res.status()));
87        }
88
89        let json: serde_json::Value = res.json().await?;
90        let content = json["choices"][0]["message"]["content"]
91            .as_str()
92            .ok_or_else(|| anyhow!("Invalid LLM response"))?
93            .to_string();
94
95        Ok(content)
96    }
97
98    async fn call_stream(
99        &self,
100        config: &LlmConfig,
101        history: &[ChatMessage],
102    ) -> Result<Pin<Box<dyn Stream<Item = Result<LlmStreamEvent>> + Send>>> {
103        let mut url = config
104            .base_url
105            .clone()
106            .unwrap_or_else(|| "https://api.openai.com/v1".to_string());
107        let model = config
108            .model
109            .clone()
110            .unwrap_or_else(|| "gpt-4-turbo".to_string());
111        let api_key = config.api_key.clone().unwrap_or_default();
112
113        if !url.ends_with("/chat/completions") {
114            url = format!("{}/chat/completions", url.trim_end_matches('/'));
115        }
116
117        let body = json!({
118            "model": model,
119            "messages": history,
120            "stream": true,
121        });
122
123        let res = self
124            .client
125            .post(&url)
126            .header("Authorization", format!("Bearer {}", api_key))
127            .json(&body)
128            .send()
129            .await?;
130
131        if !res.status().is_success() {
132            return Err(anyhow!("LLM request failed: {}", res.status()));
133        }
134
135        let stream = res.bytes_stream();
136        let s = async_stream::stream! {
137            let mut buffer = String::new();
138            for await chunk in stream {
139                match chunk {
140                    Ok(bytes) => {
141                        let text = String::from_utf8_lossy(&bytes);
142                        buffer.push_str(&text);
143
144                        while let Some(line_end) = buffer.find('\n') {
145                            let line = buffer[..line_end].trim();
146                            if line.starts_with("data:") {
147                                let data = &line[5..].trim();
148                                if *data == "[DONE]" {
149                                    break;
150                                }
151                                if let Ok(json) = serde_json::from_str::<serde_json::Value>(data) {
152                                    if let Some(delta) = json["choices"][0].get("delta") {
153                                         if let Some(thinking) = delta.get("reasoning_content").and_then(|v| v.as_str()) {
154                                             yield Ok(LlmStreamEvent::Reasoning(thinking.to_string()));
155                                         }
156                                         if let Some(content) = delta.get("content").and_then(|v| v.as_str()) {
157                                             yield Ok(LlmStreamEvent::Content(content.to_string()));
158                                         }
159                                    }
160                                }
161                            }
162                            buffer.drain(..=line_end);
163                        }
164                    }
165                    Err(e) => yield Err(anyhow!(e)),
166                }
167            }
168        };
169
170        Ok(Box::pin(s))
171    }
172}