1use async_trait::async_trait;
11use serde_json::Value;
12use std::collections::HashMap;
13
14use crate::error::Error;
15use crate::protocol::v2::capabilities::Capability;
16use crate::protocol::v2::manifest::ApiStyle;
17use crate::protocol::ProtocolError;
18use crate::types::events::StreamingEvent;
19use crate::types::message::{Message, MessageContent, MessageRole};
20
21use super::{DriverRequest, DriverResponse, ProviderDriver, UsageInfo};
22
23const DEFAULT_MAX_TOKENS: u32 = 4096;
24
25#[derive(Debug)]
27pub struct AnthropicDriver {
28 provider_id: String,
29 capabilities: Vec<Capability>,
30}
31
32impl AnthropicDriver {
33 pub fn new(provider_id: impl Into<String>, capabilities: Vec<Capability>) -> Self {
34 Self {
35 provider_id: provider_id.into(),
36 capabilities,
37 }
38 }
39
40 fn split_system_messages(messages: &[Message]) -> (Option<String>, Vec<Value>) {
43 let mut system_parts: Vec<String> = Vec::new();
44 let mut user_messages: Vec<Value> = Vec::new();
45
46 for m in messages {
47 match m.role {
48 MessageRole::System => {
49 if let MessageContent::Text(ref s) = m.content {
50 system_parts.push(s.clone());
51 }
52 }
53 _ => {
54 let role = match m.role {
55 MessageRole::User => "user",
56 MessageRole::Assistant => "assistant",
57 MessageRole::System => unreachable!(),
58 };
59 let content = match &m.content {
60 MessageContent::Text(s) => {
61 serde_json::json!([{ "type": "text", "text": s }])
62 }
63 MessageContent::Blocks(_) => {
64 serde_json::to_value(&m.content).unwrap_or(Value::Null)
65 }
66 };
67 user_messages.push(serde_json::json!({
68 "role": role,
69 "content": content,
70 }));
71 }
72 }
73 }
74
75 let system = if system_parts.is_empty() {
76 None
77 } else {
78 Some(system_parts.join("\n\n"))
79 };
80
81 (system, user_messages)
82 }
83}
84
85#[async_trait]
86impl ProviderDriver for AnthropicDriver {
87 fn provider_id(&self) -> &str {
88 &self.provider_id
89 }
90
91 fn api_style(&self) -> ApiStyle {
92 ApiStyle::AnthropicMessages
93 }
94
95 fn build_request(
96 &self,
97 messages: &[Message],
98 model: &str,
99 temperature: Option<f64>,
100 max_tokens: Option<u32>,
101 stream: bool,
102 extra: Option<&Value>,
103 ) -> Result<DriverRequest, Error> {
104 let (system, msgs) = Self::split_system_messages(messages);
105
106 let mut body = serde_json::json!({
107 "model": model,
108 "messages": msgs,
109 "max_tokens": max_tokens.unwrap_or(DEFAULT_MAX_TOKENS),
110 "stream": stream,
111 });
112
113 if let Some(sys) = system {
114 body["system"] = Value::String(sys);
115 }
116 if let Some(t) = temperature {
117 body["temperature"] = serde_json::json!(t);
118 }
119 if let Some(ext) = extra {
120 if let Value::Object(map) = ext {
121 for (k, v) in map {
122 body[k] = v.clone();
123 }
124 }
125 }
126
127 let mut headers = HashMap::new();
128 headers.insert("anthropic-version".into(), "2023-06-01".into());
129
130 Ok(DriverRequest {
131 url: String::new(),
132 method: "POST".into(),
133 headers,
134 body,
135 stream,
136 })
137 }
138
139 fn parse_response(&self, body: &Value) -> Result<DriverResponse, Error> {
140 let content = body
142 .pointer("/content/0/text")
143 .and_then(|v| v.as_str())
144 .map(String::from);
145
146 let finish_reason = body
148 .get("stop_reason")
149 .and_then(|v| v.as_str())
150 .map(|r| match r {
151 "end_turn" => "stop".to_string(),
152 "max_tokens" => "length".to_string(),
153 "tool_use" => "tool_calls".to_string(),
154 other => other.to_string(),
155 });
156
157 let usage = body.get("usage").map(|u| UsageInfo {
158 prompt_tokens: u["input_tokens"].as_u64().unwrap_or(0),
159 completion_tokens: u["output_tokens"].as_u64().unwrap_or(0),
160 total_tokens: u["input_tokens"].as_u64().unwrap_or(0)
161 + u["output_tokens"].as_u64().unwrap_or(0),
162 });
163
164 let tool_calls: Vec<Value> = body
166 .get("content")
167 .and_then(|c| c.as_array())
168 .map(|arr| {
169 arr.iter()
170 .filter(|b| b.get("type").and_then(|t| t.as_str()) == Some("tool_use"))
171 .cloned()
172 .collect()
173 })
174 .unwrap_or_default();
175
176 Ok(DriverResponse {
177 content,
178 finish_reason,
179 usage,
180 tool_calls,
181 raw: body.clone(),
182 })
183 }
184
185 fn parse_stream_event(&self, data: &str) -> Result<Option<StreamingEvent>, Error> {
186 if data.trim().is_empty() {
187 return Ok(None);
188 }
189
190 let v: Value = serde_json::from_str(data).map_err(|e| {
191 Error::Protocol(ProtocolError::ValidationError(format!(
192 "Failed to parse Anthropic SSE: {}",
193 e
194 )))
195 })?;
196
197 let event_type = v.get("type").and_then(|t| t.as_str()).unwrap_or("");
198
199 match event_type {
200 "content_block_delta" => {
201 if let Some(text) = v.pointer("/delta/text").and_then(|t| t.as_str()) {
202 if !text.is_empty() {
203 return Ok(Some(StreamingEvent::PartialContentDelta {
204 content: text.to_string(),
205 sequence_id: v.get("index").and_then(|i| i.as_u64()),
206 }));
207 }
208 }
209 if let Some(thinking) = v.pointer("/delta/thinking").and_then(|t| t.as_str()) {
211 return Ok(Some(StreamingEvent::ThinkingDelta {
212 thinking: thinking.to_string(),
213 tool_consideration: None,
214 }));
215 }
216 Ok(None)
217 }
218 "message_delta" => {
219 let reason = v.pointer("/delta/stop_reason").and_then(|r| r.as_str());
220 if let Some(r) = reason {
221 return Ok(Some(StreamingEvent::StreamEnd {
222 finish_reason: Some(match r {
223 "end_turn" => "stop".to_string(),
224 "max_tokens" => "length".to_string(),
225 other => other.to_string(),
226 }),
227 }));
228 }
229 Ok(None)
230 }
231 "message_stop" => Ok(Some(StreamingEvent::StreamEnd {
232 finish_reason: Some("stop".into()),
233 })),
234 "error" => {
235 let error = v.get("error").cloned().unwrap_or(Value::Null);
236 Ok(Some(StreamingEvent::StreamError {
237 error,
238 event_id: None,
239 }))
240 }
241 _ => Ok(None),
242 }
243 }
244
245 fn supported_capabilities(&self) -> &[Capability] {
246 &self.capabilities
247 }
248
249 fn is_stream_done(&self, _data: &str) -> bool {
250 false
253 }
254}
255
256#[cfg(test)]
257mod tests {
258 use super::*;
259
260 #[test]
261 fn test_system_message_extraction() {
262 let msgs = vec![
263 Message::system("You are helpful."),
264 Message::user("Hi"),
265 ];
266 let (sys, user_msgs) = AnthropicDriver::split_system_messages(&msgs);
267 assert_eq!(sys.as_deref(), Some("You are helpful."));
268 assert_eq!(user_msgs.len(), 1);
269 assert_eq!(user_msgs[0]["role"], "user");
270 }
271
272 #[test]
273 fn test_anthropic_build_request() {
274 let driver = AnthropicDriver::new("anthropic", vec![Capability::Text]);
275 let messages = vec![Message::user("Hello")];
276 let req = driver
277 .build_request(&messages, "claude-sonnet-4-20250514", None, Some(1024), false, None)
278 .unwrap();
279 assert_eq!(req.body["max_tokens"], 1024);
280 assert_eq!(req.body["model"], "claude-sonnet-4-20250514");
281 assert!(req.headers.contains_key("anthropic-version"));
282 }
283
284 #[test]
285 fn test_anthropic_parse_response() {
286 let driver = AnthropicDriver::new("anthropic", vec![]);
287 let body = serde_json::json!({
288 "content": [{"type": "text", "text": "Hello!"}],
289 "stop_reason": "end_turn",
290 "usage": {"input_tokens": 10, "output_tokens": 5}
291 });
292 let resp = driver.parse_response(&body).unwrap();
293 assert_eq!(resp.content.as_deref(), Some("Hello!"));
294 assert_eq!(resp.finish_reason.as_deref(), Some("stop"));
295 assert_eq!(resp.usage.unwrap().total_tokens, 15);
296 }
297
298 #[test]
299 fn test_anthropic_parse_stream_delta() {
300 let driver = AnthropicDriver::new("anthropic", vec![]);
301 let data = r#"{"type":"content_block_delta","index":0,"delta":{"type":"text_delta","text":"Hi"}}"#;
302 let event = driver.parse_stream_event(data).unwrap();
303 match event {
304 Some(StreamingEvent::PartialContentDelta { content, .. }) => {
305 assert_eq!(content, "Hi");
306 }
307 _ => panic!("Expected PartialContentDelta"),
308 }
309 }
310
311 #[test]
312 fn test_anthropic_stop_reason_normalization() {
313 let driver = AnthropicDriver::new("anthropic", vec![]);
314 let body = serde_json::json!({
315 "content": [{"type": "text", "text": ""}],
316 "stop_reason": "tool_use",
317 "usage": {"input_tokens": 0, "output_tokens": 0}
318 });
319 let resp = driver.parse_response(&body).unwrap();
320 assert_eq!(resp.finish_reason.as_deref(), Some("tool_calls"));
321 }
322}