Skip to main content

ferro_ai/client/
openai.rs

1use crate::client::{
2    CompletionRequest, CompletionResponse, LlmClient, Role, TokenStream, ToolChoice, ToolUseBlock,
3};
4use crate::error::Error;
5use async_trait::async_trait;
6use futures::{stream, StreamExt};
7use reqwest_eventsource::{Event, RequestBuilderExt};
8
9/// OpenAI Chat Completions API client.
10///
11/// Implements [`LlmClient`] against `{base_url}/v1/chat/completions`. Also
12/// serves as the Groq client when `base_url` is set to
13/// `https://api.groq.com/openai` — Groq exposes an OpenAI-compatible API.
14///
15/// # Authentication
16///
17/// Uses `Authorization: Bearer {api_key}` for all requests.
18///
19/// # Embeddings
20///
21/// `embed()` posts to `{base_url}/v1/embeddings` using model
22/// `text-embedding-3-small` and extracts `data[0].embedding` as `Vec<f32>`.
23pub struct OpenAiClient {
24    client: reqwest::Client,
25    api_key: String,
26    model: Option<String>,
27    base_url: String,
28}
29
30impl OpenAiClient {
31    /// Create a new client.
32    ///
33    /// - `model`: optional model override; `None` resolves to `default_model()`.
34    /// - `base_url`: optional base URL override; `None` defaults to
35    ///   `https://api.openai.com`. Pass `Some("https://api.groq.com/openai")`
36    ///   for Groq compatibility.
37    ///
38    /// The internal `reqwest::Client` uses a 60-second timeout (T-165-04).
39    pub fn new(api_key: String, model: Option<String>, base_url: Option<String>) -> Self {
40        let client = reqwest::Client::builder()
41            .timeout(std::time::Duration::from_secs(60))
42            .build()
43            .expect("failed to build reqwest client");
44        let base_url = base_url.unwrap_or_else(|| "https://api.openai.com".to_string());
45        Self {
46            client,
47            api_key,
48            model,
49            base_url,
50        }
51    }
52
53    /// The embedding model for `/v1/embeddings`.
54    ///
55    /// Reads `FERRO_AI_EMBED_MODEL`; falls back to `"text-embedding-3-small"`.
56    /// Intentionally separate from `default_model()` (the chat model).
57    pub(crate) fn embed_model() -> String {
58        std::env::var("FERRO_AI_EMBED_MODEL")
59            .unwrap_or_else(|_| "text-embedding-3-small".to_string())
60    }
61
62    /// Build the request body for the Chat Completions API.
63    ///
64    /// Includes `response_format.type = "json_schema"` only when
65    /// `request.schema` is `Some`. Sets `"stream": stream`.
66    pub(crate) fn build_body(
67        &self,
68        request: &CompletionRequest,
69        stream: bool,
70    ) -> serde_json::Value {
71        let model = request
72            .model_override
73            .as_deref()
74            .unwrap_or_else(|| self.default_model());
75
76        let messages: Vec<serde_json::Value> = request
77            .messages
78            .iter()
79            .map(|m| match m.role {
80                Role::Tool => {
81                    // OpenAI wire format: role "tool" with tool_call_id as a real field.
82                    // tool_call_id must not be embedded in the content string.
83                    let call_id = m.tool_call_id.as_deref().unwrap_or("");
84                    serde_json::json!({
85                        "role": "tool",
86                        "tool_call_id": call_id,
87                        "content": m.content,
88                    })
89                }
90                Role::User => serde_json::json!({"role": "user", "content": m.content}),
91                Role::Assistant => {
92                    serde_json::json!({"role": "assistant", "content": m.content})
93                }
94            })
95            .collect();
96
97        let mut body = serde_json::json!({
98            "model": model,
99            "messages": messages,
100            "max_tokens": request.max_tokens,
101            "stream": stream,
102        });
103
104        if let Some(schema) = &request.schema {
105            body["response_format"] = serde_json::json!({
106                "type": "json_schema",
107                "json_schema": {
108                    "name": "output",
109                    "schema": schema,
110                    "strict": true,
111                }
112            });
113        }
114
115        if let Some(tools) = &request.tools {
116            let tools_json: Vec<serde_json::Value> = tools
117                .iter()
118                .map(|t| {
119                    serde_json::json!({
120                        "type": "function",
121                        "function": {
122                            "name": t.name,
123                            "description": t.description,
124                            "parameters": t.parameters_schema,
125                            "strict": true,
126                        }
127                    })
128                })
129                .collect();
130            body["tools"] = serde_json::Value::Array(tools_json);
131            // WR-01: honor request.tool_choice; default to "auto" when not specified.
132            body["tool_choice"] = match request.tool_choice.as_ref() {
133                Some(ToolChoice::None) => serde_json::json!("none"),
134                Some(ToolChoice::Auto) | None => serde_json::json!("auto"),
135            };
136        }
137
138        body
139    }
140}
141
142/// Parse tool_calls from an OpenAI response into [`ToolUseBlock`]s.
143pub(crate) fn parse_openai_tool_calls(json: &serde_json::Value) -> Vec<ToolUseBlock> {
144    let Some(tool_calls) = json["choices"][0]["message"]["tool_calls"].as_array() else {
145        return vec![];
146    };
147    tool_calls
148        .iter()
149        .filter_map(|c| {
150            Some(ToolUseBlock {
151                id: c["id"].as_str()?.to_string(),
152                name: c["function"]["name"].as_str()?.to_string(),
153                input: serde_json::from_str(c["function"]["arguments"].as_str()?).ok()?,
154            })
155        })
156        .collect()
157}
158
159/// Result of parsing a single OpenAI SSE chunk data string.
160#[derive(Debug, PartialEq)]
161pub(crate) enum OpenAiDelta {
162    /// The stream is terminated (`data: [DONE]`).
163    Done,
164    /// A non-empty text token.
165    Token(String),
166    /// Empty delta or role-only chunk — skip.
167    Skip,
168}
169
170/// Parse one OpenAI SSE chunk data string.
171///
172/// Handles the `[DONE]` sentinel (Pitfall 4) before JSON-parsing.
173/// Returns `Done` on termination, `Token(text)` for content, `Skip` otherwise.
174pub(crate) fn parse_openai_delta(data: &str) -> OpenAiDelta {
175    if data == "[DONE]" {
176        return OpenAiDelta::Done;
177    }
178    let Ok(v) = serde_json::from_str::<serde_json::Value>(data) else {
179        return OpenAiDelta::Skip;
180    };
181    // Terminate on finish_reason being non-null
182    if !v["choices"][0]["finish_reason"].is_null() {
183        if let Some(reason) = v["choices"][0]["finish_reason"].as_str() {
184            if !reason.is_empty() {
185                return OpenAiDelta::Done;
186            }
187        }
188    }
189    match v["choices"][0]["delta"]["content"].as_str() {
190        Some(text) if !text.is_empty() => OpenAiDelta::Token(text.to_string()),
191        _ => OpenAiDelta::Skip,
192    }
193}
194
195/// Parse the embeddings response, extracting `data[0].embedding` as `Vec<f32>`.
196pub(crate) fn parse_embedding(json: &serde_json::Value) -> Result<Vec<f32>, Error> {
197    json["data"][0]["embedding"]
198        .as_array()
199        .map(|arr| {
200            arr.iter()
201                .filter_map(|v| v.as_f64().map(|f| f as f32))
202                .collect()
203        })
204        .ok_or_else(|| Error::Deserialization("no embedding in response".into()))
205}
206
207#[async_trait]
208impl LlmClient for OpenAiClient {
209    fn default_model(&self) -> &str {
210        self.model.as_deref().unwrap_or("gpt-4o")
211    }
212
213    async fn complete(&self, request: CompletionRequest) -> Result<String, Error> {
214        let body = self.build_body(&request, false);
215
216        let resp = self
217            .client
218            .post(format!("{}/v1/chat/completions", self.base_url))
219            .bearer_auth(&self.api_key)
220            .json(&body)
221            .send()
222            .await
223            .map_err(|e| {
224                if e.is_timeout() {
225                    Error::Timeout
226                } else {
227                    Error::Provider {
228                        status: None,
229                        message: e.to_string(),
230                    }
231                }
232            })?;
233
234        let status = resp.status().as_u16();
235        if !resp.status().is_success() {
236            let text = resp.text().await.unwrap_or_default();
237            return Err(Error::Provider {
238                status: Some(status),
239                message: text,
240            });
241        }
242
243        let json: serde_json::Value = resp
244            .json()
245            .await
246            .map_err(|e| Error::Deserialization(e.to_string()))?;
247
248        json["choices"][0]["message"]["content"]
249            .as_str()
250            .map(|s| s.to_string())
251            .ok_or_else(|| Error::Deserialization("no content in response".into()))
252    }
253
254    async fn complete_stream(&self, request: CompletionRequest) -> Result<TokenStream, Error> {
255        let body = self.build_body(&request, true);
256
257        let builder = self
258            .client
259            .post(format!("{}/v1/chat/completions", self.base_url))
260            .bearer_auth(&self.api_key)
261            .json(&body);
262
263        let es = builder.eventsource().map_err(|_| Error::Provider {
264            status: None,
265            message: "request not cloneable".into(),
266        })?;
267
268        let token_stream = stream::unfold(es, |mut es| async move {
269            loop {
270                match es.next().await {
271                    None => return None,
272                    Some(Ok(Event::Open)) => continue,
273                    Some(Ok(Event::Message(msg))) => match parse_openai_delta(&msg.data) {
274                        OpenAiDelta::Done => {
275                            es.close();
276                            return None;
277                        }
278                        OpenAiDelta::Token(text) => return Some((Ok(text), es)),
279                        OpenAiDelta::Skip => continue,
280                    },
281                    Some(Err(e)) => {
282                        es.close();
283                        return Some((
284                            Err(Error::Provider {
285                                status: None,
286                                message: e.to_string(),
287                            }),
288                            es,
289                        ));
290                    }
291                }
292            }
293        });
294
295        Ok(Box::pin(token_stream))
296    }
297
298    async fn embed(&self, text: &str) -> Result<Vec<f32>, Error> {
299        let body = serde_json::json!({
300            "model": Self::embed_model(),
301            "input": text,
302        });
303
304        let resp = self
305            .client
306            .post(format!("{}/v1/embeddings", self.base_url))
307            .bearer_auth(&self.api_key)
308            .json(&body)
309            .send()
310            .await
311            .map_err(|e| {
312                if e.is_timeout() {
313                    Error::Timeout
314                } else {
315                    Error::Provider {
316                        status: None,
317                        message: e.to_string(),
318                    }
319                }
320            })?;
321
322        let status = resp.status().as_u16();
323        if !resp.status().is_success() {
324            let text = resp.text().await.unwrap_or_default();
325            return Err(Error::Provider {
326                status: Some(status),
327                message: text,
328            });
329        }
330
331        let json: serde_json::Value = resp
332            .json()
333            .await
334            .map_err(|e| Error::Deserialization(e.to_string()))?;
335
336        parse_embedding(&json)
337    }
338
339    async fn complete_with_tools(
340        &self,
341        request: CompletionRequest,
342    ) -> Result<CompletionResponse, Error> {
343        let body = self.build_body(&request, false);
344
345        let resp = self
346            .client
347            .post(format!("{}/v1/chat/completions", self.base_url))
348            .bearer_auth(&self.api_key)
349            .json(&body)
350            .send()
351            .await
352            .map_err(|e| {
353                if e.is_timeout() {
354                    Error::Timeout
355                } else {
356                    Error::Provider {
357                        status: None,
358                        message: e.to_string(),
359                    }
360                }
361            })?;
362
363        let status = resp.status().as_u16();
364        if !resp.status().is_success() {
365            let text = resp.text().await.unwrap_or_default();
366            return Err(Error::Provider {
367                status: Some(status),
368                message: text,
369            });
370        }
371
372        let json: serde_json::Value = resp
373            .json()
374            .await
375            .map_err(|e| Error::Deserialization(e.to_string()))?;
376
377        let finish_reason = json["choices"][0]["finish_reason"].as_str().unwrap_or("");
378        if finish_reason == "tool_calls" {
379            let blocks = parse_openai_tool_calls(&json);
380            let assistant_content = json["choices"][0]["message"]["tool_calls"].to_string();
381            return Ok(CompletionResponse::ToolUse {
382                blocks,
383                assistant_content,
384            });
385        }
386
387        // stop or any other finish_reason → extract text content
388        let text = json["choices"][0]["message"]["content"]
389            .as_str()
390            .map(|s| s.to_string())
391            .ok_or_else(|| Error::Deserialization("no content in response".into()))?;
392
393        Ok(CompletionResponse::Text(text))
394    }
395}
396
397#[cfg(test)]
398mod tests {
399    use super::*;
400    use crate::client::Message;
401
402    #[test]
403    fn test_openai_default_model() {
404        let client = OpenAiClient::new("k".into(), None, None);
405        assert_eq!(client.default_model(), "gpt-4o");
406    }
407
408    #[test]
409    fn test_openai_default_base_url() {
410        let client = OpenAiClient::new("k".into(), None, None);
411        assert_eq!(client.base_url, "https://api.openai.com");
412    }
413
414    #[test]
415    fn test_openai_groq_base_url() {
416        let client =
417            OpenAiClient::new("k".into(), None, Some("https://api.groq.com/openai".into()));
418        assert_eq!(client.base_url, "https://api.groq.com/openai");
419    }
420
421    #[test]
422    fn test_build_body_response_format_with_schema() {
423        let client = OpenAiClient::new("k".into(), None, None);
424        let schema = serde_json::json!({"type": "object", "properties": {"x": {"type": "string"}}});
425        let request = CompletionRequest {
426            system: None,
427            messages: vec![Message {
428                role: Role::User,
429                content: "hi".into(),
430                tool_call_id: None,
431            }],
432            max_tokens: 100,
433            model_override: None,
434            schema: Some(schema.clone()),
435            tools: None,
436            tool_choice: None,
437        };
438        let body = client.build_body(&request, false);
439
440        assert_eq!(body["response_format"]["type"], "json_schema");
441        assert_eq!(body["response_format"]["json_schema"]["name"], "output");
442        assert_eq!(body["response_format"]["json_schema"]["schema"], schema);
443        assert_eq!(body["response_format"]["json_schema"]["strict"], true);
444    }
445
446    #[test]
447    fn test_build_body_no_response_format_without_schema() {
448        let client = OpenAiClient::new("k".into(), None, None);
449        let request = CompletionRequest {
450            system: None,
451            messages: vec![Message {
452                role: Role::User,
453                content: "hi".into(),
454                tool_call_id: None,
455            }],
456            max_tokens: 100,
457            model_override: None,
458            schema: None,
459            tools: None,
460            tool_choice: None,
461        };
462        let body = client.build_body(&request, false);
463        assert!(body.get("response_format").is_none());
464    }
465
466    #[test]
467    fn test_parse_openai_delta_done() {
468        assert_eq!(parse_openai_delta("[DONE]"), OpenAiDelta::Done);
469    }
470
471    #[test]
472    fn test_parse_openai_delta_token() {
473        // Fixture from RESEARCH line 635
474        let data = r#"{"id":"chatcmpl-1","object":"chat.completion.chunk","choices":[{"index":0,"delta":{"content":"Hello"},"finish_reason":null}]}"#;
475        assert_eq!(
476            parse_openai_delta(data),
477            OpenAiDelta::Token("Hello".to_string())
478        );
479    }
480
481    #[test]
482    fn test_parse_openai_delta_skip_empty_content() {
483        let data = r#"{"choices":[{"index":0,"delta":{"role":"assistant","content":null},"finish_reason":null}]}"#;
484        assert_eq!(parse_openai_delta(data), OpenAiDelta::Skip);
485    }
486
487    #[test]
488    fn test_parse_openai_delta_finish_reason() {
489        let data = r#"{"choices":[{"index":0,"delta":{},"finish_reason":"stop"}]}"#;
490        assert_eq!(parse_openai_delta(data), OpenAiDelta::Done);
491    }
492
493    #[test]
494    fn test_parse_embedding() {
495        let json = serde_json::json!({
496            "data": [{"embedding": [0.1, -0.2, 0.3], "index": 0}],
497            "usage": {}
498        });
499        let result = parse_embedding(&json).unwrap();
500        assert_eq!(result.len(), 3);
501        assert!((result[0] - 0.1f32).abs() < 1e-6);
502        assert!((result[1] - (-0.2f32)).abs() < 1e-6);
503        assert!((result[2] - 0.3f32).abs() < 1e-6);
504    }
505
506    #[test]
507    fn test_parse_embedding_missing() {
508        let json = serde_json::json!({"data": []});
509        assert!(matches!(
510            parse_embedding(&json),
511            Err(Error::Deserialization(_))
512        ));
513    }
514
515    #[test]
516    fn test_openai_is_object_safe() {
517        let _: Box<dyn LlmClient> = Box::new(OpenAiClient::new("k".into(), None, None));
518    }
519
520    /// CR-03 regression: OpenAI tool result messages must include `tool_call_id` as a
521    /// real top-level field. The id must not be embedded inside the content string.
522    #[test]
523    fn test_build_body_tool_result_wire_format() {
524        let client = OpenAiClient::new("k".into(), None, None);
525        let request = CompletionRequest {
526            system: None,
527            messages: vec![
528                Message {
529                    role: Role::User,
530                    content: "what is 2+2?".into(),
531                    tool_call_id: None,
532                },
533                Message {
534                    role: Role::Tool,
535                    content: "4".into(),
536                    tool_call_id: Some("call_abc123".into()),
537                },
538            ],
539            max_tokens: 100,
540            model_override: None,
541            schema: None,
542            tools: None,
543            tool_choice: None,
544        };
545        let body = client.build_body(&request, false);
546        let msgs = body["messages"].as_array().expect("messages must be array");
547        assert_eq!(msgs.len(), 2);
548
549        let tool_msg = &msgs[1];
550        assert_eq!(tool_msg["role"], "tool");
551        assert_eq!(
552            tool_msg["tool_call_id"], "call_abc123",
553            "tool_call_id must be a real top-level field"
554        );
555        assert_eq!(tool_msg["content"], "4");
556        // The id must not also appear embedded in the content string.
557        assert!(
558            !tool_msg["content"]
559                .as_str()
560                .unwrap_or("")
561                .contains("call_abc123"),
562            "tool_call_id must not be embedded in content"
563        );
564    }
565
566    /// WR-01 regression: OpenAI build_body must honor request.tool_choice.
567    #[test]
568    fn test_build_body_tool_choice_none() {
569        use crate::client::{ToolChoice, ToolRequest};
570
571        let client = OpenAiClient::new("k".into(), None, None);
572        let request = CompletionRequest {
573            system: None,
574            messages: vec![Message {
575                role: Role::User,
576                content: "hi".into(),
577                tool_call_id: None,
578            }],
579            max_tokens: 100,
580            model_override: None,
581            schema: None,
582            tools: Some(vec![ToolRequest {
583                name: "my_tool".into(),
584                description: "does stuff".into(),
585                parameters_schema: serde_json::json!({"type": "object"}),
586            }]),
587            tool_choice: Some(ToolChoice::None),
588        };
589        let body = client.build_body(&request, false);
590        assert_eq!(
591            body["tool_choice"], "none",
592            "ToolChoice::None must emit tool_choice: 'none'"
593        );
594    }
595
596    /// WR-01: Auto tool_choice (explicit) and default (None) both emit "auto".
597    #[test]
598    fn test_build_body_tool_choice_auto() {
599        use crate::client::{ToolChoice, ToolRequest};
600
601        let client = OpenAiClient::new("k".into(), None, None);
602        let tools = Some(vec![ToolRequest {
603            name: "my_tool".into(),
604            description: "does stuff".into(),
605            parameters_schema: serde_json::json!({"type": "object"}),
606        }]);
607
608        // Explicit Auto.
609        let req_auto = CompletionRequest {
610            system: None,
611            messages: vec![Message {
612                role: Role::User,
613                content: "hi".into(),
614                tool_call_id: None,
615            }],
616            max_tokens: 100,
617            model_override: None,
618            schema: None,
619            tools: tools.clone(),
620            tool_choice: Some(ToolChoice::Auto),
621        };
622        let body = client.build_body(&req_auto, false);
623        assert_eq!(body["tool_choice"], "auto");
624
625        // Default None → also "auto".
626        let req_default = CompletionRequest {
627            tool_choice: None,
628            ..req_auto
629        };
630        let body2 = client.build_body(&req_default, false);
631        assert_eq!(body2["tool_choice"], "auto");
632    }
633
634    #[test]
635    fn embed_model_default_is_text_embedding_3_small() {
636        let _g = crate::ENV_LOCK.lock().unwrap();
637        std::env::remove_var("FERRO_AI_EMBED_MODEL");
638        assert_eq!(OpenAiClient::embed_model(), "text-embedding-3-small");
639    }
640
641    #[test]
642    fn embed_model_from_env() {
643        let _g = crate::ENV_LOCK.lock().unwrap();
644        std::env::set_var("FERRO_AI_EMBED_MODEL", "text-embedding-ada-002");
645        assert_eq!(OpenAiClient::embed_model(), "text-embedding-ada-002");
646        std::env::remove_var("FERRO_AI_EMBED_MODEL");
647    }
648}