use car_inference::{
ContentBlock, Device, GenerateRequest, HardwareInfo, InferenceConfig, InferenceError,
ModelRegistry, ModelRole, ModelRouter, TaskComplexity,
};
use base64::Engine as _;
use image::{DynamicImage, ImageBuffer, ImageFormat, Rgb};
use serde_json::json;
use std::io::Cursor;
use std::path::PathBuf;
use tempfile::TempDir;
#[test]
fn device_auto_detection() {
let device = Device::auto();
#[cfg(not(any(feature = "metal", feature = "cuda")))]
assert_eq!(device, Device::Cpu);
#[cfg(feature = "metal")]
assert_eq!(device, Device::Metal);
#[cfg(feature = "cuda")]
assert!(matches!(device, Device::Cuda(0)));
}
#[test]
fn default_config() {
let config = InferenceConfig::default();
assert!(config.models_dir.ends_with(".car/models"));
assert_eq!(config.embedding_model, "Qwen3-Embedding-0.6B");
assert_eq!(config.classification_model, "Qwen3-0.6B");
assert!(config.device.is_none()); }
#[test]
fn model_registry_lists_catalog() {
let dir = PathBuf::from("/tmp/car-test-models");
let registry = ModelRegistry::new(dir);
let models = registry.list_models();
assert_eq!(models.len(), 6);
let emb = models
.iter()
.find(|m| m.name == "Qwen3-Embedding-0.6B")
.unwrap();
assert_eq!(emb.role, ModelRole::Embedding);
let small = models.iter().find(|m| m.name == "Qwen3-0.6B").unwrap();
assert_eq!(small.role, ModelRole::Small);
assert_eq!(small.param_count, "0.6B");
assert!(!small.downloaded);
let moe = models.iter().find(|m| m.name == "Qwen3-30B-A3B").unwrap();
assert_eq!(moe.role, ModelRole::Expert);
assert_eq!(moe.param_count, "30B (3B active)");
}
#[test]
fn model_registry_not_found() {
let dir = PathBuf::from("/tmp/car-test-models");
let registry = ModelRegistry::new(dir);
let rt = tokio::runtime::Runtime::new().unwrap();
let result = rt.block_on(registry.ensure_model("NonExistentModel-999B"));
assert!(matches!(result, Err(InferenceError::ModelNotFound(_))));
}
#[test]
fn model_registry_remove_nonexistent() {
let dir = PathBuf::from("/tmp/car-test-models-remove");
let registry = ModelRegistry::new(dir);
let result = registry.remove_model("Qwen3-0.6B");
assert!(result.is_ok());
}
#[test]
fn unified_registry_reports_curated_model_upgrades() {
let tmp = TempDir::new().unwrap();
let models_dir = tmp.path().join("models");
std::fs::create_dir_all(models_dir.join("Qwen3-30B-A3B-MLX")).unwrap();
std::fs::write(
models_dir.join("Qwen3-30B-A3B-MLX").join("config.json"),
"{}",
)
.unwrap();
let registry = car_inference::UnifiedRegistry::new(models_dir);
let upgrades = registry.available_upgrades();
assert_eq!(upgrades.len(), 1);
let upgrade = &upgrades[0];
assert_eq!(upgrade.from_id, "mlx/qwen3-30b-a3b:4bit");
assert_eq!(upgrade.to_id, "vllm-mlx/qwen3.6-35b-a3b:4bit");
assert!(!upgrade.target_pullable);
assert!(upgrade.remove_old_supported);
}
#[test]
fn unified_registry_includes_qwen36_vllm_target() {
let tmp = TempDir::new().unwrap();
let registry = car_inference::UnifiedRegistry::new(tmp.path().join("models"));
let qwen36 = registry
.get("vllm-mlx/qwen3.6-35b-a3b:4bit")
.expect("qwen3.6 target");
assert_eq!(qwen36.family, "qwen3.6");
assert_eq!(qwen36.context_length, 262_144);
assert!(matches!(
qwen36.source,
car_inference::ModelSource::VllmMlx { .. }
));
assert!(qwen36.has_capability(car_inference::ModelCapability::Code));
assert!(qwen36.has_capability(car_inference::ModelCapability::Vision));
}
#[test]
fn service_schemas_defined() {
let schemas = car_inference::service::all_schemas();
assert_eq!(schemas.len(), 11);
let names: Vec<&str> = schemas.iter().map(|s| s.name.as_str()).collect();
assert!(names.contains(&"infer"));
assert!(names.contains(&"infer.grounded"));
assert!(names.contains(&"embed"));
assert!(names.contains(&"classify"));
assert!(names.contains(&"generate_image"));
assert!(names.contains(&"generate_video"));
assert!(names.contains(&"transcribe"));
assert!(names.contains(&"synthesize"));
let infer = schemas.iter().find(|s| s.name == "infer").unwrap();
let required = infer.parameters.get("required").unwrap();
assert!(required.as_array().unwrap().iter().any(|v| v == "prompt"));
let embed = schemas.iter().find(|s| s.name == "embed").unwrap();
assert!(embed.idempotent);
assert!(embed.cache_ttl_secs.is_some());
let classify = schemas.iter().find(|s| s.name == "classify").unwrap();
assert!(classify.idempotent);
}
#[test]
fn generate_params_defaults() {
let params = car_inference::GenerateParams::default();
assert!((params.temperature - 0.7).abs() < f64::EPSILON);
assert!((params.top_p - 0.9).abs() < f64::EPSILON);
assert_eq!(params.top_k, 0);
assert_eq!(params.max_tokens, 4096);
assert!(params.stop.is_empty());
}
#[test]
fn generate_params_serde() {
let json = r#"{"temperature": 0.0, "max_tokens": 100}"#;
let params: car_inference::GenerateParams = serde_json::from_str(json).unwrap();
assert_eq!(params.temperature, 0.0);
assert_eq!(params.max_tokens, 100);
assert!((params.top_p - 0.9).abs() < f64::EPSILON);
}
#[test]
fn inference_engine_creation() {
let config = InferenceConfig {
models_dir: PathBuf::from("/tmp/car-test-engine"),
device: Some(Device::Cpu),
generation_model: "Qwen3-0.6B".to_string(),
preferred_generation_model: None,
embedding_model: "Qwen3-0.6B".to_string(),
preferred_embedding_model: None,
classification_model: "Qwen3-0.6B".to_string(),
preferred_classification_model: None,
};
let engine = car_inference::InferenceEngine::new(config);
let models = engine.list_models();
assert_eq!(models.len(), 6); }
fn load_repo_env() {
let mut dir = Some(PathBuf::from("/Users/mliotta/git/car"));
while let Some(d) = dir {
let env_file = d.join(".env");
if env_file.exists() {
if let Ok(content) = std::fs::read_to_string(&env_file) {
for line in content.lines() {
let line = line.trim();
if line.is_empty() || line.starts_with('#') {
continue;
}
if let Some((key, value)) = line.split_once('=') {
let key = key.trim();
let value = value.trim();
if std::env::var(key).is_err() {
unsafe { std::env::set_var(key, value) };
}
}
}
}
break;
}
dir = d.parent().map(|p| p.to_path_buf());
}
}
fn configure_managed_speech_runtime() {
let home = std::env::var("HOME").expect("HOME should be set");
let runtime_dir = PathBuf::from(home).join(".car").join("speech-runtime");
unsafe {
std::env::set_var("CAR_SPEECH_RUNTIME_DIR", runtime_dir);
std::env::set_var("CAR_SPEECH_PYTHON", "python3.13");
}
}
fn test_engine(models_dir: PathBuf) -> car_inference::InferenceEngine {
car_inference::InferenceEngine::new(InferenceConfig {
models_dir,
device: Some(Device::Cpu),
generation_model: "Qwen3-0.6B".to_string(),
preferred_generation_model: None,
embedding_model: "Qwen3-Embedding-0.6B".to_string(),
preferred_embedding_model: None,
classification_model: "Qwen3-0.6B".to_string(),
preferred_classification_model: None,
})
}
fn require_env(key: &str) {
load_repo_env();
assert!(
std::env::var(key).is_ok(),
"{key} must be set in the environment or repo .env"
);
}
fn black_png_base64() -> String {
let image = ImageBuffer::from_pixel(32, 32, Rgb([0, 0, 0]));
let mut bytes = Vec::new();
DynamicImage::ImageRgb8(image)
.write_to(&mut Cursor::new(&mut bytes), ImageFormat::Png)
.expect("should encode PNG");
base64::engine::general_purpose::STANDARD.encode(bytes)
}
#[test]
#[ignore = "live local speech roundtrip using mlx-audio and downloaded models"]
fn local_speech_roundtrip_live() {
configure_managed_speech_runtime();
let rt = tokio::runtime::Runtime::new().unwrap();
rt.block_on(async {
let tmp = TempDir::new().unwrap();
let engine = car_inference::InferenceEngine::new(InferenceConfig {
models_dir: tmp.path().join("models"),
device: Some(Device::Cpu),
generation_model: "Qwen3-0.6B".to_string(),
preferred_generation_model: None,
embedding_model: "Qwen3-Embedding-0.6B".to_string(),
preferred_embedding_model: None,
classification_model: "Qwen3-0.6B".to_string(),
preferred_classification_model: None,
});
let audio_path = tmp.path().join("local-roundtrip.wav");
let synth = engine
.synthesize(car_inference::SynthesizeRequest {
text: "Testing CAR local speech roundtrip.".to_string(),
voice: Some("Chelsie".to_string()),
language: Some("en".to_string()),
speed: Some(1.0),
output_path: Some(audio_path.display().to_string()),
..car_inference::SynthesizeRequest::default()
})
.await
.expect("local synthesis should succeed");
assert!(std::fs::metadata(&synth.audio_path).is_ok());
let transcript = engine
.transcribe(car_inference::TranscribeRequest {
audio_path: synth.audio_path.clone(),
model: None,
language: None,
prompt: None,
timestamps: false,
})
.await
.expect("local transcription should succeed");
assert!(!transcript.text.trim().is_empty());
assert!(
transcript.text.to_lowercase().contains("testing")
|| transcript.text.to_lowercase().contains("car")
|| transcript.text.to_lowercase().contains("speech")
);
});
}
#[test]
#[ignore = "live ElevenLabs speech roundtrip using ELEVENLABS_API_KEY"]
fn elevenlabs_speech_roundtrip_live() {
load_repo_env();
configure_managed_speech_runtime();
assert!(
std::env::var("ELEVENLABS_API_KEY").is_ok(),
"ELEVENLABS_API_KEY must be set in the environment or repo .env"
);
let rt = tokio::runtime::Runtime::new().unwrap();
rt.block_on(async {
let tmp = TempDir::new().unwrap();
let engine = car_inference::InferenceEngine::new(InferenceConfig {
models_dir: tmp.path().join("models"),
device: Some(Device::Cpu),
generation_model: "Qwen3-0.6B".to_string(),
preferred_generation_model: None,
embedding_model: "Qwen3-Embedding-0.6B".to_string(),
preferred_embedding_model: None,
classification_model: "Qwen3-0.6B".to_string(),
preferred_classification_model: None,
});
let audio_path = tmp.path().join("elevenlabs-roundtrip.wav");
let synth = engine
.synthesize(car_inference::SynthesizeRequest {
text: "Testing CAR ElevenLabs speech roundtrip.".to_string(),
model: Some("eleven_flash_v2_5".to_string()),
voice: Some("JBFqnCBsd6RMkjVDRZzb".to_string()),
language: Some("en".to_string()),
output_path: Some(audio_path.display().to_string()),
..car_inference::SynthesizeRequest::default()
})
.await
.expect("ElevenLabs synthesis should succeed");
assert!(std::fs::metadata(&synth.audio_path).is_ok());
let transcript = engine
.transcribe(car_inference::TranscribeRequest {
audio_path: synth.audio_path.clone(),
model: Some("scribe_v1".to_string()),
language: Some("en".to_string()),
prompt: None,
timestamps: false,
})
.await
.expect("ElevenLabs transcription should succeed");
assert!(!transcript.text.trim().is_empty());
assert!(
transcript.text.to_lowercase().contains("testing")
|| transcript.text.to_lowercase().contains("car")
|| transcript.text.to_lowercase().contains("elevenlabs")
);
});
}
#[test]
#[ignore = "live Gemini vision request using GOOGLE_API_KEY"]
fn gemini_vision_live() {
require_env("GOOGLE_API_KEY");
let rt = tokio::runtime::Runtime::new().unwrap();
rt.block_on(async {
let tmp = TempDir::new().unwrap();
let engine = test_engine(tmp.path().join("models"));
let image_data = black_png_base64();
let result = engine
.generate(GenerateRequest {
prompt: "What color is this image? Reply with one word.".to_string(),
model: Some("gemini-2.5-flash".to_string()),
params: car_inference::GenerateParams {
temperature: 0.0,
max_tokens: 64,
..Default::default()
},
context: None,
tools: None,
images: Some(vec![ContentBlock::ImageBase64 {
data: image_data,
media_type: "image/png".to_string(),
}]),
messages: None,
cache_control: false,
response_format: None,
intent: None,
})
.await
.expect("Gemini vision request should succeed");
let _ = result;
});
}
#[test]
#[ignore = "live Gemini tool call using GOOGLE_API_KEY"]
fn gemini_tool_use_live() {
require_env("GOOGLE_API_KEY");
let rt = tokio::runtime::Runtime::new().unwrap();
rt.block_on(async {
let tmp = TempDir::new().unwrap();
let engine = test_engine(tmp.path().join("models"));
let result = engine
.generate_tracked(GenerateRequest {
prompt: "Call the echo tool exactly once with {\"value\":\"gemini\"}. Do not answer with plain text.".to_string(),
model: Some("gemini-2.5-flash".to_string()),
params: car_inference::GenerateParams {
temperature: 0.0,
max_tokens: 64,
..Default::default()
},
context: None,
tools: Some(vec![json!({
"name": "echo",
"description": "Echo a value back.",
"parameters": {
"type": "object",
"properties": {
"value": {"type": "string"}
},
"required": ["value"]
}
})]),
images: None,
messages: None,
cache_control: false,
response_format: None,
intent: None,
})
.await
.expect("Gemini tool-use request should succeed");
assert!(result.has_tool_calls(), "expected a Gemini tool call, got text: {}", result.text);
assert_eq!(result.tool_calls[0].name, "echo");
assert_eq!(
result.tool_calls[0]
.arguments
.get("value")
.and_then(|v| v.as_str()),
Some("gemini")
);
});
}
#[test]
#[ignore = "live OpenAI vision request using OPENAI_API_KEY"]
fn openai_vision_live() {
require_env("OPENAI_API_KEY");
let rt = tokio::runtime::Runtime::new().unwrap();
rt.block_on(async {
let tmp = TempDir::new().unwrap();
let engine = test_engine(tmp.path().join("models"));
let image_data = black_png_base64();
let result = engine
.generate(GenerateRequest {
prompt: "What color is this image? Reply with one word.".to_string(),
model: Some("gpt-4.1-mini".to_string()),
params: car_inference::GenerateParams {
temperature: 0.0,
max_tokens: 16,
..Default::default()
},
context: None,
tools: None,
images: Some(vec![ContentBlock::ImageBase64 {
data: image_data,
media_type: "image/png".to_string(),
}]),
messages: None,
cache_control: false,
response_format: None,
intent: None,
})
.await
.expect("OpenAI vision request should succeed");
let _ = result;
});
}
#[test]
#[ignore = "live OpenAI tool call using OPENAI_API_KEY"]
fn openai_tool_use_live() {
require_env("OPENAI_API_KEY");
let rt = tokio::runtime::Runtime::new().unwrap();
rt.block_on(async {
let tmp = TempDir::new().unwrap();
let engine = test_engine(tmp.path().join("models"));
let result = engine
.generate_tracked(GenerateRequest {
prompt: "Call the echo tool exactly once with {\"value\":\"openai\"}. Do not answer with plain text.".to_string(),
model: Some("gpt-4.1-mini".to_string()),
params: car_inference::GenerateParams {
temperature: 0.0,
max_tokens: 64,
..Default::default()
},
context: None,
tools: Some(vec![json!({
"name": "echo",
"description": "Echo a value back.",
"parameters": {
"type": "object",
"properties": {
"value": {"type": "string"}
},
"required": ["value"]
}
})]),
images: None,
messages: None,
cache_control: false,
response_format: None,
intent: None,
})
.await
.expect("OpenAI tool-use request should succeed");
assert!(result.has_tool_calls(), "expected an OpenAI tool call, got text: {}", result.text);
assert_eq!(result.tool_calls[0].name, "echo");
assert_eq!(
result.tool_calls[0]
.arguments
.get("value")
.and_then(|v| v.as_str()),
Some("openai")
);
});
}
#[test]
#[ignore = "live Anthropic vision request using ANTHROPIC_API_KEY"]
fn anthropic_vision_live() {
require_env("ANTHROPIC_API_KEY");
let rt = tokio::runtime::Runtime::new().unwrap();
rt.block_on(async {
let tmp = TempDir::new().unwrap();
let engine = test_engine(tmp.path().join("models"));
let image_data = black_png_base64();
let result = engine
.generate(GenerateRequest {
prompt: "What color is this image? Reply with one word.".to_string(),
model: Some("claude-haiku-4-5".to_string()),
params: car_inference::GenerateParams {
temperature: 0.0,
max_tokens: 16,
..Default::default()
},
context: None,
tools: None,
images: Some(vec![ContentBlock::ImageBase64 {
data: image_data,
media_type: "image/png".to_string(),
}]),
messages: None,
cache_control: false,
response_format: None,
intent: None,
})
.await
.expect("Anthropic vision request should succeed");
let _ = result;
});
}
#[test]
#[ignore = "live Anthropic tool call using ANTHROPIC_API_KEY"]
fn anthropic_tool_use_live() {
require_env("ANTHROPIC_API_KEY");
let rt = tokio::runtime::Runtime::new().unwrap();
rt.block_on(async {
let tmp = TempDir::new().unwrap();
let engine = test_engine(tmp.path().join("models"));
let result = engine
.generate_tracked(GenerateRequest {
prompt: "Call the echo tool exactly once with {\"value\":\"anthropic\"}. Do not answer with plain text.".to_string(),
model: Some("claude-haiku-4-5".to_string()),
params: car_inference::GenerateParams {
temperature: 0.0,
max_tokens: 64,
..Default::default()
},
context: None,
tools: Some(vec![json!({
"function": {
"name": "echo",
"description": "Echo a value back.",
"parameters": {
"type": "object",
"properties": {
"value": {"type": "string"}
},
"required": ["value"]
}
}
})]),
images: None,
messages: None,
cache_control: false,
response_format: None,
intent: None,
})
.await
.expect("Anthropic tool-use request should succeed");
assert!(
result.has_tool_calls(),
"expected an Anthropic tool call, got text: {}",
result.text
);
assert_eq!(result.tool_calls[0].name, "echo");
assert_eq!(
result.tool_calls[0]
.arguments
.get("value")
.and_then(|v| v.as_str()),
Some("anthropic")
);
});
}
#[test]
fn generate_request_with_context() {
let req = car_inference::GenerateRequest {
prompt: "What should I do?".into(),
model: None,
params: Default::default(),
context: Some("## Facts\n- Project uses Rust\n- Deadline is Friday".into()),
tools: None,
images: None,
messages: None,
cache_control: false,
response_format: None,
intent: None,
};
assert!(req.context.is_some());
let json = serde_json::to_string(&req).unwrap();
assert!(json.contains("context"));
assert!(json.contains("Deadline is Friday"));
}
#[test]
fn generate_request_without_context_backward_compat() {
let json = r#"{"prompt":"hello","params":{"temperature":0.7}}"#;
let req: car_inference::GenerateRequest = serde_json::from_str(json).unwrap();
assert!(req.context.is_none());
}
#[test]
fn generate_request_with_context_roundtrip() {
let req = car_inference::GenerateRequest {
prompt: "test prompt".into(),
model: Some("Qwen3-0.6B".into()),
params: Default::default(),
context: Some("some context".into()),
tools: None,
images: None,
messages: None,
cache_control: false,
response_format: None,
intent: None,
};
let json = serde_json::to_string(&req).unwrap();
let deserialized: car_inference::GenerateRequest = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized.prompt, "test prompt");
assert_eq!(deserialized.context.as_deref(), Some("some context"));
assert_eq!(deserialized.model.as_deref(), Some("Qwen3-0.6B"));
}
#[test]
fn generate_request_with_null_context() {
let json = r#"{"prompt":"hello","context":null}"#;
let req: car_inference::GenerateRequest = serde_json::from_str(json).unwrap();
assert!(req.context.is_none());
}
#[test]
fn infer_grounded_schema_defined() {
let schemas = car_inference::service::all_schemas();
assert_eq!(schemas.len(), 11);
let names: Vec<&str> = schemas.iter().map(|s| s.name.as_str()).collect();
assert!(names.contains(&"infer.grounded"));
let grounded = schemas.iter().find(|s| s.name == "infer.grounded").unwrap();
let required = grounded.parameters.get("required").unwrap();
assert!(required.as_array().unwrap().iter().any(|v| v == "prompt"));
assert!(!grounded.idempotent);
}
#[test]
fn routing_simple_question() {
let complexity = TaskComplexity::assess("What is the capital of France?");
assert_eq!(complexity, TaskComplexity::Simple);
}
#[test]
fn routing_code_task() {
let complexity = TaskComplexity::assess("Write a function to sort a list");
assert_eq!(complexity, TaskComplexity::Code);
}
#[test]
fn routing_complex_reasoning() {
let complexity = TaskComplexity::assess(
"Analyze the trade-offs between microservices and monolithic architecture for our use case",
);
assert_eq!(complexity, TaskComplexity::Complex);
}
#[test]
fn routing_medium_default() {
let complexity = TaskComplexity::assess(
"Tell me about the history of computing and how it has evolved over the decades from early mechanical calculation devices through transistors to the modern digital systems we use today in everyday life",
);
assert_eq!(complexity, TaskComplexity::Medium);
}
#[test]
fn router_respects_hardware() {
let mut hw = HardwareInfo::detect();
hw.max_model_mb = 1000; hw.recommended_model = "Qwen3-0.6B".to_string();
let router = ModelRouter::new(hw);
let registry = ModelRegistry::new(PathBuf::from("/tmp/car-test-routing"));
let decision = router.route_generate(
"Analyze complex architecture decisions step by step",
®istry,
);
assert_eq!(decision.model, "Qwen3-0.6B");
assert_eq!(decision.complexity, TaskComplexity::Complex);
}
#[test]
fn routing_repair_is_code() {
let complexity = TaskComplexity::assess("Debug this failing test case");
assert_eq!(complexity, TaskComplexity::Code);
}
#[test]
fn routing_backtick_code() {
let complexity =
TaskComplexity::assess("Here is my code:\n```rust\nfn main() {}\n```\nWhat does it do?");
assert_eq!(complexity, TaskComplexity::Code);
}
#[test]
fn hardware_detect_returns_reasonable_values() {
let hw = HardwareInfo::detect();
assert!(
hw.os == "macos" || hw.os == "linux" || hw.os == "windows",
"unexpected os: {}",
hw.os
);
assert!(!hw.arch.is_empty());
assert!(hw.cpu_cores >= 1);
assert!(
hw.total_ram_mb >= 1024,
"unexpected RAM: {} MB",
hw.total_ram_mb
);
let known_models = [
"Qwen3-0.6B",
"Qwen3-1.7B",
"Qwen3-4B",
"Qwen3-8B",
"Qwen3-30B-A3B",
"Qwen3-0.6B-MLX",
"Qwen3-1.7B-MLX",
"Qwen3-4B-MLX",
"Qwen3-8B-MLX",
"Qwen3-30B-A3B-MLX",
];
assert!(
known_models.contains(&hw.recommended_model.as_str()),
"unexpected recommended model: {}",
hw.recommended_model
);
assert!(hw.recommended_context >= 2048);
assert!(hw.recommended_context <= 131072);
assert!(hw.max_model_mb > 0);
}
#[test]
fn generate_video_request_defaults_to_t2v() {
let req = car_inference::GenerateVideoRequest {
prompt: "a dog running".into(),
..car_inference::GenerateVideoRequest::default()
};
assert_eq!(req.effective_mode(), car_inference::VideoMode::T2v);
}
#[test]
fn generate_video_request_infers_i2v_from_image_path() {
let req = car_inference::GenerateVideoRequest {
prompt: "the dog jumps".into(),
image_path: Some("/tmp/ref.png".into()),
..car_inference::GenerateVideoRequest::default()
};
assert_eq!(req.effective_mode(), car_inference::VideoMode::I2v);
}
#[test]
fn generate_video_request_round_trips_i2v_through_json() {
let payload = json!({
"prompt": "cat walking",
"image_path": "/tmp/cat.jpg",
});
let req: car_inference::GenerateVideoRequest =
serde_json::from_value(payload).expect("deserialize");
assert_eq!(req.prompt, "cat walking");
assert_eq!(req.image_path.as_deref(), Some("/tmp/cat.jpg"));
assert_eq!(req.effective_mode(), car_inference::VideoMode::I2v);
let back = serde_json::to_value(&req).expect("serialize");
assert_eq!(back.get("image_path").and_then(|v| v.as_str()), Some("/tmp/cat.jpg"));
}
#[test]
fn generate_video_schema_exposes_i2v_fields() {
let schemas = car_inference::service::all_schemas();
let video = schemas
.iter()
.find(|s| s.name == "generate_video")
.expect("generate_video schema");
let props = video
.parameters
.get("properties")
.and_then(|v| v.as_object())
.expect("properties object");
assert!(props.contains_key("image_path"), "i2v field surfaced in tool schema");
assert!(
props.contains_key("audio_path"),
"audio reference field surfaced in tool schema"
);
assert!(props.contains_key("mode"), "mode selector surfaced in tool schema");
let mode = props.get("mode").and_then(|v| v.as_object()).unwrap();
let enum_vals: Vec<&str> = mode
.get("enum")
.and_then(|v| v.as_array())
.unwrap()
.iter()
.map(|v| v.as_str().unwrap())
.collect();
assert!(enum_vals.contains(&"t2v") && enum_vals.contains(&"i2v"));
assert!(enum_vals.contains(&"audio_video"), "audio_video exposed in schema");
assert!(
enum_vals.contains(&"audio_ref_video"),
"audio_ref_video exposed in schema"
);
}
#[test]
fn generate_video_request_explicit_audio_video_mode() {
let payload = json!({
"prompt": "thunderstorm over the plains",
"mode": "audio_video",
});
let req: car_inference::GenerateVideoRequest =
serde_json::from_value(payload).expect("deserialize");
assert_eq!(req.effective_mode(), car_inference::VideoMode::AudioVideo);
assert!(req.validate().is_ok());
}
#[test]
fn generate_video_request_rejects_audio_video_with_image() {
let payload = json!({
"prompt": "x",
"mode": "audio_video",
"image_path": "/tmp/x.png",
});
let req: car_inference::GenerateVideoRequest =
serde_json::from_value(payload).expect("deserialize");
let err = req.validate().expect_err("audio_video + image_path should be rejected");
assert!(err.contains("audio_video"), "error mentions mode: {}", err);
}
#[test]
fn generate_video_request_rejects_t2v_with_image() {
let payload = json!({
"prompt": "x",
"mode": "t2v",
"image_path": "/tmp/x.png",
});
let req: car_inference::GenerateVideoRequest =
serde_json::from_value(payload).expect("deserialize");
assert!(req.validate().is_err());
}
#[test]
fn generate_video_request_audio_video_round_trip() {
let req = car_inference::GenerateVideoRequest {
prompt: "waves crashing".into(),
mode: Some(car_inference::VideoMode::AudioVideo),
..car_inference::GenerateVideoRequest::default()
};
let v = serde_json::to_value(&req).unwrap();
assert_eq!(v.get("mode").and_then(|m| m.as_str()), Some("audio_video"));
let back: car_inference::GenerateVideoRequest = serde_json::from_value(v).unwrap();
assert_eq!(back.effective_mode(), car_inference::VideoMode::AudioVideo);
}
#[test]
fn generate_video_request_audio_ref_round_trip() {
let req = car_inference::GenerateVideoRequest {
prompt: "chorus explodes into color".into(),
audio_path: Some("/tmp/chorus.wav".into()),
..car_inference::GenerateVideoRequest::default()
};
assert_eq!(
req.effective_mode(),
car_inference::VideoMode::AudioRefVideo
);
req.validate().unwrap();
let v = serde_json::to_value(&req).unwrap();
assert_eq!(
v.get("audio_path").and_then(|m| m.as_str()),
Some("/tmp/chorus.wav")
);
let back: car_inference::GenerateVideoRequest = serde_json::from_value(v).unwrap();
assert_eq!(
back.effective_mode(),
car_inference::VideoMode::AudioRefVideo
);
}