use super::state::AppState;
use crate::{
create_generator,
openai::{
ChatCompletionRequest, ChatCompletionResponse, ErrorResponse, InputItem, InputRole,
MessageContent, Model, ModelsResponse, OutputTokensDetails, ReasoningConfig,
ResponsesErrorResponse, ResponsesInput, ResponsesRequest, ResponsesResponse,
ResponsesUsage, Usage,
},
openresponses::{
self, OpenResponsesStreamBuilder, Response as OpenResponsesResponse, ResponseRequest,
Usage as OpenResponsesUsage,
},
EndpointType, ErrorInjector, LatencyProfile, ResponsesTokenStreamBuilder, TokenStreamBuilder,
};
use axum::{
body::Body,
extract::{Path, State},
http::{header, StatusCode},
response::{IntoResponse, Response},
Json,
};
use futures::StreamExt;
use std::sync::Arc;
use std::time::Instant;
pub async fn health() -> impl IntoResponse {
Json(serde_json::json!({
"status": "ok",
"service": "llmsim"
}))
}
pub async fn chat_completions(
State(state): State<Arc<AppState>>,
Json(request): Json<ChatCompletionRequest>,
) -> Result<Response, AppError> {
let request_start = Instant::now();
tracing::info!(
model = %request.model,
stream = request.stream,
messages = request.messages.len(),
"Chat completion request"
);
state.stats.record_request_start(
&request.model,
request.stream,
EndpointType::ChatCompletions,
);
let error_injector = ErrorInjector::new(state.config.error_config());
if let Some(error) = error_injector.maybe_inject() {
tracing::warn!("Injecting error: {:?}", error);
let status_code = error.status_code();
let status = match status_code {
429 => StatusCode::TOO_MANY_REQUESTS,
500 => StatusCode::INTERNAL_SERVER_ERROR,
503 => StatusCode::SERVICE_UNAVAILABLE,
504 => StatusCode::GATEWAY_TIMEOUT,
400 => StatusCode::BAD_REQUEST,
401 => StatusCode::UNAUTHORIZED,
_ => StatusCode::INTERNAL_SERVER_ERROR,
};
state.stats.record_error(status_code);
let mut response = Json(error.to_error_response()).into_response();
*response.status_mut() = status;
if let Some(retry_after) = error.retry_after() {
response.headers_mut().insert(
header::RETRY_AFTER,
retry_after.to_string().parse().unwrap(),
);
}
return Ok(response);
}
let latency =
if state.config.latency.profile.is_some() || state.config.latency.ttft_mean_ms.is_some() {
state.config.latency_profile()
} else {
LatencyProfile::from_model(&request.model)
};
let generator = create_generator(
&state.config.response.generator,
state.config.response.target_tokens,
);
let content = generator.generate(&request);
let prompt_tokens = count_request_tokens(&request);
let completion_tokens =
crate::count_tokens_default(&content).unwrap_or(content.split_whitespace().count());
let usage = Usage {
prompt_tokens: prompt_tokens as u32,
completion_tokens: completion_tokens as u32,
total_tokens: (prompt_tokens + completion_tokens) as u32,
};
if request.stream {
let stats = state.stats.clone();
let prompt_tok = usage.prompt_tokens;
let completion_tok = usage.completion_tokens;
let stream = TokenStreamBuilder::new(&request.model, content)
.latency(latency)
.usage(usage)
.on_complete(move || {
stats.record_request_end(request_start.elapsed(), prompt_tok, completion_tok);
})
.build();
let body = Body::from_stream(stream.into_stream().map(Ok::<_, std::io::Error>));
Ok(Response::builder()
.status(StatusCode::OK)
.header(header::CONTENT_TYPE, "text/event-stream")
.header(header::CACHE_CONTROL, "no-cache")
.header(header::CONNECTION, "keep-alive")
.body(body)
.unwrap())
} else {
let delay = latency.sample_ttft();
if !delay.is_zero() {
tokio::time::sleep(delay).await;
}
state.stats.record_request_end(
request_start.elapsed(),
usage.prompt_tokens,
usage.completion_tokens,
);
let response = ChatCompletionResponse::new(request.model.clone(), content, usage);
Ok(Json(response).into_response())
}
}
pub async fn get_stats(State(state): State<Arc<AppState>>) -> impl IntoResponse {
Json(state.stats.snapshot())
}
pub async fn create_openresponses_response(
State(state): State<Arc<AppState>>,
Json(request): Json<ResponseRequest>,
) -> Result<Response, AppError> {
let request_start = Instant::now();
tracing::info!(
model = %request.model,
stream = request.stream,
"OpenResponses request"
);
state
.stats
.record_request_start(&request.model, request.stream, EndpointType::Responses);
let error_injector = ErrorInjector::new(state.config.error_config());
if let Some(error) = error_injector.maybe_inject() {
tracing::warn!("Injecting error: {:?}", error);
let status_code = error.status_code();
let status = match status_code {
429 => StatusCode::TOO_MANY_REQUESTS,
500 => StatusCode::INTERNAL_SERVER_ERROR,
503 => StatusCode::SERVICE_UNAVAILABLE,
504 => StatusCode::GATEWAY_TIMEOUT,
400 => StatusCode::BAD_REQUEST,
401 => StatusCode::UNAUTHORIZED,
_ => StatusCode::INTERNAL_SERVER_ERROR,
};
state.stats.record_error(status_code);
let error_response = openresponses::ErrorResponse::new(
error.to_error_response().error.message,
error.to_error_response().error.error_type,
);
let mut response = Json(error_response).into_response();
*response.status_mut() = status;
if let Some(retry_after) = error.retry_after() {
response.headers_mut().insert(
header::RETRY_AFTER,
retry_after.to_string().parse().unwrap(),
);
}
return Ok(response);
}
let latency =
if state.config.latency.profile.is_some() || state.config.latency.ttft_mean_ms.is_some() {
state.config.latency_profile()
} else {
LatencyProfile::from_model(&request.model)
};
let input_text = request.input.extract_text();
let generator = create_generator(
&state.config.response.generator,
state.config.response.target_tokens,
);
let chat_request = ChatCompletionRequest {
model: request.model.clone(),
messages: vec![crate::openai::Message::user(&input_text)],
temperature: request.temperature,
top_p: request.top_p,
n: None,
stream: request.stream,
stop: None,
max_tokens: request.max_output_tokens,
max_completion_tokens: request.max_output_tokens,
presence_penalty: None,
frequency_penalty: None,
logit_bias: None,
user: request.user.clone(),
tools: None,
tool_choice: None,
response_format: None,
seed: None,
};
let content = generator.generate(&chat_request);
let input_tokens = count_openresponses_input_tokens(&request);
let output_tokens =
crate::count_tokens_default(&content).unwrap_or(content.split_whitespace().count());
let usage = OpenResponsesUsage {
input_tokens: input_tokens as u32,
output_tokens: output_tokens as u32,
total_tokens: (input_tokens + output_tokens) as u32,
input_tokens_details: None,
output_tokens_details: None,
};
if request.stream {
let stats = state.stats.clone();
let input_tok = usage.input_tokens;
let output_tok = usage.output_tokens;
let stream = OpenResponsesStreamBuilder::new(&request.model, content)
.latency(latency)
.usage(usage)
.on_complete(move || {
stats.record_request_end(request_start.elapsed(), input_tok, output_tok);
})
.build();
let body = Body::from_stream(stream.into_stream().map(Ok::<_, std::io::Error>));
Ok(Response::builder()
.status(StatusCode::OK)
.header(header::CONTENT_TYPE, "text/event-stream")
.header(header::CACHE_CONTROL, "no-cache")
.header(header::CONNECTION, "keep-alive")
.body(body)
.unwrap())
} else {
let delay = latency.sample_ttft();
if !delay.is_zero() {
tokio::time::sleep(delay).await;
}
state.stats.record_request_end(
request_start.elapsed(),
usage.input_tokens,
usage.output_tokens,
);
let response = OpenResponsesResponse::new(request.model.clone(), content, usage);
Ok(Json(response).into_response())
}
}
fn count_openresponses_input_tokens(request: &ResponseRequest) -> usize {
let text = request.input.extract_text();
let mut total = crate::count_tokens_default(&text).unwrap_or(text.split_whitespace().count());
total += 3;
total
}
pub async fn list_models(State(state): State<Arc<AppState>>) -> impl IntoResponse {
use crate::openai::{get_model_profile, infer_model_owner};
let models: Vec<Model> = state
.config
.models
.available
.iter()
.map(|id| {
if let Some(profile) = get_model_profile(id) {
Model::from_profile(profile)
} else {
Model::new(id, infer_model_owner(id))
}
})
.collect();
Json(ModelsResponse::new(models))
}
pub async fn get_model(
State(state): State<Arc<AppState>>,
Path(model_id): Path<String>,
) -> Result<Json<Model>, AppError> {
use crate::openai::{get_model_profile, infer_model_owner};
if state.config.models.available.contains(&model_id) {
let model = if let Some(profile) = get_model_profile(&model_id) {
Model::from_profile(profile)
} else {
Model::new(&model_id, infer_model_owner(&model_id))
};
Ok(Json(model))
} else {
Err(AppError::NotFound(format!(
"Model '{}' not found",
model_id
)))
}
}
pub async fn create_response(
State(state): State<Arc<AppState>>,
Json(request): Json<ResponsesRequest>,
) -> Result<Response, AppError> {
let request_start = Instant::now();
tracing::info!(
model = %request.model,
stream = request.stream,
"Responses API request"
);
state
.stats
.record_request_start(&request.model, request.stream, EndpointType::Responses);
let error_injector = ErrorInjector::new(state.config.error_config());
if let Some(error) = error_injector.maybe_inject() {
tracing::warn!("Injecting error: {:?}", error);
let status = match error.status_code() {
429 => StatusCode::TOO_MANY_REQUESTS,
500 => StatusCode::INTERNAL_SERVER_ERROR,
503 => StatusCode::SERVICE_UNAVAILABLE,
504 => StatusCode::GATEWAY_TIMEOUT,
400 => StatusCode::BAD_REQUEST,
401 => StatusCode::UNAUTHORIZED,
_ => StatusCode::INTERNAL_SERVER_ERROR,
};
state.stats.record_error(error.status_code());
let error_response = ResponsesErrorResponse {
error: crate::openai::ResponsesError::new(
error.to_error_response().error.error_type,
error.to_error_response().error.message,
),
};
let mut response = Json(error_response).into_response();
*response.status_mut() = status;
if let Some(retry_after) = error.retry_after() {
response.headers_mut().insert(
header::RETRY_AFTER,
retry_after.to_string().parse().unwrap(),
);
}
return Ok(response);
}
let latency =
if state.config.latency.profile.is_some() || state.config.latency.ttft_mean_ms.is_some() {
state.config.latency_profile()
} else {
LatencyProfile::from_model(&request.model)
};
let input_text = extract_input_text(&request.input, &request.instructions);
let chat_request = crate::openai::ChatCompletionRequest {
model: request.model.clone(),
messages: vec![crate::openai::Message::user(&input_text)],
temperature: request.temperature,
top_p: request.top_p,
n: None,
stream: request.stream,
stop: None,
max_tokens: request.max_output_tokens,
max_completion_tokens: request.max_output_tokens,
presence_penalty: None,
frequency_penalty: None,
logit_bias: None,
user: None,
tools: None,
tool_choice: None,
response_format: None,
seed: None,
};
let generator = create_generator(
&state.config.response.generator,
state.config.response.target_tokens,
);
let content = generator.generate(&chat_request);
let input_tokens =
crate::count_tokens_default(&input_text).unwrap_or(input_text.split_whitespace().count());
let output_tokens =
crate::count_tokens_default(&content).unwrap_or(content.split_whitespace().count());
let reasoning_tokens =
calculate_reasoning_tokens(&request.model, &request.reasoning, output_tokens);
let usage = ResponsesUsage {
input_tokens: input_tokens as u32,
output_tokens: output_tokens as u32,
total_tokens: (input_tokens + output_tokens + reasoning_tokens) as u32,
output_tokens_details: Some(OutputTokensDetails {
reasoning_tokens: reasoning_tokens as u32,
}),
};
if request.stream {
let stats = state.stats.clone();
let input_tok = usage.input_tokens;
let output_tok = usage.output_tokens;
let stream = ResponsesTokenStreamBuilder::new(&request.model, content)
.latency(latency)
.usage(usage)
.on_complete(move || {
stats.record_request_end(request_start.elapsed(), input_tok, output_tok);
})
.build();
let body = Body::from_stream(stream.into_stream().map(Ok::<_, std::io::Error>));
Ok(Response::builder()
.status(StatusCode::OK)
.header(header::CONTENT_TYPE, "text/event-stream")
.header(header::CACHE_CONTROL, "no-cache")
.header(header::CONNECTION, "keep-alive")
.body(body)
.unwrap())
} else {
let delay = latency.sample_ttft();
if !delay.is_zero() {
tokio::time::sleep(delay).await;
}
state.stats.record_request_end(
request_start.elapsed(),
usage.input_tokens,
usage.output_tokens,
);
let response = ResponsesResponse::new(request.model.clone(), content, usage);
Ok(Json(response).into_response())
}
}
fn extract_input_text(input: &ResponsesInput, instructions: &Option<String>) -> String {
let mut parts = Vec::new();
if let Some(instr) = instructions {
parts.push(instr.clone());
}
match input {
ResponsesInput::Text(text) => {
parts.push(text.clone());
}
ResponsesInput::Items(items) => {
for item in items {
if let InputItem::Message { role, content } = item {
let role_str = match role {
InputRole::User => "user",
InputRole::Assistant => "assistant",
InputRole::System => "system",
InputRole::Developer => "developer",
};
let content_str = match content {
MessageContent::Text(text) => text.clone(),
MessageContent::Parts(content_parts) => content_parts
.iter()
.filter_map(|p| {
if let crate::openai::ContentPart::InputText { text } = p {
Some(text.clone())
} else {
None
}
})
.collect::<Vec<_>>()
.join(" "),
};
parts.push(format!("{}: {}", role_str, content_str));
}
}
}
}
parts.join("\n")
}
fn calculate_reasoning_tokens(
model: &str,
reasoning: &Option<ReasoningConfig>,
output_tokens: usize,
) -> usize {
let is_o_series = model.starts_with("o1")
|| model.starts_with("o3")
|| model.starts_with("o4")
|| model.contains("-o1")
|| model.contains("-o3");
let is_gpt5 = model.starts_with("gpt-5");
if !is_o_series && !is_gpt5 {
return 0;
}
let effort = reasoning
.as_ref()
.and_then(|r| r.effort.as_deref())
.unwrap_or("medium");
let multiplier = match effort {
"none" => 0.0,
"minimal" => 0.5, "low" => 1.5,
"medium" => 3.0,
"high" => 6.0,
"xhigh" => 10.0, _ => 3.0, };
(output_tokens as f64 * multiplier) as usize
}
fn count_request_tokens(request: &ChatCompletionRequest) -> usize {
let mut total = 0;
for message in &request.messages {
if let Some(content) = &message.content {
total +=
crate::count_tokens_default(content).unwrap_or(content.split_whitespace().count());
}
total += 4;
}
total += 3;
total
}
#[derive(Debug)]
#[allow(dead_code)]
pub enum AppError {
NotFound(String),
BadRequest(String),
Internal(String),
}
impl IntoResponse for AppError {
fn into_response(self) -> Response {
let (status, error_response) = match self {
AppError::NotFound(msg) => (
StatusCode::NOT_FOUND,
ErrorResponse::new(msg, "not_found_error"),
),
AppError::BadRequest(msg) => {
(StatusCode::BAD_REQUEST, ErrorResponse::invalid_request(msg))
}
AppError::Internal(msg) => (
StatusCode::INTERNAL_SERVER_ERROR,
ErrorResponse::new(msg, "internal_error"),
),
};
let mut response = Json(error_response).into_response();
*response.status_mut() = status;
response
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::openai::Message;
#[test]
fn test_count_request_tokens() {
let request = ChatCompletionRequest {
model: "gpt-4".to_string(),
messages: vec![
Message::system("You are a helpful assistant."),
Message::user("Hello!"),
],
temperature: None,
top_p: None,
n: None,
stream: false,
stop: None,
max_tokens: None,
max_completion_tokens: None,
presence_penalty: None,
frequency_penalty: None,
logit_bias: None,
user: None,
tools: None,
tool_choice: None,
response_format: None,
seed: None,
};
let tokens = count_request_tokens(&request);
assert!(tokens > 0);
}
#[tokio::test]
async fn test_health_endpoint() {
let response = health().await.into_response();
assert_eq!(response.status(), StatusCode::OK);
}
}