use axum::{
extract::State,
http::StatusCode,
response::{IntoResponse, Json, Response},
};
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use crate::api_types::{StopSequences, UsageInfo};
use crate::server::AppState;
#[derive(Debug, Clone, Deserialize)]
#[serde(untagged)]
pub enum PromptInput {
Single(String),
Batch(Vec<String>),
}
impl PromptInput {
pub fn as_strings(&self) -> Vec<&str> {
match self {
PromptInput::Single(s) => vec![s.as_str()],
PromptInput::Batch(v) => v.iter().map(String::as_str).collect(),
}
}
pub fn first(&self) -> &str {
match self {
PromptInput::Single(s) => s.as_str(),
PromptInput::Batch(v) => v.first().map(String::as_str).unwrap_or(""),
}
}
}
#[derive(Debug, Deserialize)]
pub struct CompletionRequest {
pub model: Option<String>,
pub prompt: PromptInput,
#[serde(default = "default_max_tokens")]
pub max_tokens: usize,
pub temperature: Option<f32>,
pub top_p: Option<f32>,
pub n: Option<usize>,
pub stream: Option<bool>,
pub stop: Option<StopSequences>,
pub presence_penalty: Option<f32>,
pub frequency_penalty: Option<f32>,
pub logprobs: Option<usize>,
pub echo: Option<bool>,
pub seed: Option<u64>,
pub suffix: Option<String>,
pub user: Option<String>,
}
fn default_max_tokens() -> usize {
16
}
#[derive(Debug, Serialize)]
pub struct CompletionLogprobs {
pub tokens: Vec<String>,
pub token_logprobs: Vec<f32>,
pub top_logprobs: Vec<serde_json::Value>,
pub text_offset: Vec<usize>,
}
#[derive(Debug, Serialize)]
pub struct CompletionChoice {
pub text: String,
pub index: usize,
pub logprobs: Option<CompletionLogprobs>,
pub finish_reason: String,
}
#[derive(Debug, Serialize)]
pub struct CompletionResponse {
pub id: String,
pub object: String,
pub created: u64,
pub model: String,
pub choices: Vec<CompletionChoice>,
pub usage: UsageInfo,
}
#[tracing::instrument(skip(state))]
pub async fn create_completion(
State(state): State<Arc<AppState>>,
Json(req): Json<CompletionRequest>,
) -> Result<Response, StatusCode> {
let request_start = std::time::Instant::now();
state.metrics().requests_total.inc();
state.metrics().active_requests.inc();
let prompt_text = req.prompt.first().to_owned();
let echo = req.echo.unwrap_or(false);
let max_tokens = req.max_tokens;
let prompt_tokens = if let Some(tok) = state.tokenizer() {
tok.encode(&prompt_text).map_err(|e| {
tracing::error!(error = %e, "tokenisation failed");
state.metrics().errors_total.inc();
state.metrics().active_requests.dec();
StatusCode::INTERNAL_SERVER_ERROR
})?
} else {
vec![151644u32]
};
let prompt_token_count = prompt_tokens.len();
state
.metrics()
.prompt_tokens_total
.inc_by(prompt_token_count as u64);
let output_tokens = {
let mut engine = state.engine_lock().await;
engine.generate(&prompt_tokens, max_tokens).map_err(|e| {
tracing::error!(error = %e, "generation failed");
state.metrics().errors_total.inc();
state.metrics().active_requests.dec();
StatusCode::INTERNAL_SERVER_ERROR
})?
};
let completion_token_count = output_tokens.len();
state
.metrics()
.tokens_generated_total
.inc_by(completion_token_count as u64);
let completion_text = if let Some(tok) = state.tokenizer() {
tok.decode(&output_tokens).map_err(|e| {
tracing::error!(error = %e, "decoding failed");
StatusCode::INTERNAL_SERVER_ERROR
})?
} else {
format!("{output_tokens:?}")
};
let completion_id = format!("cmpl-{}", completion_id_from_nanos());
let created = unix_timestamp_secs();
let model_name = req.model.unwrap_or_else(|| "bonsai-8b".to_string());
let response = build_completion_response(
&completion_id,
&prompt_text,
&completion_text,
echo,
prompt_token_count,
completion_token_count,
&model_name,
created,
);
let elapsed = request_start.elapsed().as_secs_f64();
state.metrics().request_duration_seconds.observe(elapsed);
state.metrics().active_requests.dec();
Ok(Json(response).into_response())
}
#[allow(clippy::too_many_arguments)]
fn build_completion_response(
id: &str,
prompt: &str,
completion: &str,
echo: bool,
prompt_tokens: usize,
completion_tokens: usize,
model: &str,
created: u64,
) -> CompletionResponse {
let text = if echo {
format!("{prompt}{completion}")
} else {
completion.to_owned()
};
CompletionResponse {
id: id.to_owned(),
object: "text_completion".to_owned(),
created,
model: model.to_owned(),
choices: vec![CompletionChoice {
text,
index: 0,
logprobs: None,
finish_reason: determine_finish_reason(completion_tokens, 16),
}],
usage: UsageInfo {
prompt_tokens,
completion_tokens,
total_tokens: prompt_tokens + completion_tokens,
},
}
}
fn determine_finish_reason(completion_tokens: usize, max_tokens: usize) -> String {
if completion_tokens >= max_tokens {
"length".to_owned()
} else {
"stop".to_owned()
}
}
fn unix_timestamp_secs() -> u64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs()
}
fn completion_id_from_nanos() -> String {
let ts = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_nanos();
format!("{ts:x}")
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn prompt_input_single_as_strings() {
let p = PromptInput::Single("hello world".to_string());
assert_eq!(p.as_strings(), vec!["hello world"]);
}
#[test]
fn prompt_input_batch_as_strings() {
let p = PromptInput::Batch(vec!["foo".to_string(), "bar".to_string()]);
assert_eq!(p.as_strings(), vec!["foo", "bar"]);
}
#[test]
fn prompt_input_single_first() {
let p = PromptInput::Single("hello".to_string());
assert_eq!(p.first(), "hello");
}
#[test]
fn prompt_input_batch_first() {
let p = PromptInput::Batch(vec!["alpha".to_string(), "beta".to_string()]);
assert_eq!(p.first(), "alpha");
}
#[test]
fn prompt_input_empty_batch_first() {
let p = PromptInput::Batch(vec![]);
assert_eq!(p.first(), "");
}
#[test]
fn build_completion_response_no_echo() {
let resp = build_completion_response(
"cmpl-abc",
"Say hello",
" world",
false,
4,
2,
"bonsai-8b",
1_000_000,
);
assert_eq!(resp.object, "text_completion");
assert_eq!(resp.choices[0].text, " world");
assert_eq!(resp.usage.prompt_tokens, 4);
assert_eq!(resp.usage.completion_tokens, 2);
assert_eq!(resp.usage.total_tokens, 6);
}
#[test]
fn build_completion_response_with_echo() {
let resp = build_completion_response(
"cmpl-abc",
"Say hello",
" world",
true,
4,
2,
"bonsai-8b",
1_000_000,
);
assert_eq!(resp.choices[0].text, "Say hello world");
}
#[test]
fn build_completion_response_id_preserved() {
let resp = build_completion_response(
"cmpl-xyz",
"prompt",
"completion",
false,
1,
1,
"bonsai-8b",
42,
);
assert_eq!(resp.id, "cmpl-xyz");
assert_eq!(resp.created, 42);
}
#[test]
fn determine_finish_reason_stop() {
assert_eq!(determine_finish_reason(8, 16), "stop");
}
#[test]
fn determine_finish_reason_length() {
assert_eq!(determine_finish_reason(16, 16), "length");
}
#[test]
fn completion_id_from_nanos_nonempty() {
let id = completion_id_from_nanos();
assert!(!id.is_empty());
}
#[test]
fn unix_timestamp_secs_nonzero() {
let ts = unix_timestamp_secs();
assert!(ts > 1_000_000_000);
}
#[test]
fn serialise_completion_response() {
let resp = build_completion_response(
"cmpl-test",
"prompt",
"result",
false,
3,
5,
"bonsai-8b",
99,
);
let json = serde_json::to_string(&resp).expect("serialisation must succeed");
assert!(json.contains("\"object\":\"text_completion\""));
assert!(json.contains("\"finish_reason\""));
}
}