atlas/providers/anthropic/
completion.rs

1//! Anthropic completion api implementation
2
3use std::iter;
4
5use crate::{
6    completion::{self, CompletionError},
7    json_utils,
8};
9
10use serde::{Deserialize, Serialize};
11use serde_json::json;
12
13use super::client::Client;
14
15// ================================================================
16// Anthropic Completion API
17// ================================================================
18/// `claude-3-5-sonnet-latest` completion model
19pub const CLAUDE_3_5_SONNET: &str = "claude-3-5-sonnet-latest";
20
21/// `claude-3-5-haiku-latest` completion model
22pub const CLAUDE_3_5_HAIKU: &str = "claude-3-5-haiku-latest";
23
24/// `claude-3-5-haiku-latest` completion model
25pub const CLAUDE_3_OPUS: &str = "claude-3-opus-latest";
26
27/// `claude-3-sonnet-20240229` completion model
28pub const CLAUDE_3_SONNET: &str = "claude-3-sonnet-20240229";
29
30/// `claude-3-haiku-20240307` completion model
31pub const CLAUDE_3_HAIKU: &str = "claude-3-haiku-20240307";
32
33pub const ANTHROPIC_VERSION_2023_01_01: &str = "2023-01-01";
34pub const ANTHROPIC_VERSION_2023_06_01: &str = "2023-06-01";
35pub const ANTHROPIC_VERSION_LATEST: &str = ANTHROPIC_VERSION_2023_06_01;
36
37#[derive(Debug, Deserialize)]
38pub struct CompletionResponse {
39    pub content: Vec<Content>,
40    pub id: String,
41    pub model: String,
42    pub role: String,
43    pub stop_reason: Option<String>,
44    pub stop_sequence: Option<String>,
45    pub usage: Usage,
46}
47
48#[derive(Debug, Deserialize, Serialize)]
49#[serde(untagged)]
50pub enum Content {
51    String(String),
52    Text {
53        r#type: String,
54        text: String,
55    },
56    ToolUse {
57        r#type: String,
58        id: String,
59        name: String,
60        input: serde_json::Value,
61    },
62}
63
64#[derive(Debug, Deserialize, Serialize)]
65pub struct Usage {
66    pub input_tokens: u64,
67    pub cache_read_input_tokens: Option<u64>,
68    pub cache_creation_input_tokens: Option<u64>,
69    pub output_tokens: u64,
70}
71
72impl std::fmt::Display for Usage {
73    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
74        write!(
75            f,
76            "Input tokens: {}\nCache read input tokens: {}\nCache creation input tokens: {}\nOutput tokens: {}",
77            self.input_tokens,
78            match self.cache_read_input_tokens {
79                Some(token) => token.to_string(),
80                None => "n/a".to_string(),
81            },
82            match self.cache_creation_input_tokens {
83                Some(token) => token.to_string(),
84                None => "n/a".to_string(),
85            },
86            self.output_tokens
87        )
88    }
89}
90
91#[derive(Debug, Deserialize, Serialize)]
92pub struct ToolDefinition {
93    pub name: String,
94    pub description: Option<String>,
95    pub input_schema: serde_json::Value,
96}
97
98#[derive(Debug, Deserialize, Serialize)]
99#[serde(tag = "type", rename_all = "snake_case")]
100pub enum CacheControl {
101    Ephemeral,
102}
103
104impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
105    type Error = CompletionError;
106
107    fn try_from(response: CompletionResponse) -> std::prelude::v1::Result<Self, Self::Error> {
108        match response.content.as_slice() {
109            [Content::String(text) | Content::Text { text, .. }, ..] => {
110                Ok(completion::CompletionResponse {
111                    choice: completion::ModelChoice::Message(text.to_string()),
112                    raw_response: response,
113                })
114            }
115            [Content::ToolUse { name, input, .. }, ..] => Ok(completion::CompletionResponse {
116                choice: completion::ModelChoice::ToolCall(name.clone(), input.clone()),
117                raw_response: response,
118            }),
119            _ => Err(CompletionError::ResponseError(
120                "Response did not contain a message or tool call".into(),
121            )),
122        }
123    }
124}
125
126#[derive(Debug, Deserialize, Serialize)]
127pub struct Message {
128    pub role: String,
129    pub content: String,
130}
131
132impl From<completion::Message> for Message {
133    fn from(message: completion::Message) -> Self {
134        Self {
135            role: message.role,
136            content: message.content,
137        }
138    }
139}
140
141#[derive(Clone)]
142pub struct CompletionModel {
143    client: Client,
144    pub model: String,
145    default_max_tokens: Option<u64>,
146}
147
148impl CompletionModel {
149    pub fn new(client: Client, model: &str) -> Self {
150        Self {
151            client,
152            model: model.to_string(),
153            default_max_tokens: calculate_max_tokens(model),
154        }
155    }
156}
157
158/// Anthropic requires a `max_tokens` parameter to be set, which is dependant on the model. If not
159/// set or if set too high, the request will fail. The following values are based on the models
160/// available at the time of writing.
161///
162/// Dev Note: This is really bad design, I'm not sure why they did it like this..
163fn calculate_max_tokens(model: &str) -> Option<u64> {
164    if model.starts_with("claude-3-5-sonnet") || model.starts_with("claude-3-5-haiku") {
165        Some(8192)
166    } else if model.starts_with("claude-3-opus")
167        || model.starts_with("claude-3-sonnet")
168        || model.starts_with("claude-3-haiku")
169    {
170        Some(4096)
171    } else {
172        None
173    }
174}
175
176#[derive(Debug, Deserialize, Serialize)]
177struct Metadata {
178    user_id: Option<String>,
179}
180
181#[derive(Debug, Serialize, Deserialize)]
182#[serde(tag = "type", rename_all = "snake_case")]
183enum ToolChoice {
184    Auto,
185    Any,
186    Tool { name: String },
187}
188
189impl completion::CompletionModel for CompletionModel {
190    type Response = CompletionResponse;
191
192    async fn completion(
193        &self,
194        completion_request: completion::CompletionRequest,
195    ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
196        // Note: Ideally we'd introduce provider-specific Request models to handle the
197        // specific requirements of each provider. For now, we just manually check while
198        // building the request as a raw JSON document.
199
200        let prompt_with_context = completion_request.prompt_with_context();
201
202        // Check if max_tokens is set, required for Anthropic
203        let max_tokens = if let Some(tokens) = completion_request.max_tokens {
204            tokens
205        } else if let Some(tokens) = self.default_max_tokens {
206            tokens
207        } else {
208            return Err(CompletionError::RequestError(
209                "`max_tokens` must be set for Anthropic".into(),
210            ));
211        };
212
213        let mut request = json!({
214            "model": self.model,
215            "messages": completion_request
216                .chat_history
217                .into_iter()
218                .map(Message::from)
219                .chain(iter::once(Message {
220                    role: "user".to_owned(),
221                    content: prompt_with_context,
222                }))
223                .collect::<Vec<_>>(),
224            "max_tokens": max_tokens,
225            "system": completion_request.preamble.unwrap_or("".to_string()),
226        });
227
228        if let Some(temperature) = completion_request.temperature {
229            json_utils::merge_inplace(&mut request, json!({ "temperature": temperature }));
230        }
231
232        if !completion_request.tools.is_empty() {
233            json_utils::merge_inplace(
234                &mut request,
235                json!({
236                    "tools": completion_request
237                        .tools
238                        .into_iter()
239                        .map(|tool| ToolDefinition {
240                            name: tool.name,
241                            description: Some(tool.description),
242                            input_schema: tool.parameters,
243                        })
244                        .collect::<Vec<_>>(),
245                    "tool_choice": ToolChoice::Auto,
246                }),
247            );
248        }
249
250        if let Some(ref params) = completion_request.additional_params {
251            json_utils::merge_inplace(&mut request, params.clone())
252        }
253
254        let response = self
255            .client
256            .post("/v1/messages")
257            .json(&request)
258            .send()
259            .await?;
260
261        if response.status().is_success() {
262            match response.json::<ApiResponse<CompletionResponse>>().await? {
263                ApiResponse::Message(completion) => {
264                    tracing::info!(target: "rig",
265                        "Anthropic completion token usage: {}",
266                        completion.usage
267                    );
268                    completion.try_into()
269                }
270                ApiResponse::Error(error) => Err(CompletionError::ProviderError(error.message)),
271            }
272        } else {
273            Err(CompletionError::ProviderError(response.text().await?))
274        }
275    }
276}
277
278#[derive(Debug, Deserialize)]
279struct ApiErrorResponse {
280    message: String,
281}
282
283#[derive(Debug, Deserialize)]
284#[serde(tag = "type", rename_all = "snake_case")]
285enum ApiResponse<T> {
286    Message(T),
287    Error(ApiErrorResponse),
288}