1use std::sync::Arc;
2
3use async_trait::async_trait;
4use serde_json::{json, Value};
5use synaptic_core::{
6 AIMessageChunk, ChatModel, ChatRequest, ChatResponse, ChatStream, Message, SynapticError,
7 TokenUsage, ToolCall, ToolChoice, ToolDefinition,
8};
9
10use crate::backend::{ProviderBackend, ProviderRequest, ProviderResponse};
11
12#[derive(Debug, Clone)]
13pub struct OpenAiConfig {
14 pub api_key: String,
15 pub model: String,
16 pub base_url: String,
17 pub max_tokens: Option<u32>,
18 pub temperature: Option<f64>,
19 pub top_p: Option<f64>,
20 pub stop: Option<Vec<String>>,
21 pub seed: Option<u64>,
22}
23
24impl OpenAiConfig {
25 pub fn new(api_key: impl Into<String>, model: impl Into<String>) -> Self {
26 Self {
27 api_key: api_key.into(),
28 model: model.into(),
29 base_url: "https://api.openai.com/v1".to_string(),
30 max_tokens: None,
31 temperature: None,
32 top_p: None,
33 stop: None,
34 seed: None,
35 }
36 }
37
38 pub fn with_base_url(mut self, url: impl Into<String>) -> Self {
39 self.base_url = url.into();
40 self
41 }
42
43 pub fn with_max_tokens(mut self, max_tokens: u32) -> Self {
44 self.max_tokens = Some(max_tokens);
45 self
46 }
47
48 pub fn with_temperature(mut self, temperature: f64) -> Self {
49 self.temperature = Some(temperature);
50 self
51 }
52
53 pub fn with_top_p(mut self, top_p: f64) -> Self {
54 self.top_p = Some(top_p);
55 self
56 }
57
58 pub fn with_stop(mut self, stop: Vec<String>) -> Self {
59 self.stop = Some(stop);
60 self
61 }
62
63 pub fn with_seed(mut self, seed: u64) -> Self {
64 self.seed = Some(seed);
65 self
66 }
67}
68
69pub struct OpenAiChatModel {
70 config: OpenAiConfig,
71 backend: Arc<dyn ProviderBackend>,
72}
73
74impl OpenAiChatModel {
75 pub fn new(config: OpenAiConfig, backend: Arc<dyn ProviderBackend>) -> Self {
76 Self { config, backend }
77 }
78
79 fn build_request(&self, request: &ChatRequest, stream: bool) -> ProviderRequest {
80 let messages: Vec<Value> = request.messages.iter().map(message_to_openai).collect();
81
82 let mut body = json!({
83 "model": self.config.model,
84 "messages": messages,
85 "stream": stream,
86 });
87
88 if let Some(max_tokens) = self.config.max_tokens {
89 body["max_tokens"] = json!(max_tokens);
90 }
91 if let Some(temp) = self.config.temperature {
92 body["temperature"] = json!(temp);
93 }
94 if let Some(top_p) = self.config.top_p {
95 body["top_p"] = json!(top_p);
96 }
97 if let Some(ref stop) = self.config.stop {
98 body["stop"] = json!(stop);
99 }
100 if let Some(seed) = self.config.seed {
101 body["seed"] = json!(seed);
102 }
103 if !request.tools.is_empty() {
104 body["tools"] = json!(request
105 .tools
106 .iter()
107 .map(tool_def_to_openai)
108 .collect::<Vec<_>>());
109 }
110 if let Some(ref choice) = request.tool_choice {
111 body["tool_choice"] = match choice {
112 ToolChoice::Auto => json!("auto"),
113 ToolChoice::Required => json!("required"),
114 ToolChoice::None => json!("none"),
115 ToolChoice::Specific(name) => json!({
116 "type": "function",
117 "function": {"name": name}
118 }),
119 };
120 }
121
122 ProviderRequest {
123 url: format!("{}/chat/completions", self.config.base_url),
124 headers: vec![
125 (
126 "Authorization".to_string(),
127 format!("Bearer {}", self.config.api_key),
128 ),
129 ("Content-Type".to_string(), "application/json".to_string()),
130 ],
131 body,
132 }
133 }
134}
135
136fn message_to_openai(msg: &Message) -> Value {
137 match msg {
138 Message::System { content, .. } => json!({
139 "role": "system",
140 "content": content,
141 }),
142 Message::Human { content, .. } => json!({
143 "role": "user",
144 "content": content,
145 }),
146 Message::AI {
147 content,
148 tool_calls,
149 ..
150 } => {
151 let mut obj = json!({
152 "role": "assistant",
153 "content": content,
154 });
155 if !tool_calls.is_empty() {
156 obj["tool_calls"] = json!(tool_calls
157 .iter()
158 .map(|tc| json!({
159 "id": tc.id,
160 "type": "function",
161 "function": {
162 "name": tc.name,
163 "arguments": tc.arguments.to_string(),
164 }
165 }))
166 .collect::<Vec<_>>());
167 }
168 obj
169 }
170 Message::Tool {
171 content,
172 tool_call_id,
173 ..
174 } => json!({
175 "role": "tool",
176 "content": content,
177 "tool_call_id": tool_call_id,
178 }),
179 Message::Chat {
180 custom_role,
181 content,
182 ..
183 } => json!({
184 "role": custom_role,
185 "content": content,
186 }),
187 Message::Remove { .. } => json!(null), }
189}
190
191fn tool_def_to_openai(def: &ToolDefinition) -> Value {
192 json!({
193 "type": "function",
194 "function": {
195 "name": def.name,
196 "description": def.description,
197 "parameters": def.parameters,
198 }
199 })
200}
201
202fn parse_response(resp: &ProviderResponse) -> Result<ChatResponse, SynapticError> {
203 check_error_status(resp)?;
204
205 let choice = &resp.body["choices"][0]["message"];
206 let content = choice["content"].as_str().unwrap_or("").to_string();
207 let tool_calls = parse_tool_calls(choice);
208
209 let usage = parse_usage(&resp.body["usage"]);
210
211 let message = if tool_calls.is_empty() {
212 Message::ai(content)
213 } else {
214 Message::ai_with_tool_calls(content, tool_calls)
215 };
216
217 Ok(ChatResponse { message, usage })
218}
219
220fn check_error_status(resp: &ProviderResponse) -> Result<(), SynapticError> {
221 if resp.status == 429 {
222 let msg = resp.body["error"]["message"]
223 .as_str()
224 .unwrap_or("rate limited")
225 .to_string();
226 return Err(SynapticError::RateLimit(msg));
227 }
228 if resp.status >= 400 {
229 let msg = resp.body["error"]["message"]
230 .as_str()
231 .unwrap_or("unknown API error")
232 .to_string();
233 return Err(SynapticError::Model(format!(
234 "OpenAI API error ({}): {}",
235 resp.status, msg
236 )));
237 }
238 Ok(())
239}
240
241fn parse_tool_calls(message: &Value) -> Vec<ToolCall> {
242 message["tool_calls"]
243 .as_array()
244 .map(|arr| {
245 arr.iter()
246 .filter_map(|tc| {
247 let id = tc["id"].as_str()?.to_string();
248 let name = tc["function"]["name"].as_str()?.to_string();
249 let args_str = tc["function"]["arguments"].as_str().unwrap_or("{}");
250 let arguments =
251 serde_json::from_str(args_str).unwrap_or(Value::Object(Default::default()));
252 Some(ToolCall {
253 id,
254 name,
255 arguments,
256 })
257 })
258 .collect()
259 })
260 .unwrap_or_default()
261}
262
263fn parse_usage(usage: &Value) -> Option<TokenUsage> {
264 if usage.is_null() {
265 return None;
266 }
267 Some(TokenUsage {
268 input_tokens: usage["prompt_tokens"].as_u64().unwrap_or(0) as u32,
269 output_tokens: usage["completion_tokens"].as_u64().unwrap_or(0) as u32,
270 total_tokens: usage["total_tokens"].as_u64().unwrap_or(0) as u32,
271 input_details: None,
272 output_details: None,
273 })
274}
275
276fn parse_stream_chunk(data: &str) -> Option<AIMessageChunk> {
277 let v: Value = serde_json::from_str(data).ok()?;
278 let delta = &v["choices"][0]["delta"];
279
280 let content = delta["content"].as_str().unwrap_or("").to_string();
281 let tool_calls = parse_tool_calls(delta);
282 let usage = parse_usage(&v["usage"]);
283
284 Some(AIMessageChunk {
285 content,
286 tool_calls,
287 usage,
288 ..Default::default()
289 })
290}
291
292#[async_trait]
293impl ChatModel for OpenAiChatModel {
294 async fn chat(&self, request: ChatRequest) -> Result<ChatResponse, SynapticError> {
295 let provider_req = self.build_request(&request, false);
296 let resp = self.backend.send(provider_req).await?;
297 parse_response(&resp)
298 }
299
300 fn stream_chat(&self, request: ChatRequest) -> ChatStream<'_> {
301 Box::pin(async_stream::stream! {
302 let provider_req = self.build_request(&request, true);
303 let byte_stream = self.backend.send_stream(provider_req).await;
304
305 let byte_stream = match byte_stream {
306 Ok(s) => s,
307 Err(e) => {
308 yield Err(e);
309 return;
310 }
311 };
312
313 use eventsource_stream::Eventsource;
314 use futures::StreamExt;
315
316 let mut event_stream = byte_stream
317 .map(|result| result.map_err(|e| std::io::Error::other(e.to_string())))
318 .eventsource();
319
320 while let Some(event) = event_stream.next().await {
321 match event {
322 Ok(ev) => {
323 if ev.data == "[DONE]" {
324 break;
325 }
326 if let Some(chunk) = parse_stream_chunk(&ev.data) {
327 yield Ok(chunk);
328 }
329 }
330 Err(e) => {
331 yield Err(SynapticError::Model(format!("SSE parse error: {e}")));
332 break;
333 }
334 }
335 }
336 })
337 }
338}