fn streaming_text_deltas(
tokenizer: &BPETokenizer,
token_ids: &[u32],
stops: Option<&[String]>,
) -> Vec<String> {
let mut deltas = Vec::new();
let mut emitted = 0usize;
for i in 0..token_ids.len() {
let Ok(raw) = tokenizer.decode(&token_ids[..=i]) else {
continue;
};
let text = crate::api::realize_handlers::truncate_at_stop(raw.clone(), stops);
let stop_hit = text.len() < raw.len();
if !stop_hit && text.ends_with('\u{FFFD}') {
continue;
}
if text.len() > emitted && text.is_char_boundary(emitted) {
deltas.push(text[emitted..].to_string());
emitted = text.len();
}
if stop_hit {
break;
}
}
deltas
}
fn resolve_stream_generation_config(
temperature: f32,
top_p: Option<f32>,
max_tokens: usize,
) -> GenerationConfig {
if temperature == 0.0 {
return GenerationConfig::default()
.with_max_tokens(max_tokens)
.with_temperature(1.0);
}
let mut config = GenerationConfig::default()
.with_max_tokens(max_tokens)
.with_temperature(temperature);
if let Some(p) = top_p {
config.strategy = SamplingStrategy::TopP { p };
}
config
}
pub async fn openai_chat_completions_stream_handler(
State(state): State<AppState>,
Json(request): Json<ChatCompletionRequest>,
) -> Result<Sse<impl Stream<Item = Result<Event, Infallible>>>, (StatusCode, Json<ErrorResponse>)> {
let model_id = if request.model == "default" || request.model.is_empty() {
None
} else {
Some(request.model.as_str())
};
let (model, tokenizer) = state.get_model(model_id).map_err(|e| {
state.metrics.record_failure();
(
StatusCode::NOT_FOUND,
Json(ErrorResponse {
error: e.to_string(),
}),
)
})?;
let prompt_text = format_chat_messages(&request.messages, Some(&request.model));
let prompt_ids = tokenizer.encode(&prompt_text);
if prompt_ids.is_empty() {
state.metrics.record_failure();
return Err((
StatusCode::BAD_REQUEST,
Json(ErrorResponse {
error: "Messages cannot be empty".to_string(),
}),
));
}
let prompt_len = prompt_ids.len();
let prompt: Vec<usize> = prompt_ids.iter().map(|&id| id as usize).collect();
let max_tokens = request.max_tokens.unwrap_or(256).min(4096);
let config = resolve_stream_generation_config(
request.temperature.unwrap_or(0.7),
request.top_p,
max_tokens,
);
let request_id = format!(
"chatcmpl-{}",
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_nanos())
.unwrap_or(0)
);
let generated = model.generate(&prompt, &config).map_err(|e| {
state.metrics.record_failure();
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(ErrorResponse {
error: e.to_string(),
}),
)
})?;
let token_ids: Vec<u32> = generated
.iter()
.filter_map(|&id| u32::try_from(id).ok())
.collect();
let generated_ids = token_ids[prompt_len..].to_vec();
let model_name = request.model.clone();
let request_id_clone = request_id.clone();
let deltas = streaming_text_deltas(&tokenizer, &generated_ids, request.stop.as_deref());
let stream = async_stream::stream! {
let initial = ChatCompletionChunk::initial(&request_id_clone, &model_name);
let data = serde_json::to_string(&initial).unwrap_or_default();
yield Ok(Event::default().data(data));
for delta in &deltas {
let chunk = ChatCompletionChunk::content(&request_id_clone, &model_name, delta);
let data = serde_json::to_string(&chunk).unwrap_or_default();
yield Ok(Event::default().data(data));
}
let done = ChatCompletionChunk::done(&request_id_clone, &model_name);
let data = serde_json::to_string(&done).unwrap_or_default();
yield Ok(Event::default().data(data));
yield Ok(Event::default().data("[DONE]"));
};
Ok(Sse::new(stream))
}
#[cfg(test)]
mod pmat758_streaming_delta_tests {
use super::*;
fn tok(vocab: &[&str]) -> BPETokenizer {
BPETokenizer::new(
vocab.iter().map(|s| (*s).to_string()).collect(),
vec![],
"<unk>",
)
.expect("test tokenizer")
}
#[test]
fn holds_back_multibyte_utf8_until_complete() {
let t = tok(&["<unk>", "<0xF0>", "<0x9F>", "<0x98>", "<0x80>"]);
let deltas = streaming_text_deltas(&t, &[1, 2, 3, 4], None);
assert_eq!(deltas.concat(), "😀");
assert!(
!deltas.concat().contains('\u{FFFD}'),
"no replacement chars in streamed deltas"
);
}
#[test]
fn applies_stop_and_halts_emission() {
let t = tok(&["<unk>", "a", "b", "X", "c"]);
let deltas = streaming_text_deltas(&t, &[1, 2, 3, 4], Some(&["X".to_string()]));
assert_eq!(deltas.concat(), "ab");
assert!(!deltas.concat().contains('X'));
}
#[test]
fn no_stop_streams_full_text() {
let t = tok(&["<unk>", "a", "b", "X", "c"]);
let deltas = streaming_text_deltas(&t, &[1, 2, 3, 4], None);
assert_eq!(deltas.concat(), "abXc");
}
}
#[cfg(test)]
mod pmat790_stream_temperature_zero_tests {
use super::resolve_stream_generation_config;
use crate::generate::{sample_token, SamplingStrategy};
use crate::tensor::Tensor;
#[test]
fn temperature_zero_resolves_to_runnable_greedy_config() {
let config = resolve_stream_generation_config(0.0, None, 16);
assert_eq!(
config.strategy,
SamplingStrategy::Greedy,
"temperature 0 must request deterministic (greedy) decoding"
);
let logits = Tensor::from_vec(vec![4], vec![0.1, 0.2, 0.9, 0.3]).expect("tensor");
let token = sample_token(&logits, &config, 0.5)
.expect("temperature-0 config must sample without error (was HTTP 500)");
assert_eq!(token, 2, "greedy must select the argmax token");
}
#[test]
fn temperature_zero_ignores_top_p_and_stays_greedy() {
let config = resolve_stream_generation_config(0.0, Some(0.9), 16);
assert_eq!(config.strategy, SamplingStrategy::Greedy);
}
#[test]
fn positive_temperature_unchanged() {
let greedy = resolve_stream_generation_config(0.7, None, 16);
assert_eq!(greedy.strategy, SamplingStrategy::Greedy);
assert!((greedy.temperature - 0.7).abs() < 1e-6);
let nucleus = resolve_stream_generation_config(0.7, Some(0.8), 16);
assert!(matches!(
nucleus.strategy,
SamplingStrategy::TopP { p } if (p - 0.8).abs() < 1e-6
));
let logits = Tensor::from_vec(vec![3], vec![0.1, 2.0, 0.3]).expect("tensor");
assert!(
sample_token(&logits, &greedy, 0.5).is_ok(),
"positive-temperature config must remain runnable"
);
}
}