use axum::http::StatusCode;
use axum::response::{IntoResponse, Response};
use thiserror::Error;
pub type ServerResult<T> = Result<T, ServerError>;
#[derive(Error, Debug)]
pub enum ServerError {
#[error("failed to bind to {addr}: {source}")]
BindError {
addr: String,
source: std::io::Error,
},
#[error("runtime error: {0}")]
Runtime(#[from] oxillama_runtime::RuntimeError),
#[error("serialization error: {0}")]
Serialization(#[from] serde_json::Error),
#[error("invalid request: {message}")]
InvalidRequest {
message: String,
},
#[error("model not ready")]
ModelNotReady,
#[error("inference queue is full — server overloaded")]
QueueFull,
#[error("inference worker is no longer running")]
WorkerDead,
#[error("thread not found: {0}")]
ThreadNotFound(String),
#[error("run not found: {0}")]
RunNotFound(String),
#[error("run is in terminal state: {0}")]
RunInTerminalState(String),
#[error("file not found: {0}")]
FileNotFound(String),
#[error("file too large: {0}")]
FileTooLarge(String),
#[error("file store error: {0}")]
FileStoreError(String),
#[error("run step not found: {0}")]
RunStepNotFound(String),
#[error("I/O error ({context}): {source}")]
IoError {
context: String,
source: std::io::Error,
},
#[error("response {0} not found")]
ResponseNotFound(String),
#[error("previous response {0} not found")]
PreviousResponseNotFound(String),
}
impl IntoResponse for ServerError {
fn into_response(self) -> Response {
let status = match &self {
ServerError::InvalidRequest { .. } => StatusCode::BAD_REQUEST,
ServerError::ModelNotReady => StatusCode::SERVICE_UNAVAILABLE,
ServerError::QueueFull => StatusCode::TOO_MANY_REQUESTS,
ServerError::WorkerDead => StatusCode::SERVICE_UNAVAILABLE,
ServerError::ThreadNotFound(_) => StatusCode::NOT_FOUND,
ServerError::RunNotFound(_) => StatusCode::NOT_FOUND,
ServerError::RunInTerminalState(_) => StatusCode::CONFLICT,
ServerError::FileNotFound(_) => StatusCode::NOT_FOUND,
ServerError::FileTooLarge(_) => StatusCode::PAYLOAD_TOO_LARGE,
ServerError::FileStoreError(_) => StatusCode::INTERNAL_SERVER_ERROR,
ServerError::RunStepNotFound(_) => StatusCode::NOT_FOUND,
ServerError::ResponseNotFound(_) => StatusCode::NOT_FOUND,
ServerError::PreviousResponseNotFound(_) => StatusCode::NOT_FOUND,
_ => StatusCode::INTERNAL_SERVER_ERROR,
};
let error_type = match &self {
ServerError::InvalidRequest { .. } => "invalid_request_error",
ServerError::ModelNotReady => "service_unavailable",
ServerError::QueueFull => "rate_limit_error",
ServerError::WorkerDead => "service_unavailable",
ServerError::ThreadNotFound(_) => "not_found_error",
ServerError::RunNotFound(_) => "not_found_error",
ServerError::RunInTerminalState(_) => "conflict_error",
ServerError::FileNotFound(_) => "not_found_error",
ServerError::FileTooLarge(_) => "payload_too_large",
ServerError::FileStoreError(_) => "internal_error",
ServerError::RunStepNotFound(_) => "not_found_error",
ServerError::ResponseNotFound(_) => "not_found_error",
ServerError::PreviousResponseNotFound(_) => "not_found_error",
_ => "internal_error",
};
let body = serde_json::json!({
"error": {
"message": self.to_string(),
"type": error_type,
}
});
(status, axum::Json(body)).into_response()
}
}
#[cfg(test)]
mod tests {
use super::*;
use axum::response::IntoResponse;
fn status_of(err: ServerError) -> StatusCode {
let resp = err.into_response();
resp.status()
}
#[test]
fn test_invalid_request_returns_400() {
let err = ServerError::InvalidRequest {
message: "bad param".to_string(),
};
assert_eq!(status_of(err), StatusCode::BAD_REQUEST);
}
#[test]
fn test_model_not_ready_returns_503() {
assert_eq!(
status_of(ServerError::ModelNotReady),
StatusCode::SERVICE_UNAVAILABLE
);
}
#[test]
fn test_queue_full_returns_429() {
assert_eq!(
status_of(ServerError::QueueFull),
StatusCode::TOO_MANY_REQUESTS
);
}
#[test]
fn test_worker_dead_returns_503() {
assert_eq!(
status_of(ServerError::WorkerDead),
StatusCode::SERVICE_UNAVAILABLE
);
}
#[test]
fn test_serialization_error_returns_500() {
let json_err = serde_json::from_str::<serde_json::Value>("not json")
.expect_err("parsing invalid JSON should fail");
let err = ServerError::Serialization(json_err);
assert_eq!(status_of(err), StatusCode::INTERNAL_SERVER_ERROR);
}
#[test]
fn test_error_display_invalid_request() {
let err = ServerError::InvalidRequest {
message: "missing field".to_string(),
};
let msg = err.to_string();
assert!(
msg.contains("missing field"),
"display should contain message: {msg}"
);
}
#[test]
fn test_error_display_model_not_ready() {
let msg = ServerError::ModelNotReady.to_string();
assert!(!msg.is_empty());
}
#[test]
fn test_error_display_queue_full() {
let msg = ServerError::QueueFull.to_string();
assert!(!msg.is_empty());
}
#[test]
fn test_error_display_worker_dead() {
let msg = ServerError::WorkerDead.to_string();
assert!(!msg.is_empty());
}
#[test]
fn test_thread_not_found_returns_404() {
assert_eq!(
status_of(ServerError::ThreadNotFound("thread_xyz".into())),
StatusCode::NOT_FOUND
);
}
#[test]
fn test_run_not_found_returns_404() {
assert_eq!(
status_of(ServerError::RunNotFound("run_xyz".into())),
StatusCode::NOT_FOUND
);
}
#[test]
fn test_run_in_terminal_state_returns_409() {
assert_eq!(
status_of(ServerError::RunInTerminalState(
"run_xyz is completed".into()
)),
StatusCode::CONFLICT
);
}
}