forgeai_adapter_anthropic/
lib.rs1use async_stream::try_stream;
2use async_trait::async_trait;
3use forgeai_core::{
4 AdapterInfo, CapabilityMatrix, ChatAdapter, ChatRequest, ChatResponse, ForgeError, Role,
5 StreamEvent, StreamResult, ToolCall, Usage,
6};
7use futures_util::StreamExt;
8use reqwest::{Client as HttpClient, StatusCode};
9use serde_json::{json, Map, Value};
10use std::env;
11use url::Url;
12
13#[derive(Clone, Debug)]
14pub struct AnthropicAdapter {
15 pub api_key: String,
16 pub base_url: Url,
17 pub api_version: String,
18 client: HttpClient,
19}
20
21impl AnthropicAdapter {
22 pub fn new(api_key: impl Into<String>) -> Result<Self, ForgeError> {
23 let base_url = Url::parse("https://api.anthropic.com")
24 .map_err(|e| ForgeError::Internal(e.to_string()))?;
25 Self::with_base_url(api_key, base_url)
26 }
27
28 pub fn with_base_url(api_key: impl Into<String>, base_url: Url) -> Result<Self, ForgeError> {
29 let client = HttpClient::builder()
30 .build()
31 .map_err(|e| ForgeError::Internal(format!("failed to build http client: {e}")))?;
32 Ok(Self {
33 api_key: api_key.into(),
34 base_url,
35 api_version: "2023-06-01".to_string(),
36 client,
37 })
38 }
39
40 pub fn from_env() -> Result<Self, ForgeError> {
41 let api_key = env::var("ANTHROPIC_API_KEY").map_err(|_| ForgeError::Authentication)?;
42 match env::var("ANTHROPIC_BASE_URL") {
43 Ok(raw) => {
44 let base_url = Url::parse(&raw).map_err(|e| {
45 ForgeError::Validation(format!("invalid ANTHROPIC_BASE_URL: {e}"))
46 })?;
47 Self::with_base_url(api_key, base_url)
48 }
49 Err(_) => Self::new(api_key),
50 }
51 }
52
53 fn messages_url(&self) -> Result<Url, ForgeError> {
54 self.base_url
55 .join("v1/messages")
56 .map_err(|e| ForgeError::Internal(format!("failed to construct endpoint url: {e}")))
57 }
58}
59
60#[async_trait]
61impl ChatAdapter for AnthropicAdapter {
62 fn info(&self) -> AdapterInfo {
63 AdapterInfo {
64 name: "anthropic".to_string(),
65 base_url: Url::parse("https://api.anthropic.com").ok(),
66 capabilities: CapabilityMatrix {
67 streaming: true,
68 tools: true,
69 structured_output: true,
70 multimodal_input: true,
71 citations: false,
72 },
73 }
74 }
75
76 async fn chat(&self, request: ChatRequest) -> Result<ChatResponse, ForgeError> {
77 let response = self
78 .client
79 .post(self.messages_url()?)
80 .header("x-api-key", &self.api_key)
81 .header("anthropic-version", &self.api_version)
82 .json(&build_messages_body(request, false))
83 .send()
84 .await
85 .map_err(|e| ForgeError::Transport(format!("request failed: {e}")))?;
86
87 if !response.status().is_success() {
88 let status = response.status();
89 let text = response
90 .text()
91 .await
92 .unwrap_or_else(|_| "failed to read error body".to_string());
93 return Err(parse_http_error(status, text));
94 }
95
96 let payload = response
97 .json::<Value>()
98 .await
99 .map_err(|e| ForgeError::Provider(format!("invalid json response: {e}")))?;
100 parse_chat_response(payload)
101 }
102
103 async fn chat_stream(
104 &self,
105 request: ChatRequest,
106 ) -> Result<StreamResult<StreamEvent>, ForgeError> {
107 let response = self
108 .client
109 .post(self.messages_url()?)
110 .header("x-api-key", &self.api_key)
111 .header("anthropic-version", &self.api_version)
112 .json(&build_messages_body(request, true))
113 .send()
114 .await
115 .map_err(|e| ForgeError::Transport(format!("stream request failed: {e}")))?;
116
117 if !response.status().is_success() {
118 let status = response.status();
119 let text = response
120 .text()
121 .await
122 .unwrap_or_else(|_| "failed to read error body".to_string());
123 return Err(parse_http_error(status, text));
124 }
125
126 let mut bytes = response.bytes_stream();
127 let stream = try_stream! {
128 let mut buffer = String::new();
129 let mut saw_done = false;
130 let mut event_name: Option<String> = None;
131 let mut data_lines: Vec<String> = Vec::new();
132
133 while let Some(chunk) = bytes.next().await {
134 let chunk = chunk.map_err(|e| ForgeError::Transport(format!("stream chunk error: {e}")))?;
135 let text = std::str::from_utf8(&chunk)
136 .map_err(|e| ForgeError::Transport(format!("invalid utf8 stream chunk: {e}")))?;
137 buffer.push_str(text);
138
139 while let Some(line_end) = buffer.find('\n') {
140 let mut line = buffer[..line_end].to_string();
141 buffer.drain(..=line_end);
142 if line.ends_with('\r') {
143 line.pop();
144 }
145 if line.is_empty() {
146 if !data_lines.is_empty() {
147 let payload = data_lines.join("\n");
148 let events = parse_stream_payload(&payload, event_name.as_deref())?;
149 for event in events {
150 if matches!(event, StreamEvent::Done) {
151 saw_done = true;
152 }
153 yield event;
154 }
155 data_lines.clear();
156 event_name = None;
157 }
158 continue;
159 }
160 if let Some(name) = line.strip_prefix("event:") {
161 event_name = Some(name.trim().to_string());
162 continue;
163 }
164 if let Some(data) = line.strip_prefix("data:") {
165 data_lines.push(data.trim().to_string());
166 }
167 }
168 }
169
170 if !data_lines.is_empty() {
171 let payload = data_lines.join("\n");
172 let events = parse_stream_payload(&payload, event_name.as_deref())?;
173 for event in events {
174 if matches!(event, StreamEvent::Done) {
175 saw_done = true;
176 }
177 yield event;
178 }
179 }
180
181 if !saw_done {
182 yield StreamEvent::Done;
183 }
184 };
185
186 Ok(Box::pin(stream))
187 }
188}
189
190fn build_messages_body(request: ChatRequest, stream: bool) -> Value {
191 let mut body = Map::new();
192 body.insert("model".to_string(), Value::String(request.model));
193 body.insert(
194 "max_tokens".to_string(),
195 Value::Number((request.max_tokens.unwrap_or(1024)).into()),
196 );
197
198 if let Some(temperature) = request.temperature {
199 body.insert("temperature".to_string(), json!(temperature));
200 }
201
202 let mut system_chunks = Vec::new();
203 let mut messages = Vec::new();
204 for message in request.messages {
205 if matches!(message.role, Role::System) {
206 system_chunks.push(message.content);
207 continue;
208 }
209 let role = match message.role {
210 Role::Assistant => "assistant",
211 _ => "user",
212 };
213 messages.push(json!({
214 "role": role,
215 "content": [{ "type": "text", "text": message.content }]
216 }));
217 }
218 body.insert("messages".to_string(), Value::Array(messages));
219
220 if !system_chunks.is_empty() {
221 body.insert(
222 "system".to_string(),
223 Value::String(system_chunks.join("\n\n")),
224 );
225 }
226
227 if !request.tools.is_empty() {
228 body.insert(
229 "tools".to_string(),
230 Value::Array(
231 request
232 .tools
233 .into_iter()
234 .map(|tool| {
235 json!({
236 "name": tool.name,
237 "description": tool.description,
238 "input_schema": tool.input_schema
239 })
240 })
241 .collect(),
242 ),
243 );
244 }
245
246 if stream {
247 body.insert("stream".to_string(), Value::Bool(true));
248 }
249
250 Value::Object(body)
251}
252
253fn parse_http_error(status: StatusCode, body: String) -> ForgeError {
254 let message = extract_provider_error(body);
255 match status {
256 StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN => ForgeError::Authentication,
257 StatusCode::TOO_MANY_REQUESTS => ForgeError::RateLimited,
258 _ => ForgeError::Provider(message),
259 }
260}
261
262fn extract_provider_error(body: String) -> String {
263 serde_json::from_str::<Value>(&body)
264 .ok()
265 .and_then(|v| {
266 v.get("error")
267 .and_then(|e| e.get("message"))
268 .and_then(Value::as_str)
269 .map(ToString::to_string)
270 })
271 .unwrap_or(body)
272}
273
274fn parse_chat_response(payload: Value) -> Result<ChatResponse, ForgeError> {
275 let id = payload
276 .get("id")
277 .and_then(Value::as_str)
278 .unwrap_or_default()
279 .to_string();
280 let model = payload
281 .get("model")
282 .and_then(Value::as_str)
283 .unwrap_or_default()
284 .to_string();
285
286 let content = payload
287 .get("content")
288 .and_then(Value::as_array)
289 .cloned()
290 .unwrap_or_default();
291 let output_text = extract_text_blocks(&content);
292 let tool_calls = extract_tool_calls_from_blocks(&content);
293 let usage = extract_usage(payload.get("usage"));
294
295 Ok(ChatResponse {
296 id,
297 model,
298 output_text,
299 tool_calls,
300 usage,
301 })
302}
303
304fn extract_text_blocks(content: &[Value]) -> String {
305 content
306 .iter()
307 .filter(|block| block.get("type").and_then(Value::as_str) == Some("text"))
308 .filter_map(|block| block.get("text").and_then(Value::as_str))
309 .collect::<Vec<_>>()
310 .join("")
311}
312
313fn extract_tool_calls_from_blocks(content: &[Value]) -> Vec<ToolCall> {
314 content
315 .iter()
316 .filter(|block| block.get("type").and_then(Value::as_str) == Some("tool_use"))
317 .map(|block| ToolCall {
318 id: block
319 .get("id")
320 .and_then(Value::as_str)
321 .unwrap_or_default()
322 .to_string(),
323 name: block
324 .get("name")
325 .and_then(Value::as_str)
326 .unwrap_or_default()
327 .to_string(),
328 arguments: block.get("input").cloned().unwrap_or(Value::Null),
329 })
330 .collect()
331}
332
333fn extract_usage(raw: Option<&Value>) -> Option<Usage> {
334 let usage = raw?;
335 let input_tokens = usage
336 .get("input_tokens")
337 .and_then(Value::as_u64)
338 .unwrap_or(0) as u32;
339 let output_tokens = usage
340 .get("output_tokens")
341 .and_then(Value::as_u64)
342 .unwrap_or(0) as u32;
343 Some(Usage {
344 input_tokens,
345 output_tokens,
346 total_tokens: input_tokens.saturating_add(output_tokens),
347 })
348}
349
350fn parse_stream_payload(
351 payload: &str,
352 event: Option<&str>,
353) -> Result<Vec<StreamEvent>, ForgeError> {
354 let value = serde_json::from_str::<Value>(payload)
355 .map_err(|e| ForgeError::Provider(format!("invalid stream payload: {e}")))?;
356 let event_type = event
357 .map(ToString::to_string)
358 .or_else(|| {
359 value
360 .get("type")
361 .and_then(Value::as_str)
362 .map(ToString::to_string)
363 })
364 .unwrap_or_default();
365
366 let mut events = Vec::new();
367
368 if let Some(usage) = value
369 .get("usage")
370 .and_then(|v| extract_usage(Some(v)))
371 .or_else(|| {
372 value
373 .get("message")
374 .and_then(|m| m.get("usage"))
375 .and_then(|v| extract_usage(Some(v)))
376 })
377 {
378 events.push(StreamEvent::Usage { usage });
379 }
380
381 if event_type == "content_block_delta" {
382 if let Some(delta_text) = value
383 .get("delta")
384 .and_then(|d| d.get("text"))
385 .and_then(Value::as_str)
386 .filter(|s| !s.is_empty())
387 {
388 events.push(StreamEvent::TextDelta {
389 delta: delta_text.to_string(),
390 });
391 }
392 }
393
394 if event_type == "content_block_start" {
395 if let Some(block) = value.get("content_block") {
396 if block.get("type").and_then(Value::as_str) == Some("tool_use") {
397 let call_id = block
398 .get("id")
399 .and_then(Value::as_str)
400 .unwrap_or_default()
401 .to_string();
402 events.push(StreamEvent::ToolCallDelta {
403 call_id,
404 delta: block.clone(),
405 });
406 }
407 }
408 }
409
410 if event_type == "message_stop" {
411 events.push(StreamEvent::Done);
412 }
413
414 Ok(events)
415}
416
417#[cfg(test)]
418mod tests {
419 use super::*;
420 use forgeai_core::{ChatRequest, Message, Role};
421 use futures_util::StreamExt;
422 use wiremock::matchers::{body_partial_json, header, method, path};
423 use wiremock::{Mock, MockServer, ResponseTemplate};
424
425 fn sample_request() -> ChatRequest {
426 ChatRequest {
427 model: "claude-3-5-sonnet-latest".to_string(),
428 messages: vec![Message {
429 role: Role::User,
430 content: "Say hello".to_string(),
431 }],
432 temperature: Some(0.2),
433 max_tokens: Some(128),
434 tools: vec![],
435 metadata: json!({}),
436 }
437 }
438
439 #[tokio::test]
440 async fn chat_contract_parses_response_and_usage() {
441 let server = MockServer::start().await;
442 Mock::given(method("POST"))
443 .and(path("/v1/messages"))
444 .and(header("x-api-key", "test-key"))
445 .and(header("anthropic-version", "2023-06-01"))
446 .and(body_partial_json(
447 json!({"model": "claude-3-5-sonnet-latest"}),
448 ))
449 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
450 "id": "msg_123",
451 "model": "claude-3-5-sonnet-latest",
452 "content": [{ "type": "text", "text": "Hello from Anthropic" }],
453 "usage": {"input_tokens": 12, "output_tokens": 5}
454 })))
455 .mount(&server)
456 .await;
457
458 let adapter =
459 AnthropicAdapter::with_base_url("test-key", Url::parse(&server.uri()).unwrap())
460 .unwrap();
461 let response = adapter.chat(sample_request()).await.unwrap();
462
463 assert_eq!(response.id, "msg_123");
464 assert_eq!(response.model, "claude-3-5-sonnet-latest");
465 assert_eq!(response.output_text, "Hello from Anthropic");
466 assert_eq!(response.usage.unwrap().total_tokens, 17);
467 }
468
469 #[tokio::test]
470 async fn chat_stream_contract_parses_sse_events() {
471 let server = MockServer::start().await;
472 let sse_body = concat!(
473 "event: message_start\n",
474 "data: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_1\",\"usage\":{\"input_tokens\":10,\"output_tokens\":0}}}\n\n",
475 "event: content_block_delta\n",
476 "data: {\"type\":\"content_block_delta\",\"delta\":{\"type\":\"text_delta\",\"text\":\"Hello\"}}\n\n",
477 "event: content_block_delta\n",
478 "data: {\"type\":\"content_block_delta\",\"delta\":{\"type\":\"text_delta\",\"text\":\" world\"}}\n\n",
479 "event: message_delta\n",
480 "data: {\"type\":\"message_delta\",\"usage\":{\"output_tokens\":2}}\n\n",
481 "event: message_stop\n",
482 "data: {\"type\":\"message_stop\"}\n\n"
483 );
484
485 Mock::given(method("POST"))
486 .and(path("/v1/messages"))
487 .and(header("x-api-key", "test-key"))
488 .and(header("anthropic-version", "2023-06-01"))
489 .and(body_partial_json(json!({"stream": true})))
490 .respond_with(ResponseTemplate::new(200).set_body_raw(sse_body, "text/event-stream"))
491 .mount(&server)
492 .await;
493
494 let adapter =
495 AnthropicAdapter::with_base_url("test-key", Url::parse(&server.uri()).unwrap())
496 .unwrap();
497 let mut stream = adapter.chat_stream(sample_request()).await.unwrap();
498 let mut events = Vec::new();
499 while let Some(item) = stream.next().await {
500 let event = item.unwrap();
501 let done = matches!(event, StreamEvent::Done);
502 events.push(event);
503 if done {
504 break;
505 }
506 }
507
508 assert!(events
509 .iter()
510 .any(|e| matches!(e, StreamEvent::TextDelta { delta } if delta == "Hello")));
511 assert!(events
512 .iter()
513 .any(|e| matches!(e, StreamEvent::TextDelta { delta } if delta == " world")));
514 assert!(events
515 .iter()
516 .any(|e| matches!(e, StreamEvent::Usage { usage } if usage.input_tokens == 10)));
517 assert!(events
518 .iter()
519 .any(|e| matches!(e, StreamEvent::Usage { usage } if usage.output_tokens == 2)));
520 assert!(events.iter().any(|e| matches!(e, StreamEvent::Done)));
521 }
522}