Skip to main content

ai_lib_rust/drivers/
anthropic.rs

1//! Anthropic Messages API 驱动 — 实现 Anthropic 特有的请求/响应格式转换
2//!
3//! Anthropic Messages API driver. Handles the key differences from OpenAI:
4//! - System messages are a top-level `system` parameter, not part of `messages`.
5//! - Content uses typed blocks: `[{"type": "text", "text": "..."}]`.
6//! - Streaming uses `event: content_block_delta` with `delta.text`.
7//! - Response uses `content[0].text` instead of `choices[0].message.content`.
8//! - `max_tokens` is required, not optional.
9
10use async_trait::async_trait;
11use serde_json::Value;
12use std::collections::HashMap;
13
14use crate::error::Error;
15use crate::protocol::v2::capabilities::Capability;
16use crate::protocol::v2::manifest::ApiStyle;
17use crate::protocol::ProtocolError;
18use crate::types::events::StreamingEvent;
19use crate::types::message::{Message, MessageContent, MessageRole};
20
21use super::{DriverRequest, DriverResponse, ProviderDriver, UsageInfo};
22
23const DEFAULT_MAX_TOKENS: u32 = 4096;
24
25/// Anthropic Messages API driver.
26#[derive(Debug)]
27pub struct AnthropicDriver {
28    provider_id: String,
29    capabilities: Vec<Capability>,
30}
31
32impl AnthropicDriver {
33    pub fn new(provider_id: impl Into<String>, capabilities: Vec<Capability>) -> Self {
34        Self {
35            provider_id: provider_id.into(),
36            capabilities,
37        }
38    }
39
40    /// Extract system message and non-system messages separately.
41    /// Anthropic requires system as a top-level param, not in messages array.
42    fn split_system_messages(messages: &[Message]) -> (Option<String>, Vec<Value>) {
43        let mut system_parts: Vec<String> = Vec::new();
44        let mut user_messages: Vec<Value> = Vec::new();
45
46        for m in messages {
47            match m.role {
48                MessageRole::System => {
49                    if let MessageContent::Text(ref s) = m.content {
50                        system_parts.push(s.clone());
51                    }
52                }
53                MessageRole::Tool => {
54                    // Anthropic: tool results are sent as user message with tool_result block
55                    if let (Some(ref id), MessageContent::Text(ref s)) =
56                        (&m.tool_call_id, &m.content)
57                    {
58                        user_messages.push(serde_json::json!({
59                            "role": "user",
60                            "content": [{ "type": "tool_result", "tool_use_id": id, "content": s }],
61                        }));
62                    }
63                }
64                _ => {
65                    let role = match m.role {
66                        MessageRole::User => "user",
67                        MessageRole::Assistant => "assistant",
68                        MessageRole::System => unreachable!(),
69                        MessageRole::Tool => unreachable!(),
70                    };
71                    let content = match &m.content {
72                        MessageContent::Text(s) => {
73                            serde_json::json!([{ "type": "text", "text": s }])
74                        }
75                        MessageContent::Blocks(_) => {
76                            serde_json::to_value(&m.content).unwrap_or(Value::Null)
77                        }
78                    };
79                    user_messages.push(serde_json::json!({
80                        "role": role,
81                        "content": content,
82                    }));
83                }
84            }
85        }
86
87        let system = if system_parts.is_empty() {
88            None
89        } else {
90            Some(system_parts.join("\n\n"))
91        };
92
93        (system, user_messages)
94    }
95}
96
97#[async_trait]
98impl ProviderDriver for AnthropicDriver {
99    fn provider_id(&self) -> &str {
100        &self.provider_id
101    }
102
103    fn api_style(&self) -> ApiStyle {
104        ApiStyle::AnthropicMessages
105    }
106
107    fn build_request(
108        &self,
109        messages: &[Message],
110        model: &str,
111        temperature: Option<f64>,
112        max_tokens: Option<u32>,
113        stream: bool,
114        extra: Option<&Value>,
115    ) -> Result<DriverRequest, Error> {
116        let (system, msgs) = Self::split_system_messages(messages);
117
118        let mut body = serde_json::json!({
119            "model": model,
120            "messages": msgs,
121            "max_tokens": max_tokens.unwrap_or(DEFAULT_MAX_TOKENS),
122            "stream": stream,
123        });
124
125        if let Some(sys) = system {
126            body["system"] = Value::String(sys);
127        }
128        if let Some(t) = temperature {
129            body["temperature"] = serde_json::json!(t);
130        }
131        if let Some(ext) = extra {
132            if let Value::Object(map) = ext {
133                for (k, v) in map {
134                    body[k] = v.clone();
135                }
136            }
137        }
138
139        let mut headers = HashMap::new();
140        headers.insert("anthropic-version".into(), "2023-06-01".into());
141
142        Ok(DriverRequest {
143            url: String::new(),
144            method: "POST".into(),
145            headers,
146            body,
147            stream,
148        })
149    }
150
151    fn parse_response(&self, body: &Value) -> Result<DriverResponse, Error> {
152        // Anthropic response: { content: [{type: "text", text: "..."}], stop_reason, usage }
153        let content = body
154            .pointer("/content/0/text")
155            .and_then(|v| v.as_str())
156            .map(String::from);
157
158        // Normalize stop_reason → finish_reason
159        let finish_reason = body
160            .get("stop_reason")
161            .and_then(|v| v.as_str())
162            .map(|r| match r {
163                "end_turn" => "stop".to_string(),
164                "max_tokens" => "length".to_string(),
165                "tool_use" => "tool_calls".to_string(),
166                other => other.to_string(),
167            });
168
169        let usage = body.get("usage").map(|u| UsageInfo {
170            prompt_tokens: u["input_tokens"].as_u64().unwrap_or(0),
171            completion_tokens: u["output_tokens"].as_u64().unwrap_or(0),
172            total_tokens: u["input_tokens"].as_u64().unwrap_or(0)
173                + u["output_tokens"].as_u64().unwrap_or(0),
174        });
175
176        // Extract tool_use blocks from content array
177        let tool_calls: Vec<Value> = body
178            .get("content")
179            .and_then(|c| c.as_array())
180            .map(|arr| {
181                arr.iter()
182                    .filter(|b| b.get("type").and_then(|t| t.as_str()) == Some("tool_use"))
183                    .cloned()
184                    .collect()
185            })
186            .unwrap_or_default();
187
188        Ok(DriverResponse {
189            content,
190            finish_reason,
191            usage,
192            tool_calls,
193            raw: body.clone(),
194        })
195    }
196
197    fn parse_stream_event(&self, data: &str) -> Result<Option<StreamingEvent>, Error> {
198        if data.trim().is_empty() {
199            return Ok(None);
200        }
201
202        let v: Value = serde_json::from_str(data).map_err(|e| {
203            Error::Protocol(ProtocolError::ValidationError(format!(
204                "Failed to parse Anthropic SSE: {}",
205                e
206            )))
207        })?;
208
209        let event_type = v.get("type").and_then(|t| t.as_str()).unwrap_or("");
210
211        match event_type {
212            "content_block_delta" => {
213                if let Some(text) = v.pointer("/delta/text").and_then(|t| t.as_str()) {
214                    if !text.is_empty() {
215                        return Ok(Some(StreamingEvent::PartialContentDelta {
216                            content: text.to_string(),
217                            sequence_id: v.get("index").and_then(|i| i.as_u64()),
218                        }));
219                    }
220                }
221                // Thinking delta support
222                if let Some(thinking) = v.pointer("/delta/thinking").and_then(|t| t.as_str()) {
223                    return Ok(Some(StreamingEvent::ThinkingDelta {
224                        thinking: thinking.to_string(),
225                        tool_consideration: None,
226                    }));
227                }
228                Ok(None)
229            }
230            "message_delta" => {
231                let reason = v.pointer("/delta/stop_reason").and_then(|r| r.as_str());
232                if let Some(r) = reason {
233                    return Ok(Some(StreamingEvent::StreamEnd {
234                        finish_reason: Some(match r {
235                            "end_turn" => "stop".to_string(),
236                            "max_tokens" => "length".to_string(),
237                            other => other.to_string(),
238                        }),
239                    }));
240                }
241                Ok(None)
242            }
243            "message_stop" => Ok(Some(StreamingEvent::StreamEnd {
244                finish_reason: Some("stop".into()),
245            })),
246            "error" => {
247                let error = v.get("error").cloned().unwrap_or(Value::Null);
248                Ok(Some(StreamingEvent::StreamError {
249                    error,
250                    event_id: None,
251                }))
252            }
253            _ => Ok(None),
254        }
255    }
256
257    fn supported_capabilities(&self) -> &[Capability] {
258        &self.capabilities
259    }
260
261    fn is_stream_done(&self, _data: &str) -> bool {
262        // Anthropic signals done via event type, not a sentinel string.
263        // The `event: message_stop` is handled in parse_stream_event.
264        false
265    }
266}
267
268#[cfg(test)]
269mod tests {
270    use super::*;
271
272    #[test]
273    fn test_system_message_extraction() {
274        let msgs = vec![
275            Message::system("You are helpful."),
276            Message::user("Hi"),
277        ];
278        let (sys, user_msgs) = AnthropicDriver::split_system_messages(&msgs);
279        assert_eq!(sys.as_deref(), Some("You are helpful."));
280        assert_eq!(user_msgs.len(), 1);
281        assert_eq!(user_msgs[0]["role"], "user");
282    }
283
284    #[test]
285    fn test_anthropic_build_request() {
286        let driver = AnthropicDriver::new("anthropic", vec![Capability::Text]);
287        let messages = vec![Message::user("Hello")];
288        let req = driver
289            .build_request(&messages, "claude-sonnet-4-20250514", None, Some(1024), false, None)
290            .unwrap();
291        assert_eq!(req.body["max_tokens"], 1024);
292        assert_eq!(req.body["model"], "claude-sonnet-4-20250514");
293        assert!(req.headers.contains_key("anthropic-version"));
294    }
295
296    #[test]
297    fn test_anthropic_parse_response() {
298        let driver = AnthropicDriver::new("anthropic", vec![]);
299        let body = serde_json::json!({
300            "content": [{"type": "text", "text": "Hello!"}],
301            "stop_reason": "end_turn",
302            "usage": {"input_tokens": 10, "output_tokens": 5}
303        });
304        let resp = driver.parse_response(&body).unwrap();
305        assert_eq!(resp.content.as_deref(), Some("Hello!"));
306        assert_eq!(resp.finish_reason.as_deref(), Some("stop"));
307        assert_eq!(resp.usage.unwrap().total_tokens, 15);
308    }
309
310    #[test]
311    fn test_anthropic_parse_stream_delta() {
312        let driver = AnthropicDriver::new("anthropic", vec![]);
313        let data = r#"{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Hi"}}"#;
314        let event = driver.parse_stream_event(data).unwrap();
315        match event {
316            Some(StreamingEvent::PartialContentDelta { content, .. }) => {
317                assert_eq!(content, "Hi");
318            }
319            _ => panic!("Expected PartialContentDelta"),
320        }
321    }
322
323    #[test]
324    fn test_anthropic_stop_reason_normalization() {
325        let driver = AnthropicDriver::new("anthropic", vec![]);
326        let body = serde_json::json!({
327            "content": [{"type": "text", "text": ""}],
328            "stop_reason": "tool_use",
329            "usage": {"input_tokens": 0, "output_tokens": 0}
330        });
331        let resp = driver.parse_response(&body).unwrap();
332        assert_eq!(resp.finish_reason.as_deref(), Some("tool_calls"));
333    }
334}