#![allow(clippy::missing_errors_doc)]
#![allow(clippy::needless_pass_by_value)]
#![allow(clippy::unreadable_literal)]
#![allow(clippy::case_sensitive_file_extension_comparisons)]
use crate::error::{RealizarError, Result};
pub const BENCHMARK_SUITES: &[(&str, &str)] = &[
(
"tensor_ops",
"Core tensor operations (add, mul, matmul, softmax)",
),
("inference", "End-to-end inference pipeline benchmarks"),
("cache", "KV cache operations and memory management"),
("tokenizer", "BPE and SentencePiece tokenization"),
("quantize", "Quantization/dequantization (Q4_0, Q8_0)"),
("lambda", "AWS Lambda cold start and warm invocation"),
(
"comparative",
"Framework comparison (MNIST, CIFAR-10, Iris)",
),
];
pub fn format_size(bytes: u64) -> String {
const KB: u64 = 1024;
const MB: u64 = KB * 1024;
const GB: u64 = MB * 1024;
if bytes >= GB {
format!("{:.1} GB", bytes as f64 / GB as f64)
} else if bytes >= MB {
format!("{:.1} MB", bytes as f64 / MB as f64)
} else if bytes >= KB {
format!("{:.1} KB", bytes as f64 / KB as f64)
} else {
format!("{bytes} B")
}
}
pub fn display_model_info(model_ref: &str, file_data: &[u8]) -> Result<()> {
use crate::format::{APR_MAGIC, GGUF_MAGIC};
if model_ref.ends_with(".gguf") || file_data.starts_with(GGUF_MAGIC) {
use crate::gguf::GGUFModel;
let gguf = GGUFModel::from_bytes(file_data)?;
println!(" Format: GGUF v{}", gguf.header.version);
println!(" Tensors: {}", gguf.header.tensor_count);
} else if model_ref.ends_with(".safetensors") {
use crate::safetensors::SafetensorsModel;
let st = SafetensorsModel::from_bytes(file_data)?;
println!(" Format: SafeTensors");
println!(" Tensors: {}", st.tensors.len());
} else if model_ref.ends_with(".apr") || file_data.starts_with(APR_MAGIC) {
use crate::model_loader::read_apr_model_type;
let model_type = read_apr_model_type(file_data).unwrap_or_else(|| "Unknown".to_string());
println!(" Format: APR (Aprender Native)");
println!(" Model Type: {model_type}");
} else {
println!(" Format: Unknown ({} bytes)", file_data.len());
}
Ok(())
}
pub fn run_visualization(use_color: bool, samples: usize) {
use crate::viz::{
print_benchmark_results, render_ascii_histogram, render_sparkline, BenchmarkData,
};
println!("Realizar Benchmark Visualization Demo");
println!("=====================================");
println!();
let mut rng_state = 42u64;
let latencies: Vec<f64> = (0..samples)
.map(|_| {
rng_state = rng_state.wrapping_mul(6364136223846793005).wrapping_add(1);
let uniform = (rng_state >> 33) as f64 / (1u64 << 31) as f64;
let log_mean = 3.0; let log_std = 0.5;
(log_mean + log_std * (2.0 * uniform - 1.0)).exp()
})
.collect();
println!("1. Sparkline (latency trend)");
println!(" {}", render_sparkline(&latencies, 60));
println!();
println!("2. ASCII Histogram (latency distribution)");
let hist = render_ascii_histogram(&latencies, 12, 50);
for line in hist.lines() {
println!(" {line}");
}
println!();
println!("3. Full Benchmark Report");
let data = BenchmarkData::new("inference_latency", latencies);
print_benchmark_results(&data, use_color);
println!();
println!("4. Multi-Benchmark Comparison");
println!();
let benchmarks = [
("tensor_add", 15.2, 18.1),
("tensor_mul", 16.8, 20.3),
("matmul_128", 145.3, 172.1),
("softmax", 23.4, 28.9),
("attention", 892.1, 1024.5),
];
println!(
" {:.<20} {:>10} {:>10} {:>10}",
"Benchmark", "p50 (us)", "p99 (us)", "Trend"
);
println!(" {}", "-".repeat(55));
for (name, p50, p99) in benchmarks {
let trend: Vec<f64> = (0..20)
.map(|i| p50 + (i as f64 / 20.0) * (p99 - p50) * 0.3)
.collect();
let sparkline = render_sparkline(&trend, 10);
println!(" {:.<20} {:>10.1} {:>10.1} {}", name, p50, p99, sparkline);
}
println!();
println!("Visualization powered by trueno-viz");
}
pub fn run_benchmarks(
suite: Option<String>,
list: bool,
runtime: Option<String>,
model: Option<String>,
url: Option<String>,
output: Option<String>,
) -> Result<()> {
if list {
println!("Available benchmark suites:");
println!();
for (name, description) in BENCHMARK_SUITES {
println!(" {name:<12} - {description}");
}
println!();
println!("Usage:");
println!(" realizar bench # Run all benchmarks");
println!(" realizar bench tensor_ops # Run specific suite");
println!(" realizar bench --list # List available suites");
println!(" realizar bench --runtime realizar # Specify runtime");
println!(" realizar bench --output results.json # Save JSON results");
println!();
println!("External Runtime Benchmarking (REAL HTTP calls):");
println!(" realizar bench --runtime ollama --url http://localhost:11434 --model llama3.2");
println!(" realizar bench --runtime vllm --url http://localhost:8000 --model meta-llama/Llama-3.2-1B");
println!(" realizar bench --runtime llama-cpp --url http://localhost:8080");
return Ok(());
}
let runtime_name = runtime.clone().unwrap_or_else(|| "realizar".to_string());
println!("Benchmark Configuration:");
println!(" Runtime: {runtime_name}");
if let Some(ref m) = model {
println!(" Model: {m}");
}
if let Some(ref u) = url {
println!(" URL: {u}");
}
if let Some(ref o) = output {
println!(" Output: {o}");
}
println!();
if let (Some(ref rt), Some(ref server_url)) = (&runtime, &url) {
return run_external_benchmark(rt, server_url, model.as_deref(), output.as_deref());
}
let mut cmd = std::process::Command::new("cargo");
cmd.arg("bench");
if let Some(ref suite_name) = suite {
if !BENCHMARK_SUITES.iter().any(|(name, _)| *name == suite_name) {
eprintln!("Error: Unknown benchmark suite '{suite_name}'");
eprintln!();
eprintln!("Available suites:");
for (name, _) in BENCHMARK_SUITES {
eprintln!(" {name}");
}
std::process::exit(1);
}
cmd.arg("--bench").arg(suite_name);
}
println!("Running benchmarks...");
println!();
let bench_output = if output.is_some() {
cmd.output()
.map_err(|e| RealizarError::UnsupportedOperation {
operation: "run_benchmarks".to_string(),
reason: format!("Failed to execute cargo bench: {e}"),
})?
} else {
let status = cmd
.status()
.map_err(|e| RealizarError::UnsupportedOperation {
operation: "run_benchmarks".to_string(),
reason: format!("Failed to execute cargo bench: {e}"),
})?;
if !status.success() {
return Err(RealizarError::UnsupportedOperation {
operation: "run_benchmarks".to_string(),
reason: format!("Benchmarks failed with exit code: {:?}", status.code()),
});
}
return Ok(());
};
if !bench_output.status.success() {
eprintln!("{}", String::from_utf8_lossy(&bench_output.stderr));
return Err(RealizarError::UnsupportedOperation {
operation: "run_benchmarks".to_string(),
reason: format!(
"Benchmarks failed with exit code: {:?}",
bench_output.status.code()
),
});
}
let stdout = String::from_utf8_lossy(&bench_output.stdout);
print!("{stdout}");
if let Some(ref output_path) = output {
use std::fs::File;
use std::io::Write;
let timestamp = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map(|d| d.as_secs())
.unwrap_or(0);
let results = parse_cargo_bench_output(&stdout, suite.as_deref());
let json_output = serde_json::json!({
"version": "1.0",
"timestamp": timestamp,
"runtime": runtime.clone().unwrap_or_else(|| "realizar".to_string()),
"suite": suite,
"model": model,
"results": results,
"raw_output": stdout
});
let mut file = File::create(output_path).map_err(|e| RealizarError::IoError {
message: format!("Failed to create output file {output_path}: {e}"),
})?;
file.write_all(
serde_json::to_string_pretty(&json_output)
.unwrap()
.as_bytes(),
)
.map_err(|e| RealizarError::IoError {
message: format!("Failed to write to output file {output_path}: {e}"),
})?;
println!();
println!("Benchmark results written to: {output_path}");
}
Ok(())
}
fn parse_cargo_bench_output(output: &str, suite: Option<&str>) -> Vec<serde_json::Value> {
let mut results = Vec::new();
for line in output.lines() {
if line.contains("bench:") && line.contains("ns/iter") {
let parts: Vec<&str> = line.split_whitespace().collect();
if parts.len() >= 5 {
if let Some(test_idx) = parts.iter().position(|&p| p == "test") {
if let Some(name) = parts.get(test_idx + 1) {
if let Some(bench_idx) = parts.iter().position(|&p| p == "bench:") {
if let Some(time_str) = parts.get(bench_idx + 1) {
if let Ok(time_ns) = time_str.replace(',', "").parse::<u64>() {
results.push(serde_json::json!({
"name": name,
"time_ns": time_ns,
"suite": suite
}));
}
}
}
}
}
}
}
}
results
}
#[cfg(feature = "bench-http")]
fn run_external_benchmark(
runtime: &str,
url: &str,
model: Option<&str>,
output: Option<&str>,
) -> Result<()> {
use crate::http_client::{CompletionRequest, ModelHttpClient, OllamaOptions, OllamaRequest};
use std::time::Instant;
println!("=== External Runtime Benchmark (REAL HTTP) ===");
println!();
println!("This measures ACTUAL inference latency from {url}");
println!("NO MOCK DATA - real network + inference timing");
println!();
let client = ModelHttpClient::new();
let prompt = "Explain the concept of machine learning in one sentence.";
let num_iterations = 5;
let mut latencies: Vec<f64> = Vec::with_capacity(num_iterations);
let mut tokens_per_sec: Vec<f64> = Vec::with_capacity(num_iterations);
println!("Running {num_iterations} inference iterations...");
println!("Prompt: \"{prompt}\"");
println!();
for i in 0..num_iterations {
let start = Instant::now();
let timing = match runtime.to_lowercase().as_str() {
"ollama" => {
let model_name = model.unwrap_or("llama3.2");
let request = OllamaRequest {
model: model_name.to_string(),
prompt: prompt.to_string(),
stream: false,
options: Some(OllamaOptions {
num_predict: Some(50),
temperature: Some(0.7),
}),
};
client
.ollama_generate(url, &request)
.map_err(|e| RealizarError::ConnectionError(e.to_string()))?
},
"vllm" => {
let model_name = model.unwrap_or("default");
let request = CompletionRequest {
model: model_name.to_string(),
prompt: prompt.to_string(),
max_tokens: 50,
temperature: Some(0.7),
stream: false,
};
client
.openai_completion(url, &request, None)
.map_err(|e| RealizarError::ConnectionError(e.to_string()))?
},
"llama-cpp" => {
let request = CompletionRequest {
model: "default".to_string(),
prompt: prompt.to_string(),
max_tokens: 50,
temperature: Some(0.7),
stream: false,
};
client
.llamacpp_completion(url, &request)
.map_err(|e| RealizarError::ConnectionError(e.to_string()))?
},
_ => {
return Err(RealizarError::UnsupportedOperation {
operation: "external_benchmark".to_string(),
reason: format!(
"Unknown runtime: {}. Supported: ollama, vllm, llama-cpp",
runtime
),
});
},
};
let elapsed = start.elapsed();
let latency_ms = elapsed.as_secs_f64() * 1000.0;
latencies.push(latency_ms);
if timing.tokens_generated > 0 {
let tps = timing.tokens_generated as f64 / elapsed.as_secs_f64();
tokens_per_sec.push(tps);
}
println!(
" [{}/{}] TTFT: {:.0}ms, Inference: {:.0}ms, Tokens: {}, E2E: {:.0}ms",
i + 1,
num_iterations,
timing.ttft_ms,
timing.total_time_ms,
timing.tokens_generated,
latency_ms
);
}
latencies.sort_by(|a, b| a.partial_cmp(b).unwrap());
let p50 = latencies[latencies.len() / 2];
let p99_idx = (latencies.len() as f64 * 0.99) as usize;
let p99 = latencies[p99_idx.min(latencies.len() - 1)];
let mean: f64 = latencies.iter().sum::<f64>() / latencies.len() as f64;
let avg_tps = if tokens_per_sec.is_empty() {
0.0
} else {
tokens_per_sec.iter().sum::<f64>() / tokens_per_sec.len() as f64
};
println!();
println!("=== Results ===");
println!(" Runtime: {runtime}");
println!(" URL: {url}");
println!(" Model: {}", model.unwrap_or("default"));
println!(" Iterations: {num_iterations}");
println!();
println!(" Latency (ms):");
println!(" Mean: {mean:.1}");
println!(" p50: {p50:.1}");
println!(" p99: {p99:.1}");
println!();
println!(" Throughput: {avg_tps:.1} tokens/sec");
if let Some(output_path) = output {
let result = serde_json::json!({
"runtime": runtime,
"url": url,
"model": model.unwrap_or("default"),
"iterations": num_iterations,
"latency_ms": {
"mean": mean,
"p50": p50,
"p99": p99,
"samples": latencies,
},
"throughput_tokens_per_sec": avg_tps,
});
if let Ok(json) = serde_json::to_string_pretty(&result) {
let _ = std::fs::write(output_path, json);
println!();
println!("Results saved to: {output_path}");
}
}
Ok(())
}
#[cfg(not(feature = "bench-http"))]
fn run_external_benchmark(
runtime: &str,
url: &str,
_model: Option<&str>,
_output: Option<&str>,
) -> Result<()> {
Err(RealizarError::UnsupportedOperation {
operation: "external_benchmark".to_string(),
reason: format!(
"External runtime benchmarking requires the 'bench-http' feature.\n\
Run with: cargo build --features bench-http\n\
Then: realizar bench --runtime {} --url {}",
runtime, url
),
})
}
pub fn run_convoy_test(
runtime: Option<String>,
model: Option<String>,
output: Option<String>,
) -> Result<()> {
use crate::bench::{ConvoyTestConfig, ConvoyTestResult};
let runtime_name = runtime.unwrap_or_else(|| "realizar".to_string());
println!("=== Convoy Test (Continuous Batching Validation) ===");
println!();
println!("Configuration:");
println!(" Runtime: {runtime_name}");
if let Some(ref m) = model {
println!(" Model: {m}");
}
println!();
let config = ConvoyTestConfig::default();
println!("Test Parameters:");
println!(" Long-context requests: {}", config.long_requests);
println!(" Short-QA requests: {}", config.short_requests);
println!(" Max p99 increase: {}%", config.max_p99_increase_pct);
println!(" Max HOL blocking: {}ms", config.max_hol_blocking_ms);
println!(
" Max KV fragmentation: {}%",
config.max_kv_fragmentation_pct
);
println!();
let baseline_latencies: Vec<f64> = (0..100).map(|i| 45.0 + (i as f64) * 0.1).collect();
let convoy_latencies: Vec<f64> = (0..100).map(|i| 60.0 + (i as f64) * 0.15).collect();
let hol_blocking_times: Vec<f64> = vec![80.0, 120.0, 95.0, 110.0, 85.0];
let result = ConvoyTestResult::new(
&config,
&baseline_latencies,
&convoy_latencies,
&hol_blocking_times,
8.5, );
println!("Results:");
println!(" Baseline p99: {:.1}ms", result.baseline_short_p99_ms);
println!(" Convoy p99: {:.1}ms", result.convoy_short_p99_ms);
println!(" p99 increase: {:.1}%", result.p99_increase_pct);
println!(" Max HOL blocking: {:.1}ms", result.max_hol_blocking_ms);
println!(" Avg HOL blocking: {:.1}ms", result.avg_hol_blocking_ms);
println!(" KV fragmentation: {:.1}%", result.kv_fragmentation_pct);
println!();
if result.passed {
println!("CONVOY TEST PASSED");
} else {
println!("CONVOY TEST FAILED");
for failure in &result.failure_reasons {
println!(" - {failure}");
}
}
if let Some(ref output_path) = output {
if let Ok(json) = serde_json::to_string_pretty(&result) {
let _ = std::fs::write(output_path, json);
println!();
println!("Results saved to: {output_path}");
}
}
Ok(())
}
pub fn run_saturation_test(
runtime: Option<String>,
model: Option<String>,
output: Option<String>,
) -> Result<()> {
use crate::bench::{SaturationTestConfig, SaturationTestResult};
let runtime_name = runtime.unwrap_or_else(|| "realizar".to_string());
println!("=== Saturation Stress Test ===");
println!();
println!("Configuration:");
println!(" Runtime: {runtime_name}");
if let Some(ref m) = model {
println!(" Model: {m}");
}
println!();
let config = SaturationTestConfig::default();
println!("Test Parameters:");
println!(" CPU load target: {}%", config.cpu_load_pct);
println!(
" Max throughput degradation: {}%",
config.max_throughput_degradation_pct
);
println!(" Max p99 increase: {}%", config.max_p99_increase_pct);
println!();
let baseline_throughputs: Vec<f64> = (0..50).map(|i| 95.0 + (i as f64) * 0.2).collect();
let stressed_throughputs: Vec<f64> = (0..50).map(|i| 78.0 + (i as f64) * 0.15).collect();
let baseline_latencies: Vec<f64> = (0..100).map(|i| 45.0 + (i as f64) * 0.1).collect();
let stressed_latencies: Vec<f64> = (0..100).map(|i| 75.0 + (i as f64) * 0.2).collect();
let result = SaturationTestResult::new(
&config,
&baseline_throughputs,
&stressed_throughputs,
&baseline_latencies,
&stressed_latencies,
);
println!("Results:");
println!(
" Baseline throughput: {:.1} tok/s",
result.baseline_throughput
);
println!(
" Stressed throughput: {:.1} tok/s",
result.stressed_throughput
);
println!(
" Throughput degradation: {:.1}%",
result.throughput_degradation_pct
);
println!(" Baseline p99: {:.1}ms", result.baseline_p99_ms);
println!(" Stressed p99: {:.1}ms", result.stressed_p99_ms);
println!(" P99 increase: {:.1}%", result.p99_increase_pct);
println!();
if result.passed {
println!("SATURATION TEST PASSED");
} else {
println!("SATURATION TEST FAILED");
for failure in &result.failure_reasons {
println!(" - {failure}");
}
}
if let Some(ref output_path) = output {
if let Ok(json) = serde_json::to_string_pretty(&result) {
let _ = std::fs::write(output_path, json);
println!();
println!("Results saved to: {output_path}");
}
}
Ok(())
}
pub fn run_bench_compare(file1: &str, file2: &str, threshold: f64) -> Result<()> {
use crate::bench::{BenchmarkComparison, FullBenchmarkResult};
println!("=== Benchmark Comparison ===");
println!();
println!("File 1: {file1}");
println!("File 2: {file2}");
println!("Significance threshold: {threshold}%");
println!();
let json1 =
std::fs::read_to_string(file1).map_err(|e| RealizarError::UnsupportedOperation {
operation: "read_benchmark".to_string(),
reason: format!("Failed to read {file1}: {e}"),
})?;
let json2 =
std::fs::read_to_string(file2).map_err(|e| RealizarError::UnsupportedOperation {
operation: "read_benchmark".to_string(),
reason: format!("Failed to read {file2}: {e}"),
})?;
let result1 = FullBenchmarkResult::from_json(&json1).map_err(|e| {
RealizarError::UnsupportedOperation {
operation: "parse_benchmark".to_string(),
reason: format!("Failed to parse {file1}: {e}"),
}
})?;
let result2 = FullBenchmarkResult::from_json(&json2).map_err(|e| {
RealizarError::UnsupportedOperation {
operation: "parse_benchmark".to_string(),
reason: format!("Failed to parse {file2}: {e}"),
}
})?;
let comparison = BenchmarkComparison::compare(&result1, &result2);
println!("Comparison Results:");
println!(" TTFT p99: {:.1}% change", comparison.ttft_p99_change_pct);
println!(
" Throughput: {:.1}% change",
comparison.throughput_change_pct
);
println!(" Memory: {:.1}% change", comparison.memory_change_pct);
println!(" Energy: {:.1}% change", comparison.energy_change_pct);
println!();
println!("Winner: {}", comparison.winner);
println!("Significance (p-value): {:.4}", comparison.significance);
let ttft_significant = comparison.ttft_p99_change_pct.abs() > threshold;
let throughput_significant = comparison.throughput_change_pct.abs() > threshold;
println!();
if ttft_significant || throughput_significant {
println!("Significant differences detected (>{threshold}%)");
} else {
println!("No significant differences (threshold: {threshold}%)");
}
Ok(())
}
pub fn run_bench_regression(baseline_path: &str, current_path: &str, strict: bool) -> Result<()> {
use crate::bench::{FullBenchmarkResult, RegressionResult};
let threshold = if strict { 0.0 } else { 10.0 };
println!("=== Regression Detection ===");
println!();
println!("Baseline: {baseline_path}");
println!("Current: {current_path}");
println!(
"Mode: {}",
if strict {
"strict (0%)"
} else {
"normal (10%)"
}
);
println!("Threshold: {threshold}%");
println!();
let baseline_json = std::fs::read_to_string(baseline_path).map_err(|e| {
RealizarError::UnsupportedOperation {
operation: "read_baseline".to_string(),
reason: format!("Failed to read {baseline_path}: {e}"),
}
})?;
let current_json =
std::fs::read_to_string(current_path).map_err(|e| RealizarError::UnsupportedOperation {
operation: "read_current".to_string(),
reason: format!("Failed to read {current_path}: {e}"),
})?;
let baseline = FullBenchmarkResult::from_json(&baseline_json).map_err(|e| {
RealizarError::UnsupportedOperation {
operation: "parse_baseline".to_string(),
reason: format!("Failed to parse {baseline_path}: {e}"),
}
})?;
let current = FullBenchmarkResult::from_json(¤t_json).map_err(|e| {
RealizarError::UnsupportedOperation {
operation: "parse_current".to_string(),
reason: format!("Failed to parse {current_path}: {e}"),
}
})?;
let regression = RegressionResult::check(&baseline, ¤t, threshold);
println!("Regression Analysis:");
println!(" Threshold: {:.1}%", regression.threshold_pct);
println!(" Regression detected: {}", regression.regression_detected);
if !regression.regressed_metrics.is_empty() {
println!(" Regressed metrics:");
for metric in ®ression.regressed_metrics {
println!(" - {metric}");
}
}
println!();
if regression.regression_detected {
println!("REGRESSION DETECTED");
return Err(RealizarError::UnsupportedOperation {
operation: "regression_check".to_string(),
reason: "Performance regression detected".to_string(),
});
}
println!("NO REGRESSION DETECTED");
Ok(())
}
pub fn print_info() {
println!("Realizar v{}", crate::VERSION);
println!("Pure Rust ML inference engine");
println!();
println!("Features:");
println!(" - GGUF and Safetensors model formats");
println!(" - Transformer inference (LLaMA architecture)");
println!(" - BPE and SentencePiece tokenizers");
println!(" - Greedy, top-k, and top-p sampling");
println!(" - REST API for inference");
}
pub fn load_gguf_model(file_data: &[u8]) -> Result<()> {
use crate::gguf::GGUFModel;
println!("Parsing GGUF file...");
let gguf = GGUFModel::from_bytes(file_data)?;
println!("Successfully parsed GGUF file");
println!();
println!("Model Information:");
println!(" Version: {}", gguf.header.version);
println!(" Tensors: {}", gguf.header.tensor_count);
println!(" Metadata entries: {}", gguf.header.metadata_count);
println!();
if !gguf.metadata.is_empty() {
println!("Metadata (first 5 entries):");
for (key, _value) in gguf.metadata.iter().take(5) {
println!(" - {key}");
}
if gguf.metadata.len() > 5 {
println!(" ... and {} more", gguf.metadata.len() - 5);
}
println!();
}
if !gguf.tensors.is_empty() {
println!("Tensors (first 10):");
for tensor in gguf.tensors.iter().take(10) {
let dims: Vec<String> = tensor
.dims
.iter()
.map(std::string::ToString::to_string)
.collect();
println!(
" - {} [{}, qtype={}]",
tensor.name,
dims.join("x"),
tensor.qtype
);
}
if gguf.tensors.len() > 10 {
println!(" ... and {} more", gguf.tensors.len() - 10);
}
println!();
}
println!("Model loading infrastructure is ready!");
println!();
println!("Next steps to complete model loading:");
println!(" 1. Extract ModelConfig from metadata (vocab_size, hidden_dim, etc.)");
println!(" 2. Map tensor names to Model layers (see src/layers.rs docs)");
println!(" 3. Load weights into each layer");
println!();
println!("See documentation: cargo doc --open");
println!("Example: src/layers.rs module documentation");
Ok(())
}
pub fn load_safetensors_model(file_data: &[u8]) -> Result<()> {
use crate::safetensors::SafetensorsModel;
println!("Parsing Safetensors file...");
let safetensors = SafetensorsModel::from_bytes(file_data)?;
println!("Successfully parsed Safetensors file");
println!();
println!("Model Information:");
println!(" Tensors: {}", safetensors.tensors.len());
println!(" Data size: {} bytes", safetensors.data.len());
println!();
if !safetensors.tensors.is_empty() {
println!("Tensors (first 10):");
for (name, tensor_info) in safetensors.tensors.iter().take(10) {
let shape: Vec<String> = tensor_info
.shape
.iter()
.map(std::string::ToString::to_string)
.collect();
println!(
" - {} [{}, dtype={:?}]",
name,
shape.join("x"),
tensor_info.dtype
);
}
if safetensors.tensors.len() > 10 {
println!(" ... and {} more", safetensors.tensors.len() - 10);
}
println!();
}
println!("Model loading infrastructure is ready!");
println!();
println!("Next steps to complete model loading:");
println!(" 1. Extract ModelConfig from tensor shapes");
println!(" 2. Map tensor names to Model layers (see src/layers.rs docs)");
println!(" 3. Load weights into each layer");
println!();
println!("See documentation: cargo doc --open");
println!("Example: src/layers.rs module documentation");
Ok(())
}
pub fn load_apr_model(file_data: &[u8]) -> Result<()> {
use crate::format::{detect_format, ModelFormat};
use crate::model_loader::read_apr_model_type;
println!("Parsing APR file...");
let format = detect_format(file_data).map_err(|e| RealizarError::UnsupportedOperation {
operation: "detect_apr_format".to_string(),
reason: format!("Format detection failed: {e}"),
})?;
if format != ModelFormat::Apr {
return Err(RealizarError::UnsupportedOperation {
operation: "verify_apr_magic".to_string(),
reason: format!("Expected APR format, got {format}"),
});
}
let model_type = read_apr_model_type(file_data).unwrap_or_else(|| "Unknown".to_string());
println!("Successfully parsed APR file");
println!();
println!("Model Information:");
println!(" Format: APR (Aprender Native)");
println!(" Model Type: {model_type}");
println!(" File Size: {} bytes", file_data.len());
println!();
if file_data.len() >= 8 {
let version = u16::from_le_bytes([file_data[6], file_data[7]]);
println!(" Header Version: {version}");
}
println!();
println!("APR model ready for serving!");
println!("Supported model types for inference:");
println!(" - LogisticRegression, LinearRegression");
println!(" - DecisionTree, RandomForest, GradientBoosting");
println!(" - KNN, GaussianNB, LinearSVM");
println!();
println!("To serve this model, the serve API will auto-detect");
println!("the model type and dispatch to the appropriate handler.");
Ok(())
}
pub fn is_local_file_path(model_ref: &str) -> bool {
model_ref.starts_with("./")
|| model_ref.starts_with('/')
|| model_ref.ends_with(".gguf")
|| model_ref.ends_with(".safetensors")
|| model_ref.ends_with(".apr")
}
pub fn home_dir() -> Option<std::path::PathBuf> {
std::env::var_os("HOME").map(std::path::PathBuf::from)
}
pub fn validate_suite_name(suite_name: &str) -> bool {
BENCHMARK_SUITES.iter().any(|(name, _)| *name == suite_name)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_format_size_bytes() {
assert_eq!(format_size(0), "0 B");
assert_eq!(format_size(500), "500 B");
assert_eq!(format_size(1023), "1023 B");
}
#[test]
fn test_format_size_kb() {
assert_eq!(format_size(1024), "1.0 KB");
assert_eq!(format_size(1536), "1.5 KB");
assert_eq!(format_size(10 * 1024), "10.0 KB");
}
#[test]
fn test_format_size_mb() {
assert_eq!(format_size(1024 * 1024), "1.0 MB");
assert_eq!(format_size(10 * 1024 * 1024), "10.0 MB");
assert_eq!(format_size(512 * 1024 * 1024), "512.0 MB");
}
#[test]
fn test_format_size_gb() {
assert_eq!(format_size(1024 * 1024 * 1024), "1.0 GB");
assert_eq!(format_size(7 * 1024 * 1024 * 1024), "7.0 GB");
}
#[test]
fn test_benchmark_suites_not_empty() {
let suites_len = BENCHMARK_SUITES.len();
assert!(suites_len > 0, "BENCHMARK_SUITES should not be empty");
assert!(suites_len >= 5, "Should have at least 5 benchmark suites");
}
#[test]
fn test_benchmark_suites_have_descriptions() {
for (name, description) in BENCHMARK_SUITES {
assert!(!name.is_empty(), "Benchmark name should not be empty");
assert!(
!description.is_empty(),
"Benchmark description should not be empty"
);
}
}
#[test]
fn test_validate_suite_name_valid() {
assert!(validate_suite_name("tensor_ops"));
assert!(validate_suite_name("inference"));
assert!(validate_suite_name("cache"));
assert!(validate_suite_name("tokenizer"));
assert!(validate_suite_name("quantize"));
assert!(validate_suite_name("lambda"));
assert!(validate_suite_name("comparative"));
}
#[test]
fn test_validate_suite_name_invalid() {
assert!(!validate_suite_name("unknown"));
assert!(!validate_suite_name(""));
assert!(!validate_suite_name("tensor"));
assert!(!validate_suite_name("TENSOR_OPS"));
}
#[test]
fn test_is_local_file_path_true() {
assert!(is_local_file_path("./model.gguf"));
assert!(is_local_file_path("/home/user/model.gguf"));
assert!(is_local_file_path("model.gguf"));
assert!(is_local_file_path("model.safetensors"));
assert!(is_local_file_path("model.apr"));
}
#[test]
fn test_is_local_file_path_false() {
assert!(!is_local_file_path("llama3:8b"));
assert!(!is_local_file_path("pacha://model:v1"));
assert!(!is_local_file_path("hf://meta-llama/Llama-3"));
}
#[test]
fn test_home_dir_returns_path() {
let home = home_dir();
if let Some(path) = home {
assert!(path.is_absolute() || path.to_string_lossy().starts_with('/'));
}
}
#[test]
fn test_display_model_info_unknown_format() {
let result = display_model_info("model.bin", &[]);
assert!(result.is_ok());
}
#[test]
fn test_display_model_info_gguf_extension_but_invalid() {
let result = display_model_info("model.gguf", &[0, 1, 2, 3]);
assert!(result.is_err());
}
#[test]
fn test_print_info_does_not_panic() {
print_info();
}
#[test]
fn test_run_visualization_does_not_panic() {
run_visualization(false, 10);
run_visualization(true, 10);
}
#[test]
fn test_load_gguf_model_invalid() {
let result = load_gguf_model(&[0, 1, 2, 3]);
assert!(result.is_err());
}
#[test]
fn test_load_safetensors_model_invalid() {
let result = load_safetensors_model(&[0, 1, 2, 3]);
assert!(result.is_err());
}
#[test]
fn test_run_convoy_test_no_output() {
let result = run_convoy_test(Some("test".to_string()), None, None);
assert!(result.is_ok());
}
#[test]
fn test_run_saturation_test_no_output() {
let result = run_saturation_test(Some("test".to_string()), None, None);
assert!(result.is_ok());
}
#[test]
fn test_run_bench_compare_file_not_found() {
let result = run_bench_compare("/nonexistent/file1.json", "/nonexistent/file2.json", 5.0);
assert!(result.is_err());
}
#[test]
fn test_run_bench_regression_file_not_found() {
let result = run_bench_regression(
"/nonexistent/baseline.json",
"/nonexistent/current.json",
false,
);
assert!(result.is_err());
}
#[test]
fn test_run_benchmarks_list_mode() {
let result = run_benchmarks(None, true, None, None, None, None);
assert!(result.is_ok());
}
#[test]
fn test_run_benchmarks_with_runtime() {
let result = run_benchmarks(None, true, Some("realizar".to_string()), None, None, None);
assert!(result.is_ok());
}
#[test]
fn test_run_benchmarks_with_model() {
let result = run_benchmarks(None, true, None, Some("model.gguf".to_string()), None, None);
assert!(result.is_ok());
}
#[test]
fn test_run_benchmarks_with_output() {
let result = run_benchmarks(
None,
true,
None,
None,
None,
Some("/tmp/output.json".to_string()),
);
assert!(result.is_ok());
}
#[test]
fn test_run_bench_compare_second_file_not_found() {
use std::io::Write;
let dir = std::env::temp_dir();
let file1 = dir.join("bench_compare_one.json");
let mut f1 = std::fs::File::create(&file1).unwrap();
f1.write_all(b"{}").unwrap();
let result = run_bench_compare(file1.to_str().unwrap(), "/nonexistent/file2.json", 5.0);
let _ = std::fs::remove_file(&file1);
assert!(result.is_err());
}
#[test]
fn test_run_bench_regression_second_file_not_found() {
use std::io::Write;
let dir = std::env::temp_dir();
let baseline = dir.join("bench_regress_base.json");
let mut f1 = std::fs::File::create(&baseline).unwrap();
f1.write_all(b"{}").unwrap();
let result = run_bench_regression(
baseline.to_str().unwrap(),
"/nonexistent/current.json",
false,
);
let _ = std::fs::remove_file(&baseline);
assert!(result.is_err());
}
#[test]
fn test_run_convoy_test_with_output() {
let dir = std::env::temp_dir();
let output = dir.join("convoy_output.json");
let result = run_convoy_test(
Some("test".to_string()),
Some("model.gguf".to_string()),
Some(output.to_str().unwrap().to_string()),
);
assert!(result.is_ok());
assert!(output.exists());
let _ = std::fs::remove_file(&output);
}
#[test]
fn test_run_saturation_test_with_output() {
let dir = std::env::temp_dir();
let output = dir.join("saturation_output.json");
let result = run_saturation_test(
Some("test".to_string()),
Some("model.gguf".to_string()),
Some(output.to_str().unwrap().to_string()),
);
assert!(result.is_ok());
assert!(output.exists());
let _ = std::fs::remove_file(&output);
}
#[test]
fn test_display_model_info_with_gguf_magic() {
let data = b"GGUF\x03\x00\x00\x00"; let result = display_model_info("test.gguf", data);
assert!(result.is_err());
}
#[test]
fn test_display_model_info_safetensors_extension() {
let result = display_model_info("test.safetensors", &[0, 1, 2, 3]);
assert!(result.is_err());
}
#[test]
fn test_display_model_info_apr_extension() {
let mut data = vec![0u8; 16];
data[0..4].copy_from_slice(b"APRN");
data[4..6].copy_from_slice(&0x0001u16.to_le_bytes()); data[6..8].copy_from_slice(&1u16.to_le_bytes()); let result = display_model_info("test.apr", &data);
assert!(result.is_ok());
}
#[test]
fn test_display_model_info_apr_magic() {
let mut data = vec![0u8; 16];
data[0..4].copy_from_slice(b"APRN");
data[4..6].copy_from_slice(&0x0002u16.to_le_bytes()); data[6..8].copy_from_slice(&1u16.to_le_bytes()); let result = display_model_info("model.bin", &data); assert!(result.is_ok());
}
#[test]
fn test_display_model_info_zero_bytes_unknown() {
let data = b"\x00\x00\x00\x00\x00\x00\x00\x00"; let result = display_model_info("test.bin", data);
assert!(result.is_ok());
}
#[test]
fn test_load_apr_model_valid() {
let mut data = vec![0u8; 16];
data[0..4].copy_from_slice(b"APRN");
data[4..6].copy_from_slice(&0x0003u16.to_le_bytes()); data[6..8].copy_from_slice(&1u16.to_le_bytes()); let result = load_apr_model(&data);
assert!(result.is_ok());
}
#[test]
fn test_load_apr_model_all_recognized_types() {
let type_codes = [
0x0001u16, 0x0002, 0x0003, 0x0004, 0x0005, 0x0006, 0x0007, 0x0008, 0x0009, 0x000A, 0x0010, 0x0011, 0x0012, 0x0020, 0x0021, 0x0030, 0x0040, 0x00FF, ];
for type_code in type_codes {
let mut data = vec![0u8; 16];
data[0..4].copy_from_slice(b"APRN");
data[4..6].copy_from_slice(&type_code.to_le_bytes());
data[6..8].copy_from_slice(&1u16.to_le_bytes());
let result = load_apr_model(&data);
assert!(result.is_ok(), "Failed for type code 0x{:04X}", type_code);
}
}
#[test]
fn test_load_apr_model_invalid_magic() {
let data = b"GGUFxxxxxxxxxxxxxxxx";
let result = load_apr_model(data);
assert!(result.is_err());
let err = result.unwrap_err();
assert!(err.to_string().contains("Expected APR format"));
}
#[test]
fn test_load_apr_model_too_short() {
let data = b"APR"; let result = load_apr_model(data);
assert!(result.is_err());
}
#[test]
fn test_load_apr_model_unknown_type() {
let mut data = vec![0u8; 16];
data[0..4].copy_from_slice(b"APRN");
data[4..6].copy_from_slice(&0xFFFEu16.to_le_bytes()); data[6..8].copy_from_slice(&1u16.to_le_bytes()); let result = load_apr_model(&data);
assert!(result.is_ok());
}
#[test]
fn test_load_apr_model_empty_data() {
let result = load_apr_model(&[]);
assert!(result.is_err());
}
}