mod api;
mod api_errors;
mod auto_discovery;
mod cli;
mod engine;
mod main_integration;
mod model_registry;
mod openai_compat;
mod port_manager;
mod server;
mod templates;
mod util {
pub mod diag;
}
use clap::Parser;
use model_registry::{ModelEntry, Registry};
use std::net::SocketAddr;
use std::sync::Arc;
use tracing::info;
pub struct AppState {
pub engine: Box<dyn engine::InferenceEngine>,
pub registry: Registry,
}
#[tokio::main]
async fn main() -> anyhow::Result<()> {
tracing_subscriber::fmt()
.with_env_filter(tracing_subscriber::EnvFilter::from_default_env())
.init();
#[cfg(all(target_arch = "aarch64", target_os = "macos", not(feature = "llama")))]
info!("llama.cpp temporarily disabled on macOS ARM64 due to upstream i8mm build incompatibility; using SafeTensors backend");
let cli = cli::Cli::parse();
if let Some(model_dirs) = &cli.model_dirs {
std::env::set_var("SHIMMY_MODEL_PATHS", model_dirs);
}
let mut reg = Registry::with_discovery();
reg.register(ModelEntry {
name: "phi3-lora".into(),
base_path: std::env::var("SHIMMY_BASE_GGUF")
.unwrap_or_else(|_| "./models/phi3-mini.gguf".into())
.into(),
lora_path: std::env::var("SHIMMY_LORA_GGUF").ok().map(Into::into),
template: Some("chatml".into()),
ctx_len: Some(4096),
n_threads: None,
});
let engine: Box<dyn engine::InferenceEngine> =
Box::new(engine::adapter::InferenceEngineAdapter::new());
let state = AppState {
engine,
registry: reg,
};
let state = Arc::new(state);
match cli.cmd {
cli::Command::Serve { .. } => {
let bind_address = cli.cmd.get_bind_address();
let addr: SocketAddr = bind_address.parse().expect("bad bind address");
println!("🚀 Starting Shimmy server on {}", bind_address);
let manual_count = state.registry.list().len();
if manual_count <= 1 {
let mut enhanced_state = AppState {
engine: Box::new(engine::llama::LlamaEngine::new()),
registry: state.registry.clone(),
};
enhanced_state.registry.auto_register_discovered();
let enhanced_state = Arc::new(enhanced_state);
let available_models = enhanced_state.registry.list_all_available();
if available_models.is_empty() {
eprintln!("❌ No models available. Please:");
eprintln!(" • Set SHIMMY_BASE_GGUF environment variable, or");
eprintln!(" • Place .gguf files in ./models/ directory, or");
eprintln!(" • Place .gguf files in ~/.cache/huggingface/hub/");
std::process::exit(1);
}
info!(%addr, models=%available_models.len(), "shimmy serving with {} available models", available_models.len());
return server::run(addr, enhanced_state).await;
}
let available_models = state.registry.list_all_available();
if available_models.is_empty() {
eprintln!("❌ No models available. Please:");
eprintln!(" • Set SHIMMY_BASE_GGUF environment variable, or");
eprintln!(" • Place .gguf files in ./models/ directory, or");
eprintln!(" • Place .gguf files in ~/.cache/huggingface/hub/");
std::process::exit(1);
}
info!(%addr, models=%available_models.len(), "shimmy serving with {} available models", available_models.len());
server::run(addr, state).await?;
}
cli::Command::List => {
let manual_models = state.registry.list();
if !manual_models.is_empty() {
println!("📋 Registered Models:");
for e in &manual_models {
println!(" {} => {:?}", e.name, e.base_path);
}
}
let auto_discovered = state.registry.discovered_models.clone();
if !auto_discovered.is_empty() {
if !manual_models.is_empty() {
println!();
}
println!("🔍 Auto-Discovered Models:");
for (name, model) in auto_discovered {
let size_mb = model.size_bytes / (1024 * 1024);
let type_info = match (&model.parameter_count, &model.quantization) {
(Some(params), Some(quant)) => format!(" ({}·{})", params, quant),
(Some(params), None) => format!(" ({})", params),
(None, Some(quant)) => format!(" ({})", quant),
_ => String::new(),
};
let lora_info = if model.lora_path.is_some() {
" + LoRA"
} else {
""
};
println!(
" {} => {:?} [{}MB{}{}]",
name, model.path, size_mb, type_info, lora_info
);
}
}
let all_available = state.registry.list_all_available();
if all_available.is_empty() {
println!(
"❌ No models found. Set SHIMMY_BASE_GGUF or place .gguf files in ./models/"
);
} else {
println!("\n✅ Total available models: {}", all_available.len());
}
}
cli::Command::Discover => {
println!("🔍 Refreshing model discovery...");
let registry = Registry::with_discovery();
let discovered = registry.discovered_models.clone();
if discovered.is_empty() {
println!("❌ No models found in search paths:");
let discovery = crate::auto_discovery::ModelAutoDiscovery::new();
for path in &discovery.search_paths {
println!(" • {:?}", path);
}
println!(" • Ollama models (if installed)");
println!("\n💡 Try downloading a GGUF model or setting SHIMMY_BASE_GGUF");
} else {
println!("✅ Found {} models:", discovered.len());
for (name, model) in discovered {
let size_mb = model.size_bytes / (1024 * 1024);
let lora_info = if model.lora_path.is_some() {
" + LoRA"
} else {
""
};
println!(" {} [{}MB{}]", name, size_mb, lora_info);
println!(" Base: {:?}", model.path);
if let Some(lora) = &model.lora_path {
println!(" LoRA: {:?}", lora);
}
}
}
}
cli::Command::Probe { name } => {
let Some(spec) = state.registry.to_spec(&name) else {
anyhow::bail!("no model {name}");
};
match state.engine.load(&spec).await {
Ok(_) => println!("ok: loaded {name}"),
Err(e) => {
eprintln!("probe failed: {e}");
std::process::exit(2);
}
}
}
cli::Command::Bench { name, max_tokens } => {
let Some(spec) = state.registry.to_spec(&name) else {
anyhow::bail!("no model {name}");
};
let loaded = state.engine.load(&spec).await?;
let t0 = std::time::Instant::now();
let out = loaded
.generate(
"Say hi.",
engine::GenOptions {
max_tokens,
stream: false,
..Default::default()
},
None,
)
.await?;
let elapsed = t0.elapsed();
println!("bench output (truncated): {}", &out[..out.len().min(120)]);
println!("elapsed: {:?}", elapsed);
}
cli::Command::Generate {
name,
prompt,
max_tokens,
} => {
let Some(spec) = state.registry.to_spec(&name) else {
anyhow::bail!("no model {name}");
};
let loaded = state.engine.load(&spec).await?;
let out = loaded
.generate(
&prompt,
engine::GenOptions {
max_tokens,
stream: false,
..Default::default()
},
None,
)
.await?;
println!("{}", out);
}
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::engine::InferenceEngine;
use std::env;
use std::sync::Arc;
struct MockEngine;
#[async_trait::async_trait]
impl engine::InferenceEngine for MockEngine {
async fn load(
&self,
_spec: &engine::ModelSpec,
) -> anyhow::Result<Box<dyn engine::LoadedModel>> {
Ok(Box::new(MockLoadedModel))
}
}
struct MockLoadedModel;
#[async_trait::async_trait]
impl engine::LoadedModel for MockLoadedModel {
async fn generate(
&self,
prompt: &str,
opts: engine::GenOptions,
_on_token: Option<Box<dyn FnMut(String) + Send>>,
) -> anyhow::Result<String> {
Ok(format!(
"Generated response to: {} (max_tokens: {})",
prompt, opts.max_tokens
))
}
}
#[tokio::test]
async fn test_main_initialization_paths() {
env::remove_var("SHIMMY_BASE_GGUF");
env::remove_var("SHIMMY_LORA_GGUF");
let mut reg = model_registry::Registry::with_discovery();
reg.register(model_registry::ModelEntry {
name: "phi3-lora".into(),
base_path: "./models/phi3-mini.gguf".into(),
lora_path: None,
template: Some("chatml".into()),
ctx_len: Some(4096),
n_threads: None,
});
let engine: Box<dyn engine::InferenceEngine> =
Box::new(engine::adapter::InferenceEngineAdapter::new());
let state = AppState {
engine,
registry: reg,
};
let _state_arc = Arc::new(state);
assert!(true); }
#[tokio::test]
async fn test_environment_variable_handling() {
env::remove_var("SHIMMY_BASE_GGUF");
env::remove_var("SHIMMY_LORA_GGUF");
env::set_var("SHIMMY_BASE_GGUF", "/custom/path/model.gguf");
env::set_var("SHIMMY_LORA_GGUF", "/custom/path/lora.safetensors");
let base_path =
env::var("SHIMMY_BASE_GGUF").unwrap_or_else(|_| "./models/phi3-mini.gguf".into());
let lora_path = env::var("SHIMMY_LORA_GGUF").ok();
assert_eq!(base_path, "/custom/path/model.gguf");
assert_eq!(lora_path, Some("/custom/path/lora.safetensors".to_string()));
env::remove_var("SHIMMY_BASE_GGUF");
env::remove_var("SHIMMY_LORA_GGUF");
}
#[test]
fn test_serve_command_address_parsing() {
use crate::port_manager::GLOBAL_PORT_ALLOCATOR;
let dynamic_port = GLOBAL_PORT_ALLOCATOR
.allocate_ephemeral_port("test-serve-parsing")
.unwrap();
let bind_str = format!("127.0.0.1:{}", dynamic_port);
let addr: std::net::SocketAddr = bind_str.parse().expect("bad bind address");
assert_eq!(addr.port(), dynamic_port);
GLOBAL_PORT_ALLOCATOR.release_port(dynamic_port);
let invalid_bind = "invalid:address";
let result = invalid_bind.parse::<std::net::SocketAddr>();
assert!(result.is_err());
}
#[test]
fn test_serve_command_model_count_logic() {
let registry = model_registry::Registry::with_discovery();
let manual_count = registry.list().len();
let should_auto_register = manual_count <= 1;
assert!(should_auto_register || !should_auto_register); }
#[tokio::test]
async fn test_list_command_execution_logic() {
let mut registry = model_registry::Registry::with_discovery();
registry.register(model_registry::ModelEntry {
name: "test-model".into(),
base_path: "./test.gguf".into(),
lora_path: None,
template: Some("chatml".into()),
ctx_len: Some(2048),
n_threads: None,
});
let manual_models = registry.list();
assert!(!manual_models.is_empty());
let _auto_discovered = registry.discovered_models.clone();
let _all_available = registry.list_all_available();
assert!(true);
}
#[tokio::test]
async fn test_discover_command_execution_logic() {
let registry = model_registry::Registry::with_discovery();
let discovered = registry.discovered_models.clone();
if discovered.is_empty() {
assert!(discovered.is_empty());
} else {
assert!(!discovered.is_empty());
for (name, model) in discovered {
let _size_mb = model.size_bytes / (1024 * 1024);
let _lora_info = if model.lora_path.is_some() {
" + LoRA"
} else {
""
};
assert!(!name.is_empty());
}
}
}
#[tokio::test]
async fn test_probe_command_execution_logic() {
let mut registry = model_registry::Registry::with_discovery();
registry.register(model_registry::ModelEntry {
name: "test-probe-model".into(),
base_path: "./test.gguf".into(),
lora_path: None,
template: Some("chatml".into()),
ctx_len: Some(2048),
n_threads: None,
});
let engine = MockEngine;
let name = "test-probe-model";
let spec_result = registry.to_spec(name);
if let Some(spec) = spec_result {
let load_result = engine.load(&spec).await;
match load_result {
Ok(_) => {
assert!(true);
}
Err(_) => {
assert!(true);
}
}
} else {
assert!(true);
}
}
#[tokio::test]
async fn test_bench_command_execution_logic() {
let mut registry = model_registry::Registry::with_discovery();
registry.register(model_registry::ModelEntry {
name: "test-bench-model".into(),
base_path: "./test.gguf".into(),
lora_path: None,
template: Some("chatml".into()),
ctx_len: Some(2048),
n_threads: None,
});
let engine = MockEngine;
let name = "test-bench-model";
let max_tokens = 100;
if let Some(spec) = registry.to_spec(name) {
let loaded = engine.load(&spec).await.unwrap();
let t0 = std::time::Instant::now();
let out = loaded
.generate(
"Say hi.",
engine::GenOptions {
max_tokens,
stream: false,
..Default::default()
},
None,
)
.await
.unwrap();
let elapsed = t0.elapsed();
let truncated = &out[..out.len().min(120)];
assert!(!truncated.is_empty());
assert!(elapsed.as_nanos() > 0);
}
}
#[tokio::test]
async fn test_generate_command_execution_logic() {
let mut registry = model_registry::Registry::with_discovery();
registry.register(model_registry::ModelEntry {
name: "test-gen-model".into(),
base_path: "./test.gguf".into(),
lora_path: None,
template: Some("chatml".into()),
ctx_len: Some(2048),
n_threads: None,
});
let engine = MockEngine;
let name = "test-gen-model";
let prompt = "Hello, world!";
let max_tokens = 50;
if let Some(spec) = registry.to_spec(name) {
let loaded = engine.load(&spec).await.unwrap();
let out = loaded
.generate(
prompt,
engine::GenOptions {
max_tokens,
stream: false,
..Default::default()
},
None,
)
.await
.unwrap();
assert!(out.contains("Generated response to: Hello, world!"));
}
}
#[test]
fn test_command_execution_paths() {
use crate::cli::{Cli, Command};
use clap::Parser;
let gen_args = vec![
"shimmy",
"generate",
"test-model",
"--prompt",
"Hello",
"--max-tokens",
"50",
];
let cli = Cli::try_parse_from(gen_args).unwrap();
match cli.cmd {
Command::Generate {
name,
prompt,
max_tokens,
} => {
assert_eq!(name, "test-model");
assert_eq!(prompt, "Hello");
assert_eq!(max_tokens, 50);
}
_ => panic!("Expected Generate command"),
}
}
#[tokio::test]
async fn test_state_initialization() {
use crate::engine::adapter::InferenceEngineAdapter;
use crate::model_registry::Registry;
let registry = Registry::with_discovery();
let engine = Box::new(InferenceEngineAdapter::new());
let state = std::sync::Arc::new(crate::AppState { engine, registry });
assert_ne!(std::mem::size_of_val(&state), 0);
let models = state.registry.list();
assert!(models.len() >= 0);
}
#[test]
fn test_serve_enhanced_state_logic() {
let registry = model_registry::Registry::with_discovery();
let mut enhanced_state = AppState {
engine: Box::new(engine::llama::LlamaEngine::new()),
registry: registry.clone(),
};
enhanced_state.registry.auto_register_discovered();
let enhanced_state_arc = Arc::new(enhanced_state);
let available_models = enhanced_state_arc.registry.list_all_available();
if available_models.is_empty() {
assert!(available_models.is_empty());
} else {
assert!(!available_models.is_empty());
}
}
#[test]
fn test_model_registration_with_env_vars() {
env::set_var("SHIMMY_BASE_GGUF", "/test/base.gguf");
env::set_var("SHIMMY_LORA_GGUF", "/test/lora.safetensors");
let base_path =
env::var("SHIMMY_BASE_GGUF").unwrap_or_else(|_| "./models/phi3-mini.gguf".into());
let lora_path = env::var("SHIMMY_LORA_GGUF").ok().map(Into::into);
let mut reg = model_registry::Registry::with_discovery();
reg.register(model_registry::ModelEntry {
name: "phi3-lora".into(),
base_path: base_path.into(),
lora_path,
template: Some("chatml".into()),
ctx_len: Some(4096),
n_threads: None,
});
let models = reg.list();
assert!(models.len() >= 1);
env::remove_var("SHIMMY_BASE_GGUF");
env::remove_var("SHIMMY_LORA_GGUF");
}
#[test]
fn test_registry_model_methods() {
let mut registry = model_registry::Registry::with_discovery();
let initial_count = registry.list().len();
let _discovered = registry.discovered_models.clone();
let _all_available = registry.list_all_available();
registry.register(model_registry::ModelEntry {
name: "test".into(),
base_path: "./test.gguf".into(),
lora_path: None,
template: None,
ctx_len: None,
n_threads: None,
});
let after_count = registry.list().len();
assert!(after_count > initial_count);
let spec = registry.to_spec("test");
assert!(spec.is_some());
let no_spec = registry.to_spec("nonexistent");
assert!(no_spec.is_none());
}
#[test]
fn test_app_state_creation() {
use crate::engine::adapter::InferenceEngineAdapter;
use crate::model_registry::Registry;
let engine: Box<dyn engine::InferenceEngine> = Box::new(InferenceEngineAdapter::new());
let registry = Registry::with_discovery();
let state = AppState { engine, registry };
assert!(state.registry.list().len() >= 0);
}
#[test]
fn test_tracing_initialization() {
use tracing_subscriber::EnvFilter;
let env_filter = EnvFilter::from_default_env();
assert!(env_filter.max_level_hint().is_some() || env_filter.max_level_hint().is_none());
}
#[tokio::test]
async fn test_serve_command_execution_simulation() {
env::set_var("SHIMMY_BASE_GGUF", "./test.gguf");
let mut reg = model_registry::Registry::with_discovery();
reg.register(model_registry::ModelEntry {
name: "phi3-lora".into(),
base_path: env::var("SHIMMY_BASE_GGUF")
.unwrap_or_else(|_| "./models/phi3-mini.gguf".into())
.into(),
lora_path: env::var("SHIMMY_LORA_GGUF").ok().map(Into::into),
template: Some("chatml".into()),
ctx_len: Some(4096),
n_threads: None,
});
let engine: Box<dyn engine::InferenceEngine> =
Box::new(engine::adapter::InferenceEngineAdapter::new());
let state = AppState {
engine,
registry: reg,
};
let state = Arc::new(state);
use crate::port_manager::GLOBAL_PORT_ALLOCATOR;
let dynamic_port = GLOBAL_PORT_ALLOCATOR
.allocate_ephemeral_port("test-serve-logic")
.unwrap();
let bind = format!("127.0.0.1:{}", dynamic_port);
let addr: std::net::SocketAddr = bind.parse().expect("bad bind address");
let manual_count = state.registry.list().len();
if manual_count <= 1 {
let mut enhanced_state = AppState {
engine: Box::new(engine::llama::LlamaEngine::new()),
registry: state.registry.clone(),
};
enhanced_state.registry.auto_register_discovered();
let enhanced_state_arc = Arc::new(enhanced_state);
let available_models = enhanced_state_arc.registry.list_all_available();
if available_models.is_empty() {
assert!(available_models.is_empty());
} else {
assert!(!available_models.is_empty());
}
}
let available_models = state.registry.list_all_available();
if available_models.is_empty() {
assert!(available_models.is_empty());
}
env::remove_var("SHIMMY_BASE_GGUF");
assert_eq!(addr.port(), dynamic_port);
GLOBAL_PORT_ALLOCATOR.release_port(dynamic_port);
}
#[tokio::test]
async fn test_command_match_branches_coverage() {
let mut reg = model_registry::Registry::with_discovery();
reg.register(model_registry::ModelEntry {
name: "test-model".into(),
base_path: "./test.gguf".into(),
lora_path: None,
template: Some("chatml".into()),
ctx_len: Some(2048),
n_threads: None,
});
let engine = MockEngine;
let state = Arc::new(AppState {
engine: Box::new(engine::adapter::InferenceEngineAdapter::new()),
registry: reg,
});
{
let manual_models = state.registry.list();
if !manual_models.is_empty() {
for e in &manual_models {
assert!(!e.name.is_empty());
}
}
let auto_discovered = state.registry.discovered_models.clone();
if !auto_discovered.is_empty() {
for (name, model) in auto_discovered {
let _size_mb = model.size_bytes / (1024 * 1024);
let _type_info = match (&model.parameter_count, &model.quantization) {
(Some(params), Some(quant)) => format!(" ({}·{})", params, quant),
(Some(params), None) => format!(" ({})", params),
(None, Some(quant)) => format!(" ({})", quant),
_ => String::new(),
};
let _lora_info = if model.lora_path.is_some() {
" + LoRA"
} else {
""
};
assert!(!name.is_empty());
}
}
let all_available = state.registry.list_all_available();
if all_available.is_empty() {
assert!(all_available.is_empty());
} else {
assert!(all_available.len() >= 0);
}
}
{
let registry = model_registry::Registry::with_discovery();
let discovered = registry.discovered_models.clone();
if discovered.is_empty() {
assert!(discovered.is_empty());
} else {
for (name, model) in discovered {
let _size_mb = model.size_bytes / (1024 * 1024);
let _lora_info = if model.lora_path.is_some() {
" + LoRA"
} else {
""
};
assert!(!name.is_empty());
}
}
}
let name = "test-model";
if let Some(spec) = state.registry.to_spec(name) {
assert!(spec.base_path.to_string_lossy().contains("test"));
assert!(true); } else {
assert!(false, "Expected to find test model");
}
}
#[test]
fn test_error_conditions_and_edge_cases() {
let invalid_addresses = vec![
"invalid-address",
"256.256.256.256:9999", "127.0.0.1:99999", "127.0.0.1:", ":9999", "", "not.an.ip:port", ];
for addr_str in invalid_addresses {
let result = addr_str.parse::<std::net::SocketAddr>();
assert!(
result.is_err(),
"Expected parsing to fail for: {}",
addr_str
);
}
use crate::port_manager::GLOBAL_PORT_ALLOCATOR;
let port1 = GLOBAL_PORT_ALLOCATOR
.allocate_ephemeral_port("test-valid-1")
.unwrap();
let port2 = GLOBAL_PORT_ALLOCATOR
.allocate_ephemeral_port("test-valid-2")
.unwrap();
let port3 = GLOBAL_PORT_ALLOCATOR
.allocate_ephemeral_port("test-valid-3")
.unwrap();
let port4 = GLOBAL_PORT_ALLOCATOR
.allocate_ephemeral_port("test-valid-4")
.unwrap();
let valid_addresses = vec![
format!("127.0.0.1:{}", port1),
format!("0.0.0.0:{}", port2),
format!("192.168.1.100:{}", port3),
format!("[::1]:{}", port4), ];
GLOBAL_PORT_ALLOCATOR.release_port(port1);
GLOBAL_PORT_ALLOCATOR.release_port(port2);
GLOBAL_PORT_ALLOCATOR.release_port(port3);
GLOBAL_PORT_ALLOCATOR.release_port(port4);
for addr_str in valid_addresses {
let result = addr_str.parse::<std::net::SocketAddr>();
assert!(
result.is_ok(),
"Expected parsing to succeed for: {}",
addr_str
);
}
}
#[test]
fn test_registry_edge_cases() {
let registry = model_registry::Registry::with_discovery();
let nonexistent_names = vec![
"nonexistent-model",
"",
"model-with-special-chars!@#",
"very-long-model-name-that-might-cause-issues-in-some-systems-if-not-handled-properly",
];
for name in nonexistent_names {
let spec = registry.to_spec(name);
assert!(
spec.is_none(),
"Expected no spec for nonexistent model: {}",
name
);
}
}
#[test]
fn test_model_entry_variants() {
let mut registry = model_registry::Registry::with_discovery();
registry.register(model_registry::ModelEntry {
name: "minimal".to_string(),
base_path: "./minimal.gguf".into(),
lora_path: None,
template: None,
ctx_len: None,
n_threads: None,
});
registry.register(model_registry::ModelEntry {
name: "maximal".to_string(),
base_path: "./maximal.gguf".into(),
lora_path: Some("./maximal.lora".into()),
template: Some("llama3".to_string()),
ctx_len: Some(8192),
n_threads: Some(8),
});
let models = registry.list();
assert!(models.len() >= 2);
let minimal = models.iter().find(|e| e.name == "minimal").unwrap();
assert!(minimal.lora_path.is_none());
assert!(minimal.template.is_none());
let maximal = models.iter().find(|e| e.name == "maximal").unwrap();
assert!(maximal.lora_path.is_some());
assert_eq!(maximal.template.as_ref().unwrap(), "llama3");
assert_eq!(maximal.ctx_len.unwrap(), 8192);
assert_eq!(maximal.n_threads.unwrap(), 8);
}
#[tokio::test]
async fn test_mock_engine_behavior() {
let engine = MockEngine;
let minimal_spec = crate::engine::ModelSpec {
name: "test".to_string(),
base_path: "./test.gguf".into(),
lora_path: None,
template: None,
ctx_len: 1024,
n_threads: None,
};
let loaded = engine.load(&minimal_spec).await.unwrap();
let output = loaded
.generate("Test prompt", crate::engine::GenOptions::default(), None)
.await
.unwrap();
assert!(output.contains("Generated response to: Test prompt"));
assert!(output.contains("max_tokens:"));
let mut opts = crate::engine::GenOptions::default();
opts.max_tokens = 150;
opts.temperature = 0.8;
let output = loaded.generate("Another test", opts, None).await.unwrap();
assert!(output.contains("Another test"));
assert!(output.contains("150"));
}
#[test]
fn test_auto_discovery_models_access() {
let registry = model_registry::Registry::with_discovery();
let discovered = registry.discovered_models.clone();
if discovered.is_empty() {
assert_eq!(discovered.len(), 0);
} else {
for (_name, model) in &discovered {
let _type_info = match (&model.parameter_count, &model.quantization) {
(Some(params), Some(quant)) => {
let info = format!(" ({}·{})", params, quant);
assert!(info.contains(params));
assert!(info.contains(quant));
info
}
(Some(params), None) => {
let info = format!(" ({})", params);
assert!(info.contains(params));
info
}
(None, Some(quant)) => {
let info = format!(" ({})", quant);
assert!(info.contains(quant));
info
}
_ => {
String::new()
}
};
let _lora_info = if model.lora_path.is_some() {
" + LoRA"
} else {
""
};
}
}
let all_available = registry.list_all_available();
assert!(all_available.len() >= 0);
}
#[tokio::test]
async fn test_serve_command_edge_cases() {
let empty_registry = model_registry::Registry::with_discovery();
let manual_count = empty_registry.list().len();
if manual_count <= 1 {
let mut enhanced_state = AppState {
engine: Box::new(engine::llama::LlamaEngine::new()),
registry: empty_registry.clone(),
};
enhanced_state.registry.auto_register_discovered();
let available_models = enhanced_state.registry.list_all_available();
if available_models.is_empty() {
assert!(available_models.is_empty());
} else {
assert!(!available_models.is_empty());
}
}
}
#[test]
fn test_string_truncation_logic() {
let long_string = "A".repeat(150);
let expected_120 = "A".repeat(120);
let exactly_225 = "Exactly120chars".to_string() + &"A".repeat(105);
let test_strings = vec![
("Short", 120, "Short"),
(&long_string, 120, &expected_120),
("", 120, ""),
(&exactly_225, 120, &exactly_225),
];
for (input, limit, expected) in test_strings {
let truncated = &input[..input.len().min(limit)];
assert_eq!(truncated, expected);
}
}
#[test]
fn test_duration_and_timing() {
let start = std::time::Instant::now();
std::thread::sleep(std::time::Duration::from_millis(1));
let elapsed = start.elapsed();
assert!(elapsed.as_nanos() > 0);
assert!(elapsed.as_millis() >= 1);
let duration_str = format!("{:?}", elapsed);
assert!(
duration_str.contains("ms")
|| duration_str.contains("µs")
|| duration_str.contains("ns")
);
}
#[tokio::test]
async fn test_discover_command_execution() {
let registry = model_registry::Registry::with_discovery();
let discovered = registry.discovered_models.clone();
if discovered.is_empty() {
assert!(discovered.is_empty());
} else {
for (name, model) in discovered {
let size_mb = model.size_bytes / (1024 * 1024);
assert!(size_mb >= 0);
let lora_info = if model.lora_path.is_some() {
" + LoRA"
} else {
""
};
assert!(lora_info == " + LoRA" || lora_info == "");
assert!(!name.is_empty());
}
}
}
#[tokio::test]
async fn test_probe_command_execution() {
let mut registry = model_registry::Registry::with_discovery();
registry.register(model_registry::ModelEntry {
name: "probe-test".to_string(),
base_path: "./probe-test.gguf".into(),
lora_path: None,
template: Some("chatml".into()),
ctx_len: Some(2048),
n_threads: None,
});
let engine = MockEngine;
let name = "probe-test";
if let Some(spec) = registry.to_spec(name) {
match engine.load(&spec).await {
Ok(_) => {
assert!(true);
}
Err(_) => {
assert!(false, "MockEngine should not fail");
}
}
} else {
assert!(false, "Expected to find probe-test model");
}
}
#[tokio::test]
async fn test_bench_command_execution() {
let mut registry = model_registry::Registry::with_discovery();
registry.register(model_registry::ModelEntry {
name: "bench-test".to_string(),
base_path: "./bench-test.gguf".into(),
lora_path: None,
template: Some("chatml".into()),
ctx_len: Some(2048),
n_threads: None,
});
let engine = MockEngine;
let name = "bench-test";
let max_tokens = 64;
if let Some(spec) = registry.to_spec(name) {
let loaded = engine.load(&spec).await.unwrap();
let t0 = std::time::Instant::now();
let out = loaded
.generate(
"Say hi.",
engine::GenOptions {
max_tokens,
stream: false,
..Default::default()
},
None,
)
.await
.unwrap();
let elapsed = t0.elapsed();
let truncated = &out[..out.len().min(120)];
assert!(!truncated.is_empty());
assert!(elapsed.as_nanos() > 0);
assert!(out.contains("Generated response to: Say hi."));
}
}
#[test]
fn test_environment_cleanup() {
let test_vars = vec![
"SHIMMY_BASE_GGUF",
"SHIMMY_LORA_GGUF",
"HOME",
"USERPROFILE",
];
let original_values: Vec<_> = test_vars
.iter()
.map(|var| (*var, env::var(var).ok()))
.collect();
for var in &test_vars {
env::set_var(var, "/test/path");
}
for var in &test_vars {
assert_eq!(env::var(var).unwrap(), "/test/path");
}
for (var, original_value) in original_values {
match original_value {
Some(value) => env::set_var(var, value),
None => env::remove_var(var),
}
}
}
#[test]
fn test_size_calculations() {
let test_cases = vec![
(0u64, 0u64),
(1024, 0), (1024 * 1024, 1), (1024 * 1024 * 5, 5), (1024 * 1024 * 1536, 1536), ];
for (size_bytes, expected_mb) in test_cases {
let size_mb = size_bytes / (1024 * 1024);
assert_eq!(
size_mb, expected_mb,
"Size calculation failed for {} bytes",
size_bytes
);
}
}
#[tokio::test]
async fn test_model_loading_error_paths() {
let mut registry = model_registry::Registry::with_discovery();
registry.register(model_registry::ModelEntry {
name: "error-test".to_string(),
base_path: "./nonexistent.gguf".into(), lora_path: None,
template: Some("chatml".into()),
ctx_len: Some(2048),
n_threads: None,
});
let engine = MockEngine;
let spec = registry.to_spec("error-test").unwrap();
let load_result = engine.load(&spec).await;
match load_result {
Ok(loaded_model) => {
let gen_result = loaded_model
.generate("test prompt", crate::engine::GenOptions::default(), None)
.await;
assert!(gen_result.is_ok());
}
Err(_) => {
assert!(true); }
}
}
}