mod api;
mod api_errors;
mod auto_discovery;
mod cli;
mod engine;
mod main_integration;
mod model_registry;
mod openai_compat;
mod port_manager;
mod server;
mod templates;
mod util { pub mod diag; }
use clap::Parser;
use model_registry::{ModelEntry, Registry};
use std::net::SocketAddr;
use tracing::info;
use std::sync::Arc;
pub struct AppState {
pub engine: Box<dyn engine::InferenceEngine>,
pub registry: Registry,
}
#[tokio::main]
async fn main() -> anyhow::Result<()> {
tracing_subscriber::fmt().with_env_filter(tracing_subscriber::EnvFilter::from_default_env()).init();
let cli = cli::Cli::parse();
let mut reg = Registry::with_discovery();
reg.register(ModelEntry {
name: "phi3-lora".into(),
base_path: std::env::var("SHIMMY_BASE_GGUF").unwrap_or_else(|_| "./models/phi3-mini.gguf".into()).into(),
lora_path: std::env::var("SHIMMY_LORA_GGUF").ok().map(Into::into),
template: Some("chatml".into()),
ctx_len: Some(4096),
n_threads: None,
});
let engine: Box<dyn engine::InferenceEngine> = Box::new(engine::adapter::InferenceEngineAdapter::new());
let state = AppState { engine, registry: reg };
let state = Arc::new(state);
match cli.cmd {
cli::Command::Serve { .. } => {
let bind_address = cli.cmd.get_bind_address();
let addr: SocketAddr = bind_address.parse().expect("bad bind address");
println!("🚀 Starting Shimmy server on {}", bind_address);
let manual_count = state.registry.list().len();
if manual_count <= 1 { let mut enhanced_state = AppState {
engine: Box::new(engine::llama::LlamaEngine::new()),
registry: state.registry.clone(),
};
enhanced_state.registry.auto_register_discovered();
let enhanced_state = Arc::new(enhanced_state);
let available_models = enhanced_state.registry.list_all_available();
if available_models.is_empty() {
eprintln!("❌ No models available. Please:");
eprintln!(" • Set SHIMMY_BASE_GGUF environment variable, or");
eprintln!(" • Place .gguf files in ./models/ directory, or");
eprintln!(" • Place .gguf files in ~/.cache/huggingface/hub/");
std::process::exit(1);
}
info!(%addr, models=%available_models.len(), "shimmy serving with {} available models", available_models.len());
return server::run(addr, enhanced_state).await;
}
let available_models = state.registry.list_all_available();
if available_models.is_empty() {
eprintln!("❌ No models available. Please:");
eprintln!(" • Set SHIMMY_BASE_GGUF environment variable, or");
eprintln!(" • Place .gguf files in ./models/ directory, or");
eprintln!(" • Place .gguf files in ~/.cache/huggingface/hub/");
std::process::exit(1);
}
info!(%addr, models=%available_models.len(), "shimmy serving with {} available models", available_models.len());
server::run(addr, state).await?;
}
cli::Command::List => {
let manual_models = state.registry.list();
if !manual_models.is_empty() {
println!("📋 Registered Models:");
for e in &manual_models {
println!(" {} => {:?}", e.name, e.base_path);
}
}
let auto_discovered = state.registry.discovered_models.clone();
if !auto_discovered.is_empty() {
if !manual_models.is_empty() { println!(); }
println!("🔍 Auto-Discovered Models:");
for (name, model) in auto_discovered {
let size_mb = model.size_bytes / (1024 * 1024);
let type_info = match (&model.parameter_count, &model.quantization) {
(Some(params), Some(quant)) => format!(" ({}·{})", params, quant),
(Some(params), None) => format!(" ({})", params),
(None, Some(quant)) => format!(" ({})", quant),
_ => String::new(),
};
let lora_info = if model.lora_path.is_some() { " + LoRA" } else { "" };
println!(" {} => {:?} [{}MB{}{}]", name, model.path, size_mb, type_info, lora_info);
}
}
let all_available = state.registry.list_all_available();
if all_available.is_empty() {
println!("❌ No models found. Set SHIMMY_BASE_GGUF or place .gguf files in ./models/");
} else {
println!("\n✅ Total available models: {}", all_available.len());
}
}
cli::Command::Discover => {
println!("🔍 Refreshing model discovery...");
let registry = Registry::with_discovery();
let discovered = registry.discovered_models.clone();
if discovered.is_empty() {
println!("❌ No models found in search paths:");
let discovery = crate::auto_discovery::ModelAutoDiscovery::new();
for path in &discovery.search_paths {
println!(" • {:?}", path);
}
println!(" • Ollama models (if installed)");
println!("\n💡 Try downloading a GGUF model or setting SHIMMY_BASE_GGUF");
} else {
println!("✅ Found {} models:", discovered.len());
for (name, model) in discovered {
let size_mb = model.size_bytes / (1024 * 1024);
let lora_info = if model.lora_path.is_some() { " + LoRA" } else { "" };
println!(" {} [{}MB{}]", name, size_mb, lora_info);
println!(" Base: {:?}", model.path);
if let Some(lora) = &model.lora_path {
println!(" LoRA: {:?}", lora);
}
}
}
}
cli::Command::Probe { name } => {
let Some(spec) = state.registry.to_spec(&name) else { anyhow::bail!("no model {name}"); };
match state.engine.load(&spec).await {
Ok(_) => println!("ok: loaded {name}"),
Err(e) => {
eprintln!("probe failed: {e}");
std::process::exit(2);
}
}
}
cli::Command::Bench { name, max_tokens } => {
let Some(spec) = state.registry.to_spec(&name) else { anyhow::bail!("no model {name}"); };
let loaded = state.engine.load(&spec).await?;
let t0 = std::time::Instant::now();
let out = loaded.generate(
"Say hi.",
engine::GenOptions { max_tokens, stream: false, ..Default::default() },
None,
).await?;
let elapsed = t0.elapsed();
println!("bench output (truncated): {}", &out[..out.len().min(120)]);
println!("elapsed: {:?}", elapsed);
}
cli::Command::Generate { name, prompt, max_tokens } => {
let Some(spec) = state.registry.to_spec(&name) else { anyhow::bail!("no model {name}"); };
let loaded = state.engine.load(&spec).await?;
let out = loaded.generate(&prompt, engine::GenOptions { max_tokens, stream: false, ..Default::default() }, None).await?;
println!("{}", out);
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use std::env;
use std::sync::Arc;
use crate::engine::InferenceEngine;
struct MockEngine;
#[async_trait::async_trait]
impl engine::InferenceEngine for MockEngine {
async fn load(&self, _spec: &engine::ModelSpec) -> anyhow::Result<Box<dyn engine::LoadedModel>> {
Ok(Box::new(MockLoadedModel))
}
}
struct MockLoadedModel;
#[async_trait::async_trait]
impl engine::LoadedModel for MockLoadedModel {
async fn generate(
&self,
prompt: &str,
opts: engine::GenOptions,
_on_token: Option<Box<dyn FnMut(String) + Send>>,
) -> anyhow::Result<String> {
Ok(format!("Generated response to: {} (max_tokens: {})", prompt, opts.max_tokens))
}
}
#[tokio::test]
async fn test_main_initialization_paths() {
env::remove_var("SHIMMY_BASE_GGUF");
env::remove_var("SHIMMY_LORA_GGUF");
let mut reg = model_registry::Registry::with_discovery();
reg.register(model_registry::ModelEntry {
name: "phi3-lora".into(),
base_path: "./models/phi3-mini.gguf".into(),
lora_path: None,
template: Some("chatml".into()),
ctx_len: Some(4096),
n_threads: None,
});
let engine: Box<dyn engine::InferenceEngine> = Box::new(engine::adapter::InferenceEngineAdapter::new());
let state = AppState { engine, registry: reg };
let _state_arc = Arc::new(state);
assert!(true); }
#[tokio::test]
async fn test_environment_variable_handling() {
env::remove_var("SHIMMY_BASE_GGUF");
env::remove_var("SHIMMY_LORA_GGUF");
env::set_var("SHIMMY_BASE_GGUF", "/custom/path/model.gguf");
env::set_var("SHIMMY_LORA_GGUF", "/custom/path/lora.safetensors");
let base_path = env::var("SHIMMY_BASE_GGUF").unwrap_or_else(|_| "./models/phi3-mini.gguf".into());
let lora_path = env::var("SHIMMY_LORA_GGUF").ok();
assert_eq!(base_path, "/custom/path/model.gguf");
assert_eq!(lora_path, Some("/custom/path/lora.safetensors".to_string()));
env::remove_var("SHIMMY_BASE_GGUF");
env::remove_var("SHIMMY_LORA_GGUF");
}
#[test]
fn test_serve_command_address_parsing() {
use crate::port_manager::GLOBAL_PORT_ALLOCATOR;
let dynamic_port = GLOBAL_PORT_ALLOCATOR.allocate_ephemeral_port("test-serve-parsing").unwrap();
let bind_str = format!("127.0.0.1:{}", dynamic_port);
let addr: std::net::SocketAddr = bind_str.parse().expect("bad bind address");
assert_eq!(addr.port(), dynamic_port);
GLOBAL_PORT_ALLOCATOR.release_port(dynamic_port);
let invalid_bind = "invalid:address";
let result = invalid_bind.parse::<std::net::SocketAddr>();
assert!(result.is_err());
}
#[test]
fn test_serve_command_model_count_logic() {
let registry = model_registry::Registry::with_discovery();
let manual_count = registry.list().len();
let should_auto_register = manual_count <= 1;
assert!(should_auto_register || !should_auto_register); }
#[tokio::test]
async fn test_list_command_execution_logic() {
let mut registry = model_registry::Registry::with_discovery();
registry.register(model_registry::ModelEntry {
name: "test-model".into(),
base_path: "./test.gguf".into(),
lora_path: None,
template: Some("chatml".into()),
ctx_len: Some(2048),
n_threads: None,
});
let manual_models = registry.list();
assert!(!manual_models.is_empty());
let _auto_discovered = registry.discovered_models.clone();
let _all_available = registry.list_all_available();
assert!(true);
}
#[tokio::test]
async fn test_discover_command_execution_logic() {
let registry = model_registry::Registry::with_discovery();
let discovered = registry.discovered_models.clone();
if discovered.is_empty() {
assert!(discovered.is_empty());
} else {
assert!(!discovered.is_empty());
for (name, model) in discovered {
let _size_mb = model.size_bytes / (1024 * 1024);
let _lora_info = if model.lora_path.is_some() { " + LoRA" } else { "" };
assert!(!name.is_empty());
}
}
}
#[tokio::test]
async fn test_probe_command_execution_logic() {
let mut registry = model_registry::Registry::with_discovery();
registry.register(model_registry::ModelEntry {
name: "test-probe-model".into(),
base_path: "./test.gguf".into(),
lora_path: None,
template: Some("chatml".into()),
ctx_len: Some(2048),
n_threads: None,
});
let engine = MockEngine;
let name = "test-probe-model";
let spec_result = registry.to_spec(name);
if let Some(spec) = spec_result {
let load_result = engine.load(&spec).await;
match load_result {
Ok(_) => {
assert!(true);
}
Err(_) => {
assert!(true);
}
}
} else {
assert!(true);
}
}
#[tokio::test]
async fn test_bench_command_execution_logic() {
let mut registry = model_registry::Registry::with_discovery();
registry.register(model_registry::ModelEntry {
name: "test-bench-model".into(),
base_path: "./test.gguf".into(),
lora_path: None,
template: Some("chatml".into()),
ctx_len: Some(2048),
n_threads: None,
});
let engine = MockEngine;
let name = "test-bench-model";
let max_tokens = 100;
if let Some(spec) = registry.to_spec(name) {
let loaded = engine.load(&spec).await.unwrap();
let t0 = std::time::Instant::now();
let out = loaded.generate(
"Say hi.",
engine::GenOptions { max_tokens, stream: false, ..Default::default() },
None,
).await.unwrap();
let elapsed = t0.elapsed();
let truncated = &out[..out.len().min(120)];
assert!(!truncated.is_empty());
assert!(elapsed.as_nanos() > 0);
}
}
#[tokio::test]
async fn test_generate_command_execution_logic() {
let mut registry = model_registry::Registry::with_discovery();
registry.register(model_registry::ModelEntry {
name: "test-gen-model".into(),
base_path: "./test.gguf".into(),
lora_path: None,
template: Some("chatml".into()),
ctx_len: Some(2048),
n_threads: None,
});
let engine = MockEngine;
let name = "test-gen-model";
let prompt = "Hello, world!";
let max_tokens = 50;
if let Some(spec) = registry.to_spec(name) {
let loaded = engine.load(&spec).await.unwrap();
let out = loaded.generate(
prompt,
engine::GenOptions { max_tokens, stream: false, ..Default::default() },
None
).await.unwrap();
assert!(out.contains("Generated response to: Hello, world!"));
}
}
#[test]
fn test_command_execution_paths() {
use crate::cli::{Cli, Command};
use clap::Parser;
let gen_args = vec!["shimmy", "generate", "test-model", "--prompt", "Hello", "--max-tokens", "50"];
let cli = Cli::try_parse_from(gen_args).unwrap();
match cli.cmd {
Command::Generate { name, prompt, max_tokens } => {
assert_eq!(name, "test-model");
assert_eq!(prompt, "Hello");
assert_eq!(max_tokens, 50);
}
_ => panic!("Expected Generate command"),
}
}
#[tokio::test]
async fn test_state_initialization() {
use crate::model_registry::Registry;
use crate::engine::adapter::InferenceEngineAdapter;
let registry = Registry::with_discovery();
let engine = Box::new(InferenceEngineAdapter::new());
let state = std::sync::Arc::new(crate::AppState { engine, registry });
assert_ne!(std::mem::size_of_val(&state), 0);
let models = state.registry.list();
assert!(models.len() >= 0);
}
#[test]
fn test_serve_enhanced_state_logic() {
let registry = model_registry::Registry::with_discovery();
let mut enhanced_state = AppState {
engine: Box::new(engine::llama::LlamaEngine::new()),
registry: registry.clone(),
};
enhanced_state.registry.auto_register_discovered();
let enhanced_state_arc = Arc::new(enhanced_state);
let available_models = enhanced_state_arc.registry.list_all_available();
if available_models.is_empty() {
assert!(available_models.is_empty());
} else {
assert!(!available_models.is_empty());
}
}
#[test]
fn test_model_registration_with_env_vars() {
env::set_var("SHIMMY_BASE_GGUF", "/test/base.gguf");
env::set_var("SHIMMY_LORA_GGUF", "/test/lora.safetensors");
let base_path = env::var("SHIMMY_BASE_GGUF").unwrap_or_else(|_| "./models/phi3-mini.gguf".into());
let lora_path = env::var("SHIMMY_LORA_GGUF").ok().map(Into::into);
let mut reg = model_registry::Registry::with_discovery();
reg.register(model_registry::ModelEntry {
name: "phi3-lora".into(),
base_path: base_path.into(),
lora_path,
template: Some("chatml".into()),
ctx_len: Some(4096),
n_threads: None,
});
let models = reg.list();
assert!(models.len() >= 1);
env::remove_var("SHIMMY_BASE_GGUF");
env::remove_var("SHIMMY_LORA_GGUF");
}
#[test]
fn test_registry_model_methods() {
let mut registry = model_registry::Registry::with_discovery();
let initial_count = registry.list().len();
let _discovered = registry.discovered_models.clone();
let _all_available = registry.list_all_available();
registry.register(model_registry::ModelEntry {
name: "test".into(),
base_path: "./test.gguf".into(),
lora_path: None,
template: None,
ctx_len: None,
n_threads: None,
});
let after_count = registry.list().len();
assert!(after_count > initial_count);
let spec = registry.to_spec("test");
assert!(spec.is_some());
let no_spec = registry.to_spec("nonexistent");
assert!(no_spec.is_none());
}
#[test]
fn test_app_state_creation() {
use crate::engine::adapter::InferenceEngineAdapter;
use crate::model_registry::Registry;
let engine: Box<dyn engine::InferenceEngine> = Box::new(InferenceEngineAdapter::new());
let registry = Registry::with_discovery();
let state = AppState { engine, registry };
assert!(state.registry.list().len() >= 0);
}
#[test]
fn test_tracing_initialization() {
use tracing_subscriber::EnvFilter;
let env_filter = EnvFilter::from_default_env();
assert!(env_filter.max_level_hint().is_some() || env_filter.max_level_hint().is_none());
}
#[tokio::test]
async fn test_serve_command_execution_simulation() {
env::set_var("SHIMMY_BASE_GGUF", "./test.gguf");
let mut reg = model_registry::Registry::with_discovery();
reg.register(model_registry::ModelEntry {
name: "phi3-lora".into(),
base_path: env::var("SHIMMY_BASE_GGUF").unwrap_or_else(|_| "./models/phi3-mini.gguf".into()).into(),
lora_path: env::var("SHIMMY_LORA_GGUF").ok().map(Into::into),
template: Some("chatml".into()),
ctx_len: Some(4096),
n_threads: None,
});
let engine: Box<dyn engine::InferenceEngine> = Box::new(engine::adapter::InferenceEngineAdapter::new());
let state = AppState { engine, registry: reg };
let state = Arc::new(state);
use crate::port_manager::GLOBAL_PORT_ALLOCATOR;
let dynamic_port = GLOBAL_PORT_ALLOCATOR.allocate_ephemeral_port("test-serve-logic").unwrap();
let bind = format!("127.0.0.1:{}", dynamic_port);
let addr: std::net::SocketAddr = bind.parse().expect("bad bind address");
let manual_count = state.registry.list().len();
if manual_count <= 1 {
let mut enhanced_state = AppState {
engine: Box::new(engine::llama::LlamaEngine::new()),
registry: state.registry.clone(),
};
enhanced_state.registry.auto_register_discovered();
let enhanced_state_arc = Arc::new(enhanced_state);
let available_models = enhanced_state_arc.registry.list_all_available();
if available_models.is_empty() {
assert!(available_models.is_empty());
} else {
assert!(!available_models.is_empty());
}
}
let available_models = state.registry.list_all_available();
if available_models.is_empty() {
assert!(available_models.is_empty());
}
env::remove_var("SHIMMY_BASE_GGUF");
assert_eq!(addr.port(), dynamic_port);
GLOBAL_PORT_ALLOCATOR.release_port(dynamic_port);
}
#[tokio::test]
async fn test_command_match_branches_coverage() {
let mut reg = model_registry::Registry::with_discovery();
reg.register(model_registry::ModelEntry {
name: "test-model".into(),
base_path: "./test.gguf".into(),
lora_path: None,
template: Some("chatml".into()),
ctx_len: Some(2048),
n_threads: None,
});
let engine = MockEngine;
let state = Arc::new(AppState { engine: Box::new(engine::adapter::InferenceEngineAdapter::new()), registry: reg });
{
let manual_models = state.registry.list();
if !manual_models.is_empty() {
for e in &manual_models {
assert!(!e.name.is_empty());
}
}
let auto_discovered = state.registry.discovered_models.clone();
if !auto_discovered.is_empty() {
for (name, model) in auto_discovered {
let _size_mb = model.size_bytes / (1024 * 1024);
let _type_info = match (&model.parameter_count, &model.quantization) {
(Some(params), Some(quant)) => format!(" ({}·{})", params, quant),
(Some(params), None) => format!(" ({})", params),
(None, Some(quant)) => format!(" ({})", quant),
_ => String::new(),
};
let _lora_info = if model.lora_path.is_some() { " + LoRA" } else { "" };
assert!(!name.is_empty());
}
}
let all_available = state.registry.list_all_available();
if all_available.is_empty() {
assert!(all_available.is_empty());
} else {
assert!(all_available.len() >= 0);
}
}
{
let registry = model_registry::Registry::with_discovery();
let discovered = registry.discovered_models.clone();
if discovered.is_empty() {
assert!(discovered.is_empty());
} else {
for (name, model) in discovered {
let _size_mb = model.size_bytes / (1024 * 1024);
let _lora_info = if model.lora_path.is_some() { " + LoRA" } else { "" };
assert!(!name.is_empty());
}
}
}
let name = "test-model";
if let Some(spec) = state.registry.to_spec(name) {
assert!(spec.base_path.to_string_lossy().contains("test"));
assert!(true); } else {
assert!(false, "Expected to find test model");
}
}
#[test]
fn test_error_conditions_and_edge_cases() {
let invalid_addresses = vec![
"invalid-address",
"256.256.256.256:9999", "127.0.0.1:99999", "127.0.0.1:", ":9999", "", "not.an.ip:port", ];
for addr_str in invalid_addresses {
let result = addr_str.parse::<std::net::SocketAddr>();
assert!(result.is_err(), "Expected parsing to fail for: {}", addr_str);
}
use crate::port_manager::GLOBAL_PORT_ALLOCATOR;
let port1 = GLOBAL_PORT_ALLOCATOR.allocate_ephemeral_port("test-valid-1").unwrap();
let port2 = GLOBAL_PORT_ALLOCATOR.allocate_ephemeral_port("test-valid-2").unwrap();
let port3 = GLOBAL_PORT_ALLOCATOR.allocate_ephemeral_port("test-valid-3").unwrap();
let port4 = GLOBAL_PORT_ALLOCATOR.allocate_ephemeral_port("test-valid-4").unwrap();
let valid_addresses = vec![
format!("127.0.0.1:{}", port1),
format!("0.0.0.0:{}", port2),
format!("192.168.1.100:{}", port3),
format!("[::1]:{}", port4), ];
GLOBAL_PORT_ALLOCATOR.release_port(port1);
GLOBAL_PORT_ALLOCATOR.release_port(port2);
GLOBAL_PORT_ALLOCATOR.release_port(port3);
GLOBAL_PORT_ALLOCATOR.release_port(port4);
for addr_str in valid_addresses {
let result = addr_str.parse::<std::net::SocketAddr>();
assert!(result.is_ok(), "Expected parsing to succeed for: {}", addr_str);
}
}
#[test]
fn test_registry_edge_cases() {
let registry = model_registry::Registry::with_discovery();
let nonexistent_names = vec![
"nonexistent-model",
"",
"model-with-special-chars!@#",
"very-long-model-name-that-might-cause-issues-in-some-systems-if-not-handled-properly",
];
for name in nonexistent_names {
let spec = registry.to_spec(name);
assert!(spec.is_none(), "Expected no spec for nonexistent model: {}", name);
}
}
#[test]
fn test_model_entry_variants() {
let mut registry = model_registry::Registry::with_discovery();
registry.register(model_registry::ModelEntry {
name: "minimal".to_string(),
base_path: "./minimal.gguf".into(),
lora_path: None,
template: None,
ctx_len: None,
n_threads: None,
});
registry.register(model_registry::ModelEntry {
name: "maximal".to_string(),
base_path: "./maximal.gguf".into(),
lora_path: Some("./maximal.lora".into()),
template: Some("llama3".to_string()),
ctx_len: Some(8192),
n_threads: Some(8),
});
let models = registry.list();
assert!(models.len() >= 2);
let minimal = models.iter().find(|e| e.name == "minimal").unwrap();
assert!(minimal.lora_path.is_none());
assert!(minimal.template.is_none());
let maximal = models.iter().find(|e| e.name == "maximal").unwrap();
assert!(maximal.lora_path.is_some());
assert_eq!(maximal.template.as_ref().unwrap(), "llama3");
assert_eq!(maximal.ctx_len.unwrap(), 8192);
assert_eq!(maximal.n_threads.unwrap(), 8);
}
#[tokio::test]
async fn test_mock_engine_behavior() {
let engine = MockEngine;
let minimal_spec = crate::engine::ModelSpec {
name: "test".to_string(),
base_path: "./test.gguf".into(),
lora_path: None,
template: None,
ctx_len: 1024,
n_threads: None,
};
let loaded = engine.load(&minimal_spec).await.unwrap();
let output = loaded.generate(
"Test prompt",
crate::engine::GenOptions::default(),
None,
).await.unwrap();
assert!(output.contains("Generated response to: Test prompt"));
assert!(output.contains("max_tokens:"));
let mut opts = crate::engine::GenOptions::default();
opts.max_tokens = 150;
opts.temperature = 0.8;
let output = loaded.generate("Another test", opts, None).await.unwrap();
assert!(output.contains("Another test"));
assert!(output.contains("150"));
}
#[test]
fn test_auto_discovery_models_access() {
let registry = model_registry::Registry::with_discovery();
let discovered = registry.discovered_models.clone();
if discovered.is_empty() {
assert_eq!(discovered.len(), 0);
} else {
for (_name, model) in &discovered {
let _type_info = match (&model.parameter_count, &model.quantization) {
(Some(params), Some(quant)) => {
let info = format!(" ({}·{})", params, quant);
assert!(info.contains(params));
assert!(info.contains(quant));
info
},
(Some(params), None) => {
let info = format!(" ({})", params);
assert!(info.contains(params));
info
},
(None, Some(quant)) => {
let info = format!(" ({})", quant);
assert!(info.contains(quant));
info
},
_ => {
String::new()
},
};
let _lora_info = if model.lora_path.is_some() { " + LoRA" } else { "" };
}
}
let all_available = registry.list_all_available();
assert!(all_available.len() >= 0);
}
#[tokio::test]
async fn test_serve_command_edge_cases() {
let empty_registry = model_registry::Registry::with_discovery();
let manual_count = empty_registry.list().len();
if manual_count <= 1 {
let mut enhanced_state = AppState {
engine: Box::new(engine::llama::LlamaEngine::new()),
registry: empty_registry.clone(),
};
enhanced_state.registry.auto_register_discovered();
let available_models = enhanced_state.registry.list_all_available();
if available_models.is_empty() {
assert!(available_models.is_empty());
} else {
assert!(!available_models.is_empty());
}
}
}
#[test]
fn test_string_truncation_logic() {
let long_string = "A".repeat(150);
let expected_120 = "A".repeat(120);
let exactly_225 = "Exactly120chars".to_string() + &"A".repeat(105);
let test_strings = vec![
("Short", 120, "Short"),
(&long_string, 120, &expected_120),
("", 120, ""),
(&exactly_225, 120, &exactly_225),
];
for (input, limit, expected) in test_strings {
let truncated = &input[..input.len().min(limit)];
assert_eq!(truncated, expected);
}
}
#[test]
fn test_duration_and_timing() {
let start = std::time::Instant::now();
std::thread::sleep(std::time::Duration::from_millis(1));
let elapsed = start.elapsed();
assert!(elapsed.as_nanos() > 0);
assert!(elapsed.as_millis() >= 1);
let duration_str = format!("{:?}", elapsed);
assert!(duration_str.contains("ms") || duration_str.contains("µs") || duration_str.contains("ns"));
}
#[tokio::test]
async fn test_discover_command_execution() {
let registry = model_registry::Registry::with_discovery();
let discovered = registry.discovered_models.clone();
if discovered.is_empty() {
assert!(discovered.is_empty());
} else {
for (name, model) in discovered {
let size_mb = model.size_bytes / (1024 * 1024);
assert!(size_mb >= 0);
let lora_info = if model.lora_path.is_some() { " + LoRA" } else { "" };
assert!(lora_info == " + LoRA" || lora_info == "");
assert!(!name.is_empty());
}
}
}
#[tokio::test]
async fn test_probe_command_execution() {
let mut registry = model_registry::Registry::with_discovery();
registry.register(model_registry::ModelEntry {
name: "probe-test".to_string(),
base_path: "./probe-test.gguf".into(),
lora_path: None,
template: Some("chatml".into()),
ctx_len: Some(2048),
n_threads: None,
});
let engine = MockEngine;
let name = "probe-test";
if let Some(spec) = registry.to_spec(name) {
match engine.load(&spec).await {
Ok(_) => {
assert!(true);
}
Err(_) => {
assert!(false, "MockEngine should not fail");
}
}
} else {
assert!(false, "Expected to find probe-test model");
}
}
#[tokio::test]
async fn test_bench_command_execution() {
let mut registry = model_registry::Registry::with_discovery();
registry.register(model_registry::ModelEntry {
name: "bench-test".to_string(),
base_path: "./bench-test.gguf".into(),
lora_path: None,
template: Some("chatml".into()),
ctx_len: Some(2048),
n_threads: None,
});
let engine = MockEngine;
let name = "bench-test";
let max_tokens = 64;
if let Some(spec) = registry.to_spec(name) {
let loaded = engine.load(&spec).await.unwrap();
let t0 = std::time::Instant::now();
let out = loaded.generate(
"Say hi.",
engine::GenOptions { max_tokens, stream: false, ..Default::default() },
None,
).await.unwrap();
let elapsed = t0.elapsed();
let truncated = &out[..out.len().min(120)];
assert!(!truncated.is_empty());
assert!(elapsed.as_nanos() > 0);
assert!(out.contains("Generated response to: Say hi."));
}
}
#[test]
fn test_environment_cleanup() {
let test_vars = vec![
"SHIMMY_BASE_GGUF",
"SHIMMY_LORA_GGUF",
"HOME",
"USERPROFILE",
];
let original_values: Vec<_> = test_vars.iter()
.map(|var| (*var, env::var(var).ok()))
.collect();
for var in &test_vars {
env::set_var(var, "/test/path");
}
for var in &test_vars {
assert_eq!(env::var(var).unwrap(), "/test/path");
}
for (var, original_value) in original_values {
match original_value {
Some(value) => env::set_var(var, value),
None => env::remove_var(var),
}
}
}
#[test]
fn test_size_calculations() {
let test_cases = vec![
(0u64, 0u64),
(1024, 0), (1024 * 1024, 1), (1024 * 1024 * 5, 5), (1024 * 1024 * 1536, 1536), ];
for (size_bytes, expected_mb) in test_cases {
let size_mb = size_bytes / (1024 * 1024);
assert_eq!(size_mb, expected_mb, "Size calculation failed for {} bytes", size_bytes);
}
}
#[tokio::test]
async fn test_model_loading_error_paths() {
let mut registry = model_registry::Registry::with_discovery();
registry.register(model_registry::ModelEntry {
name: "error-test".to_string(),
base_path: "./nonexistent.gguf".into(), lora_path: None,
template: Some("chatml".into()),
ctx_len: Some(2048),
n_threads: None,
});
let engine = MockEngine;
let spec = registry.to_spec("error-test").unwrap();
let load_result = engine.load(&spec).await;
match load_result {
Ok(loaded_model) => {
let gen_result = loaded_model.generate(
"test prompt",
crate::engine::GenOptions::default(),
None,
).await;
assert!(gen_result.is_ok());
}
Err(_) => {
assert!(true); }
}
}
}