use crate::array::Array;
use crate::error::NumRs2Error;
use std::fmt;
pub mod inference;
pub mod metrics;
pub mod optimization;
pub mod predict;
pub mod preprocessing;
pub mod registry;
pub use inference::*;
pub use metrics::*;
pub use optimization::*;
pub use predict::*;
pub use preprocessing::*;
pub use registry::*;
pub type Result<T> = std::result::Result<T, ServingError>;
#[derive(Debug, Clone)]
pub enum ServingError {
ModelNotFound {
model_name: String,
version: Option<String>,
},
InvalidVersion {
model_name: String,
version: String,
message: String,
},
ModelLoadError { model_name: String, message: String },
InferenceError { model_name: String, message: String },
ValidationError { field: String, message: String },
PreprocessingError { stage: String, message: String },
InvalidShape {
expected: Vec<Option<usize>>,
actual: Vec<usize>,
},
BatchSizeMismatch { expected: usize, actual: usize },
QuantizationError { message: String },
MemoryPoolExhausted { requested: usize, available: usize },
TimeoutError { operation: String, timeout_ms: u64 },
ConcurrencyError { message: String },
MetricsError { message: String },
NumRs2IntegrationError { source: Box<NumRs2Error> },
Other { message: String },
}
impl fmt::Display for ServingError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
ServingError::ModelNotFound {
model_name,
version,
} => {
if let Some(v) = version {
write!(f, "Model '{}' version '{}' not found", model_name, v)
} else {
write!(f, "Model '{}' not found", model_name)
}
}
ServingError::InvalidVersion {
model_name,
version,
message,
} => {
write!(
f,
"Invalid version '{}' for model '{}': {}",
version, model_name, message
)
}
ServingError::ModelLoadError {
model_name,
message,
} => {
write!(f, "Failed to load model '{}': {}", model_name, message)
}
ServingError::InferenceError {
model_name,
message,
} => {
write!(f, "Inference error in model '{}': {}", model_name, message)
}
ServingError::ValidationError { field, message } => {
write!(f, "Validation error for field '{}': {}", field, message)
}
ServingError::PreprocessingError { stage, message } => {
write!(f, "Preprocessing error in stage '{}': {}", stage, message)
}
ServingError::InvalidShape { expected, actual } => {
write!(
f,
"Invalid shape: expected {:?}, got {:?}",
expected, actual
)
}
ServingError::BatchSizeMismatch { expected, actual } => {
write!(
f,
"Batch size mismatch: expected {}, got {}",
expected, actual
)
}
ServingError::QuantizationError { message } => {
write!(f, "Quantization error: {}", message)
}
ServingError::MemoryPoolExhausted {
requested,
available,
} => {
write!(
f,
"Memory pool exhausted: requested {} bytes, available {} bytes",
requested, available
)
}
ServingError::TimeoutError {
operation,
timeout_ms,
} => {
write!(
f,
"Operation '{}' timed out after {} ms",
operation, timeout_ms
)
}
ServingError::ConcurrencyError { message } => {
write!(f, "Concurrency error: {}", message)
}
ServingError::MetricsError { message } => {
write!(f, "Metrics error: {}", message)
}
ServingError::NumRs2IntegrationError { source } => {
write!(f, "NumRS2 integration error: {}", source)
}
ServingError::Other { message } => {
write!(f, "Serving error: {}", message)
}
}
}
}
impl std::error::Error for ServingError {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
ServingError::NumRs2IntegrationError { source } => Some(source),
_ => None,
}
}
}
impl From<NumRs2Error> for ServingError {
fn from(error: NumRs2Error) -> Self {
ServingError::NumRs2IntegrationError {
source: Box::new(error),
}
}
}
pub fn validate_shape(expected: &[Option<usize>], actual: &[usize]) -> Result<()> {
if expected.len() != actual.len() {
return Err(ServingError::InvalidShape {
expected: expected.to_vec(),
actual: actual.to_vec(),
});
}
for (i, (exp, act)) in expected.iter().zip(actual.iter()).enumerate() {
if let Some(exp_size) = exp {
if exp_size != act {
return Err(ServingError::InvalidShape {
expected: expected.to_vec(),
actual: actual.to_vec(),
});
}
}
}
Ok(())
}
pub fn validate_batch_size(expected: usize, actual: usize) -> Result<()> {
if expected != actual {
return Err(ServingError::BatchSizeMismatch { expected, actual });
}
Ok(())
}
#[cfg(test)]
mod module_tests {
use super::*;
#[test]
fn test_validate_shape_exact_match() {
let expected = vec![Some(2), Some(3)];
let actual = vec![2, 3];
assert!(validate_shape(&expected, &actual).is_ok());
}
#[test]
fn test_validate_shape_with_none() {
let expected = vec![None, Some(3)];
let actual = vec![5, 3];
assert!(validate_shape(&expected, &actual).is_ok());
}
#[test]
fn test_validate_shape_mismatch() {
let expected = vec![Some(2), Some(3)];
let actual = vec![2, 4];
assert!(validate_shape(&expected, &actual).is_err());
}
#[test]
fn test_validate_shape_dimension_mismatch() {
let expected = vec![Some(2), Some(3)];
let actual = vec![2];
assert!(validate_shape(&expected, &actual).is_err());
}
#[test]
fn test_validate_batch_size_match() {
assert!(validate_batch_size(32, 32).is_ok());
}
#[test]
fn test_validate_batch_size_mismatch() {
assert!(validate_batch_size(32, 16).is_err());
}
#[test]
fn test_error_display() {
let err = ServingError::ModelNotFound {
model_name: "test_model".to_string(),
version: Some("v1.0".to_string()),
};
let display = format!("{}", err);
assert!(display.contains("test_model"));
assert!(display.contains("v1.0"));
}
#[test]
fn test_error_from_numrs2() {
let numrs2_err = NumRs2Error::DimensionMismatch("test".to_string());
let serving_err: ServingError = numrs2_err.into();
match serving_err {
ServingError::NumRs2IntegrationError { .. } => {}
_ => panic!("Expected NumRs2IntegrationError"),
}
}
}