use crate::{api::ChatMessage, AppState};
use axum::{extract::State, response::IntoResponse, Json};
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use uuid::Uuid;
#[derive(Debug, Deserialize)]
pub struct AnthropicMessageRequest {
pub model: String,
pub max_tokens: usize, pub messages: Vec<AnthropicMessage>,
#[serde(default)]
pub system: Option<String>,
#[serde(default)]
pub temperature: Option<f32>,
#[serde(default)]
pub top_p: Option<f32>,
#[serde(default)]
pub top_k: Option<i32>,
#[serde(default)]
pub stream: Option<bool>,
}
#[derive(Debug, Deserialize, Serialize, Clone)]
pub struct AnthropicMessage {
pub role: String, pub content: AnthropicContent,
}
#[derive(Debug, Deserialize, Serialize, Clone)]
#[serde(untagged)]
pub enum AnthropicContent {
Text(String),
Blocks(Vec<ContentBlock>),
}
#[derive(Debug, Deserialize, Serialize, Clone)]
pub struct ContentBlock {
#[serde(rename = "type")]
pub content_type: String,
pub text: Option<String>,
pub source: Option<ImageSource>,
}
#[derive(Debug, Deserialize, Serialize, Clone)]
pub struct ImageSource {
#[serde(rename = "type")]
pub source_type: String,
pub media_type: String,
pub data: String,
}
#[derive(Debug, Serialize)]
pub struct AnthropicMessageResponse {
pub id: String,
#[serde(rename = "type")]
pub response_type: String, pub role: String, pub content: Vec<AnthropicContentBlock>,
pub model: String,
pub stop_reason: String,
pub stop_sequence: Option<String>,
pub usage: AnthropicUsage,
}
#[derive(Debug, Serialize)]
pub struct AnthropicContentBlock {
#[serde(rename = "type")]
pub content_type: String, pub text: String,
}
#[derive(Debug, Serialize)]
pub struct AnthropicUsage {
pub input_tokens: usize,
pub output_tokens: usize,
}
impl From<AnthropicMessage> for ChatMessage {
fn from(msg: AnthropicMessage) -> Self {
let content = match msg.content {
AnthropicContent::Text(text) => text,
AnthropicContent::Blocks(blocks) => {
blocks
.iter()
.filter_map(|block| {
if block.content_type == "text" {
block.text.clone()
} else {
Some(format!("[{} content]", block.content_type))
}
})
.collect::<Vec<_>>()
.join("\n")
}
};
ChatMessage {
role: msg.role,
content,
}
}
}
pub async fn messages(
State(state): State<Arc<AppState>>,
Json(req): Json<AnthropicMessageRequest>,
) -> impl IntoResponse {
let internal_messages: Vec<ChatMessage> =
req.messages.into_iter().map(|msg| msg.into()).collect();
let Some(spec) = state.registry.to_spec(&req.model) else {
tracing::error!("Model '{}' not found in registry", req.model);
return axum::http::StatusCode::NOT_FOUND.into_response();
};
let system_message = req.system.clone();
let mut options = crate::engine::GenOptions {
max_tokens: req.max_tokens,
stream: req.stream.unwrap_or(false),
..Default::default()
};
if let Some(temp) = req.temperature {
options.temperature = temp;
}
if let Some(p) = req.top_p {
options.top_p = p;
}
if let Some(k) = req.top_k {
options.top_k = k;
}
let (system_prompt, conversation_pairs) =
extract_system_and_pairs(&internal_messages, system_message);
let mut prompt = String::new();
if let Some(system) = system_prompt {
prompt.push_str(&format!("System: {}\n\n", system));
}
for (user_msg, assistant_msg) in conversation_pairs {
prompt.push_str(&format!("Human: {}\n", user_msg));
if let Some(assistant) = assistant_msg {
prompt.push_str(&format!("Assistant: {}\n", assistant));
} else {
prompt.push_str("Assistant: ");
}
}
let Ok(loaded_model) = state.engine.load(&spec).await else {
tracing::error!("Failed to load model '{}'", req.model);
return axum::http::StatusCode::INTERNAL_SERVER_ERROR.into_response();
};
match loaded_model.generate(&prompt, options, None).await {
Ok(response) => {
let anthropic_response = AnthropicMessageResponse {
id: format!("msg_{}", Uuid::new_v4()),
response_type: "message".to_string(),
role: "assistant".to_string(),
content: vec![AnthropicContentBlock {
content_type: "text".to_string(),
text: response.clone(),
}],
model: req.model,
stop_reason: "end_turn".to_string(),
stop_sequence: None,
usage: AnthropicUsage {
input_tokens: estimate_tokens(&prompt),
output_tokens: estimate_tokens(&response),
},
};
Json(anthropic_response).into_response()
}
Err(e) => {
tracing::error!("Generation failed: {}", e);
axum::http::StatusCode::INTERNAL_SERVER_ERROR.into_response()
}
}
}
fn extract_system_and_pairs(
messages: &[ChatMessage],
explicit_system: Option<String>,
) -> (Option<String>, Vec<(&str, Option<&str>)>) {
let mut pairs = Vec::new();
let mut system_message = explicit_system;
let start_idx = if let Some(first) = messages.first() {
if first.role == "system" {
system_message = Some(first.content.clone());
1
} else {
0
}
} else {
0
};
let mut i = start_idx;
while i < messages.len() {
if messages[i].role == "user" {
let user_msg = &messages[i].content;
let assistant_msg = if i + 1 < messages.len() && messages[i + 1].role == "assistant" {
Some(messages[i + 1].content.as_str())
} else {
None
};
pairs.push((user_msg.as_str(), assistant_msg));
if assistant_msg.is_some() {
i += 2;
} else {
i += 1;
}
} else {
i += 1;
}
}
(system_message, pairs)
}
fn estimate_tokens(text: &str) -> usize {
(text.len() as f32 / 4.0).ceil() as usize
}
#[cfg(test)]
mod tests {
use super::*;
use crate::api::ChatMessage;
#[test]
fn test_anthropic_message_conversion() {
let anthropic_msg = AnthropicMessage {
role: "user".to_string(),
content: AnthropicContent::Text("Hello, world!".to_string()),
};
let chat_msg: ChatMessage = anthropic_msg.into();
assert_eq!(chat_msg.role, "user");
assert_eq!(chat_msg.content, "Hello, world!");
}
#[test]
fn test_anthropic_content_blocks_conversion() {
let anthropic_msg = AnthropicMessage {
role: "user".to_string(),
content: AnthropicContent::Blocks(vec![
ContentBlock {
content_type: "text".to_string(),
text: Some("Hello".to_string()),
source: None,
},
ContentBlock {
content_type: "text".to_string(),
text: Some("World".to_string()),
source: None,
},
]),
};
let chat_msg: ChatMessage = anthropic_msg.into();
assert_eq!(chat_msg.content, "Hello\nWorld");
}
#[test]
fn test_extract_system_and_pairs() {
let messages = vec![
ChatMessage {
role: "system".to_string(),
content: "You are a helpful assistant".to_string(),
},
ChatMessage {
role: "user".to_string(),
content: "Hello".to_string(),
},
ChatMessage {
role: "assistant".to_string(),
content: "Hi there!".to_string(),
},
ChatMessage {
role: "user".to_string(),
content: "How are you?".to_string(),
},
];
let (system, pairs) = extract_system_and_pairs(&messages, None);
assert_eq!(system, Some("You are a helpful assistant".to_string()));
assert_eq!(pairs.len(), 2);
assert_eq!(pairs[0], ("Hello", Some("Hi there!")));
assert_eq!(pairs[1], ("How are you?", None));
}
#[test]
fn test_explicit_system_message() {
let messages = vec![ChatMessage {
role: "user".to_string(),
content: "Hello".to_string(),
}];
let (system, pairs) =
extract_system_and_pairs(&messages, Some("Custom system".to_string()));
assert_eq!(system, Some("Custom system".to_string()));
assert_eq!(pairs.len(), 1);
assert_eq!(pairs[0], ("Hello", None));
}
#[test]
fn test_token_estimation() {
assert_eq!(estimate_tokens(""), 0);
assert_eq!(estimate_tokens("test"), 1); assert_eq!(estimate_tokens("hello world"), 3); }
}