1use async_trait::async_trait;
11use serde_json::Value;
12use std::collections::HashMap;
13
14use crate::error::Error;
15use crate::protocol::v2::capabilities::Capability;
16use crate::protocol::v2::manifest::ApiStyle;
17use crate::protocol::ProtocolError;
18use crate::types::events::StreamingEvent;
19use crate::types::message::{Message, MessageContent, MessageRole};
20
21use super::{DriverRequest, DriverResponse, ProviderDriver, UsageInfo};
22
23const DEFAULT_MAX_TOKENS: u32 = 4096;
24
25#[derive(Debug)]
27pub struct AnthropicDriver {
28 provider_id: String,
29 capabilities: Vec<Capability>,
30}
31
32impl AnthropicDriver {
33 pub fn new(provider_id: impl Into<String>, capabilities: Vec<Capability>) -> Self {
34 Self {
35 provider_id: provider_id.into(),
36 capabilities,
37 }
38 }
39
40 fn split_system_messages(messages: &[Message]) -> (Option<String>, Vec<Value>) {
43 let mut system_parts: Vec<String> = Vec::new();
44 let mut user_messages: Vec<Value> = Vec::new();
45
46 for m in messages {
47 match m.role {
48 MessageRole::System => {
49 if let MessageContent::Text(ref s) = m.content {
50 system_parts.push(s.clone());
51 }
52 }
53 MessageRole::Tool => {
54 if let (Some(ref id), MessageContent::Text(ref s)) =
56 (&m.tool_call_id, &m.content)
57 {
58 user_messages.push(serde_json::json!({
59 "role": "user",
60 "content": [{ "type": "tool_result", "tool_use_id": id, "content": s }],
61 }));
62 }
63 }
64 _ => {
65 let role = match m.role {
66 MessageRole::User => "user",
67 MessageRole::Assistant => "assistant",
68 MessageRole::System => unreachable!(),
69 MessageRole::Tool => unreachable!(),
70 };
71 let content = match &m.content {
72 MessageContent::Text(s) => {
73 serde_json::json!([{ "type": "text", "text": s }])
74 }
75 MessageContent::Blocks(_) => {
76 serde_json::to_value(&m.content).unwrap_or(Value::Null)
77 }
78 };
79 user_messages.push(serde_json::json!({
80 "role": role,
81 "content": content,
82 }));
83 }
84 }
85 }
86
87 let system = if system_parts.is_empty() {
88 None
89 } else {
90 Some(system_parts.join("\n\n"))
91 };
92
93 (system, user_messages)
94 }
95}
96
97#[async_trait]
98impl ProviderDriver for AnthropicDriver {
99 fn provider_id(&self) -> &str {
100 &self.provider_id
101 }
102
103 fn api_style(&self) -> ApiStyle {
104 ApiStyle::AnthropicMessages
105 }
106
107 fn build_request(
108 &self,
109 messages: &[Message],
110 model: &str,
111 temperature: Option<f64>,
112 max_tokens: Option<u32>,
113 stream: bool,
114 extra: Option<&Value>,
115 ) -> Result<DriverRequest, Error> {
116 let (system, msgs) = Self::split_system_messages(messages);
117
118 let mut body = serde_json::json!({
119 "model": model,
120 "messages": msgs,
121 "max_tokens": max_tokens.unwrap_or(DEFAULT_MAX_TOKENS),
122 "stream": stream,
123 });
124
125 if let Some(sys) = system {
126 body["system"] = Value::String(sys);
127 }
128 if let Some(t) = temperature {
129 body["temperature"] = serde_json::json!(t);
130 }
131 if let Some(ext) = extra {
132 if let Value::Object(map) = ext {
133 for (k, v) in map {
134 body[k] = v.clone();
135 }
136 }
137 }
138
139 let mut headers = HashMap::new();
140 headers.insert("anthropic-version".into(), "2023-06-01".into());
141
142 Ok(DriverRequest {
143 url: String::new(),
144 method: "POST".into(),
145 headers,
146 body,
147 stream,
148 })
149 }
150
151 fn parse_response(&self, body: &Value) -> Result<DriverResponse, Error> {
152 let content = body
154 .pointer("/content/0/text")
155 .and_then(|v| v.as_str())
156 .map(String::from);
157
158 let finish_reason = body
160 .get("stop_reason")
161 .and_then(|v| v.as_str())
162 .map(|r| match r {
163 "end_turn" => "stop".to_string(),
164 "max_tokens" => "length".to_string(),
165 "tool_use" => "tool_calls".to_string(),
166 other => other.to_string(),
167 });
168
169 let usage = body.get("usage").map(|u| UsageInfo {
170 prompt_tokens: u["input_tokens"].as_u64().unwrap_or(0),
171 completion_tokens: u["output_tokens"].as_u64().unwrap_or(0),
172 total_tokens: u["input_tokens"].as_u64().unwrap_or(0)
173 + u["output_tokens"].as_u64().unwrap_or(0),
174 });
175
176 let tool_calls: Vec<Value> = body
178 .get("content")
179 .and_then(|c| c.as_array())
180 .map(|arr| {
181 arr.iter()
182 .filter(|b| b.get("type").and_then(|t| t.as_str()) == Some("tool_use"))
183 .cloned()
184 .collect()
185 })
186 .unwrap_or_default();
187
188 Ok(DriverResponse {
189 content,
190 finish_reason,
191 usage,
192 tool_calls,
193 raw: body.clone(),
194 })
195 }
196
197 fn parse_stream_event(&self, data: &str) -> Result<Option<StreamingEvent>, Error> {
198 if data.trim().is_empty() {
199 return Ok(None);
200 }
201
202 let v: Value = serde_json::from_str(data).map_err(|e| {
203 Error::Protocol(ProtocolError::ValidationError(format!(
204 "Failed to parse Anthropic SSE: {}",
205 e
206 )))
207 })?;
208
209 let event_type = v.get("type").and_then(|t| t.as_str()).unwrap_or("");
210
211 match event_type {
212 "content_block_delta" => {
213 if let Some(text) = v.pointer("/delta/text").and_then(|t| t.as_str()) {
214 if !text.is_empty() {
215 return Ok(Some(StreamingEvent::PartialContentDelta {
216 content: text.to_string(),
217 sequence_id: v.get("index").and_then(|i| i.as_u64()),
218 }));
219 }
220 }
221 if let Some(thinking) = v.pointer("/delta/thinking").and_then(|t| t.as_str()) {
223 return Ok(Some(StreamingEvent::ThinkingDelta {
224 thinking: thinking.to_string(),
225 tool_consideration: None,
226 }));
227 }
228 Ok(None)
229 }
230 "message_delta" => {
231 let reason = v.pointer("/delta/stop_reason").and_then(|r| r.as_str());
232 if let Some(r) = reason {
233 return Ok(Some(StreamingEvent::StreamEnd {
234 finish_reason: Some(match r {
235 "end_turn" => "stop".to_string(),
236 "max_tokens" => "length".to_string(),
237 other => other.to_string(),
238 }),
239 }));
240 }
241 Ok(None)
242 }
243 "message_stop" => Ok(Some(StreamingEvent::StreamEnd {
244 finish_reason: Some("stop".into()),
245 })),
246 "error" => {
247 let error = v.get("error").cloned().unwrap_or(Value::Null);
248 Ok(Some(StreamingEvent::StreamError {
249 error,
250 event_id: None,
251 }))
252 }
253 _ => Ok(None),
254 }
255 }
256
257 fn supported_capabilities(&self) -> &[Capability] {
258 &self.capabilities
259 }
260
261 fn is_stream_done(&self, _data: &str) -> bool {
262 false
265 }
266}
267
268#[cfg(test)]
269mod tests {
270 use super::*;
271
272 #[test]
273 fn test_system_message_extraction() {
274 let msgs = vec![
275 Message::system("You are helpful."),
276 Message::user("Hi"),
277 ];
278 let (sys, user_msgs) = AnthropicDriver::split_system_messages(&msgs);
279 assert_eq!(sys.as_deref(), Some("You are helpful."));
280 assert_eq!(user_msgs.len(), 1);
281 assert_eq!(user_msgs[0]["role"], "user");
282 }
283
284 #[test]
285 fn test_anthropic_build_request() {
286 let driver = AnthropicDriver::new("anthropic", vec![Capability::Text]);
287 let messages = vec![Message::user("Hello")];
288 let req = driver
289 .build_request(&messages, "claude-sonnet-4-20250514", None, Some(1024), false, None)
290 .unwrap();
291 assert_eq!(req.body["max_tokens"], 1024);
292 assert_eq!(req.body["model"], "claude-sonnet-4-20250514");
293 assert!(req.headers.contains_key("anthropic-version"));
294 }
295
296 #[test]
297 fn test_anthropic_parse_response() {
298 let driver = AnthropicDriver::new("anthropic", vec![]);
299 let body = serde_json::json!({
300 "content": [{"type": "text", "text": "Hello!"}],
301 "stop_reason": "end_turn",
302 "usage": {"input_tokens": 10, "output_tokens": 5}
303 });
304 let resp = driver.parse_response(&body).unwrap();
305 assert_eq!(resp.content.as_deref(), Some("Hello!"));
306 assert_eq!(resp.finish_reason.as_deref(), Some("stop"));
307 assert_eq!(resp.usage.unwrap().total_tokens, 15);
308 }
309
310 #[test]
311 fn test_anthropic_parse_stream_delta() {
312 let driver = AnthropicDriver::new("anthropic", vec![]);
313 let data = r#"{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Hi"}}"#;
314 let event = driver.parse_stream_event(data).unwrap();
315 match event {
316 Some(StreamingEvent::PartialContentDelta { content, .. }) => {
317 assert_eq!(content, "Hi");
318 }
319 _ => panic!("Expected PartialContentDelta"),
320 }
321 }
322
323 #[test]
324 fn test_anthropic_stop_reason_normalization() {
325 let driver = AnthropicDriver::new("anthropic", vec![]);
326 let body = serde_json::json!({
327 "content": [{"type": "text", "text": ""}],
328 "stop_reason": "tool_use",
329 "usage": {"input_tokens": 0, "output_tokens": 0}
330 });
331 let resp = driver.parse_response(&body).unwrap();
332 assert_eq!(resp.finish_reason.as_deref(), Some("tool_calls"));
333 }
334}