use std::io::{self, Write};
use std::time::Duration;
use anyhow::{Context, Result};
use indicatif::{ProgressBar, ProgressStyle};
use serde::{Deserialize, Serialize};
use sysinfo::System;
use crate::llm::ollama::OllamaClient;
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub struct RecommendedModel {
pub name: String,
pub ram_gb: f32,
pub speed_rating: u8,
pub quality_rating: u8,
pub installed: bool,
}
pub const RECOMMENDED_MODELS: &[RecommendedModel] = &[
RecommendedModel {
name: String::new(), ram_gb: 4.0,
speed_rating: 5,
quality_rating: 3,
installed: false,
},
RecommendedModel {
name: String::new(),
ram_gb: 6.0,
speed_rating: 4,
quality_rating: 4,
installed: false,
},
RecommendedModel {
name: String::new(),
ram_gb: 8.0,
speed_rating: 3,
quality_rating: 5,
installed: false,
},
RecommendedModel {
name: String::new(),
ram_gb: 3.0,
speed_rating: 5,
quality_rating: 2,
installed: false,
},
];
fn get_default_models() -> Vec<RecommendedModel> {
vec![
RecommendedModel {
name: "llama3.2:3b".to_string(),
ram_gb: 4.0,
speed_rating: 5,
quality_rating: 3,
installed: false,
},
RecommendedModel {
name: "mistral:7b".to_string(),
ram_gb: 6.0,
speed_rating: 4,
quality_rating: 4,
installed: false,
},
RecommendedModel {
name: "llama3.1:8b".to_string(),
ram_gb: 8.0,
speed_rating: 3,
quality_rating: 5,
installed: false,
},
RecommendedModel {
name: "phi3:mini".to_string(),
ram_gb: 3.0,
speed_rating: 5,
quality_rating: 2,
installed: false,
},
]
}
pub async fn get_available_models(client: &OllamaClient) -> Result<Vec<RecommendedModel>> {
let installed_models = client
.list_models()
.await
.context("Failed to fetch installed models from Ollama")?;
let installed_names: Vec<String> = installed_models.iter().map(|m| m.name.clone()).collect();
let mut models = get_default_models();
for model in &mut models {
model.installed = installed_names.iter().any(|name| {
name == &model.name
|| name
== &format!(
"{}:latest",
model.name.split(':').next().unwrap_or(&model.name)
)
|| model.name == format!("{}:latest", name.split(':').next().unwrap_or(name))
});
}
Ok(models)
}
#[must_use]
pub fn get_system_ram_gb() -> f32 {
let sys = System::new_all();
let total_memory_bytes = sys.total_memory();
#[allow(clippy::cast_precision_loss)]
let ram_gb = total_memory_bytes as f32 / (1024.0 * 1024.0 * 1024.0);
ram_gb
}
pub fn display_model_selection(models: &[RecommendedModel], system_ram_gb: f32) -> Result<usize> {
let recommended_idx = get_recommended_model_index(system_ram_gb);
println!("\nAvailable Models:\n");
println!(
" # {:<14} {:<7} {:<7} {:<7} Status",
"Model", "RAM", "Speed", "Quality"
);
println!("{}", "-".repeat(60));
for (i, model) in models.iter().enumerate() {
let speed_stars = format_stars(model.speed_rating);
let quality_stars = format_stars(model.quality_rating);
let status = if model.installed && i == recommended_idx {
"[Installed] [Recommended]"
} else if model.installed {
"[Installed]"
} else if i == recommended_idx {
"[Recommended]"
} else {
""
};
println!(
" {} {:<14} {:<7} {:<7} {:<7} {}",
i + 1,
model.name,
format!("{:.0} GB", model.ram_gb),
speed_stars,
quality_stars,
status
);
}
println!();
println!(
"Your system has {:.1} GB RAM. Recommended: {}",
system_ram_gb,
models.get(recommended_idx).map_or("unknown", |m| &m.name)
);
println!();
loop {
print!("Select model (1-{}): ", models.len());
io::stdout().flush().context("Failed to flush stdout")?;
let mut input = String::new();
io::stdin()
.read_line(&mut input)
.context("Failed to read user input")?;
match input.trim().parse::<usize>() {
Ok(n) if n >= 1 && n <= models.len() => return Ok(n - 1),
_ => {
println!(
"Invalid selection. Please enter a number between 1 and {}.",
models.len()
);
}
}
}
}
fn format_stars(rating: u8) -> String {
let filled = rating.min(5) as usize;
let empty = 5 - filled;
format!("{}{}", "★".repeat(filled), "☆".repeat(empty))
}
fn get_recommended_model_index(ram_gb: f32) -> usize {
if ram_gb >= 8.0 {
2 } else if ram_gb >= 6.0 {
1 } else if ram_gb >= 4.0 {
0 } else {
3 }
}
#[derive(Debug, Clone)]
pub struct EmbeddingModel {
pub name: String,
pub dimension: u32,
pub speed_rating: u8,
pub installed: bool,
}
fn get_default_embedding_models() -> Vec<EmbeddingModel> {
vec![
EmbeddingModel {
name: "nomic-embed-text".to_string(),
dimension: 768,
speed_rating: 5,
installed: false,
},
EmbeddingModel {
name: "mxbai-embed-large".to_string(),
dimension: 1024,
speed_rating: 4,
installed: false,
},
EmbeddingModel {
name: "all-minilm".to_string(),
dimension: 384,
speed_rating: 5,
installed: false,
},
]
}
pub async fn get_available_embedding_models(client: &OllamaClient) -> Result<Vec<EmbeddingModel>> {
let installed_models = client
.list_models()
.await
.context("Failed to fetch installed models from Ollama")?;
let installed_names: Vec<String> = installed_models.iter().map(|m| m.name.clone()).collect();
let mut models = get_default_embedding_models();
for model in &mut models {
model.installed = installed_names.iter().any(|name| {
name == &model.name
|| name.starts_with(&format!("{}:", model.name))
|| model.name == name.split(':').next().unwrap_or(name)
});
}
Ok(models)
}
pub fn display_embedding_model_selection(models: &[EmbeddingModel]) -> Result<usize> {
println!("\nEmbedding Models (for search indexing):\n");
println!(
" # {:<20} {:<10} {:<7} Status",
"Model", "Dimension", "Speed"
);
println!("{}", "-".repeat(55));
for (i, model) in models.iter().enumerate() {
let speed_stars = format_stars(model.speed_rating);
let status = if model.installed && i == 0 {
"[Installed] [Recommended]"
} else if model.installed {
"[Installed]"
} else if i == 0 {
"[Recommended]"
} else {
""
};
println!(
" {} {:<20} {:<10} {:<7} {}",
i + 1,
model.name,
model.dimension,
speed_stars,
status
);
}
println!();
println!("Recommended: {} (fast, good quality)", models[0].name);
println!();
loop {
print!("Select embedding model (1-{}): ", models.len());
io::stdout().flush().context("Failed to flush stdout")?;
let mut input = String::new();
io::stdin()
.read_line(&mut input)
.context("Failed to read user input")?;
match input.trim().parse::<usize>() {
Ok(n) if n >= 1 && n <= models.len() => return Ok(n - 1),
_ => {
println!(
"Invalid selection. Please enter a number between 1 and {}.",
models.len()
);
}
}
}
}
#[derive(Debug, Clone)]
pub struct ModelPreset {
pub name: &'static str,
pub description: &'static str,
pub embedding_model: &'static str,
pub llm_model: &'static str,
pub ram_gb: f32,
}
pub const MODEL_PRESETS: &[ModelPreset] = &[
ModelPreset {
name: "Fast",
description: "Quick responses, lower RAM",
embedding_model: "nomic-embed-text",
llm_model: "llama3.2:3b",
ram_gb: 4.0,
},
ModelPreset {
name: "Balanced",
description: "Good speed and quality",
embedding_model: "nomic-embed-text",
llm_model: "mistral:7b",
ram_gb: 6.0,
},
ModelPreset {
name: "Quality",
description: "Best results, more RAM",
embedding_model: "mxbai-embed-large",
llm_model: "llama3.1:8b",
ram_gb: 8.0,
},
];
#[derive(Debug)]
pub enum PresetSelection {
Preset(usize),
Custom,
}
pub fn display_preset_selection(system_ram_gb: f32) -> Result<PresetSelection> {
#[allow(clippy::bool_to_int_with_if)]
let recommended_idx = if system_ram_gb >= 8.0 {
2 } else if system_ram_gb >= 6.0 {
1 } else {
0 };
println!("\nModel Configuration:\n");
println!(" # {:<12} {:<30} {:<7}", "Preset", "Description", "RAM");
println!("{}", "-".repeat(55));
for (i, preset) in MODEL_PRESETS.iter().enumerate() {
let status = if i == recommended_idx {
"[Recommended]"
} else {
""
};
println!(
" {} {:<12} {:<30} {:<7} {}",
i + 1,
preset.name,
preset.description,
format!("{:.0} GB", preset.ram_gb),
status
);
}
println!(
" {} {:<12} {:<30}",
MODEL_PRESETS.len() + 1,
"Custom",
"Choose models individually"
);
println!();
println!("Your system has {system_ram_gb:.1} GB RAM.");
println!();
loop {
print!("Select option (1-{}): ", MODEL_PRESETS.len() + 1);
io::stdout().flush().context("Failed to flush stdout")?;
let mut input = String::new();
io::stdin()
.read_line(&mut input)
.context("Failed to read user input")?;
match input.trim().parse::<usize>() {
Ok(n) if n >= 1 && n <= MODEL_PRESETS.len() => {
return Ok(PresetSelection::Preset(n - 1));
}
Ok(n) if n == MODEL_PRESETS.len() + 1 => {
return Ok(PresetSelection::Custom);
}
_ => {
println!(
"Invalid selection. Please enter a number between 1 and {}.",
MODEL_PRESETS.len() + 1
);
}
}
}
}
#[derive(Debug, Clone, Deserialize)]
pub struct PullProgress {
pub status: String,
#[serde(default)]
pub digest: Option<String>,
#[serde(default)]
pub total: Option<u64>,
#[serde(default)]
pub completed: Option<u64>,
}
#[derive(Debug, Clone, Serialize)]
struct PullRequest {
name: String,
stream: bool,
}
pub async fn pull_model_with_progress(client: &OllamaClient, model_name: &str) -> Result<()> {
let url = format!("{}/api/pull", client.base_url());
let request = PullRequest {
name: model_name.to_string(),
stream: true,
};
println!("Pulling model '{model_name}'...");
let http_client = reqwest::Client::builder()
.timeout(Duration::from_secs(1800)) .build()
.context("Failed to create HTTP client")?;
let response = http_client
.post(&url)
.json(&request)
.send()
.await
.with_context(|| format!("Failed to start pulling model '{model_name}'"))?;
if !response.status().is_success() {
let status = response.status();
let body = response.text().await.unwrap_or_default();
anyhow::bail!("Failed to pull model '{model_name}' ({status}): {body}");
}
let body = response
.text()
.await
.context("Failed to read pull response")?;
let pb = ProgressBar::new(100);
pb.set_style(
ProgressStyle::default_bar()
.template("{spinner:.green} [{bar:40.cyan/blue}] {percent}% {msg}")
.map_err(|e| anyhow::anyhow!("Failed to set progress style: {e}"))?
.progress_chars("█▓░"),
);
let mut last_digest = String::new();
let mut success = false;
for line in body.lines() {
if line.trim().is_empty() {
continue;
}
let progress: PullProgress = serde_json::from_str(line)
.with_context(|| format!("Failed to parse progress: {line}"))?;
match progress.status.as_str() {
"success" => {
success = true;
pb.finish_with_message("Complete!");
}
"pulling manifest" => {
pb.set_message("Pulling manifest...");
}
"verifying sha256 digest" => {
pb.set_message("Verifying...");
}
"writing manifest" => {
pb.set_message("Writing manifest...");
}
_ => {
if let (Some(total), Some(completed)) = (progress.total, progress.completed) {
if total > 0 {
let percent = completed * 100 / total;
pb.set_position(percent);
if let Some(ref digest) = progress.digest {
if digest != &last_digest {
last_digest.clone_from(digest);
let short_digest = &digest[..digest.len().min(12)];
pb.set_message(format!("Layer {short_digest}..."));
}
}
}
}
}
}
}
if !success {
anyhow::bail!("Pull did not complete successfully");
}
println!("✓ Model '{model_name}' ready");
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_recommended_model_serialization() {
let model = RecommendedModel {
name: "llama3.2:3b".to_string(),
ram_gb: 4.0,
speed_rating: 5,
quality_rating: 3,
installed: true,
};
let json = serde_json::to_string(&model).expect("Failed to serialize");
assert!(json.contains("\"name\":\"llama3.2:3b\""));
assert!(json.contains("\"ram_gb\":4.0"));
assert!(json.contains("\"speed_rating\":5"));
assert!(json.contains("\"quality_rating\":3"));
assert!(json.contains("\"installed\":true"));
}
#[test]
fn test_recommended_model_deserialization() {
let json = r#"{
"name": "mistral:7b",
"ram_gb": 6.0,
"speed_rating": 4,
"quality_rating": 4,
"installed": false
}"#;
let model: RecommendedModel = serde_json::from_str(json).expect("Failed to deserialize");
assert_eq!(model.name, "mistral:7b");
assert_eq!(model.ram_gb, 6.0);
assert_eq!(model.speed_rating, 4);
assert_eq!(model.quality_rating, 4);
assert!(!model.installed);
}
#[test]
fn test_default_models_content() {
let models = get_default_models();
assert_eq!(models.len(), 4);
assert_eq!(models[0].name, "llama3.2:3b");
assert_eq!(models[0].ram_gb, 4.0);
assert_eq!(models[0].speed_rating, 5);
assert_eq!(models[0].quality_rating, 3);
assert!(!models[0].installed);
for model in &models {
assert!(!model.name.is_empty());
assert!(model.speed_rating >= 1 && model.speed_rating <= 5);
assert!(model.quality_rating >= 1 && model.quality_rating <= 5);
}
}
#[test]
fn test_model_equality() {
let model1 = RecommendedModel {
name: "test".to_string(),
ram_gb: 4.0,
speed_rating: 5,
quality_rating: 3,
installed: false,
};
let model2 = model1.clone();
assert_eq!(model1, model2);
}
#[test]
fn test_format_stars() {
assert_eq!(format_stars(5), "★★★★★");
assert_eq!(format_stars(3), "★★★☆☆");
assert_eq!(format_stars(1), "★☆☆☆☆");
assert_eq!(format_stars(0), "☆☆☆☆☆");
}
#[test]
fn test_get_recommended_model_index() {
assert_eq!(get_recommended_model_index(16.0), 2);
assert_eq!(get_recommended_model_index(8.0), 2);
assert_eq!(get_recommended_model_index(7.0), 1);
assert_eq!(get_recommended_model_index(6.0), 1);
assert_eq!(get_recommended_model_index(5.0), 0);
assert_eq!(get_recommended_model_index(4.0), 0);
assert_eq!(get_recommended_model_index(3.0), 3);
assert_eq!(get_recommended_model_index(2.0), 3);
}
#[test]
fn test_get_system_ram_gb() {
let ram = get_system_ram_gb();
assert!(ram > 0.5);
assert!(ram < 1024.0);
}
#[test]
fn test_pull_progress_deserialization() {
let json = r#"{"status": "downloading", "digest": "sha256:abc123", "total": 2000000000, "completed": 500000000}"#;
let progress: PullProgress = serde_json::from_str(json).expect("Failed to deserialize");
assert_eq!(progress.status, "downloading");
assert_eq!(progress.digest, Some("sha256:abc123".to_string()));
assert_eq!(progress.total, Some(2_000_000_000));
assert_eq!(progress.completed, Some(500_000_000));
}
#[test]
fn test_pull_progress_minimal() {
let json = r#"{"status": "success"}"#;
let progress: PullProgress = serde_json::from_str(json).expect("Failed to deserialize");
assert_eq!(progress.status, "success");
assert_eq!(progress.digest, None);
assert_eq!(progress.total, None);
assert_eq!(progress.completed, None);
}
#[test]
fn test_pull_progress_percentage_calculation() {
let total: u64 = 2_000_000_000;
let completed: u64 = 500_000_000;
let percent = (completed * 100 / total) as u64;
assert_eq!(percent, 25);
}
}