Skip to main content

jamjet_models/
openai.rs

1//! OpenAI adapter (Chat Completions API).
2//!
3//! Supports gpt-4o, gpt-4o-mini, gpt-4-turbo, o1, o3, etc.
4//! Reads `OPENAI_API_KEY` from the environment.
5//! Also works with OpenAI-compatible APIs (e.g. local Ollama) via `base_url`.
6
7use crate::adapter::{
8    ChatMessage, ChatRole, ModelAdapter, ModelConfig, ModelError, ModelRequest, ModelResponse,
9    StructuredRequest,
10};
11use async_trait::async_trait;
12use serde_json::{json, Value};
13use tracing::{debug, instrument};
14
15const OPENAI_API_BASE: &str = "https://api.openai.com";
16const DEFAULT_MODEL: &str = "gpt-4o";
17const DEFAULT_MAX_TOKENS: u32 = 4096;
18
19/// OpenAI Chat Completions adapter.
20pub struct OpenAiAdapter {
21    client: reqwest::Client,
22    api_key: String,
23    base_url: String,
24    default_model: String,
25}
26
27impl OpenAiAdapter {
28    pub fn new(api_key: impl Into<String>) -> Self {
29        Self {
30            client: reqwest::Client::new(),
31            api_key: api_key.into(),
32            base_url: OPENAI_API_BASE.into(),
33            default_model: DEFAULT_MODEL.into(),
34        }
35    }
36
37    /// Create adapter from `OPENAI_API_KEY` env var.
38    pub fn from_env() -> Result<Self, ModelError> {
39        let key = std::env::var("OPENAI_API_KEY")
40            .map_err(|_| ModelError::Network("OPENAI_API_KEY not set".into()))?;
41        Ok(Self::new(key))
42    }
43
44    /// Override the base URL (for OpenAI-compatible APIs).
45    pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
46        self.base_url = base_url.into();
47        self
48    }
49
50    pub fn with_default_model(mut self, model: impl Into<String>) -> Self {
51        self.default_model = model.into();
52        self
53    }
54
55    async fn call_api(&self, body: Value) -> Result<Value, ModelError> {
56        let resp = self
57            .client
58            .post(format!("{}/v1/chat/completions", self.base_url))
59            .bearer_auth(&self.api_key)
60            .json(&body)
61            .send()
62            .await
63            .map_err(|e| ModelError::Network(e.to_string()))?;
64
65        let status = resp.status().as_u16();
66        let body_text = resp
67            .text()
68            .await
69            .map_err(|e| ModelError::Network(e.to_string()))?;
70
71        if status == 429 {
72            return Err(ModelError::RateLimited {
73                retry_after_secs: 60,
74            });
75        }
76        if status != 200 {
77            return Err(ModelError::Api {
78                status,
79                body: body_text,
80            });
81        }
82
83        serde_json::from_str(&body_text).map_err(|e| ModelError::Serialization(e.to_string()))
84    }
85
86    fn build_request_body(
87        &self,
88        messages: &[ChatMessage],
89        config: &ModelConfig,
90        response_format: Option<Value>,
91    ) -> Value {
92        let model = config.model.as_deref().unwrap_or(&self.default_model);
93        let max_tokens = config.max_tokens.unwrap_or(DEFAULT_MAX_TOKENS);
94
95        let openai_messages: Vec<Value> = messages
96            .iter()
97            .map(|m| {
98                let role = match m.role {
99                    ChatRole::System => "system",
100                    ChatRole::User => "user",
101                    ChatRole::Assistant => "assistant",
102                    ChatRole::Tool => "tool",
103                };
104                // Prepend system_prompt as system message if provided.
105                json!({ "role": role, "content": m.content })
106            })
107            .collect();
108
109        // If a system_prompt is in config, prepend it.
110        let mut final_messages = openai_messages;
111        if let Some(sys) = &config.system_prompt {
112            final_messages.insert(0, json!({ "role": "system", "content": sys }));
113        }
114
115        let mut body = json!({
116            "model": model,
117            "max_tokens": max_tokens,
118            "messages": final_messages,
119        });
120
121        if let Some(temp) = config.temperature {
122            body["temperature"] = json!(temp);
123        }
124        if let Some(stops) = &config.stop_sequences {
125            body["stop"] = json!(stops);
126        }
127        if let Some(fmt) = response_format {
128            body["response_format"] = fmt;
129        }
130
131        body
132    }
133
134    fn parse_response(&self, resp: Value) -> Result<ModelResponse, ModelError> {
135        let model = resp["model"]
136            .as_str()
137            .unwrap_or(&self.default_model)
138            .to_string();
139
140        let choice = resp["choices"]
141            .as_array()
142            .and_then(|cs| cs.first())
143            .ok_or_else(|| ModelError::Api {
144                status: 200,
145                body: "no choices in response".into(),
146            })?;
147
148        let content = choice["message"]["content"]
149            .as_str()
150            .unwrap_or("")
151            .to_string();
152
153        let finish_reason = choice["finish_reason"]
154            .as_str()
155            .unwrap_or("stop")
156            .to_string();
157        let input_tokens = resp["usage"]["prompt_tokens"].as_u64().unwrap_or(0);
158        let output_tokens = resp["usage"]["completion_tokens"].as_u64().unwrap_or(0);
159
160        Ok(ModelResponse {
161            content,
162            model,
163            finish_reason,
164            input_tokens,
165            output_tokens,
166            structured: None,
167        })
168    }
169}
170
171#[async_trait]
172impl ModelAdapter for OpenAiAdapter {
173    fn system_name(&self) -> &'static str {
174        "openai"
175    }
176
177    fn default_model(&self) -> &str {
178        &self.default_model
179    }
180
181    #[instrument(skip(self, request), fields(
182        gen_ai.system = "openai",
183        gen_ai.request.model = tracing::field::Empty,
184        gen_ai.usage.input_tokens = tracing::field::Empty,
185        gen_ai.usage.output_tokens = tracing::field::Empty,
186    ))]
187    async fn chat(&self, request: ModelRequest) -> Result<ModelResponse, ModelError> {
188        let model = request
189            .config
190            .model
191            .as_deref()
192            .unwrap_or(&self.default_model)
193            .to_string();
194        tracing::Span::current().record("gen_ai.request.model", model.as_str());
195
196        debug!(model = %model, "Calling OpenAI Chat Completions API");
197
198        let body = self.build_request_body(&request.messages, &request.config, None);
199        let resp_json = self.call_api(body).await?;
200        let response = self.parse_response(resp_json)?;
201
202        tracing::Span::current()
203            .record("gen_ai.usage.input_tokens", response.input_tokens)
204            .record("gen_ai.usage.output_tokens", response.output_tokens);
205
206        Ok(response)
207    }
208
209    #[instrument(skip(self, request), fields(
210        gen_ai.system = "openai",
211        gen_ai.request.model = tracing::field::Empty,
212    ))]
213    async fn structured_output(
214        &self,
215        request: StructuredRequest,
216    ) -> Result<ModelResponse, ModelError> {
217        let model = request
218            .config
219            .model
220            .as_deref()
221            .unwrap_or(&self.default_model)
222            .to_string();
223        tracing::Span::current().record("gen_ai.request.model", model.as_str());
224
225        // Use OpenAI's native JSON mode (response_format: json_object).
226        // For models that support json_schema, we pass the schema directly.
227        let response_format = json!({ "type": "json_object" });
228
229        let mut config = request.config.clone();
230        let schema_str = serde_json::to_string_pretty(&request.output_schema)
231            .map_err(|e| ModelError::Serialization(e.to_string()))?;
232        let system = config.system_prompt.get_or_insert_with(String::new);
233        system.push_str(&format!(
234            "\n\nRespond ONLY with a valid JSON object matching this schema:\n{schema_str}"
235        ));
236
237        let body = self.build_request_body(&request.messages, &config, Some(response_format));
238        let resp_json = self.call_api(body).await?;
239        let mut response = self.parse_response(resp_json)?;
240
241        let structured =
242            serde_json::from_str::<serde_json::Value>(&response.content).map_err(|e| {
243                ModelError::Serialization(format!("structured output parse error: {e}"))
244            })?;
245        response.structured = Some(structured);
246
247        Ok(response)
248    }
249}