use std::sync::atomic::Ordering;
use std::sync::Arc;
use axum::{
extract::{Json, State},
http::StatusCode,
response::{sse::Event, IntoResponse, Response, Sse},
};
use futures::stream::{self, Stream};
use serde::{Deserialize, Serialize};
use tokio_stream::wrappers::ReceiverStream;
use tokio_stream::StreamExt as _;
use super::models::RequestGuard;
use super::protocol::ChatMessage;
use super::server::Daemon;
pub type AppState = Arc<Daemon>;
use super::server::resolve_chat_stop_sequences;
fn apply_default_system_prompt(
messages: Vec<ChatMessage>,
system_prompt: Option<&str>,
) -> Vec<ChatMessage> {
let Some(system_prompt) = system_prompt else {
return messages;
};
if system_prompt.trim().is_empty() {
return messages;
}
if messages
.iter()
.any(|m| m.role.eq_ignore_ascii_case("system"))
{
return messages;
}
let mut with_system = Vec::with_capacity(messages.len() + 1);
with_system.push(ChatMessage {
role: "system".to_string(),
content: system_prompt.to_string().into(),
name: None,
tool_calls: None,
tool_call_id: None,
});
with_system.extend(messages);
with_system
}
#[derive(Debug, Deserialize)]
pub struct MessagesRequest {
pub model: Option<String>,
pub max_tokens: u32,
pub messages: Vec<AnthropicMessage>,
#[serde(default)]
pub system: Option<String>,
#[serde(default)]
pub stream: bool,
#[serde(default)]
pub temperature: Option<f32>,
#[serde(default)]
pub top_p: Option<f32>,
#[serde(default)]
pub top_k: Option<i32>,
#[serde(default)]
pub stop_sequences: Option<Vec<String>>,
#[serde(default)]
pub metadata: Option<serde_json::Value>,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct AnthropicMessage {
pub role: String,
pub content: MessageContent,
}
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(untagged)]
pub enum MessageContent {
Text(String),
Blocks(Vec<ContentBlock>),
}
impl MessageContent {
pub fn as_text(&self) -> String {
match self {
MessageContent::Text(s) => s.clone(),
MessageContent::Blocks(blocks) => blocks
.iter()
.filter_map(|b| match b {
ContentBlock::Text { text } => Some(text.clone()),
_ => None,
})
.collect::<Vec<_>>()
.join("\n"),
}
}
}
#[derive(Debug, Clone, Deserialize, Serialize)]
#[serde(tag = "type")]
pub enum ContentBlock {
#[serde(rename = "text")]
Text { text: String },
#[serde(rename = "image")]
Image { source: ImageSource },
#[serde(rename = "tool_use")]
ToolUse {
id: String,
name: String,
input: serde_json::Value,
},
#[serde(rename = "tool_result")]
ToolResult {
tool_use_id: String,
content: String,
},
}
#[derive(Debug, Clone, Deserialize, Serialize)]
pub struct ImageSource {
#[serde(rename = "type")]
pub source_type: String, pub media_type: String, pub data: String, }
#[derive(Debug, Serialize)]
pub struct MessagesResponse {
pub id: String,
#[serde(rename = "type")]
pub object_type: String,
pub role: String,
pub content: Vec<ResponseContentBlock>,
pub model: String,
pub stop_reason: Option<String>,
pub stop_sequence: Option<String>,
pub usage: AnthropicUsage,
}
#[derive(Debug, Clone, Serialize)]
#[serde(tag = "type")]
pub enum ResponseContentBlock {
#[serde(rename = "text")]
Text { text: String },
}
#[derive(Debug, Clone, Serialize)]
pub struct AnthropicUsage {
pub input_tokens: u32,
pub output_tokens: u32,
}
#[derive(Debug, Serialize)]
pub struct AnthropicError {
#[serde(rename = "type")]
pub error_type: String,
pub error: AnthropicErrorDetail,
}
#[derive(Debug, Serialize)]
pub struct AnthropicErrorDetail {
#[serde(rename = "type")]
pub error_type: String,
pub message: String,
}
#[allow(dead_code)]
#[derive(Debug, Serialize)]
#[serde(tag = "type")]
pub enum StreamEvent {
#[serde(rename = "message_start")]
MessageStart { message: MessageStartData },
#[serde(rename = "content_block_start")]
ContentBlockStart {
index: usize,
content_block: ContentBlockStartData,
},
#[serde(rename = "content_block_delta")]
ContentBlockDelta {
index: usize,
delta: ContentBlockDelta,
},
#[serde(rename = "content_block_stop")]
ContentBlockStop { index: usize },
#[serde(rename = "message_delta")]
MessageDelta {
delta: MessageDeltaData,
usage: AnthropicUsage,
},
#[serde(rename = "message_stop")]
MessageStop,
#[serde(rename = "ping")]
Ping,
#[serde(rename = "error")]
Error { error: AnthropicErrorDetail },
}
#[derive(Debug, Serialize)]
pub struct MessageStartData {
pub id: String,
#[serde(rename = "type")]
pub object_type: String,
pub role: String,
pub content: Vec<serde_json::Value>,
pub model: String,
pub stop_reason: Option<String>,
pub stop_sequence: Option<String>,
pub usage: AnthropicUsage,
}
#[derive(Debug, Serialize)]
pub struct ContentBlockStartData {
#[serde(rename = "type")]
pub block_type: String,
pub text: String,
}
#[derive(Debug, Serialize)]
#[serde(tag = "type")]
pub enum ContentBlockDelta {
#[serde(rename = "text_delta")]
TextDelta { text: String },
}
#[derive(Debug, Serialize)]
pub struct MessageDeltaData {
pub stop_reason: Option<String>,
pub stop_sequence: Option<String>,
}
pub async fn messages_handler(
State(daemon): State<AppState>,
Json(request): Json<MessagesRequest>,
) -> Response {
if request.stream {
match handle_messages_streaming(daemon, request).await {
Ok(stream) => stream.into_response(),
Err(e) => e.into_response(),
}
} else {
match handle_messages(daemon, request).await {
Ok(response) => (StatusCode::OK, Json(response)).into_response(),
Err(e) => e.into_response(),
}
}
}
async fn handle_messages(
daemon: Arc<Daemon>,
request: MessagesRequest,
) -> Result<MessagesResponse, ApiError> {
let mut messages = convert_messages(&request)?;
let stop = request.stop_sequences.unwrap_or_default();
let loaded = daemon
.models
.get(request.model.as_deref())
.await
.map_err(|e| ApiError::model_not_found(e.to_string()))?;
let _guard = RequestGuard::new(loaded.clone());
daemon.active_requests.fetch_add(1, Ordering::Relaxed);
messages = apply_default_system_prompt(messages, loaded.config.system_prompt.as_deref());
let model_alias = loaded.alias.clone();
let prompt = daemon.build_chat_prompt(&loaded.model, &messages);
let all_stops = resolve_chat_stop_sequences(&loaded, stop);
let mut sampler_params = crate::SamplerParams::default();
sampler_params.temperature = request
.temperature
.or(loaded.config.temperature)
.unwrap_or(1.0);
sampler_params.top_p = request
.top_p
.or(loaded.config.top_p)
.unwrap_or(sampler_params.top_p);
sampler_params.top_k = request
.top_k
.or(loaded.config.top_k)
.unwrap_or(sampler_params.top_k);
let result = daemon
.generate_text(
&loaded,
&prompt,
request.max_tokens,
sampler_params,
&all_stops,
None, )
.await;
daemon.active_requests.fetch_sub(1, Ordering::Relaxed);
let result = result.map_err(|e| ApiError::generation_failed(e.to_string()))?;
let (text, prompt_tokens, completion_tokens) = result;
Ok(MessagesResponse {
id: generate_message_id(),
object_type: "message".to_string(),
role: "assistant".to_string(),
content: vec![ResponseContentBlock::Text { text }],
model: model_alias,
stop_reason: Some("end_turn".to_string()),
stop_sequence: None,
usage: AnthropicUsage {
input_tokens: prompt_tokens,
output_tokens: completion_tokens,
},
})
}
async fn handle_messages_streaming(
daemon: Arc<Daemon>,
request: MessagesRequest,
) -> Result<Sse<impl Stream<Item = Result<Event, std::convert::Infallible>>>, ApiError> {
let mut messages = convert_messages(&request)?;
let stop = request.stop_sequences.unwrap_or_default();
let loaded = daemon
.models
.get(request.model.as_deref())
.await
.map_err(|e| ApiError::model_not_found(e.to_string()))?;
messages = apply_default_system_prompt(messages, loaded.config.system_prompt.as_deref());
let model_alias = loaded.alias.clone();
let message_id = generate_message_id();
let prompt = daemon.build_chat_prompt(&loaded.model, &messages);
let all_stops = resolve_chat_stop_sequences(&loaded, stop);
let mut sampler_params = crate::SamplerParams::default();
sampler_params.temperature = request
.temperature
.or(loaded.config.temperature)
.unwrap_or(1.0);
sampler_params.top_p = request
.top_p
.or(loaded.config.top_p)
.unwrap_or(sampler_params.top_p);
sampler_params.top_k = request
.top_k
.or(loaded.config.top_k)
.unwrap_or(sampler_params.top_k);
let (rx, prompt_tokens, _request_id) = daemon
.generate_text_streaming(
loaded,
prompt,
request.max_tokens,
sampler_params,
all_stops,
)
.await
.map_err(|e| ApiError::generation_failed(e.to_string()))?;
let message_start = StreamEvent::MessageStart {
message: MessageStartData {
id: message_id.clone(),
object_type: "message".to_string(),
role: "assistant".to_string(),
content: vec![],
model: model_alias.clone(),
stop_reason: None,
stop_sequence: None,
usage: AnthropicUsage {
input_tokens: prompt_tokens,
output_tokens: 0,
},
},
};
let content_block_start = StreamEvent::ContentBlockStart {
index: 0,
content_block: ContentBlockStartData {
block_type: "text".to_string(),
text: String::new(),
},
};
let initial_events = vec![
Ok(Event::default().event("message_start").data(
serde_json::to_string(&message_start)
.unwrap_or_else(|e| format!(r#"{{"error":"{}"}}"#, e)),
)),
Ok(Event::default().event("content_block_start").data(
serde_json::to_string(&content_block_start)
.unwrap_or_else(|e| format!(r#"{{"error":"{}"}}"#, e)),
)),
];
let initial_stream = stream::iter(initial_events);
let _model_alias_clone = model_alias.clone();
let token_stream = ReceiverStream::new(rx).map(move |chunk| {
let delta = StreamEvent::ContentBlockDelta {
index: 0,
delta: ContentBlockDelta::TextDelta { text: chunk.delta },
};
Ok(Event::default().event("content_block_delta").data(
serde_json::to_string(&delta).unwrap_or_else(|e| format!(r#"{{"error":"{}"}}"#, e)),
))
});
let final_events = vec![
Ok(Event::default().event("content_block_stop").data(
serde_json::to_string(&StreamEvent::ContentBlockStop { index: 0 })
.unwrap_or_else(|e| format!(r#"{{"error":"{}"}}"#, e)),
)),
Ok(Event::default().event("message_stop").data(
serde_json::to_string(&StreamEvent::MessageStop)
.unwrap_or_else(|e| format!(r#"{{"error":"{}"}}"#, e)),
)),
];
let final_stream = stream::iter(final_events);
let combined = initial_stream.chain(token_stream).chain(final_stream);
Ok(Sse::new(combined).keep_alive(
axum::response::sse::KeepAlive::new()
.interval(std::time::Duration::from_secs(15))
.text("ping"),
))
}
fn convert_messages(request: &MessagesRequest) -> Result<Vec<ChatMessage>, ApiError> {
let mut messages = Vec::new();
if let Some(ref system) = request.system {
messages.push(ChatMessage {
role: "system".to_string(),
content: system.clone().into(),
name: None,
tool_calls: None,
tool_call_id: None,
});
}
for msg in &request.messages {
messages.push(ChatMessage {
role: msg.role.clone(),
content: msg.content.as_text().into(),
name: None,
tool_calls: None,
tool_call_id: None,
});
}
Ok(messages)
}
fn generate_message_id() -> String {
use rand::Rng;
let mut rng = rand::thread_rng();
let chars: String = (0..24)
.map(|_| {
let idx = rng.gen_range(0..36);
if idx < 10 {
(b'0' + idx) as char
} else {
(b'a' + idx - 10) as char
}
})
.collect();
format!("msg_{}", chars)
}
pub struct ApiError {
status: StatusCode,
error_type: String,
message: String,
}
impl ApiError {
fn model_not_found(message: String) -> Self {
Self {
status: StatusCode::NOT_FOUND,
error_type: "not_found_error".to_string(),
message,
}
}
fn generation_failed(message: String) -> Self {
Self {
status: StatusCode::INTERNAL_SERVER_ERROR,
error_type: "api_error".to_string(),
message,
}
}
}
impl IntoResponse for ApiError {
fn into_response(self) -> Response {
let body = AnthropicError {
error_type: "error".to_string(),
error: AnthropicErrorDetail {
error_type: self.error_type,
message: self.message,
},
};
(self.status, Json(body)).into_response()
}
}