use std::{sync::Arc, time::Instant};
#[cfg(feature = "aprender-serve")]
use aprender::{
classification::{GaussianNB, KNearestNeighbors, LinearSVM, LogisticRegression},
linear_model::LinearRegression,
primitives::Matrix,
tree::{DecisionTreeClassifier, GradientBoostingClassifier, RandomForestClassifier},
AprenderError, Estimator,
};
use axum::{
extract::State,
http::StatusCode,
routing::{get, post},
Json, Router,
};
use serde::{Deserialize, Serialize};
#[cfg(feature = "aprender-serve")]
#[derive(Clone)]
pub enum LoadedModel {
LogisticRegression(Arc<LogisticRegression>),
KNearestNeighbors(Arc<KNearestNeighbors>),
GaussianNB(Arc<GaussianNB>),
LinearSVM(Arc<LinearSVM>),
DecisionTreeClassifier(Arc<DecisionTreeClassifier>),
RandomForestClassifier(Arc<RandomForestClassifier>),
GradientBoostingClassifier(Arc<GradientBoostingClassifier>),
LinearRegression(Arc<LinearRegression>),
}
#[derive(Clone)]
pub struct ServeState {
#[cfg(feature = "aprender-serve")]
model: Option<LoadedModel>,
model_name: String,
model_version: String,
input_dim: usize,
request_count: Arc<std::sync::atomic::AtomicU64>,
}
impl ServeState {
#[must_use]
pub fn new(model_name: String, model_version: String) -> Self {
Self {
#[cfg(feature = "aprender-serve")]
model: None,
model_name,
model_version,
input_dim: 0,
request_count: Arc::new(std::sync::atomic::AtomicU64::new(0)),
}
}
#[cfg(feature = "aprender-serve")]
#[must_use]
pub fn with_logistic_regression(
model: LogisticRegression,
model_version: String,
input_dim: usize,
) -> Self {
Self {
model: Some(LoadedModel::LogisticRegression(Arc::new(model))),
model_name: "LogisticRegression".to_string(),
model_version,
input_dim,
request_count: Arc::new(std::sync::atomic::AtomicU64::new(0)),
}
}
#[cfg(feature = "aprender-serve")]
pub fn load_apr(
path: impl AsRef<std::path::Path>,
model_version: String,
input_dim: usize,
) -> Result<Self, anyhow::Error> {
use aprender::format::{load, ModelType};
let model: LogisticRegression = load(path, ModelType::LogisticRegression)?;
Ok(Self {
model: Some(LoadedModel::LogisticRegression(Arc::new(model))),
model_name: "LogisticRegression".to_string(),
model_version,
input_dim,
request_count: Arc::new(std::sync::atomic::AtomicU64::new(0)),
})
}
#[cfg(feature = "aprender-serve")]
pub fn load_apr_from_bytes(
bytes: &[u8],
model_version: String,
input_dim: usize,
) -> Result<Self, anyhow::Error> {
use aprender::format::{load_from_bytes, ModelType};
let model: LogisticRegression = load_from_bytes(bytes, ModelType::LogisticRegression)?;
Ok(Self {
model: Some(LoadedModel::LogisticRegression(Arc::new(model))),
model_name: "LogisticRegression".to_string(),
model_version,
input_dim,
request_count: Arc::new(std::sync::atomic::AtomicU64::new(0)),
})
}
#[must_use]
pub fn has_model(&self) -> bool {
#[cfg(feature = "aprender-serve")]
{
self.model.is_some()
}
#[cfg(not(feature = "aprender-serve"))]
{
false
}
}
}
#[derive(Debug, Serialize, Deserialize)]
pub struct HealthResponse {
pub status: String,
pub version: String,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct ReadyResponse {
pub ready: bool,
pub model_loaded: bool,
pub model_name: String,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct PredictRequest {
#[serde(skip_serializing_if = "Option::is_none")]
pub model_id: Option<String>,
pub features: Vec<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub options: Option<PredictOptions>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct PredictOptions {
#[serde(default)]
pub return_probabilities: bool,
#[serde(skip_serializing_if = "Option::is_none")]
pub top_k: Option<usize>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct PredictResponse {
pub prediction: f32,
#[serde(skip_serializing_if = "Option::is_none")]
pub probabilities: Option<Vec<f32>>,
pub latency_ms: f64,
pub model_version: String,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct BatchPredictRequest {
#[serde(skip_serializing_if = "Option::is_none")]
pub model_id: Option<String>,
pub instances: Vec<PredictInstance>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct PredictInstance {
pub features: Vec<f32>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct BatchPredictResponse {
pub predictions: Vec<PredictResponse>,
pub total_latency_ms: f64,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct ModelsResponse {
pub models: Vec<ModelInfo>,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct ModelInfo {
pub id: String,
pub model_type: String,
pub version: String,
pub loaded: bool,
}
#[derive(Debug, Serialize, Deserialize)]
pub struct ErrorResponse {
pub error: String,
#[serde(skip_serializing_if = "Option::is_none")]
pub code: Option<String>,
}
pub fn create_serve_router(state: ServeState) -> Router {
Router::new()
.route("/health", get(health_handler))
.route("/ready", get(ready_handler))
.route("/predict", post(predict_handler))
.route("/predict/batch", post(batch_predict_handler))
.route("/models", get(models_handler))
.route("/metrics", get(metrics_handler))
.with_state(state)
}
async fn health_handler() -> Json<HealthResponse> {
Json(HealthResponse {
status: "healthy".to_string(),
version: env!("CARGO_PKG_VERSION").to_string(),
})
}
async fn ready_handler(State(state): State<ServeState>) -> Json<ReadyResponse> {
let model_loaded = state.has_model();
Json(ReadyResponse {
ready: model_loaded,
model_loaded,
model_name: state.model_name,
})
}
#[cfg(feature = "aprender-serve")]
async fn predict_handler(
State(state): State<ServeState>,
Json(payload): Json<PredictRequest>,
) -> Result<Json<PredictResponse>, (StatusCode, Json<ErrorResponse>)> {
use std::sync::atomic::Ordering;
state.request_count.fetch_add(1, Ordering::Relaxed);
let Some(model) = &state.model else {
return Err((
StatusCode::SERVICE_UNAVAILABLE,
Json(ErrorResponse {
error: "No model loaded".to_string(),
code: Some("E_NO_MODEL".to_string()),
}),
));
};
if state.input_dim > 0 && payload.features.len() != state.input_dim {
return Err((
StatusCode::BAD_REQUEST,
Json(ErrorResponse {
error: format!(
"Invalid input dimension: expected {}, got {}",
state.input_dim,
payload.features.len()
),
code: Some("E_INVALID_INPUT".to_string()),
}),
));
}
let start = Instant::now();
let n_features = payload.features.len();
let input = Matrix::from_vec(1, n_features, payload.features.clone()).map_err(|e| {
(
StatusCode::BAD_REQUEST,
Json(ErrorResponse {
error: format!("Failed to create input matrix: {e}"),
code: Some("E_MATRIX_ERROR".to_string()),
}),
)
})?;
let return_probs = payload
.options
.as_ref()
.is_some_and(|o| o.return_probabilities);
let map_err = |e: aprender::AprenderError| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(ErrorResponse {
error: format!("Model inference error: {e}"),
code: Some("E_INFERENCE_ERROR".to_string()),
}),
)
};
let (prediction, probabilities) = match model {
LoadedModel::LogisticRegression(lr) => {
let predictions = lr.predict(&input);
#[allow(clippy::cast_precision_loss)]
let pred = predictions.first().copied().unwrap_or(0) as f32;
let probs = if return_probs {
let prob_vec = lr.predict_proba(&input);
let p1 = prob_vec.as_slice().first().copied().unwrap_or(0.5);
Some(vec![1.0 - p1, p1])
} else {
None
};
(pred, probs)
},
LoadedModel::KNearestNeighbors(knn) => {
let predictions = knn.predict(&input).map_err(map_err)?;
#[allow(clippy::cast_precision_loss)]
let pred = predictions.first().copied().unwrap_or(0) as f32;
(pred, None)
},
LoadedModel::GaussianNB(nb) => {
let predictions = nb.predict(&input).map_err(map_err)?;
#[allow(clippy::cast_precision_loss)]
let pred = predictions.first().copied().unwrap_or(0) as f32;
let probs = if return_probs {
let prob_vecs = nb.predict_proba(&input).map_err(map_err)?;
prob_vecs.first().cloned()
} else {
None
};
(pred, probs)
},
LoadedModel::LinearSVM(svm) => {
let predictions = svm.predict(&input).map_err(map_err)?;
#[allow(clippy::cast_precision_loss)]
let pred = predictions.first().copied().unwrap_or(0) as f32;
(pred, None)
},
LoadedModel::DecisionTreeClassifier(dt) => {
let predictions = dt.predict(&input);
#[allow(clippy::cast_precision_loss)]
let pred = predictions.first().copied().unwrap_or(0) as f32;
(pred, None)
},
LoadedModel::RandomForestClassifier(rf) => {
let predictions = rf.predict(&input);
#[allow(clippy::cast_precision_loss)]
let pred = predictions.first().copied().unwrap_or(0) as f32;
let probs = if return_probs {
let prob_matrix = rf.predict_proba(&input);
let n_classes = prob_matrix.n_cols();
let mut probs_vec = Vec::with_capacity(n_classes);
for j in 0..n_classes {
probs_vec.push(prob_matrix.get(0, j));
}
Some(probs_vec)
} else {
None
};
(pred, probs)
},
LoadedModel::GradientBoostingClassifier(gb) => {
let predictions = gb.predict(&input).map_err(map_err)?;
#[allow(clippy::cast_precision_loss)]
let pred = predictions.first().copied().unwrap_or(0) as f32;
let probs = if return_probs {
let prob_vecs = gb.predict_proba(&input).map_err(map_err)?;
prob_vecs.first().cloned()
} else {
None
};
(pred, probs)
},
LoadedModel::LinearRegression(lr) => {
let predictions = lr.predict(&input);
let pred = predictions.as_slice().first().copied().unwrap_or(0.0);
(pred, None)
},
};
let latency_ms = start.elapsed().as_secs_f64() * 1000.0;
Ok(Json(PredictResponse {
prediction,
probabilities,
latency_ms,
model_version: state.model_version.clone(),
}))
}
#[cfg(not(feature = "aprender-serve"))]
async fn predict_handler(
State(_state): State<ServeState>,
Json(_payload): Json<PredictRequest>,
) -> Result<Json<PredictResponse>, (StatusCode, Json<ErrorResponse>)> {
Err((
StatusCode::NOT_IMPLEMENTED,
Json(ErrorResponse {
error: "aprender-serve feature not enabled".to_string(),
code: Some("E_NOT_IMPLEMENTED".to_string()),
}),
))
}
#[cfg(feature = "aprender-serve")]
async fn batch_predict_handler(
State(state): State<ServeState>,
Json(payload): Json<BatchPredictRequest>,
) -> Result<Json<BatchPredictResponse>, (StatusCode, Json<ErrorResponse>)> {
use std::sync::atomic::Ordering;
let Some(model) = &state.model else {
return Err((
StatusCode::SERVICE_UNAVAILABLE,
Json(ErrorResponse {
error: "No model loaded".to_string(),
code: Some("E_NO_MODEL".to_string()),
}),
));
};
if payload.instances.is_empty() {
return Err((
StatusCode::BAD_REQUEST,
Json(ErrorResponse {
error: "Empty batch".to_string(),
code: Some("E_EMPTY_BATCH".to_string()),
}),
));
}
let batch_start = Instant::now();
let mut predictions = Vec::with_capacity(payload.instances.len());
state
.request_count
.fetch_add(payload.instances.len() as u64, Ordering::Relaxed);
for instance in &payload.instances {
if state.input_dim > 0 && instance.features.len() != state.input_dim {
return Err((
StatusCode::BAD_REQUEST,
Json(ErrorResponse {
error: format!(
"Invalid input dimension: expected {}, got {}",
state.input_dim,
instance.features.len()
),
code: Some("E_INVALID_INPUT".to_string()),
}),
));
}
let start = Instant::now();
let n_features = instance.features.len();
let input = Matrix::from_vec(1, n_features, instance.features.clone()).map_err(|e| {
(
StatusCode::BAD_REQUEST,
Json(ErrorResponse {
error: format!("Failed to create input matrix: {e}"),
code: Some("E_MATRIX_ERROR".to_string()),
}),
)
})?;
let map_err = |e: AprenderError| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(ErrorResponse {
error: format!("Model inference error: {e}"),
code: Some("E_INFERENCE_ERROR".to_string()),
}),
)
};
let (prediction, probabilities) = match model {
LoadedModel::LogisticRegression(lr) => {
let preds = lr.predict(&input);
#[allow(clippy::cast_precision_loss)]
let pred = preds.first().copied().unwrap_or(0) as f32;
(pred, None)
},
LoadedModel::KNearestNeighbors(knn) => {
let preds = knn.predict(&input).map_err(map_err)?;
#[allow(clippy::cast_precision_loss)]
let pred = preds.first().copied().unwrap_or(0) as f32;
(pred, None)
},
LoadedModel::GaussianNB(nb) => {
let preds = nb.predict(&input).map_err(map_err)?;
#[allow(clippy::cast_precision_loss)]
let pred = preds.first().copied().unwrap_or(0) as f32;
(pred, None)
},
LoadedModel::LinearSVM(svm) => {
let preds = svm.predict(&input).map_err(map_err)?;
#[allow(clippy::cast_precision_loss)]
let pred = preds.first().copied().unwrap_or(0) as f32;
(pred, None)
},
LoadedModel::DecisionTreeClassifier(dt) => {
let preds = dt.predict(&input);
#[allow(clippy::cast_precision_loss)]
let pred = preds.first().copied().unwrap_or(0) as f32;
(pred, None)
},
LoadedModel::RandomForestClassifier(rf) => {
let preds = rf.predict(&input);
#[allow(clippy::cast_precision_loss)]
let pred = preds.first().copied().unwrap_or(0) as f32;
(pred, None)
},
LoadedModel::GradientBoostingClassifier(gb) => {
let preds = gb.predict(&input).map_err(map_err)?;
#[allow(clippy::cast_precision_loss)]
let pred = preds.first().copied().unwrap_or(0) as f32;
(pred, None)
},
LoadedModel::LinearRegression(lr) => {
let preds = lr.predict(&input);
let pred = preds.as_slice().first().copied().unwrap_or(0.0);
(pred, None)
},
};
let latency_ms = start.elapsed().as_secs_f64() * 1000.0;
predictions.push(PredictResponse {
prediction,
probabilities,
latency_ms,
model_version: state.model_version.clone(),
});
}
let total_latency_ms = batch_start.elapsed().as_secs_f64() * 1000.0;
Ok(Json(BatchPredictResponse {
predictions,
total_latency_ms,
}))
}
#[cfg(not(feature = "aprender-serve"))]
async fn batch_predict_handler(
State(_state): State<ServeState>,
Json(_payload): Json<BatchPredictRequest>,
) -> Result<Json<BatchPredictResponse>, (StatusCode, Json<ErrorResponse>)> {
Err((
StatusCode::NOT_IMPLEMENTED,
Json(ErrorResponse {
error: "aprender-serve feature not enabled".to_string(),
code: Some("E_NOT_IMPLEMENTED".to_string()),
}),
))
}
async fn models_handler(
State(state): State<ServeState>,
) -> Result<Json<ModelsResponse>, (StatusCode, Json<ErrorResponse>)> {
let models = if state.has_model() {
vec![ModelInfo {
id: "default".to_string(),
model_type: state.model_name.clone(),
version: state.model_version,
loaded: true,
}]
} else {
vec![]
};
Ok(Json(ModelsResponse { models }))
}
async fn metrics_handler(State(state): State<ServeState>) -> String {
use std::sync::atomic::Ordering;
let request_count = state.request_count.load(Ordering::Relaxed);
let model_loaded = i32::from(state.has_model());
format!(
"# HELP requests_total Total number of inference requests\n\
# TYPE requests_total counter\n\
requests_total {request_count}\n\
# HELP model_loaded Whether a model is loaded (1=yes, 0=no)\n\
# TYPE model_loaded gauge\n\
model_loaded {model_loaded}\n\
# HELP input_dimension Expected input feature dimension\n\
# TYPE input_dimension gauge\n\
input_dimension {}\n",
state.input_dim
)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_serve_state_creation() {
let state = ServeState::new("test-model".to_string(), "v1.0".to_string());
assert_eq!(state.model_name, "test-model");
assert_eq!(state.model_version, "v1.0");
}
#[test]
fn test_predict_request_serialization() {
let request = PredictRequest {
model_id: Some("sentiment-v1".to_string()),
features: vec![0.5, 1.2, -0.3, 0.8],
options: Some(PredictOptions {
return_probabilities: true,
top_k: Some(3),
}),
};
let json = serde_json::to_string(&request).expect("serialization failed");
assert!(json.contains("sentiment-v1"));
assert!(json.contains("0.5"));
assert!(json.contains("return_probabilities"));
}
#[test]
fn test_predict_response_serialization() {
let response = PredictResponse {
prediction: 1.0,
probabilities: Some(vec![0.12, 0.85, 0.03]),
latency_ms: 2.3,
model_version: "v1.2.0".to_string(),
};
let json = serde_json::to_string(&response).expect("serialization failed");
assert!(json.contains("v1.2.0"));
assert!(json.contains("2.3"));
assert!(json.contains("0.12"));
}
#[test]
fn test_batch_predict_request_serialization() {
let request = BatchPredictRequest {
model_id: Some("model-v1".to_string()),
instances: vec![
PredictInstance {
features: vec![0.5, 1.2],
},
PredictInstance {
features: vec![0.1, 0.9],
},
],
};
let json = serde_json::to_string(&request).expect("serialization failed");
assert!(json.contains("model-v1"));
assert!(json.contains("instances"));
assert!(json.contains("0.5"));
assert!(json.contains("0.9"));
}
#[test]
fn test_health_response_format() {
let response = HealthResponse {
status: "healthy".to_string(),
version: "0.2.0".to_string(),
};
let json = serde_json::to_string(&response).expect("serialization failed");
assert!(json.contains("healthy"));
assert!(json.contains("0.2.0"));
}
#[test]
fn test_ready_response_format() {
let response = ReadyResponse {
ready: true,
model_loaded: true,
model_name: "test-model".to_string(),
};
let json = serde_json::to_string(&response).expect("serialization failed");
assert!(json.contains("true"));
assert!(json.contains("test-model"));
}
#[test]
fn test_error_response_format() {
let response = ErrorResponse {
error: "Model not found".to_string(),
code: Some("E404".to_string()),
};
let json = serde_json::to_string(&response).expect("serialization failed");
assert!(json.contains("Model not found"));
assert!(json.contains("E404"));
}
#[tokio::test]
async fn test_health_handler() {
let response = health_handler().await;
assert_eq!(response.0.status, "healthy");
assert!(!response.0.version.is_empty());
}
#[tokio::test]
async fn test_ready_handler_no_model() {
let state = ServeState::new("test-model".to_string(), "v1.0".to_string());
let response = ready_handler(State(state)).await;
assert!(!response.0.ready);
assert!(!response.0.model_loaded);
assert_eq!(response.0.model_name, "test-model");
}
#[test]
fn test_serve_state_has_model() {
let state = ServeState::new("test".to_string(), "v1".to_string());
assert!(!state.has_model());
}
#[test]
fn test_models_info_serialization() {
let info = ModelInfo {
id: "mnist-v1".to_string(),
model_type: "LogisticRegression".to_string(),
version: "1.0.0".to_string(),
loaded: true,
};
let json = serde_json::to_string(&info).expect("serialization failed");
assert!(json.contains("mnist-v1"));
assert!(json.contains("LogisticRegression"));
}
#[cfg(feature = "aprender-serve")]
#[tokio::test]
async fn test_predict_with_loaded_model() {
let x = Matrix::from_vec(4, 2, vec![0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0])
.expect("4x2 matrix");
let y = vec![0, 0, 1, 1];
let mut model = LogisticRegression::new()
.with_learning_rate(0.1)
.with_max_iter(100);
model.fit(&x, &y).expect("Training should succeed");
let state = ServeState::with_logistic_regression(model, "test-v1".to_string(), 2);
assert!(state.has_model());
let request = PredictRequest {
model_id: None,
features: vec![0.9, 0.9], options: Some(PredictOptions {
return_probabilities: true,
top_k: None,
}),
};
let result = predict_handler(State(state.clone()), Json(request)).await;
let response = result.expect("Prediction should succeed");
assert_eq!(response.prediction, 1.0); assert!(response.probabilities.is_some());
assert!(response.latency_ms < 10.0); assert_eq!(response.model_version, "test-v1");
let request_0 = PredictRequest {
model_id: None,
features: vec![0.0, 0.0], options: None,
};
let result_0 = predict_handler(State(state), Json(request_0)).await;
let response_0 = result_0.expect("Prediction should succeed");
assert_eq!(response_0.prediction, 0.0);
}
#[cfg(feature = "aprender-serve")]
#[tokio::test]
async fn test_batch_predict_with_loaded_model() {
let x = Matrix::from_vec(4, 2, vec![0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0])
.expect("4x2 matrix");
let y = vec![0, 0, 1, 1];
let mut model = LogisticRegression::new()
.with_learning_rate(0.1)
.with_max_iter(100);
model.fit(&x, &y).expect("Training should succeed");
let state = ServeState::with_logistic_regression(model, "batch-v1".to_string(), 2);
let request = BatchPredictRequest {
model_id: None,
instances: vec![
PredictInstance {
features: vec![0.0, 0.0],
},
PredictInstance {
features: vec![1.0, 1.0],
},
],
};
let result = batch_predict_handler(State(state), Json(request)).await;
let response = result.expect("Batch prediction should succeed");
assert_eq!(response.predictions.len(), 2);
assert_eq!(response.predictions[0].prediction, 0.0); assert_eq!(response.predictions[1].prediction, 1.0); assert!(response.total_latency_ms < 10.0);
}
#[cfg(feature = "aprender-serve")]
#[tokio::test]
async fn test_predict_invalid_dimensions() {
let x = Matrix::from_vec(4, 2, vec![0.0, 0.0, 0.0, 1.0, 1.0, 0.0, 1.0, 1.0])
.expect("4x2 matrix");
let y = vec![0, 0, 1, 1];
let mut model = LogisticRegression::new();
model.fit(&x, &y).expect("Training should succeed");
let state = ServeState::with_logistic_regression(model, "v1".to_string(), 2);
let request = PredictRequest {
model_id: None,
features: vec![1.0, 2.0, 3.0], options: None,
};
let result = predict_handler(State(state), Json(request)).await;
assert!(result.is_err());
let (status, error) = result.unwrap_err();
assert_eq!(status, StatusCode::BAD_REQUEST);
assert!(error.error.contains("Invalid input dimension"));
}
}