1pub mod anthropic;
8pub mod gemini;
9
10use async_trait::async_trait;
11use serde_json::Value;
12use std::collections::HashMap;
13
14use crate::error::Error;
15use crate::protocol::v2::capabilities::Capability;
16use crate::protocol::v2::manifest::ApiStyle;
17use crate::protocol::ProtocolError;
18use crate::types::events::StreamingEvent;
19use crate::types::execution_result::ExecutionUsage;
20use crate::types::message::{Message, MessageContent};
21
22pub use anthropic::AnthropicDriver;
23pub use gemini::GeminiDriver;
24
25#[derive(Debug, Clone)]
27pub struct DriverRequest {
28 pub url: String,
30 pub method: String,
32 pub headers: HashMap<String, String>,
34 pub body: Value,
36 pub stream: bool,
38}
39
40#[derive(Debug, Clone)]
42pub struct DriverResponse {
43 pub content: Option<String>,
45 pub finish_reason: Option<String>,
47 pub usage: Option<UsageInfo>,
49 pub tool_calls: Vec<Value>,
51 pub raw: Value,
53}
54
55#[derive(Debug, Clone, Default, serde::Serialize)]
57pub struct UsageInfo {
58 pub prompt_tokens: u64,
59 pub completion_tokens: u64,
60 pub total_tokens: u64,
61 pub reasoning_tokens: Option<u64>,
63 pub cache_read_tokens: Option<u64>,
65 pub cache_creation_tokens: Option<u64>,
66}
67
68impl From<UsageInfo> for ExecutionUsage {
69 fn from(u: UsageInfo) -> Self {
70 Self {
71 prompt_tokens: u.prompt_tokens,
72 completion_tokens: u.completion_tokens,
73 total_tokens: u.total_tokens,
74 reasoning_tokens: u.reasoning_tokens,
75 cache_read_tokens: u.cache_read_tokens,
76 cache_creation_tokens: u.cache_creation_tokens,
77 }
78 }
79}
80
81#[async_trait]
92pub trait ProviderDriver: Send + Sync + std::fmt::Debug {
93 fn provider_id(&self) -> &str;
95
96 fn api_style(&self) -> ApiStyle;
98
99 fn build_request(
101 &self,
102 messages: &[Message],
103 model: &str,
104 temperature: Option<f64>,
105 max_tokens: Option<u32>,
106 stream: bool,
107 extra: Option<&Value>,
108 ) -> Result<DriverRequest, Error>;
109
110 fn parse_response(&self, body: &Value) -> Result<DriverResponse, Error>;
112
113 fn parse_stream_event(&self, data: &str) -> Result<Option<StreamingEvent>, Error>;
115
116 fn supported_capabilities(&self) -> &[Capability];
118
119 fn is_stream_done(&self, data: &str) -> bool;
121}
122
123#[derive(Debug)]
125pub struct OpenAiDriver {
126 provider_id: String,
127 capabilities: Vec<Capability>,
128}
129
130impl OpenAiDriver {
131 pub fn new(provider_id: impl Into<String>, capabilities: Vec<Capability>) -> Self {
132 Self {
133 provider_id: provider_id.into(),
134 capabilities,
135 }
136 }
137}
138
139fn parse_openai_usage_value(u: &Value) -> UsageInfo {
140 let reasoning = u
141 .pointer("/completion_tokens_details/reasoning_tokens")
142 .and_then(|v| v.as_u64());
143 UsageInfo {
144 prompt_tokens: u["prompt_tokens"].as_u64().unwrap_or(0),
145 completion_tokens: u["completion_tokens"].as_u64().unwrap_or(0),
146 total_tokens: u["total_tokens"].as_u64().unwrap_or(0),
147 reasoning_tokens: reasoning,
148 cache_read_tokens: None,
149 cache_creation_tokens: None,
150 }
151}
152
153#[async_trait]
154impl ProviderDriver for OpenAiDriver {
155 fn provider_id(&self) -> &str {
156 &self.provider_id
157 }
158
159 fn api_style(&self) -> ApiStyle {
160 ApiStyle::OpenAiCompatible
161 }
162
163 fn build_request(
164 &self,
165 messages: &[Message],
166 model: &str,
167 temperature: Option<f64>,
168 max_tokens: Option<u32>,
169 stream: bool,
170 extra: Option<&Value>,
171 ) -> Result<DriverRequest, Error> {
172 let oai_messages: Vec<Value> = messages
173 .iter()
174 .map(|m| {
175 let role = serde_json::to_value(&m.role).unwrap_or(Value::String("user".into()));
176 let content = match &m.content {
177 MessageContent::Text(s) => Value::String(s.clone()),
178 MessageContent::Blocks(_) => {
179 serde_json::to_value(&m.content).unwrap_or(Value::Null)
180 }
181 };
182 let mut obj = serde_json::json!({ "role": role, "content": content });
183 if matches!(m.role, crate::types::message::MessageRole::Tool) {
185 if let Some(ref id) = m.tool_call_id {
186 obj["tool_call_id"] = Value::String(id.clone());
187 }
188 }
189 obj
190 })
191 .collect();
192
193 let mut body = serde_json::json!({
194 "model": model,
195 "messages": oai_messages,
196 "stream": stream,
197 });
198
199 if let Some(t) = temperature {
200 body["temperature"] = serde_json::json!(t);
201 }
202 if let Some(mt) = max_tokens {
203 body["max_tokens"] = serde_json::json!(mt);
204 }
205 if let Some(Value::Object(map)) = extra {
206 for (k, v) in map {
207 body[k] = v.clone();
208 }
209 }
210
211 Ok(DriverRequest {
212 url: String::new(), method: "POST".into(),
214 headers: HashMap::new(),
215 body,
216 stream,
217 })
218 }
219
220 fn parse_response(&self, body: &Value) -> Result<DriverResponse, Error> {
221 let content = body
222 .pointer("/choices/0/message/content")
223 .and_then(|v| v.as_str())
224 .map(String::from);
225 let finish_reason = body
226 .pointer("/choices/0/finish_reason")
227 .and_then(|v| v.as_str())
228 .map(String::from);
229 let usage = body.get("usage").map(parse_openai_usage_value);
230 let tool_calls = body
231 .pointer("/choices/0/message/tool_calls")
232 .and_then(|v| v.as_array())
233 .cloned()
234 .unwrap_or_default();
235
236 Ok(DriverResponse {
237 content,
238 finish_reason,
239 usage,
240 tool_calls,
241 raw: body.clone(),
242 })
243 }
244
245 fn parse_stream_event(&self, data: &str) -> Result<Option<StreamingEvent>, Error> {
246 if data.trim().is_empty() || self.is_stream_done(data) {
247 return Ok(None);
248 }
249 let v: Value = serde_json::from_str(data).map_err(|e| {
250 Error::Protocol(ProtocolError::ValidationError(format!(
251 "Failed to parse SSE data: {}",
252 e
253 )))
254 })?;
255
256 if let Some(content) = v
258 .pointer("/choices/0/delta/content")
259 .and_then(|c| c.as_str())
260 {
261 if !content.is_empty() {
262 return Ok(Some(StreamingEvent::PartialContentDelta {
263 content: content.to_string(),
264 sequence_id: None,
265 }));
266 }
267 }
268
269 if let Some(thinking) = v
271 .pointer("/choices/0/delta/reasoning_content")
272 .and_then(|c| c.as_str())
273 {
274 if !thinking.is_empty() {
275 return Ok(Some(StreamingEvent::ThinkingDelta {
276 thinking: thinking.to_string(),
277 tool_consideration: None,
278 }));
279 }
280 }
281
282 if let Some(reason) = v
284 .pointer("/choices/0/finish_reason")
285 .and_then(|r| r.as_str())
286 {
287 return Ok(Some(StreamingEvent::StreamEnd {
288 finish_reason: Some(reason.to_string()),
289 }));
290 }
291
292 Ok(None)
293 }
294
295 fn supported_capabilities(&self) -> &[Capability] {
296 &self.capabilities
297 }
298
299 fn is_stream_done(&self, data: &str) -> bool {
300 data.trim() == "[DONE]"
301 }
302}
303
304pub fn create_driver(
310 api_style: ApiStyle,
311 provider_id: &str,
312 capabilities: Vec<Capability>,
313) -> Box<dyn ProviderDriver> {
314 match api_style {
315 ApiStyle::OpenAiCompatible | ApiStyle::Custom => {
316 Box::new(OpenAiDriver::new(provider_id, capabilities))
317 }
318 ApiStyle::AnthropicMessages => Box::new(AnthropicDriver::new(provider_id, capabilities)),
319 ApiStyle::GeminiGenerate => Box::new(GeminiDriver::new(provider_id, capabilities)),
320 }
321}
322
323#[cfg(test)]
324mod tests {
325 use super::*;
326
327 #[test]
328 fn test_openai_driver_build_request() {
329 let driver = OpenAiDriver::new("openai", vec![Capability::Text, Capability::Streaming]);
330 let messages = vec![Message::user("Hello")];
331 let req = driver
332 .build_request(&messages, "gpt-4", Some(0.7), Some(1024), true, None)
333 .unwrap();
334 assert!(req.stream);
335 assert_eq!(req.body["model"], "gpt-4");
336 assert_eq!(req.body["temperature"], 0.7);
337 }
338
339 #[test]
340 fn test_openai_driver_parse_response() {
341 let driver = OpenAiDriver::new("openai", vec![]);
342 let body = serde_json::json!({
343 "choices": [{"message": {"content": "Hi there!"}, "finish_reason": "stop"}],
344 "usage": {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15}
345 });
346 let resp = driver.parse_response(&body).unwrap();
347 assert_eq!(resp.content.as_deref(), Some("Hi there!"));
348 assert_eq!(resp.finish_reason.as_deref(), Some("stop"));
349 assert_eq!(resp.usage.unwrap().total_tokens, 15);
350 }
351
352 #[test]
353 fn test_openai_driver_parse_response_reasoning_tokens() {
354 let driver = OpenAiDriver::new("openai", vec![]);
355 let body = serde_json::json!({
356 "choices": [{"message": {"content": "Hello, world!"}, "finish_reason": "stop"}],
357 "usage": {
358 "prompt_tokens": 10,
359 "completion_tokens": 5,
360 "total_tokens": 15,
361 "completion_tokens_details": {"reasoning_tokens": 3}
362 }
363 });
364 let resp = driver.parse_response(&body).unwrap();
365 let u = resp.usage.expect("usage");
366 assert_eq!(u.reasoning_tokens, Some(3));
367 assert_eq!(u.prompt_tokens, 10);
368 assert_eq!(u.completion_tokens, 5);
369 }
370
371 #[test]
372 fn test_openai_driver_parse_stream_reasoning_delta() {
373 let driver = OpenAiDriver::new("openai", vec![]);
374 let data = r#"{"choices":[{"delta":{"reasoning_content":"Let me think..."},"index":0}]}"#;
375 let event = driver.parse_stream_event(data).unwrap();
376 match event {
377 Some(StreamingEvent::ThinkingDelta { thinking, .. }) => {
378 assert_eq!(thinking, "Let me think...");
379 }
380 _ => panic!("Expected ThinkingDelta, got {:?}", event),
381 }
382 }
383
384 #[test]
385 fn test_openai_driver_parse_stream() {
386 let driver = OpenAiDriver::new("openai", vec![]);
387 let data = r#"{"choices":[{"delta":{"content":"Hello"},"index":0}]}"#;
388 let event = driver.parse_stream_event(data).unwrap();
389 match event {
390 Some(StreamingEvent::PartialContentDelta { content, .. }) => {
391 assert_eq!(content, "Hello");
392 }
393 _ => panic!("Expected PartialContentDelta"),
394 }
395 }
396
397 #[test]
398 fn test_stream_done_detection() {
399 let driver = OpenAiDriver::new("openai", vec![]);
400 assert!(driver.is_stream_done("[DONE]"));
401 assert!(!driver.is_stream_done(r#"{"choices":[]}"#));
402 }
403}