1use crate::client::{
2 CompletionRequest, CompletionResponse, LlmClient, Role, TokenStream, ToolChoice, ToolUseBlock,
3};
4use crate::error::Error;
5use async_trait::async_trait;
6use futures::{stream, StreamExt};
7use reqwest_eventsource::{Event, RequestBuilderExt};
8
9pub struct OpenAiClient {
24 client: reqwest::Client,
25 api_key: String,
26 model: Option<String>,
27 base_url: String,
28}
29
30impl OpenAiClient {
31 pub fn new(api_key: String, model: Option<String>, base_url: Option<String>) -> Self {
40 let client = reqwest::Client::builder()
41 .timeout(std::time::Duration::from_secs(60))
42 .build()
43 .expect("failed to build reqwest client");
44 let base_url = base_url.unwrap_or_else(|| "https://api.openai.com".to_string());
45 Self {
46 client,
47 api_key,
48 model,
49 base_url,
50 }
51 }
52
53 pub(crate) fn embed_model() -> String {
58 std::env::var("FERRO_AI_EMBED_MODEL")
59 .unwrap_or_else(|_| "text-embedding-3-small".to_string())
60 }
61
62 pub(crate) fn build_body(
67 &self,
68 request: &CompletionRequest,
69 stream: bool,
70 ) -> serde_json::Value {
71 let model = request
72 .model_override
73 .as_deref()
74 .unwrap_or_else(|| self.default_model());
75
76 let messages: Vec<serde_json::Value> = request
77 .messages
78 .iter()
79 .map(|m| match m.role {
80 Role::Tool => {
81 let call_id = m.tool_call_id.as_deref().unwrap_or("");
84 serde_json::json!({
85 "role": "tool",
86 "tool_call_id": call_id,
87 "content": m.content,
88 })
89 }
90 Role::User => serde_json::json!({"role": "user", "content": m.content}),
91 Role::Assistant => {
92 serde_json::json!({"role": "assistant", "content": m.content})
93 }
94 })
95 .collect();
96
97 let mut body = serde_json::json!({
98 "model": model,
99 "messages": messages,
100 "max_tokens": request.max_tokens,
101 "stream": stream,
102 });
103
104 if let Some(schema) = &request.schema {
105 body["response_format"] = serde_json::json!({
106 "type": "json_schema",
107 "json_schema": {
108 "name": "output",
109 "schema": schema,
110 "strict": true,
111 }
112 });
113 }
114
115 if let Some(tools) = &request.tools {
116 let tools_json: Vec<serde_json::Value> = tools
117 .iter()
118 .map(|t| {
119 serde_json::json!({
120 "type": "function",
121 "function": {
122 "name": t.name,
123 "description": t.description,
124 "parameters": t.parameters_schema,
125 "strict": true,
126 }
127 })
128 })
129 .collect();
130 body["tools"] = serde_json::Value::Array(tools_json);
131 body["tool_choice"] = match request.tool_choice.as_ref() {
133 Some(ToolChoice::None) => serde_json::json!("none"),
134 Some(ToolChoice::Auto) | None => serde_json::json!("auto"),
135 };
136 }
137
138 body
139 }
140}
141
142pub(crate) fn parse_openai_tool_calls(json: &serde_json::Value) -> Vec<ToolUseBlock> {
144 let Some(tool_calls) = json["choices"][0]["message"]["tool_calls"].as_array() else {
145 return vec![];
146 };
147 tool_calls
148 .iter()
149 .filter_map(|c| {
150 Some(ToolUseBlock {
151 id: c["id"].as_str()?.to_string(),
152 name: c["function"]["name"].as_str()?.to_string(),
153 input: serde_json::from_str(c["function"]["arguments"].as_str()?).ok()?,
154 })
155 })
156 .collect()
157}
158
159#[derive(Debug, PartialEq)]
161pub(crate) enum OpenAiDelta {
162 Done,
164 Token(String),
166 Skip,
168}
169
170pub(crate) fn parse_openai_delta(data: &str) -> OpenAiDelta {
175 if data == "[DONE]" {
176 return OpenAiDelta::Done;
177 }
178 let Ok(v) = serde_json::from_str::<serde_json::Value>(data) else {
179 return OpenAiDelta::Skip;
180 };
181 if !v["choices"][0]["finish_reason"].is_null() {
183 if let Some(reason) = v["choices"][0]["finish_reason"].as_str() {
184 if !reason.is_empty() {
185 return OpenAiDelta::Done;
186 }
187 }
188 }
189 match v["choices"][0]["delta"]["content"].as_str() {
190 Some(text) if !text.is_empty() => OpenAiDelta::Token(text.to_string()),
191 _ => OpenAiDelta::Skip,
192 }
193}
194
195pub(crate) fn parse_embedding(json: &serde_json::Value) -> Result<Vec<f32>, Error> {
197 json["data"][0]["embedding"]
198 .as_array()
199 .map(|arr| {
200 arr.iter()
201 .filter_map(|v| v.as_f64().map(|f| f as f32))
202 .collect()
203 })
204 .ok_or_else(|| Error::Deserialization("no embedding in response".into()))
205}
206
207#[async_trait]
208impl LlmClient for OpenAiClient {
209 fn default_model(&self) -> &str {
210 self.model.as_deref().unwrap_or("gpt-4o")
211 }
212
213 async fn complete(&self, request: CompletionRequest) -> Result<String, Error> {
214 let body = self.build_body(&request, false);
215
216 let resp = self
217 .client
218 .post(format!("{}/v1/chat/completions", self.base_url))
219 .bearer_auth(&self.api_key)
220 .json(&body)
221 .send()
222 .await
223 .map_err(|e| {
224 if e.is_timeout() {
225 Error::Timeout
226 } else {
227 Error::Provider {
228 status: None,
229 message: e.to_string(),
230 }
231 }
232 })?;
233
234 let status = resp.status().as_u16();
235 if !resp.status().is_success() {
236 let text = resp.text().await.unwrap_or_default();
237 return Err(Error::Provider {
238 status: Some(status),
239 message: text,
240 });
241 }
242
243 let json: serde_json::Value = resp
244 .json()
245 .await
246 .map_err(|e| Error::Deserialization(e.to_string()))?;
247
248 json["choices"][0]["message"]["content"]
249 .as_str()
250 .map(|s| s.to_string())
251 .ok_or_else(|| Error::Deserialization("no content in response".into()))
252 }
253
254 async fn complete_stream(&self, request: CompletionRequest) -> Result<TokenStream, Error> {
255 let body = self.build_body(&request, true);
256
257 let builder = self
258 .client
259 .post(format!("{}/v1/chat/completions", self.base_url))
260 .bearer_auth(&self.api_key)
261 .json(&body);
262
263 let es = builder.eventsource().map_err(|_| Error::Provider {
264 status: None,
265 message: "request not cloneable".into(),
266 })?;
267
268 let token_stream = stream::unfold(es, |mut es| async move {
269 loop {
270 match es.next().await {
271 None => return None,
272 Some(Ok(Event::Open)) => continue,
273 Some(Ok(Event::Message(msg))) => match parse_openai_delta(&msg.data) {
274 OpenAiDelta::Done => {
275 es.close();
276 return None;
277 }
278 OpenAiDelta::Token(text) => return Some((Ok(text), es)),
279 OpenAiDelta::Skip => continue,
280 },
281 Some(Err(e)) => {
282 es.close();
283 return Some((
284 Err(Error::Provider {
285 status: None,
286 message: e.to_string(),
287 }),
288 es,
289 ));
290 }
291 }
292 }
293 });
294
295 Ok(Box::pin(token_stream))
296 }
297
298 async fn embed(&self, text: &str) -> Result<Vec<f32>, Error> {
299 let body = serde_json::json!({
300 "model": Self::embed_model(),
301 "input": text,
302 });
303
304 let resp = self
305 .client
306 .post(format!("{}/v1/embeddings", self.base_url))
307 .bearer_auth(&self.api_key)
308 .json(&body)
309 .send()
310 .await
311 .map_err(|e| {
312 if e.is_timeout() {
313 Error::Timeout
314 } else {
315 Error::Provider {
316 status: None,
317 message: e.to_string(),
318 }
319 }
320 })?;
321
322 let status = resp.status().as_u16();
323 if !resp.status().is_success() {
324 let text = resp.text().await.unwrap_or_default();
325 return Err(Error::Provider {
326 status: Some(status),
327 message: text,
328 });
329 }
330
331 let json: serde_json::Value = resp
332 .json()
333 .await
334 .map_err(|e| Error::Deserialization(e.to_string()))?;
335
336 parse_embedding(&json)
337 }
338
339 async fn complete_with_tools(
340 &self,
341 request: CompletionRequest,
342 ) -> Result<CompletionResponse, Error> {
343 let body = self.build_body(&request, false);
344
345 let resp = self
346 .client
347 .post(format!("{}/v1/chat/completions", self.base_url))
348 .bearer_auth(&self.api_key)
349 .json(&body)
350 .send()
351 .await
352 .map_err(|e| {
353 if e.is_timeout() {
354 Error::Timeout
355 } else {
356 Error::Provider {
357 status: None,
358 message: e.to_string(),
359 }
360 }
361 })?;
362
363 let status = resp.status().as_u16();
364 if !resp.status().is_success() {
365 let text = resp.text().await.unwrap_or_default();
366 return Err(Error::Provider {
367 status: Some(status),
368 message: text,
369 });
370 }
371
372 let json: serde_json::Value = resp
373 .json()
374 .await
375 .map_err(|e| Error::Deserialization(e.to_string()))?;
376
377 let finish_reason = json["choices"][0]["finish_reason"].as_str().unwrap_or("");
378 if finish_reason == "tool_calls" {
379 let blocks = parse_openai_tool_calls(&json);
380 let assistant_content = json["choices"][0]["message"]["tool_calls"].to_string();
381 return Ok(CompletionResponse::ToolUse {
382 blocks,
383 assistant_content,
384 });
385 }
386
387 let text = json["choices"][0]["message"]["content"]
389 .as_str()
390 .map(|s| s.to_string())
391 .ok_or_else(|| Error::Deserialization("no content in response".into()))?;
392
393 Ok(CompletionResponse::Text(text))
394 }
395}
396
397#[cfg(test)]
398mod tests {
399 use super::*;
400 use crate::client::Message;
401
402 #[test]
403 fn test_openai_default_model() {
404 let client = OpenAiClient::new("k".into(), None, None);
405 assert_eq!(client.default_model(), "gpt-4o");
406 }
407
408 #[test]
409 fn test_openai_default_base_url() {
410 let client = OpenAiClient::new("k".into(), None, None);
411 assert_eq!(client.base_url, "https://api.openai.com");
412 }
413
414 #[test]
415 fn test_openai_groq_base_url() {
416 let client =
417 OpenAiClient::new("k".into(), None, Some("https://api.groq.com/openai".into()));
418 assert_eq!(client.base_url, "https://api.groq.com/openai");
419 }
420
421 #[test]
422 fn test_build_body_response_format_with_schema() {
423 let client = OpenAiClient::new("k".into(), None, None);
424 let schema = serde_json::json!({"type": "object", "properties": {"x": {"type": "string"}}});
425 let request = CompletionRequest {
426 system: None,
427 messages: vec![Message {
428 role: Role::User,
429 content: "hi".into(),
430 tool_call_id: None,
431 }],
432 max_tokens: 100,
433 model_override: None,
434 schema: Some(schema.clone()),
435 tools: None,
436 tool_choice: None,
437 };
438 let body = client.build_body(&request, false);
439
440 assert_eq!(body["response_format"]["type"], "json_schema");
441 assert_eq!(body["response_format"]["json_schema"]["name"], "output");
442 assert_eq!(body["response_format"]["json_schema"]["schema"], schema);
443 assert_eq!(body["response_format"]["json_schema"]["strict"], true);
444 }
445
446 #[test]
447 fn test_build_body_no_response_format_without_schema() {
448 let client = OpenAiClient::new("k".into(), None, None);
449 let request = CompletionRequest {
450 system: None,
451 messages: vec![Message {
452 role: Role::User,
453 content: "hi".into(),
454 tool_call_id: None,
455 }],
456 max_tokens: 100,
457 model_override: None,
458 schema: None,
459 tools: None,
460 tool_choice: None,
461 };
462 let body = client.build_body(&request, false);
463 assert!(body.get("response_format").is_none());
464 }
465
466 #[test]
467 fn test_parse_openai_delta_done() {
468 assert_eq!(parse_openai_delta("[DONE]"), OpenAiDelta::Done);
469 }
470
471 #[test]
472 fn test_parse_openai_delta_token() {
473 let data = r#"{"id":"chatcmpl-1","object":"chat.completion.chunk","choices":[{"index":0,"delta":{"content":"Hello"},"finish_reason":null}]}"#;
475 assert_eq!(
476 parse_openai_delta(data),
477 OpenAiDelta::Token("Hello".to_string())
478 );
479 }
480
481 #[test]
482 fn test_parse_openai_delta_skip_empty_content() {
483 let data = r#"{"choices":[{"index":0,"delta":{"role":"assistant","content":null},"finish_reason":null}]}"#;
484 assert_eq!(parse_openai_delta(data), OpenAiDelta::Skip);
485 }
486
487 #[test]
488 fn test_parse_openai_delta_finish_reason() {
489 let data = r#"{"choices":[{"index":0,"delta":{},"finish_reason":"stop"}]}"#;
490 assert_eq!(parse_openai_delta(data), OpenAiDelta::Done);
491 }
492
493 #[test]
494 fn test_parse_embedding() {
495 let json = serde_json::json!({
496 "data": [{"embedding": [0.1, -0.2, 0.3], "index": 0}],
497 "usage": {}
498 });
499 let result = parse_embedding(&json).unwrap();
500 assert_eq!(result.len(), 3);
501 assert!((result[0] - 0.1f32).abs() < 1e-6);
502 assert!((result[1] - (-0.2f32)).abs() < 1e-6);
503 assert!((result[2] - 0.3f32).abs() < 1e-6);
504 }
505
506 #[test]
507 fn test_parse_embedding_missing() {
508 let json = serde_json::json!({"data": []});
509 assert!(matches!(
510 parse_embedding(&json),
511 Err(Error::Deserialization(_))
512 ));
513 }
514
515 #[test]
516 fn test_openai_is_object_safe() {
517 let _: Box<dyn LlmClient> = Box::new(OpenAiClient::new("k".into(), None, None));
518 }
519
520 #[test]
523 fn test_build_body_tool_result_wire_format() {
524 let client = OpenAiClient::new("k".into(), None, None);
525 let request = CompletionRequest {
526 system: None,
527 messages: vec![
528 Message {
529 role: Role::User,
530 content: "what is 2+2?".into(),
531 tool_call_id: None,
532 },
533 Message {
534 role: Role::Tool,
535 content: "4".into(),
536 tool_call_id: Some("call_abc123".into()),
537 },
538 ],
539 max_tokens: 100,
540 model_override: None,
541 schema: None,
542 tools: None,
543 tool_choice: None,
544 };
545 let body = client.build_body(&request, false);
546 let msgs = body["messages"].as_array().expect("messages must be array");
547 assert_eq!(msgs.len(), 2);
548
549 let tool_msg = &msgs[1];
550 assert_eq!(tool_msg["role"], "tool");
551 assert_eq!(
552 tool_msg["tool_call_id"], "call_abc123",
553 "tool_call_id must be a real top-level field"
554 );
555 assert_eq!(tool_msg["content"], "4");
556 assert!(
558 !tool_msg["content"]
559 .as_str()
560 .unwrap_or("")
561 .contains("call_abc123"),
562 "tool_call_id must not be embedded in content"
563 );
564 }
565
566 #[test]
568 fn test_build_body_tool_choice_none() {
569 use crate::client::{ToolChoice, ToolRequest};
570
571 let client = OpenAiClient::new("k".into(), None, None);
572 let request = CompletionRequest {
573 system: None,
574 messages: vec![Message {
575 role: Role::User,
576 content: "hi".into(),
577 tool_call_id: None,
578 }],
579 max_tokens: 100,
580 model_override: None,
581 schema: None,
582 tools: Some(vec![ToolRequest {
583 name: "my_tool".into(),
584 description: "does stuff".into(),
585 parameters_schema: serde_json::json!({"type": "object"}),
586 }]),
587 tool_choice: Some(ToolChoice::None),
588 };
589 let body = client.build_body(&request, false);
590 assert_eq!(
591 body["tool_choice"], "none",
592 "ToolChoice::None must emit tool_choice: 'none'"
593 );
594 }
595
596 #[test]
598 fn test_build_body_tool_choice_auto() {
599 use crate::client::{ToolChoice, ToolRequest};
600
601 let client = OpenAiClient::new("k".into(), None, None);
602 let tools = Some(vec![ToolRequest {
603 name: "my_tool".into(),
604 description: "does stuff".into(),
605 parameters_schema: serde_json::json!({"type": "object"}),
606 }]);
607
608 let req_auto = CompletionRequest {
610 system: None,
611 messages: vec![Message {
612 role: Role::User,
613 content: "hi".into(),
614 tool_call_id: None,
615 }],
616 max_tokens: 100,
617 model_override: None,
618 schema: None,
619 tools: tools.clone(),
620 tool_choice: Some(ToolChoice::Auto),
621 };
622 let body = client.build_body(&req_auto, false);
623 assert_eq!(body["tool_choice"], "auto");
624
625 let req_default = CompletionRequest {
627 tool_choice: None,
628 ..req_auto
629 };
630 let body2 = client.build_body(&req_default, false);
631 assert_eq!(body2["tool_choice"], "auto");
632 }
633
634 #[test]
635 fn embed_model_default_is_text_embedding_3_small() {
636 let _g = crate::ENV_LOCK.lock().unwrap();
637 std::env::remove_var("FERRO_AI_EMBED_MODEL");
638 assert_eq!(OpenAiClient::embed_model(), "text-embedding-3-small");
639 }
640
641 #[test]
642 fn embed_model_from_env() {
643 let _g = crate::ENV_LOCK.lock().unwrap();
644 std::env::set_var("FERRO_AI_EMBED_MODEL", "text-embedding-ada-002");
645 assert_eq!(OpenAiClient::embed_model(), "text-embedding-ada-002");
646 std::env::remove_var("FERRO_AI_EMBED_MODEL");
647 }
648}