1use std::path::{Path, PathBuf};
7
8use crate::error::PiperError;
9
10#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
12pub struct ModelInfo {
13 pub name: String,
14 pub language: String,
15 pub quality: String,
17 pub description: String,
18 pub model_url: String,
19 pub config_url: String,
20 pub size_bytes: Option<u64>,
21}
22
23pub type ProgressCallback = Box<dyn Fn(DownloadProgress) + Send>;
25
26#[derive(Debug, Clone)]
28pub struct DownloadProgress {
29 pub bytes_downloaded: u64,
30 pub total_bytes: Option<u64>,
31 pub percentage: Option<f64>,
32}
33
34pub fn default_model_dir() -> PathBuf {
43 if let Some(dir) = platform_data_dir() {
44 return dir.join("piper-plus").join("models");
45 }
46
47 if let Ok(home) = std::env::var("HOME") {
49 return PathBuf::from(home).join(".piper-plus").join("models");
50 }
51
52 if let Ok(profile) = std::env::var("USERPROFILE") {
54 return PathBuf::from(profile).join(".piper-plus").join("models");
55 }
56
57 PathBuf::from(".piper-plus").join("models")
58}
59
60fn platform_data_dir() -> Option<PathBuf> {
62 #[cfg(target_os = "linux")]
63 {
64 if let Ok(xdg) = std::env::var("XDG_DATA_HOME") {
66 return Some(PathBuf::from(xdg));
67 }
68 std::env::var("HOME")
69 .ok()
70 .map(|h| PathBuf::from(h).join(".local").join("share"))
71 }
72
73 #[cfg(target_os = "macos")]
74 {
75 std::env::var("HOME")
76 .ok()
77 .map(|h| PathBuf::from(h).join("Library").join("Application Support"))
78 }
79
80 #[cfg(target_os = "windows")]
81 {
82 std::env::var("APPDATA").ok().map(PathBuf::from)
83 }
84
85 #[cfg(not(any(target_os = "linux", target_os = "macos", target_os = "windows")))]
86 {
87 None
88 }
89}
90
91#[cfg(feature = "download")]
96pub fn download_file(
97 url: &str,
98 dest: &Path,
99 progress: Option<ProgressCallback>,
100) -> Result<(), PiperError> {
101 use std::io::{BufWriter, Read as _, Write};
102
103 if let Some(parent) = dest.parent() {
105 std::fs::create_dir_all(parent).map_err(|e| {
106 PiperError::ModelLoad(format!(
107 "failed to create directory {}: {e}",
108 parent.display()
109 ))
110 })?;
111 }
112
113 let client = reqwest::blocking::Client::builder()
114 .connect_timeout(std::time::Duration::from_secs(30))
115 .timeout(std::time::Duration::from_secs(600)) .build()
117 .map_err(|e| PiperError::Download(format!("HTTP client error: {e}")))?;
118
119 let mut response = client
120 .get(url)
121 .send()
122 .map_err(|e| PiperError::Download(format!("download failed: {e}")))?;
123
124 if !response.status().is_success() {
125 return Err(PiperError::ModelLoad(format!(
126 "HTTP {} for {url}",
127 response.status()
128 )));
129 }
130
131 let total_bytes = response.content_length();
132 let mut bytes_downloaded: u64 = 0;
133 const PROGRESS_INTERVAL: u64 = 100 * 1024;
135 let mut next_report = PROGRESS_INTERVAL;
136
137 let file = std::fs::File::create(dest).map_err(|e| {
138 PiperError::ModelLoad(format!("failed to create file {}: {e}", dest.display()))
139 })?;
140 let mut file = BufWriter::with_capacity(256 * 1024, file); let mut buf = [0u8; 64 * 1024];
145 loop {
146 let n = response.read(&mut buf).map_err(|e| {
147 PiperError::ModelLoad(format!("failed to read response body from {url}: {e}"))
148 })?;
149 if n == 0 {
150 break;
151 }
152 file.write_all(&buf[..n]).map_err(|e| {
153 PiperError::ModelLoad(format!("failed to write to {}: {e}", dest.display()))
154 })?;
155 bytes_downloaded += n as u64;
156
157 if let Some(ref cb) = progress
158 && (bytes_downloaded >= next_report || (total_bytes == Some(bytes_downloaded)))
159 {
160 let percentage = total_bytes.map(|t| {
161 if t == 0 {
162 100.0
163 } else {
164 (bytes_downloaded as f64 / t as f64) * 100.0
165 }
166 });
167 cb(DownloadProgress {
168 bytes_downloaded,
169 total_bytes,
170 percentage,
171 });
172 next_report = bytes_downloaded + PROGRESS_INTERVAL;
173 }
174 }
175
176 file.flush()
177 .map_err(|e| PiperError::ModelLoad(format!("failed to flush {}: {e}", dest.display())))?;
178
179 Ok(())
180}
181
182#[cfg(not(feature = "download"))]
186pub fn download_file(
187 _url: &str,
188 _dest: &Path,
189 _progress: Option<ProgressCallback>,
190) -> Result<(), PiperError> {
191 Err(PiperError::ModelLoad(
192 "the \"download\" feature is required for download_file; \
193 rebuild with `--features download`"
194 .to_string(),
195 ))
196}
197
198#[cfg(feature = "download")]
202pub fn download_model(
203 model_info: &ModelInfo,
204 dest_dir: &Path,
205 progress: Option<ProgressCallback>,
206) -> Result<(PathBuf, PathBuf), PiperError> {
207 std::fs::create_dir_all(dest_dir).map_err(|e| {
208 PiperError::ModelLoad(format!(
209 "failed to create model directory {}: {e}",
210 dest_dir.display()
211 ))
212 })?;
213
214 let model_filename =
215 url_filename(&model_info.model_url).unwrap_or_else(|| format!("{}.onnx", model_info.name));
216 let config_filename =
217 url_filename(&model_info.config_url).unwrap_or_else(|| "config.json".to_string());
218
219 let model_path = dest_dir.join(&model_filename);
220 let config_path = dest_dir.join(&config_filename);
221
222 download_file(&model_info.model_url, &model_path, progress)?;
224
225 download_file(&model_info.config_url, &config_path, None)?;
227
228 Ok((model_path, config_path))
229}
230
231#[cfg(not(feature = "download"))]
233pub fn download_model(
234 _model_info: &ModelInfo,
235 _dest_dir: &Path,
236 _progress: Option<ProgressCallback>,
237) -> Result<(PathBuf, PathBuf), PiperError> {
238 Err(PiperError::ModelLoad(
239 "the \"download\" feature is required for download_model; \
240 rebuild with `--features download`"
241 .to_string(),
242 ))
243}
244
245pub fn huggingface_url(repo: &str, filename: &str) -> String {
257 format!("https://huggingface.co/{repo}/resolve/main/{filename}")
258}
259
260pub fn parse_model_registry(json_str: &str) -> Result<Vec<ModelInfo>, PiperError> {
264 let models: Vec<ModelInfo> = serde_json::from_str(json_str)?;
265 Ok(models)
266}
267
268pub fn is_model_cached(model_name: &str, model_dir: &Path) -> bool {
273 let onnx = model_dir.join(format!("{model_name}.onnx"));
274 let onnx_json = model_dir.join(format!("{model_name}.onnx.json"));
275 let config_json = model_dir.join("config.json");
276
277 onnx.exists() && (onnx_json.exists() || config_json.exists())
278}
279
280pub fn builtin_registry() -> &'static [ModelInfo] {
285 use std::sync::OnceLock;
286 static REGISTRY: OnceLock<Vec<ModelInfo>> = OnceLock::new();
287 REGISTRY.get_or_init(|| {
288 vec![
289 ModelInfo {
290 name: "tsukuyomi-6lang-v2".to_string(),
291 language: "ja-en-zh-es-fr-pt".to_string(),
292 quality: "medium".to_string(),
293 description: "Tsukuyomi-chan 6-language model (JA/EN/ZH/ES/FR/PT)".to_string(),
294 model_url: huggingface_url(
295 "ayousanz/piper-plus-tsukuyomi-chan",
296 "tsukuyomi-chan-6lang-fp16.onnx",
297 ),
298 config_url: huggingface_url("ayousanz/piper-plus-tsukuyomi-chan", "config.json"),
299 size_bytes: None,
300 },
301 ModelInfo {
302 name: "css10-6lang".to_string(),
303 language: "ja-en-zh-es-fr-pt".to_string(),
304 quality: "medium".to_string(),
305 description:
306 "CSS10 Japanese 6-language model fine-tuned from multilingual base (FP16)"
307 .to_string(),
308 model_url: huggingface_url(
309 "ayousanz/piper-plus-css10-ja-6lang",
310 "css10-ja-6lang-fp16.onnx",
311 ),
312 config_url: huggingface_url("ayousanz/piper-plus-css10-ja-6lang", "config.json"),
313 size_bytes: Some(39_414_515),
314 },
315 ]
316 })
317}
318
319pub fn find_model(query: &str) -> Option<&'static ModelInfo> {
324 let registry = builtin_registry();
325
326 if let Some(m) = registry.iter().find(|m| m.name == query) {
328 return Some(m);
329 }
330
331 let matches: Vec<_> = registry.iter().filter(|m| m.name.contains(query)).collect();
333 if matches.len() == 1 {
334 return Some(matches[0]);
335 }
336
337 let query_lower = query.to_lowercase();
339 let desc_matches: Vec<_> = registry
340 .iter()
341 .filter(|m| m.description.to_lowercase().contains(&query_lower))
342 .collect();
343 if desc_matches.len() == 1 {
344 return Some(desc_matches[0]);
345 }
346
347 None
348}
349
350pub fn resolve_model_path(
357 model_str: &str,
358 model_dir: Option<&Path>,
359) -> Result<PathBuf, PiperError> {
360 let path = PathBuf::from(model_str);
361
362 if path.is_file() {
364 return Ok(path);
365 } else if path.is_dir() {
366 return Err(PiperError::ModelLoad(format!(
367 "Path '{}' is a directory. Please provide a model file path or a model name.",
368 path.display()
369 )));
370 }
371
372 let model_info = find_model(model_str).ok_or_else(|| {
374 PiperError::ModelLoad(format!(
375 "Model '{}' not found. Use --list-models to see available models, or specify a file path.",
376 model_str
377 ))
378 })?;
379
380 let dir = model_dir
381 .map(PathBuf::from)
382 .unwrap_or_else(default_model_dir);
383
384 if is_model_cached(&model_info.name, &dir) {
386 let model_path = dir.join(format!("{}.onnx", model_info.name));
387 return Ok(model_path);
388 }
389
390 #[cfg(feature = "download")]
392 {
393 eprintln!(
394 "Model '{}' not found locally. Downloading...",
395 model_info.name
396 );
397 let (model_path, _config_path) = download_model(
398 model_info,
399 &dir,
400 Some(Box::new(|progress| {
401 if let Some(pct) = progress.percentage {
402 eprint!("\r Downloading... {:.1}%", pct);
403 }
404 })),
405 )?;
406 eprintln!();
407 eprintln!("Model downloaded to: {}", model_path.display());
408 Ok(model_path)
409 }
410
411 #[cfg(not(feature = "download"))]
412 {
413 Err(PiperError::ModelLoad(format!(
414 "Model '{}' not cached. Download it with: --download-model {}",
415 model_str, model_info.name
416 )))
417 }
418}
419
420#[cfg(any(feature = "download", test))]
424fn url_filename(url: &str) -> Option<String> {
425 let path = url.split('?').next().unwrap_or(url);
426 let path = path.split('#').next().unwrap_or(path);
427 path.rsplit('/')
428 .next()
429 .filter(|s| !s.is_empty())
430 .map(|s| s.to_string())
431}
432
433#[cfg(test)]
437mod tests {
438 use super::*;
439
440 #[test]
443 fn test_huggingface_url_basic() {
444 let url = huggingface_url("owner/repo", "model.onnx");
445 assert_eq!(
446 url,
447 "https://huggingface.co/owner/repo/resolve/main/model.onnx"
448 );
449 }
450
451 #[test]
452 fn test_huggingface_url_with_subdirectory_filename() {
453 let url = huggingface_url("ayousanz/piper-plus-tsukuyomi-chan", "models/v2.onnx");
454 assert_eq!(
455 url,
456 "https://huggingface.co/ayousanz/piper-plus-tsukuyomi-chan/resolve/main/models/v2.onnx"
457 );
458 }
459
460 #[test]
463 fn test_parse_model_registry_valid() {
464 let json = r#"[
465 {
466 "name": "test-model",
467 "language": "ja",
468 "quality": "medium",
469 "description": "A test model",
470 "model_url": "https://example.com/model.onnx",
471 "config_url": "https://example.com/config.json",
472 "size_bytes": 1024
473 }
474 ]"#;
475 let models = parse_model_registry(json).unwrap();
476 assert_eq!(models.len(), 1);
477 assert_eq!(models[0].name, "test-model");
478 assert_eq!(models[0].size_bytes, Some(1024));
479 }
480
481 #[test]
482 fn test_parse_model_registry_empty_array() {
483 let models = parse_model_registry("[]").unwrap();
484 assert!(models.is_empty());
485 }
486
487 #[test]
488 fn test_parse_model_registry_invalid_json() {
489 let result = parse_model_registry("not valid json");
490 assert!(result.is_err());
491 }
492
493 #[test]
494 fn test_parse_model_registry_missing_required_fields() {
495 let json = r#"[{"name": "incomplete"}]"#;
496 let result = parse_model_registry(json);
497 assert!(result.is_err());
498 }
499
500 #[test]
503 fn test_is_model_cached_with_onnx_json() {
504 let dir = tempfile::tempdir().unwrap();
505 let dir_path = dir.path();
506
507 assert!(!is_model_cached("voice", dir_path));
509
510 std::fs::write(dir_path.join("voice.onnx"), b"fake").unwrap();
512 assert!(!is_model_cached("voice", dir_path));
513
514 std::fs::write(dir_path.join("voice.onnx.json"), b"{}").unwrap();
516 assert!(is_model_cached("voice", dir_path));
517 }
518
519 #[test]
520 fn test_is_model_cached_with_config_json() {
521 let dir = tempfile::tempdir().unwrap();
522 let dir_path = dir.path();
523
524 std::fs::write(dir_path.join("voice.onnx"), b"fake").unwrap();
525 std::fs::write(dir_path.join("config.json"), b"{}").unwrap();
526 assert!(is_model_cached("voice", dir_path));
527 }
528
529 #[test]
530 fn test_is_model_cached_missing_onnx() {
531 let dir = tempfile::tempdir().unwrap();
532 let dir_path = dir.path();
533
534 std::fs::write(dir_path.join("config.json"), b"{}").unwrap();
536 assert!(!is_model_cached("voice", dir_path));
537 }
538
539 #[test]
542 fn test_default_model_dir_is_non_empty() {
543 let dir = default_model_dir();
544 assert!(
545 !dir.as_os_str().is_empty(),
546 "default_model_dir must not be empty"
547 );
548 assert_eq!(
550 dir.file_name().and_then(|s| s.to_str()),
551 Some("models"),
552 "expected path to end with 'models', got: {dir:?}"
553 );
554 }
555
556 #[test]
559 fn test_model_info_roundtrip() {
560 let info = ModelInfo {
561 name: "roundtrip-test".to_string(),
562 language: "en".to_string(),
563 quality: "high".to_string(),
564 description: "Roundtrip test model".to_string(),
565 model_url: "https://example.com/m.onnx".to_string(),
566 config_url: "https://example.com/c.json".to_string(),
567 size_bytes: Some(42),
568 };
569
570 let json = serde_json::to_string(&info).unwrap();
571 let deserialized: ModelInfo = serde_json::from_str(&json).unwrap();
572 assert_eq!(deserialized.name, info.name);
573 assert_eq!(deserialized.language, info.language);
574 assert_eq!(deserialized.quality, info.quality);
575 assert_eq!(deserialized.description, info.description);
576 assert_eq!(deserialized.model_url, info.model_url);
577 assert_eq!(deserialized.config_url, info.config_url);
578 assert_eq!(deserialized.size_bytes, info.size_bytes);
579 }
580
581 #[test]
582 fn test_model_info_size_bytes_optional() {
583 let json = r#"{
584 "name": "n",
585 "language": "ja",
586 "quality": "low",
587 "description": "d",
588 "model_url": "https://example.com/m.onnx",
589 "config_url": "https://example.com/c.json",
590 "size_bytes": null
591 }"#;
592 let info: ModelInfo = serde_json::from_str(json).unwrap();
593 assert!(info.size_bytes.is_none());
594 }
595
596 #[test]
599 fn test_builtin_registry_non_empty() {
600 let models = builtin_registry();
601 assert!(
602 models.len() >= 2,
603 "builtin registry should contain at least 2 models"
604 );
605 for m in models {
607 assert!(
608 m.model_url.starts_with("https://"),
609 "bad model_url: {}",
610 m.model_url
611 );
612 assert!(
613 m.config_url.starts_with("https://"),
614 "bad config_url: {}",
615 m.config_url
616 );
617 assert!(!m.name.is_empty());
618 }
619 }
620
621 #[test]
624 fn test_download_progress_percentage() {
625 let progress = DownloadProgress {
626 bytes_downloaded: 50,
627 total_bytes: Some(200),
628 percentage: Some(25.0),
629 };
630 assert_eq!(progress.percentage, Some(25.0));
631 assert_eq!(progress.bytes_downloaded, 50);
632 assert_eq!(progress.total_bytes, Some(200));
633 }
634
635 #[test]
636 fn test_download_progress_unknown_total() {
637 let progress = DownloadProgress {
638 bytes_downloaded: 1024,
639 total_bytes: None,
640 percentage: None,
641 };
642 assert!(progress.total_bytes.is_none());
643 assert!(progress.percentage.is_none());
644 }
645
646 #[test]
649 fn test_url_filename_extraction() {
650 assert_eq!(
651 url_filename("https://example.com/path/to/model.onnx"),
652 Some("model.onnx".to_string())
653 );
654 assert_eq!(url_filename("https://example.com/"), None);
655 assert_eq!(url_filename("model.onnx"), Some("model.onnx".to_string()));
656 }
657
658 #[test]
659 fn test_url_filename_strips_query_string() {
660 assert_eq!(
661 url_filename("https://example.com/model.onnx?token=abc123"),
662 Some("model.onnx".to_string()),
663 );
664 }
665
666 #[test]
667 fn test_url_filename_strips_fragment() {
668 assert_eq!(
669 url_filename("https://example.com/model.onnx#section"),
670 Some("model.onnx".to_string()),
671 );
672 }
673
674 #[test]
675 fn test_url_filename_strips_query_and_fragment() {
676 assert_eq!(
677 url_filename("https://example.com/model.onnx?v=2#top"),
678 Some("model.onnx".to_string()),
679 );
680 }
681
682 #[cfg(not(feature = "download"))]
685 #[test]
686 fn test_download_file_stub_returns_error() {
687 let dir = tempfile::tempdir().unwrap();
688 let dest = dir.path().join("out.onnx");
689 let result = download_file("https://example.com/model.onnx", &dest, None);
690 assert!(result.is_err());
691 let msg = format!("{}", result.unwrap_err());
692 assert!(
693 msg.contains("download"),
694 "error should mention the download feature: {msg}"
695 );
696 }
697
698 #[cfg(not(feature = "download"))]
699 #[test]
700 fn test_download_model_stub_returns_error() {
701 let dir = tempfile::tempdir().unwrap();
702 let info = ModelInfo {
703 name: "test".to_string(),
704 language: "en".to_string(),
705 quality: "medium".to_string(),
706 description: "test".to_string(),
707 model_url: "https://example.com/m.onnx".to_string(),
708 config_url: "https://example.com/c.json".to_string(),
709 size_bytes: None,
710 };
711 let result = download_model(&info, dir.path(), None);
712 assert!(result.is_err());
713 }
714
715 #[test]
718 fn test_download_progress_percentage_zero_total() {
719 let total: Option<u64> = Some(0);
724 let percentage = total.map(|t| {
725 if t == 0 {
726 100.0
727 } else {
728 (50_f64 / t as f64) * 100.0
729 }
730 });
731 let progress = DownloadProgress {
732 bytes_downloaded: 50,
733 total_bytes: total,
734 percentage,
735 };
736 assert_eq!(progress.percentage, Some(100.0));
737 assert_eq!(progress.total_bytes, Some(0));
738 }
739
740 #[test]
741 fn test_model_info_empty_fields() {
742 let info = ModelInfo {
744 name: String::new(),
745 language: String::new(),
746 quality: String::new(),
747 description: String::new(),
748 model_url: String::new(),
749 config_url: String::new(),
750 size_bytes: None,
751 };
752 assert!(info.name.is_empty());
753 assert!(info.size_bytes.is_none());
754
755 let json = serde_json::to_string(&info).unwrap();
757 let back: ModelInfo = serde_json::from_str(&json).unwrap();
758 assert!(back.name.is_empty());
759 }
760
761 #[test]
762 fn test_huggingface_url_special_chars() {
763 let url = huggingface_url("owner/repo with spaces", "model (v2).onnx");
766 assert!(url.starts_with("https://huggingface.co/"));
767 assert!(url.contains("repo with spaces"));
768 assert!(url.contains("model (v2).onnx"));
769
770 let url2 = huggingface_url("user/日本語モデル", "model.onnx");
772 assert!(url2.contains("日本語モデル"));
773 }
774
775 #[test]
776 fn test_is_model_cached_empty_model_name() {
777 let dir = tempfile::tempdir().unwrap();
778 let dir_path = dir.path();
779
780 assert!(!is_model_cached("", dir_path));
783
784 std::fs::write(dir_path.join(".onnx"), b"fake").unwrap();
786 std::fs::write(dir_path.join(".onnx.json"), b"{}").unwrap();
787 assert!(is_model_cached("", dir_path));
788 }
789
790 #[test]
791 fn test_is_model_cached_with_subdirectory() {
792 let nonexistent = PathBuf::from("/tmp/piper_test_nonexistent_dir_12345");
795 assert!(!is_model_cached("some-model", &nonexistent));
796 }
797
798 #[test]
799 fn test_parse_model_registry_extra_fields() {
800 let json = r#"[
802 {
803 "name": "test",
804 "language": "en",
805 "quality": "medium",
806 "description": "desc",
807 "model_url": "https://example.com/m.onnx",
808 "config_url": "https://example.com/c.json",
809 "size_bytes": null,
810 "author": "someone",
811 "license": "MIT",
812 "extra_nested": {"a": 1}
813 }
814 ]"#;
815 let models = parse_model_registry(json).unwrap();
816 assert_eq!(models.len(), 1);
817 assert_eq!(models[0].name, "test");
818 }
819
820 #[test]
821 fn test_parse_model_registry_unicode() {
822 let json = r#"[
824 {
825 "name": "つくよみちゃん",
826 "language": "ja",
827 "quality": "medium",
828 "description": "高品質な日本語音声合成 — 中文描述也可以",
829 "model_url": "https://example.com/model.onnx",
830 "config_url": "https://example.com/config.json",
831 "size_bytes": 999
832 }
833 ]"#;
834 let models = parse_model_registry(json).unwrap();
835 assert_eq!(models[0].name, "つくよみちゃん");
836 assert!(models[0].description.contains("中文"));
837 }
838
839 #[test]
840 fn test_builtin_registry_urls_format() {
841 for m in builtin_registry() {
844 assert!(
845 m.model_url.starts_with("https://") && m.model_url.contains("huggingface"),
846 "model_url must be an HTTPS HuggingFace URL, got: {}",
847 m.model_url,
848 );
849 assert!(
850 m.config_url.starts_with("https://") && m.config_url.contains("huggingface"),
851 "config_url must be an HTTPS HuggingFace URL, got: {}",
852 m.config_url,
853 );
854 }
855 }
856
857 #[test]
858 fn test_default_model_dir_consistent() {
859 let a = default_model_dir();
862 let b = default_model_dir();
863 assert_eq!(a, b, "default_model_dir should be deterministic");
864 }
865
866 #[test]
869 fn test_find_model_exact_name() {
870 let m = find_model("tsukuyomi-6lang-v2");
871 assert!(m.is_some());
872 assert_eq!(m.unwrap().name, "tsukuyomi-6lang-v2");
873 }
874
875 #[test]
876 fn test_find_model_partial_name() {
877 let m = find_model("css10");
879 assert!(m.is_some());
880 assert!(
881 m.unwrap().name.contains("css10"),
882 "partial name match should return a model containing the query string"
883 );
884 }
885
886 #[test]
887 fn test_find_model_description_match() {
888 let m = find_model("Tsukuyomi");
890 assert!(m.is_some());
891 assert!(
892 m.unwrap().description.to_lowercase().contains("tsukuyomi"),
893 "description match should return a model whose description contains the query"
894 );
895 }
896
897 #[test]
898 fn test_find_model_case_insensitive_description() {
899 let m = find_model("tsukuyomi");
900 assert!(m.is_some());
901 assert!(
902 m.unwrap().description.to_lowercase().contains("tsukuyomi"),
903 "case-insensitive description match should find a model"
904 );
905 }
906
907 #[test]
908 fn test_find_model_no_match() {
909 let m = find_model("nonexistent-model-xyz");
910 assert!(m.is_none());
911 }
912
913 #[test]
914 fn test_find_model_ambiguous_returns_none() {
915 let m = find_model("6lang");
917 assert!(m.is_none(), "ambiguous partial match should return None");
918 }
919
920 #[test]
923 fn test_resolve_model_path_existing_file() {
924 let dir = tempfile::tempdir().unwrap();
925 let file = dir.path().join("my-model.onnx");
926 std::fs::write(&file, b"fake onnx").unwrap();
927
928 let resolved = resolve_model_path(file.to_str().unwrap(), None).unwrap();
929 assert_eq!(resolved, file);
930 }
931
932 #[test]
933 fn test_resolve_model_path_cached_model() {
934 let dir = tempfile::tempdir().unwrap();
935 let dir_path = dir.path();
936
937 std::fs::write(dir_path.join("tsukuyomi-6lang-v2.onnx"), b"fake").unwrap();
939 std::fs::write(dir_path.join("tsukuyomi-6lang-v2.onnx.json"), b"{}").unwrap();
940
941 let resolved = resolve_model_path("tsukuyomi-6lang-v2", Some(dir_path)).unwrap();
942 assert_eq!(resolved, dir_path.join("tsukuyomi-6lang-v2.onnx"));
943 }
944
945 #[test]
946 fn test_resolve_model_path_cached_via_alias() {
947 let dir = tempfile::tempdir().unwrap();
948 let dir_path = dir.path();
949
950 std::fs::write(dir_path.join("css10-6lang.onnx"), b"fake").unwrap();
952 std::fs::write(dir_path.join("css10-6lang.onnx.json"), b"{}").unwrap();
953
954 let resolved = resolve_model_path("css10", Some(dir_path)).unwrap();
955 assert_eq!(resolved, dir_path.join("css10-6lang.onnx"));
956 }
957
958 #[test]
959 fn test_resolve_model_path_unknown_model_error() {
960 let result = resolve_model_path("nonexistent-model-xyz", None);
961 assert!(result.is_err());
962 let msg = format!("{}", result.unwrap_err());
963 assert!(msg.contains("not found"), "error message: {msg}");
964 }
965}