1use crate::adapter::{
8 ChatMessage, ChatRole, ModelAdapter, ModelConfig, ModelError, ModelRequest, ModelResponse,
9 StructuredRequest,
10};
11use async_trait::async_trait;
12use serde_json::{json, Value};
13use tracing::{debug, instrument};
14
15const OPENAI_API_BASE: &str = "https://api.openai.com";
16const DEFAULT_MODEL: &str = "gpt-4o";
17const DEFAULT_MAX_TOKENS: u32 = 4096;
18
19pub struct OpenAiAdapter {
21 client: reqwest::Client,
22 api_key: String,
23 base_url: String,
24 default_model: String,
25}
26
27impl OpenAiAdapter {
28 pub fn new(api_key: impl Into<String>) -> Self {
29 Self {
30 client: reqwest::Client::new(),
31 api_key: api_key.into(),
32 base_url: OPENAI_API_BASE.into(),
33 default_model: DEFAULT_MODEL.into(),
34 }
35 }
36
37 pub fn from_env() -> Result<Self, ModelError> {
39 let key = std::env::var("OPENAI_API_KEY")
40 .map_err(|_| ModelError::Network("OPENAI_API_KEY not set".into()))?;
41 Ok(Self::new(key))
42 }
43
44 pub fn with_base_url(mut self, base_url: impl Into<String>) -> Self {
46 self.base_url = base_url.into();
47 self
48 }
49
50 pub fn with_default_model(mut self, model: impl Into<String>) -> Self {
51 self.default_model = model.into();
52 self
53 }
54
55 async fn call_api(&self, body: Value) -> Result<Value, ModelError> {
56 let resp = self
57 .client
58 .post(format!("{}/v1/chat/completions", self.base_url))
59 .bearer_auth(&self.api_key)
60 .json(&body)
61 .send()
62 .await
63 .map_err(|e| ModelError::Network(e.to_string()))?;
64
65 let status = resp.status().as_u16();
66 let body_text = resp
67 .text()
68 .await
69 .map_err(|e| ModelError::Network(e.to_string()))?;
70
71 if status == 429 {
72 return Err(ModelError::RateLimited {
73 retry_after_secs: 60,
74 });
75 }
76 if status != 200 {
77 return Err(ModelError::Api {
78 status,
79 body: body_text,
80 });
81 }
82
83 serde_json::from_str(&body_text).map_err(|e| ModelError::Serialization(e.to_string()))
84 }
85
86 fn build_request_body(
87 &self,
88 messages: &[ChatMessage],
89 config: &ModelConfig,
90 response_format: Option<Value>,
91 ) -> Value {
92 let model = config.model.as_deref().unwrap_or(&self.default_model);
93 let max_tokens = config.max_tokens.unwrap_or(DEFAULT_MAX_TOKENS);
94
95 let openai_messages: Vec<Value> = messages
96 .iter()
97 .map(|m| {
98 let role = match m.role {
99 ChatRole::System => "system",
100 ChatRole::User => "user",
101 ChatRole::Assistant => "assistant",
102 ChatRole::Tool => "tool",
103 };
104 json!({ "role": role, "content": m.content })
106 })
107 .collect();
108
109 let mut final_messages = openai_messages;
111 if let Some(sys) = &config.system_prompt {
112 final_messages.insert(0, json!({ "role": "system", "content": sys }));
113 }
114
115 let mut body = json!({
116 "model": model,
117 "max_tokens": max_tokens,
118 "messages": final_messages,
119 });
120
121 if let Some(temp) = config.temperature {
122 body["temperature"] = json!(temp);
123 }
124 if let Some(stops) = &config.stop_sequences {
125 body["stop"] = json!(stops);
126 }
127 if let Some(fmt) = response_format {
128 body["response_format"] = fmt;
129 }
130
131 body
132 }
133
134 fn parse_response(&self, resp: Value) -> Result<ModelResponse, ModelError> {
135 let model = resp["model"]
136 .as_str()
137 .unwrap_or(&self.default_model)
138 .to_string();
139
140 let choice = resp["choices"]
141 .as_array()
142 .and_then(|cs| cs.first())
143 .ok_or_else(|| ModelError::Api {
144 status: 200,
145 body: "no choices in response".into(),
146 })?;
147
148 let content = choice["message"]["content"]
149 .as_str()
150 .unwrap_or("")
151 .to_string();
152
153 let finish_reason = choice["finish_reason"]
154 .as_str()
155 .unwrap_or("stop")
156 .to_string();
157 let input_tokens = resp["usage"]["prompt_tokens"].as_u64().unwrap_or(0);
158 let output_tokens = resp["usage"]["completion_tokens"].as_u64().unwrap_or(0);
159
160 Ok(ModelResponse {
161 content,
162 model,
163 finish_reason,
164 input_tokens,
165 output_tokens,
166 structured: None,
167 })
168 }
169}
170
171#[async_trait]
172impl ModelAdapter for OpenAiAdapter {
173 fn system_name(&self) -> &'static str {
174 "openai"
175 }
176
177 fn default_model(&self) -> &str {
178 &self.default_model
179 }
180
181 #[instrument(skip(self, request), fields(
182 gen_ai.system = "openai",
183 gen_ai.request.model = tracing::field::Empty,
184 gen_ai.usage.input_tokens = tracing::field::Empty,
185 gen_ai.usage.output_tokens = tracing::field::Empty,
186 ))]
187 async fn chat(&self, request: ModelRequest) -> Result<ModelResponse, ModelError> {
188 let model = request
189 .config
190 .model
191 .as_deref()
192 .unwrap_or(&self.default_model)
193 .to_string();
194 tracing::Span::current().record("gen_ai.request.model", model.as_str());
195
196 debug!(model = %model, "Calling OpenAI Chat Completions API");
197
198 let body = self.build_request_body(&request.messages, &request.config, None);
199 let resp_json = self.call_api(body).await?;
200 let response = self.parse_response(resp_json)?;
201
202 tracing::Span::current()
203 .record("gen_ai.usage.input_tokens", response.input_tokens)
204 .record("gen_ai.usage.output_tokens", response.output_tokens);
205
206 Ok(response)
207 }
208
209 #[instrument(skip(self, request), fields(
210 gen_ai.system = "openai",
211 gen_ai.request.model = tracing::field::Empty,
212 ))]
213 async fn structured_output(
214 &self,
215 request: StructuredRequest,
216 ) -> Result<ModelResponse, ModelError> {
217 let model = request
218 .config
219 .model
220 .as_deref()
221 .unwrap_or(&self.default_model)
222 .to_string();
223 tracing::Span::current().record("gen_ai.request.model", model.as_str());
224
225 let response_format = json!({ "type": "json_object" });
228
229 let mut config = request.config.clone();
230 let schema_str = serde_json::to_string_pretty(&request.output_schema)
231 .map_err(|e| ModelError::Serialization(e.to_string()))?;
232 let system = config.system_prompt.get_or_insert_with(String::new);
233 system.push_str(&format!(
234 "\n\nRespond ONLY with a valid JSON object matching this schema:\n{schema_str}"
235 ));
236
237 let body = self.build_request_body(&request.messages, &config, Some(response_format));
238 let resp_json = self.call_api(body).await?;
239 let mut response = self.parse_response(resp_json)?;
240
241 let structured =
242 serde_json::from_str::<serde_json::Value>(&response.content).map_err(|e| {
243 ModelError::Serialization(format!("structured output parse error: {e}"))
244 })?;
245 response.structured = Some(structured);
246
247 Ok(response)
248 }
249}