Skip to main content

forgeai_adapter_anthropic/
lib.rs

1use async_stream::try_stream;
2use async_trait::async_trait;
3use forgeai_core::{
4    AdapterInfo, CapabilityMatrix, ChatAdapter, ChatRequest, ChatResponse, ForgeError, Role,
5    StreamEvent, StreamResult, ToolCall, Usage,
6};
7use futures_util::StreamExt;
8use reqwest::{Client as HttpClient, StatusCode};
9use serde_json::{json, Map, Value};
10use std::env;
11use url::Url;
12
13#[derive(Clone, Debug)]
14pub struct AnthropicAdapter {
15    pub api_key: String,
16    pub base_url: Url,
17    pub api_version: String,
18    client: HttpClient,
19}
20
21impl AnthropicAdapter {
22    pub fn new(api_key: impl Into<String>) -> Result<Self, ForgeError> {
23        let base_url = Url::parse("https://api.anthropic.com")
24            .map_err(|e| ForgeError::Internal(e.to_string()))?;
25        Self::with_base_url(api_key, base_url)
26    }
27
28    pub fn with_base_url(api_key: impl Into<String>, base_url: Url) -> Result<Self, ForgeError> {
29        let client = HttpClient::builder()
30            .build()
31            .map_err(|e| ForgeError::Internal(format!("failed to build http client: {e}")))?;
32        Ok(Self {
33            api_key: api_key.into(),
34            base_url,
35            api_version: "2023-06-01".to_string(),
36            client,
37        })
38    }
39
40    pub fn from_env() -> Result<Self, ForgeError> {
41        let api_key = env::var("ANTHROPIC_API_KEY").map_err(|_| ForgeError::Authentication)?;
42        match env::var("ANTHROPIC_BASE_URL") {
43            Ok(raw) => {
44                let base_url = Url::parse(&raw).map_err(|e| {
45                    ForgeError::Validation(format!("invalid ANTHROPIC_BASE_URL: {e}"))
46                })?;
47                Self::with_base_url(api_key, base_url)
48            }
49            Err(_) => Self::new(api_key),
50        }
51    }
52
53    fn messages_url(&self) -> Result<Url, ForgeError> {
54        self.base_url
55            .join("v1/messages")
56            .map_err(|e| ForgeError::Internal(format!("failed to construct endpoint url: {e}")))
57    }
58}
59
60#[async_trait]
61impl ChatAdapter for AnthropicAdapter {
62    fn info(&self) -> AdapterInfo {
63        AdapterInfo {
64            name: "anthropic".to_string(),
65            base_url: Url::parse("https://api.anthropic.com").ok(),
66            capabilities: CapabilityMatrix {
67                streaming: true,
68                tools: true,
69                structured_output: true,
70                multimodal_input: true,
71                citations: false,
72            },
73        }
74    }
75
76    async fn chat(&self, request: ChatRequest) -> Result<ChatResponse, ForgeError> {
77        let response = self
78            .client
79            .post(self.messages_url()?)
80            .header("x-api-key", &self.api_key)
81            .header("anthropic-version", &self.api_version)
82            .json(&build_messages_body(request, false))
83            .send()
84            .await
85            .map_err(|e| ForgeError::Transport(format!("request failed: {e}")))?;
86
87        if !response.status().is_success() {
88            let status = response.status();
89            let text = response
90                .text()
91                .await
92                .unwrap_or_else(|_| "failed to read error body".to_string());
93            return Err(parse_http_error(status, text));
94        }
95
96        let payload = response
97            .json::<Value>()
98            .await
99            .map_err(|e| ForgeError::Provider(format!("invalid json response: {e}")))?;
100        parse_chat_response(payload)
101    }
102
103    async fn chat_stream(
104        &self,
105        request: ChatRequest,
106    ) -> Result<StreamResult<StreamEvent>, ForgeError> {
107        let response = self
108            .client
109            .post(self.messages_url()?)
110            .header("x-api-key", &self.api_key)
111            .header("anthropic-version", &self.api_version)
112            .json(&build_messages_body(request, true))
113            .send()
114            .await
115            .map_err(|e| ForgeError::Transport(format!("stream request failed: {e}")))?;
116
117        if !response.status().is_success() {
118            let status = response.status();
119            let text = response
120                .text()
121                .await
122                .unwrap_or_else(|_| "failed to read error body".to_string());
123            return Err(parse_http_error(status, text));
124        }
125
126        let mut bytes = response.bytes_stream();
127        let stream = try_stream! {
128            let mut buffer = String::new();
129            let mut saw_done = false;
130            let mut event_name: Option<String> = None;
131            let mut data_lines: Vec<String> = Vec::new();
132
133            while let Some(chunk) = bytes.next().await {
134                let chunk = chunk.map_err(|e| ForgeError::Transport(format!("stream chunk error: {e}")))?;
135                let text = std::str::from_utf8(&chunk)
136                    .map_err(|e| ForgeError::Transport(format!("invalid utf8 stream chunk: {e}")))?;
137                buffer.push_str(text);
138
139                while let Some(line_end) = buffer.find('\n') {
140                    let mut line = buffer[..line_end].to_string();
141                    buffer.drain(..=line_end);
142                    if line.ends_with('\r') {
143                        line.pop();
144                    }
145                    if line.is_empty() {
146                        if !data_lines.is_empty() {
147                            let payload = data_lines.join("\n");
148                            let events = parse_stream_payload(&payload, event_name.as_deref())?;
149                            for event in events {
150                                if matches!(event, StreamEvent::Done) {
151                                    saw_done = true;
152                                }
153                                yield event;
154                            }
155                            data_lines.clear();
156                            event_name = None;
157                        }
158                        continue;
159                    }
160                    if let Some(name) = line.strip_prefix("event:") {
161                        event_name = Some(name.trim().to_string());
162                        continue;
163                    }
164                    if let Some(data) = line.strip_prefix("data:") {
165                        data_lines.push(data.trim().to_string());
166                    }
167                }
168            }
169
170            if !data_lines.is_empty() {
171                let payload = data_lines.join("\n");
172                let events = parse_stream_payload(&payload, event_name.as_deref())?;
173                for event in events {
174                    if matches!(event, StreamEvent::Done) {
175                        saw_done = true;
176                    }
177                    yield event;
178                }
179            }
180
181            if !saw_done {
182                yield StreamEvent::Done;
183            }
184        };
185
186        Ok(Box::pin(stream))
187    }
188}
189
190fn build_messages_body(request: ChatRequest, stream: bool) -> Value {
191    let mut body = Map::new();
192    body.insert("model".to_string(), Value::String(request.model));
193    body.insert(
194        "max_tokens".to_string(),
195        Value::Number((request.max_tokens.unwrap_or(1024)).into()),
196    );
197
198    if let Some(temperature) = request.temperature {
199        body.insert("temperature".to_string(), json!(temperature));
200    }
201
202    let mut system_chunks = Vec::new();
203    let mut messages = Vec::new();
204    for message in request.messages {
205        if matches!(message.role, Role::System) {
206            system_chunks.push(message.content);
207            continue;
208        }
209        let role = match message.role {
210            Role::Assistant => "assistant",
211            _ => "user",
212        };
213        messages.push(json!({
214            "role": role,
215            "content": [{ "type": "text", "text": message.content }]
216        }));
217    }
218    body.insert("messages".to_string(), Value::Array(messages));
219
220    if !system_chunks.is_empty() {
221        body.insert(
222            "system".to_string(),
223            Value::String(system_chunks.join("\n\n")),
224        );
225    }
226
227    if !request.tools.is_empty() {
228        body.insert(
229            "tools".to_string(),
230            Value::Array(
231                request
232                    .tools
233                    .into_iter()
234                    .map(|tool| {
235                        json!({
236                            "name": tool.name,
237                            "description": tool.description,
238                            "input_schema": tool.input_schema
239                        })
240                    })
241                    .collect(),
242            ),
243        );
244    }
245
246    if stream {
247        body.insert("stream".to_string(), Value::Bool(true));
248    }
249
250    Value::Object(body)
251}
252
253fn parse_http_error(status: StatusCode, body: String) -> ForgeError {
254    let message = extract_provider_error(body);
255    match status {
256        StatusCode::UNAUTHORIZED | StatusCode::FORBIDDEN => ForgeError::Authentication,
257        StatusCode::TOO_MANY_REQUESTS => ForgeError::RateLimited,
258        _ => ForgeError::Provider(message),
259    }
260}
261
262fn extract_provider_error(body: String) -> String {
263    serde_json::from_str::<Value>(&body)
264        .ok()
265        .and_then(|v| {
266            v.get("error")
267                .and_then(|e| e.get("message"))
268                .and_then(Value::as_str)
269                .map(ToString::to_string)
270        })
271        .unwrap_or(body)
272}
273
274fn parse_chat_response(payload: Value) -> Result<ChatResponse, ForgeError> {
275    let id = payload
276        .get("id")
277        .and_then(Value::as_str)
278        .unwrap_or_default()
279        .to_string();
280    let model = payload
281        .get("model")
282        .and_then(Value::as_str)
283        .unwrap_or_default()
284        .to_string();
285
286    let content = payload
287        .get("content")
288        .and_then(Value::as_array)
289        .cloned()
290        .unwrap_or_default();
291    let output_text = extract_text_blocks(&content);
292    let tool_calls = extract_tool_calls_from_blocks(&content);
293    let usage = extract_usage(payload.get("usage"));
294
295    Ok(ChatResponse {
296        id,
297        model,
298        output_text,
299        tool_calls,
300        usage,
301    })
302}
303
304fn extract_text_blocks(content: &[Value]) -> String {
305    content
306        .iter()
307        .filter(|block| block.get("type").and_then(Value::as_str) == Some("text"))
308        .filter_map(|block| block.get("text").and_then(Value::as_str))
309        .collect::<Vec<_>>()
310        .join("")
311}
312
313fn extract_tool_calls_from_blocks(content: &[Value]) -> Vec<ToolCall> {
314    content
315        .iter()
316        .filter(|block| block.get("type").and_then(Value::as_str) == Some("tool_use"))
317        .map(|block| ToolCall {
318            id: block
319                .get("id")
320                .and_then(Value::as_str)
321                .unwrap_or_default()
322                .to_string(),
323            name: block
324                .get("name")
325                .and_then(Value::as_str)
326                .unwrap_or_default()
327                .to_string(),
328            arguments: block.get("input").cloned().unwrap_or(Value::Null),
329        })
330        .collect()
331}
332
333fn extract_usage(raw: Option<&Value>) -> Option<Usage> {
334    let usage = raw?;
335    let input_tokens = usage
336        .get("input_tokens")
337        .and_then(Value::as_u64)
338        .unwrap_or(0) as u32;
339    let output_tokens = usage
340        .get("output_tokens")
341        .and_then(Value::as_u64)
342        .unwrap_or(0) as u32;
343    Some(Usage {
344        input_tokens,
345        output_tokens,
346        total_tokens: input_tokens.saturating_add(output_tokens),
347    })
348}
349
350fn parse_stream_payload(
351    payload: &str,
352    event: Option<&str>,
353) -> Result<Vec<StreamEvent>, ForgeError> {
354    let value = serde_json::from_str::<Value>(payload)
355        .map_err(|e| ForgeError::Provider(format!("invalid stream payload: {e}")))?;
356    let event_type = event
357        .map(ToString::to_string)
358        .or_else(|| {
359            value
360                .get("type")
361                .and_then(Value::as_str)
362                .map(ToString::to_string)
363        })
364        .unwrap_or_default();
365
366    let mut events = Vec::new();
367
368    if let Some(usage) = value
369        .get("usage")
370        .and_then(|v| extract_usage(Some(v)))
371        .or_else(|| {
372            value
373                .get("message")
374                .and_then(|m| m.get("usage"))
375                .and_then(|v| extract_usage(Some(v)))
376        })
377    {
378        events.push(StreamEvent::Usage { usage });
379    }
380
381    if event_type == "content_block_delta" {
382        if let Some(delta_text) = value
383            .get("delta")
384            .and_then(|d| d.get("text"))
385            .and_then(Value::as_str)
386            .filter(|s| !s.is_empty())
387        {
388            events.push(StreamEvent::TextDelta {
389                delta: delta_text.to_string(),
390            });
391        }
392    }
393
394    if event_type == "content_block_start" {
395        if let Some(block) = value.get("content_block") {
396            if block.get("type").and_then(Value::as_str) == Some("tool_use") {
397                let call_id = block
398                    .get("id")
399                    .and_then(Value::as_str)
400                    .unwrap_or_default()
401                    .to_string();
402                events.push(StreamEvent::ToolCallDelta {
403                    call_id,
404                    delta: block.clone(),
405                });
406            }
407        }
408    }
409
410    if event_type == "message_stop" {
411        events.push(StreamEvent::Done);
412    }
413
414    Ok(events)
415}
416
417#[cfg(test)]
418mod tests {
419    use super::*;
420    use forgeai_core::{ChatRequest, Message, Role};
421    use futures_util::StreamExt;
422    use wiremock::matchers::{body_partial_json, header, method, path};
423    use wiremock::{Mock, MockServer, ResponseTemplate};
424
425    fn sample_request() -> ChatRequest {
426        ChatRequest {
427            model: "claude-3-5-sonnet-latest".to_string(),
428            messages: vec![Message {
429                role: Role::User,
430                content: "Say hello".to_string(),
431            }],
432            temperature: Some(0.2),
433            max_tokens: Some(128),
434            tools: vec![],
435            metadata: json!({}),
436        }
437    }
438
439    #[tokio::test]
440    async fn chat_contract_parses_response_and_usage() {
441        let server = MockServer::start().await;
442        Mock::given(method("POST"))
443            .and(path("/v1/messages"))
444            .and(header("x-api-key", "test-key"))
445            .and(header("anthropic-version", "2023-06-01"))
446            .and(body_partial_json(
447                json!({"model": "claude-3-5-sonnet-latest"}),
448            ))
449            .respond_with(ResponseTemplate::new(200).set_body_json(json!({
450                "id": "msg_123",
451                "model": "claude-3-5-sonnet-latest",
452                "content": [{ "type": "text", "text": "Hello from Anthropic" }],
453                "usage": {"input_tokens": 12, "output_tokens": 5}
454            })))
455            .mount(&server)
456            .await;
457
458        let adapter =
459            AnthropicAdapter::with_base_url("test-key", Url::parse(&server.uri()).unwrap())
460                .unwrap();
461        let response = adapter.chat(sample_request()).await.unwrap();
462
463        assert_eq!(response.id, "msg_123");
464        assert_eq!(response.model, "claude-3-5-sonnet-latest");
465        assert_eq!(response.output_text, "Hello from Anthropic");
466        assert_eq!(response.usage.unwrap().total_tokens, 17);
467    }
468
469    #[tokio::test]
470    async fn chat_stream_contract_parses_sse_events() {
471        let server = MockServer::start().await;
472        let sse_body = concat!(
473            "event: message_start\n",
474            "data: {\"type\":\"message_start\",\"message\":{\"id\":\"msg_1\",\"usage\":{\"input_tokens\":10,\"output_tokens\":0}}}\n\n",
475            "event: content_block_delta\n",
476            "data: {\"type\":\"content_block_delta\",\"delta\":{\"type\":\"text_delta\",\"text\":\"Hello\"}}\n\n",
477            "event: content_block_delta\n",
478            "data: {\"type\":\"content_block_delta\",\"delta\":{\"type\":\"text_delta\",\"text\":\" world\"}}\n\n",
479            "event: message_delta\n",
480            "data: {\"type\":\"message_delta\",\"usage\":{\"output_tokens\":2}}\n\n",
481            "event: message_stop\n",
482            "data: {\"type\":\"message_stop\"}\n\n"
483        );
484
485        Mock::given(method("POST"))
486            .and(path("/v1/messages"))
487            .and(header("x-api-key", "test-key"))
488            .and(header("anthropic-version", "2023-06-01"))
489            .and(body_partial_json(json!({"stream": true})))
490            .respond_with(ResponseTemplate::new(200).set_body_raw(sse_body, "text/event-stream"))
491            .mount(&server)
492            .await;
493
494        let adapter =
495            AnthropicAdapter::with_base_url("test-key", Url::parse(&server.uri()).unwrap())
496                .unwrap();
497        let mut stream = adapter.chat_stream(sample_request()).await.unwrap();
498        let mut events = Vec::new();
499        while let Some(item) = stream.next().await {
500            let event = item.unwrap();
501            let done = matches!(event, StreamEvent::Done);
502            events.push(event);
503            if done {
504                break;
505            }
506        }
507
508        assert!(events
509            .iter()
510            .any(|e| matches!(e, StreamEvent::TextDelta { delta } if delta == "Hello")));
511        assert!(events
512            .iter()
513            .any(|e| matches!(e, StreamEvent::TextDelta { delta } if delta == " world")));
514        assert!(events
515            .iter()
516            .any(|e| matches!(e, StreamEvent::Usage { usage } if usage.input_tokens == 10)));
517        assert!(events
518            .iter()
519            .any(|e| matches!(e, StreamEvent::Usage { usage } if usage.output_tokens == 2)));
520        assert!(events.iter().any(|e| matches!(e, StreamEvent::Done)));
521    }
522}