Skip to main content

brainos_cortex/llm/
openai.rs

1use std::pin::Pin;
2
3use futures::Stream;
4use serde::{Deserialize, Serialize};
5
6use super::{
7    build_http_client, ensure_ok, LlmError, LlmProvider, Message, ProposedToolCall, Response,
8    ResponseChunk, ToolDef, Usage,
9};
10
11#[derive(Serialize)]
12struct OpenAiRequest {
13    model: String,
14    messages: Vec<OpenAiMessage>,
15    temperature: f64,
16    max_tokens: Option<i32>,
17    stream: bool,
18    /// Advertised tools (OpenAI function-calling shape). Omitted entirely
19    /// from a plain-text request so behaviour is unchanged when no tools
20    /// channel is in play.
21    #[serde(skip_serializing_if = "Option::is_none")]
22    tools: Option<Vec<OpenAiTool>>,
23    /// `"auto"` lets the model answer in plain text or propose a call. We
24    /// never force tool use, so chat stays able to just talk.
25    #[serde(skip_serializing_if = "Option::is_none")]
26    tool_choice: Option<&'static str>,
27}
28
29#[derive(Serialize, Deserialize, Default)]
30struct OpenAiMessage {
31    role: String,
32    /// Optional on the response side: a tool-call turn carries `null`
33    /// content. Always `Some` on the request side.
34    #[serde(default, skip_serializing_if = "Option::is_none")]
35    content: Option<String>,
36    #[serde(default, skip_serializing_if = "Option::is_none")]
37    tool_calls: Option<Vec<OpenAiToolCall>>,
38    /// Set only on a `role:"tool"` result turn — links the result to the
39    /// assistant `tool_calls` entry it answers.
40    #[serde(default, skip_serializing_if = "Option::is_none")]
41    tool_call_id: Option<String>,
42}
43
44/// One advertised tool in the OpenAI request (`{"type":"function", ...}`).
45#[derive(Serialize)]
46struct OpenAiTool {
47    #[serde(rename = "type")]
48    kind: &'static str,
49    function: OpenAiFunctionDef,
50}
51
52#[derive(Serialize)]
53struct OpenAiFunctionDef {
54    name: String,
55    description: String,
56    parameters: serde_json::Value,
57}
58
59/// A tool call in the response. `function.arguments` is a JSON-encoded
60/// string per the OpenAI wire format. The same shape is replayed on the
61/// request side for an assistant tool-call turn, where `type` must be
62/// `"function"`.
63#[derive(Serialize, Deserialize)]
64struct OpenAiToolCall {
65    #[serde(default)]
66    id: Option<String>,
67    #[serde(rename = "type", default = "function_kind")]
68    kind: String,
69    function: OpenAiFunctionCall,
70}
71
72fn function_kind() -> String {
73    "function".to_string()
74}
75
76#[derive(Serialize, Deserialize)]
77struct OpenAiFunctionCall {
78    name: String,
79    #[serde(default)]
80    arguments: String,
81}
82
83#[derive(Deserialize)]
84struct OpenAiResponse {
85    choices: Vec<OpenAiChoice>,
86    usage: Option<OpenAiUsage>,
87}
88
89#[derive(Deserialize)]
90struct OpenAiChoice {
91    message: OpenAiMessage,
92    #[allow(dead_code)]
93    finish_reason: Option<String>,
94}
95
96#[derive(Deserialize)]
97struct OpenAiStreamResponse {
98    choices: Vec<OpenAiStreamChoice>,
99}
100
101#[derive(Deserialize)]
102struct OpenAiStreamChoice {
103    delta: OpenAiDelta,
104    finish_reason: Option<String>,
105}
106
107#[derive(Deserialize)]
108struct OpenAiDelta {
109    #[serde(default)]
110    content: Option<String>,
111}
112
113#[derive(Deserialize)]
114struct OpenAiUsage {
115    prompt_tokens: u32,
116    completion_tokens: u32,
117    total_tokens: u32,
118}
119
120/// OpenAI-compatible provider (works with OpenAI, OpenRouter, etc.)
121pub struct OpenAiProvider {
122    client: reqwest::Client,
123    base_url: String,
124    api_key: Option<String>,
125    model: String,
126    temperature: f64,
127    max_tokens: Option<i32>,
128}
129
130impl OpenAiProvider {
131    pub fn new(
132        base_url: &str,
133        api_key: Option<&str>,
134        model: &str,
135        temperature: f64,
136        max_tokens: Option<i32>,
137    ) -> Result<Self, LlmError> {
138        let client = build_http_client(brain::timeouts::LLM_GENERATE)?;
139        Ok(Self {
140            client,
141            base_url: base_url.trim_end_matches('/').to_string(),
142            api_key: api_key.map(|s| s.to_string()),
143            model: model.to_string(),
144            temperature,
145            max_tokens,
146        })
147    }
148
149    pub fn openai(api_key: &str, model: &str) -> Result<Self, LlmError> {
150        Self::new(
151            "https://api.openai.com/v1",
152            Some(api_key),
153            model,
154            0.7,
155            Some(4096),
156        )
157    }
158
159    pub fn openrouter(api_key: &str, model: &str) -> Result<Self, LlmError> {
160        Self::new(
161            "https://openrouter.ai/api/v1",
162            Some(api_key),
163            model,
164            0.7,
165            Some(4096),
166        )
167    }
168
169    fn convert_messages(messages: &[Message]) -> Vec<OpenAiMessage> {
170        messages.iter().map(Self::convert_message).collect()
171    }
172
173    /// Translate one kernel [`Message`] into the OpenAI wire shape. An
174    /// assistant turn that proposed tool calls replays them (with `null`
175    /// content when it carried no prose); a [`Role::Tool`] result turn
176    /// carries its `tool_call_id`; every other turn is plain content.
177    fn convert_message(m: &Message) -> OpenAiMessage {
178        let role = m.role.as_wire_str().to_string();
179        if !m.tool_calls.is_empty() {
180            return OpenAiMessage {
181                role,
182                content: (!m.content.is_empty()).then(|| m.content.clone()),
183                tool_calls: Some(m.tool_calls.iter().map(convert_proposed_call).collect()),
184                tool_call_id: None,
185            };
186        }
187        OpenAiMessage {
188            role,
189            content: Some(m.content.clone()),
190            tool_calls: None,
191            tool_call_id: m.tool_call_id.clone(),
192        }
193    }
194
195    /// Translate the kernel's provider-agnostic [`ToolDef`]s into the
196    /// OpenAI function-calling request shape.
197    fn convert_tools(tools: &[ToolDef]) -> Vec<OpenAiTool> {
198        tools
199            .iter()
200            .map(|t| OpenAiTool {
201                kind: "function",
202                function: OpenAiFunctionDef {
203                    name: t.name.clone(),
204                    description: t.description.clone(),
205                    parameters: t.parameters.clone(),
206                },
207            })
208            .collect()
209    }
210
211    /// Map a response message's `tool_calls` into provider-agnostic
212    /// [`ProposedToolCall`]s, parsing each JSON-string argument blob into a
213    /// [`serde_json::Value`] (empty / unparseable args become an empty
214    /// object so the caller never has to re-parse or guard for null).
215    fn extract_tool_calls(message: &OpenAiMessage) -> Vec<ProposedToolCall> {
216        message
217            .tool_calls
218            .iter()
219            .flatten()
220            .map(|tc| ProposedToolCall {
221                id: tc.id.clone(),
222                name: tc.function.name.clone(),
223                arguments: parse_arguments(&tc.function.arguments),
224            })
225            .collect()
226    }
227
228    fn build_request(&self, builder: reqwest::RequestBuilder) -> reqwest::RequestBuilder {
229        let mut builder = builder;
230        if let Some(key) = &self.api_key {
231            builder = builder.header("Authorization", format!("Bearer {}", key));
232        }
233        builder
234    }
235}
236
237#[async_trait::async_trait]
238impl LlmProvider for OpenAiProvider {
239    async fn generate(&self, messages: &[Message]) -> Result<Response, LlmError> {
240        let url = format!("{}/chat/completions", self.base_url);
241        let request = OpenAiRequest {
242            model: self.model.clone(),
243            messages: Self::convert_messages(messages),
244            temperature: self.temperature,
245            max_tokens: self.max_tokens,
246            stream: false,
247            tools: None,
248            tool_choice: None,
249        };
250
251        let resp = self
252            .build_request(self.client.post(&url))
253            .json(&request)
254            .send()
255            .await?;
256        let resp = ensure_ok(resp).await?;
257
258        let data: OpenAiResponse = resp.json().await?;
259        let content = data
260            .choices
261            .first()
262            .and_then(|c| c.message.content.clone())
263            .unwrap_or_default();
264
265        Ok(Response::text(content, convert_usage(data.usage)))
266    }
267
268    async fn generate_with_tools(
269        &self,
270        messages: &[Message],
271        tools: &[ToolDef],
272    ) -> Result<Response, LlmError> {
273        // No tools to advertise → identical to a plain generate.
274        if tools.is_empty() {
275            return self.generate(messages).await;
276        }
277
278        let url = format!("{}/chat/completions", self.base_url);
279        let request = OpenAiRequest {
280            model: self.model.clone(),
281            messages: Self::convert_messages(messages),
282            temperature: self.temperature,
283            max_tokens: self.max_tokens,
284            stream: false,
285            tools: Some(Self::convert_tools(tools)),
286            // Never force a call — the model may answer in plain text.
287            tool_choice: Some("auto"),
288        };
289
290        let resp = self
291            .build_request(self.client.post(&url))
292            .json(&request)
293            .send()
294            .await?;
295        let resp = ensure_ok(resp).await?;
296
297        let data: OpenAiResponse = resp.json().await?;
298        let (content, tool_calls) = match data.choices.first() {
299            Some(choice) => (
300                choice.message.content.clone().unwrap_or_default(),
301                Self::extract_tool_calls(&choice.message),
302            ),
303            None => (String::new(), Vec::new()),
304        };
305
306        Ok(Response {
307            content,
308            usage: convert_usage(data.usage),
309            tool_calls,
310        })
311    }
312
313    async fn generate_stream(
314        &self,
315        messages: &[Message],
316    ) -> Result<Pin<Box<dyn Stream<Item = Result<ResponseChunk, LlmError>> + Send>>, LlmError> {
317        use futures::stream::try_unfold;
318
319        let url = format!("{}/chat/completions", self.base_url);
320        let request = OpenAiRequest {
321            model: self.model.clone(),
322            messages: Self::convert_messages(messages),
323            temperature: self.temperature,
324            max_tokens: self.max_tokens,
325            stream: true,
326            tools: None,
327            tool_choice: None,
328        };
329
330        let resp = self
331            .build_request(self.client.post(&url))
332            .json(&request)
333            .send()
334            .await?;
335        let resp = ensure_ok(resp).await?;
336
337        let byte_stream = resp.bytes_stream();
338        let stream = try_unfold(
339            (Box::pin(byte_stream), String::new()),
340            |(mut byte_stream, mut buf)| async move {
341                use futures::TryStreamExt;
342
343                loop {
344                    if let Some(newline_pos) = buf.find('\n') {
345                        let line: String = buf[..newline_pos].to_string();
346                        buf = buf[newline_pos + 1..].to_string();
347
348                        let line = line.trim();
349                        if line.is_empty() {
350                            continue;
351                        }
352
353                        if let Some(data) = line.strip_prefix("data: ") {
354                            let data = data.trim();
355                            if data == "[DONE]" {
356                                return Ok(None);
357                            }
358
359                            match serde_json::from_str::<OpenAiStreamResponse>(data) {
360                                Ok(resp) => {
361                                    if let Some(choice) = resp.choices.first() {
362                                        let content =
363                                            choice.delta.content.clone().unwrap_or_default();
364                                        let is_done = choice.finish_reason.is_some();
365                                        let chunk = ResponseChunk { content, is_done };
366                                        return Ok(Some((chunk, (byte_stream, buf))));
367                                    }
368                                    continue;
369                                }
370                                Err(e) => {
371                                    return Err(LlmError::InvalidFormat(format!(
372                                        "Failed to parse streaming response: {e}"
373                                    )));
374                                }
375                            }
376                        }
377                        continue;
378                    }
379
380                    match byte_stream.try_next().await {
381                        Ok(Some(bytes)) => {
382                            buf.push_str(&String::from_utf8_lossy(&bytes));
383                        }
384                        Ok(None) => return Ok(None),
385                        Err(e) => return Err(LlmError::Http(e)),
386                    }
387                }
388            },
389        );
390
391        Ok(Box::pin(stream))
392    }
393
394    async fn health_check(&self) -> bool {
395        let url = format!("{}/models", self.base_url);
396        match self.build_request(self.client.get(&url)).send().await {
397            Ok(resp) => resp.status().is_success(),
398            Err(_) => false,
399        }
400    }
401
402    fn name(&self) -> &str {
403        "openai"
404    }
405
406    fn model(&self) -> &str {
407        &self.model
408    }
409
410    async fn list_models(&self) -> Result<Vec<String>, LlmError> {
411        #[derive(Deserialize)]
412        struct ModelEntry {
413            id: String,
414        }
415        #[derive(Deserialize)]
416        struct Models {
417            data: Vec<ModelEntry>,
418        }
419
420        let url = format!("{}/models", self.base_url);
421        let resp = self.build_request(self.client.get(&url)).send().await?;
422        let resp = ensure_ok(resp).await?;
423        let data: Models = resp.json().await?;
424        Ok(data.data.into_iter().map(|m| m.id).collect())
425    }
426
427    async fn fetch_context_window(&self) -> Option<usize> {
428        // 1. API-based detection: some providers (OpenRouter) advertise
429        //    `context_length` per model in their /models response.
430        #[derive(Deserialize)]
431        struct ModelDetail {
432            id: String,
433            #[serde(default)]
434            context_length: Option<usize>,
435        }
436        #[derive(Deserialize)]
437        struct ModelsResponse {
438            data: Vec<ModelDetail>,
439        }
440
441        let from_api = (async {
442            let url = format!("{}/models", self.base_url);
443            let resp = self
444                .build_request(self.client.get(&url))
445                .send()
446                .await
447                .ok()?;
448            let resp = ensure_ok(resp).await.ok()?;
449            let data: ModelsResponse = resp.json().await.ok()?;
450            let active = self.model();
451            // Exact match first.
452            for model in &data.data {
453                if model.id == active {
454                    return model.context_length;
455                }
456            }
457            // Prefix match for OpenRouter model IDs like "openai/gpt-4o"
458            // where the config stores just "gpt-4o".
459            for model in &data.data {
460                if model.id.ends_with(active) || model.id.contains(active) {
461                    return model.context_length;
462                }
463            }
464            None
465        })
466        .await;
467        if from_api.is_some() {
468            return from_api;
469        }
470
471        // 2. Model-name heuristics (covers OpenAI, Groq, DeepSeek, etc.).
472        super::known_context_window(self.model())
473    }
474}
475
476/// Map the wire usage block into the kernel's [`Usage`].
477fn convert_usage(usage: Option<OpenAiUsage>) -> Option<Usage> {
478    usage.map(|u| Usage {
479        prompt_tokens: u.prompt_tokens,
480        completion_tokens: u.completion_tokens,
481        total_tokens: u.total_tokens,
482    })
483}
484
485/// Reverse of [`OpenAiProvider::extract_tool_calls`]: render a kernel
486/// [`ProposedToolCall`] back into the OpenAI wire shape for an assistant
487/// tool-call turn. `arguments` go back out as a JSON-encoded string.
488fn convert_proposed_call(call: &ProposedToolCall) -> OpenAiToolCall {
489    OpenAiToolCall {
490        id: call.id.clone(),
491        kind: function_kind(),
492        function: OpenAiFunctionCall {
493            name: call.name.clone(),
494            arguments: serde_json::to_string(&call.arguments).unwrap_or_else(|_| "{}".to_string()),
495        },
496    }
497}
498
499/// Parse a tool call's JSON-string `arguments` into a [`serde_json::Value`].
500/// An empty or unparseable blob becomes an empty object so callers always
501/// get a well-formed args object.
502fn parse_arguments(raw: &str) -> serde_json::Value {
503    let trimmed = raw.trim();
504    if trimmed.is_empty() {
505        return serde_json::json!({});
506    }
507    serde_json::from_str(trimmed).unwrap_or_else(|_| serde_json::json!({}))
508}