1pub mod anthropic;
8pub mod gemini;
9
10use async_trait::async_trait;
11use serde_json::Value;
12use std::collections::HashMap;
13
14use crate::error::Error;
15use crate::protocol::v2::manifest::ApiStyle;
16use crate::protocol::v2::capabilities::Capability;
17use crate::protocol::ProtocolError;
18use crate::types::events::StreamingEvent;
19use crate::types::message::{Message, MessageContent};
20
21pub use anthropic::AnthropicDriver;
22pub use gemini::GeminiDriver;
23
24#[derive(Debug, Clone)]
26pub struct DriverRequest {
27 pub url: String,
29 pub method: String,
31 pub headers: HashMap<String, String>,
33 pub body: Value,
35 pub stream: bool,
37}
38
39#[derive(Debug, Clone)]
41pub struct DriverResponse {
42 pub content: Option<String>,
44 pub finish_reason: Option<String>,
46 pub usage: Option<UsageInfo>,
48 pub tool_calls: Vec<Value>,
50 pub raw: Value,
52}
53
54#[derive(Debug, Clone, Default)]
56pub struct UsageInfo {
57 pub prompt_tokens: u64,
58 pub completion_tokens: u64,
59 pub total_tokens: u64,
60}
61
62#[async_trait]
73pub trait ProviderDriver: Send + Sync + std::fmt::Debug {
74 fn provider_id(&self) -> &str;
76
77 fn api_style(&self) -> ApiStyle;
79
80 fn build_request(
82 &self,
83 messages: &[Message],
84 model: &str,
85 temperature: Option<f64>,
86 max_tokens: Option<u32>,
87 stream: bool,
88 extra: Option<&Value>,
89 ) -> Result<DriverRequest, Error>;
90
91 fn parse_response(&self, body: &Value) -> Result<DriverResponse, Error>;
93
94 fn parse_stream_event(&self, data: &str) -> Result<Option<StreamingEvent>, Error>;
96
97 fn supported_capabilities(&self) -> &[Capability];
99
100 fn is_stream_done(&self, data: &str) -> bool;
102}
103
104#[derive(Debug)]
106pub struct OpenAiDriver {
107 provider_id: String,
108 capabilities: Vec<Capability>,
109}
110
111impl OpenAiDriver {
112 pub fn new(provider_id: impl Into<String>, capabilities: Vec<Capability>) -> Self {
113 Self {
114 provider_id: provider_id.into(),
115 capabilities,
116 }
117 }
118}
119
120#[async_trait]
121impl ProviderDriver for OpenAiDriver {
122 fn provider_id(&self) -> &str {
123 &self.provider_id
124 }
125
126 fn api_style(&self) -> ApiStyle {
127 ApiStyle::OpenAiCompatible
128 }
129
130 fn build_request(
131 &self,
132 messages: &[Message],
133 model: &str,
134 temperature: Option<f64>,
135 max_tokens: Option<u32>,
136 stream: bool,
137 extra: Option<&Value>,
138 ) -> Result<DriverRequest, Error> {
139 let oai_messages: Vec<Value> = messages
140 .iter()
141 .map(|m| {
142 let role = serde_json::to_value(&m.role).unwrap_or(Value::String("user".into()));
143 let content = match &m.content {
144 MessageContent::Text(s) => Value::String(s.clone()),
145 MessageContent::Blocks(_) => {
146 serde_json::to_value(&m.content).unwrap_or(Value::Null)
147 }
148 };
149 let mut obj = serde_json::json!({ "role": role, "content": content });
150 if matches!(m.role, crate::types::message::MessageRole::Tool) {
152 if let Some(ref id) = m.tool_call_id {
153 obj["tool_call_id"] = Value::String(id.clone());
154 }
155 }
156 obj
157 })
158 .collect();
159
160 let mut body = serde_json::json!({
161 "model": model,
162 "messages": oai_messages,
163 "stream": stream,
164 });
165
166 if let Some(t) = temperature {
167 body["temperature"] = serde_json::json!(t);
168 }
169 if let Some(mt) = max_tokens {
170 body["max_tokens"] = serde_json::json!(mt);
171 }
172 if let Some(ext) = extra {
173 if let Value::Object(map) = ext {
174 for (k, v) in map {
175 body[k] = v.clone();
176 }
177 }
178 }
179
180 Ok(DriverRequest {
181 url: String::new(), method: "POST".into(),
183 headers: HashMap::new(),
184 body,
185 stream,
186 })
187 }
188
189 fn parse_response(&self, body: &Value) -> Result<DriverResponse, Error> {
190 let content = body
191 .pointer("/choices/0/message/content")
192 .and_then(|v| v.as_str())
193 .map(String::from);
194 let finish_reason = body
195 .pointer("/choices/0/finish_reason")
196 .and_then(|v| v.as_str())
197 .map(String::from);
198 let usage = body.get("usage").map(|u| UsageInfo {
199 prompt_tokens: u["prompt_tokens"].as_u64().unwrap_or(0),
200 completion_tokens: u["completion_tokens"].as_u64().unwrap_or(0),
201 total_tokens: u["total_tokens"].as_u64().unwrap_or(0),
202 });
203 let tool_calls = body
204 .pointer("/choices/0/message/tool_calls")
205 .and_then(|v| v.as_array())
206 .cloned()
207 .unwrap_or_default();
208
209 Ok(DriverResponse {
210 content,
211 finish_reason,
212 usage,
213 tool_calls,
214 raw: body.clone(),
215 })
216 }
217
218 fn parse_stream_event(&self, data: &str) -> Result<Option<StreamingEvent>, Error> {
219 if data.trim().is_empty() || self.is_stream_done(data) {
220 return Ok(None);
221 }
222 let v: Value = serde_json::from_str(data)
223 .map_err(|e| Error::Protocol(ProtocolError::ValidationError(
224 format!("Failed to parse SSE data: {}", e),
225 )))?;
226
227 if let Some(content) = v.pointer("/choices/0/delta/content").and_then(|c| c.as_str()) {
229 if !content.is_empty() {
230 return Ok(Some(StreamingEvent::PartialContentDelta {
231 content: content.to_string(),
232 sequence_id: None,
233 }));
234 }
235 }
236
237 if let Some(reason) = v.pointer("/choices/0/finish_reason").and_then(|r| r.as_str()) {
239 return Ok(Some(StreamingEvent::StreamEnd {
240 finish_reason: Some(reason.to_string()),
241 }));
242 }
243
244 Ok(None)
245 }
246
247 fn supported_capabilities(&self) -> &[Capability] {
248 &self.capabilities
249 }
250
251 fn is_stream_done(&self, data: &str) -> bool {
252 data.trim() == "[DONE]"
253 }
254}
255
256pub fn create_driver(
262 api_style: ApiStyle,
263 provider_id: &str,
264 capabilities: Vec<Capability>,
265) -> Box<dyn ProviderDriver> {
266 match api_style {
267 ApiStyle::OpenAiCompatible | ApiStyle::Custom => {
268 Box::new(OpenAiDriver::new(provider_id, capabilities))
269 }
270 ApiStyle::AnthropicMessages => {
271 Box::new(AnthropicDriver::new(provider_id, capabilities))
272 }
273 ApiStyle::GeminiGenerate => {
274 Box::new(GeminiDriver::new(provider_id, capabilities))
275 }
276 }
277}
278
279#[cfg(test)]
280mod tests {
281 use super::*;
282
283 #[test]
284 fn test_openai_driver_build_request() {
285 let driver = OpenAiDriver::new("openai", vec![Capability::Text, Capability::Streaming]);
286 let messages = vec![Message::user("Hello")];
287 let req = driver
288 .build_request(&messages, "gpt-4", Some(0.7), Some(1024), true, None)
289 .unwrap();
290 assert!(req.stream);
291 assert_eq!(req.body["model"], "gpt-4");
292 assert_eq!(req.body["temperature"], 0.7);
293 }
294
295 #[test]
296 fn test_openai_driver_parse_response() {
297 let driver = OpenAiDriver::new("openai", vec![]);
298 let body = serde_json::json!({
299 "choices": [{"message": {"content": "Hi there!"}, "finish_reason": "stop"}],
300 "usage": {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15}
301 });
302 let resp = driver.parse_response(&body).unwrap();
303 assert_eq!(resp.content.as_deref(), Some("Hi there!"));
304 assert_eq!(resp.finish_reason.as_deref(), Some("stop"));
305 assert_eq!(resp.usage.unwrap().total_tokens, 15);
306 }
307
308 #[test]
309 fn test_openai_driver_parse_stream() {
310 let driver = OpenAiDriver::new("openai", vec![]);
311 let data = r#"{"choices":[{"delta":{"content":"Hello"},"index":0}]}"#;
312 let event = driver.parse_stream_event(data).unwrap();
313 match event {
314 Some(StreamingEvent::PartialContentDelta { content, .. }) => {
315 assert_eq!(content, "Hello");
316 }
317 _ => panic!("Expected PartialContentDelta"),
318 }
319 }
320
321 #[test]
322 fn test_stream_done_detection() {
323 let driver = OpenAiDriver::new("openai", vec![]);
324 assert!(driver.is_stream_done("[DONE]"));
325 assert!(!driver.is_stream_done(r#"{"choices":[]}"#));
326 }
327}