1use async_stream::try_stream;
2use async_trait::async_trait;
3use forgeai_core::{
4 AdapterInfo, CapabilityMatrix, ChatAdapter, ChatRequest, ChatResponse, ForgeError, Role,
5 StreamEvent, StreamResult, ToolCall, Usage,
6};
7use futures_util::StreamExt;
8use reqwest::{Client as HttpClient, StatusCode};
9use serde_json::{json, Map, Value};
10use std::env;
11use url::Url;
12
13#[derive(Clone, Debug)]
14pub struct OpenAiAdapter {
15 pub api_key: String,
16 pub base_url: Url,
17 client: HttpClient,
18}
19
20impl OpenAiAdapter {
21 pub fn new(api_key: impl Into<String>) -> Result<Self, ForgeError> {
22 let base_url = Url::parse("https://api.openai.com")
23 .map_err(|e| ForgeError::Internal(e.to_string()))?;
24 Self::with_base_url(api_key, base_url)
25 }
26
27 pub fn with_base_url(api_key: impl Into<String>, base_url: Url) -> Result<Self, ForgeError> {
28 let client = HttpClient::builder()
29 .build()
30 .map_err(|e| ForgeError::Internal(format!("failed to build http client: {e}")))?;
31 Ok(Self {
32 api_key: api_key.into(),
33 base_url,
34 client,
35 })
36 }
37
38 pub fn from_env() -> Result<Self, ForgeError> {
39 let api_key = env::var("OPENAI_API_KEY").map_err(|_| ForgeError::Authentication)?;
40 match env::var("OPENAI_BASE_URL") {
41 Ok(raw) => {
42 let base_url = Url::parse(&raw)
43 .map_err(|e| ForgeError::Validation(format!("invalid OPENAI_BASE_URL: {e}")))?;
44 Self::with_base_url(api_key, base_url)
45 }
46 Err(_) => Self::new(api_key),
47 }
48 }
49
50 fn chat_completions_url(&self) -> Result<Url, ForgeError> {
51 self.base_url
52 .join("v1/chat/completions")
53 .map_err(|e| ForgeError::Internal(format!("failed to construct endpoint url: {e}")))
54 }
55}
56
57#[async_trait]
58impl ChatAdapter for OpenAiAdapter {
59 fn info(&self) -> AdapterInfo {
60 AdapterInfo {
61 name: "openai".to_string(),
62 base_url: Some(self.base_url.clone()),
63 capabilities: CapabilityMatrix {
64 streaming: true,
65 tools: true,
66 structured_output: true,
67 multimodal_input: true,
68 citations: false,
69 },
70 }
71 }
72
73 async fn chat(&self, request: ChatRequest) -> Result<ChatResponse, ForgeError> {
74 let response = self
75 .client
76 .post(self.chat_completions_url()?)
77 .bearer_auth(&self.api_key)
78 .json(&build_chat_body(request, false))
79 .send()
80 .await
81 .map_err(|e| ForgeError::Transport(format!("request failed: {e}")))?;
82
83 if !response.status().is_success() {
84 let status = response.status();
85 let text = response
86 .text()
87 .await
88 .unwrap_or_else(|_| "failed to read error body".to_string());
89 return Err(parse_http_error(status, text));
90 }
91
92 let payload = response
93 .json::<Value>()
94 .await
95 .map_err(|e| ForgeError::Provider(format!("invalid json response: {e}")))?;
96 parse_chat_response(payload)
97 }
98
99 async fn chat_stream(
100 &self,
101 request: ChatRequest,
102 ) -> Result<StreamResult<StreamEvent>, ForgeError> {
103 let response = self
104 .client
105 .post(self.chat_completions_url()?)
106 .bearer_auth(&self.api_key)
107 .json(&build_chat_body(request, true))
108 .send()
109 .await
110 .map_err(|e| ForgeError::Transport(format!("stream request failed: {e}")))?;
111
112 if !response.status().is_success() {
113 let status = response.status();
114 let text = response
115 .text()
116 .await
117 .unwrap_or_else(|_| "failed to read error body".to_string());
118 return Err(parse_http_error(status, text));
119 }
120
121 let mut bytes = response.bytes_stream();
122 let stream = try_stream! {
123 let mut buffer = String::new();
124 let mut saw_done = false;
125
126 while let Some(chunk) = bytes.next().await {
127 let chunk = chunk.map_err(|e| ForgeError::Transport(format!("stream chunk error: {e}")))?;
128 let chunk_text = std::str::from_utf8(&chunk)
129 .map_err(|e| ForgeError::Transport(format!("invalid utf8 stream chunk: {e}")))?;
130 buffer.push_str(chunk_text);
131
132 while let Some(line_end) = buffer.find('\n') {
133 let mut line = buffer[..line_end].to_string();
134 buffer.drain(..=line_end);
135 if line.ends_with('\r') {
136 line.pop();
137 }
138 if line.trim().is_empty() {
139 continue;
140 }
141 if let Some(data) = line.strip_prefix("data:") {
142 let payload = data.trim();
143 if payload == "[DONE]" {
144 saw_done = true;
145 yield StreamEvent::Done;
146 continue;
147 }
148 for event in parse_stream_payload(payload)? {
149 yield event;
150 }
151 }
152 }
153 }
154
155 if !buffer.trim().is_empty() {
156 let line = buffer.trim();
157 if let Some(data) = line.strip_prefix("data:") {
158 let payload = data.trim();
159 if payload == "[DONE]" {
160 saw_done = true;
161 yield StreamEvent::Done;
162 } else {
163 for event in parse_stream_payload(payload)? {
164 yield event;
165 }
166 }
167 }
168 }
169
170 if !saw_done {
171 yield StreamEvent::Done;
172 }
173 };
174
175 Ok(Box::pin(stream))
176 }
177}
178
179fn build_chat_body(request: ChatRequest, stream: bool) -> Value {
180 let mut body = Map::new();
181 body.insert("model".to_string(), Value::String(request.model));
182 body.insert(
183 "messages".to_string(),
184 Value::Array(
185 request
186 .messages
187 .into_iter()
188 .map(|m| {
189 json!({
190 "role": role_to_openai(&m.role),
191 "content": m.content
192 })
193 })
194 .collect(),
195 ),
196 );
197 if let Some(temperature) = request.temperature {
198 body.insert("temperature".to_string(), json!(temperature));
199 }
200 if let Some(max_tokens) = request.max_tokens {
201 body.insert("max_tokens".to_string(), json!(max_tokens));
202 }
203 if !request.tools.is_empty() {
204 body.insert(
205 "tools".to_string(),
206 Value::Array(
207 request
208 .tools
209 .into_iter()
210 .map(|tool| {
211 json!({
212 "type": "function",
213 "function": {
214 "name": tool.name,
215 "description": tool.description,
216 "parameters": tool.input_schema,
217 }
218 })
219 })
220 .collect(),
221 ),
222 );
223 }
224 if stream {
225 body.insert("stream".to_string(), Value::Bool(true));
226 body.insert("stream_options".to_string(), json!({"include_usage": true}));
227 }
228 Value::Object(body)
229}
230
231fn role_to_openai(role: &Role) -> &'static str {
232 match role {
233 Role::System => "system",
234 Role::User => "user",
235 Role::Assistant => "assistant",
236 Role::Tool => "tool",
237 }
238}
239
240fn parse_http_error(status: StatusCode, body: String) -> ForgeError {
241 let message = extract_provider_error(body);
242 match status {
243 StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN => ForgeError::Authentication,
244 StatusCode::TOO_MANY_REQUESTS => ForgeError::RateLimited,
245 _ => ForgeError::Provider(message),
246 }
247}
248
249fn extract_provider_error(body: String) -> String {
250 serde_json::from_str::<Value>(&body)
251 .ok()
252 .and_then(|v| {
253 v.get("error")
254 .and_then(|e| e.get("message"))
255 .and_then(Value::as_str)
256 .map(ToString::to_string)
257 })
258 .unwrap_or(body)
259}
260
261fn parse_chat_response(payload: Value) -> Result<ChatResponse, ForgeError> {
262 let id = payload
263 .get("id")
264 .and_then(Value::as_str)
265 .unwrap_or_default()
266 .to_string();
267 let model = payload
268 .get("model")
269 .and_then(Value::as_str)
270 .unwrap_or_default()
271 .to_string();
272
273 let choice = payload
274 .get("choices")
275 .and_then(Value::as_array)
276 .and_then(|choices| choices.first());
277
278 let message = choice
279 .and_then(|c| c.get("message"))
280 .unwrap_or(&Value::Null);
281 let output_text = extract_text_content(message.get("content"));
282 let tool_calls = extract_tool_calls(message.get("tool_calls"));
283 let usage = extract_usage(payload.get("usage"));
284
285 Ok(ChatResponse {
286 id,
287 model,
288 output_text,
289 tool_calls,
290 usage,
291 })
292}
293
294fn extract_text_content(content: Option<&Value>) -> String {
295 match content {
296 Some(Value::String(text)) => text.clone(),
297 Some(Value::Array(parts)) => parts
298 .iter()
299 .filter_map(|part| part.get("text").and_then(Value::as_str))
300 .collect::<Vec<_>>()
301 .join(""),
302 _ => String::new(),
303 }
304}
305
306fn extract_tool_calls(raw: Option<&Value>) -> Vec<ToolCall> {
307 raw.and_then(Value::as_array)
308 .map(|items| {
309 items
310 .iter()
311 .map(|item| {
312 let id = item
313 .get("id")
314 .and_then(Value::as_str)
315 .unwrap_or_default()
316 .to_string();
317 let function = item.get("function").unwrap_or(&Value::Null);
318 let name = function
319 .get("name")
320 .and_then(Value::as_str)
321 .unwrap_or_default()
322 .to_string();
323 let arguments = function
324 .get("arguments")
325 .and_then(Value::as_str)
326 .and_then(|raw_args| serde_json::from_str::<Value>(raw_args).ok())
327 .unwrap_or_else(|| {
328 function.get("arguments").cloned().unwrap_or(Value::Null)
329 });
330 ToolCall {
331 id,
332 name,
333 arguments,
334 }
335 })
336 .collect()
337 })
338 .unwrap_or_default()
339}
340
341fn extract_usage(raw: Option<&Value>) -> Option<Usage> {
342 let usage = raw?;
343 let input_tokens = usage.get("prompt_tokens")?.as_u64()? as u32;
344 let output_tokens = usage.get("completion_tokens")?.as_u64()? as u32;
345 let total_tokens = usage.get("total_tokens")?.as_u64()? as u32;
346 Some(Usage {
347 input_tokens,
348 output_tokens,
349 total_tokens,
350 })
351}
352
353fn parse_stream_payload(payload: &str) -> Result<Vec<StreamEvent>, ForgeError> {
354 let value = serde_json::from_str::<Value>(payload)
355 .map_err(|e| ForgeError::Provider(format!("invalid stream payload: {e}")))?;
356
357 let mut events = Vec::new();
358 if let Some(usage) = extract_usage(value.get("usage")) {
359 events.push(StreamEvent::Usage { usage });
360 }
361
362 if let Some(choices) = value.get("choices").and_then(Value::as_array) {
363 for choice in choices {
364 if let Some(content) = choice
365 .get("delta")
366 .and_then(|d| d.get("content"))
367 .and_then(Value::as_str)
368 .filter(|s| !s.is_empty())
369 {
370 events.push(StreamEvent::TextDelta {
371 delta: content.to_string(),
372 });
373 }
374
375 if let Some(tool_calls) = choice
376 .get("delta")
377 .and_then(|d| d.get("tool_calls"))
378 .and_then(Value::as_array)
379 {
380 for tool_call in tool_calls {
381 let call_id = tool_call
382 .get("id")
383 .and_then(Value::as_str)
384 .unwrap_or_default()
385 .to_string();
386 events.push(StreamEvent::ToolCallDelta {
387 call_id,
388 delta: tool_call.clone(),
389 });
390 }
391 }
392 }
393 }
394
395 Ok(events)
396}
397
398#[cfg(test)]
399mod tests {
400 use super::*;
401 use forgeai_core::{ChatRequest, Message, Role};
402 use futures_util::StreamExt;
403 use wiremock::matchers::{body_partial_json, header, method, path};
404 use wiremock::{Mock, MockServer, ResponseTemplate};
405
406 fn sample_request() -> ChatRequest {
407 ChatRequest {
408 model: "gpt-4o-mini".to_string(),
409 messages: vec![Message {
410 role: Role::User,
411 content: "Say hello".to_string(),
412 }],
413 temperature: Some(0.2),
414 max_tokens: Some(32),
415 tools: vec![],
416 metadata: json!({}),
417 }
418 }
419
420 #[tokio::test]
421 async fn chat_contract_parses_response_and_usage() {
422 let server = MockServer::start().await;
423 Mock::given(method("POST"))
424 .and(path("/v1/chat/completions"))
425 .and(header("authorization", "Bearer test-key"))
426 .and(body_partial_json(json!({"model": "gpt-4o-mini"})))
427 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
428 "id": "chatcmpl-123",
429 "model": "gpt-4o-mini",
430 "choices": [{
431 "index": 0,
432 "message": {"role": "assistant", "content": "Hello from OpenAI"}
433 }],
434 "usage": {"prompt_tokens": 10, "completion_tokens": 4, "total_tokens": 14}
435 })))
436 .mount(&server)
437 .await;
438
439 let adapter =
440 OpenAiAdapter::with_base_url("test-key", Url::parse(&server.uri()).unwrap()).unwrap();
441 let response = adapter.chat(sample_request()).await.unwrap();
442
443 assert_eq!(response.id, "chatcmpl-123");
444 assert_eq!(response.model, "gpt-4o-mini");
445 assert_eq!(response.output_text, "Hello from OpenAI");
446 assert_eq!(response.usage.unwrap().total_tokens, 14);
447 }
448
449 #[tokio::test]
450 async fn chat_stream_contract_parses_sse_events() {
451 let server = MockServer::start().await;
452 let sse_body = concat!(
453 "data: {\"id\":\"chatcmpl-1\",\"model\":\"gpt-4o-mini\",\"choices\":[{\"delta\":{\"content\":\"Hello\"},\"index\":0}]}\n\n",
454 "data: {\"id\":\"chatcmpl-1\",\"model\":\"gpt-4o-mini\",\"choices\":[{\"delta\":{\"content\":\" world\"},\"index\":0}]}\n\n",
455 "data: {\"usage\":{\"prompt_tokens\":10,\"completion_tokens\":2,\"total_tokens\":12},\"choices\":[]}\n\n",
456 "data: [DONE]\n\n"
457 );
458
459 Mock::given(method("POST"))
460 .and(path("/v1/chat/completions"))
461 .and(header("authorization", "Bearer test-key"))
462 .and(body_partial_json(json!({"stream": true})))
463 .respond_with(ResponseTemplate::new(200).set_body_raw(sse_body, "text/event-stream"))
464 .mount(&server)
465 .await;
466
467 let adapter =
468 OpenAiAdapter::with_base_url("test-key", Url::parse(&server.uri()).unwrap()).unwrap();
469 let mut stream = adapter.chat_stream(sample_request()).await.unwrap();
470 let mut events = Vec::new();
471 while let Some(item) = stream.next().await {
472 let event = item.unwrap();
473 let done = matches!(event, StreamEvent::Done);
474 events.push(event);
475 if done {
476 break;
477 }
478 }
479
480 assert!(events
481 .iter()
482 .any(|e| matches!(e, StreamEvent::TextDelta { delta } if delta == "Hello")));
483 assert!(events
484 .iter()
485 .any(|e| matches!(e, StreamEvent::TextDelta { delta } if delta == " world")));
486 assert!(events.iter().any(|e| matches!(
487 e,
488 StreamEvent::Usage { usage } if usage.total_tokens == 12
489 )));
490 assert!(events.iter().any(|e| matches!(e, StreamEvent::Done)));
491 }
492}