use mullama::*;
#[cfg(test)]
mod model_error_tests {
use super::*;
#[test]
fn test_model_load_nonexistent_file() {
let result = Model::load("nonexistent_model.gguf");
assert!(result.is_err());
match result.unwrap_err() {
MullamaError::ModelLoadError(msg) => {
assert!(!msg.is_empty());
println!("Expected error for nonexistent file: {}", msg);
}
_ => panic!("Expected ModelLoadError"),
}
}
#[test]
fn test_model_load_invalid_path() {
let invalid_paths = [
"",
"/dev/null",
"/invalid/path/to/model.gguf",
"model_with_invalid_extension.txt",
];
for path in &invalid_paths {
let result = Model::load(path);
assert!(result.is_err(), "Should fail for path: {}", path);
}
}
#[test]
fn test_model_params_validation() {
let mut params = ModelParams::default();
params.n_gpu_layers = -1;
assert_eq!(params.n_gpu_layers, -1);
params.n_gpu_layers = i32::MAX;
assert_eq!(params.n_gpu_layers, i32::MAX);
params.n_gpu_layers = i32::MIN;
assert_eq!(params.n_gpu_layers, i32::MIN);
}
#[test]
fn test_model_kv_override_edge_cases() {
let override_empty_key = ModelKvOverride {
key: "".to_string(),
value: ModelKvOverrideValue::Int(42),
};
assert!(override_empty_key.key.is_empty());
let long_key = "a".repeat(1000);
let override_long_key = ModelKvOverride {
key: long_key.clone(),
value: ModelKvOverrideValue::Str("test".to_string()),
};
assert_eq!(override_long_key.key.len(), 1000);
let override_extreme_int = ModelKvOverride {
key: "extreme_int".to_string(),
value: ModelKvOverrideValue::Int(i64::MAX),
};
let override_extreme_float = ModelKvOverride {
key: "extreme_float".to_string(),
value: ModelKvOverrideValue::Float(f64::MAX),
};
assert_eq!(override_extreme_int.key, "extreme_int");
assert_eq!(override_extreme_float.key, "extreme_float");
}
#[test]
fn test_tensor_split_validation() {
let mut params = ModelParams::default();
params.tensor_split = vec![];
assert!(params.tensor_split.is_empty());
params.tensor_split = vec![1.0];
assert_eq!(params.tensor_split.len(), 1);
params.tensor_split = vec![0.5; 16];
assert_eq!(params.tensor_split.len(), 16);
params.tensor_split = vec![0.0, f32::MAX, f32::MIN, f32::INFINITY, f32::NEG_INFINITY];
assert_eq!(params.tensor_split.len(), 5);
}
}
#[cfg(test)]
mod context_error_tests {
use super::*;
#[test]
fn test_context_params_validation() {
let mut params = ContextParams::default();
params.n_ctx = 0;
assert_eq!(params.n_ctx, 0);
params.n_ctx = u32::MAX;
assert_eq!(params.n_ctx, u32::MAX);
params.n_batch = 0;
assert_eq!(params.n_batch, 0);
params.n_ctx = 1024;
params.n_batch = 2048;
assert!(params.n_batch > params.n_ctx);
params.n_threads = 0;
assert_eq!(params.n_threads, 0);
params.n_threads = i32::MAX;
assert_eq!(params.n_threads, i32::MAX);
}
#[test]
fn test_context_sequence_limits() {
let mut params = ContextParams::default();
params.n_seq_max = 0;
assert_eq!(params.n_seq_max, 0);
params.n_seq_max = 1;
assert_eq!(params.n_seq_max, 1);
params.n_seq_max = 1000;
assert_eq!(params.n_seq_max, 1000);
params.n_seq_max = u32::MAX;
assert_eq!(params.n_seq_max, u32::MAX);
}
#[test]
fn test_context_memory_constraints() {
let mut params = ContextParams::default();
params.n_ctx = 1_000_000;
params.n_batch = 100_000;
params.n_ubatch = 50_000;
assert_eq!(params.n_ctx, 1_000_000);
assert_eq!(params.n_batch, 100_000);
assert_eq!(params.n_ubatch, 50_000);
}
}
#[cfg(test)]
mod sampling_error_tests {
use super::*;
#[test]
fn test_sampler_params_edge_cases() {
let mut params = SamplerParams::default();
params.temperature = 0.0;
assert_eq!(params.temperature, 0.0);
params.temperature = -1.0;
assert_eq!(params.temperature, -1.0);
params.temperature = f32::MAX;
assert_eq!(params.temperature, f32::MAX);
params.temperature = f32::INFINITY;
assert!(params.temperature.is_infinite());
params.top_p = -0.5;
assert_eq!(params.top_p, -0.5);
params.top_p = 1.5;
assert_eq!(params.top_p, 1.5);
params.min_p = 2.0;
assert_eq!(params.min_p, 2.0);
params.top_k = 0;
assert_eq!(params.top_k, 0);
params.top_k = -10;
assert_eq!(params.top_k, -10);
}
#[test]
fn test_logit_bias_edge_cases() {
let bias_negative_token = LogitBias {
token: -1,
bias: 1.0,
};
assert_eq!(bias_negative_token.token, -1);
let bias_max_token = LogitBias {
token: TokenId::MAX,
bias: 0.5,
};
assert_eq!(bias_max_token.token, TokenId::MAX);
let bias_infinite = LogitBias {
token: 100,
bias: f32::INFINITY,
};
assert!(bias_infinite.bias.is_infinite());
let bias_nan = LogitBias {
token: 100,
bias: f32::NAN,
};
assert!(bias_nan.bias.is_nan());
}
#[test]
fn test_token_data_array_edge_cases() {
let empty_array = TokenDataArray::new(vec![]);
assert!(empty_array.is_empty());
assert_eq!(empty_array.len(), 0);
let single_element = TokenDataArray::new(vec![TokenData {
id: 1,
logit: 1.0,
p: 1.0,
}]);
assert_eq!(single_element.len(), 1);
assert!(!single_element.is_empty());
let extreme_values = TokenDataArray::new(vec![
TokenData {
id: TokenId::MAX,
logit: f32::MAX,
p: 1.0,
},
TokenData {
id: TokenId::MIN,
logit: f32::MIN,
p: 0.0,
},
TokenData {
id: 0,
logit: f32::INFINITY,
p: 0.5,
},
TokenData {
id: 1,
logit: f32::NEG_INFINITY,
p: 0.25,
},
TokenData {
id: 2,
logit: f32::NAN,
p: f32::NAN,
},
]);
assert_eq!(extreme_values.len(), 5);
let duplicate_ids = TokenDataArray::new(vec![
TokenData {
id: 1,
logit: 1.0,
p: 0.5,
},
TokenData {
id: 1,
logit: 2.0,
p: 0.5,
},
]);
assert_eq!(duplicate_ids.len(), 2);
}
#[test]
fn test_sampler_chain_edge_cases() {
let params = SamplerChainParams { no_perf: true };
let _chain = SamplerChain::new(params);
}
}
#[cfg(test)]
mod batch_error_tests {
use super::*;
#[test]
fn test_batch_with_empty_tokens() {
let batch = Batch::from_tokens(&[]);
assert!(batch.is_empty());
}
#[test]
fn test_batch_with_extreme_tokens() {
let extreme_tokens = vec![TokenId::MAX, TokenId::MIN, 0, -1, 1000000];
let batch = Batch::from_tokens(&extreme_tokens);
assert!(!batch.is_empty());
}
#[test]
fn test_batch_with_many_tokens() {
let many_tokens: Vec<TokenId> = (0..10000).collect();
let batch = Batch::from_tokens(&many_tokens);
assert!(!batch.is_empty());
}
#[test]
fn test_batch_operations_safety() {
let tokens = vec![1, 2, 3];
let batch = Batch::from_tokens(&tokens);
let _llama_batch = batch.get_llama_batch();
}
}
#[cfg(test)]
mod session_error_tests {
use super::*;
#[test]
fn test_session_with_empty_data() {
let session = Session { data: vec![] };
assert!(session.data.is_empty());
}
#[test]
fn test_session_with_large_data() {
let large_data = vec![0u8; 1_000_000];
let session = Session { data: large_data };
assert_eq!(session.data.len(), 1_000_000);
}
#[test]
fn test_session_with_invalid_data() {
let invalid_data = vec![0xFF; 1000];
let session = Session { data: invalid_data };
assert_eq!(session.data.len(), 1000);
assert!(session.data.iter().all(|&b| b == 0xFF));
}
}
#[cfg(test)]
mod memory_error_tests {
use super::*;
#[test]
fn test_memory_manager_initialization() {
let memory_manager = MemoryManager::new();
assert!(!memory_manager.is_valid());
}
#[test]
fn test_embeddings_edge_cases() {
let empty_embeddings = Embeddings::new(vec![], 0);
assert_eq!(empty_embeddings.len(), 0);
assert_eq!(empty_embeddings.dimension, 0);
let mismatched = Embeddings::new(vec![1.0, 2.0, 3.0], 5);
assert_eq!(mismatched.len(), 0); assert_eq!(mismatched.dimension, 5);
let extreme_embeddings = Embeddings::new(
vec![
f32::MAX,
f32::MIN,
f32::INFINITY,
f32::NEG_INFINITY,
f32::NAN,
],
5,
);
assert_eq!(extreme_embeddings.len(), 1);
assert_eq!(extreme_embeddings.dimension, 5);
}
#[test]
fn test_vocabulary_initialization() {
let vocab = Vocabulary::new();
assert_eq!(vocab._placeholder, 0);
}
}
#[cfg(test)]
mod ffi_error_tests {
use mullama::sys;
#[test]
fn test_backend_multiple_init_free() {
unsafe {
for _ in 0..5 {
sys::llama_backend_init();
sys::llama_backend_free();
}
}
}
#[test]
fn test_system_info_functions() {
unsafe {
let _max_devices = sys::llama_max_devices();
let _max_sequences = sys::llama_max_parallel_sequences();
let _supports_mmap = sys::llama_supports_mmap();
let _supports_mlock = sys::llama_supports_mlock();
let _supports_gpu = sys::llama_supports_gpu_offload();
let _supports_rpc = sys::llama_supports_rpc();
let _system_info = sys::llama_print_system_info();
}
}
}
#[cfg(test)]
mod thread_safety_tests {
use super::*;
use std::sync::{Arc, Barrier};
use std::thread;
#[test]
fn test_concurrent_parameter_creation() {
let barrier = Arc::new(Barrier::new(4));
let mut handles = vec![];
for _ in 0..4 {
let barrier = barrier.clone();
let handle = thread::spawn(move || {
barrier.wait();
let _model_params = ModelParams::default();
let _context_params = ContextParams::default();
let _sampler_params = SamplerParams::default();
});
handles.push(handle);
}
for handle in handles {
handle.join().unwrap();
}
}
#[test]
fn test_concurrent_structure_creation() {
let barrier = Arc::new(Barrier::new(4));
let mut handles = vec![];
for _ in 0..4 {
let barrier = barrier.clone();
let handle = thread::spawn(move || {
barrier.wait();
let _batch = Batch::from_tokens(&[1, 2, 3]);
let _session = Session {
data: vec![1, 2, 3],
};
let _embeddings = Embeddings::new(vec![1.0, 2.0, 3.0], 3);
});
handles.push(handle);
}
for handle in handles {
handle.join().unwrap();
}
}
}
#[cfg(test)]
mod resource_exhaustion_tests {
use super::*;
#[test]
fn test_large_parameter_structures() {
let mut params = ModelParams::default();
match std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
params.tensor_split = vec![1.0; 1_000_000];
})) {
Ok(_) => {
assert_eq!(params.tensor_split.len(), 1_000_000);
}
Err(_) => {
println!("Large allocation failed gracefully");
}
}
}
#[test]
fn test_many_kv_overrides() {
let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
let mut params = ModelParams::default();
for i in 0..10000 {
params.kv_overrides.push(ModelKvOverride {
key: format!("key_{}", i),
value: ModelKvOverrideValue::Int(i as i64),
});
}
params.kv_overrides.len()
}));
match result {
Ok(len) => {
assert_eq!(len, 10000);
}
Err(_) => {
println!("Many KV overrides failed gracefully");
}
}
}
#[test]
fn test_large_token_arrays() {
let result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
let large_tokens: Vec<TokenId> = (0..1_000_000).collect();
let _batch = Batch::from_tokens(&large_tokens);
"success"
}));
match result {
Ok(_) => {
println!("Large token array handled successfully");
}
Err(_) => {
println!("Large token array failed gracefully");
}
}
}
}