1use serde::{Deserialize, Serialize};
8use std::fmt;
9use std::path::PathBuf;
10use thiserror::Error;
11
12#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
18pub struct PullProgress {
19 pub status: String,
21 pub completed: u64,
23 pub total: u64,
25}
26
27impl PullProgress {
28 #[must_use]
30 pub fn new(status: impl Into<String>, completed: u64, total: u64) -> Self {
31 Self {
32 status: status.into(),
33 completed,
34 total,
35 }
36 }
37
38 #[must_use]
40 pub fn percent(&self) -> f64 {
41 if self.total == 0 {
42 0.0
43 } else {
44 (self.completed as f64 / self.total as f64) * 100.0
45 }
46 }
47
48 #[must_use]
50 pub fn is_complete(&self) -> bool {
51 self.total > 0 && self.completed >= self.total
52 }
53}
54
55impl fmt::Display for PullProgress {
56 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
57 write!(f, "{}: {:.1}%", self.status, self.percent())
58 }
59}
60
61#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
67pub struct ModelInfo {
68 pub name: String,
70 pub size: u64,
72 pub quantization: Option<String>,
74 pub parameters: Option<String>,
76 pub digest: Option<String>,
78}
79
80impl ModelInfo {
81 #[must_use]
83 pub fn size_gb(&self) -> f64 {
84 self.size as f64 / 1_000_000_000.0
85 }
86
87 #[must_use]
89 pub fn size_human(&self) -> String {
90 let gb = self.size_gb();
91 if gb >= 1.0 {
92 format!("{gb:.1} GB")
93 } else {
94 format!("{:.0} MB", self.size as f64 / 1_000_000.0)
95 }
96 }
97}
98
99impl fmt::Display for ModelInfo {
100 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
101 write!(f, "{} ({})", self.name, self.size_human())
102 }
103}
104
105#[derive(Debug, Clone)]
111pub struct DownloadRequest<'a> {
112 pub model: Option<&'a super::models::KnownModel>,
114
115 pub hf_repo: Option<String>,
117
118 pub filename: Option<String>,
120
121 pub quantization: Option<super::models::Quantization>,
123
124 pub force: bool,
126}
127
128impl<'a> DownloadRequest<'a> {
129 #[must_use]
131 pub fn curated(model: &'a super::models::KnownModel) -> Self {
132 Self {
133 model: Some(model),
134 hf_repo: None,
135 filename: None,
136 quantization: None,
137 force: false,
138 }
139 }
140
141 #[must_use]
143 pub fn huggingface(repo: impl Into<String>, filename: impl Into<String>) -> Self {
144 Self {
145 model: None,
146 hf_repo: Some(repo.into()),
147 filename: Some(filename.into()),
148 quantization: None,
149 force: false,
150 }
151 }
152
153 #[must_use]
155 pub fn with_quantization(mut self, quant: super::models::Quantization) -> Self {
156 self.quantization = Some(quant);
157 self
158 }
159
160 #[must_use]
162 pub fn force(mut self) -> Self {
163 self.force = true;
164 self
165 }
166
167 #[must_use]
169 pub fn target_filename(&self) -> Option<String> {
170 if let Some(filename) = &self.filename {
171 return Some(filename.clone());
172 }
173
174 if let Some(model) = self.model {
175 let quant = self
176 .quantization
177 .unwrap_or(super::models::Quantization::Q4_K_M);
178 return model
180 .quantizations
181 .iter()
182 .find(|(q, _)| *q == quant)
183 .map(|(_, f)| (*f).to_string());
184 }
185
186 None
187 }
188}
189
190#[derive(Debug, Clone)]
192pub struct DownloadResult {
193 pub path: PathBuf,
195
196 pub size: u64,
198
199 pub checksum: Option<String>,
201
202 pub cached: bool,
204}
205
206#[derive(Debug, Clone, Default, PartialEq, Eq, Serialize, Deserialize)]
220#[serde(tag = "kind", rename_all = "snake_case")]
221pub enum NativeModelKind {
222 #[default]
224 TextGguf,
225 VisionHf {
227 model_id: String,
229 isq: Option<String>,
232 },
233}
234
235impl NativeModelKind {
236 #[must_use]
238 pub fn is_vision(&self) -> bool {
239 matches!(self, Self::VisionHf { .. })
240 }
241}
242
243#[derive(Debug, Clone)]
253pub struct VisionImage {
254 pub bytes: Vec<u8>,
256 pub media_type: String,
258}
259
260impl VisionImage {
261 #[must_use]
263 pub fn new(bytes: Vec<u8>, media_type: impl Into<String>) -> Self {
264 Self {
265 bytes,
266 media_type: media_type.into(),
267 }
268 }
269}
270
271#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
277pub struct LoadConfig {
278 pub gpu_ids: Vec<u32>,
280 pub gpu_layers: i32,
282 pub context_size: Option<u32>,
284 pub keep_alive: bool,
286 #[serde(default)]
289 pub model_kind: NativeModelKind,
290}
291
292impl Default for LoadConfig {
293 fn default() -> Self {
294 Self {
295 gpu_ids: Vec::new(),
296 gpu_layers: -1, context_size: None,
298 keep_alive: false,
299 model_kind: NativeModelKind::default(),
300 }
301 }
302}
303
304impl LoadConfig {
305 #[must_use]
307 pub fn new() -> Self {
308 Self::default()
309 }
310
311 #[must_use]
313 pub fn with_gpus(mut self, gpu_ids: Vec<u32>) -> Self {
314 self.gpu_ids = gpu_ids;
315 self
316 }
317
318 #[must_use]
320 pub fn with_gpu_layers(mut self, layers: i32) -> Self {
321 self.gpu_layers = layers;
322 self
323 }
324
325 #[must_use]
327 pub fn with_context_size(mut self, size: u32) -> Self {
328 self.context_size = Some(size);
329 self
330 }
331
332 #[must_use]
334 pub fn with_keep_alive(mut self, keep: bool) -> Self {
335 self.keep_alive = keep;
336 self
337 }
338
339 #[must_use]
341 pub fn is_cpu_only(&self) -> bool {
342 self.gpu_layers == 0
343 }
344
345 #[must_use]
347 pub fn is_full_gpu(&self) -> bool {
348 self.gpu_layers < 0
349 }
350}
351
352#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
358#[serde(rename_all = "lowercase")]
359pub enum ChatRole {
360 System,
362 User,
364 Assistant,
366}
367
368impl fmt::Display for ChatRole {
369 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
370 match self {
371 Self::System => write!(f, "system"),
372 Self::User => write!(f, "user"),
373 Self::Assistant => write!(f, "assistant"),
374 }
375 }
376}
377
378#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
380pub struct ChatMessage {
381 pub role: ChatRole,
383 pub content: String,
385}
386
387impl ChatMessage {
388 #[must_use]
390 pub fn system(content: impl Into<String>) -> Self {
391 Self {
392 role: ChatRole::System,
393 content: content.into(),
394 }
395 }
396
397 #[must_use]
399 pub fn user(content: impl Into<String>) -> Self {
400 Self {
401 role: ChatRole::User,
402 content: content.into(),
403 }
404 }
405
406 #[must_use]
408 pub fn assistant(content: impl Into<String>) -> Self {
409 Self {
410 role: ChatRole::Assistant,
411 content: content.into(),
412 }
413 }
414}
415
416#[derive(Debug, Clone, PartialEq, Default, Serialize, Deserialize)]
418pub struct ChatOptions {
419 pub temperature: Option<f32>,
421 pub top_p: Option<f32>,
423 pub top_k: Option<u32>,
425 pub max_tokens: Option<u32>,
427 pub stop: Vec<String>,
429 pub seed: Option<u64>,
431}
432
433impl ChatOptions {
434 #[must_use]
436 pub fn new() -> Self {
437 Self::default()
438 }
439
440 #[must_use]
442 pub fn with_temperature(mut self, temp: f32) -> Self {
443 self.temperature = Some(temp);
444 self
445 }
446
447 #[must_use]
449 pub fn with_top_p(mut self, top_p: f32) -> Self {
450 self.top_p = Some(top_p);
451 self
452 }
453
454 #[must_use]
456 pub fn with_top_k(mut self, top_k: u32) -> Self {
457 self.top_k = Some(top_k);
458 self
459 }
460
461 #[must_use]
463 pub fn with_max_tokens(mut self, max: u32) -> Self {
464 self.max_tokens = Some(max);
465 self
466 }
467
468 #[must_use]
470 pub fn with_stop(mut self, stop: impl Into<String>) -> Self {
471 self.stop.push(stop.into());
472 self
473 }
474
475 #[must_use]
477 pub fn with_seed(mut self, seed: u64) -> Self {
478 self.seed = Some(seed);
479 self
480 }
481}
482
483#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
485pub struct ChatResponse {
486 pub message: ChatMessage,
488 pub done: bool,
490 pub total_duration: Option<u64>,
492 pub eval_count: Option<u64>,
494 pub prompt_eval_count: Option<u64>,
496}
497
498impl ChatResponse {
499 #[must_use]
501 pub fn content(&self) -> &str {
502 &self.message.content
503 }
504
505 #[must_use]
507 pub fn tokens_per_second(&self) -> Option<f64> {
508 match (self.eval_count, self.total_duration) {
509 (Some(count), Some(duration)) if duration > 0 => {
510 Some(count as f64 / (duration as f64 / 1_000_000_000.0))
511 }
512 _ => None,
513 }
514 }
515}
516
517#[derive(Error, Debug, Clone)]
523pub enum BackendError {
524 #[error("Backend server is not running")]
526 NotRunning,
527
528 #[error("Model not found: {0}")]
530 ModelNotFound(String),
531
532 #[error("Model already loaded: {0}")]
534 AlreadyLoaded(String),
535
536 #[error("Insufficient memory to load model")]
538 InsufficientMemory,
539
540 #[error("Network error: {0}")]
542 NetworkError(String),
543
544 #[error("Process error: {0}")]
546 ProcessError(String),
547
548 #[error("Backend error: {0}")]
550 BackendSpecific(String),
551
552 #[error("Missing API key for provider: {0}")]
554 MissingApiKey(String),
555
556 #[error("API error (HTTP {status}): {message}")]
558 ApiError {
559 status: u16,
561 message: String,
563 },
564
565 #[error("Parse error: {0}")]
567 ParseError(String),
568
569 #[error("Model load error: {0}")]
571 LoadError(String),
572
573 #[error("Inference error: {0}")]
575 InferenceError(String),
576
577 #[error("Invalid configuration: {0}")]
579 InvalidConfig(String),
580
581 #[error("Storage error: {0}")]
583 StorageError(String),
584
585 #[error("Download error: {0}")]
587 DownloadError(String),
588
589 #[error("Checksum mismatch: expected {expected}, got {actual}")]
591 ChecksumError {
592 expected: String,
594 actual: String,
596 },
597
598 #[error("Path traversal detected: '{path}' escapes storage directory")]
600 PathTraversal {
601 path: String,
603 },
604}
605
606impl BackendError {
607 #[must_use]
609 pub const fn is_retryable(&self) -> bool {
610 matches!(
611 self,
612 Self::NetworkError(_) | Self::NotRunning | Self::DownloadError(_)
613 )
614 }
615
616 #[must_use]
618 pub fn is_auth_error(&self) -> bool {
619 match self {
620 Self::MissingApiKey(_) => true,
621 Self::ApiError { status, .. } => *status == 401 || *status == 403,
622 _ => false,
623 }
624 }
625}
626
627#[cfg(test)]
632mod tests {
633 use super::*;
634
635 #[test]
636 fn test_pull_progress() {
637 let progress = PullProgress::new("downloading", 500, 1000);
638 assert_eq!(progress.percent(), 50.0);
639 assert!(!progress.is_complete());
640
641 let complete = PullProgress::new("complete", 1000, 1000);
642 assert!(complete.is_complete());
643 }
644
645 #[test]
646 fn test_pull_progress_display() {
647 let progress = PullProgress::new("pulling", 750, 1000);
648 assert_eq!(progress.to_string(), "pulling: 75.0%");
649 }
650
651 #[test]
652 fn test_model_info_size() {
653 let info = ModelInfo {
654 name: "llama3.2:7b".to_string(),
655 size: 4_500_000_000,
656 quantization: Some("Q4_K_M".to_string()),
657 parameters: Some("7B".to_string()),
658 digest: None,
659 };
660
661 assert!((info.size_gb() - 4.5).abs() < 0.01);
662 assert_eq!(info.size_human(), "4.5 GB");
663 }
664
665 #[test]
666 fn test_load_config_default() {
667 let config = LoadConfig::default();
668 assert!(config.gpu_ids.is_empty());
669 assert_eq!(config.gpu_layers, -1);
670 assert!(config.is_full_gpu());
671 assert!(!config.is_cpu_only());
672 }
673
674 #[test]
675 fn test_load_config_builder() {
676 let config = LoadConfig::new()
677 .with_gpus(vec![0, 1])
678 .with_gpu_layers(32)
679 .with_context_size(8192)
680 .with_keep_alive(true);
681
682 assert_eq!(config.gpu_ids, vec![0, 1]);
683 assert_eq!(config.gpu_layers, 32);
684 assert_eq!(config.context_size, Some(8192));
685 assert!(config.keep_alive);
686 }
687
688 #[test]
689 fn test_chat_message_constructors() {
690 let system = ChatMessage::system("You are helpful");
691 assert_eq!(system.role, ChatRole::System);
692 assert_eq!(system.content, "You are helpful");
693
694 let user = ChatMessage::user("Hello");
695 assert_eq!(user.role, ChatRole::User);
696
697 let assistant = ChatMessage::assistant("Hi there!");
698 assert_eq!(assistant.role, ChatRole::Assistant);
699 }
700
701 #[test]
702 fn test_chat_options_builder() {
703 let options = ChatOptions::new()
704 .with_temperature(0.7)
705 .with_top_p(0.9)
706 .with_max_tokens(100);
707
708 assert_eq!(options.temperature, Some(0.7));
709 assert_eq!(options.top_p, Some(0.9));
710 assert_eq!(options.max_tokens, Some(100));
711 }
712
713 #[test]
714 fn test_backend_error_is_retryable() {
715 assert!(BackendError::NetworkError("timeout".to_string()).is_retryable());
716 assert!(BackendError::NotRunning.is_retryable());
717 assert!(!BackendError::ModelNotFound("model".to_string()).is_retryable());
718 assert!(!BackendError::InsufficientMemory.is_retryable());
719 }
720
721 #[test]
726 fn test_native_model_kind_serde_text_gguf() {
727 let kind = NativeModelKind::TextGguf;
728 let json = serde_json::to_string(&kind).unwrap();
729 assert!(json.contains("text_gguf"));
730 let roundtrip: NativeModelKind = serde_json::from_str(&json).unwrap();
731 assert_eq!(roundtrip, NativeModelKind::TextGguf);
732 }
733
734 #[test]
735 fn test_native_model_kind_serde_vision_hf() {
736 let kind = NativeModelKind::VisionHf {
737 model_id: "Qwen/Qwen2.5-VL-7B-Instruct".to_string(),
738 isq: Some("Q4K".to_string()),
739 };
740 let json = serde_json::to_string(&kind).unwrap();
741 assert!(json.contains("vision_hf"));
742 assert!(json.contains("Qwen/Qwen2.5-VL-7B-Instruct"));
743 assert!(json.contains("Q4K"));
744 let roundtrip: NativeModelKind = serde_json::from_str(&json).unwrap();
745 assert_eq!(roundtrip, kind);
746 }
747
748 #[test]
749 fn test_native_model_kind_serde_vision_hf_no_isq() {
750 let kind = NativeModelKind::VisionHf {
751 model_id: "google/gemma-3-4b-it".to_string(),
752 isq: None,
753 };
754 let json = serde_json::to_string(&kind).unwrap();
755 assert!(json.contains("vision_hf"));
756 assert!(json.contains("google/gemma-3-4b-it"));
757 let roundtrip: NativeModelKind = serde_json::from_str(&json).unwrap();
759 assert_eq!(roundtrip, kind);
760 assert_eq!(
761 roundtrip,
762 NativeModelKind::VisionHf {
763 model_id: "google/gemma-3-4b-it".to_string(),
764 isq: None,
765 }
766 );
767 }
768
769 #[test]
774 fn test_load_config_serde_default_model_kind() {
775 let json = r#"{"gpu_ids":[],"gpu_layers":-1,"context_size":null,"keep_alive":false}"#;
777 let config: LoadConfig = serde_json::from_str(json).unwrap();
778 assert_eq!(config.model_kind, NativeModelKind::TextGguf);
779 }
780
781 #[test]
782 fn test_load_config_serde_with_vision_hf() {
783 let json = r#"{"gpu_ids":[],"gpu_layers":-1,"context_size":4096,"keep_alive":false,"model_kind":{"kind":"vision_hf","model_id":"Qwen/Qwen2.5-VL-7B","isq":"Q4K"}}"#;
784 let config: LoadConfig = serde_json::from_str(json).unwrap();
785 assert!(config.model_kind.is_vision());
786 match &config.model_kind {
787 NativeModelKind::VisionHf { model_id, isq } => {
788 assert_eq!(model_id, "Qwen/Qwen2.5-VL-7B");
789 assert_eq!(isq.as_deref(), Some("Q4K"));
790 }
791 _ => panic!("Expected VisionHf"),
792 }
793 }
794
795 #[test]
800 fn test_vision_image_construction() {
801 let img = VisionImage::new(vec![0x89, 0x50, 0x4E, 0x47], "image/png");
802 assert_eq!(img.bytes.len(), 4);
803 assert_eq!(img.media_type, "image/png");
804 }
805
806 #[test]
807 fn test_vision_image_empty_bytes() {
808 let img = VisionImage::new(vec![], "image/jpeg");
809 assert_eq!(img.bytes.len(), 0);
810 assert_eq!(img.media_type, "image/jpeg");
811 }
812}