use serde::{Deserialize, Serialize};
use std::fmt;
use std::path::PathBuf;
use thiserror::Error;
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct PullProgress {
pub status: String,
pub completed: u64,
pub total: u64,
}
impl PullProgress {
#[must_use]
pub fn new(status: impl Into<String>, completed: u64, total: u64) -> Self {
Self {
status: status.into(),
completed,
total,
}
}
#[must_use]
pub fn percent(&self) -> f64 {
if self.total == 0 {
0.0
} else {
(self.completed as f64 / self.total as f64) * 100.0
}
}
#[must_use]
pub fn is_complete(&self) -> bool {
self.total > 0 && self.completed >= self.total
}
}
impl fmt::Display for PullProgress {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}: {:.1}%", self.status, self.percent())
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct ModelInfo {
pub name: String,
pub size: u64,
pub quantization: Option<String>,
pub parameters: Option<String>,
pub digest: Option<String>,
}
impl ModelInfo {
#[must_use]
pub fn size_gb(&self) -> f64 {
self.size as f64 / 1_000_000_000.0
}
#[must_use]
pub fn size_human(&self) -> String {
let gb = self.size_gb();
if gb >= 1.0 {
format!("{gb:.1} GB")
} else {
format!("{:.0} MB", self.size as f64 / 1_000_000.0)
}
}
}
impl fmt::Display for ModelInfo {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{} ({})", self.name, self.size_human())
}
}
#[derive(Debug, Clone)]
pub struct DownloadRequest<'a> {
pub model: Option<&'a super::models::KnownModel>,
pub hf_repo: Option<String>,
pub filename: Option<String>,
pub quantization: Option<super::models::Quantization>,
pub force: bool,
}
impl<'a> DownloadRequest<'a> {
#[must_use]
pub fn curated(model: &'a super::models::KnownModel) -> Self {
Self {
model: Some(model),
hf_repo: None,
filename: None,
quantization: None,
force: false,
}
}
#[must_use]
pub fn huggingface(repo: impl Into<String>, filename: impl Into<String>) -> Self {
Self {
model: None,
hf_repo: Some(repo.into()),
filename: Some(filename.into()),
quantization: None,
force: false,
}
}
#[must_use]
pub fn with_quantization(mut self, quant: super::models::Quantization) -> Self {
self.quantization = Some(quant);
self
}
#[must_use]
pub fn force(mut self) -> Self {
self.force = true;
self
}
#[must_use]
pub fn target_filename(&self) -> Option<String> {
if let Some(filename) = &self.filename {
return Some(filename.clone());
}
if let Some(model) = self.model {
let quant = self
.quantization
.unwrap_or(super::models::Quantization::Q4_K_M);
return model
.quantizations
.iter()
.find(|(q, _)| *q == quant)
.map(|(_, f)| (*f).to_string());
}
None
}
}
#[derive(Debug, Clone)]
pub struct DownloadResult {
pub path: PathBuf,
pub size: u64,
pub checksum: Option<String>,
pub cached: bool,
}
#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
#[serde(tag = "kind", rename_all = "snake_case")]
pub enum NativeModelKind {
#[default]
TextGguf,
VisionHf {
model_id: String,
isq: Option<String>,
},
}
impl NativeModelKind {
#[must_use]
pub fn is_vision(&self) -> bool {
matches!(self, Self::VisionHf { .. })
}
}
#[derive(Debug, Clone)]
pub struct VisionImage {
pub bytes: Vec<u8>,
pub media_type: String,
}
impl VisionImage {
#[must_use]
pub fn new(bytes: Vec<u8>, media_type: impl Into<String>) -> Self {
Self {
bytes,
media_type: media_type.into(),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct LoadConfig {
pub gpu_ids: Vec<u32>,
pub gpu_layers: i32,
pub context_size: Option<u32>,
pub keep_alive: bool,
#[serde(default)]
pub model_kind: NativeModelKind,
}
impl Default for LoadConfig {
fn default() -> Self {
Self {
gpu_ids: Vec::new(),
gpu_layers: -1, context_size: None,
keep_alive: false,
model_kind: NativeModelKind::default(),
}
}
}
impl LoadConfig {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn with_gpus(mut self, gpu_ids: Vec<u32>) -> Self {
self.gpu_ids = gpu_ids;
self
}
#[must_use]
pub fn with_gpu_layers(mut self, layers: i32) -> Self {
self.gpu_layers = layers;
self
}
#[must_use]
pub fn with_context_size(mut self, size: u32) -> Self {
self.context_size = Some(size);
self
}
#[must_use]
pub fn with_keep_alive(mut self, keep: bool) -> Self {
self.keep_alive = keep;
self
}
#[must_use]
pub fn is_cpu_only(&self) -> bool {
self.gpu_layers == 0
}
#[must_use]
pub fn is_full_gpu(&self) -> bool {
self.gpu_layers < 0
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum ChatRole {
System,
User,
Assistant,
}
impl fmt::Display for ChatRole {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::System => write!(f, "system"),
Self::User => write!(f, "user"),
Self::Assistant => write!(f, "assistant"),
}
}
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct ChatMessage {
pub role: ChatRole,
pub content: String,
}
impl ChatMessage {
#[must_use]
pub fn system(content: impl Into<String>) -> Self {
Self {
role: ChatRole::System,
content: content.into(),
}
}
#[must_use]
pub fn user(content: impl Into<String>) -> Self {
Self {
role: ChatRole::User,
content: content.into(),
}
}
#[must_use]
pub fn assistant(content: impl Into<String>) -> Self {
Self {
role: ChatRole::Assistant,
content: content.into(),
}
}
}
#[derive(Debug, Clone, PartialEq, Default, Serialize, Deserialize)]
pub struct ChatOptions {
pub temperature: Option<f32>,
pub top_p: Option<f32>,
pub top_k: Option<u32>,
pub max_tokens: Option<u32>,
pub stop: Vec<String>,
pub seed: Option<u64>,
}
impl ChatOptions {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn with_temperature(mut self, temp: f32) -> Self {
self.temperature = Some(temp);
self
}
#[must_use]
pub fn with_top_p(mut self, top_p: f32) -> Self {
self.top_p = Some(top_p);
self
}
#[must_use]
pub fn with_top_k(mut self, top_k: u32) -> Self {
self.top_k = Some(top_k);
self
}
#[must_use]
pub fn with_max_tokens(mut self, max: u32) -> Self {
self.max_tokens = Some(max);
self
}
#[must_use]
pub fn with_stop(mut self, stop: impl Into<String>) -> Self {
self.stop.push(stop.into());
self
}
#[must_use]
pub fn with_seed(mut self, seed: u64) -> Self {
self.seed = Some(seed);
self
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct ChatResponse {
pub message: ChatMessage,
pub done: bool,
pub total_duration: Option<u64>,
pub eval_count: Option<u64>,
pub prompt_eval_count: Option<u64>,
}
impl ChatResponse {
#[must_use]
pub fn content(&self) -> &str {
&self.message.content
}
#[must_use]
pub fn tokens_per_second(&self) -> Option<f64> {
match (self.eval_count, self.total_duration) {
(Some(count), Some(duration)) if duration > 0 => {
Some(count as f64 / (duration as f64 / 1_000_000_000.0))
}
_ => None,
}
}
}
#[derive(Error, Debug, Clone)]
pub enum BackendError {
#[error("Backend server is not running")]
NotRunning,
#[error("Model not found: {0}")]
ModelNotFound(String),
#[error("Model already loaded: {0}")]
AlreadyLoaded(String),
#[error("Insufficient memory to load model")]
InsufficientMemory,
#[error("Network error: {0}")]
NetworkError(String),
#[error("Process error: {0}")]
ProcessError(String),
#[error("Backend error: {0}")]
BackendSpecific(String),
#[error("Missing API key for provider: {0}")]
MissingApiKey(String),
#[error("API error (HTTP {status}): {message}")]
ApiError {
status: u16,
message: String,
},
#[error("Parse error: {0}")]
ParseError(String),
#[error("Model load error: {0}")]
LoadError(String),
#[error("Inference error: {0}")]
InferenceError(String),
#[error("Invalid configuration: {0}")]
InvalidConfig(String),
#[error("Storage error: {0}")]
StorageError(String),
#[error("Download error: {0}")]
DownloadError(String),
#[error("Checksum mismatch: expected {expected}, got {actual}")]
ChecksumError {
expected: String,
actual: String,
},
#[error("Path traversal detected: '{path}' escapes storage directory")]
PathTraversal {
path: String,
},
}
impl BackendError {
#[must_use]
pub const fn is_retryable(&self) -> bool {
matches!(
self,
Self::NetworkError(_) | Self::NotRunning | Self::DownloadError(_)
)
}
#[must_use]
pub fn is_auth_error(&self) -> bool {
match self {
Self::MissingApiKey(_) => true,
Self::ApiError { status, .. } => *status == 401 || *status == 403,
_ => false,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_pull_progress() {
let progress = PullProgress::new("downloading", 500, 1000);
assert_eq!(progress.percent(), 50.0);
assert!(!progress.is_complete());
let complete = PullProgress::new("complete", 1000, 1000);
assert!(complete.is_complete());
}
#[test]
fn test_pull_progress_display() {
let progress = PullProgress::new("pulling", 750, 1000);
assert_eq!(progress.to_string(), "pulling: 75.0%");
}
#[test]
fn test_model_info_size() {
let info = ModelInfo {
name: "llama3.2:7b".to_string(),
size: 4_500_000_000,
quantization: Some("Q4_K_M".to_string()),
parameters: Some("7B".to_string()),
digest: None,
};
assert!((info.size_gb() - 4.5).abs() < 0.01);
assert_eq!(info.size_human(), "4.5 GB");
}
#[test]
fn test_load_config_default() {
let config = LoadConfig::default();
assert!(config.gpu_ids.is_empty());
assert_eq!(config.gpu_layers, -1);
assert!(config.is_full_gpu());
assert!(!config.is_cpu_only());
}
#[test]
fn test_load_config_builder() {
let config = LoadConfig::new()
.with_gpus(vec![0, 1])
.with_gpu_layers(32)
.with_context_size(8192)
.with_keep_alive(true);
assert_eq!(config.gpu_ids, vec![0, 1]);
assert_eq!(config.gpu_layers, 32);
assert_eq!(config.context_size, Some(8192));
assert!(config.keep_alive);
}
#[test]
fn test_chat_message_constructors() {
let system = ChatMessage::system("You are helpful");
assert_eq!(system.role, ChatRole::System);
assert_eq!(system.content, "You are helpful");
let user = ChatMessage::user("Hello");
assert_eq!(user.role, ChatRole::User);
let assistant = ChatMessage::assistant("Hi there!");
assert_eq!(assistant.role, ChatRole::Assistant);
}
#[test]
fn test_chat_options_builder() {
let options = ChatOptions::new()
.with_temperature(0.7)
.with_top_p(0.9)
.with_max_tokens(100);
assert_eq!(options.temperature, Some(0.7));
assert_eq!(options.top_p, Some(0.9));
assert_eq!(options.max_tokens, Some(100));
}
#[test]
fn test_backend_error_is_retryable() {
assert!(BackendError::NetworkError("timeout".to_string()).is_retryable());
assert!(BackendError::NotRunning.is_retryable());
assert!(!BackendError::ModelNotFound("model".to_string()).is_retryable());
assert!(!BackendError::InsufficientMemory.is_retryable());
}
#[test]
fn test_native_model_kind_serde_text_gguf() {
let kind = NativeModelKind::TextGguf;
let json = serde_json::to_string(&kind).unwrap();
assert!(json.contains("text_gguf"));
let roundtrip: NativeModelKind = serde_json::from_str(&json).unwrap();
assert_eq!(roundtrip, NativeModelKind::TextGguf);
}
#[test]
fn test_native_model_kind_serde_vision_hf() {
let kind = NativeModelKind::VisionHf {
model_id: "Qwen/Qwen2.5-VL-7B-Instruct".to_string(),
isq: Some("Q4K".to_string()),
};
let json = serde_json::to_string(&kind).unwrap();
assert!(json.contains("vision_hf"));
assert!(json.contains("Qwen/Qwen2.5-VL-7B-Instruct"));
assert!(json.contains("Q4K"));
let roundtrip: NativeModelKind = serde_json::from_str(&json).unwrap();
assert_eq!(roundtrip, kind);
}
#[test]
fn test_native_model_kind_serde_vision_hf_no_isq() {
let kind = NativeModelKind::VisionHf {
model_id: "google/gemma-3-4b-it".to_string(),
isq: None,
};
let json = serde_json::to_string(&kind).unwrap();
assert!(json.contains("vision_hf"));
assert!(json.contains("google/gemma-3-4b-it"));
let roundtrip: NativeModelKind = serde_json::from_str(&json).unwrap();
assert_eq!(roundtrip, kind);
assert_eq!(
roundtrip,
NativeModelKind::VisionHf {
model_id: "google/gemma-3-4b-it".to_string(),
isq: None,
}
);
}
#[test]
fn test_load_config_serde_default_model_kind() {
let json = r#"{"gpu_ids":[],"gpu_layers":-1,"context_size":null,"keep_alive":false}"#;
let config: LoadConfig = serde_json::from_str(json).unwrap();
assert_eq!(config.model_kind, NativeModelKind::TextGguf);
}
#[test]
fn test_load_config_serde_with_vision_hf() {
let json = r#"{"gpu_ids":[],"gpu_layers":-1,"context_size":4096,"keep_alive":false,"model_kind":{"kind":"vision_hf","model_id":"Qwen/Qwen2.5-VL-7B","isq":"Q4K"}}"#;
let config: LoadConfig = serde_json::from_str(json).unwrap();
assert!(config.model_kind.is_vision());
match &config.model_kind {
NativeModelKind::VisionHf { model_id, isq } => {
assert_eq!(model_id, "Qwen/Qwen2.5-VL-7B");
assert_eq!(isq.as_deref(), Some("Q4K"));
}
_ => panic!("Expected VisionHf"),
}
}
#[test]
fn test_vision_image_construction() {
let img = VisionImage::new(vec![0x89, 0x50, 0x4E, 0x47], "image/png");
assert_eq!(img.bytes.len(), 4);
assert_eq!(img.media_type, "image/png");
}
#[test]
fn test_vision_image_empty_bytes() {
let img = VisionImage::new(vec![], "image/jpeg");
assert_eq!(img.bytes.len(), 0);
assert_eq!(img.media_type, "image/jpeg");
}
}