use std::fmt;
use std::time::Duration;
use serde::{Deserialize, Serialize};
use thiserror::Error;
pub mod canonical_inputs {
pub const VERSION: &str = "1.0.0";
pub const LATENCY_PROMPT: &str = "Explain the concept of machine learning in one sentence.";
pub const THROUGHPUT_TOKENS: &[u32] = &[1, 2, 3, 4, 5, 6, 7, 8];
pub const MAX_TOKENS: usize = 50;
pub const SPEC_VERSION: &str = "1.0.1";
}
#[derive(Debug, Error)]
pub enum PreflightError {
#[error("Server unreachable at {url}: {reason}")]
ServerUnreachable {
url: String,
reason: String,
},
#[error("Health check failed at {url}: HTTP {status}")]
HealthCheckFailed {
url: String,
status: u16,
},
#[error("Model not found: requested '{requested}', available: {available:?}")]
ModelNotFound {
requested: String,
available: Vec<String>,
},
#[error("Schema mismatch: missing field '{missing_field}'")]
SchemaMismatch {
missing_field: String,
},
#[error("Field type mismatch: '{field}' expected {expected}, got {actual}")]
FieldTypeMismatch {
field: String,
expected: String,
actual: String,
},
#[error("Response parse error: {reason}")]
ResponseParseError {
reason: String,
},
#[error("Timeout after {duration:?} during {operation}")]
Timeout {
duration: Duration,
operation: String,
},
#[error("Configuration error: {reason}")]
ConfigError {
reason: String,
},
}
pub type PreflightResult<T> = Result<T, PreflightError>;
pub trait PreflightCheck: fmt::Debug + Send + Sync {
fn name(&self) -> &'static str;
fn validate(&self) -> PreflightResult<()>;
fn description(&self) -> &'static str {
"Preflight validation check"
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DeterministicInferenceConfig {
pub temperature: f64,
pub seed: u64,
pub top_k: usize,
pub top_p: f64,
}
impl Default for DeterministicInferenceConfig {
fn default() -> Self {
Self {
temperature: 0.0, seed: 42, top_k: 1, top_p: 1.0, }
}
}
impl DeterministicInferenceConfig {
#[must_use]
pub fn with_seed(seed: u64) -> Self {
Self {
seed,
..Default::default()
}
}
pub fn validate_determinism(&self) -> PreflightResult<()> {
if self.temperature > 0.0 {
return Err(PreflightError::ConfigError {
reason: format!(
"Temperature {} > 0.0 allows randomness; set to 0.0 for determinism",
self.temperature
),
});
}
if self.top_k != 1 {
return Err(PreflightError::ConfigError {
reason: format!(
"top_k {} != 1 allows multiple token choices; set to 1 for determinism",
self.top_k
),
});
}
Ok(())
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum StopReason {
CvConverged(f64),
MaxSamples,
Continue,
}
#[derive(Debug, Clone, PartialEq)]
pub enum StopDecision {
Continue,
Stop(StopReason),
}
#[derive(Debug, Clone)]
pub struct CvStoppingCriterion {
pub min_samples: usize,
pub max_samples: usize,
pub cv_threshold: f64,
}
impl Default for CvStoppingCriterion {
fn default() -> Self {
Self {
min_samples: 5,
max_samples: 30,
cv_threshold: 0.05, }
}
}
impl CvStoppingCriterion {
#[must_use]
pub fn new(min_samples: usize, max_samples: usize, cv_threshold: f64) -> Self {
Self {
min_samples,
max_samples,
cv_threshold,
}
}
#[must_use]
pub fn should_stop(&self, samples: &[f64]) -> StopDecision {
if samples.len() < self.min_samples {
return StopDecision::Continue;
}
if samples.len() >= self.max_samples {
return StopDecision::Stop(StopReason::MaxSamples);
}
let cv = self.calculate_cv(samples);
if cv < self.cv_threshold {
StopDecision::Stop(StopReason::CvConverged(cv))
} else {
StopDecision::Continue
}
}
#[must_use]
pub fn calculate_cv(&self, samples: &[f64]) -> f64 {
if samples.len() < 2 {
return f64::MAX;
}
let n = samples.len() as f64;
let mean = samples.iter().sum::<f64>() / n;
if mean.abs() < f64::EPSILON {
return f64::MAX;
}
let variance = samples.iter().map(|x| (x - mean).powi(2)).sum::<f64>() / (n - 1.0);
let std_dev = variance.sqrt();
std_dev / mean
}
}
pub struct OutlierDetector {
pub k_factor: f64,
}
impl Default for OutlierDetector {
fn default() -> Self {
Self { k_factor: 3.0 }
}
}
impl OutlierDetector {
#[must_use]
pub fn new(k_factor: f64) -> Self {
Self { k_factor }
}
#[must_use]
pub fn detect(&self, samples: &[f64]) -> Vec<bool> {
if samples.len() < 3 {
return vec![false; samples.len()];
}
let median = Self::percentile(samples, 50.0);
let deviations: Vec<f64> = samples.iter().map(|x| (x - median).abs()).collect();
let mad = Self::percentile(&deviations, 50.0);
let threshold = self.k_factor * mad * 1.4826;
samples
.iter()
.map(|x| (x - median).abs() > threshold)
.collect()
}
#[must_use]
pub fn filter(&self, samples: &[f64]) -> Vec<f64> {
let outliers = self.detect(samples);
samples
.iter()
.zip(outliers.iter())
.filter(|(_, is_outlier)| !**is_outlier)
.map(|(sample, _)| *sample)
.collect()
}
fn percentile(samples: &[f64], p: f64) -> f64 {
if samples.is_empty() {
return 0.0;
}
let mut sorted = samples.to_vec();
sorted.sort_by(|a, b| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal));
let idx = ((p / 100.0) * (sorted.len() - 1) as f64).round() as usize;
sorted[idx.min(sorted.len() - 1)]
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct QualityMetrics {
pub cv_at_stop: f64,
pub cv_converged: bool,
pub outliers_detected: usize,
pub outliers_excluded: usize,
pub preflight_checks_passed: Vec<String>,
}
impl Default for QualityMetrics {
fn default() -> Self {
Self {
cv_at_stop: f64::MAX,
cv_converged: false,
outliers_detected: 0,
outliers_excluded: 0,
preflight_checks_passed: Vec::new(),
}
}
}
#[derive(Debug)]
pub struct DeterminismCheck {
config: DeterministicInferenceConfig,
}
impl DeterminismCheck {
#[must_use]
pub fn new(config: DeterministicInferenceConfig) -> Self {
Self { config }
}
}
impl PreflightCheck for DeterminismCheck {
fn name(&self) -> &'static str {
"determinism_check"
}
fn description(&self) -> &'static str {
"Validates inference configuration ensures deterministic output"
}
fn validate(&self) -> PreflightResult<()> {
self.config.validate_determinism()
}
}
#[derive(Debug)]
pub struct ServerAvailabilityCheck {
url: String,
health_path: String,
health_status: Option<u16>,
}
impl ServerAvailabilityCheck {
#[must_use]
pub fn new(url: String, health_path: String) -> Self {
Self {
url,
health_path,
health_status: None,
}
}
#[must_use]
pub fn llama_cpp(port: u16) -> Self {
Self::new(format!("http://127.0.0.1:{port}"), "/health".to_string())
}
#[must_use]
pub fn ollama(port: u16) -> Self {
Self::new(format!("http://127.0.0.1:{port}"), "/api/tags".to_string())
}
pub fn set_health_status(&mut self, status: u16) {
self.health_status = Some(status);
}
#[must_use]
pub fn health_url(&self) -> String {
format!("{}{}", self.url, self.health_path)
}
fn validate_url(&self) -> PreflightResult<()> {
if self.url.is_empty() {
return Err(PreflightError::ConfigError {
reason: "Server URL cannot be empty".to_string(),
});
}
if !self.url.starts_with("http://") && !self.url.starts_with("https://") {
return Err(PreflightError::ConfigError {
reason: format!(
"Server URL must start with http:// or https://, got: {}",
self.url
),
});
}
Ok(())
}
}
impl PreflightCheck for ServerAvailabilityCheck {
fn name(&self) -> &'static str {
"server_availability_check"
}
fn description(&self) -> &'static str {
"Validates server is reachable at the configured URL"
}
fn validate(&self) -> PreflightResult<()> {
self.validate_url()?;
match self.health_status {
Some(status) if status >= 200 && status < 300 => Ok(()),
Some(status) => Err(PreflightError::HealthCheckFailed {
url: self.health_url(),
status,
}),
None => Err(PreflightError::ConfigError {
reason: "Health check not performed - call set_health_status() first".to_string(),
}),
}
}
}
#[derive(Debug)]
pub struct ModelAvailabilityCheck {
requested_model: String,
available_models: Vec<String>,
}
impl ModelAvailabilityCheck {
#[must_use]
pub fn new(requested_model: String) -> Self {
Self {
requested_model,
available_models: Vec::new(),
}
}
pub fn set_available_models(&mut self, models: Vec<String>) {
self.available_models = models;
}
#[must_use]
pub fn requested_model(&self) -> &str {
&self.requested_model
}
}
impl PreflightCheck for ModelAvailabilityCheck {
fn name(&self) -> &'static str {
"model_availability_check"
}
fn description(&self) -> &'static str {
"Validates requested model is available on the server"
}
fn validate(&self) -> PreflightResult<()> {
if self.requested_model.is_empty() {
return Err(PreflightError::ConfigError {
reason: "Model name cannot be empty".to_string(),
});
}
if self.available_models.is_empty() {
return Err(PreflightError::ConfigError {
reason: "Available models list not set - call set_available_models() first"
.to_string(),
});
}
let found = self.available_models.iter().any(|m| {
m == &self.requested_model
|| m.starts_with(&format!("{}:", self.requested_model))
|| self.requested_model.starts_with(&format!("{m}:"))
});
if found {
Ok(())
} else {
Err(PreflightError::ModelNotFound {
requested: self.requested_model.clone(),
available: self.available_models.clone(),
})
}
}
}
#[derive(Debug)]
pub struct ResponseSchemaCheck {
required_fields: Vec<String>,
field_types: std::collections::HashMap<String, String>,
}
impl ResponseSchemaCheck {
#[must_use]
pub fn new(required_fields: Vec<String>) -> Self {
Self {
required_fields,
field_types: std::collections::HashMap::new(),
}
}
#[must_use]
pub fn llama_cpp_completion() -> Self {
let mut check = Self::new(vec![
"content".to_string(),
"tokens_predicted".to_string(),
"timings".to_string(),
]);
check
.field_types
.insert("tokens_predicted".to_string(), "number".to_string());
check
.field_types
.insert("content".to_string(), "string".to_string());
check
}
#[must_use]
pub fn ollama_generate() -> Self {
let mut check = Self::new(vec!["response".to_string(), "done".to_string()]);
check
.field_types
.insert("response".to_string(), "string".to_string());
check
.field_types
.insert("done".to_string(), "boolean".to_string());
check
}
#[must_use]
pub fn with_type_constraint(mut self, field: String, expected_type: String) -> Self {
self.field_types.insert(field, expected_type);
self
}
pub fn validate_json(&self, json: &serde_json::Value) -> PreflightResult<()> {
let obj = json
.as_object()
.ok_or_else(|| PreflightError::ResponseParseError {
reason: "Expected JSON object at root".to_string(),
})?;
for field in &self.required_fields {
if !obj.contains_key(field) {
return Err(PreflightError::SchemaMismatch {
missing_field: field.clone(),
});
}
}
for (field, expected_type) in &self.field_types {
if let Some(value) = obj.get(field) {
let actual_type = match value {
serde_json::Value::Null => "null",
serde_json::Value::Bool(_) => "boolean",
serde_json::Value::Number(_) => "number",
serde_json::Value::String(_) => "string",
serde_json::Value::Array(_) => "array",
serde_json::Value::Object(_) => "object",
};
if actual_type != expected_type {
return Err(PreflightError::FieldTypeMismatch {
field: field.clone(),
expected: expected_type.clone(),
actual: actual_type.to_string(),
});
}
}
}
Ok(())
}
}
impl PreflightCheck for ResponseSchemaCheck {
fn name(&self) -> &'static str {
"response_schema_check"
}
fn description(&self) -> &'static str {
"Validates response JSON matches expected schema"
}
fn validate(&self) -> PreflightResult<()> {
if self.required_fields.is_empty() {
return Err(PreflightError::ConfigError {
reason: "At least one required field must be specified".to_string(),
});
}
Ok(())
}
}
#[derive(Debug, Default)]
pub struct PreflightRunner {
checks: Vec<Box<dyn PreflightCheck>>,
passed: Vec<String>,
}
impl PreflightRunner {
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub fn add_check(&mut self, check: Box<dyn PreflightCheck>) {
self.checks.push(check);
}
pub fn run(&mut self) -> PreflightResult<Vec<String>> {
self.passed.clear();
for check in &self.checks {
check.validate()?;
self.passed.push(check.name().to_string());
}
Ok(self.passed.clone())
}
#[must_use]
pub fn passed_checks(&self) -> &[String] {
&self.passed
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_canonical_inputs_version_is_semver() {
let version = canonical_inputs::VERSION;
let parts: Vec<&str> = version.split('.').collect();
assert_eq!(
parts.len(),
3,
"Version should be semver (major.minor.patch)"
);
for part in parts {
assert!(
part.parse::<u32>().is_ok(),
"Version part '{}' should be numeric",
part
);
}
}
#[test]
fn test_canonical_inputs_prompt_not_empty() {
let prompt_len = canonical_inputs::LATENCY_PROMPT.len();
assert!(
prompt_len >= 10,
"Latency prompt should have at least 10 chars, got {}",
prompt_len
);
}
#[test]
fn test_canonical_inputs_tokens_not_empty() {
let token_count = canonical_inputs::THROUGHPUT_TOKENS.len();
assert!(
token_count >= 4,
"Throughput tokens should have at least 4 tokens, got {}",
token_count
);
}
#[test]
fn test_canonical_inputs_max_tokens_reasonable() {
let max_tokens = canonical_inputs::MAX_TOKENS;
assert!(
max_tokens > 0,
"Max tokens should be positive, got {}",
max_tokens
);
assert!(
max_tokens <= 1000,
"Max tokens should be <= 1000, got {}",
max_tokens
);
}
#[test]
fn test_deterministic_config_default_is_deterministic() {
let config = DeterministicInferenceConfig::default();
assert!(
config.validate_determinism().is_ok(),
"Default config should be deterministic"
);
}
#[test]
fn test_deterministic_config_default_values() {
let config = DeterministicInferenceConfig::default();
assert_eq!(config.temperature, 0.0);
assert_eq!(config.seed, 42);
assert_eq!(config.top_k, 1);
assert_eq!(config.top_p, 1.0);
}
#[test]
fn test_deterministic_config_with_seed() {
let config = DeterministicInferenceConfig::with_seed(12345);
assert_eq!(config.seed, 12345);
assert!(config.validate_determinism().is_ok());
}
#[test]
fn test_deterministic_config_rejects_nonzero_temperature() {
let config = DeterministicInferenceConfig {
temperature: 0.7,
..Default::default()
};
let result = config.validate_determinism();
assert!(result.is_err());
let err = result.unwrap_err();
assert!(matches!(err, PreflightError::ConfigError { .. }));
}
#[test]
fn test_deterministic_config_rejects_topk_not_one() {
let config = DeterministicInferenceConfig {
top_k: 50,
..Default::default()
};
let result = config.validate_determinism();
assert!(result.is_err());
}
#[test]
fn test_cv_criterion_default_values() {
let criterion = CvStoppingCriterion::default();
assert_eq!(criterion.min_samples, 5);
assert_eq!(criterion.max_samples, 30);
assert!((criterion.cv_threshold - 0.05).abs() < 0.001);
}
#[test]
fn test_cv_criterion_continues_below_min_samples() {
let criterion = CvStoppingCriterion::new(5, 30, 0.05);
let samples = vec![100.0, 100.0, 100.0]; assert_eq!(criterion.should_stop(&samples), StopDecision::Continue);
}
#[test]
fn test_cv_criterion_stops_at_max_samples() {
let criterion = CvStoppingCriterion::new(5, 10, 0.01); let samples: Vec<f64> = (1..=10).map(|x| x as f64 * 10.0).collect();
assert_eq!(
criterion.should_stop(&samples),
StopDecision::Stop(StopReason::MaxSamples)
);
}
#[test]
fn test_cv_criterion_converges_on_identical_values() {
let criterion = CvStoppingCriterion::new(5, 30, 0.05);
let samples = vec![100.0, 100.0, 100.0, 100.0, 100.0];
let cv = criterion.calculate_cv(&samples);
assert!(
cv < 0.001,
"CV of identical values should be ~0, got {}",
cv
);
match criterion.should_stop(&samples) {
StopDecision::Stop(StopReason::CvConverged(cv)) => {
assert!(cv < 0.05);
},
other => panic!("Expected CvConverged, got {:?}", other),
}
}
#[test]
fn test_cv_criterion_continues_on_high_variance() {
let criterion = CvStoppingCriterion::new(5, 30, 0.05);
let samples = vec![10.0, 100.0, 10.0, 100.0, 10.0];
assert_eq!(criterion.should_stop(&samples), StopDecision::Continue);
}
#[test]
fn test_cv_calculation_single_value() {
let criterion = CvStoppingCriterion::default();
let samples = vec![100.0];
let cv = criterion.calculate_cv(&samples);
assert_eq!(cv, f64::MAX);
}
#[test]
fn test_cv_calculation_empty() {
let criterion = CvStoppingCriterion::default();
let samples: Vec<f64> = vec![];
let cv = criterion.calculate_cv(&samples);
assert_eq!(cv, f64::MAX);
}
#[test]
fn test_cv_calculation_known_values() {
let criterion = CvStoppingCriterion::default();
let samples = vec![90.0, 95.0, 100.0, 105.0, 110.0];
let cv = criterion.calculate_cv(&samples);
assert!(cv > 0.07 && cv < 0.09, "Expected CV ~0.079, got {}", cv);
}
#[test]
fn test_outlier_detector_default_k_factor() {
let detector = OutlierDetector::default();
assert!((detector.k_factor - 3.0).abs() < 0.001);
}
#[test]
fn test_outlier_detector_no_outliers_uniform() {
let detector = OutlierDetector::default();
let samples = vec![100.0, 101.0, 99.0, 100.5, 99.5];
let outliers = detector.detect(&samples);
assert!(
!outliers.iter().any(|&x| x),
"Uniform samples should have no outliers"
);
}
#[test]
fn test_outlier_detector_finds_extreme_outlier() {
let detector = OutlierDetector::default();
let samples = vec![100.0, 101.0, 99.0, 100.0, 1000.0]; let outliers = detector.detect(&samples);
assert!(outliers[4], "1000.0 should be detected as outlier");
}
#[test]
fn test_outlier_detector_filter_removes_outliers() {
let detector = OutlierDetector::default();
let samples = vec![100.0, 101.0, 99.0, 100.0, 1000.0];
let filtered = detector.filter(&samples);
assert!(
!filtered.contains(&1000.0),
"Filtered should not contain outlier"
);
assert_eq!(filtered.len(), 4);
}
#[test]
fn test_outlier_detector_handles_small_samples() {
let detector = OutlierDetector::default();
let samples = vec![100.0, 200.0]; let outliers = detector.detect(&samples);
assert_eq!(
outliers,
vec![false, false],
"Should not detect outliers with < 3 samples"
);
}
#[test]
fn test_outlier_detector_percentile() {
let samples = vec![1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0];
let p50 = OutlierDetector::percentile(&samples, 50.0);
assert!(
(p50 - 5.5).abs() < 1.0,
"p50 should be ~5.5 (nearest rank gives 6), got {}",
p50
);
let p99 = OutlierDetector::percentile(&samples, 99.0);
assert!((p99 - 10.0).abs() < 0.5, "p99 should be ~10.0, got {}", p99);
}
#[test]
fn test_quality_metrics_default() {
let metrics = QualityMetrics::default();
assert_eq!(metrics.cv_at_stop, f64::MAX);
assert!(!metrics.cv_converged);
assert_eq!(metrics.outliers_detected, 0);
assert!(metrics.preflight_checks_passed.is_empty());
}
#[test]
fn test_quality_metrics_serialization() {
let metrics = QualityMetrics {
cv_at_stop: 0.03,
cv_converged: true,
outliers_detected: 2,
outliers_excluded: 1,
preflight_checks_passed: vec!["server_check".to_string()],
};
let json = serde_json::to_string(&metrics).expect("serialization");
assert!(json.contains("0.03"));
assert!(json.contains("server_check"));
}
#[test]
fn test_determinism_check_trait_impl() {
let config = DeterministicInferenceConfig::default();
let check = DeterminismCheck::new(config);
assert_eq!(check.name(), "determinism_check");
assert!(check.validate().is_ok());
}
#[test]
fn test_determinism_check_fails_on_bad_config() {
let config = DeterministicInferenceConfig {
temperature: 0.5,
..Default::default()
};
let check = DeterminismCheck::new(config);
assert!(check.validate().is_err());
}
#[test]
fn test_preflight_error_display() {
let err = PreflightError::ModelNotFound {
requested: "phi".to_string(),
available: vec!["phi2:2.7b".to_string(), "llama2".to_string()],
};
let msg = format!("{}", err);
assert!(msg.contains("phi"));
assert!(msg.contains("phi2:2.7b"));
}
#[test]
fn test_preflight_error_schema_mismatch() {
let err = PreflightError::SchemaMismatch {
missing_field: "eval_count".to_string(),
};
let msg = format!("{}", err);
assert!(msg.contains("eval_count"));
}
#[test]
fn test_preflight_error_type_mismatch() {
let err = PreflightError::FieldTypeMismatch {
field: "tokens".to_string(),
expected: "number".to_string(),
actual: "string".to_string(),
};
let msg = format!("{}", err);
assert!(msg.contains("tokens"));
assert!(msg.contains("number"));
}
#[test]
fn test_server_check_llama_cpp_defaults() {
let check = ServerAvailabilityCheck::llama_cpp(8082);
assert_eq!(check.health_url(), "http://127.0.0.1:8082/health");
assert_eq!(check.name(), "server_availability_check");
}
#[test]
fn test_server_check_ollama_defaults() {
let check = ServerAvailabilityCheck::ollama(11434);
assert_eq!(check.health_url(), "http://127.0.0.1:11434/api/tags");
}
#[test]
fn test_server_check_validates_url_format() {
let check = ServerAvailabilityCheck::new("invalid-url".to_string(), "/health".to_string());
let result = check.validate();
assert!(result.is_err());
}
#[test]
fn test_server_check_rejects_empty_url() {
let check = ServerAvailabilityCheck::new(String::new(), "/health".to_string());
let result = check.validate();
assert!(matches!(result, Err(PreflightError::ConfigError { .. })));
}
#[test]
fn test_server_check_requires_health_status() {
let check = ServerAvailabilityCheck::llama_cpp(8082);
let result = check.validate();
assert!(result.is_err());
}
#[test]
fn test_server_check_accepts_200_status() {
let mut check = ServerAvailabilityCheck::llama_cpp(8082);
check.set_health_status(200);
assert!(check.validate().is_ok());
}
#[test]
fn test_server_check_accepts_204_status() {
let mut check = ServerAvailabilityCheck::llama_cpp(8082);
check.set_health_status(204); assert!(check.validate().is_ok());
}
#[test]
fn test_server_check_rejects_500_status() {
let mut check = ServerAvailabilityCheck::llama_cpp(8082);
check.set_health_status(500);
let result = check.validate();
assert!(matches!(
result,
Err(PreflightError::HealthCheckFailed { status: 500, .. })
));
}
#[test]
fn test_server_check_rejects_404_status() {
let mut check = ServerAvailabilityCheck::llama_cpp(8082);
check.set_health_status(404);
let result = check.validate();
assert!(result.is_err());
}
#[test]
fn test_model_check_finds_exact_match() {
let mut check = ModelAvailabilityCheck::new("phi2:2.7b".to_string());
check.set_available_models(vec!["phi2:2.7b".to_string(), "llama2".to_string()]);
assert!(check.validate().is_ok());
}
#[test]
fn test_model_check_finds_partial_match() {
let mut check = ModelAvailabilityCheck::new("phi2".to_string());
check.set_available_models(vec!["phi2:2.7b".to_string(), "llama2".to_string()]);
assert!(check.validate().is_ok());
}
#[test]
fn test_model_check_fails_on_missing_model() {
let mut check = ModelAvailabilityCheck::new("gpt4".to_string());
check.set_available_models(vec!["phi2:2.7b".to_string(), "llama2".to_string()]);
let result = check.validate();
assert!(matches!(result, Err(PreflightError::ModelNotFound { .. })));
}
#[test]
fn test_model_check_rejects_empty_model_name() {
let check = ModelAvailabilityCheck::new(String::new());
let result = check.validate();
assert!(result.is_err());
}
#[test]
fn test_model_check_requires_available_models() {
let check = ModelAvailabilityCheck::new("phi2".to_string());
let result = check.validate();
assert!(result.is_err());
}
#[test]
fn test_model_check_name() {
let check = ModelAvailabilityCheck::new("phi2".to_string());
assert_eq!(check.name(), "model_availability_check");
assert_eq!(check.requested_model(), "phi2");
}
#[test]
fn test_schema_check_llama_cpp_completion() {
let check = ResponseSchemaCheck::llama_cpp_completion();
assert_eq!(check.name(), "response_schema_check");
assert!(check.validate().is_ok()); }
#[test]
fn test_schema_check_ollama_generate() {
let check = ResponseSchemaCheck::ollama_generate();
assert!(check.validate().is_ok());
}
#[test]
fn test_schema_check_validates_required_fields() {
let check = ResponseSchemaCheck::new(vec!["content".to_string(), "tokens".to_string()]);
let json: serde_json::Value = serde_json::json!({
"content": "Hello",
"tokens": 5
});
assert!(check.validate_json(&json).is_ok());
}
#[test]
fn test_schema_check_fails_on_missing_field() {
let check = ResponseSchemaCheck::new(vec!["content".to_string(), "tokens".to_string()]);
let json: serde_json::Value = serde_json::json!({
"content": "Hello"
});
let result = check.validate_json(&json);
assert!(
matches!(result, Err(PreflightError::SchemaMismatch { missing_field }) if missing_field == "tokens")
);
}
#[test]
fn test_schema_check_validates_field_types() {
let check = ResponseSchemaCheck::new(vec!["count".to_string()])
.with_type_constraint("count".to_string(), "number".to_string());
let json: serde_json::Value = serde_json::json!({ "count": 42 });
assert!(check.validate_json(&json).is_ok());
let json: serde_json::Value = serde_json::json!({ "count": "42" });
let result = check.validate_json(&json);
assert!(matches!(
result,
Err(PreflightError::FieldTypeMismatch { .. })
));
}
#[test]
fn test_schema_check_rejects_non_object() {
let check = ResponseSchemaCheck::new(vec!["content".to_string()]);
let json: serde_json::Value = serde_json::json!("not an object");
let result = check.validate_json(&json);
assert!(matches!(
result,
Err(PreflightError::ResponseParseError { .. })
));
}
#[test]
fn test_schema_check_rejects_empty_required_fields() {
let check = ResponseSchemaCheck::new(vec![]);
let result = check.validate();
assert!(result.is_err());
}
#[test]
fn test_runner_runs_all_checks() {
let mut runner = PreflightRunner::new();
let config = DeterministicInferenceConfig::default();
runner.add_check(Box::new(DeterminismCheck::new(config)));
let schema = ResponseSchemaCheck::new(vec!["foo".to_string()]);
runner.add_check(Box::new(schema));
let result = runner.run();
assert!(result.is_ok());
assert_eq!(result.unwrap().len(), 2);
}
#[test]
fn test_runner_stops_on_first_failure_jidoka() {
let mut runner = PreflightRunner::new();
let config = DeterministicInferenceConfig::default();
runner.add_check(Box::new(DeterminismCheck::new(config)));
let schema = ResponseSchemaCheck::new(vec![]);
runner.add_check(Box::new(schema));
let config2 = DeterministicInferenceConfig::default();
runner.add_check(Box::new(DeterminismCheck::new(config2)));
let result = runner.run();
assert!(result.is_err());
assert_eq!(runner.passed_checks().len(), 1);
}
#[test]
fn test_runner_empty_passes() {
let mut runner = PreflightRunner::new();
let result = runner.run();
assert!(result.is_ok());
assert!(result.unwrap().is_empty());
}
#[test]
fn test_runner_clears_passed_on_rerun() {
let mut runner = PreflightRunner::new();
let config = DeterministicInferenceConfig::default();
runner.add_check(Box::new(DeterminismCheck::new(config)));
let _ = runner.run();
assert_eq!(runner.passed_checks().len(), 1);
let _ = runner.run();
assert_eq!(runner.passed_checks().len(), 1);
}
#[test]
#[ignore = "Requires running llama.cpp server on port 8082"]
fn test_imp_143a_llamacpp_real_server_check() {
let mut check = ServerAvailabilityCheck::llama_cpp(8082);
let client = reqwest::blocking::Client::builder()
.timeout(std::time::Duration::from_secs(5))
.build()
.expect("IMP-143a: Should create HTTP client");
let health_url = check.health_url();
match client.get(&health_url).send() {
Ok(response) => {
let status = response.status().as_u16();
check.set_health_status(status);
let result = check.validate();
assert!(
result.is_ok(),
"IMP-143a: llama.cpp server check should pass when server is running. Status: {}, Error: {:?}",
status,
result.err()
);
},
Err(e) => {
panic!(
"IMP-143a: Could not connect to llama.cpp server at {}. \
Start with: llama-server -m model.gguf --host 127.0.0.1 --port 8082. \
Error: {}",
health_url, e
);
},
}
}
#[test]
#[ignore = "Requires running Ollama server on port 11434"]
fn test_imp_143b_ollama_real_server_check() {
let mut check = ServerAvailabilityCheck::ollama(11434);
let client = reqwest::blocking::Client::builder()
.timeout(std::time::Duration::from_secs(5))
.build()
.expect("IMP-143b: Should create HTTP client");
let health_url = check.health_url();
match client.get(&health_url).send() {
Ok(response) => {
let status = response.status().as_u16();
check.set_health_status(status);
let result = check.validate();
assert!(
result.is_ok(),
"IMP-143b: Ollama server check should pass when server is running. Status: {}, Error: {:?}",
status,
result.err()
);
},
Err(e) => {
panic!(
"IMP-143b: Could not connect to Ollama server at {}. \
Start with: ollama serve. \
Error: {}",
health_url, e
);
},
}
}
#[test]
fn test_imp_143c_preflight_detects_unavailable_server() {
let mut check = ServerAvailabilityCheck::llama_cpp(59999);
check.set_health_status(0);
let result = check.validate();
assert!(
result.is_err(),
"IMP-143c: Preflight should detect unavailable server"
);
}
#[test]
fn test_imp_143d_preflight_error_reporting() {
let mut check = ServerAvailabilityCheck::llama_cpp(59998);
check.set_health_status(503);
let result = check.validate();
match result {
Err(PreflightError::HealthCheckFailed { status, url, .. }) => {
assert_eq!(status, 503, "IMP-143d: Should report correct status code");
assert!(url.contains("59998"), "IMP-143d: Should report correct URL");
},
_ => panic!("IMP-143d: Should return HealthCheckFailed error"),
}
}
}