atlas/providers/anthropic/
completion.rs1use std::iter;
4
5use crate::{
6 completion::{self, CompletionError},
7 json_utils,
8};
9
10use serde::{Deserialize, Serialize};
11use serde_json::json;
12
13use super::client::Client;
14
15pub const CLAUDE_3_5_SONNET: &str = "claude-3-5-sonnet-latest";
20
21pub const CLAUDE_3_5_HAIKU: &str = "claude-3-5-haiku-latest";
23
24pub const CLAUDE_3_OPUS: &str = "claude-3-opus-latest";
26
27pub const CLAUDE_3_SONNET: &str = "claude-3-sonnet-20240229";
29
30pub const CLAUDE_3_HAIKU: &str = "claude-3-haiku-20240307";
32
33pub const ANTHROPIC_VERSION_2023_01_01: &str = "2023-01-01";
34pub const ANTHROPIC_VERSION_2023_06_01: &str = "2023-06-01";
35pub const ANTHROPIC_VERSION_LATEST: &str = ANTHROPIC_VERSION_2023_06_01;
36
37#[derive(Debug, Deserialize)]
38pub struct CompletionResponse {
39 pub content: Vec<Content>,
40 pub id: String,
41 pub model: String,
42 pub role: String,
43 pub stop_reason: Option<String>,
44 pub stop_sequence: Option<String>,
45 pub usage: Usage,
46}
47
48#[derive(Debug, Deserialize, Serialize)]
49#[serde(untagged)]
50pub enum Content {
51 String(String),
52 Text {
53 r#type: String,
54 text: String,
55 },
56 ToolUse {
57 r#type: String,
58 id: String,
59 name: String,
60 input: serde_json::Value,
61 },
62}
63
64#[derive(Debug, Deserialize, Serialize)]
65pub struct Usage {
66 pub input_tokens: u64,
67 pub cache_read_input_tokens: Option<u64>,
68 pub cache_creation_input_tokens: Option<u64>,
69 pub output_tokens: u64,
70}
71
72impl std::fmt::Display for Usage {
73 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
74 write!(
75 f,
76 "Input tokens: {}\nCache read input tokens: {}\nCache creation input tokens: {}\nOutput tokens: {}",
77 self.input_tokens,
78 match self.cache_read_input_tokens {
79 Some(token) => token.to_string(),
80 None => "n/a".to_string(),
81 },
82 match self.cache_creation_input_tokens {
83 Some(token) => token.to_string(),
84 None => "n/a".to_string(),
85 },
86 self.output_tokens
87 )
88 }
89}
90
91#[derive(Debug, Deserialize, Serialize)]
92pub struct ToolDefinition {
93 pub name: String,
94 pub description: Option<String>,
95 pub input_schema: serde_json::Value,
96}
97
98#[derive(Debug, Deserialize, Serialize)]
99#[serde(tag = "type", rename_all = "snake_case")]
100pub enum CacheControl {
101 Ephemeral,
102}
103
104impl TryFrom<CompletionResponse> for completion::CompletionResponse<CompletionResponse> {
105 type Error = CompletionError;
106
107 fn try_from(response: CompletionResponse) -> std::prelude::v1::Result<Self, Self::Error> {
108 match response.content.as_slice() {
109 [Content::String(text) | Content::Text { text, .. }, ..] => {
110 Ok(completion::CompletionResponse {
111 choice: completion::ModelChoice::Message(text.to_string()),
112 raw_response: response,
113 })
114 }
115 [Content::ToolUse { name, input, .. }, ..] => Ok(completion::CompletionResponse {
116 choice: completion::ModelChoice::ToolCall(name.clone(), input.clone()),
117 raw_response: response,
118 }),
119 _ => Err(CompletionError::ResponseError(
120 "Response did not contain a message or tool call".into(),
121 )),
122 }
123 }
124}
125
126#[derive(Debug, Deserialize, Serialize)]
127pub struct Message {
128 pub role: String,
129 pub content: String,
130}
131
132impl From<completion::Message> for Message {
133 fn from(message: completion::Message) -> Self {
134 Self {
135 role: message.role,
136 content: message.content,
137 }
138 }
139}
140
141#[derive(Clone)]
142pub struct CompletionModel {
143 client: Client,
144 pub model: String,
145 default_max_tokens: Option<u64>,
146}
147
148impl CompletionModel {
149 pub fn new(client: Client, model: &str) -> Self {
150 Self {
151 client,
152 model: model.to_string(),
153 default_max_tokens: calculate_max_tokens(model),
154 }
155 }
156}
157
158fn calculate_max_tokens(model: &str) -> Option<u64> {
164 if model.starts_with("claude-3-5-sonnet") || model.starts_with("claude-3-5-haiku") {
165 Some(8192)
166 } else if model.starts_with("claude-3-opus")
167 || model.starts_with("claude-3-sonnet")
168 || model.starts_with("claude-3-haiku")
169 {
170 Some(4096)
171 } else {
172 None
173 }
174}
175
176#[derive(Debug, Deserialize, Serialize)]
177struct Metadata {
178 user_id: Option<String>,
179}
180
181#[derive(Debug, Serialize, Deserialize)]
182#[serde(tag = "type", rename_all = "snake_case")]
183enum ToolChoice {
184 Auto,
185 Any,
186 Tool { name: String },
187}
188
189impl completion::CompletionModel for CompletionModel {
190 type Response = CompletionResponse;
191
192 async fn completion(
193 &self,
194 completion_request: completion::CompletionRequest,
195 ) -> Result<completion::CompletionResponse<CompletionResponse>, CompletionError> {
196 let prompt_with_context = completion_request.prompt_with_context();
201
202 let max_tokens = if let Some(tokens) = completion_request.max_tokens {
204 tokens
205 } else if let Some(tokens) = self.default_max_tokens {
206 tokens
207 } else {
208 return Err(CompletionError::RequestError(
209 "`max_tokens` must be set for Anthropic".into(),
210 ));
211 };
212
213 let mut request = json!({
214 "model": self.model,
215 "messages": completion_request
216 .chat_history
217 .into_iter()
218 .map(Message::from)
219 .chain(iter::once(Message {
220 role: "user".to_owned(),
221 content: prompt_with_context,
222 }))
223 .collect::<Vec<_>>(),
224 "max_tokens": max_tokens,
225 "system": completion_request.preamble.unwrap_or("".to_string()),
226 });
227
228 if let Some(temperature) = completion_request.temperature {
229 json_utils::merge_inplace(&mut request, json!({ "temperature": temperature }));
230 }
231
232 if !completion_request.tools.is_empty() {
233 json_utils::merge_inplace(
234 &mut request,
235 json!({
236 "tools": completion_request
237 .tools
238 .into_iter()
239 .map(|tool| ToolDefinition {
240 name: tool.name,
241 description: Some(tool.description),
242 input_schema: tool.parameters,
243 })
244 .collect::<Vec<_>>(),
245 "tool_choice": ToolChoice::Auto,
246 }),
247 );
248 }
249
250 if let Some(ref params) = completion_request.additional_params {
251 json_utils::merge_inplace(&mut request, params.clone())
252 }
253
254 let response = self
255 .client
256 .post("/v1/messages")
257 .json(&request)
258 .send()
259 .await?;
260
261 if response.status().is_success() {
262 match response.json::<ApiResponse<CompletionResponse>>().await? {
263 ApiResponse::Message(completion) => {
264 tracing::info!(target: "rig",
265 "Anthropic completion token usage: {}",
266 completion.usage
267 );
268 completion.try_into()
269 }
270 ApiResponse::Error(error) => Err(CompletionError::ProviderError(error.message)),
271 }
272 } else {
273 Err(CompletionError::ProviderError(response.text().await?))
274 }
275 }
276}
277
278#[derive(Debug, Deserialize)]
279struct ApiErrorResponse {
280 message: String,
281}
282
283#[derive(Debug, Deserialize)]
284#[serde(tag = "type", rename_all = "snake_case")]
285enum ApiResponse<T> {
286 Message(T),
287 Error(ApiErrorResponse),
288}