Skip to main content

forgeai_adapter_openai/
lib.rs

1use async_stream::try_stream;
2use async_trait::async_trait;
3use forgeai_core::{
4    AdapterInfo, CapabilityMatrix, ChatAdapter, ChatRequest, ChatResponse, ForgeError, Role,
5    StreamEvent, StreamResult, ToolCall, Usage,
6};
7use futures_util::StreamExt;
8use reqwest::{Client as HttpClient, StatusCode};
9use serde_json::{json, Map, Value};
10use std::env;
11use url::Url;
12
13#[derive(Clone, Debug)]
14pub struct OpenAiAdapter {
15    pub api_key: String,
16    pub base_url: Url,
17    client: HttpClient,
18}
19
20impl OpenAiAdapter {
21    pub fn new(api_key: impl Into<String>) -> Result<Self, ForgeError> {
22        let base_url = Url::parse("https://api.openai.com")
23            .map_err(|e| ForgeError::Internal(e.to_string()))?;
24        Self::with_base_url(api_key, base_url)
25    }
26
27    pub fn with_base_url(api_key: impl Into<String>, base_url: Url) -> Result<Self, ForgeError> {
28        let client = HttpClient::builder()
29            .build()
30            .map_err(|e| ForgeError::Internal(format!("failed to build http client: {e}")))?;
31        Ok(Self {
32            api_key: api_key.into(),
33            base_url,
34            client,
35        })
36    }
37
38    pub fn from_env() -> Result<Self, ForgeError> {
39        let api_key = env::var("OPENAI_API_KEY").map_err(|_| ForgeError::Authentication)?;
40        match env::var("OPENAI_BASE_URL") {
41            Ok(raw) => {
42                let base_url = Url::parse(&raw)
43                    .map_err(|e| ForgeError::Validation(format!("invalid OPENAI_BASE_URL: {e}")))?;
44                Self::with_base_url(api_key, base_url)
45            }
46            Err(_) => Self::new(api_key),
47        }
48    }
49
50    fn chat_completions_url(&self) -> Result<Url, ForgeError> {
51        self.base_url
52            .join("v1/chat/completions")
53            .map_err(|e| ForgeError::Internal(format!("failed to construct endpoint url: {e}")))
54    }
55}
56
57#[async_trait]
58impl ChatAdapter for OpenAiAdapter {
59    fn info(&self) -> AdapterInfo {
60        AdapterInfo {
61            name: "openai".to_string(),
62            base_url: Some(self.base_url.clone()),
63            capabilities: CapabilityMatrix {
64                streaming: true,
65                tools: true,
66                structured_output: true,
67                multimodal_input: true,
68                citations: false,
69            },
70        }
71    }
72
73    async fn chat(&self, request: ChatRequest) -> Result<ChatResponse, ForgeError> {
74        let response = self
75            .client
76            .post(self.chat_completions_url()?)
77            .bearer_auth(&self.api_key)
78            .json(&build_chat_body(request, false))
79            .send()
80            .await
81            .map_err(|e| ForgeError::Transport(format!("request failed: {e}")))?;
82
83        if !response.status().is_success() {
84            let status = response.status();
85            let text = response
86                .text()
87                .await
88                .unwrap_or_else(|_| "failed to read error body".to_string());
89            return Err(parse_http_error(status, text));
90        }
91
92        let payload = response
93            .json::<Value>()
94            .await
95            .map_err(|e| ForgeError::Provider(format!("invalid json response: {e}")))?;
96        parse_chat_response(payload)
97    }
98
99    async fn chat_stream(
100        &self,
101        request: ChatRequest,
102    ) -> Result<StreamResult<StreamEvent>, ForgeError> {
103        let response = self
104            .client
105            .post(self.chat_completions_url()?)
106            .bearer_auth(&self.api_key)
107            .json(&build_chat_body(request, true))
108            .send()
109            .await
110            .map_err(|e| ForgeError::Transport(format!("stream request failed: {e}")))?;
111
112        if !response.status().is_success() {
113            let status = response.status();
114            let text = response
115                .text()
116                .await
117                .unwrap_or_else(|_| "failed to read error body".to_string());
118            return Err(parse_http_error(status, text));
119        }
120
121        let mut bytes = response.bytes_stream();
122        let stream = try_stream! {
123            let mut buffer = String::new();
124            let mut saw_done = false;
125
126            while let Some(chunk) = bytes.next().await {
127                let chunk = chunk.map_err(|e| ForgeError::Transport(format!("stream chunk error: {e}")))?;
128                let chunk_text = std::str::from_utf8(&chunk)
129                    .map_err(|e| ForgeError::Transport(format!("invalid utf8 stream chunk: {e}")))?;
130                buffer.push_str(chunk_text);
131
132                while let Some(line_end) = buffer.find('\n') {
133                    let mut line = buffer[..line_end].to_string();
134                    buffer.drain(..=line_end);
135                    if line.ends_with('\r') {
136                        line.pop();
137                    }
138                    if line.trim().is_empty() {
139                        continue;
140                    }
141                    if let Some(data) = line.strip_prefix("data:") {
142                        let payload = data.trim();
143                        if payload == "[DONE]" {
144                            saw_done = true;
145                            yield StreamEvent::Done;
146                            continue;
147                        }
148                        for event in parse_stream_payload(payload)? {
149                            yield event;
150                        }
151                    }
152                }
153            }
154
155            if !buffer.trim().is_empty() {
156                let line = buffer.trim();
157                if let Some(data) = line.strip_prefix("data:") {
158                    let payload = data.trim();
159                    if payload == "[DONE]" {
160                        saw_done = true;
161                        yield StreamEvent::Done;
162                    } else {
163                        for event in parse_stream_payload(payload)? {
164                            yield event;
165                        }
166                    }
167                }
168            }
169
170            if !saw_done {
171                yield StreamEvent::Done;
172            }
173        };
174
175        Ok(Box::pin(stream))
176    }
177}
178
179fn build_chat_body(request: ChatRequest, stream: bool) -> Value {
180    let mut body = Map::new();
181    body.insert("model".to_string(), Value::String(request.model));
182    body.insert(
183        "messages".to_string(),
184        Value::Array(
185            request
186                .messages
187                .into_iter()
188                .map(|m| {
189                    json!({
190                        "role": role_to_openai(&m.role),
191                        "content": m.content
192                    })
193                })
194                .collect(),
195        ),
196    );
197    if let Some(temperature) = request.temperature {
198        body.insert("temperature".to_string(), json!(temperature));
199    }
200    if let Some(max_tokens) = request.max_tokens {
201        body.insert("max_tokens".to_string(), json!(max_tokens));
202    }
203    if !request.tools.is_empty() {
204        body.insert(
205            "tools".to_string(),
206            Value::Array(
207                request
208                    .tools
209                    .into_iter()
210                    .map(|tool| {
211                        json!({
212                            "type": "function",
213                            "function": {
214                                "name": tool.name,
215                                "description": tool.description,
216                                "parameters": tool.input_schema,
217                            }
218                        })
219                    })
220                    .collect(),
221            ),
222        );
223    }
224    if stream {
225        body.insert("stream".to_string(), Value::Bool(true));
226        body.insert("stream_options".to_string(), json!({"include_usage": true}));
227    }
228    Value::Object(body)
229}
230
231fn role_to_openai(role: &Role) -> &'static str {
232    match role {
233        Role::System => "system",
234        Role::User => "user",
235        Role::Assistant => "assistant",
236        Role::Tool => "tool",
237    }
238}
239
240fn parse_http_error(status: StatusCode, body: String) -> ForgeError {
241    let message = extract_provider_error(body);
242    match status {
243        StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN => ForgeError::Authentication,
244        StatusCode::TOO_MANY_REQUESTS => ForgeError::RateLimited,
245        _ => ForgeError::Provider(message),
246    }
247}
248
249fn extract_provider_error(body: String) -> String {
250    serde_json::from_str::<Value>(&body)
251        .ok()
252        .and_then(|v| {
253            v.get("error")
254                .and_then(|e| e.get("message"))
255                .and_then(Value::as_str)
256                .map(ToString::to_string)
257        })
258        .unwrap_or(body)
259}
260
261fn parse_chat_response(payload: Value) -> Result<ChatResponse, ForgeError> {
262    let id = payload
263        .get("id")
264        .and_then(Value::as_str)
265        .unwrap_or_default()
266        .to_string();
267    let model = payload
268        .get("model")
269        .and_then(Value::as_str)
270        .unwrap_or_default()
271        .to_string();
272
273    let choice = payload
274        .get("choices")
275        .and_then(Value::as_array)
276        .and_then(|choices| choices.first());
277
278    let message = choice
279        .and_then(|c| c.get("message"))
280        .unwrap_or(&Value::Null);
281    let output_text = extract_text_content(message.get("content"));
282    let tool_calls = extract_tool_calls(message.get("tool_calls"));
283    let usage = extract_usage(payload.get("usage"));
284
285    Ok(ChatResponse {
286        id,
287        model,
288        output_text,
289        tool_calls,
290        usage,
291    })
292}
293
294fn extract_text_content(content: Option<&Value>) -> String {
295    match content {
296        Some(Value::String(text)) => text.clone(),
297        Some(Value::Array(parts)) => parts
298            .iter()
299            .filter_map(|part| part.get("text").and_then(Value::as_str))
300            .collect::<Vec<_>>()
301            .join(""),
302        _ => String::new(),
303    }
304}
305
306fn extract_tool_calls(raw: Option<&Value>) -> Vec<ToolCall> {
307    raw.and_then(Value::as_array)
308        .map(|items| {
309            items
310                .iter()
311                .map(|item| {
312                    let id = item
313                        .get("id")
314                        .and_then(Value::as_str)
315                        .unwrap_or_default()
316                        .to_string();
317                    let function = item.get("function").unwrap_or(&Value::Null);
318                    let name = function
319                        .get("name")
320                        .and_then(Value::as_str)
321                        .unwrap_or_default()
322                        .to_string();
323                    let arguments = function
324                        .get("arguments")
325                        .and_then(Value::as_str)
326                        .and_then(|raw_args| serde_json::from_str::<Value>(raw_args).ok())
327                        .unwrap_or_else(|| {
328                            function.get("arguments").cloned().unwrap_or(Value::Null)
329                        });
330                    ToolCall {
331                        id,
332                        name,
333                        arguments,
334                    }
335                })
336                .collect()
337        })
338        .unwrap_or_default()
339}
340
341fn extract_usage(raw: Option<&Value>) -> Option<Usage> {
342    let usage = raw?;
343    let input_tokens = usage.get("prompt_tokens")?.as_u64()? as u32;
344    let output_tokens = usage.get("completion_tokens")?.as_u64()? as u32;
345    let total_tokens = usage.get("total_tokens")?.as_u64()? as u32;
346    Some(Usage {
347        input_tokens,
348        output_tokens,
349        total_tokens,
350    })
351}
352
353fn parse_stream_payload(payload: &str) -> Result<Vec<StreamEvent>, ForgeError> {
354    let value = serde_json::from_str::<Value>(payload)
355        .map_err(|e| ForgeError::Provider(format!("invalid stream payload: {e}")))?;
356
357    let mut events = Vec::new();
358    if let Some(usage) = extract_usage(value.get("usage")) {
359        events.push(StreamEvent::Usage { usage });
360    }
361
362    if let Some(choices) = value.get("choices").and_then(Value::as_array) {
363        for choice in choices {
364            if let Some(content) = choice
365                .get("delta")
366                .and_then(|d| d.get("content"))
367                .and_then(Value::as_str)
368                .filter(|s| !s.is_empty())
369            {
370                events.push(StreamEvent::TextDelta {
371                    delta: content.to_string(),
372                });
373            }
374
375            if let Some(tool_calls) = choice
376                .get("delta")
377                .and_then(|d| d.get("tool_calls"))
378                .and_then(Value::as_array)
379            {
380                for tool_call in tool_calls {
381                    let call_id = tool_call
382                        .get("id")
383                        .and_then(Value::as_str)
384                        .unwrap_or_default()
385                        .to_string();
386                    events.push(StreamEvent::ToolCallDelta {
387                        call_id,
388                        delta: tool_call.clone(),
389                    });
390                }
391            }
392        }
393    }
394
395    Ok(events)
396}
397
398#[cfg(test)]
399mod tests {
400    use super::*;
401    use forgeai_core::{ChatRequest, Message, Role};
402    use futures_util::StreamExt;
403    use wiremock::matchers::{body_partial_json, header, method, path};
404    use wiremock::{Mock, MockServer, ResponseTemplate};
405
406    fn sample_request() -> ChatRequest {
407        ChatRequest {
408            model: "gpt-4o-mini".to_string(),
409            messages: vec![Message {
410                role: Role::User,
411                content: "Say hello".to_string(),
412            }],
413            temperature: Some(0.2),
414            max_tokens: Some(32),
415            tools: vec![],
416            metadata: json!({}),
417        }
418    }
419
420    #[tokio::test]
421    async fn chat_contract_parses_response_and_usage() {
422        let server = MockServer::start().await;
423        Mock::given(method("POST"))
424            .and(path("/v1/chat/completions"))
425            .and(header("authorization", "Bearer test-key"))
426            .and(body_partial_json(json!({"model": "gpt-4o-mini"})))
427            .respond_with(ResponseTemplate::new(200).set_body_json(json!({
428                "id": "chatcmpl-123",
429                "model": "gpt-4o-mini",
430                "choices": [{
431                    "index": 0,
432                    "message": {"role": "assistant", "content": "Hello from OpenAI"}
433                }],
434                "usage": {"prompt_tokens": 10, "completion_tokens": 4, "total_tokens": 14}
435            })))
436            .mount(&server)
437            .await;
438
439        let adapter =
440            OpenAiAdapter::with_base_url("test-key", Url::parse(&server.uri()).unwrap()).unwrap();
441        let response = adapter.chat(sample_request()).await.unwrap();
442
443        assert_eq!(response.id, "chatcmpl-123");
444        assert_eq!(response.model, "gpt-4o-mini");
445        assert_eq!(response.output_text, "Hello from OpenAI");
446        assert_eq!(response.usage.unwrap().total_tokens, 14);
447    }
448
449    #[tokio::test]
450    async fn chat_stream_contract_parses_sse_events() {
451        let server = MockServer::start().await;
452        let sse_body = concat!(
453            "data: {\"id\":\"chatcmpl-1\",\"model\":\"gpt-4o-mini\",\"choices\":[{\"delta\":{\"content\":\"Hello\"},\"index\":0}]}\n\n",
454            "data: {\"id\":\"chatcmpl-1\",\"model\":\"gpt-4o-mini\",\"choices\":[{\"delta\":{\"content\":\" world\"},\"index\":0}]}\n\n",
455            "data: {\"usage\":{\"prompt_tokens\":10,\"completion_tokens\":2,\"total_tokens\":12},\"choices\":[]}\n\n",
456            "data: [DONE]\n\n"
457        );
458
459        Mock::given(method("POST"))
460            .and(path("/v1/chat/completions"))
461            .and(header("authorization", "Bearer test-key"))
462            .and(body_partial_json(json!({"stream": true})))
463            .respond_with(ResponseTemplate::new(200).set_body_raw(sse_body, "text/event-stream"))
464            .mount(&server)
465            .await;
466
467        let adapter =
468            OpenAiAdapter::with_base_url("test-key", Url::parse(&server.uri()).unwrap()).unwrap();
469        let mut stream = adapter.chat_stream(sample_request()).await.unwrap();
470        let mut events = Vec::new();
471        while let Some(item) = stream.next().await {
472            let event = item.unwrap();
473            let done = matches!(event, StreamEvent::Done);
474            events.push(event);
475            if done {
476                break;
477            }
478        }
479
480        assert!(events
481            .iter()
482            .any(|e| matches!(e, StreamEvent::TextDelta { delta } if delta == "Hello")));
483        assert!(events
484            .iter()
485            .any(|e| matches!(e, StreamEvent::TextDelta { delta } if delta == " world")));
486        assert!(events.iter().any(|e| matches!(
487            e,
488            StreamEvent::Usage { usage } if usage.total_tokens == 12
489        )));
490        assert!(events.iter().any(|e| matches!(e, StreamEvent::Done)));
491    }
492}