Skip to main content

brainos_cortex/llm/
openai.rs

1use std::pin::Pin;
2
3use futures::Stream;
4use serde::{Deserialize, Serialize};
5
6use super::{LlmError, LlmProvider, Message, Response, ResponseChunk, Role, Usage};
7
8#[derive(Serialize)]
9struct OpenAiRequest {
10    model: String,
11    messages: Vec<OpenAiMessage>,
12    temperature: f64,
13    max_tokens: Option<i32>,
14    stream: bool,
15}
16
17#[derive(Serialize, Deserialize)]
18struct OpenAiMessage {
19    role: String,
20    content: String,
21}
22
23#[derive(Deserialize)]
24struct OpenAiResponse {
25    choices: Vec<OpenAiChoice>,
26    usage: Option<OpenAiUsage>,
27}
28
29#[derive(Deserialize)]
30struct OpenAiChoice {
31    message: OpenAiMessage,
32    #[allow(dead_code)]
33    finish_reason: Option<String>,
34}
35
36#[derive(Deserialize)]
37struct OpenAiStreamResponse {
38    choices: Vec<OpenAiStreamChoice>,
39}
40
41#[derive(Deserialize)]
42struct OpenAiStreamChoice {
43    delta: OpenAiDelta,
44    finish_reason: Option<String>,
45}
46
47#[derive(Deserialize)]
48struct OpenAiDelta {
49    #[serde(default)]
50    content: Option<String>,
51}
52
53#[derive(Deserialize)]
54struct OpenAiUsage {
55    prompt_tokens: u32,
56    completion_tokens: u32,
57    total_tokens: u32,
58}
59
60/// OpenAI-compatible provider (works with OpenAI, OpenRouter, etc.)
61pub struct OpenAiProvider {
62    client: reqwest::Client,
63    base_url: String,
64    api_key: Option<String>,
65    model: String,
66    temperature: f64,
67    max_tokens: Option<i32>,
68}
69
70impl OpenAiProvider {
71    pub fn new(
72        base_url: &str,
73        api_key: Option<&str>,
74        model: &str,
75        temperature: f64,
76        max_tokens: Option<i32>,
77    ) -> Result<Self, LlmError> {
78        let client = reqwest::Client::builder()
79            .timeout(brain_core::timeouts::LLM_GENERATE)
80            .build()
81            .map_err(|e| {
82                LlmError::ProviderUnavailable(format!("Failed to create HTTP client: {e}"))
83            })?;
84
85        Ok(Self {
86            client,
87            base_url: base_url.trim_end_matches('/').to_string(),
88            api_key: api_key.map(|s| s.to_string()),
89            model: model.to_string(),
90            temperature,
91            max_tokens,
92        })
93    }
94
95    pub fn openai(api_key: &str, model: &str) -> Result<Self, LlmError> {
96        Self::new(
97            "https://api.openai.com/v1",
98            Some(api_key),
99            model,
100            0.7,
101            Some(4096),
102        )
103    }
104
105    pub fn openrouter(api_key: &str, model: &str) -> Result<Self, LlmError> {
106        Self::new(
107            "https://openrouter.ai/api/v1",
108            Some(api_key),
109            model,
110            0.7,
111            Some(4096),
112        )
113    }
114
115    fn convert_messages(messages: &[Message]) -> Vec<OpenAiMessage> {
116        messages
117            .iter()
118            .map(|m| OpenAiMessage {
119                role: match m.role {
120                    Role::System => "system".to_string(),
121                    Role::User => "user".to_string(),
122                    Role::Assistant => "assistant".to_string(),
123                },
124                content: m.content.clone(),
125            })
126            .collect()
127    }
128
129    fn build_request(&self, builder: reqwest::RequestBuilder) -> reqwest::RequestBuilder {
130        let mut builder = builder;
131        if let Some(key) = &self.api_key {
132            builder = builder.header("Authorization", format!("Bearer {}", key));
133        }
134        builder
135    }
136}
137
138#[async_trait::async_trait]
139impl LlmProvider for OpenAiProvider {
140    async fn generate(&self, messages: &[Message]) -> Result<Response, LlmError> {
141        let url = format!("{}/chat/completions", self.base_url);
142        let request = OpenAiRequest {
143            model: self.model.clone(),
144            messages: Self::convert_messages(messages),
145            temperature: self.temperature,
146            max_tokens: self.max_tokens,
147            stream: false,
148        };
149
150        let resp = self
151            .build_request(self.client.post(&url))
152            .json(&request)
153            .send()
154            .await?;
155
156        if !resp.status().is_success() {
157            let status = resp.status();
158            let body = resp.text().await.unwrap_or_default();
159            return Err(LlmError::Api {
160                status: status.as_u16(),
161                message: body,
162            });
163        }
164
165        let data: OpenAiResponse = resp.json().await?;
166        let content = data
167            .choices
168            .first()
169            .map(|c| c.message.content.clone())
170            .unwrap_or_default();
171
172        Ok(Response {
173            content,
174            usage: data.usage.map(|u| Usage {
175                prompt_tokens: u.prompt_tokens,
176                completion_tokens: u.completion_tokens,
177                total_tokens: u.total_tokens,
178            }),
179        })
180    }
181
182    async fn generate_stream(
183        &self,
184        messages: &[Message],
185    ) -> Result<Pin<Box<dyn Stream<Item = Result<ResponseChunk, LlmError>> + Send>>, LlmError> {
186        use futures::stream::try_unfold;
187
188        let url = format!("{}/chat/completions", self.base_url);
189        let request = OpenAiRequest {
190            model: self.model.clone(),
191            messages: Self::convert_messages(messages),
192            temperature: self.temperature,
193            max_tokens: self.max_tokens,
194            stream: true,
195        };
196
197        let resp = self
198            .build_request(self.client.post(&url))
199            .json(&request)
200            .send()
201            .await?;
202
203        if !resp.status().is_success() {
204            let status = resp.status();
205            let body = resp.text().await.unwrap_or_default();
206            return Err(LlmError::Api {
207                status: status.as_u16(),
208                message: body,
209            });
210        }
211
212        let byte_stream = resp.bytes_stream();
213        let stream = try_unfold(
214            (Box::pin(byte_stream), String::new()),
215            |(mut byte_stream, mut buf)| async move {
216                use futures::TryStreamExt;
217
218                loop {
219                    if let Some(newline_pos) = buf.find('\n') {
220                        let line: String = buf[..newline_pos].to_string();
221                        buf = buf[newline_pos + 1..].to_string();
222
223                        let line = line.trim();
224                        if line.is_empty() {
225                            continue;
226                        }
227
228                        if let Some(data) = line.strip_prefix("data: ") {
229                            let data = data.trim();
230                            if data == "[DONE]" {
231                                return Ok(None);
232                            }
233
234                            match serde_json::from_str::<OpenAiStreamResponse>(data) {
235                                Ok(resp) => {
236                                    if let Some(choice) = resp.choices.first() {
237                                        let content =
238                                            choice.delta.content.clone().unwrap_or_default();
239                                        let is_done = choice.finish_reason.is_some();
240                                        let chunk = ResponseChunk { content, is_done };
241                                        return Ok(Some((chunk, (byte_stream, buf))));
242                                    }
243                                    continue;
244                                }
245                                Err(e) => {
246                                    return Err(LlmError::InvalidFormat(format!(
247                                        "Failed to parse streaming response: {e}"
248                                    )));
249                                }
250                            }
251                        }
252                        continue;
253                    }
254
255                    match byte_stream.try_next().await {
256                        Ok(Some(bytes)) => {
257                            buf.push_str(&String::from_utf8_lossy(&bytes));
258                        }
259                        Ok(None) => return Ok(None),
260                        Err(e) => return Err(LlmError::Http(e)),
261                    }
262                }
263            },
264        );
265
266        Ok(Box::pin(stream))
267    }
268
269    async fn health_check(&self) -> bool {
270        let url = format!("{}/models", self.base_url);
271        match self.build_request(self.client.get(&url)).send().await {
272            Ok(resp) => resp.status().is_success(),
273            Err(_) => false,
274        }
275    }
276
277    fn name(&self) -> &str {
278        "openai"
279    }
280
281    fn model(&self) -> &str {
282        &self.model
283    }
284
285    async fn list_models(&self) -> Result<Vec<String>, LlmError> {
286        #[derive(Deserialize)]
287        struct ModelEntry {
288            id: String,
289        }
290        #[derive(Deserialize)]
291        struct Models {
292            data: Vec<ModelEntry>,
293        }
294
295        let url = format!("{}/models", self.base_url);
296        let resp = self.build_request(self.client.get(&url)).send().await?;
297        if !resp.status().is_success() {
298            let status = resp.status();
299            let body = resp.text().await.unwrap_or_default();
300            return Err(LlmError::Api {
301                status: status.as_u16(),
302                message: body,
303            });
304        }
305        let data: Models = resp.json().await?;
306        Ok(data.data.into_iter().map(|m| m.id).collect())
307    }
308}