#![allow(unpredictable_function_pointer_comparisons)]
mod anthropic_compat;
mod api;
mod api_errors;
mod auto_discovery;
mod cache;
mod cli;
mod engine;
mod invariant_ppt;
mod model_registry;
mod observability;
mod openai_compat;
mod port_manager;
mod server;
mod templates;
mod util {
pub mod diag;
pub mod memory;
}
use clap::Parser;
use model_registry::{ModelEntry, Registry};
use std::path::PathBuf;
use std::sync::Arc;
use tracing::info;
pub struct AppState {
pub engine: Box<dyn engine::InferenceEngine>,
pub registry: Registry,
pub observability: observability::ObservabilityManager,
pub response_cache: cache::ResponseCache,
}
impl AppState {
pub fn new(engine: Box<dyn engine::InferenceEngine>, registry: Registry) -> Self {
#[allow(unused_mut)]
let mut state = Self {
engine,
registry,
observability: observability::ObservabilityManager::new(),
response_cache: cache::ResponseCache::new(),
};
state
}
}
fn validate_runtime_version() {
let version = env!("CARGO_PKG_VERSION");
if version == "0.1.0" {
eprintln!();
eprintln!("❌ ERROR: Invalid shimmy version detected!");
eprintln!();
eprintln!("This binary reports version 0.1.0, which indicates it was built incorrectly.");
eprintln!("This is the exact issue reported in GitHub Issue #63.");
eprintln!();
eprintln!("🔧 Solutions:");
eprintln!(" • Download the official release from: https://github.com/Michael-A-Kuykendall/shimmy/releases");
eprintln!(" • If building from source, ensure you're building from a proper Git tag");
eprintln!(" • If forking, update the version in Cargo.toml before building");
eprintln!();
eprintln!("Current version: {}", version);
eprintln!("Expected version: 1.4.1+ (not 0.1.0)");
eprintln!();
std::process::exit(1);
}
if version.is_empty() {
eprintln!("ERROR: Empty version detected. This binary was built incorrectly.");
std::process::exit(1);
}
let parts: Vec<&str> = version.split('.').collect();
if parts.len() < 2 || parts.iter().take(2).any(|p| p.parse::<u32>().is_err()) {
eprintln!(
"ERROR: Invalid version format '{}'. Expected semantic versioning.",
version
);
std::process::exit(1);
}
}
fn print_startup_diagnostics(
version: &str,
#[cfg_attr(not(feature = "llama"), allow(unused_variables))] gpu_backend: Option<&str>,
#[cfg_attr(not(feature = "llama"), allow(unused_variables))] cpu_moe: bool,
#[cfg_attr(not(feature = "llama"), allow(unused_variables))] n_cpu_moe: Option<usize>,
model_count: usize,
airframe_selected: bool,
) {
println!("🎯 Shimmy v{}", version);
if airframe_selected {
println!("🔧 Backend: Airframe (GPU)");
} else {
#[cfg(feature = "llama")]
{
let backend_display = match gpu_backend {
Some("cpu") => "CPU only".to_string(),
Some("cuda") => "CUDA (GPU acceleration)".to_string(),
Some("vulkan") => "Vulkan (GPU acceleration)".to_string(),
Some("opencl") => "OpenCL (GPU acceleration)".to_string(),
Some("auto") | None => {
if cfg!(feature = "llama-cuda") {
"CUDA (auto-detected)".to_string()
} else if cfg!(feature = "llama-vulkan") {
"Vulkan (auto-detected)".to_string()
} else if cfg!(feature = "llama-opencl") {
"OpenCL (auto-detected)".to_string()
} else {
"CPU (no GPU acceleration)".to_string()
}
}
Some(other) => format!("{} (custom)", other),
};
println!("🔧 Backend: {}", backend_display);
}
#[cfg(not(feature = "llama"))]
{
println!("🔧 Backend: Stub mode (no llama feature)");
}
}
#[cfg(feature = "llama")]
if cpu_moe || n_cpu_moe.is_some() {
if let Some(n) = n_cpu_moe {
println!(
"🧠 MoE: CPU offload first {} layers (saves VRAM for large MoE models)",
n
);
} else if cpu_moe {
println!("🧠 MoE: CPU offload ALL expert tensors (saves ~80-85% VRAM)");
}
}
println!("📦 Models: {} available", model_count);
}
fn airframe_engine_selected(legacy: bool) -> bool {
if legacy {
return false;
}
if !cfg!(feature = "airframe") {
return false;
}
std::env::var("SHIMMY_ENGINE_BACKEND")
.or_else(|_| std::env::var("SHIMMY_EXPERIMENTAL_BACKEND"))
.map(|value| !value.eq_ignore_ascii_case("llama"))
.unwrap_or(true)
}
fn build_engine(
gpu_backend: Option<&str>,
legacy: bool,
_cpu_moe: bool,
_n_cpu_moe: Option<usize>,
) -> Box<dyn engine::InferenceEngine> {
if airframe_engine_selected(legacy) {
#[cfg(feature = "airframe")]
{
return Box::new(engine::airframe::AirframeEngine::new());
}
#[cfg(not(feature = "airframe"))]
{
tracing::warn!("Airframe engine selected but not compiled in; falling back to adapter");
return Box::new(engine::adapter::InferenceEngineAdapter::new_with_backend(
gpu_backend,
));
}
}
#[cfg(feature = "llama")]
{
let mut adapter = engine::adapter::InferenceEngineAdapter::new_with_backend(gpu_backend);
if cpu_moe || n_cpu_moe.is_some() {
adapter = adapter.with_moe_config(cpu_moe, n_cpu_moe);
}
return Box::new(adapter);
}
#[cfg(not(feature = "llama"))]
Box::new(engine::adapter::InferenceEngineAdapter::new_with_backend(
gpu_backend,
))
}
#[tokio::main]
async fn main() -> anyhow::Result<()> {
validate_runtime_version();
let use_ansi = std::env::var("NO_COLOR").is_err()
&& std::io::IsTerminal::is_terminal(&std::io::stdout())
&& std::env::var("TERM")
.map(|t| !t.is_empty() && t != "dumb")
.unwrap_or(false);
tracing_subscriber::fmt()
.with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
.with_ansi(use_ansi)
.init();
#[cfg(all(target_arch = "aarch64", target_os = "macos", not(feature = "llama")))]
info!("llama.cpp temporarily disabled on macOS ARM64 due to upstream i8mm build incompatibility; using SafeTensors backend");
let cli = cli::Cli::parse();
if let Some(model_dirs) = &cli.model_dirs {
std::env::set_var("SHIMMY_MODEL_PATHS", model_dirs);
}
let mut reg = Registry::with_discovery();
if let Ok(default_model_path) = std::env::var("SHIMMY_BASE_GGUF") {
reg.register(ModelEntry {
name: "tinyllama-1.1b".into(),
base_path: default_model_path.into(),
lora_path: std::env::var("SHIMMY_LORA_GGUF").ok().map(Into::into),
template: Some("chatml".into()),
ctx_len: Some(crate::model_registry::shimmy_ctx_len()),
n_threads: None,
});
}
let engine: Box<dyn engine::InferenceEngine> = build_engine(
cli.gpu_backend.as_deref(),
cli.legacy,
cli.cpu_moe,
cli.n_cpu_moe,
);
if let cli::Command::Serve {
model_path: Some(ref path),
..
} = cli.cmd
{
let path_buf = PathBuf::from(path);
if path_buf.exists() {
let model_name = path_buf
.file_stem()
.and_then(|s| s.to_str())
.unwrap_or("direct-model")
.to_string();
reg.register(ModelEntry {
name: model_name.clone(),
base_path: path_buf.clone(),
lora_path: None,
template: None,
ctx_len: None,
n_threads: None,
});
println!("🎯 Direct model loaded: {} -> {}", model_name, path);
} else {
eprintln!("❌ Model file not found: {}", path);
std::process::exit(1);
}
}
let state = AppState::new(engine, reg);
let state = Arc::new(state);
match cli.cmd {
cli::Command::Serve { ref bind, .. } => {
let addr = port_manager::GLOBAL_PORT_ALLOCATOR
.resolve_bind_address(bind)
.unwrap_or_else(|e| {
eprintln!("❌ Failed to resolve bind address '{}': {}", bind, e);
eprintln!();
eprintln!("💡 Valid bind address examples:");
eprintln!(" auto # Auto-allocate (default)");
eprintln!(" 127.0.0.1:11435 # Specific address");
eprintln!(" 0.0.0.0:8080 # All interfaces");
eprintln!();
eprintln!("🔧 Environment variable: SHIMMY_BIND_ADDRESS=127.0.0.1:11435");
std::process::exit(1);
});
let use_airframe = airframe_engine_selected(cli.legacy);
print_startup_diagnostics(
env!("CARGO_PKG_VERSION"),
cli.gpu_backend.as_deref(),
cli.cpu_moe,
cli.n_cpu_moe,
state.registry.list().len(),
use_airframe,
);
println!("🚀 Starting server on {}", addr);
let manual_count = state.registry.list().len();
if manual_count <= 1 && !airframe_engine_selected(cli.legacy) {
let enhanced_engine: Box<dyn engine::InferenceEngine> = build_engine(
cli.gpu_backend.as_deref(),
cli.legacy,
cli.cpu_moe,
cli.n_cpu_moe,
);
let mut enhanced_state = AppState::new(enhanced_engine, state.registry.clone());
enhanced_state.registry.auto_register_discovered();
let enhanced_state = Arc::new(enhanced_state);
let available_models = enhanced_state.registry.list_all_available();
if available_models.is_empty() {
eprintln!("❌ No models available. Please:");
eprintln!(" • Set SHIMMY_BASE_GGUF environment variable, or");
eprintln!(" • Place .gguf files in ./models/ directory, or");
eprintln!(" • Place .gguf files in ~/.cache/huggingface/hub/");
std::process::exit(1);
}
println!("📦 Models: {} available", available_models.len());
println!("✅ Ready to serve requests");
println!(" • POST /api/generate (streaming + non-streaming)");
println!(" • GET /health (health check + metrics)");
println!(" • GET /v1/models (OpenAI-compatible)");
info!(%addr, models=%available_models.len(), "shimmy serving with {} available models", available_models.len());
return server::run(addr, enhanced_state).await;
}
let available_models = state.registry.list_all_available();
if available_models.is_empty() {
eprintln!("❌ No models available. Please:");
eprintln!(" • Set SHIMMY_BASE_GGUF environment variable, or");
eprintln!(" • Place .gguf files in ./models/ directory, or");
eprintln!(" • Place .gguf files in ~/.cache/huggingface/hub/");
std::process::exit(1);
}
println!("📦 Models: {} available", available_models.len());
println!("✅ Ready to serve requests");
println!(" • POST /api/generate (streaming + non-streaming)");
println!(" • GET /health (health check + metrics)");
println!(" • GET /v1/models (OpenAI-compatible)");
info!(%addr, models=%available_models.len(), "shimmy serving with {} available models", available_models.len());
server::run(addr, state).await?;
}
cli::Command::List { short } => {
if short {
let all_available = state.registry.list_all_available();
for model_name in all_available {
println!("{}", model_name);
}
} else {
let manual_models = state.registry.list();
if !manual_models.is_empty() {
println!("📋 Registered Models:");
for e in &manual_models {
println!(" {} => {:?}", e.name, e.base_path);
}
}
let auto_discovered = state.registry.discovered_models.clone();
if !auto_discovered.is_empty() {
if !manual_models.is_empty() {
println!();
}
println!("🔍 Auto-Discovered Models:");
for (name, model) in auto_discovered {
let size_mb = model.size_bytes / (1024 * 1024);
let type_info = match (&model.parameter_count, &model.quantization) {
(Some(params), Some(quant)) => format!(" ({}·{})", params, quant),
(Some(params), None) => format!(" ({})", params),
(None, Some(quant)) => format!(" ({})", quant),
_ => String::new(),
};
let lora_info = if model.lora_path.is_some() {
" + LoRA"
} else {
""
};
println!(
" {} => {:?} [{}MB{}{}]",
name, model.path, size_mb, type_info, lora_info
);
}
}
let all_available = state.registry.list_all_available();
if all_available.is_empty() {
println!(
"❌ No models found. Set SHIMMY_BASE_GGUF or place .gguf files in ./models/"
);
} else {
println!("\n✅ Total available models: {}", all_available.len());
}
}
}
cli::Command::Discover { llm_only } => {
println!("🔍 Refreshing model discovery...");
let registry = Registry::with_discovery();
let mut discovered = registry.discovered_models.clone();
if llm_only {
discovered.retain(|name, _| {
let name_lower = name.to_lowercase();
!name_lower.contains("clip")
&& !name_lower.contains("text-to-image")
&& !name_lower.contains("vision")
&& !name_lower.contains("image")
&& !name_lower.contains("video")
&& !name_lower.contains("audio")
&& !name_lower.contains("tts")
&& !name_lower.contains("stt")
&& !name_lower.contains("embedding")
&& !name_lower.contains("encoder")
});
println!("🎯 Filtering to LLM models only...");
}
if discovered.is_empty() {
if llm_only {
println!("❌ No LLM models found after filtering");
println!("💡 Try running without --llm-only to see all models");
} else {
println!("❌ No models found in search paths:");
let discovery = crate::auto_discovery::ModelAutoDiscovery::new();
for path in &discovery.search_paths {
println!(" • {:?}", path);
}
println!(" • Ollama models (if installed)");
println!("\n💡 Try downloading a GGUF model or setting SHIMMY_BASE_GGUF");
}
} else {
println!("✅ Found {} models:", discovered.len());
for (name, model) in discovered {
let size_mb = model.size_bytes / (1024 * 1024);
let lora_info = if model.lora_path.is_some() {
" + LoRA"
} else {
""
};
println!(" {} [{}MB{}]", name, size_mb, lora_info);
println!(" Base: {:?}", model.path);
if let Some(lora) = &model.lora_path {
println!(" LoRA: {:?}", lora);
}
}
}
}
cli::Command::Probe { name } => {
let Some(spec) = state.registry.to_spec(&name) else {
anyhow::bail!("no model {name}");
};
match state.engine.load(&spec).await {
Ok(_) => println!("ok: loaded {name}"),
Err(e) => {
eprintln!("probe failed: {e}");
std::process::exit(2);
}
}
}
cli::Command::Bench { name, max_tokens } => {
let Some(spec) = state.registry.to_spec(&name) else {
anyhow::bail!("no model {name}");
};
let loaded = state.engine.load(&spec).await?;
let t0 = std::time::Instant::now();
let out = loaded
.generate(
"Say hi.",
engine::GenOptions {
max_tokens,
stream: false,
..Default::default()
},
None,
)
.await?;
let elapsed = t0.elapsed();
println!("bench output (truncated): {}", &out[..out.len().min(120)]);
println!("elapsed: {:?}", elapsed);
}
cli::Command::Generate {
name,
prompt,
max_tokens,
} => {
let Some(spec) = state.registry.to_spec(&name) else {
anyhow::bail!("no model {name}");
};
let loaded = state.engine.load(&spec).await?;
let out = loaded
.generate(
&prompt,
engine::GenOptions {
max_tokens,
stream: false,
..Default::default()
},
None,
)
.await?;
println!("{}", out);
}
cli::Command::GpuInfo => {
println!("🖥️ GPU Backend Information");
println!();
#[cfg(feature = "airframe")]
{
println!("⚡ Airframe Engine: ✅ Enabled (WebGPU via wgpu)");
println!(" Adapter is selected automatically at runtime.");
println!(" Run `shimmy serve` to see which GPU adapter is chosen.");
println!(
" Supported: NVIDIA, AMD, Intel, Apple Silicon (Metal), integrated GPUs."
);
}
#[cfg(not(feature = "airframe"))]
{
println!("⚡ Airframe Engine: Disabled (this build uses the huggingface engine)");
println!(
" For GPU acceleration, download a release binary from GitHub Releases."
);
}
#[cfg(feature = "llama")]
{
println!();
println!("🔧 Legacy llama.cpp backend: ✅ Available (use --legacy or SHIMMY_ENGINE_BACKEND=llama)");
}
#[cfg(not(feature = "llama"))]
{
println!();
println!("🔧 Legacy llama.cpp backend: Disabled");
}
println!();
println!("💡 GPU acceleration: download a release binary from GitHub Releases.");
println!(" https://github.com/Michael-A-Kuykendall/shimmy/releases/latest");
}
cli::Command::Init {
template,
output,
name,
} => {
let result = templates::generate_template(&template, &output, name.as_deref());
match result {
Ok(message) => println!("{}", message),
Err(e) => {
eprintln!("❌ Template generation failed: {}", e);
std::process::exit(1);
}
}
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::engine::InferenceEngine;
use std::env;
use std::sync::Arc;
struct MockEngine;
#[async_trait::async_trait]
impl engine::InferenceEngine for MockEngine {
async fn load(
&self,
_spec: &engine::ModelSpec,
) -> anyhow::Result<Box<dyn engine::LoadedModel>> {
Ok(Box::new(MockLoadedModel))
}
}
struct MockLoadedModel;
#[async_trait::async_trait]
impl engine::LoadedModel for MockLoadedModel {
async fn generate(
&self,
prompt: &str,
opts: engine::GenOptions,
_on_token: Option<Box<dyn FnMut(String) + Send>>,
) -> anyhow::Result<String> {
Ok(format!(
"Generated response to: {} (max_tokens: {})",
prompt, opts.max_tokens
))
}
}
#[tokio::test]
async fn test_main_initialization_paths() {
env::remove_var("SHIMMY_BASE_GGUF");
env::remove_var("SHIMMY_LORA_GGUF");
let mut reg = model_registry::Registry::with_discovery();
reg.register(model_registry::ModelEntry {
name: "phi3-lora".into(),
base_path: "./models/phi3-mini.gguf".into(),
lora_path: None,
template: Some("chatml".into()),
ctx_len: Some(2048),
n_threads: None,
});
let engine: Box<dyn engine::InferenceEngine> =
Box::new(engine::adapter::InferenceEngineAdapter::new());
let state = AppState::new(engine, reg);
let _state_arc = Arc::new(state);
}
#[tokio::test]
async fn test_environment_variable_handling() {
env::remove_var("SHIMMY_BASE_GGUF");
env::remove_var("SHIMMY_LORA_GGUF");
env::set_var("SHIMMY_BASE_GGUF", "/custom/path/model.gguf");
env::set_var("SHIMMY_LORA_GGUF", "/custom/path/lora.safetensors");
let base_path =
env::var("SHIMMY_BASE_GGUF").unwrap_or_else(|_| "./models/phi3-mini.gguf".into());
let lora_path = env::var("SHIMMY_LORA_GGUF").ok();
assert_eq!(base_path, "/custom/path/model.gguf");
assert_eq!(lora_path, Some("/custom/path/lora.safetensors".to_string()));
env::remove_var("SHIMMY_BASE_GGUF");
env::remove_var("SHIMMY_LORA_GGUF");
}
#[test]
fn test_serve_command_address_parsing() {
use crate::port_manager::GLOBAL_PORT_ALLOCATOR;
let dynamic_port = GLOBAL_PORT_ALLOCATOR
.allocate_ephemeral_port("test-serve-parsing")
.unwrap();
let bind_str = format!("127.0.0.1:{}", dynamic_port);
let addr: std::net::SocketAddr = bind_str.parse().expect("bad bind address");
assert_eq!(addr.port(), dynamic_port);
GLOBAL_PORT_ALLOCATOR.release_port(dynamic_port);
let invalid_bind = "invalid:address";
let result = invalid_bind.parse::<std::net::SocketAddr>();
assert!(result.is_err());
}
#[test]
fn test_serve_command_model_count_logic() {
let registry = model_registry::Registry::with_discovery();
let manual_count = registry.list().len();
let _should_auto_register = manual_count <= 1;
}
#[tokio::test]
async fn test_list_command_execution_logic() {
let mut registry = model_registry::Registry::with_discovery();
registry.register(model_registry::ModelEntry {
name: "test-model".into(),
base_path: "./test.gguf".into(),
lora_path: None,
template: Some("chatml".into()),
ctx_len: Some(2048),
n_threads: None,
});
let manual_models = registry.list();
assert!(!manual_models.is_empty());
let _auto_discovered = registry.discovered_models.clone();
let _all_available = registry.list_all_available();
}
#[tokio::test]
async fn test_discover_command_execution_logic() {
let registry = model_registry::Registry::with_discovery();
let discovered = registry.discovered_models.clone();
if discovered.is_empty() {
assert!(discovered.is_empty());
} else {
assert!(!discovered.is_empty());
for (name, model) in discovered {
let _size_mb = model.size_bytes / (1024 * 1024);
let _lora_info = if model.lora_path.is_some() {
" + LoRA"
} else {
""
};
assert!(!name.is_empty());
}
}
}
#[tokio::test]
async fn test_probe_command_execution_logic() {
let mut registry = model_registry::Registry::with_discovery();
registry.register(model_registry::ModelEntry {
name: "test-probe-model".into(),
base_path: "./test.gguf".into(),
lora_path: None,
template: Some("chatml".into()),
ctx_len: Some(2048),
n_threads: None,
});
let engine = MockEngine;
let name = "test-probe-model";
let spec_result = registry.to_spec(name);
if let Some(spec) = spec_result {
let load_result = engine.load(&spec).await;
match load_result {
Ok(_) => {
}
Err(_) => {
}
}
} else {
}
}
#[tokio::test]
async fn test_bench_command_execution_logic() {
let mut registry = model_registry::Registry::with_discovery();
registry.register(model_registry::ModelEntry {
name: "test-bench-model".into(),
base_path: "./test.gguf".into(),
lora_path: None,
template: Some("chatml".into()),
ctx_len: Some(2048),
n_threads: None,
});
let engine = MockEngine;
let name = "test-bench-model";
let max_tokens = 100;
if let Some(spec) = registry.to_spec(name) {
let loaded = engine.load(&spec).await.unwrap();
let t0 = std::time::Instant::now();
let out = loaded
.generate(
"Say hi.",
engine::GenOptions {
max_tokens,
stream: false,
..Default::default()
},
None,
)
.await
.unwrap();
let elapsed = t0.elapsed();
let truncated = &out[..out.len().min(120)];
assert!(!truncated.is_empty());
assert!(elapsed.as_nanos() > 0);
}
}
#[tokio::test]
async fn test_generate_command_execution_logic() {
let mut registry = model_registry::Registry::with_discovery();
registry.register(model_registry::ModelEntry {
name: "test-gen-model".into(),
base_path: "./test.gguf".into(),
lora_path: None,
template: Some("chatml".into()),
ctx_len: Some(2048),
n_threads: None,
});
let engine = MockEngine;
let name = "test-gen-model";
let prompt = "Hello, world!";
let max_tokens = 50;
if let Some(spec) = registry.to_spec(name) {
let loaded = engine.load(&spec).await.unwrap();
let out = loaded
.generate(
prompt,
engine::GenOptions {
max_tokens,
stream: false,
..Default::default()
},
None,
)
.await
.unwrap();
assert!(out.contains("Generated response to: Hello, world!"));
}
}
#[test]
fn test_command_execution_paths() {
use crate::cli::{Cli, Command};
use clap::Parser;
let gen_args = vec![
"shimmy",
"generate",
"test-model",
"--prompt",
"Hello",
"--max-tokens",
"50",
];
let cli = Cli::try_parse_from(gen_args).unwrap();
match cli.cmd {
Command::Generate {
name,
prompt,
max_tokens,
} => {
assert_eq!(name, "test-model");
assert_eq!(prompt, "Hello");
assert_eq!(max_tokens, 50);
}
_ => panic!("Expected Generate command"),
}
}
#[tokio::test]
async fn test_state_initialization() {
use crate::engine::adapter::InferenceEngineAdapter;
use crate::model_registry::Registry;
let registry = Registry::with_discovery();
let engine = Box::new(InferenceEngineAdapter::new());
let state = std::sync::Arc::new(crate::AppState::new(engine, registry));
assert_ne!(std::mem::size_of_val(&state), 0);
let _models = state.registry.list();
}
#[test]
fn test_serve_enhanced_state_logic() {
let registry = model_registry::Registry::with_discovery();
let mut enhanced_state = AppState::new(
Box::new(engine::adapter::InferenceEngineAdapter::new()),
registry.clone(),
);
enhanced_state.registry.auto_register_discovered();
let enhanced_state_arc = Arc::new(enhanced_state);
let available_models = enhanced_state_arc.registry.list_all_available();
if available_models.is_empty() {
assert!(available_models.is_empty());
} else {
assert!(!available_models.is_empty());
}
}
#[test]
fn test_model_registration_with_env_vars() {
env::set_var("SHIMMY_BASE_GGUF", "/test/base.gguf");
env::set_var("SHIMMY_LORA_GGUF", "/test/lora.safetensors");
let base_path =
env::var("SHIMMY_BASE_GGUF").unwrap_or_else(|_| "./models/phi3-mini.gguf".into());
let lora_path = env::var("SHIMMY_LORA_GGUF").ok().map(Into::into);
let mut reg = model_registry::Registry::with_discovery();
reg.register(model_registry::ModelEntry {
name: "phi3-lora".into(),
base_path: base_path.into(),
lora_path,
template: Some("chatml".into()),
ctx_len: Some(2048),
n_threads: None,
});
let models = reg.list();
assert!(!models.is_empty());
env::remove_var("SHIMMY_BASE_GGUF");
env::remove_var("SHIMMY_LORA_GGUF");
}
#[test]
fn test_registry_model_methods() {
let mut registry = model_registry::Registry::with_discovery();
let initial_count = registry.list().len();
let _discovered = registry.discovered_models.clone();
let _all_available = registry.list_all_available();
registry.register(model_registry::ModelEntry {
name: "test".into(),
base_path: "./test.gguf".into(),
lora_path: None,
template: None,
ctx_len: None,
n_threads: None,
});
let after_count = registry.list().len();
assert!(after_count > initial_count);
let spec = registry.to_spec("test");
assert!(spec.is_some());
let no_spec = registry.to_spec("nonexistent");
assert!(no_spec.is_none());
}
#[test]
fn test_app_state_creation() {
use crate::engine::adapter::InferenceEngineAdapter;
use crate::model_registry::Registry;
let engine: Box<dyn engine::InferenceEngine> = Box::new(InferenceEngineAdapter::new());
let registry = Registry::with_discovery();
let _state = AppState::new(engine, registry);
}
#[test]
fn test_tracing_initialization() {
use tracing_subscriber::EnvFilter;
let env_filter = EnvFilter::from_default_env();
assert!(env_filter.max_level_hint().is_some() || env_filter.max_level_hint().is_none());
}
#[tokio::test]
async fn test_serve_command_execution_simulation() {
env::set_var("SHIMMY_BASE_GGUF", "./test.gguf");
let mut reg = model_registry::Registry::with_discovery();
reg.register(model_registry::ModelEntry {
name: "phi3-lora".into(),
base_path: env::var("SHIMMY_BASE_GGUF")
.unwrap_or_else(|_| "./models/phi3-mini.gguf".into())
.into(),
lora_path: env::var("SHIMMY_LORA_GGUF").ok().map(Into::into),
template: Some("chatml".into()),
ctx_len: Some(2048),
n_threads: None,
});
let engine: Box<dyn engine::InferenceEngine> =
Box::new(engine::adapter::InferenceEngineAdapter::new());
let state = AppState::new(engine, reg);
let state = Arc::new(state);
use crate::port_manager::GLOBAL_PORT_ALLOCATOR;
let dynamic_port = GLOBAL_PORT_ALLOCATOR
.allocate_ephemeral_port("test-serve-logic")
.unwrap();
let bind = format!("127.0.0.1:{}", dynamic_port);
let addr: std::net::SocketAddr = bind.parse().expect("bad bind address");
let manual_count = state.registry.list().len();
if manual_count <= 1 {
let mut enhanced_state = AppState::new(
Box::new(engine::adapter::InferenceEngineAdapter::new()),
state.registry.clone(),
);
enhanced_state.registry.auto_register_discovered();
let enhanced_state_arc = Arc::new(enhanced_state);
let available_models = enhanced_state_arc.registry.list_all_available();
if available_models.is_empty() {
assert!(available_models.is_empty());
} else {
assert!(!available_models.is_empty());
}
}
let available_models = state.registry.list_all_available();
if available_models.is_empty() {
assert!(available_models.is_empty());
}
env::remove_var("SHIMMY_BASE_GGUF");
assert_eq!(addr.port(), dynamic_port);
GLOBAL_PORT_ALLOCATOR.release_port(dynamic_port);
}
#[tokio::test]
async fn test_command_match_branches_coverage() {
let mut reg = model_registry::Registry::with_discovery();
reg.register(model_registry::ModelEntry {
name: "test-model".into(),
base_path: "./test.gguf".into(),
lora_path: None,
template: Some("chatml".into()),
ctx_len: Some(2048),
n_threads: None,
});
let _engine = MockEngine;
let state = Arc::new(AppState::new(
Box::new(engine::adapter::InferenceEngineAdapter::new()),
reg,
));
{
let manual_models = state.registry.list();
if !manual_models.is_empty() {
for e in &manual_models {
assert!(!e.name.is_empty());
}
}
let auto_discovered = state.registry.discovered_models.clone();
if !auto_discovered.is_empty() {
for (name, model) in auto_discovered {
let _size_mb = model.size_bytes / (1024 * 1024);
let _type_info = match (&model.parameter_count, &model.quantization) {
(Some(params), Some(quant)) => format!(" ({}·{})", params, quant),
(Some(params), None) => format!(" ({})", params),
(None, Some(quant)) => format!(" ({})", quant),
_ => String::new(),
};
let _lora_info = if model.lora_path.is_some() {
" + LoRA"
} else {
""
};
assert!(!name.is_empty());
}
}
let all_available = state.registry.list_all_available();
if all_available.is_empty() {
assert!(all_available.is_empty());
} else {
}
}
{
let registry = model_registry::Registry::with_discovery();
let discovered = registry.discovered_models.clone();
if discovered.is_empty() {
assert!(discovered.is_empty());
} else {
for (name, model) in discovered {
let _size_mb = model.size_bytes / (1024 * 1024);
let _lora_info = if model.lora_path.is_some() {
" + LoRA"
} else {
""
};
assert!(!name.is_empty());
}
}
}
let name = "test-model";
if let Some(spec) = state.registry.to_spec(name) {
assert!(spec.base_path.to_string_lossy().contains("test"));
} else {
panic!("Expected to find test model");
}
}
#[test]
fn test_error_conditions_and_edge_cases() {
let invalid_addresses = vec![
"invalid-address",
"256.256.256.256:9999", "127.0.0.1:99999", "127.0.0.1:", ":9999", "", "not.an.ip:port", ];
for addr_str in invalid_addresses {
let result = addr_str.parse::<std::net::SocketAddr>();
assert!(
result.is_err(),
"Expected parsing to fail for: {}",
addr_str
);
}
use crate::port_manager::GLOBAL_PORT_ALLOCATOR;
let port1 = GLOBAL_PORT_ALLOCATOR
.allocate_ephemeral_port("test-valid-1")
.unwrap();
let port2 = GLOBAL_PORT_ALLOCATOR
.allocate_ephemeral_port("test-valid-2")
.unwrap();
let port3 = GLOBAL_PORT_ALLOCATOR
.allocate_ephemeral_port("test-valid-3")
.unwrap();
let port4 = GLOBAL_PORT_ALLOCATOR
.allocate_ephemeral_port("test-valid-4")
.unwrap();
let valid_addresses = vec![
format!("127.0.0.1:{}", port1),
format!("0.0.0.0:{}", port2),
format!("192.168.1.100:{}", port3),
format!("[::1]:{}", port4), ];
GLOBAL_PORT_ALLOCATOR.release_port(port1);
GLOBAL_PORT_ALLOCATOR.release_port(port2);
GLOBAL_PORT_ALLOCATOR.release_port(port3);
GLOBAL_PORT_ALLOCATOR.release_port(port4);
for addr_str in valid_addresses {
let result = addr_str.parse::<std::net::SocketAddr>();
assert!(
result.is_ok(),
"Expected parsing to succeed for: {}",
addr_str
);
}
}
#[test]
fn test_registry_edge_cases() {
let registry = model_registry::Registry::with_discovery();
let nonexistent_names = vec![
"nonexistent-model",
"",
"model-with-special-chars!@#",
"very-long-model-name-that-might-cause-issues-in-some-systems-if-not-handled-properly",
];
for name in nonexistent_names {
let spec = registry.to_spec(name);
assert!(
spec.is_none(),
"Expected no spec for nonexistent model: {}",
name
);
}
}
#[test]
fn test_model_entry_variants() {
let mut registry = model_registry::Registry::with_discovery();
registry.register(model_registry::ModelEntry {
name: "minimal".to_string(),
base_path: "./minimal.gguf".into(),
lora_path: None,
template: None,
ctx_len: None,
n_threads: None,
});
registry.register(model_registry::ModelEntry {
name: "maximal".to_string(),
base_path: "./maximal.gguf".into(),
lora_path: Some("./maximal.lora".into()),
template: Some("llama3".to_string()),
ctx_len: Some(2048),
n_threads: Some(8),
});
let models = registry.list();
assert!(models.len() >= 2);
let minimal = models.iter().find(|e| e.name == "minimal").unwrap();
assert!(minimal.lora_path.is_none());
assert!(minimal.template.is_none());
let maximal = models.iter().find(|e| e.name == "maximal").unwrap();
assert!(maximal.lora_path.is_some());
assert_eq!(maximal.template.as_ref().unwrap(), "llama3");
assert_eq!(maximal.ctx_len.unwrap(), 2048);
assert_eq!(maximal.n_threads.unwrap(), 8);
}
#[tokio::test]
async fn test_mock_engine_behavior() {
let engine = MockEngine;
let minimal_spec = crate::engine::ModelSpec {
name: "test".to_string(),
base_path: "./test.gguf".into(),
lora_path: None,
template: None,
ctx_len: 1024,
n_threads: None,
};
let loaded = engine.load(&minimal_spec).await.unwrap();
let output = loaded
.generate("Test prompt", crate::engine::GenOptions::default(), None)
.await
.unwrap();
assert!(output.contains("Generated response to: Test prompt"));
assert!(output.contains("max_tokens:"));
let opts = crate::engine::GenOptions {
max_tokens: 150,
temperature: 0.8,
..Default::default()
};
let output = loaded.generate("Another test", opts, None).await.unwrap();
assert!(output.contains("Another test"));
assert!(output.contains("150"));
}
#[test]
fn test_auto_discovery_models_access() {
let registry = model_registry::Registry::with_discovery();
let discovered = registry.discovered_models.clone();
if discovered.is_empty() {
assert_eq!(discovered.len(), 0);
} else {
for model in discovered.values() {
let _type_info = match (&model.parameter_count, &model.quantization) {
(Some(params), Some(quant)) => {
let info = format!(" ({}·{})", params, quant);
assert!(info.contains(params));
assert!(info.contains(quant));
info
}
(Some(params), None) => {
let info = format!(" ({})", params);
assert!(info.contains(params));
info
}
(None, Some(quant)) => {
let info = format!(" ({})", quant);
assert!(info.contains(quant));
info
}
_ => {
String::new()
}
};
let _lora_info = if model.lora_path.is_some() {
" + LoRA"
} else {
""
};
}
}
let _all_available = registry.list_all_available();
}
#[tokio::test]
async fn test_serve_command_edge_cases() {
let empty_registry = model_registry::Registry::with_discovery();
let manual_count = empty_registry.list().len();
if manual_count <= 1 {
let mut enhanced_state = AppState::new(
Box::new(engine::adapter::InferenceEngineAdapter::new()),
empty_registry.clone(),
);
enhanced_state.registry.auto_register_discovered();
let available_models = enhanced_state.registry.list_all_available();
if available_models.is_empty() {
assert!(available_models.is_empty());
} else {
assert!(!available_models.is_empty());
}
}
}
#[test]
fn test_string_truncation_logic() {
let long_string = "A".repeat(150);
let expected_120 = "A".repeat(120);
let exactly_225 = "Exactly120chars".to_string() + &"A".repeat(105);
let test_strings = vec![
("Short", 120, "Short"),
(&long_string, 120, &expected_120),
("", 120, ""),
(&exactly_225, 120, &exactly_225),
];
for (input, limit, expected) in test_strings {
let truncated = &input[..input.len().min(limit)];
assert_eq!(truncated, expected);
}
}
#[test]
fn test_duration_and_timing() {
let start = std::time::Instant::now();
std::thread::sleep(std::time::Duration::from_millis(1));
let elapsed = start.elapsed();
assert!(elapsed.as_nanos() > 0);
assert!(elapsed.as_millis() >= 1);
let duration_str = format!("{:?}", elapsed);
assert!(
duration_str.contains("ms")
|| duration_str.contains("µs")
|| duration_str.contains("ns")
);
}
#[tokio::test]
async fn test_discover_command_execution() {
let registry = model_registry::Registry::with_discovery();
let discovered = registry.discovered_models.clone();
if discovered.is_empty() {
assert!(discovered.is_empty());
} else {
for (name, model) in discovered {
let _size_mb = model.size_bytes / (1024 * 1024);
let lora_info = if model.lora_path.is_some() {
" + LoRA"
} else {
""
};
assert!(lora_info == " + LoRA" || lora_info.is_empty());
assert!(!name.is_empty());
}
}
}
#[tokio::test]
async fn test_probe_command_execution() {
let mut registry = model_registry::Registry::with_discovery();
registry.register(model_registry::ModelEntry {
name: "probe-test".to_string(),
base_path: "./probe-test.gguf".into(),
lora_path: None,
template: Some("chatml".into()),
ctx_len: Some(2048),
n_threads: None,
});
let engine = MockEngine;
let name = "probe-test";
if let Some(spec) = registry.to_spec(name) {
match engine.load(&spec).await {
Ok(_) => {
}
Err(_) => {
panic!("MockEngine should not fail");
}
}
} else {
panic!("Expected to find probe-test model");
}
}
#[tokio::test]
async fn test_bench_command_execution() {
let mut registry = model_registry::Registry::with_discovery();
registry.register(model_registry::ModelEntry {
name: "bench-test".to_string(),
base_path: "./bench-test.gguf".into(),
lora_path: None,
template: Some("chatml".into()),
ctx_len: Some(2048),
n_threads: None,
});
let engine = MockEngine;
let name = "bench-test";
let max_tokens = 64;
if let Some(spec) = registry.to_spec(name) {
let loaded = engine.load(&spec).await.unwrap();
let t0 = std::time::Instant::now();
let out = loaded
.generate(
"Say hi.",
engine::GenOptions {
max_tokens,
stream: false,
..Default::default()
},
None,
)
.await
.unwrap();
let elapsed = t0.elapsed();
let truncated = &out[..out.len().min(120)];
assert!(!truncated.is_empty());
assert!(elapsed.as_nanos() > 0);
assert!(out.contains("Generated response to: Say hi."));
}
}
#[test]
fn test_environment_cleanup() {
let test_vars = vec!["SHIMMY_BASE_GGUF", "SHIMMY_LORA_GGUF"];
let original_values: Vec<_> = test_vars
.iter()
.map(|var| (*var, env::var(var).ok()))
.collect();
for var in &test_vars {
env::set_var(var, "/test/path");
}
for var in &test_vars {
assert_eq!(
env::var(var)
.unwrap_or_else(|_| panic!("Environment variable {} should be set", var)),
"/test/path"
);
}
for (var, original_value) in original_values {
match original_value {
Some(value) => env::set_var(var, value),
None => env::remove_var(var),
}
}
}
#[test]
fn test_size_calculations() {
let test_cases = vec![
(0u64, 0u64),
(1024, 0), (1024 * 1024, 1), (1024 * 1024 * 5, 5), (1024 * 1024 * 1536, 1536), ];
for (size_bytes, expected_mb) in test_cases {
let size_mb = size_bytes / (1024 * 1024);
assert_eq!(
size_mb, expected_mb,
"Size calculation failed for {} bytes",
size_bytes
);
}
}
#[tokio::test]
async fn test_model_loading_error_paths() {
let mut registry = model_registry::Registry::with_discovery();
registry.register(model_registry::ModelEntry {
name: "error-test".to_string(),
base_path: "./nonexistent.gguf".into(), lora_path: None,
template: Some("chatml".into()),
ctx_len: Some(2048),
n_threads: None,
});
let engine = MockEngine;
let spec = registry.to_spec("error-test").unwrap();
let load_result = engine.load(&spec).await;
match load_result {
Ok(loaded_model) => {
let gen_result = loaded_model
.generate("test prompt", crate::engine::GenOptions::default(), None)
.await;
assert!(gen_result.is_ok());
}
Err(_) => {
}
}
}
#[test]
fn test_print_startup_diagnostics_basic() {
print_startup_diagnostics("1.6.0", None, false, None, 3, true);
print_startup_diagnostics("1.6.0", Some("auto"), false, None, 5, true);
}
#[test]
fn test_print_startup_diagnostics_with_backends() {
print_startup_diagnostics("1.6.0", Some("cpu"), false, None, 2, false);
print_startup_diagnostics("1.6.0", Some("cuda"), false, None, 4, false);
print_startup_diagnostics("1.6.0", Some("vulkan"), false, None, 1, false);
print_startup_diagnostics("1.6.0", Some("opencl"), false, None, 6, false);
print_startup_diagnostics("1.6.0", Some("custom-backend"), false, None, 3, false);
}
#[test]
#[cfg(feature = "llama")]
fn test_print_startup_diagnostics_with_moe() {
print_startup_diagnostics("1.6.0", Some("cuda"), true, None, 2, false);
print_startup_diagnostics("1.6.0", Some("cuda"), false, Some(16), 2, false);
print_startup_diagnostics("1.6.0", Some("auto"), true, None, 5, false);
}
#[test]
fn test_print_startup_diagnostics_zero_models() {
print_startup_diagnostics("1.6.0", None, false, None, 0, true);
}
#[test]
fn test_print_startup_diagnostics_many_models() {
print_startup_diagnostics("1.6.0", Some("cuda"), false, None, 13, false);
print_startup_diagnostics("1.6.0", Some("auto"), true, None, 25, false);
}
#[test]
fn test_serve_diagnostics_integration() {
let version = env!("CARGO_PKG_VERSION");
let gpu_backend: Option<&str> = None;
let cpu_moe = false;
let n_cpu_moe: Option<usize> = None;
let model_count = 0;
print_startup_diagnostics(version, gpu_backend, cpu_moe, n_cpu_moe, model_count, true);
}
#[test]
fn test_startup_diagnostics_version_display() {
let version = env!("CARGO_PKG_VERSION");
assert!(!version.is_empty(), "Version should not be empty");
assert_ne!(
version, "0.1.0",
"Version should not be the broken 0.1.0 from Issue #63"
);
print_startup_diagnostics(version, None, false, None, 1, true);
}
}