fn mean_pool_hidden_states(
hidden: &crate::tensor::Tensor<f32>,
token_ids: &[u32],
hidden_dim: usize,
tokenizer: &crate::tokenizer::BPETokenizer,
) -> Vec<f32> {
let data = hidden.data();
let seq_len = token_ids.len();
let mut sum = vec![0.0f32; hidden_dim];
let mut counted = 0usize;
for (t, &tok) in token_ids.iter().enumerate().take(seq_len) {
if tokenizer.is_special_token(tok) {
continue;
}
let row = &data[t * hidden_dim..(t + 1) * hidden_dim];
for (s, &h) in sum.iter_mut().zip(row.iter()) {
*s += h;
}
counted += 1;
}
if counted == 0 {
for t in 0..seq_len {
let row = &data[t * hidden_dim..(t + 1) * hidden_dim];
for (s, &h) in sum.iter_mut().zip(row.iter()) {
*s += h;
}
}
counted = seq_len;
}
if counted > 0 {
let inv = 1.0 / counted as f32;
for s in &mut sum {
*s *= inv;
}
}
sum
}
pub async fn realize_embed_handler(
State(state): State<AppState>,
Json(request): Json<EmbeddingRequest>,
) -> Result<Json<EmbeddingResponse>, (StatusCode, Json<ErrorResponse>)> {
let model_id = request.model.as_deref();
let (model, tokenizer) = state.get_model(model_id).map_err(|e| {
(
StatusCode::NOT_FOUND,
Json(ErrorResponse {
error: e.to_string(),
}),
)
})?;
let hidden_dim = model.config().hidden_dim;
let mut data = Vec::with_capacity(request.input.len());
let mut prompt_tokens = 0usize;
for (index, text) in request.input.iter().enumerate() {
let token_ids = tokenizer.encode(text);
if token_ids.is_empty() {
return Err((
StatusCode::BAD_REQUEST,
Json(ErrorResponse {
error: format!("Input at index {index} cannot be empty"),
}),
));
}
prompt_tokens += token_ids.len();
let usize_ids: Vec<usize> = token_ids.iter().map(|&t| t as usize).collect();
let hidden = model.forward_hidden(&usize_ids).map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(ErrorResponse {
error: format!("Embedding forward pass failed: {e}"),
}),
)
})?;
let mut embedding = mean_pool_hidden_states(&hidden, &token_ids, hidden_dim, &tokenizer);
let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm > 0.0 {
for v in &mut embedding {
*v /= norm;
}
}
data.push(EmbeddingData {
object: "embedding".to_string(),
index,
embedding,
});
}
Ok(Json(EmbeddingResponse {
object: "list".to_string(),
data,
model: request.model.unwrap_or_else(|| "default".to_string()),
usage: EmbeddingUsage {
prompt_tokens,
total_tokens: prompt_tokens,
},
}))
}
pub async fn realize_model_handler(
State(state): State<AppState>,
) -> Result<Json<ModelMetadataResponse>, (StatusCode, Json<ErrorResponse>)> {
let model_info = if let Some(registry) = &state.registry {
let models = registry.list();
models.first().cloned()
} else {
Some(ModelInfo {
id: "default".to_string(),
name: "Default Model".to_string(),
description: "Single model deployment".to_string(),
format: "gguf".to_string(),
loaded: true,
})
};
let info = model_info.ok_or_else(|| {
(
StatusCode::NOT_FOUND,
Json(ErrorResponse {
error: "No model loaded".to_string(),
}),
)
})?;
Ok(Json(ModelMetadataResponse {
id: info.id.clone(),
name: info.name,
format: info.format,
size_bytes: 0, quantization: Some("Q4_K_M".to_string()),
context_length: 4096,
lineage: Some(ModelLineage {
uri: format!("pacha://{}:latest", info.id),
version: "1.0.0".to_string(),
recipe: None,
parent: None,
content_hash: "blake3:0".repeat(16),
}),
loaded: info.loaded,
}))
}
pub async fn realize_reload_handler(
State(state): State<AppState>,
Json(request): Json<ReloadRequest>,
) -> Result<Json<ReloadResponse>, (StatusCode, Json<ErrorResponse>)> {
let start = std::time::Instant::now();
let model_id = request.model.unwrap_or_else(|| "default".to_string());
let registry = state.registry.as_ref().ok_or_else(|| {
(
StatusCode::NOT_IMPLEMENTED,
Json(ErrorResponse {
error: "Hot-reload requires registry mode. Start server with --registry flag."
.to_string(),
}),
)
})?;
let model_path = request.path.ok_or_else(|| {
(
StatusCode::BAD_REQUEST,
Json(ErrorResponse {
error: "Model path is required for reload. Provide 'path' field with path to model file.".to_string(),
}),
)
})?;
if !registry.contains(&model_id) {
return Err((
StatusCode::NOT_FOUND,
Json(ErrorResponse {
error: format!(
"Model '{}' not found in registry. Use POST /realize/models to register first.",
model_id
),
}),
));
}
if !std::path::Path::new(&model_path).exists() {
return Err((
StatusCode::BAD_REQUEST,
Json(ErrorResponse {
error: format!("Model file not found: {}", model_path),
}),
));
}
Ok(Json(ReloadResponse {
success: true,
message: format!(
"Model '{}' reload validated from '{}'. Atomic swap ready.",
model_id, model_path
),
reload_time_ms: start.elapsed().as_millis() as u64,
}))
}
fn completion_resp(
id_prefix: &str,
model: String,
text: String,
prompt_tokens: usize,
completion_tokens: usize,
max_tokens: usize,
) -> CompletionResponse {
let finish_reason = if completion_tokens >= max_tokens {
"length"
} else {
"stop"
};
CompletionResponse {
id: format!("{id_prefix}-{}", epoch_millis()),
object: "text_completion".to_string(),
created: epoch_secs(),
model,
choices: vec![CompletionChoice {
text,
index: 0,
logprobs: None,
finish_reason: finish_reason.to_string(),
}],
usage: Usage {
prompt_tokens,
completion_tokens,
total_tokens: prompt_tokens + completion_tokens,
},
}
}
#[cfg(feature = "gpu")]
async fn try_batch_completion(
state: &AppState,
tokenizer: &crate::tokenizer::BPETokenizer,
prompt_ids: &[u32],
prompt_tokens: usize,
max_tokens: usize,
temperature: f32,
start: std::time::Instant,
) -> Result<Option<CompletionResponse>, RErr> {
if !state.batch_enabled() {
return Ok(None);
}
let batch_tx = match state.batch_request_tx() {
Some(tx) => tx,
None => return Ok(None),
};
let (response_tx, response_rx) = tokio::sync::oneshot::channel();
let batch_request = ContinuousBatchRequest {
prompt_tokens: prompt_ids.to_vec(),
max_tokens,
temperature,
top_k: if temperature == 0.0 { 1 } else { 40 },
response_tx,
submitted_at: std::time::Instant::now(),
};
if batch_tx.send(batch_request).await.is_err() {
return Ok(None);
}
let batch_response = match response_rx.await {
Ok(r) => r,
Err(_) => return Ok(None),
};
let token_ids = batch_response.generated_tokens().to_vec();
let completion_tokens = token_ids.len();
let text = tokenizer
.decode(&token_ids)
.map_err(|e| rerr(state, StatusCode::INTERNAL_SERVER_ERROR, e))?;
state
.metrics
.record_success(completion_tokens, start.elapsed());
Ok(Some(completion_resp(
"cmpl-batch",
format!("batch-q4k-{}", batch_response.batch_size),
text,
prompt_tokens,
completion_tokens,
max_tokens,
)))
}
pub(crate) fn truncate_at_stop(text: String, stops: Option<&[String]>) -> String {
let Some(stops) = stops else {
return text;
};
let cut = stops
.iter()
.filter(|s| !s.is_empty())
.filter_map(|s| text.find(s.as_str()))
.min();
match cut {
Some(pos) => text[..pos].to_string(),
None => text,
}
}
#[cfg(feature = "gpu")]
async fn try_cached_completions(
state: &AppState,
request: &CompletionRequest,
max_tokens: usize,
temperature: f32,
start: std::time::Instant,
) -> Result<Option<CompletionResponse>, RErr> {
use crate::gguf::QuantizedGenerateConfig;
let cached_model = match state.cached_model() {
Some(m) => m,
None => return Ok(None),
};
let tokenizer = state.tokenizer.clone().ok_or_else(|| {
rerr(
state,
StatusCode::INTERNAL_SERVER_ERROR,
"No tokenizer available",
)
})?;
let prompt_ids = tokenizer.encode(&request.prompt);
if prompt_ids.is_empty() {
return Err(rerr(
state,
StatusCode::BAD_REQUEST,
"Prompt cannot be empty",
));
}
let prompt_tokens = prompt_ids.len();
if let Some(r) = try_batch_completion(
state,
&tokenizer,
&prompt_ids,
prompt_tokens,
max_tokens,
temperature,
start,
)
.await?
{
return Ok(Some(r));
}
let q_config = QuantizedGenerateConfig {
max_tokens,
temperature,
top_k: if temperature == 0.0 { 1 } else { 40 },
stop_tokens: Vec::new(),
trace: state.is_trace_enabled(),
..Default::default()
};
let generated = if let Some(metrics) = state.dispatch_metrics() {
cached_model
.generate_with_cache_adaptive(&prompt_ids, &q_config, metrics)
.map_err(|e| rerr(state, StatusCode::INTERNAL_SERVER_ERROR, e))?
} else {
cached_model
.generate_with_cache(&prompt_ids, &q_config)
.map_err(|e| rerr(state, StatusCode::INTERNAL_SERVER_ERROR, e))?
};
let token_ids: Vec<u32> = generated.iter().skip(prompt_tokens).copied().collect();
let completion_tokens = token_ids.len();
let text = tokenizer
.decode(&token_ids)
.map_err(|e| rerr(state, StatusCode::INTERNAL_SERVER_ERROR, e))?;
let text = truncate_at_stop(text, request.stop.as_deref());
state
.metrics
.record_success(completion_tokens, start.elapsed());
Ok(Some(completion_resp(
"cmpl-cached",
"cached-q4k".to_string(),
text,
prompt_tokens,
completion_tokens,
max_tokens,
)))
}
fn try_quantized_completions(
state: &AppState,
request: &CompletionRequest,
max_tokens: usize,
temperature: f32,
start: std::time::Instant,
) -> Result<Option<CompletionResponse>, RErr> {
use crate::gguf::QuantizedGenerateConfig;
let quantized_model = match state.quantized_model() {
Some(m) => m,
None => return Ok(None),
};
let tokenizer = state.tokenizer.clone().ok_or_else(|| {
rerr(
state,
StatusCode::INTERNAL_SERVER_ERROR,
"No tokenizer available",
)
})?;
let prompt_ids = tokenizer.encode(&request.prompt);
if prompt_ids.is_empty() {
return Err(rerr(
state,
StatusCode::BAD_REQUEST,
"Prompt cannot be empty",
));
}
let prompt_tokens = prompt_ids.len();
let q_config = QuantizedGenerateConfig {
max_tokens,
temperature,
top_k: if temperature == 0.0 { 1 } else { 40 },
stop_tokens: Vec::new(),
trace: state.is_trace_enabled(),
..Default::default()
};
let generated = quantized_model
.generate_with_cache(&prompt_ids, &q_config)
.map_err(|e| rerr(state, StatusCode::INTERNAL_SERVER_ERROR, e))?;
let token_ids: Vec<u32> = generated.iter().skip(prompt_tokens).copied().collect();
let completion_tokens = token_ids.len();
let text = tokenizer
.decode(&token_ids)
.map_err(|e| rerr(state, StatusCode::INTERNAL_SERVER_ERROR, e))?;
let text = truncate_at_stop(text, request.stop.as_deref());
state
.metrics
.record_success(completion_tokens, start.elapsed());
Ok(Some(completion_resp(
"cmpl-q4k",
request.model.clone(),
text,
prompt_tokens,
completion_tokens,
max_tokens,
)))
}
#[cfg(test)]
mod pmat754_stop_truncation_tests {
use super::truncate_at_stop;
#[test]
fn no_stops_returns_unchanged() {
assert_eq!(truncate_at_stop("hello world".to_string(), None), "hello world");
assert_eq!(truncate_at_stop("hello".to_string(), Some(&[])), "hello");
}
#[test]
fn truncates_at_earliest_position_not_first_listed() {
let stops = vec!["world".to_string(), "hello".to_string()];
assert_eq!(truncate_at_stop("hello world".to_string(), Some(&stops)), "");
let one = vec!["END".to_string()];
assert_eq!(
truncate_at_stop("keep thisENDdrop that".to_string(), Some(&one)),
"keep this"
);
}
#[test]
fn stop_absent_keeps_text() {
let stops = vec!["XYZ".to_string()];
assert_eq!(truncate_at_stop("hello".to_string(), Some(&stops)), "hello");
}
#[test]
fn empty_stop_strings_ignored() {
let stops = vec![String::new(), "stop".to_string()];
assert_eq!(truncate_at_stop("a stop b".to_string(), Some(&stops)), "a ");
}
}