1use crate::anthropic::{
2 AnthropicError, AnthropicModelMode, ContentDelta, Event, ResponseContent, ToolResultContent,
3 ToolResultPart, Usage,
4};
5use anyhow::{Context as _, anyhow};
6use futures::Stream;
7use futures::{FutureExt, StreamExt, future::BoxFuture, stream::BoxStream};
8use std::collections::{BTreeMap, HashMap};
9
10use crate::http_client::HttpClient;
11use crate::model::{
12 self, LanguageModel, LanguageModelCompletionError, LanguageModelId, LanguageModelName,
13 LanguageModelProvider, LanguageModelProviderId, LanguageModelProviderName,
14 LanguageModelRequest, LanguageModelToolChoice, LanguageModelToolResultContent, MessageContent,
15 Role,
16};
17use crate::model::{LanguageModelCompletionEvent, LanguageModelToolUse, StopReason};
18use schemars::JsonSchema;
19use serde::{Deserialize, Serialize};
20use std::pin::Pin;
21use std::str::FromStr;
22use std::sync::Arc;
23use strum::IntoEnumIterator;
24use crate::anthropic;
26
27const PROVIDER_ID: LanguageModelProviderId = model::ANTHROPIC_PROVIDER_ID;
28const PROVIDER_NAME: LanguageModelProviderName = model::ANTHROPIC_PROVIDER_NAME;
29
30#[derive(Default, Clone, Debug, PartialEq)]
31pub struct AnthropicSettings {
32 pub api_url: String,
33 pub api_key: String,
35}
36
37#[derive(Clone, Debug, PartialEq, Serialize, Deserialize, JsonSchema)]
38pub struct AvailableModel {
39 pub name: String,
41 pub display_name: Option<String>,
43 pub max_tokens: u64,
45 pub tool_override: Option<String>,
47 pub max_output_tokens: Option<u64>,
50 pub default_temperature: Option<f32>,
51 #[serde(default)]
52 pub extra_beta_headers: Vec<String>,
53 pub mode: Option<ModelMode>,
55}
56
57#[derive(Clone, Debug, Default, PartialEq, Serialize, Deserialize, JsonSchema)]
58#[serde(tag = "type", rename_all = "lowercase")]
59pub enum ModelMode {
60 #[default]
61 Default,
62 Thinking {
63 budget_tokens: Option<u32>,
65 },
66}
67
68impl From<ModelMode> for AnthropicModelMode {
69 fn from(value: ModelMode) -> Self {
70 match value {
71 ModelMode::Default => AnthropicModelMode::Default,
72 ModelMode::Thinking { budget_tokens } => AnthropicModelMode::Thinking { budget_tokens },
73 }
74 }
75}
76
77impl From<AnthropicModelMode> for ModelMode {
78 fn from(value: AnthropicModelMode) -> Self {
79 match value {
80 AnthropicModelMode::Default => ModelMode::Default,
81 AnthropicModelMode::Thinking { budget_tokens } => ModelMode::Thinking { budget_tokens },
82 }
83 }
84}
85
86pub struct AnthropicLanguageModelProvider {
87 http_client: Arc<dyn HttpClient>,
88 }
90
91const ANTHROPIC_API_KEY_VAR: &str = "ANTHROPIC_API_KEY";
92
93
94impl AnthropicLanguageModelProvider {
95 pub fn new(http_client: Arc<dyn HttpClient>) -> Self {
96 Self { http_client }
97 }
98
99 pub fn create_language_model(&self, model: anthropic::Model) -> Arc<dyn LanguageModel> {
100 Arc::new(AnthropicModel {
101 id: LanguageModelId::from(model.id().to_string()),
102 model,
103 http_client: self.http_client.clone(),
104 })
105 }
106}
107
108impl LanguageModelProvider for AnthropicLanguageModelProvider {
109 fn id(&self) -> LanguageModelProviderId {
110 PROVIDER_ID
111 }
112
113 fn name(&self) -> LanguageModelProviderName {
114 PROVIDER_NAME
115 }
116
117 fn default_model(&self) -> Option<Arc<dyn LanguageModel>> {
118 Some(self.create_language_model(anthropic::Model::default()))
119 }
120
121 fn default_fast_model(&self) -> Option<Arc<dyn LanguageModel>> {
122 Some(self.create_language_model(anthropic::Model::default_fast()))
123 }
124
125 fn provided_models(&self) -> Vec<Arc<dyn LanguageModel>> {
136 todo!()
178 }
179
180 }
188
189pub struct AnthropicModel {
190 id: LanguageModelId,
191 model: anthropic::Model,
192 http_client: Arc<dyn HttpClient>,
193 }
195
196pub fn count_anthropic_tokens(
197 request: LanguageModelRequest,
198) -> BoxFuture<'static, anyhow::Result<u64>> {
199 async move {
200 let messages = request.messages;
201 let tokens_from_images = 0;
202 let mut string_messages = Vec::with_capacity(messages.len());
203
204 for message in messages {
205 use crate::model::MessageContent;
206
207 let mut string_contents = String::new();
208
209 for content in message.content {
210 match content {
211 MessageContent::Text(text) => {
212 string_contents.push_str(&text);
213 }
214 MessageContent::Thinking { .. } => {
215 }
217 MessageContent::RedactedThinking(_) => {
218 }
220 MessageContent::Image(image) => {
221 }
224 MessageContent::ToolUse(_tool_use) => {
225 }
227 MessageContent::ToolResult(tool_result) => match &tool_result.content {
228 LanguageModelToolResultContent::Text(text) => {
229 string_contents.push_str(text);
230 } },
234 }
235 }
236
237 if !string_contents.is_empty() {
238 string_messages.push(tiktoken_rs::ChatCompletionRequestMessage {
239 role: match message.role {
240 Role::User => "user".into(),
241 Role::Assistant => "assistant".into(),
242 Role::System => "system".into(),
243 },
244 content: Some(string_contents),
245 name: None,
246 function_call: None,
247 });
248 }
249 }
250
251 tiktoken_rs::num_tokens_from_messages("gpt-4", &string_messages)
254 .map(|tokens| (tokens + tokens_from_images) as u64)
255 }
256 .boxed()
257}
258
259impl AnthropicModel {
260 async fn stream_completion(
261 &self,
262 request: anthropic::Request,
263 ) ->
264 Result<
265 BoxStream<'static, Result<anthropic::Event, AnthropicError>>,
266 LanguageModelCompletionError,
267 >
268 {
269 let http_client = self.http_client.clone();
270
271 let anthropic_settings =
272 global_registry::get!(AnthropicSettings).expect("AnthropicSettings not found");
273 let api_key = anthropic_settings.api_key.clone();
274 let api_url = anthropic_settings.api_url.clone();
275
276 anthropic::stream_completion(http_client.as_ref(), &api_url, &api_key, request).await
277 .map_err(Into::<LanguageModelCompletionError>::into)
278 }
279}
280#[async_trait::async_trait]
281impl LanguageModel for AnthropicModel {
282 fn id(&self) -> LanguageModelId {
283 self.id.clone()
284 }
285
286 fn name(&self) -> LanguageModelName {
287 LanguageModelName::from(self.model.display_name().to_string())
288 }
289
290 fn provider_id(&self) -> LanguageModelProviderId {
291 PROVIDER_ID
292 }
293
294 fn provider_name(&self) -> LanguageModelProviderName {
295 PROVIDER_NAME
296 }
297
298 fn max_token_count(&self) -> u64 {
299 self.model.max_token_count()
300 }
301
302 fn max_output_tokens(&self) -> Option<u64> {
303 Some(self.model.max_output_tokens())
304 }
305
306 async fn stream_completion(
307 &self,
308 request: LanguageModelRequest,
309 ) -> Result<
310 BoxStream<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>,
311 LanguageModelCompletionError,
312 > {
313 let request = into_anthropic(
314 request,
315 self.model.request_id().into(),
316 self.model.default_temperature(),
317 self.model.max_output_tokens(),
318 self.model.mode(),
319 );
320 let response = self.stream_completion(request).await?;
321 let stream = AnthropicEventMapper::new().map_stream(response);
322 Ok(stream.boxed())
323 }
324
325
326 fn supports_tools(&self) -> bool {
327 true
328 }
329
330 fn supports_burn_mode(&self) -> bool {
331 true
332 }
333}
334
335pub fn into_anthropic(
336 request: LanguageModelRequest,
337 model: String,
338 default_temperature: f32,
339 max_output_tokens: u64,
340 mode: AnthropicModelMode,
341) -> anthropic::Request {
342 let mut new_messages: Vec<anthropic::Message> = Vec::new();
343 let mut system_message = String::new();
344
345 for message in request.messages {
346 if message.contents_empty() {
347 continue;
348 }
349
350 match message.role {
351 Role::User | Role::Assistant => {
352 let mut anthropic_message_content: Vec<anthropic::RequestContent> = message
353 .content
354 .into_iter()
355 .filter_map(|content| match content {
356 MessageContent::Text(text) => {
357 let text = if text.chars().last().map_or(false, |c| c.is_whitespace()) {
358 text.trim_end().to_string()
359 } else {
360 text
361 };
362 if !text.is_empty() {
363 Some(anthropic::RequestContent::Text {
364 text,
365 cache_control: None,
366 })
367 } else {
368 None
369 }
370 }
371 MessageContent::Thinking {
372 text: thinking,
373 signature,
374 } => {
375 if !thinking.is_empty() {
376 Some(anthropic::RequestContent::Thinking {
377 thinking,
378 signature: signature.unwrap_or_default(),
379 cache_control: None,
380 })
381 } else {
382 None
383 }
384 }
385 MessageContent::RedactedThinking(data) => {
386 if !data.is_empty() {
387 Some(anthropic::RequestContent::RedactedThinking { data })
388 } else {
389 None
390 }
391 }
392 MessageContent::Image(image) => Some(anthropic::RequestContent::Image {
393 source: anthropic::ImageSource {
394 source_type: "base64".to_string(),
395 media_type: "image/png".to_string(),
396 data: image.source.to_string(),
397 },
398 cache_control: None,
399 }),
400 MessageContent::ToolUse(tool_use) => {
401 Some(anthropic::RequestContent::ToolUse {
402 id: tool_use.id.to_string(),
403 name: tool_use.name.to_string(),
404 input: tool_use.input,
405 cache_control: None,
406 })
407 }
408 MessageContent::ToolResult(tool_result) => {
409 Some(anthropic::RequestContent::ToolResult {
410 tool_use_id: tool_result.tool_use_id.to_string(),
411 is_error: tool_result.is_error,
412 content: match tool_result.content {
413 LanguageModelToolResultContent::Text(text) => {
414 ToolResultContent::Plain(text.to_string())
415 } },
425 cache_control: None,
426 })
427 }
428 })
429 .collect();
430 let anthropic_role = match message.role {
431 Role::User => anthropic::Role::User,
432 Role::Assistant => anthropic::Role::Assistant,
433 Role::System => unreachable!("System role should never occur here"),
434 };
435 if let Some(last_message) = new_messages.last_mut() {
436 if last_message.role == anthropic_role {
437 last_message.content.extend(anthropic_message_content);
438 continue;
439 }
440 }
441
442 if message.cache {
444 let cache_control_value = Some(anthropic::CacheControl {
445 cache_type: anthropic::CacheControlType::Ephemeral,
446 });
447 for message_content in anthropic_message_content.iter_mut().rev() {
448 match message_content {
449 anthropic::RequestContent::RedactedThinking { .. } => {
450 }
452 anthropic::RequestContent::Text { cache_control, .. }
453 | anthropic::RequestContent::Thinking { cache_control, .. }
454 | anthropic::RequestContent::Image { cache_control, .. }
455 | anthropic::RequestContent::ToolUse { cache_control, .. }
456 | anthropic::RequestContent::ToolResult { cache_control, .. } => {
457 *cache_control = cache_control_value;
458 break;
459 }
460 }
461 }
462 }
463
464 new_messages.push(anthropic::Message {
465 role: anthropic_role,
466 content: anthropic_message_content,
467 });
468 }
469 Role::System => {
470 if !system_message.is_empty() {
471 system_message.push_str("\n\n");
472 }
473 system_message.push_str(&message.string_contents());
474 }
475 }
476 }
477
478 anthropic::Request {
479 model,
480 messages: new_messages,
481 max_tokens: max_output_tokens,
482 system: if system_message.is_empty() {
483 None
484 } else {
485 Some(anthropic::StringOrContents::String(system_message))
486 },
487 thinking: if request.thinking_allowed
488 && let AnthropicModelMode::Thinking { budget_tokens } = mode
489 {
490 Some(anthropic::Thinking::Enabled { budget_tokens })
491 } else {
492 None
493 },
494 tools: request
495 .tools
496 .into_iter()
497 .map(|tool| anthropic::Tool {
498 name: tool.name,
499 description: tool.description,
500 input_schema: tool.input_schema,
501 })
502 .collect(),
503 tool_choice: request.tool_choice.map(|choice| match choice {
504 LanguageModelToolChoice::Auto => anthropic::ToolChoice::Auto,
505 LanguageModelToolChoice::Any => anthropic::ToolChoice::Any,
506 LanguageModelToolChoice::None => anthropic::ToolChoice::None,
507 }),
508 metadata: None,
509 stop_sequences: Vec::new(),
510 temperature: request.temperature.or(Some(default_temperature)),
511 top_k: None,
512 top_p: None,
513 }
514}
515
516pub struct AnthropicEventMapper {
517 tool_uses_by_index: HashMap<usize, RawToolUse>,
518 usage: Usage,
519 stop_reason: StopReason,
520}
521
522impl AnthropicEventMapper {
523 pub fn new() -> Self {
524 Self {
525 tool_uses_by_index: HashMap::default(),
526 usage: Usage::default(),
527 stop_reason: StopReason::EndTurn,
528 }
529 }
530
531 pub fn map_stream(
532 mut self,
533 events: Pin<Box<dyn Send + Stream<Item=Result<Event, AnthropicError>>>>,
534 ) -> impl Stream<Item=Result<LanguageModelCompletionEvent, LanguageModelCompletionError>>
535 {
536 events.flat_map(move |event| {
537 futures::stream::iter(match event {
538 Ok(event) => self.map_event(event),
539 Err(error) => vec![Err(error.into())],
540 })
541 })
542 }
543
544 pub fn map_event(
545 &mut self,
546 event: Event,
547 ) -> Vec<Result<LanguageModelCompletionEvent, LanguageModelCompletionError>> {
548 match event {
549 Event::ContentBlockStart {
550 index,
551 content_block,
552 } => match content_block {
553 ResponseContent::Text { text } => {
554 vec![Ok(LanguageModelCompletionEvent::Text(text))]
555 }
556 ResponseContent::Thinking { thinking } => {
557 vec![Ok(LanguageModelCompletionEvent::Thinking {
558 text: thinking,
559 signature: None,
560 })]
561 }
562 ResponseContent::RedactedThinking { data } => {
563 vec![Ok(LanguageModelCompletionEvent::RedactedThinking { data })]
564 }
565 ResponseContent::ToolUse { id, name, .. } => {
566 self.tool_uses_by_index.insert(
567 index,
568 RawToolUse {
569 id,
570 name,
571 input_json: String::new(),
572 },
573 );
574 Vec::new()
575 }
576 },
577 Event::ContentBlockDelta { index, delta } => match delta {
578 ContentDelta::TextDelta { text } => {
579 vec![Ok(LanguageModelCompletionEvent::Text(text))]
580 }
581 ContentDelta::ThinkingDelta { thinking } => {
582 vec![Ok(LanguageModelCompletionEvent::Thinking {
583 text: thinking,
584 signature: None,
585 })]
586 }
587 ContentDelta::SignatureDelta { signature } => {
588 vec![Ok(LanguageModelCompletionEvent::Thinking {
589 text: "".to_string(),
590 signature: Some(signature),
591 })]
592 }
593 ContentDelta::InputJsonDelta { partial_json } => {
594 if let Some(tool_use) = self.tool_uses_by_index.get_mut(&index) {
595 tool_use.input_json.push_str(&partial_json);
596
597 if let Ok(input) = serde_json::Value::from_str(
602 &partial_json_fixer::fix_json(&tool_use.input_json),
603 ) {
604 return vec![Ok(LanguageModelCompletionEvent::ToolUse(
605 LanguageModelToolUse {
606 id: tool_use.id.clone().into(),
607 name: tool_use.name.clone().into(),
608 is_input_complete: false,
609 raw_input: tool_use.input_json.clone(),
610 input,
611 },
612 ))];
613 }
614 }
615 return vec![];
616 }
617 },
618 Event::ContentBlockStop { index } => {
619 if let Some(tool_use) = self.tool_uses_by_index.remove(&index) {
620 let input_json = tool_use.input_json.trim();
621 let input_value = if input_json.is_empty() {
622 Ok(serde_json::Value::Object(serde_json::Map::default()))
623 } else {
624 serde_json::Value::from_str(input_json)
625 };
626 let event_result = match input_value {
627 Ok(input) => Ok(LanguageModelCompletionEvent::ToolUse(
628 LanguageModelToolUse {
629 id: tool_use.id.into(),
630 name: tool_use.name.into(),
631 is_input_complete: true,
632 input,
633 raw_input: tool_use.input_json.clone(),
634 },
635 )),
636 Err(json_parse_err) => {
637 Ok(LanguageModelCompletionEvent::ToolUseJsonParseError {
638 id: tool_use.id.into(),
639 tool_name: tool_use.name.into(),
640 raw_input: input_json.into(),
641 json_parse_error: json_parse_err.to_string(),
642 })
643 }
644 };
645
646 vec![event_result]
647 } else {
648 Vec::new()
649 }
650 }
651 Event::MessageStart { message } => {
652 update_usage(&mut self.usage, &message.usage);
653 vec![
654 Ok(LanguageModelCompletionEvent::UsageUpdate(convert_usage(
655 &self.usage,
656 ))),
657 Ok(LanguageModelCompletionEvent::StartMessage {
658 message_id: message.id,
659 }),
660 ]
661 }
662 Event::MessageDelta { delta, usage } => {
663 update_usage(&mut self.usage, &usage);
664 if let Some(stop_reason) = delta.stop_reason.as_deref() {
665 self.stop_reason = match stop_reason {
666 "end_turn" => StopReason::EndTurn,
667 "max_tokens" => StopReason::MaxTokens,
668 "tool_use" => StopReason::ToolUse,
669 "refusal" => StopReason::Refusal,
670 _ => {
671 log::error!("Unexpected anthropic stop_reason: {stop_reason}");
672 StopReason::EndTurn
673 }
674 };
675 }
676 vec![Ok(LanguageModelCompletionEvent::UsageUpdate(
677 convert_usage(&self.usage),
678 ))]
679 }
680 Event::MessageStop => {
681 vec![Ok(LanguageModelCompletionEvent::Stop(self.stop_reason))]
682 }
683 Event::Error { error } => {
684 vec![Err(error.into())]
685 }
686 _ => Vec::new(),
687 }
688 }
689}
690
691struct RawToolUse {
692 id: String,
693 name: String,
694 input_json: String,
695}
696
697fn update_usage(usage: &mut Usage, new: &Usage) {
699 if let Some(input_tokens) = new.input_tokens {
700 usage.input_tokens = Some(input_tokens);
701 }
702 if let Some(output_tokens) = new.output_tokens {
703 usage.output_tokens = Some(output_tokens);
704 }
705 if let Some(cache_creation_input_tokens) = new.cache_creation_input_tokens {
706 usage.cache_creation_input_tokens = Some(cache_creation_input_tokens);
707 }
708 if let Some(cache_read_input_tokens) = new.cache_read_input_tokens {
709 usage.cache_read_input_tokens = Some(cache_read_input_tokens);
710 }
711}
712
713fn convert_usage(usage: &Usage) -> model::TokenUsage {
714 model::TokenUsage {
715 input_tokens: usage.input_tokens.unwrap_or(0),
716 output_tokens: usage.output_tokens.unwrap_or(0),
717 cache_creation_input_tokens: usage.cache_creation_input_tokens.unwrap_or(0),
718 cache_read_input_tokens: usage.cache_read_input_tokens.unwrap_or(0),
719 }
720}