Skip to main content

ai_lib_rust/drivers/
mod.rs

1//! Provider 驱动抽象层 — 通过 trait 实现多厂商 API 适配的动态分发
2//!
3//! Provider driver abstraction layer implementing the ProviderContract specification.
4//! Uses `Box<dyn ProviderDriver>` for runtime polymorphism, enabling the same client
5//! code to work with OpenAI, Anthropic, Gemini, and any OpenAI-compatible provider.
6
7pub mod anthropic;
8pub mod gemini;
9
10use async_trait::async_trait;
11use serde_json::Value;
12use std::collections::HashMap;
13
14use crate::error::Error;
15use crate::protocol::v2::manifest::ApiStyle;
16use crate::protocol::v2::capabilities::Capability;
17use crate::protocol::ProtocolError;
18use crate::types::events::StreamingEvent;
19use crate::types::message::{Message, MessageContent};
20
21pub use anthropic::AnthropicDriver;
22pub use gemini::GeminiDriver;
23
24/// Unified HTTP request representation for provider communication.
25#[derive(Debug, Clone)]
26pub struct DriverRequest {
27    /// Target URL (base_url + chat_path).
28    pub url: String,
29    /// HTTP method (POST for chat, GET for models).
30    pub method: String,
31    /// Request headers.
32    pub headers: HashMap<String, String>,
33    /// Serialized JSON request body.
34    pub body: Value,
35    /// Whether streaming is requested.
36    pub stream: bool,
37}
38
39/// Unified chat response from provider.
40#[derive(Debug, Clone)]
41pub struct DriverResponse {
42    /// Extracted text content.
43    pub content: Option<String>,
44    /// Finish reason normalized to AI-Protocol standard.
45    pub finish_reason: Option<String>,
46    /// Token usage statistics.
47    pub usage: Option<UsageInfo>,
48    /// Tool calls if any.
49    pub tool_calls: Vec<Value>,
50    /// Raw provider response for debugging.
51    pub raw: Value,
52}
53
54/// Token usage information.
55#[derive(Debug, Clone, Default)]
56pub struct UsageInfo {
57    pub prompt_tokens: u64,
58    pub completion_tokens: u64,
59    pub total_tokens: u64,
60}
61
62/// Core trait for provider-specific API adaptation.
63///
64/// Each provider API style (OpenAI, Anthropic, Gemini) has a concrete implementation.
65/// The trait is object-safe and supports dynamic dispatch via `Box<dyn ProviderDriver>`.
66///
67/// # Design Notes
68///
69/// Inspired by `sqlx::Database` — the trait defines the contract, concrete types
70/// implement the transformations. The runtime selects the correct driver based on
71/// the manifest's `api_style` or `provider_contract`.
72#[async_trait]
73pub trait ProviderDriver: Send + Sync + std::fmt::Debug {
74    /// Unique provider identifier (matches manifest `id`).
75    fn provider_id(&self) -> &str;
76
77    /// API style this driver implements.
78    fn api_style(&self) -> ApiStyle;
79
80    /// Build a provider-specific HTTP request from unified parameters.
81    fn build_request(
82        &self,
83        messages: &[Message],
84        model: &str,
85        temperature: Option<f64>,
86        max_tokens: Option<u32>,
87        stream: bool,
88        extra: Option<&Value>,
89    ) -> Result<DriverRequest, Error>;
90
91    /// Parse a non-streaming response into unified format.
92    fn parse_response(&self, body: &Value) -> Result<DriverResponse, Error>;
93
94    /// Parse a single streaming event from raw SSE/NDJSON data.
95    fn parse_stream_event(&self, data: &str) -> Result<Option<StreamingEvent>, Error>;
96
97    /// Get the list of capabilities this driver supports.
98    fn supported_capabilities(&self) -> &[Capability];
99
100    /// Check if the done signal has been received in streaming.
101    fn is_stream_done(&self, data: &str) -> bool;
102}
103
104/// OpenAI-compatible driver — works for OpenAI, DeepSeek, Moonshot, Zhipu, etc.
105#[derive(Debug)]
106pub struct OpenAiDriver {
107    provider_id: String,
108    capabilities: Vec<Capability>,
109}
110
111impl OpenAiDriver {
112    pub fn new(provider_id: impl Into<String>, capabilities: Vec<Capability>) -> Self {
113        Self {
114            provider_id: provider_id.into(),
115            capabilities,
116        }
117    }
118}
119
120#[async_trait]
121impl ProviderDriver for OpenAiDriver {
122    fn provider_id(&self) -> &str {
123        &self.provider_id
124    }
125
126    fn api_style(&self) -> ApiStyle {
127        ApiStyle::OpenAiCompatible
128    }
129
130    fn build_request(
131        &self,
132        messages: &[Message],
133        model: &str,
134        temperature: Option<f64>,
135        max_tokens: Option<u32>,
136        stream: bool,
137        extra: Option<&Value>,
138    ) -> Result<DriverRequest, Error> {
139        let oai_messages: Vec<Value> = messages
140            .iter()
141            .map(|m| {
142                let role = serde_json::to_value(&m.role).unwrap_or(Value::String("user".into()));
143                let content = match &m.content {
144                    MessageContent::Text(s) => Value::String(s.clone()),
145                    MessageContent::Blocks(_) => {
146                        serde_json::to_value(&m.content).unwrap_or(Value::Null)
147                    }
148                };
149                let mut obj = serde_json::json!({ "role": role, "content": content });
150                // OpenAI API requires tool_call_id for role "tool"
151                if matches!(m.role, crate::types::message::MessageRole::Tool) {
152                    if let Some(ref id) = m.tool_call_id {
153                        obj["tool_call_id"] = Value::String(id.clone());
154                    }
155                }
156                obj
157            })
158            .collect();
159
160        let mut body = serde_json::json!({
161            "model": model,
162            "messages": oai_messages,
163            "stream": stream,
164        });
165
166        if let Some(t) = temperature {
167            body["temperature"] = serde_json::json!(t);
168        }
169        if let Some(mt) = max_tokens {
170            body["max_tokens"] = serde_json::json!(mt);
171        }
172        if let Some(ext) = extra {
173            if let Value::Object(map) = ext {
174                for (k, v) in map {
175                    body[k] = v.clone();
176                }
177            }
178        }
179
180        Ok(DriverRequest {
181            url: String::new(), // URL is set by the client layer from manifest
182            method: "POST".into(),
183            headers: HashMap::new(),
184            body,
185            stream,
186        })
187    }
188
189    fn parse_response(&self, body: &Value) -> Result<DriverResponse, Error> {
190        let content = body
191            .pointer("/choices/0/message/content")
192            .and_then(|v| v.as_str())
193            .map(String::from);
194        let finish_reason = body
195            .pointer("/choices/0/finish_reason")
196            .and_then(|v| v.as_str())
197            .map(String::from);
198        let usage = body.get("usage").map(|u| UsageInfo {
199            prompt_tokens: u["prompt_tokens"].as_u64().unwrap_or(0),
200            completion_tokens: u["completion_tokens"].as_u64().unwrap_or(0),
201            total_tokens: u["total_tokens"].as_u64().unwrap_or(0),
202        });
203        let tool_calls = body
204            .pointer("/choices/0/message/tool_calls")
205            .and_then(|v| v.as_array())
206            .cloned()
207            .unwrap_or_default();
208
209        Ok(DriverResponse {
210            content,
211            finish_reason,
212            usage,
213            tool_calls,
214            raw: body.clone(),
215        })
216    }
217
218    fn parse_stream_event(&self, data: &str) -> Result<Option<StreamingEvent>, Error> {
219        if data.trim().is_empty() || self.is_stream_done(data) {
220            return Ok(None);
221        }
222        let v: Value = serde_json::from_str(data)
223            .map_err(|e| Error::Protocol(ProtocolError::ValidationError(
224                format!("Failed to parse SSE data: {}", e),
225            )))?;
226
227        // Content delta
228        if let Some(content) = v.pointer("/choices/0/delta/content").and_then(|c| c.as_str()) {
229            if !content.is_empty() {
230                return Ok(Some(StreamingEvent::PartialContentDelta {
231                    content: content.to_string(),
232                    sequence_id: None,
233                }));
234            }
235        }
236
237        // Finish reason
238        if let Some(reason) = v.pointer("/choices/0/finish_reason").and_then(|r| r.as_str()) {
239            return Ok(Some(StreamingEvent::StreamEnd {
240                finish_reason: Some(reason.to_string()),
241            }));
242        }
243
244        Ok(None)
245    }
246
247    fn supported_capabilities(&self) -> &[Capability] {
248        &self.capabilities
249    }
250
251    fn is_stream_done(&self, data: &str) -> bool {
252        data.trim() == "[DONE]"
253    }
254}
255
256/// Factory function to create the appropriate driver from an API style.
257///
258/// The `Custom` style falls back to OpenAI-compatible, which covers most
259/// providers that follow the OpenAI chat completions format (DeepSeek,
260/// Moonshot, Zhipu, etc.).
261pub fn create_driver(
262    api_style: ApiStyle,
263    provider_id: &str,
264    capabilities: Vec<Capability>,
265) -> Box<dyn ProviderDriver> {
266    match api_style {
267        ApiStyle::OpenAiCompatible | ApiStyle::Custom => {
268            Box::new(OpenAiDriver::new(provider_id, capabilities))
269        }
270        ApiStyle::AnthropicMessages => {
271            Box::new(AnthropicDriver::new(provider_id, capabilities))
272        }
273        ApiStyle::GeminiGenerate => {
274            Box::new(GeminiDriver::new(provider_id, capabilities))
275        }
276    }
277}
278
279#[cfg(test)]
280mod tests {
281    use super::*;
282
283    #[test]
284    fn test_openai_driver_build_request() {
285        let driver = OpenAiDriver::new("openai", vec![Capability::Text, Capability::Streaming]);
286        let messages = vec![Message::user("Hello")];
287        let req = driver
288            .build_request(&messages, "gpt-4", Some(0.7), Some(1024), true, None)
289            .unwrap();
290        assert!(req.stream);
291        assert_eq!(req.body["model"], "gpt-4");
292        assert_eq!(req.body["temperature"], 0.7);
293    }
294
295    #[test]
296    fn test_openai_driver_parse_response() {
297        let driver = OpenAiDriver::new("openai", vec![]);
298        let body = serde_json::json!({
299            "choices": [{"message": {"content": "Hi there!"}, "finish_reason": "stop"}],
300            "usage": {"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15}
301        });
302        let resp = driver.parse_response(&body).unwrap();
303        assert_eq!(resp.content.as_deref(), Some("Hi there!"));
304        assert_eq!(resp.finish_reason.as_deref(), Some("stop"));
305        assert_eq!(resp.usage.unwrap().total_tokens, 15);
306    }
307
308    #[test]
309    fn test_openai_driver_parse_stream() {
310        let driver = OpenAiDriver::new("openai", vec![]);
311        let data = r#"{"choices":[{"delta":{"content":"Hello"},"index":0}]}"#;
312        let event = driver.parse_stream_event(data).unwrap();
313        match event {
314            Some(StreamingEvent::PartialContentDelta { content, .. }) => {
315                assert_eq!(content, "Hello");
316            }
317            _ => panic!("Expected PartialContentDelta"),
318        }
319    }
320
321    #[test]
322    fn test_stream_done_detection() {
323        let driver = OpenAiDriver::new("openai", vec![]);
324        assert!(driver.is_stream_done("[DONE]"));
325        assert!(!driver.is_stream_done(r#"{"choices":[]}"#));
326    }
327}