use async_trait::async_trait;
use seedframe::completion::{Client, CompletionError, CompletionModel, Message, TokenUsage};
use seedframe::embeddings::Embedder;
use seedframe::tools::{ToolCall, ToolResponse, ToolSet};
use serde::{Deserialize, Serialize};
use serde_json::json;
const API_KEY_ENV_VAR: &str = "SEEDFRAME_ANTHROPIC_API_KEY";
const URL: &str = "https://api.anthropic.com/v1/messages ";
const DEFAULT_MODEL: &str = "claude-3-7-sonnet-20250219";
const DEFAULT_TEMP: f64 = 1.0;
const DEFAULT_TOKENS: usize = 1023;
mod utils;
type MessageHistory = Vec<Message>;
#[derive(Serialize, Deserialize)]
#[serde(deny_unknown_fields)]
struct ModelConfig {
api_key: Option<String>,
api_url: Option<String>,
model: Option<String>,
}
pub struct AnthropicCompletionModel {
api_key: String,
api_url: String,
client: reqwest::Client,
model: String,
system: Option<String>,
}
impl AnthropicCompletionModel {
#[must_use]
pub fn new(config_json: Option<&str>) -> Self {
let (api_key_var, api_url, model) = if let Some(json) = config_json {
let config: ModelConfig = serde_json::from_str(json).unwrap();
(
config.api_key.unwrap_or(API_KEY_ENV_VAR.to_string()),
config.api_url.unwrap_or(URL.to_string()),
config.model.unwrap_or(DEFAULT_MODEL.to_string()),
)
} else {
(
API_KEY_ENV_VAR.to_string(),
URL.to_string(),
DEFAULT_MODEL.to_string(),
)
};
let api_key = std::env::var(api_key_var).unwrap();
Self {
api_key,
api_url,
client: reqwest::Client::new(),
model,
system: None,
}
}
}
#[derive(Serialize, Deserialize, Debug, Eq, PartialEq)]
#[serde(tag = "role", content = "content")]
#[allow(non_camel_case_types)]
pub(crate) enum AnthropicMessage {
user(Vec<ContentBlock>),
assistant(Vec<ContentBlock>),
}
#[derive(Serialize, Deserialize, Eq, PartialEq, Clone, Debug)]
#[serde(tag = "type")]
pub(crate) enum ContentBlock {
#[serde(rename = "text")]
Text {
text: String,
},
#[serde(rename = "thinking")]
Thinking {
thinking: String,
signature: String,
},
#[serde(rename = "redacted_thinking")]
RedactedThinking {
data: String,
},
ToolUse(ToolCall),
ToolResult(ToolResponse),
}
impl From<Message> for AnthropicMessage {
fn from(value: Message) -> Self {
match value {
Message::Preamble(_) => Self::user(vec![ContentBlock::Text {
text: String::new(),
}]),
Message::User {
content,
tool_responses,
} => {
let mut out: Vec<ContentBlock> = Vec::new();
if !content.is_empty() {
let vals = utils::parse_content_blocks(&content);
vals.iter().for_each(|v| out.push(v.clone()));
}
if let Some(tools) = tool_responses {
tools
.iter()
.for_each(|t| out.push(ContentBlock::ToolResult(t.clone())));
}
Self::user(out)
}
Message::Assistant {
content,
tool_calls,
} => {
let mut out = Vec::new();
if !content.is_empty() {
let vals = utils::parse_content_blocks(&content);
vals.iter().for_each(|v| out.push(v.clone()));
}
if let Some(tools) = tool_calls {
tools
.iter()
.for_each(|t| out.push(ContentBlock::ToolUse(t.clone())));
}
Self::user(out)
}
}
}
}
#[allow(refining_impl_trait)]
#[async_trait]
impl CompletionModel for AnthropicCompletionModel {
fn build_client(
self,
preamble: impl AsRef<str>,
embedder_instances: Vec<Embedder>,
tools: ToolSet,
) -> Client<Self> {
Client::new(
self,
preamble,
DEFAULT_TEMP,
DEFAULT_TOKENS,
embedder_instances,
tools,
)
}
#[allow(clippy::too_many_lines)]
async fn send(
&mut self,
message: Message,
history: &MessageHistory,
tools: Option<&ToolSet>,
temperature: f64,
max_tokens: usize,
) -> Result<(Message, TokenUsage), CompletionError> {
let mut messages = history.clone();
if let Some(Message::Preamble(p)) = messages.first() {
self.system = Some(p.clone());
messages.remove(0);
}
messages.push(message);
let messages: Vec<_> = messages
.into_iter()
.map(Into::<AnthropicMessage>::into)
.collect();
let mut request_body = json!({
"store": true,
"model": self.model,
"messages": messages,
"temperature": temperature,
"max_tokens": max_tokens,
});
if let Some(tools) = tools {
let tools_serialized: Vec<serde_json::Value> =
tools.0.iter().map(|t| t.default_serializer()).collect();
if let Some(obj) = request_body.as_object_mut() {
obj.insert(
"tools".to_string(),
serde_json::Value::Array(tools_serialized),
);
}
}
if let Some(preamble) = &self.system {
if let Some(obj) = request_body.as_object_mut() {
obj.insert(
"system".to_string(),
serde_json::Value::String(preamble.to_string()),
);
}
}
let response = self
.client
.post(&self.api_url)
.header("Authorization", format!("Bearer {}", self.api_key))
.header("Content-Type", "application/json")
.json(&request_body)
.send()
.await
.map_err(|e| CompletionError::RequestError(e.to_string()))?;
if response.status().is_success() {
let response_json: serde_json::Value = response
.json()
.await
.map_err(|e| CompletionError::ParseError(e.to_string()))?;
let mut content: Vec<String> = vec![];
let mut tool_calls: Vec<ToolCall> = vec![];
let () = &response_json["content"]
.as_array()
.unwrap()
.iter()
.for_each(|c| match c["type"].as_str() {
Some("text") => {
content.push(c["text"].as_str().unwrap().to_string());
}
Some("thinking") => {
let thought = c["thinking"].as_str().unwrap().to_string();
let sig = c["signature"].as_str().unwrap().to_string();
content.push(format!(
"<sf_thinking>{thought}</sf_sig>{sig}</sf_thinking>"
));
}
Some("tool_use") => {
let id = c["id"].as_str().unwrap().to_string();
let name = c["name"].as_str().unwrap().to_string();
let arguments = c["input"].as_str().unwrap().to_string();
tool_calls.push(ToolCall {
id,
name,
arguments,
});
}
Some("redacted_thinking") => {
let data = c["data"].as_str().unwrap().to_string();
content.push(format!("<sf_r_thinking>{data}</sf_r_thinking>"));
}
_ => {}
});
let content: String = content.join("");
let usage_response = &response_json["usage"];
let usage_parse_error =
CompletionError::ParseError("Failed to parse usage data from response".to_string());
let input_tokens = usage_response["input_tokens"]
.as_u64()
.ok_or(usage_parse_error.clone())?;
let output_tokens = usage_response["output_tokens"]
.as_u64()
.ok_or(usage_parse_error.clone())?;
let token_usage = TokenUsage {
prompt_tokens: Some(input_tokens),
completion_tokens: Some(output_tokens),
total_tokens: Some(input_tokens + output_tokens),
};
let tool_calls = if tool_calls.is_empty() {
None
} else {
Some(tool_calls)
};
Ok((
Message::Assistant {
content,
tool_calls,
},
token_usage,
))
} else {
let status = response.status();
let error_msg = response
.text()
.await
.unwrap_or_else(|_| "Unknown error (failed to read response body)".to_string());
Err(CompletionError::ProviderError(status.into(), error_msg))?
}
}
}
#[cfg(test)]
mod test {
use seedframe::completion::Message;
use crate::{AnthropicMessage, ContentBlock};
#[test]
fn test_proper_message_conversion() {
let st = "start<sf_thinking>think1</sf_sig>make</sf_thinking>middle<sf_r_thinking>think2</sf_r_thinking>end";
let crate_messages = Message::User {
content: String::from(st),
tool_responses: None,
};
let converted = Into::<AnthropicMessage>::into(crate_messages);
assert_eq!(
converted,
AnthropicMessage::user(vec![
ContentBlock::Text {
text: "start".to_string()
},
ContentBlock::Thinking {
thinking: "think1".to_string(),
signature: "make".to_string()
},
ContentBlock::Text {
text: "middle".to_string()
},
ContentBlock::RedactedThinking {
data: "think2".to_string()
},
ContentBlock::Text {
text: "end".to_string()
},
])
);
}
}