pub async fn stream_generate_handler(
State(state): State<AppState>,
Json(request): Json<GenerateRequest>,
) -> Result<Sse<impl Stream<Item = Result<Event, Infallible>>>, (StatusCode, Json<ErrorResponse>)> {
let (model, tokenizer) = state.get_model(request.model_id.as_deref()).map_err(|e| {
(
StatusCode::NOT_FOUND,
Json(ErrorResponse {
error: e.to_string(),
}),
)
})?;
let prompt_ids = tokenizer.encode(&request.prompt);
if prompt_ids.is_empty() {
return Err((
StatusCode::BAD_REQUEST,
Json(ErrorResponse {
error: "Prompt cannot be empty".to_string(),
}),
));
}
let prompt: Vec<usize> = prompt_ids.iter().map(|&id| id as usize).collect();
let prompt_len = prompt.len();
let strategy = match request.strategy.as_str() {
"greedy" => SamplingStrategy::Greedy,
"top_k" => SamplingStrategy::TopK { k: request.top_k },
"top_p" => SamplingStrategy::TopP { p: request.top_p },
_ => {
return Err((
StatusCode::BAD_REQUEST,
Json(ErrorResponse {
error: format!("Invalid strategy: {}", request.strategy),
}),
));
},
};
let mut config = GenerationConfig::default()
.with_max_tokens(request.max_tokens)
.with_temperature(request.temperature);
config.strategy = strategy;
if let Some(seed) = request.seed {
config = config.with_seed(seed);
}
let generated = match model.generate(&prompt, &config) {
Ok(tokens) => tokens,
Err(e) => {
return Err((
StatusCode::INTERNAL_SERVER_ERROR,
Json(ErrorResponse {
error: e.to_string(),
}),
));
},
};
let token_ids: Vec<u32> = generated
.iter()
.map(|&id| {
u32::try_from(id).map_err(|_| {
(
StatusCode::BAD_REQUEST,
Json(ErrorResponse {
error: format!("Token ID {id} exceeds u32 range"),
}),
)
})
})
.collect::<Result<Vec<_>, _>>()?;
let tokenizer_clone = tokenizer;
let stream = async_stream::stream! {
for &token_id in &token_ids[prompt_len..] {
let text = match tokenizer_clone.decode(&[token_id]) {
Ok(t) => t,
Err(_) => String::from("<error>"),
};
let event = StreamTokenEvent { token_id, text };
let data = serde_json::to_string(&event)
.unwrap_or_else(|_| r#"{"error":"serialization failed"}"#.to_string());
yield Ok::<_, Infallible>(Event::default().event("token").data(data));
}
let done_event = StreamDoneEvent {
num_generated: token_ids.len() - prompt_len,
};
let data = serde_json::to_string(&done_event)
.unwrap_or_else(|_| r#"{"error":"serialization failed"}"#.to_string());
yield Ok(Event::default().event("done").data(data));
};
Ok(Sse::new(stream))
}
#[cfg(test)]
#[path = "gpu_handlers_tests.rs"]
mod gpu_handlers_tests;