1use crate::core::backend::{
32 BackendError, DownloadRequest, DownloadResult, ModelInfo, PullProgress,
33};
34use futures::StreamExt;
35use reqwest::Client;
36use serde::Deserialize;
37use sha2::{Digest, Sha256};
38use std::path::{Path, PathBuf};
39use thiserror::Error;
40use tokio::fs::{self, File};
41use tokio::io::AsyncWriteExt;
42
43#[derive(Error, Debug)]
49pub enum StorageError {
50 #[error("Model not found: {repo}/{filename}")]
52 ModelNotFound {
53 repo: String,
55 filename: String,
57 },
58
59 #[error("Checksum mismatch for {path}: expected {expected}, got {actual}")]
61 ChecksumMismatch {
62 path: PathBuf,
64 expected: String,
66 actual: String,
68 },
69
70 #[error("Invalid configuration: {0}")]
72 InvalidConfig(String),
73
74 #[error("Network error: {0}")]
76 Network(#[from] reqwest::Error),
77
78 #[error("I/O error: {0}")]
80 Io(#[from] std::io::Error),
81}
82
83pub type Result<T> = std::result::Result<T, StorageError>;
85
86#[derive(Debug, Deserialize)]
92struct HfFileInfo {
93 #[serde(rename = "path")]
95 filename: String,
96 size: u64,
98 lfs: Option<HfLfsInfo>,
100}
101
102#[derive(Debug, Deserialize)]
104struct HfLfsInfo {
105 #[serde(rename = "oid")]
107 sha256: String,
108}
109
110#[must_use]
121pub fn default_model_dir() -> PathBuf {
122 dirs::data_dir()
123 .unwrap_or_else(|| PathBuf::from("."))
124 .join("nika")
125 .join("models")
126}
127
128#[must_use]
137pub fn detect_system_ram_gb() -> f64 {
138 crate::util::system::get_total_ram_gb()
139}
140
141#[allow(async_fn_in_trait)]
149pub trait ModelStorage {
150 async fn download<F>(
152 &self,
153 request: &DownloadRequest<'_>,
154 progress: F,
155 ) -> Result<DownloadResult>
156 where
157 F: Fn(PullProgress) + Send + 'static;
158
159 fn list_models(&self) -> std::result::Result<Vec<ModelInfo>, BackendError>;
161
162 fn exists(&self, model_id: &str) -> bool;
164
165 fn model_info(&self, model_id: &str) -> std::result::Result<ModelInfo, BackendError>;
167
168 fn delete(&self, model_id: &str) -> std::result::Result<(), BackendError>;
170
171 fn model_path(&self, model_id: &str) -> std::result::Result<PathBuf, BackendError>;
181}
182
183pub struct HuggingFaceStorage {
192 storage_dir: PathBuf,
194 client: Client,
196}
197
198impl HuggingFaceStorage {
199 pub fn new(storage_dir: PathBuf) -> Result<Self> {
205 let user_agent = format!("nika/{}", env!("CARGO_PKG_VERSION"));
206 let client = Client::builder()
207 .user_agent(&user_agent)
208 .build()
209 .map_err(|e| {
210 StorageError::InvalidConfig(format!("Failed to create HTTP client: {e}"))
211 })?;
212
213 Ok(Self {
214 storage_dir,
215 client,
216 })
217 }
218
219 #[must_use]
221 pub fn with_client(storage_dir: PathBuf, client: Client) -> Self {
222 Self {
223 storage_dir,
224 client,
225 }
226 }
227
228 #[must_use]
230 pub fn storage_dir(&self) -> &Path {
231 &self.storage_dir
232 }
233
234 pub async fn download<F>(
249 &self,
250 request: &DownloadRequest<'_>,
251 progress: F,
252 ) -> Result<DownloadResult>
253 where
254 F: Fn(PullProgress) + Send + 'static,
255 {
256 let (repo, filename) = self.resolve_request(request)?;
258
259 let model_dir = self.storage_dir.join(&repo);
261 fs::create_dir_all(&model_dir).await?;
262
263 let file_path = model_dir.join(&filename);
264
265 if !request.force {
269 match fs::metadata(&file_path).await {
270 Ok(metadata) => {
271 progress(PullProgress::new("cached", 1, 1));
272 return Ok(DownloadResult {
273 path: file_path,
274 size: metadata.len(),
275 checksum: None,
276 cached: true,
277 });
278 }
279 Err(e) if e.kind() == std::io::ErrorKind::NotFound => {
280 }
282 Err(e) => return Err(StorageError::Io(e)),
283 }
284 }
285
286 progress(PullProgress::new("fetching metadata", 0, 1));
288 let file_info = self.get_file_info(&repo, &filename).await?;
289
290 let download_url = format!("https://huggingface.co/{}/resolve/main/{}", repo, filename);
292
293 progress(PullProgress::new("downloading", 0, file_info.size));
294
295 let response = self.client.get(&download_url).send().await?;
296
297 if !response.status().is_success() {
298 return Err(StorageError::ModelNotFound {
299 repo: repo.clone(),
300 filename: filename.clone(),
301 });
302 }
303
304 let mut file = File::create(&file_path).await?;
306 let mut stream = response.bytes_stream();
307 let mut downloaded: u64 = 0;
308 let mut hasher = Sha256::new();
309
310 while let Some(chunk) = stream.next().await {
311 let chunk = chunk?;
312 hasher.update(&chunk);
313 file.write_all(&chunk).await?;
314 downloaded += chunk.len() as u64;
315
316 progress(PullProgress::new("downloading", downloaded, file_info.size));
317 }
318
319 file.flush().await?;
320 drop(file);
321
322 let checksum = format!("{:x}", hasher.finalize());
324 if let Some(ref lfs) = file_info.lfs {
325 if checksum != lfs.sha256 {
326 let _ = fs::remove_file(&file_path).await;
328 return Err(StorageError::ChecksumMismatch {
329 path: file_path,
330 expected: lfs.sha256.clone(),
331 actual: checksum,
332 });
333 }
334 }
335
336 progress(PullProgress::new(
337 "complete",
338 file_info.size,
339 file_info.size,
340 ));
341
342 Ok(DownloadResult {
343 path: file_path,
344 size: file_info.size,
345 checksum: Some(checksum),
346 cached: false,
347 })
348 }
349
350 fn resolve_request(&self, request: &DownloadRequest<'_>) -> Result<(String, String)> {
352 if let Some(hf_repo) = &request.hf_repo {
353 let filename = request.filename.clone().ok_or_else(|| {
354 StorageError::InvalidConfig("HuggingFace download requires filename".into())
355 })?;
356 return Ok((hf_repo.clone(), filename));
357 }
358
359 if let Some(model) = request.model {
360 let filename = request.target_filename().ok_or_else(|| {
361 StorageError::InvalidConfig("No quantization available for model".into())
362 })?;
363 return Ok((model.hf_repo.to_string(), filename));
364 }
365
366 Err(StorageError::InvalidConfig(
367 "Download request must specify model or HuggingFace repo".into(),
368 ))
369 }
370
371 async fn get_file_info(&self, repo: &str, filename: &str) -> Result<HfFileInfo> {
373 let api_url = format!("https://huggingface.co/api/models/{}/tree/main", repo);
374
375 let response = self.client.get(&api_url).send().await?;
376
377 if !response.status().is_success() {
378 return Err(StorageError::ModelNotFound {
379 repo: repo.to_string(),
380 filename: filename.to_string(),
381 });
382 }
383
384 let files: Vec<HfFileInfo> = response.json().await?;
385
386 files
387 .into_iter()
388 .find(|f| f.filename == filename)
389 .ok_or_else(|| StorageError::ModelNotFound {
390 repo: repo.to_string(),
391 filename: filename.to_string(),
392 })
393 }
394
395 pub fn list_models(&self) -> std::result::Result<Vec<ModelInfo>, BackendError> {
397 let mut models = Vec::new();
398
399 let entries = match std::fs::read_dir(&self.storage_dir) {
402 Ok(entries) => entries,
403 Err(e) if e.kind() == std::io::ErrorKind::NotFound => {
404 return Ok(models);
405 }
406 Err(e) => return Err(BackendError::StorageError(e.to_string())),
407 };
408
409 for entry in entries.flatten() {
410 let path = entry.path();
411 if path.is_dir() {
412 let repo_name = entry.file_name().to_string_lossy().to_string();
414
415 if let Ok(files) = std::fs::read_dir(&path) {
417 for file in files.flatten() {
418 let filename = file.file_name().to_string_lossy().to_string();
419 if filename.ends_with(".gguf") {
420 if let Ok(metadata) = file.metadata() {
421 let quant = extract_quantization(&filename);
422 models.push(ModelInfo {
423 name: format!("{}/{}", repo_name, filename),
424 size: metadata.len(),
425 quantization: quant,
426 parameters: None,
427 digest: None,
428 });
429 }
430 }
431 }
432 }
433 }
434 }
435
436 Ok(models)
437 }
438
439 #[must_use]
443 pub fn exists(&self, model_id: &str) -> bool {
444 self.model_path(model_id)
445 .map(|p| p.exists())
446 .unwrap_or(false)
447 }
448
449 pub fn model_info(&self, model_id: &str) -> std::result::Result<ModelInfo, BackendError> {
451 let path = self.model_path(model_id)?;
452
453 let metadata = match std::fs::metadata(&path) {
456 Ok(metadata) => metadata,
457 Err(e) if e.kind() == std::io::ErrorKind::NotFound => {
458 return Err(BackendError::ModelNotFound(model_id.to_string()));
459 }
460 Err(e) => return Err(BackendError::StorageError(e.to_string())),
461 };
462
463 let filename = path.file_name().unwrap_or_default().to_string_lossy();
464
465 Ok(ModelInfo {
466 name: model_id.to_string(),
467 size: metadata.len(),
468 quantization: extract_quantization(&filename),
469 parameters: None,
470 digest: None,
471 })
472 }
473
474 pub fn delete(&self, model_id: &str) -> std::result::Result<(), BackendError> {
476 let path = self.model_path(model_id)?;
477
478 match std::fs::remove_file(&path) {
481 Ok(()) => Ok(()),
482 Err(e) if e.kind() == std::io::ErrorKind::NotFound => {
483 Err(BackendError::ModelNotFound(model_id.to_string()))
484 }
485 Err(e) => Err(BackendError::StorageError(e.to_string())),
486 }
487 }
488
489 pub fn model_path(&self, model_id: &str) -> std::result::Result<PathBuf, BackendError> {
499 let model_path = Path::new(model_id);
504 if model_path.is_absolute() {
505 return Err(BackendError::PathTraversal {
506 path: model_id.to_string(),
507 });
508 }
509
510 let joined = self.storage_dir.join(model_id);
513 let normalized = normalize_path(&joined);
514 let normalized_base = normalize_path(&self.storage_dir);
515
516 if !normalized.starts_with(&normalized_base) {
517 return Err(BackendError::PathTraversal {
518 path: model_id.to_string(),
519 });
520 }
521
522 Ok(joined)
523 }
524}
525
526impl ModelStorage for HuggingFaceStorage {
531 async fn download<F>(
532 &self,
533 request: &DownloadRequest<'_>,
534 progress: F,
535 ) -> Result<DownloadResult>
536 where
537 F: Fn(PullProgress) + Send + 'static,
538 {
539 HuggingFaceStorage::download(self, request, progress).await
540 }
541
542 fn list_models(&self) -> std::result::Result<Vec<ModelInfo>, BackendError> {
543 HuggingFaceStorage::list_models(self)
544 }
545
546 fn exists(&self, model_id: &str) -> bool {
547 HuggingFaceStorage::exists(self, model_id)
548 }
549
550 fn model_info(&self, model_id: &str) -> std::result::Result<ModelInfo, BackendError> {
551 HuggingFaceStorage::model_info(self, model_id)
552 }
553
554 fn delete(&self, model_id: &str) -> std::result::Result<(), BackendError> {
555 HuggingFaceStorage::delete(self, model_id)
556 }
557
558 fn model_path(&self, model_id: &str) -> std::result::Result<PathBuf, BackendError> {
559 HuggingFaceStorage::model_path(self, model_id)
560 }
561}
562
563fn normalize_path(path: &Path) -> PathBuf {
572 let mut normalized = PathBuf::new();
573
574 for component in path.components() {
575 match component {
576 std::path::Component::ParentDir => {
577 normalized.pop();
578 }
579 std::path::Component::CurDir => {
580 }
582 _ => {
583 normalized.push(component);
584 }
585 }
586 }
587
588 normalized
589}
590
591#[must_use]
601pub fn extract_quantization(filename: &str) -> Option<String> {
602 let patterns = [
604 "Q4_K_M", "Q4_K_S", "Q5_K_M", "Q5_K_S", "Q6_K", "Q8_0", "Q2_K", "Q3_K_M", "Q3_K_S", "Q4_0",
605 "Q4_1", "Q5_0", "Q5_1", "F16", "F32", "BF16",
606 ];
607
608 let filename_upper = filename.to_uppercase();
609 for pattern in patterns {
610 if filename_upper.contains(pattern) {
611 return Some(pattern.to_string());
612 }
613 }
614
615 None
616}
617
618#[cfg(test)]
623mod tests {
624 use super::*;
625 use tempfile::tempdir;
626
627 #[test]
628 fn test_extract_quantization() {
629 assert_eq!(
630 extract_quantization("model-q4_k_m.gguf"),
631 Some("Q4_K_M".to_string())
632 );
633 assert_eq!(
634 extract_quantization("model-Q8_0.gguf"),
635 Some("Q8_0".to_string())
636 );
637 assert_eq!(
638 extract_quantization("model-f16.gguf"),
639 Some("F16".to_string())
640 );
641 assert_eq!(
642 extract_quantization("Qwen3-8B-Q4_K_M.gguf"),
643 Some("Q4_K_M".to_string())
644 );
645 assert_eq!(extract_quantization("model.gguf"), None);
646 }
647
648 #[test]
649 fn test_storage_new() {
650 let dir = tempdir().unwrap();
651 let storage = HuggingFaceStorage::new(dir.path().to_path_buf()).unwrap();
652 assert_eq!(storage.storage_dir(), dir.path());
653 }
654
655 #[test]
656 fn test_model_path() {
657 let dir = tempdir().unwrap();
658 let storage = HuggingFaceStorage::new(dir.path().to_path_buf()).unwrap();
659
660 let path = storage.model_path("repo/model.gguf").unwrap();
661 assert!(path.ends_with("repo/model.gguf"));
662
663 let path = storage.model_path("model.gguf").unwrap();
664 assert!(path.ends_with("model.gguf"));
665 }
666
667 #[test]
668 fn test_model_path_traversal_rejected() {
669 let dir = tempdir().unwrap();
670 let storage = HuggingFaceStorage::new(dir.path().to_path_buf()).unwrap();
671
672 let result = storage.model_path("../../../etc/passwd");
674 assert!(result.is_err());
675 assert!(matches!(
676 result.unwrap_err(),
677 BackendError::PathTraversal { .. }
678 ));
679
680 let result = storage.model_path("/etc/passwd");
682 assert!(result.is_err());
683 assert!(matches!(
684 result.unwrap_err(),
685 BackendError::PathTraversal { .. }
686 ));
687
688 let result = storage.model_path("Qwen/Qwen3-8B-Q4_K_M.gguf");
690 assert!(result.is_ok());
691 }
692
693 #[test]
694 fn test_list_models_empty() {
695 let dir = tempdir().unwrap();
696 let storage = HuggingFaceStorage::new(dir.path().to_path_buf()).unwrap();
697 let models = storage.list_models().unwrap();
698 assert!(models.is_empty());
699 }
700
701 #[test]
702 fn test_exists_false() {
703 let dir = tempdir().unwrap();
704 let storage = HuggingFaceStorage::new(dir.path().to_path_buf()).unwrap();
705 assert!(!storage.exists("nonexistent/model.gguf"));
706 }
707
708 #[test]
709 fn test_default_model_dir() {
710 let dir = default_model_dir();
711 assert!(dir.ends_with("nika/models"));
712 }
713
714 #[test]
715 fn test_detect_system_ram() {
716 let ram = detect_system_ram_gb();
717 assert!(ram > 1.0);
719 }
720
721 #[test]
722 fn test_storage_error_display() {
723 let err = StorageError::ModelNotFound {
724 repo: "test/repo".to_string(),
725 filename: "model.gguf".to_string(),
726 };
727 assert_eq!(err.to_string(), "Model not found: test/repo/model.gguf");
728 }
729}