Skip to main content

ai_lib_rust/drivers/
anthropic.rs

1//! Anthropic Messages API 驱动 — 实现 Anthropic 特有的请求/响应格式转换
2//!
3//! Anthropic Messages API driver. Handles the key differences from OpenAI:
4//! - System messages are a top-level `system` parameter, not part of `messages`.
5//! - Content uses typed blocks: `[{"type": "text", "text": "..."}]`.
6//! - Streaming uses `event: content_block_delta` with `delta.text`.
7//! - Response uses `content[0].text` instead of `choices[0].message.content`.
8//! - `max_tokens` is required, not optional.
9
10use async_trait::async_trait;
11use serde_json::Value;
12use std::collections::HashMap;
13
14use crate::error::Error;
15use crate::protocol::v2::capabilities::Capability;
16use crate::protocol::v2::manifest::ApiStyle;
17use crate::protocol::ProtocolError;
18use crate::types::events::StreamingEvent;
19use crate::types::message::{Message, MessageContent, MessageRole};
20
21use super::{DriverRequest, DriverResponse, ProviderDriver, UsageInfo};
22
23const DEFAULT_MAX_TOKENS: u32 = 4096;
24
25/// Anthropic Messages API driver.
26#[derive(Debug)]
27pub struct AnthropicDriver {
28    provider_id: String,
29    capabilities: Vec<Capability>,
30}
31
32impl AnthropicDriver {
33    pub fn new(provider_id: impl Into<String>, capabilities: Vec<Capability>) -> Self {
34        Self {
35            provider_id: provider_id.into(),
36            capabilities,
37        }
38    }
39
40    /// Extract system message and non-system messages separately.
41    /// Anthropic requires system as a top-level param, not in messages array.
42    fn split_system_messages(messages: &[Message]) -> (Option<String>, Vec<Value>) {
43        let mut system_parts: Vec<String> = Vec::new();
44        let mut user_messages: Vec<Value> = Vec::new();
45
46        for m in messages {
47            match m.role {
48                MessageRole::System => {
49                    if let MessageContent::Text(ref s) = m.content {
50                        system_parts.push(s.clone());
51                    }
52                }
53                _ => {
54                    let role = match m.role {
55                        MessageRole::User => "user",
56                        MessageRole::Assistant => "assistant",
57                        MessageRole::System => unreachable!(),
58                    };
59                    let content = match &m.content {
60                        MessageContent::Text(s) => {
61                            serde_json::json!([{ "type": "text", "text": s }])
62                        }
63                        MessageContent::Blocks(_) => {
64                            serde_json::to_value(&m.content).unwrap_or(Value::Null)
65                        }
66                    };
67                    user_messages.push(serde_json::json!({
68                        "role": role,
69                        "content": content,
70                    }));
71                }
72            }
73        }
74
75        let system = if system_parts.is_empty() {
76            None
77        } else {
78            Some(system_parts.join("\n\n"))
79        };
80
81        (system, user_messages)
82    }
83}
84
85#[async_trait]
86impl ProviderDriver for AnthropicDriver {
87    fn provider_id(&self) -> &str {
88        &self.provider_id
89    }
90
91    fn api_style(&self) -> ApiStyle {
92        ApiStyle::AnthropicMessages
93    }
94
95    fn build_request(
96        &self,
97        messages: &[Message],
98        model: &str,
99        temperature: Option<f64>,
100        max_tokens: Option<u32>,
101        stream: bool,
102        extra: Option<&Value>,
103    ) -> Result<DriverRequest, Error> {
104        let (system, msgs) = Self::split_system_messages(messages);
105
106        let mut body = serde_json::json!({
107            "model": model,
108            "messages": msgs,
109            "max_tokens": max_tokens.unwrap_or(DEFAULT_MAX_TOKENS),
110            "stream": stream,
111        });
112
113        if let Some(sys) = system {
114            body["system"] = Value::String(sys);
115        }
116        if let Some(t) = temperature {
117            body["temperature"] = serde_json::json!(t);
118        }
119        if let Some(ext) = extra {
120            if let Value::Object(map) = ext {
121                for (k, v) in map {
122                    body[k] = v.clone();
123                }
124            }
125        }
126
127        let mut headers = HashMap::new();
128        headers.insert("anthropic-version".into(), "2023-06-01".into());
129
130        Ok(DriverRequest {
131            url: String::new(),
132            method: "POST".into(),
133            headers,
134            body,
135            stream,
136        })
137    }
138
139    fn parse_response(&self, body: &Value) -> Result<DriverResponse, Error> {
140        // Anthropic response: { content: [{type: "text", text: "..."}], stop_reason, usage }
141        let content = body
142            .pointer("/content/0/text")
143            .and_then(|v| v.as_str())
144            .map(String::from);
145
146        // Normalize stop_reason → finish_reason
147        let finish_reason = body
148            .get("stop_reason")
149            .and_then(|v| v.as_str())
150            .map(|r| match r {
151                "end_turn" => "stop".to_string(),
152                "max_tokens" => "length".to_string(),
153                "tool_use" => "tool_calls".to_string(),
154                other => other.to_string(),
155            });
156
157        let usage = body.get("usage").map(|u| UsageInfo {
158            prompt_tokens: u["input_tokens"].as_u64().unwrap_or(0),
159            completion_tokens: u["output_tokens"].as_u64().unwrap_or(0),
160            total_tokens: u["input_tokens"].as_u64().unwrap_or(0)
161                + u["output_tokens"].as_u64().unwrap_or(0),
162        });
163
164        // Extract tool_use blocks from content array
165        let tool_calls: Vec<Value> = body
166            .get("content")
167            .and_then(|c| c.as_array())
168            .map(|arr| {
169                arr.iter()
170                    .filter(|b| b.get("type").and_then(|t| t.as_str()) == Some("tool_use"))
171                    .cloned()
172                    .collect()
173            })
174            .unwrap_or_default();
175
176        Ok(DriverResponse {
177            content,
178            finish_reason,
179            usage,
180            tool_calls,
181            raw: body.clone(),
182        })
183    }
184
185    fn parse_stream_event(&self, data: &str) -> Result<Option<StreamingEvent>, Error> {
186        if data.trim().is_empty() {
187            return Ok(None);
188        }
189
190        let v: Value = serde_json::from_str(data).map_err(|e| {
191            Error::Protocol(ProtocolError::ValidationError(format!(
192                "Failed to parse Anthropic SSE: {}",
193                e
194            )))
195        })?;
196
197        let event_type = v.get("type").and_then(|t| t.as_str()).unwrap_or("");
198
199        match event_type {
200            "content_block_delta" => {
201                if let Some(text) = v.pointer("/delta/text").and_then(|t| t.as_str()) {
202                    if !text.is_empty() {
203                        return Ok(Some(StreamingEvent::PartialContentDelta {
204                            content: text.to_string(),
205                            sequence_id: v.get("index").and_then(|i| i.as_u64()),
206                        }));
207                    }
208                }
209                // Thinking delta support
210                if let Some(thinking) = v.pointer("/delta/thinking").and_then(|t| t.as_str()) {
211                    return Ok(Some(StreamingEvent::ThinkingDelta {
212                        thinking: thinking.to_string(),
213                        tool_consideration: None,
214                    }));
215                }
216                Ok(None)
217            }
218            "message_delta" => {
219                let reason = v.pointer("/delta/stop_reason").and_then(|r| r.as_str());
220                if let Some(r) = reason {
221                    return Ok(Some(StreamingEvent::StreamEnd {
222                        finish_reason: Some(match r {
223                            "end_turn" => "stop".to_string(),
224                            "max_tokens" => "length".to_string(),
225                            other => other.to_string(),
226                        }),
227                    }));
228                }
229                Ok(None)
230            }
231            "message_stop" => Ok(Some(StreamingEvent::StreamEnd {
232                finish_reason: Some("stop".into()),
233            })),
234            "error" => {
235                let error = v.get("error").cloned().unwrap_or(Value::Null);
236                Ok(Some(StreamingEvent::StreamError {
237                    error,
238                    event_id: None,
239                }))
240            }
241            _ => Ok(None),
242        }
243    }
244
245    fn supported_capabilities(&self) -> &[Capability] {
246        &self.capabilities
247    }
248
249    fn is_stream_done(&self, _data: &str) -> bool {
250        // Anthropic signals done via event type, not a sentinel string.
251        // The `event: message_stop` is handled in parse_stream_event.
252        false
253    }
254}
255
256#[cfg(test)]
257mod tests {
258    use super::*;
259
260    #[test]
261    fn test_system_message_extraction() {
262        let msgs = vec![
263            Message::system("You are helpful."),
264            Message::user("Hi"),
265        ];
266        let (sys, user_msgs) = AnthropicDriver::split_system_messages(&msgs);
267        assert_eq!(sys.as_deref(), Some("You are helpful."));
268        assert_eq!(user_msgs.len(), 1);
269        assert_eq!(user_msgs[0]["role"], "user");
270    }
271
272    #[test]
273    fn test_anthropic_build_request() {
274        let driver = AnthropicDriver::new("anthropic", vec![Capability::Text]);
275        let messages = vec![Message::user("Hello")];
276        let req = driver
277            .build_request(&messages, "claude-sonnet-4-20250514", None, Some(1024), false, None)
278            .unwrap();
279        assert_eq!(req.body["max_tokens"], 1024);
280        assert_eq!(req.body["model"], "claude-sonnet-4-20250514");
281        assert!(req.headers.contains_key("anthropic-version"));
282    }
283
284    #[test]
285    fn test_anthropic_parse_response() {
286        let driver = AnthropicDriver::new("anthropic", vec![]);
287        let body = serde_json::json!({
288            "content": [{"type": "text", "text": "Hello!"}],
289            "stop_reason": "end_turn",
290            "usage": {"input_tokens": 10, "output_tokens": 5}
291        });
292        let resp = driver.parse_response(&body).unwrap();
293        assert_eq!(resp.content.as_deref(), Some("Hello!"));
294        assert_eq!(resp.finish_reason.as_deref(), Some("stop"));
295        assert_eq!(resp.usage.unwrap().total_tokens, 15);
296    }
297
298    #[test]
299    fn test_anthropic_parse_stream_delta() {
300        let driver = AnthropicDriver::new("anthropic", vec![]);
301        let data = r#"{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Hi"}}"#;
302        let event = driver.parse_stream_event(data).unwrap();
303        match event {
304            Some(StreamingEvent::PartialContentDelta { content, .. }) => {
305                assert_eq!(content, "Hi");
306            }
307            _ => panic!("Expected PartialContentDelta"),
308        }
309    }
310
311    #[test]
312    fn test_anthropic_stop_reason_normalization() {
313        let driver = AnthropicDriver::new("anthropic", vec![]);
314        let body = serde_json::json!({
315            "content": [{"type": "text", "text": ""}],
316            "stop_reason": "tool_use",
317            "usage": {"input_tokens": 0, "output_tokens": 0}
318        });
319        let resp = driver.parse_response(&body).unwrap();
320        assert_eq!(resp.finish_reason.as_deref(), Some("tool_calls"));
321    }
322}