Skip to main content

nika_engine/core/
storage.rs

1//! HuggingFace model storage implementation.
2//!
3//! Downloads GGUF models from HuggingFace Hub with:
4//! - Progress callbacks
5//! - SHA256 checksum verification
6//! - Resumable downloads (via HTTP Range requests)
7//! - Caching (skip download if file exists and matches checksum)
8//!
9//! # Example
10//!
11//! ```rust,ignore
12//! use nika::core::storage::{HuggingFaceStorage, default_model_dir};
13//! use nika::core::backend::{DownloadRequest, PullProgress};
14//! use nika::core::models::find_model;
15//!
16//! #[tokio::main]
17//! async fn main() -> Result<(), Box<dyn std::error::Error>> {
18//!     let storage = HuggingFaceStorage::new(default_model_dir())?;
19//!     let model = find_model("qwen3:8b").unwrap();
20//!     let request = DownloadRequest::curated(model);
21//!
22//!     let result = storage.download(&request, |p| {
23//!         println!("{}", p);
24//!     }).await?;
25//!
26//!     println!("Downloaded: {:?}", result.path);
27//!     Ok(())
28//! }
29//! ```
30
31use crate::core::backend::{
32    BackendError, DownloadRequest, DownloadResult, ModelInfo, PullProgress,
33};
34use futures::StreamExt;
35use reqwest::Client;
36use serde::Deserialize;
37use sha2::{Digest, Sha256};
38use std::path::{Path, PathBuf};
39use thiserror::Error;
40use tokio::fs::{self, File};
41use tokio::io::AsyncWriteExt;
42
43// ============================================================================
44// Storage Error
45// ============================================================================
46
47/// Error types for storage operations.
48#[derive(Error, Debug)]
49pub enum StorageError {
50    /// Model not found on HuggingFace.
51    #[error("Model not found: {repo}/{filename}")]
52    ModelNotFound {
53        /// HuggingFace repository.
54        repo: String,
55        /// Filename.
56        filename: String,
57    },
58
59    /// Checksum verification failed.
60    #[error("Checksum mismatch for {path}: expected {expected}, got {actual}")]
61    ChecksumMismatch {
62        /// File path.
63        path: PathBuf,
64        /// Expected checksum.
65        expected: String,
66        /// Actual checksum.
67        actual: String,
68    },
69
70    /// Invalid configuration.
71    #[error("Invalid configuration: {0}")]
72    InvalidConfig(String),
73
74    /// Network error.
75    #[error("Network error: {0}")]
76    Network(#[from] reqwest::Error),
77
78    /// I/O error.
79    #[error("I/O error: {0}")]
80    Io(#[from] std::io::Error),
81}
82
83/// Result type for storage operations.
84pub type Result<T> = std::result::Result<T, StorageError>;
85
86// ============================================================================
87// HuggingFace API Types
88// ============================================================================
89
90/// File info from HuggingFace API.
91#[derive(Debug, Deserialize)]
92struct HfFileInfo {
93    /// Filename.
94    #[serde(rename = "path")]
95    filename: String,
96    /// File size in bytes.
97    size: u64,
98    /// LFS info (contains SHA256).
99    lfs: Option<HfLfsInfo>,
100}
101
102/// LFS metadata from HuggingFace.
103#[derive(Debug, Deserialize)]
104struct HfLfsInfo {
105    /// SHA256 checksum.
106    #[serde(rename = "oid")]
107    sha256: String,
108}
109
110// ============================================================================
111// Default Paths
112// ============================================================================
113
114/// Returns the default model storage directory.
115///
116/// Platform-specific:
117/// - macOS: `~/Library/Application Support/nika/models`
118/// - Linux: `~/.local/share/nika/models`
119/// - Windows: `%APPDATA%/nika/models`
120#[must_use]
121pub fn default_model_dir() -> PathBuf {
122    dirs::data_dir()
123        .unwrap_or_else(|| PathBuf::from("."))
124        .join("nika")
125        .join("models")
126}
127
128// ============================================================================
129// Platform Detection
130// ============================================================================
131
132/// Detect total system RAM in gigabytes.
133///
134/// This is a re-export of [`crate::util::system::get_total_ram_gb`]
135/// compatibility. New code should use `crate::util::system` directly.
136#[must_use]
137pub fn detect_system_ram_gb() -> f64 {
138    crate::util::system::get_total_ram_gb()
139}
140
141// ============================================================================
142// ModelStorage Trait
143// ============================================================================
144
145/// Trait for model storage backends.
146///
147/// Provides a common interface for downloading and managing local models.
148#[allow(async_fn_in_trait)]
149pub trait ModelStorage {
150    /// Download a model with progress callback.
151    async fn download<F>(
152        &self,
153        request: &DownloadRequest<'_>,
154        progress: F,
155    ) -> Result<DownloadResult>
156    where
157        F: Fn(PullProgress) + Send + 'static;
158
159    /// List all downloaded models.
160    fn list_models(&self) -> std::result::Result<Vec<ModelInfo>, BackendError>;
161
162    /// Check if a model exists locally.
163    fn exists(&self, model_id: &str) -> bool;
164
165    /// Get info about a local model.
166    fn model_info(&self, model_id: &str) -> std::result::Result<ModelInfo, BackendError>;
167
168    /// Delete a local model.
169    fn delete(&self, model_id: &str) -> std::result::Result<(), BackendError>;
170
171    /// Get the path to a model file.
172    ///
173    /// # Security
174    ///
175    /// Validates that model_id doesn't escape storage directory via path traversal.
176    ///
177    /// # Errors
178    ///
179    /// Returns `BackendError::PathTraversal` if the path would escape the storage directory.
180    fn model_path(&self, model_id: &str) -> std::result::Result<PathBuf, BackendError>;
181}
182
183// ============================================================================
184// HuggingFace Storage
185// ============================================================================
186
187/// Storage backend for HuggingFace Hub models.
188///
189/// Downloads GGUF models from HuggingFace with progress tracking and
190/// checksum verification.
191pub struct HuggingFaceStorage {
192    /// Root directory for model storage.
193    storage_dir: PathBuf,
194    /// HTTP client.
195    client: Client,
196}
197
198impl HuggingFaceStorage {
199    /// Create a new HuggingFace storage with the given directory.
200    ///
201    /// # Errors
202    ///
203    /// Returns `StorageError::InvalidConfig` if the HTTP client cannot be built.
204    pub fn new(storage_dir: PathBuf) -> Result<Self> {
205        let user_agent = format!("nika/{}", env!("CARGO_PKG_VERSION"));
206        let client = Client::builder()
207            .user_agent(&user_agent)
208            .build()
209            .map_err(|e| {
210                StorageError::InvalidConfig(format!("Failed to create HTTP client: {e}"))
211            })?;
212
213        Ok(Self {
214            storage_dir,
215            client,
216        })
217    }
218
219    /// Create storage with a custom HTTP client.
220    #[must_use]
221    pub fn with_client(storage_dir: PathBuf, client: Client) -> Self {
222        Self {
223            storage_dir,
224            client,
225        }
226    }
227
228    /// Get the storage directory.
229    #[must_use]
230    pub fn storage_dir(&self) -> &Path {
231        &self.storage_dir
232    }
233
234    /// Download a model with progress callback.
235    ///
236    /// # Arguments
237    ///
238    /// * `request` - Download request specifying model and quantization
239    /// * `progress` - Callback for download progress updates
240    ///
241    /// # Errors
242    ///
243    /// Returns error if:
244    /// - Model not found on HuggingFace
245    /// - Network error during download
246    /// - Checksum verification fails
247    /// - I/O error writing file
248    pub async fn download<F>(
249        &self,
250        request: &DownloadRequest<'_>,
251        progress: F,
252    ) -> Result<DownloadResult>
253    where
254        F: Fn(PullProgress) + Send + 'static,
255    {
256        // Resolve repo and filename
257        let (repo, filename) = self.resolve_request(request)?;
258
259        // Create storage directory
260        let model_dir = self.storage_dir.join(&repo);
261        fs::create_dir_all(&model_dir).await?;
262
263        let file_path = model_dir.join(&filename);
264
265        // TOCTOU-safe: Attempt to read metadata directly instead of exists() check.
266        // If the file exists and we're not forcing, return cached result.
267        // If it doesn't exist, continue to download.
268        if !request.force {
269            match fs::metadata(&file_path).await {
270                Ok(metadata) => {
271                    progress(PullProgress::new("cached", 1, 1));
272                    return Ok(DownloadResult {
273                        path: file_path,
274                        size: metadata.len(),
275                        checksum: None,
276                        cached: true,
277                    });
278                }
279                Err(e) if e.kind() == std::io::ErrorKind::NotFound => {
280                    // File doesn't exist, proceed to download
281                }
282                Err(e) => return Err(StorageError::Io(e)),
283            }
284        }
285
286        // Get file info from HuggingFace API
287        progress(PullProgress::new("fetching metadata", 0, 1));
288        let file_info = self.get_file_info(&repo, &filename).await?;
289
290        // Download the file
291        let download_url = format!("https://huggingface.co/{}/resolve/main/{}", repo, filename);
292
293        progress(PullProgress::new("downloading", 0, file_info.size));
294
295        let response = self.client.get(&download_url).send().await?;
296
297        if !response.status().is_success() {
298            return Err(StorageError::ModelNotFound {
299                repo: repo.clone(),
300                filename: filename.clone(),
301            });
302        }
303
304        // Stream download to file with progress
305        let mut file = File::create(&file_path).await?;
306        let mut stream = response.bytes_stream();
307        let mut downloaded: u64 = 0;
308        let mut hasher = Sha256::new();
309
310        while let Some(chunk) = stream.next().await {
311            let chunk = chunk?;
312            hasher.update(&chunk);
313            file.write_all(&chunk).await?;
314            downloaded += chunk.len() as u64;
315
316            progress(PullProgress::new("downloading", downloaded, file_info.size));
317        }
318
319        file.flush().await?;
320        drop(file);
321
322        // Verify checksum
323        let checksum = format!("{:x}", hasher.finalize());
324        if let Some(ref lfs) = file_info.lfs {
325            if checksum != lfs.sha256 {
326                // Delete corrupted file
327                let _ = fs::remove_file(&file_path).await;
328                return Err(StorageError::ChecksumMismatch {
329                    path: file_path,
330                    expected: lfs.sha256.clone(),
331                    actual: checksum,
332                });
333            }
334        }
335
336        progress(PullProgress::new(
337            "complete",
338            file_info.size,
339            file_info.size,
340        ));
341
342        Ok(DownloadResult {
343            path: file_path,
344            size: file_info.size,
345            checksum: Some(checksum),
346            cached: false,
347        })
348    }
349
350    /// Resolve download request to HuggingFace repo and filename.
351    fn resolve_request(&self, request: &DownloadRequest<'_>) -> Result<(String, String)> {
352        if let Some(hf_repo) = &request.hf_repo {
353            let filename = request.filename.clone().ok_or_else(|| {
354                StorageError::InvalidConfig("HuggingFace download requires filename".into())
355            })?;
356            return Ok((hf_repo.clone(), filename));
357        }
358
359        if let Some(model) = request.model {
360            let filename = request.target_filename().ok_or_else(|| {
361                StorageError::InvalidConfig("No quantization available for model".into())
362            })?;
363            return Ok((model.hf_repo.to_string(), filename));
364        }
365
366        Err(StorageError::InvalidConfig(
367            "Download request must specify model or HuggingFace repo".into(),
368        ))
369    }
370
371    /// Get file info from HuggingFace API.
372    async fn get_file_info(&self, repo: &str, filename: &str) -> Result<HfFileInfo> {
373        let api_url = format!("https://huggingface.co/api/models/{}/tree/main", repo);
374
375        let response = self.client.get(&api_url).send().await?;
376
377        if !response.status().is_success() {
378            return Err(StorageError::ModelNotFound {
379                repo: repo.to_string(),
380                filename: filename.to_string(),
381            });
382        }
383
384        let files: Vec<HfFileInfo> = response.json().await?;
385
386        files
387            .into_iter()
388            .find(|f| f.filename == filename)
389            .ok_or_else(|| StorageError::ModelNotFound {
390                repo: repo.to_string(),
391                filename: filename.to_string(),
392            })
393    }
394
395    /// List all downloaded models.
396    pub fn list_models(&self) -> std::result::Result<Vec<ModelInfo>, BackendError> {
397        let mut models = Vec::new();
398
399        // TOCTOU-safe: Attempt to read directory directly instead of exists() check.
400        // If directory doesn't exist, return empty list.
401        let entries = match std::fs::read_dir(&self.storage_dir) {
402            Ok(entries) => entries,
403            Err(e) if e.kind() == std::io::ErrorKind::NotFound => {
404                return Ok(models);
405            }
406            Err(e) => return Err(BackendError::StorageError(e.to_string())),
407        };
408
409        for entry in entries.flatten() {
410            let path = entry.path();
411            if path.is_dir() {
412                // This is a repo directory
413                let repo_name = entry.file_name().to_string_lossy().to_string();
414
415                // List GGUF files in this directory
416                if let Ok(files) = std::fs::read_dir(&path) {
417                    for file in files.flatten() {
418                        let filename = file.file_name().to_string_lossy().to_string();
419                        if filename.ends_with(".gguf") {
420                            if let Ok(metadata) = file.metadata() {
421                                let quant = extract_quantization(&filename);
422                                models.push(ModelInfo {
423                                    name: format!("{}/{}", repo_name, filename),
424                                    size: metadata.len(),
425                                    quantization: quant,
426                                    parameters: None,
427                                    digest: None,
428                                });
429                            }
430                        }
431                    }
432                }
433            }
434        }
435
436        Ok(models)
437    }
438
439    /// Check if a model exists locally.
440    ///
441    /// Returns `false` if the model_id contains path traversal patterns.
442    #[must_use]
443    pub fn exists(&self, model_id: &str) -> bool {
444        self.model_path(model_id)
445            .map(|p| p.exists())
446            .unwrap_or(false)
447    }
448
449    /// Get info about a local model.
450    pub fn model_info(&self, model_id: &str) -> std::result::Result<ModelInfo, BackendError> {
451        let path = self.model_path(model_id)?;
452
453        // TOCTOU-safe: Attempt to read metadata directly instead of exists() check.
454        // This avoids race where file is deleted between exists() and metadata().
455        let metadata = match std::fs::metadata(&path) {
456            Ok(metadata) => metadata,
457            Err(e) if e.kind() == std::io::ErrorKind::NotFound => {
458                return Err(BackendError::ModelNotFound(model_id.to_string()));
459            }
460            Err(e) => return Err(BackendError::StorageError(e.to_string())),
461        };
462
463        let filename = path.file_name().unwrap_or_default().to_string_lossy();
464
465        Ok(ModelInfo {
466            name: model_id.to_string(),
467            size: metadata.len(),
468            quantization: extract_quantization(&filename),
469            parameters: None,
470            digest: None,
471        })
472    }
473
474    /// Delete a local model.
475    pub fn delete(&self, model_id: &str) -> std::result::Result<(), BackendError> {
476        let path = self.model_path(model_id)?;
477
478        // TOCTOU-safe: Attempt to remove directly instead of exists() check.
479        // This avoids race where file is deleted between exists() and remove_file().
480        match std::fs::remove_file(&path) {
481            Ok(()) => Ok(()),
482            Err(e) if e.kind() == std::io::ErrorKind::NotFound => {
483                Err(BackendError::ModelNotFound(model_id.to_string()))
484            }
485            Err(e) => Err(BackendError::StorageError(e.to_string())),
486        }
487    }
488
489    /// Get the path to a model file with path traversal validation.
490    ///
491    /// # Security
492    ///
493    /// Validates that model_id doesn't escape storage directory via `..` or absolute paths.
494    ///
495    /// # Errors
496    ///
497    /// Returns `BackendError::PathTraversal` if the path would escape the storage directory.
498    pub fn model_path(&self, model_id: &str) -> std::result::Result<PathBuf, BackendError> {
499        // model_id format: "repo/filename" or just "filename"
500        // Both cases join to storage_dir
501
502        // Security: Reject absolute paths
503        let model_path = Path::new(model_id);
504        if model_path.is_absolute() {
505            return Err(BackendError::PathTraversal {
506                path: model_id.to_string(),
507            });
508        }
509
510        // Security: Check for path traversal patterns
511        // We normalize the path to handle ".." components
512        let joined = self.storage_dir.join(model_id);
513        let normalized = normalize_path(&joined);
514        let normalized_base = normalize_path(&self.storage_dir);
515
516        if !normalized.starts_with(&normalized_base) {
517            return Err(BackendError::PathTraversal {
518                path: model_id.to_string(),
519            });
520        }
521
522        Ok(joined)
523    }
524}
525
526// ============================================================================
527// ModelStorage Implementation
528// ============================================================================
529
530impl ModelStorage for HuggingFaceStorage {
531    async fn download<F>(
532        &self,
533        request: &DownloadRequest<'_>,
534        progress: F,
535    ) -> Result<DownloadResult>
536    where
537        F: Fn(PullProgress) + Send + 'static,
538    {
539        HuggingFaceStorage::download(self, request, progress).await
540    }
541
542    fn list_models(&self) -> std::result::Result<Vec<ModelInfo>, BackendError> {
543        HuggingFaceStorage::list_models(self)
544    }
545
546    fn exists(&self, model_id: &str) -> bool {
547        HuggingFaceStorage::exists(self, model_id)
548    }
549
550    fn model_info(&self, model_id: &str) -> std::result::Result<ModelInfo, BackendError> {
551        HuggingFaceStorage::model_info(self, model_id)
552    }
553
554    fn delete(&self, model_id: &str) -> std::result::Result<(), BackendError> {
555        HuggingFaceStorage::delete(self, model_id)
556    }
557
558    fn model_path(&self, model_id: &str) -> std::result::Result<PathBuf, BackendError> {
559        HuggingFaceStorage::model_path(self, model_id)
560    }
561}
562
563// ============================================================================
564// Helpers
565// ============================================================================
566
567/// Normalize a path by resolving `.` and `..` components without filesystem access.
568///
569/// This is used for path traversal validation before the path exists.
570/// Adapted from `io/security.rs` for use in storage module.
571fn normalize_path(path: &Path) -> PathBuf {
572    let mut normalized = PathBuf::new();
573
574    for component in path.components() {
575        match component {
576            std::path::Component::ParentDir => {
577                normalized.pop();
578            }
579            std::path::Component::CurDir => {
580                // Skip current directory references
581            }
582            _ => {
583                normalized.push(component);
584            }
585        }
586    }
587
588    normalized
589}
590
591/// Extract quantization level from GGUF filename.
592///
593/// # Examples
594///
595/// ```rust,ignore
596/// assert_eq!(extract_quantization("model-Q4_K_M.gguf"), Some("Q4_K_M".to_string()));
597/// assert_eq!(extract_quantization("model-q8_0.gguf"), Some("Q8_0".to_string()));
598/// assert_eq!(extract_quantization("model.gguf"), None);
599/// ```
600#[must_use]
601pub fn extract_quantization(filename: &str) -> Option<String> {
602    // Common patterns: -Q4_K_M.gguf, -q4_k_m.gguf, -F16.gguf, -f16.gguf
603    let patterns = [
604        "Q4_K_M", "Q4_K_S", "Q5_K_M", "Q5_K_S", "Q6_K", "Q8_0", "Q2_K", "Q3_K_M", "Q3_K_S", "Q4_0",
605        "Q4_1", "Q5_0", "Q5_1", "F16", "F32", "BF16",
606    ];
607
608    let filename_upper = filename.to_uppercase();
609    for pattern in patterns {
610        if filename_upper.contains(pattern) {
611            return Some(pattern.to_string());
612        }
613    }
614
615    None
616}
617
618// ============================================================================
619// Tests
620// ============================================================================
621
622#[cfg(test)]
623mod tests {
624    use super::*;
625    use tempfile::tempdir;
626
627    #[test]
628    fn test_extract_quantization() {
629        assert_eq!(
630            extract_quantization("model-q4_k_m.gguf"),
631            Some("Q4_K_M".to_string())
632        );
633        assert_eq!(
634            extract_quantization("model-Q8_0.gguf"),
635            Some("Q8_0".to_string())
636        );
637        assert_eq!(
638            extract_quantization("model-f16.gguf"),
639            Some("F16".to_string())
640        );
641        assert_eq!(
642            extract_quantization("Qwen3-8B-Q4_K_M.gguf"),
643            Some("Q4_K_M".to_string())
644        );
645        assert_eq!(extract_quantization("model.gguf"), None);
646    }
647
648    #[test]
649    fn test_storage_new() {
650        let dir = tempdir().unwrap();
651        let storage = HuggingFaceStorage::new(dir.path().to_path_buf()).unwrap();
652        assert_eq!(storage.storage_dir(), dir.path());
653    }
654
655    #[test]
656    fn test_model_path() {
657        let dir = tempdir().unwrap();
658        let storage = HuggingFaceStorage::new(dir.path().to_path_buf()).unwrap();
659
660        let path = storage.model_path("repo/model.gguf").unwrap();
661        assert!(path.ends_with("repo/model.gguf"));
662
663        let path = storage.model_path("model.gguf").unwrap();
664        assert!(path.ends_with("model.gguf"));
665    }
666
667    #[test]
668    fn test_model_path_traversal_rejected() {
669        let dir = tempdir().unwrap();
670        let storage = HuggingFaceStorage::new(dir.path().to_path_buf()).unwrap();
671
672        // Test path traversal with ..
673        let result = storage.model_path("../../../etc/passwd");
674        assert!(result.is_err());
675        assert!(matches!(
676            result.unwrap_err(),
677            BackendError::PathTraversal { .. }
678        ));
679
680        // Test absolute path rejection
681        let result = storage.model_path("/etc/passwd");
682        assert!(result.is_err());
683        assert!(matches!(
684            result.unwrap_err(),
685            BackendError::PathTraversal { .. }
686        ));
687
688        // Test valid nested path is accepted
689        let result = storage.model_path("Qwen/Qwen3-8B-Q4_K_M.gguf");
690        assert!(result.is_ok());
691    }
692
693    #[test]
694    fn test_list_models_empty() {
695        let dir = tempdir().unwrap();
696        let storage = HuggingFaceStorage::new(dir.path().to_path_buf()).unwrap();
697        let models = storage.list_models().unwrap();
698        assert!(models.is_empty());
699    }
700
701    #[test]
702    fn test_exists_false() {
703        let dir = tempdir().unwrap();
704        let storage = HuggingFaceStorage::new(dir.path().to_path_buf()).unwrap();
705        assert!(!storage.exists("nonexistent/model.gguf"));
706    }
707
708    #[test]
709    fn test_default_model_dir() {
710        let dir = default_model_dir();
711        assert!(dir.ends_with("nika/models"));
712    }
713
714    #[test]
715    fn test_detect_system_ram() {
716        let ram = detect_system_ram_gb();
717        // Should return something reasonable (> 1GB on any modern system)
718        assert!(ram > 1.0);
719    }
720
721    #[test]
722    fn test_storage_error_display() {
723        let err = StorageError::ModelNotFound {
724            repo: "test/repo".to_string(),
725            filename: "model.gguf".to_string(),
726        };
727        assert_eq!(err.to_string(), "Model not found: test/repo/model.gguf");
728    }
729}