agents_runtime/providers/
anthropic.rs

1use agents_core::llm::{LanguageModel, LlmRequest, LlmResponse};
2use agents_core::messaging::{AgentMessage, MessageContent, MessageRole};
3use agents_core::tools::ToolSchema;
4use async_trait::async_trait;
5use reqwest::Client;
6use serde::{Deserialize, Serialize};
7use serde_json::Value;
8
9#[derive(Clone)]
10pub struct AnthropicConfig {
11    pub api_key: String,
12    pub model: String,
13    pub max_output_tokens: u32,
14    pub api_url: Option<String>,
15    pub api_version: Option<String>,
16    pub custom_headers: Vec<(String, String)>,
17}
18
19impl AnthropicConfig {
20    pub fn new(
21        api_key: impl Into<String>,
22        model: impl Into<String>,
23        max_output_tokens: u32,
24    ) -> Self {
25        Self {
26            api_key: api_key.into(),
27            model: model.into(),
28            max_output_tokens,
29            api_url: None,
30            api_version: None,
31            custom_headers: Vec::new(),
32        }
33    }
34
35    pub fn with_custom_headers(mut self, headers: Vec<(String, String)>) -> Self {
36        self.custom_headers = headers;
37        self
38    }
39}
40
41pub struct AnthropicMessagesModel {
42    client: Client,
43    config: AnthropicConfig,
44}
45
46impl AnthropicMessagesModel {
47    pub fn new(config: AnthropicConfig) -> anyhow::Result<Self> {
48        Ok(Self {
49            client: Client::builder()
50                .user_agent("rust-deep-agents-sdk/0.1")
51                .build()?,
52            config,
53        })
54    }
55}
56
57#[derive(Serialize)]
58struct AnthropicRequest {
59    model: String,
60    max_tokens: u32,
61    system: String,
62    messages: Vec<AnthropicMessage>,
63    #[serde(skip_serializing_if = "Option::is_none")]
64    tools: Option<Vec<AnthropicTool>>,
65}
66
67#[derive(Serialize)]
68struct AnthropicTool {
69    name: String,
70    description: String,
71    input_schema: Value,
72}
73
74#[derive(Serialize)]
75struct AnthropicMessage {
76    role: String,
77    content: Vec<AnthropicContentBlock>,
78}
79
80#[derive(Serialize)]
81struct AnthropicContentBlock {
82    #[serde(rename = "type")]
83    kind: &'static str,
84    text: String,
85    #[serde(skip_serializing_if = "Option::is_none")]
86    cache_control: Option<AnthropicCacheControl>,
87}
88
89#[derive(Serialize)]
90struct AnthropicCacheControl {
91    #[serde(rename = "type")]
92    cache_type: String,
93}
94
95#[derive(Deserialize)]
96struct AnthropicResponse {
97    content: Vec<AnthropicResponseBlock>,
98}
99
100#[derive(Deserialize)]
101struct AnthropicResponseBlock {
102    #[serde(rename = "type")]
103    kind: String,
104    text: Option<String>,
105    #[allow(dead_code)]
106    id: Option<String>,
107    name: Option<String>,
108    input: Option<Value>,
109}
110
111fn to_anthropic_messages(request: &LlmRequest) -> (String, Vec<AnthropicMessage>) {
112    let mut system_prompt = request.system_prompt.clone();
113    let mut messages = Vec::new();
114
115    for message in &request.messages {
116        let text = match &message.content {
117            MessageContent::Text(text) => text.clone(),
118            MessageContent::Json(value) => value.to_string(),
119        };
120
121        // Handle system messages specially - they should be part of the system prompt
122        if matches!(message.role, MessageRole::System) {
123            if !system_prompt.is_empty() {
124                system_prompt.push_str("\n\n");
125            }
126            system_prompt.push_str(&text);
127            continue;
128        }
129
130        let role = match message.role {
131            MessageRole::User => "user",
132            MessageRole::Agent => "assistant",
133            MessageRole::Tool => "user",
134            MessageRole::System => unreachable!(), // Handled above
135        };
136
137        // Convert cache control if present
138        let cache_control = message
139            .metadata
140            .as_ref()
141            .and_then(|meta| meta.cache_control.as_ref())
142            .map(|cc| AnthropicCacheControl {
143                cache_type: cc.cache_type.clone(),
144            });
145
146        messages.push(AnthropicMessage {
147            role: role.to_string(),
148            content: vec![AnthropicContentBlock {
149                kind: "text",
150                text,
151                cache_control,
152            }],
153        });
154    }
155
156    (system_prompt, messages)
157}
158
159/// Convert tool schemas to Anthropic tool format
160fn to_anthropic_tools(tools: &[ToolSchema]) -> Option<Vec<AnthropicTool>> {
161    if tools.is_empty() {
162        return None;
163    }
164
165    Some(
166        tools
167            .iter()
168            .map(|tool| AnthropicTool {
169                name: tool.name.clone(),
170                description: tool.description.clone(),
171                input_schema: serde_json::to_value(&tool.parameters)
172                    .unwrap_or_else(|_| serde_json::json!({})),
173            })
174            .collect(),
175    )
176}
177
178#[async_trait]
179impl LanguageModel for AnthropicMessagesModel {
180    async fn generate(&self, request: LlmRequest) -> anyhow::Result<LlmResponse> {
181        let (system_prompt, messages) = to_anthropic_messages(&request);
182        let tools = to_anthropic_tools(&request.tools);
183
184        // Debug logging
185        tracing::debug!(
186            "Anthropic request: model={}, messages={}, tools={}",
187            self.config.model,
188            messages.len(),
189            tools.as_ref().map(|t| t.len()).unwrap_or(0)
190        );
191
192        let body = AnthropicRequest {
193            model: self.config.model.clone(),
194            max_tokens: self.config.max_output_tokens,
195            system: system_prompt,
196            messages,
197            tools,
198        };
199
200        let url = self
201            .config
202            .api_url
203            .as_deref()
204            .unwrap_or("https://api.anthropic.com/v1/messages");
205        let version = self.config.api_version.as_deref().unwrap_or("2023-06-01");
206
207        let mut request = self
208            .client
209            .post(url)
210            .header("x-api-key", &self.config.api_key)
211            .header("anthropic-version", version);
212
213        for (key, value) in &self.config.custom_headers {
214            request = request.header(key, value);
215        }
216
217        let response = request.json(&body).send().await?.error_for_status()?;
218
219        let data: AnthropicResponse = response.json().await?;
220
221        // Check if response contains tool_use blocks
222        let tool_uses: Vec<_> = data
223            .content
224            .iter()
225            .filter(|block| block.kind == "tool_use")
226            .collect();
227
228        if !tool_uses.is_empty() {
229            // Convert Anthropic tool_use format to our JSON format
230            let tool_calls: Vec<_> = tool_uses
231                .iter()
232                .filter_map(|block| {
233                    Some(serde_json::json!({
234                        "name": block.name.as_ref()?,
235                        "args": block.input.as_ref()?
236                    }))
237                })
238                .collect();
239
240            tracing::debug!("Anthropic response contains {} tool uses", tool_calls.len());
241
242            return Ok(LlmResponse {
243                message: AgentMessage {
244                    role: MessageRole::Agent,
245                    content: MessageContent::Json(serde_json::json!({
246                        "tool_calls": tool_calls
247                    })),
248                    metadata: None,
249                },
250            });
251        }
252
253        // Regular text response
254        let text = data
255            .content
256            .into_iter()
257            .find_map(|block| (block.kind == "text").then(|| block.text.unwrap_or_default()))
258            .unwrap_or_default();
259
260        Ok(LlmResponse {
261            message: AgentMessage {
262                role: MessageRole::Agent,
263                content: MessageContent::Text(text),
264                metadata: None,
265            },
266        })
267    }
268}
269
270#[cfg(test)]
271mod tests {
272    use super::*;
273
274    #[test]
275    fn anthropic_message_conversion_includes_system_prompt() {
276        let request = LlmRequest::new(
277            "You are helpful",
278            vec![AgentMessage {
279                role: MessageRole::User,
280                content: MessageContent::Text("Hello".into()),
281                metadata: None,
282            }],
283        );
284        let (system, messages) = to_anthropic_messages(&request);
285        assert_eq!(system, "You are helpful");
286        assert_eq!(messages.len(), 1);
287        assert_eq!(messages[0].role, "user");
288        assert_eq!(messages[0].content[0].text, "Hello");
289    }
290
291    #[test]
292    fn anthropic_config_new_initializes_empty_custom_headers() {
293        let config = AnthropicConfig::new("test-key", "claude-3", 1024);
294        assert_eq!(config.api_key, "test-key");
295        assert_eq!(config.model, "claude-3");
296        assert_eq!(config.max_output_tokens, 1024);
297        assert!(config.custom_headers.is_empty());
298        assert!(config.api_url.is_none());
299        assert!(config.api_version.is_none());
300    }
301
302    #[test]
303    fn anthropic_config_with_custom_headers_sets_headers() {
304        let headers = vec![
305            ("X-Custom-Header".to_string(), "value1".to_string()),
306            ("X-Another-Header".to_string(), "value2".to_string()),
307        ];
308        let config =
309            AnthropicConfig::new("test-key", "claude-3", 1024).with_custom_headers(headers.clone());
310
311        assert_eq!(config.custom_headers.len(), 2);
312        assert_eq!(config.custom_headers[0].0, "X-Custom-Header");
313        assert_eq!(config.custom_headers[0].1, "value1");
314        assert_eq!(config.custom_headers[1].0, "X-Another-Header");
315        assert_eq!(config.custom_headers[1].1, "value2");
316    }
317}