use std::path::Path;
use thiserror::Error;
#[derive(Error, Debug)]
pub enum ConfigError {
#[error("batch size must be greater than 0")]
InvalidBatchSize,
#[error(
"batch size {actual} exceeds maximum allowed {max}; consider reducing batch size or increasing memory limits"
)]
BatchSizeTooLarge { actual: usize, max: usize },
#[error("model path does not exist: {path}; ensure the model file has been downloaded")]
ModelPathNotFound { path: std::path::PathBuf },
#[error("model path '{path}' is not a file; expected a .onnx model file")]
ModelPathNotFile { path: std::path::PathBuf },
#[error("invalid value for field '{field}': expected {expected}, got {actual}{suggestion}")]
InvalidFieldValue {
field: String,
expected: String,
actual: String,
suggestion: String,
},
#[error("missing required configuration field '{field}'{suggestion}")]
MissingRequiredField { field: String, suggestion: String },
#[error("invalid configuration: {message}")]
InvalidConfig { message: String },
#[error("validation failed: {message}")]
ValidationFailed { message: String },
#[error("resource limit exceeded: {message}")]
ResourceLimitExceeded { message: String },
#[error("dependency error: {message}")]
DependencyError { message: String },
#[error("type mismatch: {message}")]
TypeMismatch { message: String },
}
pub trait ConfigDefaults: Sized {
fn defaults() -> Self;
}
impl<T: ConfigValidator> ConfigDefaults for T {
fn defaults() -> Self {
T::get_defaults()
}
}
pub trait ConfigValidator {
fn validate(&self) -> Result<(), ConfigError>;
fn get_defaults() -> Self
where
Self: Sized;
fn validate_batch_size(&self, batch_size: usize) -> Result<(), ConfigError> {
if batch_size == 0 {
Err(ConfigError::InvalidBatchSize)
} else {
Ok(())
}
}
fn validate_batch_size_with_limits(
&self,
batch_size: usize,
max_batch_size: usize,
) -> Result<(), ConfigError> {
if batch_size == 0 {
return Err(ConfigError::InvalidBatchSize);
}
if batch_size > max_batch_size {
return Err(ConfigError::BatchSizeTooLarge {
actual: batch_size,
max: max_batch_size,
});
}
Ok(())
}
fn validate_model_path(&self, path: &Path) -> Result<(), ConfigError> {
if !path.exists() {
Err(ConfigError::ModelPathNotFound {
path: path.to_path_buf(),
})
} else if !path.is_file() {
Err(ConfigError::ModelPathNotFile {
path: path.to_path_buf(),
})
} else {
Ok(())
}
}
fn validate_image_dimensions(&self, width: u32, height: u32) -> Result<(), ConfigError> {
if width == 0 || height == 0 {
Err(ConfigError::InvalidConfig {
message: "Image dimensions must be positive".to_string(),
})
} else {
Ok(())
}
}
fn validate_confidence_threshold(&self, threshold: f32) -> Result<(), ConfigError> {
if !(0.0..=1.0).contains(&threshold) {
Err(ConfigError::InvalidConfig {
message: format!(
"Confidence threshold must be between 0.0 and 1.0, got {}",
threshold
),
})
} else {
Ok(())
}
}
fn validate_memory_limit(&self, limit_mb: usize) -> Result<(), ConfigError> {
const MAX_REASONABLE_MEMORY_MB: usize = 32 * 1024;
if limit_mb > MAX_REASONABLE_MEMORY_MB {
Err(ConfigError::ResourceLimitExceeded {
message: format!(
"Memory limit {} MB exceeds reasonable maximum of {} MB",
limit_mb, MAX_REASONABLE_MEMORY_MB
),
})
} else {
Ok(())
}
}
fn validate_thread_count(&self, thread_count: usize) -> Result<(), ConfigError> {
const MAX_REASONABLE_THREADS: usize = 256;
if thread_count == 0 {
Err(ConfigError::InvalidConfig {
message: "Thread count must be greater than 0".to_string(),
})
} else if thread_count > MAX_REASONABLE_THREADS {
Err(ConfigError::ResourceLimitExceeded {
message: format!(
"Thread count {} exceeds reasonable maximum of {}",
thread_count, MAX_REASONABLE_THREADS
),
})
} else {
Ok(())
}
}
fn validate_f32_range(
&self,
value: f32,
min: f32,
max: f32,
field_name: &str,
) -> Result<(), ConfigError> {
if value < min || value > max {
Err(ConfigError::InvalidConfig {
message: format!(
"{} must be between {} and {}, got {}",
field_name, min, max, value
),
})
} else {
Ok(())
}
}
fn validate_positive_f32(&self, value: f32, field_name: &str) -> Result<(), ConfigError> {
if value <= 0.0 {
Err(ConfigError::InvalidConfig {
message: format!("{} must be greater than 0, got {}", field_name, value),
})
} else {
Ok(())
}
}
fn validate_positive_usize(&self, value: usize, field_name: &str) -> Result<(), ConfigError> {
if value == 0 {
Err(ConfigError::InvalidConfig {
message: format!("{} must be greater than 0, got {}", field_name, value),
})
} else {
Ok(())
}
}
}
pub struct DefaultValidator;
impl ConfigValidator for DefaultValidator {
fn validate(&self) -> Result<(), ConfigError> {
Ok(())
}
fn get_defaults() -> Self {
DefaultValidator
}
}
pub trait ConfigValidatorExt: ConfigValidator {
fn validate_and_wrap_ocr_error(self) -> Result<Self, super::super::errors::OCRError>
where
Self: Sized,
{
self.validate()
.map_err(|e| super::super::errors::OCRError::ConfigError {
message: e.to_string(),
})?;
Ok(self)
}
fn validate_and_wrap_generic(self) -> Result<Self, Box<dyn std::error::Error + Send + Sync>>
where
Self: Sized,
{
self.validate()?;
Ok(self)
}
}
impl<T: ConfigValidator> ConfigValidatorExt for T {}
impl From<ConfigError> for String {
fn from(error: ConfigError) -> Self {
error.to_string()
}
}
#[cfg(test)]
mod tests {
use super::*;
struct TestValidator;
impl ConfigValidator for TestValidator {
fn validate(&self) -> Result<(), ConfigError> {
Ok(())
}
fn get_defaults() -> Self {
TestValidator
}
}
#[test]
fn test_validate_batch_size() {
let validator = TestValidator;
assert!(validator.validate_batch_size(1).is_ok());
assert!(validator.validate_batch_size(10).is_ok());
assert!(validator.validate_batch_size(0).is_err());
}
#[test]
fn test_validate_image_dimensions() {
let validator = TestValidator;
assert!(validator.validate_image_dimensions(100, 100).is_ok());
assert!(validator.validate_image_dimensions(1, 1).is_ok());
assert!(validator.validate_image_dimensions(0, 100).is_err());
assert!(validator.validate_image_dimensions(100, 0).is_err());
assert!(validator.validate_image_dimensions(0, 0).is_err());
}
#[test]
fn test_validate_confidence_threshold() {
let validator = TestValidator;
assert!(validator.validate_confidence_threshold(0.0).is_ok());
assert!(validator.validate_confidence_threshold(0.5).is_ok());
assert!(validator.validate_confidence_threshold(1.0).is_ok());
assert!(validator.validate_confidence_threshold(-0.1).is_err());
assert!(validator.validate_confidence_threshold(1.1).is_err());
}
#[test]
fn test_validate_memory_limit() {
let validator = TestValidator;
assert!(validator.validate_memory_limit(1024).is_ok());
assert!(validator.validate_memory_limit(16 * 1024).is_ok());
assert!(validator.validate_memory_limit(64 * 1024).is_err());
}
#[test]
fn test_validate_thread_count() {
let validator = TestValidator;
assert!(validator.validate_thread_count(1).is_ok());
assert!(validator.validate_thread_count(8).is_ok());
assert!(validator.validate_thread_count(64).is_ok());
assert!(validator.validate_thread_count(0).is_err());
assert!(validator.validate_thread_count(512).is_err());
}
#[test]
fn test_config_error_to_string() {
let error = ConfigError::InvalidBatchSize;
let error_string: String = error.into();
assert_eq!(error_string, "batch size must be greater than 0");
}
}