1use async_trait::async_trait;
4use reqwest::{Client, header};
5use serde::{Deserialize, Serialize};
6use std::collections::HashMap;
7use once_cell::sync::Lazy;
8use futures_util::{Stream, StreamExt};
9use pin_project::pin_project;
10use std::pin::Pin;
11use std::task::{Context, Poll};
12use bytes::Bytes;
13use tokio::sync::mpsc;
14use std::time::{Duration, Instant};
15use tokio::time::sleep;
16use parking_lot::Mutex;
17use std::sync::Arc;
18use futures::future;
19
20use super::{
21 Provider, Model, Message, MessageRole, MessageContent, MessagePart,
22 GenerateOptions, GenerateResult, StreamChunk, StreamOptions,
23 ToolCall, ToolDefinition, Usage, FinishReason, LanguageModel,
24 ModelInfo, ModelCapabilities, ModelLimits, ModelPricing, ModelStatus,
25 ProviderHealth, ProviderConfig, RateLimitInfo, UsageStats,
26 ModelConfig,
27};
28use super::provider::ModelMetadata;
29use crate::auth::{Auth, AuthCredentials};
30
31struct SimpleAnthropicAuth;
33
34impl SimpleAnthropicAuth {
35 fn new() -> Self {
36 Self
37 }
38}
39
40#[async_trait]
41impl Auth for SimpleAnthropicAuth {
42 fn provider_id(&self) -> &str {
43 "anthropic"
44 }
45
46 async fn get_credentials(&self) -> crate::Result<AuthCredentials> {
47 if let Ok(api_key) = std::env::var("ANTHROPIC_API_KEY") {
49 Ok(AuthCredentials::ApiKey { key: api_key })
50 } else {
51 Err(crate::Error::AuthenticationFailed(
52 "No ANTHROPIC_API_KEY environment variable found".to_string()
53 ))
54 }
55 }
56
57 async fn set_credentials(&self, _credentials: AuthCredentials) -> crate::Result<()> {
58 Err(crate::Error::AuthenticationFailed(
60 "Setting credentials not supported in simple auth".to_string()
61 ))
62 }
63
64 async fn remove_credentials(&self) -> crate::Result<()> {
65 Ok(())
67 }
68
69 async fn has_credentials(&self) -> bool {
70 std::env::var("ANTHROPIC_API_KEY").is_ok()
71 }
72}
73
74#[derive(Debug, Clone, Copy, PartialEq, Eq)]
76pub enum AnthropicMode {
77 Standard,
79 Batch,
81 RealTime,
83}
84
85impl Default for AnthropicMode {
86 fn default() -> Self {
87 Self::Standard
88 }
89}
90
91#[derive(Clone)]
93pub struct AnthropicProvider {
94 client: Client,
95 auth: Arc<dyn Auth>,
96 rate_limiter: Arc<RateLimiter>,
97 config: ProviderConfig,
98 mode: AnthropicMode,
99}
100
101pub(crate) struct RateLimiter {
103 pub(crate) last_request: Mutex<Option<Instant>>,
104 pub(crate) min_interval: Duration,
105}
106
107impl RateLimiter {
108 pub(crate) fn new() -> Self {
109 Self {
110 last_request: Mutex::new(None),
111 min_interval: Duration::from_millis(100), }
113 }
114
115 pub(crate) async fn acquire(&self) {
116 let should_wait = {
117 let mut last = self.last_request.lock();
118 if let Some(last_time) = *last {
119 let elapsed = last_time.elapsed();
120 if elapsed < self.min_interval {
121 Some(self.min_interval - elapsed)
122 } else {
123 *last = Some(Instant::now());
124 None
125 }
126 } else {
127 *last = Some(Instant::now());
128 None
129 }
130 };
131
132 if let Some(wait_time) = should_wait {
133 sleep(wait_time).await;
134 self.last_request.lock().replace(Instant::now());
135 }
136 }
137}
138
139impl AnthropicProvider {
140 pub fn new(auth: Box<dyn Auth>) -> Self {
141 let client = Client::builder()
142 .default_headers({
143 let mut headers = header::HeaderMap::new();
144 headers.insert("anthropic-version", "2023-06-01".parse().unwrap());
145 headers.insert("accept", "application/json".parse().unwrap());
146 headers.insert("content-type", "application/json".parse().unwrap());
147 headers.insert("user-agent", "code-mesh/0.1.0".parse().unwrap());
148 headers
149 })
150 .timeout(Duration::from_secs(300))
151 .build()
152 .unwrap();
153
154 Self {
155 client,
156 auth: Arc::from(auth),
157 rate_limiter: Arc::new(RateLimiter::new()),
158 config: ProviderConfig {
159 provider_id: "anthropic".to_string(),
160 ..Default::default()
161 },
162 mode: AnthropicMode::default(),
163 }
164 }
165
166 async fn execute_with_retry<F, T>(&self, operation: F) -> crate::Result<T>
168 where
169 F: Fn() -> future::BoxFuture<'static, crate::Result<T>>,
170 {
171 let mut attempts = 0;
172 let max_attempts = 3;
173
174 loop {
175 self.rate_limiter.acquire().await;
176
177 match operation().await {
178 Ok(result) => return Ok(result),
179 Err(e) => {
180 attempts += 1;
181
182 if attempts >= max_attempts {
183 return Err(e);
184 }
185
186 let should_retry = match &e {
188 crate::Error::Network(req_err) => {
189 req_err.status().map_or(true, |status| {
190 status.as_u16() >= 500 || status.as_u16() == 429
191 })
192 },
193 crate::Error::Provider(msg) => {
194 msg.contains("rate_limit") || msg.contains("timeout")
195 },
196 _ => false,
197 };
198
199 if !should_retry {
200 return Err(e);
201 }
202
203 let delay = Duration::from_millis(1000 * (2_u64.pow(attempts - 1)));
205 sleep(delay).await;
206 }
207 }
208 }
209 }
210
211 pub(crate) async fn validate_and_refresh_credentials(&self) -> crate::Result<String> {
213 let credentials = self.auth.get_credentials().await?;
214
215 match credentials {
216 AuthCredentials::ApiKey { key } => {
217 if !key.starts_with("sk-ant-") {
219 return Err(crate::Error::AuthenticationFailed(
220 "Invalid Anthropic API key format".to_string()
221 ));
222 }
223 Ok(key)
224 },
225 AuthCredentials::OAuth { access_token, refresh_token, expires_at } => {
226 if let Some(expires_at) = expires_at {
228 let now = std::time::SystemTime::now()
229 .duration_since(std::time::UNIX_EPOCH)
230 .unwrap()
231 .as_secs();
232
233 if now >= expires_at {
234 if let Some(refresh_token) = refresh_token {
236 return self.refresh_oauth_token(refresh_token).await;
237 } else {
238 return Err(crate::Error::AuthenticationFailed(
239 "OAuth token expired and no refresh token available".to_string()
240 ));
241 }
242 }
243 }
244 Ok(access_token)
245 },
246 _ => Err(crate::Error::AuthenticationFailed(
247 "Unsupported credential type for Anthropic".to_string()
248 )),
249 }
250 }
251
252 async fn refresh_oauth_token(&self, refresh_token: String) -> crate::Result<String> {
253 let refresh_request = self.client
255 .post("https://api.anthropic.com/oauth/token")
256 .json(&serde_json::json!({
257 "grant_type": "refresh_token",
258 "refresh_token": refresh_token
259 }))
260 .send()
261 .await?;
262
263 if !refresh_request.status().is_success() {
264 return Err(crate::Error::AuthenticationFailed(
265 "Failed to refresh OAuth token".to_string()
266 ));
267 }
268
269 let refresh_response: serde_json::Value = refresh_request.json().await?;
270
271 let new_access_token = refresh_response["access_token"]
272 .as_str()
273 .ok_or_else(|| crate::Error::AuthenticationFailed(
274 "Invalid refresh response".to_string()
275 ))?
276 .to_string();
277
278 let new_refresh_token = refresh_response["refresh_token"]
279 .as_str()
280 .map(|s| s.to_string());
281
282 let expires_in = refresh_response["expires_in"]
283 .as_u64()
284 .unwrap_or(3600);
285
286 let expires_at = std::time::SystemTime::now()
287 .duration_since(std::time::UNIX_EPOCH)
288 .unwrap()
289 .as_secs() + expires_in;
290
291 let new_credentials = AuthCredentials::OAuth {
293 access_token: new_access_token.clone(),
294 refresh_token: new_refresh_token,
295 expires_at: Some(expires_at),
296 };
297
298 self.auth.set_credentials(new_credentials).await?;
299
300 Ok(new_access_token)
301 }
302}
303
304#[async_trait]
305impl Provider for AnthropicProvider {
306 fn id(&self) -> &str {
307 "anthropic"
308 }
309
310 fn name(&self) -> &str {
311 "Anthropic"
312 }
313
314 fn base_url(&self) -> &str {
315 "https://api.anthropic.com"
316 }
317
318 fn api_version(&self) -> &str {
319 "2023-06-01"
320 }
321
322 async fn list_models(&self) -> crate::Result<Vec<ModelInfo>> {
323 Ok(vec![
325 ModelInfo {
326 id: "claude-3-5-sonnet-20241022".to_string(),
327 name: "Claude 3.5 Sonnet".to_string(),
328 description: Some("Latest flagship model with improved performance".to_string()),
329 capabilities: ModelCapabilities {
330 text_generation: true,
331 tool_calling: true,
332 vision: true,
333 streaming: true,
334 caching: true,
335 json_mode: true,
336 reasoning: true,
337 code_generation: true,
338 multilingual: true,
339 custom: HashMap::new(),
340 },
341 limits: ModelLimits {
342 max_context_tokens: 200000,
343 max_output_tokens: 8192,
344 max_image_size_bytes: Some(5 * 1024 * 1024),
345 max_images_per_request: Some(20),
346 max_tool_calls: Some(20),
347 rate_limits: RateLimitInfo {
348 requests_per_minute: Some(100),
349 tokens_per_minute: Some(40000),
350 tokens_per_day: None,
351 concurrent_requests: Some(10),
352 current_usage: None,
353 },
354 },
355 pricing: ModelPricing {
356 input_cost_per_1k: 3.0,
357 output_cost_per_1k: 15.0,
358 cache_read_cost_per_1k: Some(0.3),
359 cache_write_cost_per_1k: Some(3.75),
360 currency: "USD".to_string(),
361 },
362 release_date: Some(chrono::DateTime::parse_from_rfc3339("2024-10-22T00:00:00Z").unwrap().with_timezone(&chrono::Utc)),
363 status: ModelStatus::Active,
364 },
365 ModelInfo {
366 id: "claude-3-5-haiku-20241022".to_string(),
367 name: "Claude 3.5 Haiku".to_string(),
368 description: Some("Fast and efficient model".to_string()),
369 capabilities: ModelCapabilities {
370 text_generation: true,
371 tool_calling: true,
372 vision: true,
373 streaming: true,
374 caching: false,
375 json_mode: true,
376 reasoning: true,
377 code_generation: true,
378 multilingual: true,
379 custom: HashMap::new(),
380 },
381 limits: ModelLimits {
382 max_context_tokens: 200000,
383 max_output_tokens: 8192,
384 max_image_size_bytes: Some(5 * 1024 * 1024),
385 max_images_per_request: Some(20),
386 max_tool_calls: Some(20),
387 rate_limits: RateLimitInfo {
388 requests_per_minute: Some(200),
389 tokens_per_minute: Some(80000),
390 tokens_per_day: None,
391 concurrent_requests: Some(20),
392 current_usage: None,
393 },
394 },
395 pricing: ModelPricing {
396 input_cost_per_1k: 1.0,
397 output_cost_per_1k: 5.0,
398 cache_read_cost_per_1k: None,
399 cache_write_cost_per_1k: None,
400 currency: "USD".to_string(),
401 },
402 release_date: Some(chrono::DateTime::parse_from_rfc3339("2024-10-22T00:00:00Z").unwrap().with_timezone(&chrono::Utc)),
403 status: ModelStatus::Active,
404 },
405 ])
406 }
407
408 async fn get_model(&self, model_id: &str) -> crate::Result<Arc<dyn Model>> {
409 let model = AnthropicModel::new(
411 self.clone(),
412 model_id.to_string(),
413 );
414
415 let model_with_provider = AnthropicModelWithProvider::new(model, self.clone());
417
418 let wrapper = AnthropicModelWrapper::new(model_with_provider);
420
421 Ok(Arc::new(wrapper))
422 }
423
424 async fn health_check(&self) -> crate::Result<ProviderHealth> {
425 let start = std::time::Instant::now();
427
428 match self.auth.get_credentials().await {
430 Ok(_) => {
431 Ok(ProviderHealth {
432 available: true,
433 latency_ms: Some(start.elapsed().as_millis() as u64),
434 error: None,
435 last_check: chrono::Utc::now(),
436 details: HashMap::new(),
437 })
438 }
439 Err(e) => {
440 Ok(ProviderHealth {
441 available: false,
442 latency_ms: Some(start.elapsed().as_millis() as u64),
443 error: Some(e.to_string()),
444 last_check: chrono::Utc::now(),
445 details: HashMap::new(),
446 })
447 }
448 }
449 }
450
451 fn get_config(&self) -> &ProviderConfig {
452 &self.config
453 }
454
455 async fn update_config(&mut self, config: ProviderConfig) -> crate::Result<()> {
456 self.config = config;
457 Ok(())
458 }
459
460 async fn get_rate_limits(&self) -> crate::Result<RateLimitInfo> {
461 Ok(RateLimitInfo {
463 requests_per_minute: Some(100),
464 tokens_per_minute: Some(40000),
465 tokens_per_day: None,
466 concurrent_requests: Some(10),
467 current_usage: None,
468 })
469 }
470
471 async fn get_usage(&self) -> crate::Result<UsageStats> {
472 Ok(UsageStats {
474 total_requests: 0,
475 total_tokens: 0,
476 total_cost: 0.0,
477 currency: "USD".to_string(),
478 by_model: HashMap::new(),
479 by_period: HashMap::new(),
480 })
481 }
482}
483
484
485pub struct AnthropicModel {
487 id: String,
488 provider: AnthropicProvider,
489 client: Client,
490 auth: Arc<dyn Auth>,
491 rate_limiter: Arc<RateLimiter>,
492 model_id: String,
493}
494
495impl AnthropicModel {
496 pub fn new(provider: AnthropicProvider, model_id: String) -> Self {
498 Self {
499 id: model_id.clone(),
500 client: provider.client.clone(),
501 auth: provider.auth.clone(),
502 rate_limiter: provider.rate_limiter.clone(),
503 model_id,
504 provider,
505 }
506 }
507
508 pub fn id(&self) -> &str {
509 &self.id
510 }
511
512 pub fn name(&self) -> &str {
513 &self.model_id
514 }
515
516 pub fn capabilities(&self) -> &ModelCapabilities {
517 static CAPABILITIES: Lazy<ModelCapabilities> = Lazy::new(|| ModelCapabilities {
519 text_generation: true,
520 tool_calling: true,
521 vision: true,
522 streaming: true,
523 caching: true,
524 json_mode: true,
525 reasoning: true,
526 code_generation: true,
527 multilingual: true,
528 custom: HashMap::new(),
529 });
530 &*CAPABILITIES
531 }
532
533 pub fn config(&self) -> &ModelConfig {
534 static CONFIG: Lazy<ModelConfig> = Lazy::new(|| ModelConfig::default());
536 &*CONFIG
537 }
538
539 pub fn metadata(&self) -> &ModelMetadata {
540 static METADATA: Lazy<ModelMetadata> = Lazy::new(|| ModelMetadata {
542 family: "claude".to_string(),
543 ..Default::default()
544 });
545 &*METADATA
546 }
547
548 pub async fn count_tokens(&self, messages: &[Message]) -> crate::Result<u32> {
549 let mut total_tokens = 0u32;
551 for message in messages {
552 match &message.content {
553 MessageContent::Text(text) => {
554 total_tokens += (text.len() / 4) as u32;
556 }
557 MessageContent::Parts(parts) => {
558 for part in parts {
559 match part {
560 MessagePart::Text { text } => {
561 total_tokens += (text.len() / 4) as u32;
562 }
563 MessagePart::Image { .. } => {
564 total_tokens += 1000;
566 }
567 }
568 }
569 }
570 }
571 }
572 Ok(total_tokens)
573 }
574
575 pub async fn estimate_cost(&self, input_tokens: u32, output_tokens: u32) -> crate::Result<f64> {
576 let input_cost_per_1k = match self.model_id.as_str() {
578 "claude-3-opus-20240229" => 15.0,
579 "claude-3-sonnet-20240229" => 3.0,
580 "claude-3-haiku-20240307" => 0.25,
581 _ => 3.0, };
583
584 let output_cost_per_1k = match self.model_id.as_str() {
585 "claude-3-opus-20240229" => 75.0,
586 "claude-3-sonnet-20240229" => 15.0,
587 "claude-3-haiku-20240307" => 1.25,
588 _ => 15.0, };
590
591 let input_cost = (input_tokens as f64 / 1000.0) * input_cost_per_1k;
592 let output_cost = (output_tokens as f64 / 1000.0) * output_cost_per_1k;
593
594 Ok(input_cost + output_cost)
595 }
596
597 async fn execute_with_retry_simple<F, Fut, T>(&self, mut operation: F) -> crate::Result<T>
599 where
600 F: FnMut() -> Fut,
601 Fut: std::future::Future<Output = crate::Result<T>>,
602 {
603 let mut attempts = 0;
604 let max_attempts = 3;
605
606 loop {
607 self.rate_limiter.acquire().await;
608
609 match operation().await {
610 Ok(result) => return Ok(result),
611 Err(e) => {
612 attempts += 1;
613
614 if attempts >= max_attempts {
615 return Err(e);
616 }
617
618 let should_retry = match &e {
620 crate::Error::Network(req_err) => {
621 req_err.status().map_or(true, |status| {
622 status.as_u16() >= 500 || status.as_u16() == 429
623 })
624 },
625 crate::Error::Provider(msg) => {
626 msg.contains("rate_limit") || msg.contains("timeout")
627 },
628 _ => false,
629 };
630
631 if !should_retry {
632 return Err(e);
633 }
634
635 let delay = Duration::from_millis(1000 * (2_u64.pow(attempts - 1)));
637 sleep(delay).await;
638 }
639 }
640 }
641 }
642}
643
644#[async_trait]
645impl LanguageModel for AnthropicModel {
646 async fn generate(
647 &self,
648 messages: Vec<Message>,
649 options: GenerateOptions,
650 ) -> crate::Result<GenerateResult> {
651 let credentials = self.auth.get_credentials().await?;
653 let api_key = match credentials {
654 AuthCredentials::ApiKey { key } => {
655 if !key.starts_with("sk-ant-") {
656 return Err(crate::Error::AuthenticationFailed(
657 "Invalid Anthropic API key format".to_string()
658 ));
659 }
660 key
661 },
662 AuthCredentials::OAuth { access_token, .. } => access_token,
663 _ => return Err(crate::Error::AuthenticationFailed(
664 "Unsupported credential type for Anthropic".to_string()
665 )),
666 };
667
668 let (system_prompt, anthropic_messages) = convert_messages_with_system(messages)?;
670
671 let mut request_body = serde_json::json!({
673 "model": self.model_id,
674 "messages": anthropic_messages,
675 "max_tokens": options.max_tokens.unwrap_or(4096),
676 });
677
678 if let Some(system) = system_prompt {
679 request_body["system"] = serde_json::json!(system);
680 }
681
682 if let Some(temp) = options.temperature {
683 request_body["temperature"] = serde_json::json!(temp);
684 }
685
686 if !options.stop_sequences.is_empty() {
687 request_body["stop_sequences"] = serde_json::json!(options.stop_sequences);
688 }
689
690 if !options.tools.is_empty() {
691 request_body["tools"] = serde_json::json!(convert_tools_to_anthropic(options.tools));
692 }
693
694 let client = self.client.clone();
696 let response = self.execute_with_retry_simple(|| {
697 let client = client.clone();
698 let api_key = api_key.clone();
699 let request_body = request_body.clone();
700
701 Box::pin(async move {
702 let response = client
703 .post("https://api.anthropic.com/v1/messages")
704 .header("x-api-key", api_key)
705 .json(&request_body)
706 .send()
707 .await?;
708
709 if !response.status().is_success() {
710 let status = response.status();
711 let error_text = response.text().await.unwrap_or_else(|_| "Unknown error".to_string());
712
713 return Err(crate::Error::Provider(format!(
714 "Anthropic API error ({}): {}",
715 status.as_u16(),
716 error_text
717 )));
718 }
719
720 Ok(response)
721 })
722 }).await?;
723
724 let api_response: AnthropicResponse = response.json().await?;
725
726 Ok(GenerateResult {
728 content: extract_content(&api_response),
729 tool_calls: extract_tool_calls(&api_response),
730 usage: Usage {
731 prompt_tokens: api_response.usage.input_tokens,
732 completion_tokens: api_response.usage.output_tokens,
733 total_tokens: api_response.usage.input_tokens + api_response.usage.output_tokens,
734 },
735 finish_reason: convert_finish_reason(&api_response.stop_reason),
736 })
737 }
738
739 async fn stream(
740 &self,
741 messages: Vec<Message>,
742 options: StreamOptions,
743 ) -> crate::Result<Box<dyn futures::Stream<Item = crate::Result<StreamChunk>> + Send + Unpin>> {
744 let credentials = self.auth.get_credentials().await?;
746 let api_key = match credentials {
747 AuthCredentials::ApiKey { key } => {
748 if !key.starts_with("sk-ant-") {
749 return Err(crate::Error::AuthenticationFailed(
750 "Invalid Anthropic API key format".to_string()
751 ));
752 }
753 key
754 },
755 AuthCredentials::OAuth { access_token, .. } => access_token,
756 _ => return Err(crate::Error::AuthenticationFailed(
757 "Unsupported credential type for Anthropic".to_string()
758 )),
759 };
760
761 let (system_prompt, anthropic_messages) = convert_messages_with_system(messages)?;
763
764 let mut request_body = serde_json::json!({
766 "model": self.model_id,
767 "messages": anthropic_messages,
768 "max_tokens": options.max_tokens.unwrap_or(4096),
769 "stream": true
770 });
771
772 if let Some(system) = system_prompt {
773 request_body["system"] = serde_json::json!(system);
774 }
775
776 if let Some(temp) = options.temperature {
777 request_body["temperature"] = serde_json::json!(temp);
778 }
779
780 if !options.stop_sequences.is_empty() {
781 request_body["stop_sequences"] = serde_json::json!(options.stop_sequences);
782 }
783
784 if !options.tools.is_empty() {
785 request_body["tools"] = serde_json::json!(convert_tools_to_anthropic(options.tools));
786 }
787
788 self.rate_limiter.acquire().await;
790
791 let response = self.client
792 .post("https://api.anthropic.com/v1/messages")
793 .header("x-api-key", api_key)
794 .header("accept", "text/event-stream")
795 .json(&request_body)
796 .send()
797 .await?;
798
799 if !response.status().is_success() {
800 let status = response.status();
801 let error_text = response.text().await.unwrap_or_else(|_| "Unknown error".to_string());
802
803 return Err(crate::Error::Provider(format!(
804 "Anthropic streaming API error ({}): {}",
805 status.as_u16(),
806 error_text
807 )));
808 }
809
810 let stream = AnthropicStream::new(response.bytes_stream());
812 Ok(Box::new(stream))
813 }
814
815 fn supports_tools(&self) -> bool {
816 matches!(self.model_id.as_str(),
817 "claude-3-5-sonnet-20241022" |
818 "claude-3-5-haiku-20241022" |
819 "claude-3-opus-20240229" |
820 "claude-3-sonnet-20240229" |
821 "claude-3-haiku-20240307"
822 )
823 }
824
825 fn supports_vision(&self) -> bool {
826 self.model_id.contains("claude-3")
827 }
828
829 fn supports_caching(&self) -> bool {
830 matches!(self.model_id.as_str(),
831 "claude-3-5-sonnet-20241022" |
832 "claude-3-opus-20240229"
833 )
834 }
835}
836
837#[derive(Debug, Clone)]
839pub(crate) struct AnthropicModelInfo {
840 id: String,
841 name: String,
842 provider_id: String,
843 capabilities: ModelCapabilities,
844 config: ModelConfig,
845 metadata: ModelMetadata,
846}
847
848impl AnthropicModelInfo {
849 pub fn new(id: String, name: String) -> Self {
850 Self {
851 id: id.clone(),
852 name,
853 provider_id: "anthropic".to_string(),
854 capabilities: ModelCapabilities {
855 text_generation: true,
856 tool_calling: true,
857 vision: true,
858 streaming: true,
859 caching: true,
860 json_mode: true,
861 reasoning: true,
862 code_generation: true,
863 multilingual: true,
864 custom: HashMap::new(),
865 },
866 config: ModelConfig {
867 model_id: id,
868 ..Default::default()
869 },
870 metadata: ModelMetadata {
871 family: "claude".to_string(),
872 ..Default::default()
873 },
874 }
875 }
876}
877
878#[async_trait]
879impl Model for AnthropicModelInfo {
880 fn id(&self) -> &str { &self.id }
881 fn name(&self) -> &str { &self.name }
882 fn provider_id(&self) -> &str { &self.provider_id }
883 fn capabilities(&self) -> &ModelCapabilities { &self.capabilities }
884 fn config(&self) -> &ModelConfig { &self.config }
885
886 async fn generate(
887 &self,
888 _messages: Vec<Message>,
889 _options: GenerateOptions,
890 ) -> crate::Result<GenerateResult> {
891 Err(crate::Error::Other(anyhow::anyhow!("AnthropicModelInfo does not support generation directly")))
892 }
893
894 async fn stream(
895 &self,
896 _messages: Vec<Message>,
897 _options: GenerateOptions,
898 ) -> crate::Result<Pin<Box<dyn Stream<Item = crate::Result<StreamChunk>> + Send>>> {
899 Err(crate::Error::Other(anyhow::anyhow!("AnthropicModelInfo does not support streaming directly")))
900 }
901
902 async fn count_tokens(&self, _messages: &[Message]) -> crate::Result<u32> {
903 Ok(0) }
905
906 async fn estimate_cost(&self, input_tokens: u32, output_tokens: u32) -> crate::Result<f64> {
907 let input_cost = (input_tokens as f64 / 1000.0) * 3.0;
909 let output_cost = (output_tokens as f64 / 1000.0) * 15.0;
910 Ok(input_cost + output_cost)
911 }
912
913 fn metadata(&self) -> &ModelMetadata { &self.metadata }
914}
915
916#[derive(Deserialize)]
918struct AnthropicResponse {
919 content: Vec<AnthropicContent>,
920 stop_reason: Option<String>,
921 usage: AnthropicUsage,
922}
923
924#[derive(Deserialize)]
925struct AnthropicContent {
926 #[serde(rename = "type")]
927 content_type: String,
928 text: Option<String>,
929 name: Option<String>,
930 input: Option<serde_json::Value>,
931}
932
933#[derive(Debug, Deserialize)]
934struct AnthropicUsage {
935 input_tokens: u32,
936 output_tokens: u32,
937}
938
939#[derive(Debug, Deserialize)]
941struct StreamEvent {
942 #[serde(rename = "type")]
943 event_type: String,
944 #[serde(flatten)]
945 data: serde_json::Value,
946}
947
948#[derive(Debug, Deserialize)]
949struct MessageStart {
950 message: MessageStartData,
951}
952
953#[derive(Debug, Deserialize)]
954struct MessageStartData {
955 id: String,
956 #[serde(rename = "type")]
957 message_type: String,
958 role: String,
959 model: String,
960 content: Vec<serde_json::Value>,
961 stop_reason: Option<String>,
962 stop_sequence: Option<String>,
963 usage: AnthropicUsage,
964}
965
966#[derive(Debug, Deserialize)]
967struct ContentBlockStart {
968 index: u32,
969 content_block: ContentBlock,
970}
971
972#[derive(Debug, Deserialize)]
973struct ContentBlock {
974 #[serde(rename = "type")]
975 block_type: String,
976 text: Option<String>,
977 name: Option<String>,
978 input: Option<serde_json::Value>,
979}
980
981#[derive(Debug, Deserialize)]
982struct ContentBlockDelta {
983 index: u32,
984 delta: ContentDelta,
985}
986
987#[derive(Debug, Deserialize)]
988struct ContentDelta {
989 #[serde(rename = "type")]
990 delta_type: String,
991 text: Option<String>,
992 partial_json: Option<String>,
993}
994
995#[derive(Debug, Deserialize)]
996struct MessageDelta {
997 delta: MessageDeltaData,
998 usage: AnthropicUsage,
999}
1000
1001#[derive(Debug, Deserialize)]
1002struct MessageDeltaData {
1003 stop_reason: Option<String>,
1004 stop_sequence: Option<String>,
1005}
1006
1007#[pin_project]
1009struct AnthropicStream {
1010 #[pin]
1011 inner: futures_util::stream::BoxStream<'static, std::result::Result<Bytes, reqwest::Error>>,
1012 buffer: String,
1013 current_tool_calls: Vec<ToolCall>,
1014 tool_call_buffer: HashMap<u32, (String, String)>, finished: bool,
1016}
1017
1018impl AnthropicStream {
1019 fn new(stream: impl Stream<Item = std::result::Result<Bytes, reqwest::Error>> + Send + 'static) -> Self {
1020 Self {
1021 inner: stream.boxed(),
1022 buffer: String::new(),
1023 current_tool_calls: Vec::new(),
1024 tool_call_buffer: HashMap::new(),
1025 finished: false,
1026 }
1027 }
1028
1029 fn parse_sse_line(&mut self, line: &str) -> Option<crate::Result<StreamChunk>> {
1030 if line.is_empty() || line.starts_with(':') {
1031 return None;
1032 }
1033
1034 if !line.starts_with("data: ") {
1035 return None;
1036 }
1037
1038 let data = &line[6..]; if data == "[DONE]" {
1041 self.finished = true;
1042 return Some(Ok(StreamChunk {
1043 delta: String::new(),
1044 tool_calls: Vec::new(),
1045 finish_reason: Some(FinishReason::Stop),
1046 }));
1047 }
1048
1049 let event: StreamEvent = match serde_json::from_str(data) {
1050 Ok(event) => event,
1051 Err(e) => {
1052 return Some(Err(crate::Error::Provider(
1053 format!("Failed to parse SSE event: {}", e)
1054 )));
1055 }
1056 };
1057
1058 self.process_stream_event(event)
1059 }
1060
1061 fn process_stream_event(&mut self, event: StreamEvent) -> Option<crate::Result<StreamChunk>> {
1062 process_stream_event_static(event, &mut self.tool_call_buffer, &mut self.current_tool_calls, &mut self.finished)
1063 }
1064}
1065
1066fn process_sse_line_static(
1068 line: &str,
1069 tool_call_buffer: &mut HashMap<u32, (String, String)>,
1070 _current_tool_calls: &mut Vec<ToolCall>,
1071 finished: &mut bool,
1072) -> Option<crate::Result<StreamChunk>> {
1073 if line.is_empty() || line.starts_with(':') {
1074 return None;
1075 }
1076
1077 if !line.starts_with("data: ") {
1078 return None;
1079 }
1080
1081 let data = &line[6..]; if data == "[DONE]" {
1084 *finished = true;
1085 return Some(Ok(StreamChunk {
1086 delta: String::new(),
1087 tool_calls: Vec::new(),
1088 finish_reason: Some(FinishReason::Stop),
1089 }));
1090 }
1091
1092 let event: StreamEvent = match serde_json::from_str(data) {
1093 Ok(event) => event,
1094 Err(e) => {
1095 return Some(Err(crate::Error::Provider(
1096 format!("Failed to parse SSE event: {}", e)
1097 )));
1098 }
1099 };
1100
1101 process_stream_event_static(event, tool_call_buffer, _current_tool_calls, finished)
1102}
1103
1104fn process_stream_event_static(
1106 event: StreamEvent,
1107 tool_call_buffer: &mut HashMap<u32, (String, String)>,
1108 current_tool_calls: &mut Vec<ToolCall>,
1109 finished: &mut bool,
1110) -> Option<crate::Result<StreamChunk>> {
1111 match event.event_type.as_str() {
1112 "message_start" => {
1113 None
1115 }
1116 "content_block_start" => {
1117 if let Ok(block_start) = serde_json::from_value::<ContentBlockStart>(event.data) {
1118 if block_start.content_block.block_type == "tool_use" {
1119 if let (Some(name), Some(_)) = (&block_start.content_block.name, &block_start.content_block.input) {
1120 tool_call_buffer.insert(block_start.index, (name.clone(), String::new()));
1121 }
1122 }
1123 }
1124 None
1125 }
1126 "content_block_delta" => {
1127 if let Ok(delta) = serde_json::from_value::<ContentBlockDelta>(event.data) {
1128 match delta.delta.delta_type.as_str() {
1129 "text_delta" => {
1130 if let Some(text) = delta.delta.text {
1131 return Some(Ok(StreamChunk {
1132 delta: text,
1133 tool_calls: Vec::new(),
1134 finish_reason: None,
1135 }));
1136 }
1137 }
1138 "input_json_delta" => {
1139 if let Some(partial_json) = delta.delta.partial_json {
1140 if let Some((_name, existing_json)) = tool_call_buffer.get_mut(&delta.index) {
1141 existing_json.push_str(&partial_json);
1142 }
1143 }
1144 }
1145 _ => {}
1146 }
1147 }
1148 None
1149 }
1150 "content_block_stop" => {
1151 if let Ok(block_stop) = serde_json::from_value::<serde_json::Value>(event.data) {
1152 if let Some(index) = block_stop.get("index").and_then(|i| i.as_u64()) {
1153 if let Some((name, json_str)) = tool_call_buffer.remove(&(index as u32)) {
1154 if let Ok(arguments) = serde_json::from_str::<serde_json::Value>(&json_str) {
1155 let tool_call = ToolCall {
1156 id: format!("call_{}", uuid::Uuid::new_v4()),
1157 name,
1158 arguments,
1159 };
1160 current_tool_calls.push(tool_call.clone());
1161 return Some(Ok(StreamChunk {
1162 delta: String::new(),
1163 tool_calls: vec![tool_call],
1164 finish_reason: None,
1165 }));
1166 }
1167 }
1168 }
1169 }
1170 None
1171 }
1172 "message_delta" => {
1173 if let Ok(msg_delta) = serde_json::from_value::<MessageDelta>(event.data) {
1174 if let Some(stop_reason) = msg_delta.delta.stop_reason {
1175 let finish_reason = match stop_reason.as_str() {
1176 "end_turn" => FinishReason::Stop,
1177 "max_tokens" => FinishReason::Length,
1178 "tool_use" => FinishReason::ToolCalls,
1179 "stop_sequence" => FinishReason::Stop,
1180 _ => FinishReason::Stop,
1181 };
1182 return Some(Ok(StreamChunk {
1183 delta: String::new(),
1184 tool_calls: Vec::new(),
1185 finish_reason: Some(finish_reason),
1186 }));
1187 }
1188 }
1189 None
1190 }
1191 "message_stop" => {
1192 *finished = true;
1193 Some(Ok(StreamChunk {
1194 delta: String::new(),
1195 tool_calls: Vec::new(),
1196 finish_reason: Some(FinishReason::Stop),
1197 }))
1198 }
1199 "error" => {
1200 if let Some(error_msg) = event.data.get("error").and_then(|e| e.as_str()) {
1201 Some(Err(crate::Error::Provider(format!("Anthropic streaming error: {}", error_msg))))
1202 } else {
1203 Some(Err(crate::Error::Provider("Unknown Anthropic streaming error".to_string())))
1204 }
1205 }
1206 _ => None, }
1208}
1209
1210impl Stream for AnthropicStream {
1211 type Item = crate::Result<StreamChunk>;
1212
1213 fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
1214 let mut this = self.project();
1215
1216 if *this.finished {
1217 return Poll::Ready(None);
1218 }
1219
1220 loop {
1221 match this.inner.as_mut().poll_next(cx) {
1222 Poll::Ready(Some(Ok(chunk))) => {
1223 let chunk_str = String::from_utf8_lossy(&chunk);
1224 this.buffer.push_str(&chunk_str);
1225
1226 while let Some(line_end) = this.buffer.find('\n') {
1228 let line = this.buffer[..line_end].trim_end_matches('\r').to_string();
1229 this.buffer.drain(..=line_end);
1230
1231 if let Some(result) = process_sse_line_static(&line, &mut this.tool_call_buffer, &mut this.current_tool_calls, &mut this.finished) {
1232 return Poll::Ready(Some(result));
1233 }
1234 }
1235 }
1236 Poll::Ready(Some(Err(e))) => {
1237 return Poll::Ready(Some(Err(crate::Error::Other(e.into()))));
1238 }
1239 Poll::Ready(None) => {
1240 *this.finished = true;
1241 return Poll::Ready(None);
1242 }
1243 Poll::Pending => return Poll::Pending,
1244 }
1245 }
1246 }
1247}
1248
1249fn convert_messages(messages: Vec<Message>) -> Vec<serde_json::Value> {
1251 messages.into_iter().filter_map(|msg| {
1252 if matches!(msg.role, MessageRole::System) {
1254 return None;
1255 }
1256
1257 let role = match msg.role {
1258 MessageRole::User => "user",
1259 MessageRole::Assistant => "assistant",
1260 MessageRole::Tool => "user", MessageRole::System => return None, };
1263
1264 let content = match msg.content {
1265 MessageContent::Text(text) => {
1266 if matches!(msg.role, MessageRole::Tool) {
1267 serde_json::json!([{
1269 "type": "tool_result",
1270 "tool_use_id": msg.tool_call_id.unwrap_or_else(|| "unknown".to_string()),
1271 "content": text
1272 }])
1273 } else {
1274 serde_json::json!(text)
1275 }
1276 }
1277 MessageContent::Parts(parts) => {
1278 let anthropic_content: Vec<serde_json::Value> = parts.into_iter().map(|part| {
1279 match part {
1280 super::MessagePart::Text { text } => serde_json::json!({
1281 "type": "text",
1282 "text": text
1283 }),
1284 super::MessagePart::Image { image } => {
1285 if let Some(base64) = image.base64 {
1286 serde_json::json!({
1287 "type": "image",
1288 "source": {
1289 "type": "base64",
1290 "media_type": image.mime_type,
1291 "data": base64
1292 }
1293 })
1294 } else if let Some(url) = image.url {
1295 serde_json::json!({
1296 "type": "image",
1297 "source": {
1298 "type": "url",
1299 "url": url
1300 }
1301 })
1302 } else {
1303 serde_json::json!({
1304 "type": "text",
1305 "text": "[Invalid image data]"
1306 })
1307 }
1308 }
1309 }
1310 }).collect();
1311 serde_json::json!(anthropic_content)
1312 }
1313 };
1314
1315 let mut obj = serde_json::json!({
1316 "role": role,
1317 "content": content,
1318 });
1319
1320 if let Some(tool_calls) = &msg.tool_calls {
1322 if !tool_calls.is_empty() {
1323 let mut content_array = vec![];
1324
1325 if let serde_json::Value::String(text) = &content {
1327 if !text.trim().is_empty() {
1328 content_array.push(serde_json::json!({
1329 "type": "text",
1330 "text": text
1331 }));
1332 }
1333 }
1334
1335 for tool_call in tool_calls {
1337 content_array.push(serde_json::json!({
1338 "type": "tool_use",
1339 "id": tool_call.id,
1340 "name": tool_call.name,
1341 "input": tool_call.arguments
1342 }));
1343 }
1344
1345 obj["content"] = serde_json::json!(content_array);
1346 }
1347 }
1348
1349 Some(obj)
1350 }).collect()
1351}
1352
1353pub(crate) fn convert_messages_with_system(messages: Vec<Message>) -> crate::Result<(Option<String>, Vec<serde_json::Value>)> {
1355 let mut system_prompt = None;
1356 let mut filtered_messages = Vec::new();
1357
1358 for msg in messages {
1359 match msg.role {
1360 MessageRole::System => {
1361 match msg.content {
1362 MessageContent::Text(text) => {
1363 if system_prompt.is_some() {
1364 let existing = system_prompt.take().unwrap();
1366 system_prompt = Some(format!("{} {}", existing, text));
1367 } else {
1368 system_prompt = Some(text);
1369 }
1370 }
1371 MessageContent::Parts(_) => {
1372 return Err(crate::Error::Other(anyhow::anyhow!(
1373 "System messages with parts are not supported by Anthropic API"
1374 )));
1375 }
1376 }
1377 }
1378 _ => {
1379 filtered_messages.push(msg);
1380 }
1381 }
1382 }
1383
1384 Ok((system_prompt, convert_messages(filtered_messages)))
1385}
1386
1387fn convert_tools_to_anthropic(tools: Vec<ToolDefinition>) -> Vec<serde_json::Value> {
1388 tools.into_iter().map(|tool| {
1389 serde_json::json!({
1390 "name": tool.name,
1391 "description": tool.description,
1392 "input_schema": tool.parameters,
1393 })
1394 }).collect()
1395}
1396
1397fn extract_content(response: &AnthropicResponse) -> String {
1398 response.content.iter()
1399 .filter_map(|c| c.text.as_ref())
1400 .cloned()
1401 .collect::<Vec<_>>()
1402 .join("")
1403}
1404
1405fn extract_tool_calls(response: &AnthropicResponse) -> Vec<ToolCall> {
1406 response.content.iter()
1407 .filter(|c| c.content_type == "tool_use")
1408 .filter_map(|c| {
1409 Some(ToolCall {
1410 id: uuid::Uuid::new_v4().to_string(),
1411 name: c.name.clone()?,
1412 arguments: c.input.clone()?,
1413 })
1414 })
1415 .collect()
1416}
1417
1418fn convert_finish_reason(stop_reason: &Option<String>) -> FinishReason {
1419 match stop_reason.as_deref() {
1420 Some("end_turn") => FinishReason::Stop,
1421 Some("max_tokens") => FinishReason::Length,
1422 Some("tool_use") => FinishReason::ToolCalls,
1423 _ => FinishReason::Stop,
1424 }
1425}
1426
1427pub struct AnthropicModelWithProvider {
1429 model: AnthropicModel,
1430 provider: AnthropicProvider,
1431}
1432
1433pub struct AnthropicModelWrapper {
1435 inner: AnthropicModelWithProvider,
1436}
1437
1438impl AnthropicModelWrapper {
1439 pub fn new(model_with_provider: AnthropicModelWithProvider) -> Self {
1440 Self { inner: model_with_provider }
1441 }
1442}
1443
1444impl AnthropicModelWithProvider {
1445 pub fn new(model: AnthropicModel, provider: AnthropicProvider) -> Self {
1447 Self { model, provider }
1448 }
1449
1450 pub fn model(&self) -> &AnthropicModel {
1452 &self.model
1453 }
1454
1455 pub fn provider(&self) -> &AnthropicProvider {
1457 &self.provider
1458 }
1459}
1460
1461#[async_trait]
1462impl Model for AnthropicModelWrapper {
1463 fn id(&self) -> &str { self.inner.model.id() }
1464 fn name(&self) -> &str { self.inner.model.name() }
1465 fn provider_id(&self) -> &str { "anthropic" }
1466 fn capabilities(&self) -> &ModelCapabilities { self.inner.model.capabilities() }
1467 fn config(&self) -> &ModelConfig { self.inner.model.config() }
1468
1469 async fn generate(
1470 &self,
1471 messages: Vec<Message>,
1472 options: GenerateOptions,
1473 ) -> crate::Result<GenerateResult> {
1474 self.inner.model.generate(messages, options).await
1475 }
1476
1477 async fn stream(
1478 &self,
1479 messages: Vec<Message>,
1480 options: GenerateOptions,
1481 ) -> crate::Result<Pin<Box<dyn Stream<Item = crate::Result<StreamChunk>> + Send>>> {
1482 let stream_options = StreamOptions {
1484 temperature: options.temperature,
1485 max_tokens: options.max_tokens,
1486 tools: options.tools,
1487 stop_sequences: options.stop_sequences,
1488 };
1489
1490 let stream = self.inner.model.stream(messages, stream_options).await?;
1492 Ok(Box::pin(stream))
1493 }
1494
1495 async fn count_tokens(&self, messages: &[Message]) -> crate::Result<u32> {
1496 self.inner.model.count_tokens(messages).await
1497 }
1498
1499 async fn estimate_cost(&self, input_tokens: u32, output_tokens: u32) -> crate::Result<f64> {
1500 self.inner.model.estimate_cost(input_tokens, output_tokens).await
1501 }
1502
1503 fn metadata(&self) -> &ModelMetadata {
1504 self.inner.model.metadata()
1505 }
1506}
1507
1508#[async_trait]
1509impl LanguageModel for AnthropicModelWrapper {
1510 async fn generate(
1511 &self,
1512 messages: Vec<Message>,
1513 options: GenerateOptions,
1514 ) -> crate::Result<GenerateResult> {
1515 self.inner.generate(messages, options).await
1516 }
1517
1518 async fn stream(
1519 &self,
1520 messages: Vec<Message>,
1521 options: StreamOptions,
1522 ) -> crate::Result<Box<dyn futures::Stream<Item = crate::Result<StreamChunk>> + Send + Unpin>> {
1523 self.inner.stream(messages, options).await
1524 }
1525
1526 fn supports_tools(&self) -> bool {
1527 self.inner.supports_tools()
1528 }
1529
1530 fn supports_vision(&self) -> bool {
1531 self.inner.supports_vision()
1532 }
1533
1534 fn supports_caching(&self) -> bool {
1535 self.inner.supports_caching()
1536 }
1537}
1538
1539#[async_trait]
1540impl LanguageModel for AnthropicModelWithProvider {
1541 async fn generate(
1542 &self,
1543 messages: Vec<Message>,
1544 options: GenerateOptions,
1545 ) -> crate::Result<GenerateResult> {
1546 self.model.generate(messages, options).await
1547 }
1548
1549 async fn stream(
1550 &self,
1551 messages: Vec<Message>,
1552 options: StreamOptions,
1553 ) -> crate::Result<Box<dyn futures::Stream<Item = crate::Result<StreamChunk>> + Send + Unpin>> {
1554 self.model.stream(messages, options).await
1555 }
1556
1557 fn supports_tools(&self) -> bool {
1558 true
1559 }
1560
1561 fn supports_vision(&self) -> bool {
1562 true
1563 }
1564
1565 fn supports_caching(&self) -> bool {
1566 true
1567 }
1568}
1569