#![allow(unreachable_pub)]
use axum::{extract::State, http::StatusCode, Json};
use serde::{Deserialize, Serialize};
#[cfg(feature = "gpu")]
use super::ContinuousBatchRequest;
use super::{AppState, ChatMessage, ErrorResponse, Usage};
use crate::generate::{GenerationConfig, SamplingStrategy};
use crate::registry::ModelInfo;
type RErr = (StatusCode, Json<ErrorResponse>);
fn rerr(state: &AppState, status: StatusCode, msg: impl std::fmt::Display) -> RErr {
state.metrics.record_failure();
(
status,
Json(ErrorResponse {
error: msg.to_string(),
}),
)
}
fn epoch_secs() -> u64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0)
}
fn epoch_millis() -> u128 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_millis()
}
#[derive(Debug, Clone)]
pub struct ContextWindowConfig {
pub max_tokens: usize,
pub reserved_output_tokens: usize,
pub preserve_system: bool,
}
impl Default for ContextWindowConfig {
fn default() -> Self {
Self {
max_tokens: 4096,
reserved_output_tokens: 256,
preserve_system: true,
}
}
}
impl ContextWindowConfig {
#[must_use]
pub fn new(max_tokens: usize) -> Self {
Self {
max_tokens,
..Default::default()
}
}
#[must_use]
pub fn with_reserved_output(mut self, tokens: usize) -> Self {
self.reserved_output_tokens = tokens;
self
}
pub fn available_tokens(&self) -> usize {
self.max_tokens.saturating_sub(self.reserved_output_tokens)
}
}
pub struct ContextWindowManager {
config: ContextWindowConfig,
}
impl ContextWindowManager {
#[must_use]
pub fn new(config: ContextWindowConfig) -> Self {
Self { config }
}
#[must_use]
pub fn default_manager() -> Self {
Self::new(ContextWindowConfig::default())
}
fn estimate_tokens(text: &str) -> usize {
const ROLE_OVERHEAD: usize = 10;
text.len().div_ceil(4) + ROLE_OVERHEAD
}
pub fn truncate_messages(&self, messages: &[ChatMessage]) -> (Vec<ChatMessage>, bool) {
let available = self.config.available_tokens();
let total_tokens: usize = messages
.iter()
.map(|m| Self::estimate_tokens(&m.content))
.sum();
if total_tokens <= available {
return (messages.to_vec(), false);
}
let mut result = Vec::new();
let mut used_tokens = 0;
let (system_msgs, other_msgs): (Vec<_>, Vec<_>) = messages
.iter()
.partition(|m| m.role == "system" && self.config.preserve_system);
for msg in &system_msgs {
let tokens = Self::estimate_tokens(&msg.content);
if used_tokens + tokens <= available {
result.push((*msg).clone());
used_tokens += tokens;
}
}
let mut temp_msgs: Vec<ChatMessage> = Vec::new();
for msg in other_msgs.iter().rev() {
let tokens = Self::estimate_tokens(&msg.content);
if used_tokens + tokens <= available {
temp_msgs.push((*msg).clone());
used_tokens += tokens;
} else {
break;
}
}
temp_msgs.reverse();
result.extend(temp_msgs);
(result, true)
}
pub fn needs_truncation(&self, messages: &[ChatMessage]) -> bool {
let available = self.config.available_tokens();
let total_tokens: usize = messages
.iter()
.map(|m| Self::estimate_tokens(&m.content))
.sum();
total_tokens > available
}
pub fn estimate_total_tokens(&self, messages: &[ChatMessage]) -> usize {
messages
.iter()
.map(|m| Self::estimate_tokens(&m.content))
.sum()
}
}
pub fn format_chat_messages(messages: &[ChatMessage], model_name: Option<&str>) -> String {
use crate::chat_template::{self, ChatMessage as TemplateMessage};
let template_messages: Vec<TemplateMessage> = messages
.iter()
.map(|m| TemplateMessage::new(&m.role, &m.content))
.collect();
chat_template::format_messages(&template_messages, model_name).unwrap_or_else(|_| {
let mut prompt = String::new();
for msg in messages {
prompt.push_str(&msg.content);
prompt.push('\n');
}
prompt
})
}
pub fn clean_chat_output(text: &str) -> String {
const STOP_SEQUENCES: &[&str] = &[
"<|im_end|>", "<|endoftext|>", "<|end|>", "</s>", "\nHuman:", "\nUser:", "\n\nHuman:", "\n\nUser:", "<|im_start|>", ];
let mut result = text.to_string();
let mut earliest_pos = result.len();
for stop in STOP_SEQUENCES {
if let Some(pos) = result.find(stop) {
if pos < earliest_pos {
earliest_pos = pos;
}
}
}
result.truncate(earliest_pos);
result.trim().to_string()
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmbeddingRequest {
pub input: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub model: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmbeddingResponse {
pub object: String,
pub data: Vec<EmbeddingData>,
pub model: String,
pub usage: EmbeddingUsage,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmbeddingData {
pub object: String,
pub index: usize,
pub embedding: Vec<f32>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmbeddingUsage {
pub prompt_tokens: usize,
pub total_tokens: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelMetadataResponse {
pub id: String,
pub name: String,
pub format: String,
pub size_bytes: u64,
#[serde(skip_serializing_if = "Option::is_none")]
pub quantization: Option<String>,
pub context_length: usize,
#[serde(skip_serializing_if = "Option::is_none")]
pub lineage: Option<ModelLineage>,
pub loaded: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelLineage {
pub uri: String,
pub version: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub recipe: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub parent: Option<String>,
pub content_hash: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ReloadRequest {
#[serde(skip_serializing_if = "Option::is_none")]
pub model: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub path: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ReloadResponse {
pub success: bool,
pub message: String,
pub reload_time_ms: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CompletionRequest {
pub model: String,
pub prompt: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub max_tokens: Option<usize>,
#[serde(skip_serializing_if = "Option::is_none")]
pub temperature: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_p: Option<f64>,
#[serde(skip_serializing_if = "Option::is_none")]
pub stop: Option<Vec<String>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CompletionResponse {
pub id: String,
pub object: String,
pub created: u64,
pub model: String,
pub choices: Vec<CompletionChoice>,
pub usage: Usage,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CompletionChoice {
pub text: String,
pub index: usize,
#[serde(skip_serializing_if = "Option::is_none")]
pub logprobs: Option<serde_json::Value>,
pub finish_reason: String,
}
include!("realize_handlers_embed_completion.rs");
include!("gpu_completions_handler.rs");
include!("realize_handlers_model_lineage.rs");