Skip to main content

ai_lib_core/drivers/
mod.rs

1//! Provider 驱动抽象层 — 通过 trait 实现多厂商 API 适配的动态分发
2//!
3//! Provider driver abstraction layer implementing the ProviderContract specification.
4//! Uses `Box<dyn ProviderDriver>` for runtime polymorphism, enabling the same client
5//! code to work with OpenAI, Anthropic, Gemini, and any OpenAI-compatible provider.
6
7pub mod anthropic;
8pub mod gemini;
9
10use async_trait::async_trait;
11use serde_json::Value;
12use std::collections::HashMap;
13
14use crate::error::Error;
15use crate::protocol::v2::capabilities::Capability;
16use crate::protocol::v2::manifest::ApiStyle;
17use crate::protocol::ProtocolError;
18use crate::types::events::StreamingEvent;
19use crate::types::execution_result::ExecutionUsage;
20use crate::types::message::{Message, MessageContent};
21
22pub use anthropic::AnthropicDriver;
23pub use gemini::GeminiDriver;
24
25/// Unified HTTP request representation for provider communication.
26#[derive(Debug, Clone)]
27pub struct DriverRequest {
28    /// Target URL (base_url + chat_path).
29    pub url: String,
30    /// HTTP method (POST for chat, GET for models).
31    pub method: String,
32    /// Request headers.
33    pub headers: HashMap<String, String>,
34    /// Serialized JSON request body.
35    pub body: Value,
36    /// Whether streaming is requested.
37    pub stream: bool,
38}
39
40/// Unified chat response from provider.
41#[derive(Debug, Clone)]
42pub struct DriverResponse {
43    /// Extracted text content.
44    pub content: Option<String>,
45    /// Finish reason normalized to AI-Protocol standard.
46    pub finish_reason: Option<String>,
47    /// Token usage statistics.
48    pub usage: Option<UsageInfo>,
49    /// Tool calls if any.
50    pub tool_calls: Vec<Value>,
51    /// Raw provider response for debugging.
52    pub raw: Value,
53}
54
55/// Token usage information.
56#[derive(Debug, Clone, Default, serde::Serialize)]
57pub struct UsageInfo {
58    pub prompt_tokens: u64,
59    pub completion_tokens: u64,
60    pub total_tokens: u64,
61    /// OpenAI-style `completion_tokens_details.reasoning_tokens`.
62    pub reasoning_tokens: Option<u64>,
63    /// Anthropic cache read / creation input tokens (when present).
64    pub cache_read_tokens: Option<u64>,
65    pub cache_creation_tokens: Option<u64>,
66}
67
68impl From<UsageInfo> for ExecutionUsage {
69    fn from(u: UsageInfo) -> Self {
70        Self {
71            prompt_tokens: u.prompt_tokens,
72            completion_tokens: u.completion_tokens,
73            total_tokens: u.total_tokens,
74            reasoning_tokens: u.reasoning_tokens,
75            cache_read_tokens: u.cache_read_tokens,
76            cache_creation_tokens: u.cache_creation_tokens,
77        }
78    }
79}
80
81/// Core trait for provider-specific API adaptation.
82///
83/// Each provider API style (OpenAI, Anthropic, Gemini) has a concrete implementation.
84/// The trait is object-safe and supports dynamic dispatch via `Box<dyn ProviderDriver>`.
85///
86/// # Design Notes
87///
88/// Inspired by `sqlx::Database` — the trait defines the contract, concrete types
89/// implement the transformations. The runtime selects the correct driver based on
90/// the manifest's `api_style` or `provider_contract`.
91#[async_trait]
92pub trait ProviderDriver: Send + Sync + std::fmt::Debug {
93    /// Unique provider identifier (matches manifest `id`).
94    fn provider_id(&self) -> &str;
95
96    /// API style this driver implements.
97    fn api_style(&self) -> ApiStyle;
98
99    /// Build a provider-specific HTTP request from unified parameters.
100    fn build_request(
101        &self,
102        messages: &[Message],
103        model: &str,
104        temperature: Option<f64>,
105        max_tokens: Option<u32>,
106        stream: bool,
107        extra: Option<&Value>,
108    ) -> Result<DriverRequest, Error>;
109
110    /// Parse a non-streaming response into unified format.
111    fn parse_response(&self, body: &Value) -> Result<DriverResponse, Error>;
112
113    /// Parse a single streaming event from raw SSE/NDJSON data.
114    fn parse_stream_event(&self, data: &str) -> Result<Option<StreamingEvent>, Error>;
115
116    /// Get the list of capabilities this driver supports.
117    fn supported_capabilities(&self) -> &[Capability];
118
119    /// Check if the done signal has been received in streaming.
120    fn is_stream_done(&self, data: &str) -> bool;
121}
122
123/// OpenAI-compatible driver — works for OpenAI, DeepSeek, Moonshot, Zhipu, etc.
124#[derive(Debug)]
125pub struct OpenAiDriver {
126    provider_id: String,
127    capabilities: Vec<Capability>,
128}
129
130impl OpenAiDriver {
131    pub fn new(provider_id: impl Into<String>, capabilities: Vec<Capability>) -> Self {
132        Self {
133            provider_id: provider_id.into(),
134            capabilities,
135        }
136    }
137}
138
139fn parse_openai_usage_value(u: &Value) -> UsageInfo {
140    let reasoning = u
141        .pointer("/completion_tokens_details/reasoning_tokens")
142        .and_then(|v| v.as_u64());
143    UsageInfo {
144        prompt_tokens: u["prompt_tokens"].as_u64().unwrap_or(0),
145        completion_tokens: u["completion_tokens"].as_u64().unwrap_or(0),
146        total_tokens: u["total_tokens"].as_u64().unwrap_or(0),
147        reasoning_tokens: reasoning,
148        cache_read_tokens: None,
149        cache_creation_tokens: None,
150    }
151}
152
153#[async_trait]
154impl ProviderDriver for OpenAiDriver {
155    fn provider_id(&self) -> &str {
156        &self.provider_id
157    }
158
159    fn api_style(&self) -> ApiStyle {
160        ApiStyle::OpenAiCompatible
161    }
162
163    fn build_request(
164        &self,
165        messages: &[Message],
166        model: &str,
167        temperature: Option<f64>,
168        max_tokens: Option<u32>,
169        stream: bool,
170        extra: Option<&Value>,
171    ) -> Result<DriverRequest, Error> {
172        let oai_messages: Vec<Value> = messages
173            .iter()
174            .map(|m| {
175                let role = serde_json::to_value(&m.role).unwrap_or(Value::String("user".into()));
176                let content = match &m.content {
177                    MessageContent::Text(s) => Value::String(s.clone()),
178                    MessageContent::Blocks(_) => {
179                        serde_json::to_value(&m.content).unwrap_or(Value::Null)
180                    }
181                };
182                let mut obj = serde_json::json!({ "role": role, "content": content });
183                // OpenAI API requires tool_call_id for role "tool"
184                if matches!(m.role, crate::types::message::MessageRole::Tool) {
185                    if let Some(ref id) = m.tool_call_id {
186                        obj["tool_call_id"] = Value::String(id.clone());
187                    }
188                }
189                obj
190            })
191            .collect();
192
193        let mut body = serde_json::json!({
194            "model": model,
195            "messages": oai_messages,
196            "stream": stream,
197        });
198
199        if let Some(t) = temperature {
200            body["temperature"] = serde_json::json!(t);
201        }
202        if let Some(mt) = max_tokens {
203            body["max_tokens"] = serde_json::json!(mt);
204        }
205        if let Some(Value::Object(map)) = extra {
206            for (k, v) in map {
207                body[k] = v.clone();
208            }
209        }
210
211        Ok(DriverRequest {
212            url: String::new(), // URL is set by the client layer from manifest
213            method: "POST".into(),
214            headers: HashMap::new(),
215            body,
216            stream,
217        })
218    }
219
220    fn parse_response(&self, body: &Value) -> Result<DriverResponse, Error> {
221        let content = body
222            .pointer("/choices/0/message/content")
223            .and_then(|v| v.as_str())
224            .map(String::from);
225        let finish_reason = body
226            .pointer("/choices/0/finish_reason")
227            .and_then(|v| v.as_str())
228            .map(String::from);
229        let usage = body.get("usage").map(parse_openai_usage_value);
230        let tool_calls = body
231            .pointer("/choices/0/message/tool_calls")
232            .and_then(|v| v.as_array())
233            .cloned()
234            .unwrap_or_default();
235
236        Ok(DriverResponse {
237            content,
238            finish_reason,
239            usage,
240            tool_calls,
241            raw: body.clone(),
242        })
243    }
244
245    fn parse_stream_event(&self, data: &str) -> Result<Option<StreamingEvent>, Error> {
246        if data.trim().is_empty() || self.is_stream_done(data) {
247            return Ok(None);
248        }
249        let v: Value = serde_json::from_str(data).map_err(|e| {
250            Error::Protocol(ProtocolError::ValidationError(format!(
251                "Failed to parse SSE data: {}",
252                e
253            )))
254        })?;
255
256        // Content delta
257        if let Some(content) = v
258            .pointer("/choices/0/delta/content")
259            .and_then(|c| c.as_str())
260        {
261            if !content.is_empty() {
262                return Ok(Some(StreamingEvent::PartialContentDelta {
263                    content: content.to_string(),
264                    sequence_id: None,
265                }));
266            }
267        }
268
269        // Reasoning / extended thinking (OpenAI-compatible `reasoning_content` delta)
270        if let Some(thinking) = v
271            .pointer("/choices/0/delta/reasoning_content")
272            .and_then(|c| c.as_str())
273        {
274            if !thinking.is_empty() {
275                return Ok(Some(StreamingEvent::ThinkingDelta {
276                    thinking: thinking.to_string(),
277                    tool_consideration: None,
278                }));
279            }
280        }
281
282        // Finish reason
283        if let Some(reason) = v
284            .pointer("/choices/0/finish_reason")
285            .and_then(|r| r.as_str())
286        {
287            return Ok(Some(StreamingEvent::StreamEnd {
288                finish_reason: Some(reason.to_string()),
289            }));
290        }
291
292        Ok(None)
293    }
294
295    fn supported_capabilities(&self) -> &[Capability] {
296        &self.capabilities
297    }
298
299    fn is_stream_done(&self, data: &str) -> bool {
300        data.trim() == "[DONE]"
301    }
302}
303
304/// Factory function to create the appropriate driver from an API style.
305///
306/// The `Custom` style falls back to OpenAI-compatible, which covers most
307/// providers that follow the OpenAI chat completions format (DeepSeek,
308/// Moonshot, Zhipu, etc.).
309pub fn create_driver(
310    api_style: ApiStyle,
311    provider_id: &str,
312    capabilities: Vec<Capability>,
313) -> Box<dyn ProviderDriver> {
314    match api_style {
315        ApiStyle::OpenAiCompatible | ApiStyle::Custom => {
316            Box::new(OpenAiDriver::new(provider_id, capabilities))
317        }
318        ApiStyle::AnthropicMessages => Box::new(AnthropicDriver::new(provider_id, capabilities)),
319        ApiStyle::GeminiGenerate => Box::new(GeminiDriver::new(provider_id, capabilities)),
320    }
321}
322
323#[cfg(test)]
324mod tests {
325    use super::*;
326
327    #[test]
328    fn test_openai_driver_build_request() {
329        let driver = OpenAiDriver::new("openai", vec![Capability::Text, Capability::Streaming]);
330        let messages = vec![Message::user("Hello")];
331        let req = driver
332            .build_request(&messages, "gpt-4", Some(0.7), Some(1024), true, None)
333            .unwrap();
334        assert!(req.stream);
335        assert_eq!(req.body["model"], "gpt-4");
336        assert_eq!(req.body["temperature"], 0.7);
337    }
338
339    #[test]
340    fn test_openai_driver_parse_response() {
341        let driver = OpenAiDriver::new("openai", vec![]);
342        let body = serde_json::json!({
343            "choices": [{"message": {"content": "Hi there!"}, "finish_reason": "stop"}],
344            "usage": {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15}
345        });
346        let resp = driver.parse_response(&body).unwrap();
347        assert_eq!(resp.content.as_deref(), Some("Hi there!"));
348        assert_eq!(resp.finish_reason.as_deref(), Some("stop"));
349        assert_eq!(resp.usage.unwrap().total_tokens, 15);
350    }
351
352    #[test]
353    fn test_openai_driver_parse_response_reasoning_tokens() {
354        let driver = OpenAiDriver::new("openai", vec![]);
355        let body = serde_json::json!({
356            "choices": [{"message": {"content": "Hello, world!"}, "finish_reason": "stop"}],
357            "usage": {
358                "prompt_tokens": 10,
359                "completion_tokens": 5,
360                "total_tokens": 15,
361                "completion_tokens_details": {"reasoning_tokens": 3}
362            }
363        });
364        let resp = driver.parse_response(&body).unwrap();
365        let u = resp.usage.expect("usage");
366        assert_eq!(u.reasoning_tokens, Some(3));
367        assert_eq!(u.prompt_tokens, 10);
368        assert_eq!(u.completion_tokens, 5);
369    }
370
371    #[test]
372    fn test_openai_driver_parse_stream_reasoning_delta() {
373        let driver = OpenAiDriver::new("openai", vec![]);
374        let data = r#"{"choices":[{"delta":{"reasoning_content":"Let me think..."},"index":0}]}"#;
375        let event = driver.parse_stream_event(data).unwrap();
376        match event {
377            Some(StreamingEvent::ThinkingDelta { thinking, .. }) => {
378                assert_eq!(thinking, "Let me think...");
379            }
380            _ => panic!("Expected ThinkingDelta, got {:?}", event),
381        }
382    }
383
384    #[test]
385    fn test_openai_driver_parse_stream() {
386        let driver = OpenAiDriver::new("openai", vec![]);
387        let data = r#"{"choices":[{"delta":{"content":"Hello"},"index":0}]}"#;
388        let event = driver.parse_stream_event(data).unwrap();
389        match event {
390            Some(StreamingEvent::PartialContentDelta { content, .. }) => {
391                assert_eq!(content, "Hello");
392            }
393            _ => panic!("Expected PartialContentDelta"),
394        }
395    }
396
397    #[test]
398    fn test_stream_done_detection() {
399        let driver = OpenAiDriver::new("openai", vec![]);
400        assert!(driver.is_stream_done("[DONE]"));
401        assert!(!driver.is_stream_done(r#"{"choices":[]}"#));
402    }
403}