Skip to main content

piper_plus/
model_download.rs

1//! Model download and management utilities.
2//!
3//! Download ONNX models and config files from HuggingFace or direct URLs.
4//! Feature-gated behind "download" feature (requires reqwest).
5
6use std::path::{Path, PathBuf};
7
8use crate::error::PiperError;
9
10/// Model metadata for a downloadable voice.
11#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
12pub struct ModelInfo {
13    pub name: String,
14    pub language: String,
15    /// Quality tier: "low", "medium", or "high".
16    pub quality: String,
17    pub description: String,
18    pub model_url: String,
19    pub config_url: String,
20    pub size_bytes: Option<u64>,
21}
22
23/// Download progress callback.
24pub type ProgressCallback = Box<dyn Fn(DownloadProgress) + Send>;
25
26/// Download progress information.
27#[derive(Debug, Clone)]
28pub struct DownloadProgress {
29    pub bytes_downloaded: u64,
30    pub total_bytes: Option<u64>,
31    pub percentage: Option<f64>,
32}
33
34/// Default model directory based on the current platform.
35///
36/// - Linux: `~/.local/share/piper-plus/models/`
37/// - macOS: `~/Library/Application Support/piper-plus/models/`
38/// - Windows: `%APPDATA%/piper-plus/models/`
39///
40/// Falls back to `~/.piper-plus/models/` if the home directory cannot be
41/// determined through standard means.
42pub fn default_model_dir() -> PathBuf {
43    if let Some(dir) = platform_data_dir() {
44        return dir.join("piper-plus").join("models");
45    }
46
47    // Fallback: try HOME environment variable directly.
48    if let Ok(home) = std::env::var("HOME") {
49        return PathBuf::from(home).join(".piper-plus").join("models");
50    }
51
52    // Last resort on Windows.
53    if let Ok(profile) = std::env::var("USERPROFILE") {
54        return PathBuf::from(profile).join(".piper-plus").join("models");
55    }
56
57    PathBuf::from(".piper-plus").join("models")
58}
59
60/// Platform-specific data directory without pulling in the `dirs` crate.
61fn platform_data_dir() -> Option<PathBuf> {
62    #[cfg(target_os = "linux")]
63    {
64        // XDG_DATA_HOME or ~/.local/share
65        if let Ok(xdg) = std::env::var("XDG_DATA_HOME") {
66            return Some(PathBuf::from(xdg));
67        }
68        std::env::var("HOME")
69            .ok()
70            .map(|h| PathBuf::from(h).join(".local").join("share"))
71    }
72
73    #[cfg(target_os = "macos")]
74    {
75        std::env::var("HOME")
76            .ok()
77            .map(|h| PathBuf::from(h).join("Library").join("Application Support"))
78    }
79
80    #[cfg(target_os = "windows")]
81    {
82        std::env::var("APPDATA").ok().map(PathBuf::from)
83    }
84
85    #[cfg(not(any(target_os = "linux", target_os = "macos", target_os = "windows")))]
86    {
87        None
88    }
89}
90
91/// Download a file from `url` to `dest`, calling `progress` periodically.
92///
93/// This is the feature-gated implementation that requires the `download`
94/// Cargo feature (which brings in `reqwest`).
95#[cfg(feature = "download")]
96pub fn download_file(
97    url: &str,
98    dest: &Path,
99    progress: Option<ProgressCallback>,
100) -> Result<(), PiperError> {
101    use std::io::{BufWriter, Read as _, Write};
102
103    // Ensure the parent directory exists.
104    if let Some(parent) = dest.parent() {
105        std::fs::create_dir_all(parent).map_err(|e| {
106            PiperError::ModelLoad(format!(
107                "failed to create directory {}: {e}",
108                parent.display()
109            ))
110        })?;
111    }
112
113    let client = reqwest::blocking::Client::builder()
114        .connect_timeout(std::time::Duration::from_secs(30))
115        .timeout(std::time::Duration::from_secs(600)) // 10 min for large models
116        .build()
117        .map_err(|e| PiperError::Download(format!("HTTP client error: {e}")))?;
118
119    let mut response = client
120        .get(url)
121        .send()
122        .map_err(|e| PiperError::Download(format!("download failed: {e}")))?;
123
124    if !response.status().is_success() {
125        return Err(PiperError::ModelLoad(format!(
126            "HTTP {} for {url}",
127            response.status()
128        )));
129    }
130
131    let total_bytes = response.content_length();
132    let mut bytes_downloaded: u64 = 0;
133    // ~100 KB progress granularity.
134    const PROGRESS_INTERVAL: u64 = 100 * 1024;
135    let mut next_report = PROGRESS_INTERVAL;
136
137    let file = std::fs::File::create(dest).map_err(|e| {
138        PiperError::ModelLoad(format!("failed to create file {}: {e}", dest.display()))
139    })?;
140    let mut file = BufWriter::with_capacity(256 * 1024, file); // 256KB buffer
141
142    // Stream directly from the response to disk to avoid loading
143    // the entire body into memory.
144    let mut buf = [0u8; 64 * 1024];
145    loop {
146        let n = response.read(&mut buf).map_err(|e| {
147            PiperError::ModelLoad(format!("failed to read response body from {url}: {e}"))
148        })?;
149        if n == 0 {
150            break;
151        }
152        file.write_all(&buf[..n]).map_err(|e| {
153            PiperError::ModelLoad(format!("failed to write to {}: {e}", dest.display()))
154        })?;
155        bytes_downloaded += n as u64;
156
157        if let Some(ref cb) = progress
158            && (bytes_downloaded >= next_report || (total_bytes == Some(bytes_downloaded)))
159        {
160            let percentage = total_bytes.map(|t| {
161                if t == 0 {
162                    100.0
163                } else {
164                    (bytes_downloaded as f64 / t as f64) * 100.0
165                }
166            });
167            cb(DownloadProgress {
168                bytes_downloaded,
169                total_bytes,
170                percentage,
171            });
172            next_report = bytes_downloaded + PROGRESS_INTERVAL;
173        }
174    }
175
176    file.flush()
177        .map_err(|e| PiperError::ModelLoad(format!("failed to flush {}: {e}", dest.display())))?;
178
179    Ok(())
180}
181
182/// Stub when the `download` feature is not enabled.
183///
184/// Returns an error indicating that the feature must be enabled.
185#[cfg(not(feature = "download"))]
186pub fn download_file(
187    _url: &str,
188    _dest: &Path,
189    _progress: Option<ProgressCallback>,
190) -> Result<(), PiperError> {
191    Err(PiperError::ModelLoad(
192        "the \"download\" feature is required for download_file; \
193         rebuild with `--features download`"
194            .to_string(),
195    ))
196}
197
198/// Download a model (ONNX + config.json) from HuggingFace.
199///
200/// Creates `dest_dir` if it does not exist. Returns `(model_path, config_path)`.
201#[cfg(feature = "download")]
202pub fn download_model(
203    model_info: &ModelInfo,
204    dest_dir: &Path,
205    progress: Option<ProgressCallback>,
206) -> Result<(PathBuf, PathBuf), PiperError> {
207    std::fs::create_dir_all(dest_dir).map_err(|e| {
208        PiperError::ModelLoad(format!(
209            "failed to create model directory {}: {e}",
210            dest_dir.display()
211        ))
212    })?;
213
214    let model_filename =
215        url_filename(&model_info.model_url).unwrap_or_else(|| format!("{}.onnx", model_info.name));
216    let config_filename =
217        url_filename(&model_info.config_url).unwrap_or_else(|| "config.json".to_string());
218
219    let model_path = dest_dir.join(&model_filename);
220    let config_path = dest_dir.join(&config_filename);
221
222    // Download model file (with progress).
223    download_file(&model_info.model_url, &model_path, progress)?;
224
225    // Download config file (no progress -- typically tiny).
226    download_file(&model_info.config_url, &config_path, None)?;
227
228    Ok((model_path, config_path))
229}
230
231/// Stub when the `download` feature is not enabled.
232#[cfg(not(feature = "download"))]
233pub fn download_model(
234    _model_info: &ModelInfo,
235    _dest_dir: &Path,
236    _progress: Option<ProgressCallback>,
237) -> Result<(PathBuf, PathBuf), PiperError> {
238    Err(PiperError::ModelLoad(
239        "the \"download\" feature is required for download_model; \
240         rebuild with `--features download`"
241            .to_string(),
242    ))
243}
244
245/// Construct a HuggingFace download URL from a repo identifier and filename.
246///
247/// Format: `https://huggingface.co/{repo}/resolve/main/{filename}`
248///
249/// # Examples
250///
251/// ```
252/// # use piper_plus::model_download::huggingface_url;
253/// let url = huggingface_url("ayousanz/piper-plus-tsukuyomi-chan", "model.onnx");
254/// assert_eq!(url, "https://huggingface.co/ayousanz/piper-plus-tsukuyomi-chan/resolve/main/model.onnx");
255/// ```
256pub fn huggingface_url(repo: &str, filename: &str) -> String {
257    format!("https://huggingface.co/{repo}/resolve/main/{filename}")
258}
259
260/// Parse a model registry from a JSON string.
261///
262/// The JSON should be an array of [`ModelInfo`] objects.
263pub fn parse_model_registry(json_str: &str) -> Result<Vec<ModelInfo>, PiperError> {
264    let models: Vec<ModelInfo> = serde_json::from_str(json_str)?;
265    Ok(models)
266}
267
268/// Check whether a model named `model_name` is already cached in `model_dir`.
269///
270/// A model is considered cached when both `{model_name}.onnx` and
271/// `{model_name}.onnx.json` (or `config.json`) exist inside the directory.
272pub fn is_model_cached(model_name: &str, model_dir: &Path) -> bool {
273    let onnx = model_dir.join(format!("{model_name}.onnx"));
274    let onnx_json = model_dir.join(format!("{model_name}.onnx.json"));
275    let config_json = model_dir.join("config.json");
276
277    onnx.exists() && (onnx_json.exists() || config_json.exists())
278}
279
280/// Built-in model registry with known Piper-Plus models.
281///
282/// The registry is lazily initialised once and then shared for the lifetime
283/// of the process, avoiding repeated heap allocations on every call.
284pub fn builtin_registry() -> &'static [ModelInfo] {
285    use std::sync::OnceLock;
286    static REGISTRY: OnceLock<Vec<ModelInfo>> = OnceLock::new();
287    REGISTRY.get_or_init(|| {
288        vec![
289            ModelInfo {
290                name: "tsukuyomi-6lang-v2".to_string(),
291                language: "ja-en-zh-es-fr-pt".to_string(),
292                quality: "medium".to_string(),
293                description: "Tsukuyomi-chan 6-language model (JA/EN/ZH/ES/FR/PT)".to_string(),
294                model_url: huggingface_url(
295                    "ayousanz/piper-plus-tsukuyomi-chan",
296                    "tsukuyomi-chan-6lang-fp16.onnx",
297                ),
298                config_url: huggingface_url("ayousanz/piper-plus-tsukuyomi-chan", "config.json"),
299                size_bytes: None,
300            },
301            ModelInfo {
302                name: "css10-6lang".to_string(),
303                language: "ja-en-zh-es-fr-pt".to_string(),
304                quality: "medium".to_string(),
305                description:
306                    "CSS10 Japanese 6-language model fine-tuned from multilingual base (FP16)"
307                        .to_string(),
308                model_url: huggingface_url(
309                    "ayousanz/piper-plus-css10-ja-6lang",
310                    "css10-ja-6lang-fp16.onnx",
311                ),
312                config_url: huggingface_url("ayousanz/piper-plus-css10-ja-6lang", "config.json"),
313                size_bytes: Some(39_414_515),
314            },
315        ]
316    })
317}
318
319/// Find a model by name or alias in the built-in registry.
320///
321/// Supports exact name match, unique partial match (contains), and unique
322/// description match (case-insensitive).
323pub fn find_model(query: &str) -> Option<&'static ModelInfo> {
324    let registry = builtin_registry();
325
326    // 1. Exact name match
327    if let Some(m) = registry.iter().find(|m| m.name == query) {
328        return Some(m);
329    }
330
331    // 2. Partial name match (contains)
332    let matches: Vec<_> = registry.iter().filter(|m| m.name.contains(query)).collect();
333    if matches.len() == 1 {
334        return Some(matches[0]);
335    }
336
337    // 3. Check if query matches any part of the description
338    let query_lower = query.to_lowercase();
339    let desc_matches: Vec<_> = registry
340        .iter()
341        .filter(|m| m.description.to_lowercase().contains(&query_lower))
342        .collect();
343    if desc_matches.len() == 1 {
344        return Some(desc_matches[0]);
345    }
346
347    None
348}
349
350/// Resolve a model path from a name, alias, or file path.
351///
352/// 1. If the string is a path to an existing file, return it directly.
353/// 2. If it matches a model name in the registry, look in `model_dir` for a
354///    cached copy.
355/// 3. If not cached, auto-download when the `download` feature is enabled.
356pub fn resolve_model_path(
357    model_str: &str,
358    model_dir: Option<&Path>,
359) -> Result<PathBuf, PiperError> {
360    let path = PathBuf::from(model_str);
361
362    // 1. Direct file path
363    if path.is_file() {
364        return Ok(path);
365    } else if path.is_dir() {
366        return Err(PiperError::ModelLoad(format!(
367            "Path '{}' is a directory. Please provide a model file path or a model name.",
368            path.display()
369        )));
370    }
371
372    // 2. Try as model name
373    let model_info = find_model(model_str).ok_or_else(|| {
374        PiperError::ModelLoad(format!(
375            "Model '{}' not found. Use --list-models to see available models, or specify a file path.",
376            model_str
377        ))
378    })?;
379
380    let dir = model_dir
381        .map(PathBuf::from)
382        .unwrap_or_else(default_model_dir);
383
384    // Check if already cached
385    if is_model_cached(&model_info.name, &dir) {
386        let model_path = dir.join(format!("{}.onnx", model_info.name));
387        return Ok(model_path);
388    }
389
390    // 3. Auto-download
391    #[cfg(feature = "download")]
392    {
393        eprintln!(
394            "Model '{}' not found locally. Downloading...",
395            model_info.name
396        );
397        let (model_path, _config_path) = download_model(
398            model_info,
399            &dir,
400            Some(Box::new(|progress| {
401                if let Some(pct) = progress.percentage {
402                    eprint!("\r  Downloading... {:.1}%", pct);
403                }
404            })),
405        )?;
406        eprintln!();
407        eprintln!("Model downloaded to: {}", model_path.display());
408        Ok(model_path)
409    }
410
411    #[cfg(not(feature = "download"))]
412    {
413        Err(PiperError::ModelLoad(format!(
414            "Model '{}' not cached. Download it with: --download-model {}",
415            model_str, model_info.name
416        )))
417    }
418}
419
420/// Extract the filename component from a URL path.
421///
422/// Returns `None` if the URL has no path segments or the last segment is empty.
423#[cfg(any(feature = "download", test))]
424fn url_filename(url: &str) -> Option<String> {
425    let path = url.split('?').next().unwrap_or(url);
426    let path = path.split('#').next().unwrap_or(path);
427    path.rsplit('/')
428        .next()
429        .filter(|s| !s.is_empty())
430        .map(|s| s.to_string())
431}
432
433// ---------------------------------------------------------------------------
434// Tests
435// ---------------------------------------------------------------------------
436#[cfg(test)]
437mod tests {
438    use super::*;
439
440    // -- huggingface_url construction -----------------------------------------
441
442    #[test]
443    fn test_huggingface_url_basic() {
444        let url = huggingface_url("owner/repo", "model.onnx");
445        assert_eq!(
446            url,
447            "https://huggingface.co/owner/repo/resolve/main/model.onnx"
448        );
449    }
450
451    #[test]
452    fn test_huggingface_url_with_subdirectory_filename() {
453        let url = huggingface_url("ayousanz/piper-plus-tsukuyomi-chan", "models/v2.onnx");
454        assert_eq!(
455            url,
456            "https://huggingface.co/ayousanz/piper-plus-tsukuyomi-chan/resolve/main/models/v2.onnx"
457        );
458    }
459
460    // -- parse_model_registry -------------------------------------------------
461
462    #[test]
463    fn test_parse_model_registry_valid() {
464        let json = r#"[
465            {
466                "name": "test-model",
467                "language": "ja",
468                "quality": "medium",
469                "description": "A test model",
470                "model_url": "https://example.com/model.onnx",
471                "config_url": "https://example.com/config.json",
472                "size_bytes": 1024
473            }
474        ]"#;
475        let models = parse_model_registry(json).unwrap();
476        assert_eq!(models.len(), 1);
477        assert_eq!(models[0].name, "test-model");
478        assert_eq!(models[0].size_bytes, Some(1024));
479    }
480
481    #[test]
482    fn test_parse_model_registry_empty_array() {
483        let models = parse_model_registry("[]").unwrap();
484        assert!(models.is_empty());
485    }
486
487    #[test]
488    fn test_parse_model_registry_invalid_json() {
489        let result = parse_model_registry("not valid json");
490        assert!(result.is_err());
491    }
492
493    #[test]
494    fn test_parse_model_registry_missing_required_fields() {
495        let json = r#"[{"name": "incomplete"}]"#;
496        let result = parse_model_registry(json);
497        assert!(result.is_err());
498    }
499
500    // -- is_model_cached ------------------------------------------------------
501
502    #[test]
503    fn test_is_model_cached_with_onnx_json() {
504        let dir = tempfile::tempdir().unwrap();
505        let dir_path = dir.path();
506
507        // Neither file exists -- not cached.
508        assert!(!is_model_cached("voice", dir_path));
509
510        // Only ONNX -- still not cached.
511        std::fs::write(dir_path.join("voice.onnx"), b"fake").unwrap();
512        assert!(!is_model_cached("voice", dir_path));
513
514        // ONNX + onnx.json -- cached.
515        std::fs::write(dir_path.join("voice.onnx.json"), b"{}").unwrap();
516        assert!(is_model_cached("voice", dir_path));
517    }
518
519    #[test]
520    fn test_is_model_cached_with_config_json() {
521        let dir = tempfile::tempdir().unwrap();
522        let dir_path = dir.path();
523
524        std::fs::write(dir_path.join("voice.onnx"), b"fake").unwrap();
525        std::fs::write(dir_path.join("config.json"), b"{}").unwrap();
526        assert!(is_model_cached("voice", dir_path));
527    }
528
529    #[test]
530    fn test_is_model_cached_missing_onnx() {
531        let dir = tempfile::tempdir().unwrap();
532        let dir_path = dir.path();
533
534        // Config exists but ONNX does not -- not cached.
535        std::fs::write(dir_path.join("config.json"), b"{}").unwrap();
536        assert!(!is_model_cached("voice", dir_path));
537    }
538
539    // -- default_model_dir ----------------------------------------------------
540
541    #[test]
542    fn test_default_model_dir_is_non_empty() {
543        let dir = default_model_dir();
544        assert!(
545            !dir.as_os_str().is_empty(),
546            "default_model_dir must not be empty"
547        );
548        // Should always end with "models".
549        assert_eq!(
550            dir.file_name().and_then(|s| s.to_str()),
551            Some("models"),
552            "expected path to end with 'models', got: {dir:?}"
553        );
554    }
555
556    // -- ModelInfo serialization roundtrip -------------------------------------
557
558    #[test]
559    fn test_model_info_roundtrip() {
560        let info = ModelInfo {
561            name: "roundtrip-test".to_string(),
562            language: "en".to_string(),
563            quality: "high".to_string(),
564            description: "Roundtrip test model".to_string(),
565            model_url: "https://example.com/m.onnx".to_string(),
566            config_url: "https://example.com/c.json".to_string(),
567            size_bytes: Some(42),
568        };
569
570        let json = serde_json::to_string(&info).unwrap();
571        let deserialized: ModelInfo = serde_json::from_str(&json).unwrap();
572        assert_eq!(deserialized.name, info.name);
573        assert_eq!(deserialized.language, info.language);
574        assert_eq!(deserialized.quality, info.quality);
575        assert_eq!(deserialized.description, info.description);
576        assert_eq!(deserialized.model_url, info.model_url);
577        assert_eq!(deserialized.config_url, info.config_url);
578        assert_eq!(deserialized.size_bytes, info.size_bytes);
579    }
580
581    #[test]
582    fn test_model_info_size_bytes_optional() {
583        let json = r#"{
584            "name": "n",
585            "language": "ja",
586            "quality": "low",
587            "description": "d",
588            "model_url": "https://example.com/m.onnx",
589            "config_url": "https://example.com/c.json",
590            "size_bytes": null
591        }"#;
592        let info: ModelInfo = serde_json::from_str(json).unwrap();
593        assert!(info.size_bytes.is_none());
594    }
595
596    // -- builtin_registry -----------------------------------------------------
597
598    #[test]
599    fn test_builtin_registry_non_empty() {
600        let models = builtin_registry();
601        assert!(
602            models.len() >= 2,
603            "builtin registry should contain at least 2 models"
604        );
605        // Every entry should have valid-looking URLs.
606        for m in models {
607            assert!(
608                m.model_url.starts_with("https://"),
609                "bad model_url: {}",
610                m.model_url
611            );
612            assert!(
613                m.config_url.starts_with("https://"),
614                "bad config_url: {}",
615                m.config_url
616            );
617            assert!(!m.name.is_empty());
618        }
619    }
620
621    // -- DownloadProgress percentage ------------------------------------------
622
623    #[test]
624    fn test_download_progress_percentage() {
625        let progress = DownloadProgress {
626            bytes_downloaded: 50,
627            total_bytes: Some(200),
628            percentage: Some(25.0),
629        };
630        assert_eq!(progress.percentage, Some(25.0));
631        assert_eq!(progress.bytes_downloaded, 50);
632        assert_eq!(progress.total_bytes, Some(200));
633    }
634
635    #[test]
636    fn test_download_progress_unknown_total() {
637        let progress = DownloadProgress {
638            bytes_downloaded: 1024,
639            total_bytes: None,
640            percentage: None,
641        };
642        assert!(progress.total_bytes.is_none());
643        assert!(progress.percentage.is_none());
644    }
645
646    // -- url_filename (internal helper) ---------------------------------------
647
648    #[test]
649    fn test_url_filename_extraction() {
650        assert_eq!(
651            url_filename("https://example.com/path/to/model.onnx"),
652            Some("model.onnx".to_string())
653        );
654        assert_eq!(url_filename("https://example.com/"), None);
655        assert_eq!(url_filename("model.onnx"), Some("model.onnx".to_string()));
656    }
657
658    #[test]
659    fn test_url_filename_strips_query_string() {
660        assert_eq!(
661            url_filename("https://example.com/model.onnx?token=abc123"),
662            Some("model.onnx".to_string()),
663        );
664    }
665
666    #[test]
667    fn test_url_filename_strips_fragment() {
668        assert_eq!(
669            url_filename("https://example.com/model.onnx#section"),
670            Some("model.onnx".to_string()),
671        );
672    }
673
674    #[test]
675    fn test_url_filename_strips_query_and_fragment() {
676        assert_eq!(
677            url_filename("https://example.com/model.onnx?v=2#top"),
678            Some("model.onnx".to_string()),
679        );
680    }
681
682    // -- download_file stub (non-download feature) ----------------------------
683
684    #[cfg(not(feature = "download"))]
685    #[test]
686    fn test_download_file_stub_returns_error() {
687        let dir = tempfile::tempdir().unwrap();
688        let dest = dir.path().join("out.onnx");
689        let result = download_file("https://example.com/model.onnx", &dest, None);
690        assert!(result.is_err());
691        let msg = format!("{}", result.unwrap_err());
692        assert!(
693            msg.contains("download"),
694            "error should mention the download feature: {msg}"
695        );
696    }
697
698    #[cfg(not(feature = "download"))]
699    #[test]
700    fn test_download_model_stub_returns_error() {
701        let dir = tempfile::tempdir().unwrap();
702        let info = ModelInfo {
703            name: "test".to_string(),
704            language: "en".to_string(),
705            quality: "medium".to_string(),
706            description: "test".to_string(),
707            model_url: "https://example.com/m.onnx".to_string(),
708            config_url: "https://example.com/c.json".to_string(),
709            size_bytes: None,
710        };
711        let result = download_model(&info, dir.path(), None);
712        assert!(result.is_err());
713    }
714
715    // -- TDD additions: feature-gated paths & error handling ------------------
716
717    #[test]
718    fn test_download_progress_percentage_zero_total() {
719        // When total_bytes is Some(0) the percentage calculation in
720        // download_file uses `if t == 0 { 100.0 }`.  Verify the same
721        // convention works when constructing DownloadProgress manually
722        // (i.e. no division-by-zero panic).
723        let total: Option<u64> = Some(0);
724        let percentage = total.map(|t| {
725            if t == 0 {
726                100.0
727            } else {
728                (50_f64 / t as f64) * 100.0
729            }
730        });
731        let progress = DownloadProgress {
732            bytes_downloaded: 50,
733            total_bytes: total,
734            percentage,
735        };
736        assert_eq!(progress.percentage, Some(100.0));
737        assert_eq!(progress.total_bytes, Some(0));
738    }
739
740    #[test]
741    fn test_model_info_empty_fields() {
742        // All-empty strings are structurally valid — no runtime panic.
743        let info = ModelInfo {
744            name: String::new(),
745            language: String::new(),
746            quality: String::new(),
747            description: String::new(),
748            model_url: String::new(),
749            config_url: String::new(),
750            size_bytes: None,
751        };
752        assert!(info.name.is_empty());
753        assert!(info.size_bytes.is_none());
754
755        // Roundtrip through JSON should also succeed.
756        let json = serde_json::to_string(&info).unwrap();
757        let back: ModelInfo = serde_json::from_str(&json).unwrap();
758        assert!(back.name.is_empty());
759    }
760
761    #[test]
762    fn test_huggingface_url_special_chars() {
763        // Repo names with spaces or special characters — the function does
764        // plain string interpolation so they must appear verbatim in the URL.
765        let url = huggingface_url("owner/repo with spaces", "model (v2).onnx");
766        assert!(url.starts_with("https://huggingface.co/"));
767        assert!(url.contains("repo with spaces"));
768        assert!(url.contains("model (v2).onnx"));
769
770        // Unicode characters in repo name.
771        let url2 = huggingface_url("user/日本語モデル", "model.onnx");
772        assert!(url2.contains("日本語モデル"));
773    }
774
775    #[test]
776    fn test_is_model_cached_empty_model_name() {
777        let dir = tempfile::tempdir().unwrap();
778        let dir_path = dir.path();
779
780        // Empty model name produces ".onnx" and ".onnx.json" lookups.
781        // Nothing exists so it must return false without panicking.
782        assert!(!is_model_cached("", dir_path));
783
784        // Even if we create the degenerate files, the logic should work.
785        std::fs::write(dir_path.join(".onnx"), b"fake").unwrap();
786        std::fs::write(dir_path.join(".onnx.json"), b"{}").unwrap();
787        assert!(is_model_cached("", dir_path));
788    }
789
790    #[test]
791    fn test_is_model_cached_with_subdirectory() {
792        // A model_dir that does not exist on disk should return false,
793        // never panic.
794        let nonexistent = PathBuf::from("/tmp/piper_test_nonexistent_dir_12345");
795        assert!(!is_model_cached("some-model", &nonexistent));
796    }
797
798    #[test]
799    fn test_parse_model_registry_extra_fields() {
800        // serde by default ignores unknown fields (no deny_unknown_fields).
801        let json = r#"[
802            {
803                "name": "test",
804                "language": "en",
805                "quality": "medium",
806                "description": "desc",
807                "model_url": "https://example.com/m.onnx",
808                "config_url": "https://example.com/c.json",
809                "size_bytes": null,
810                "author": "someone",
811                "license": "MIT",
812                "extra_nested": {"a": 1}
813            }
814        ]"#;
815        let models = parse_model_registry(json).unwrap();
816        assert_eq!(models.len(), 1);
817        assert_eq!(models[0].name, "test");
818    }
819
820    #[test]
821    fn test_parse_model_registry_unicode() {
822        // Japanese/Chinese characters in name and description.
823        let json = r#"[
824            {
825                "name": "つくよみちゃん",
826                "language": "ja",
827                "quality": "medium",
828                "description": "高品質な日本語音声合成 — 中文描述也可以",
829                "model_url": "https://example.com/model.onnx",
830                "config_url": "https://example.com/config.json",
831                "size_bytes": 999
832            }
833        ]"#;
834        let models = parse_model_registry(json).unwrap();
835        assert_eq!(models[0].name, "つくよみちゃん");
836        assert!(models[0].description.contains("中文"));
837    }
838
839    #[test]
840    fn test_builtin_registry_urls_format() {
841        // Every URL in the builtin registry must start with https://
842        // and reference huggingface.co.
843        for m in builtin_registry() {
844            assert!(
845                m.model_url.starts_with("https://") && m.model_url.contains("huggingface"),
846                "model_url must be an HTTPS HuggingFace URL, got: {}",
847                m.model_url,
848            );
849            assert!(
850                m.config_url.starts_with("https://") && m.config_url.contains("huggingface"),
851                "config_url must be an HTTPS HuggingFace URL, got: {}",
852                m.config_url,
853            );
854        }
855    }
856
857    #[test]
858    fn test_default_model_dir_consistent() {
859        // Calling twice must return the exact same path — no randomness
860        // or time-dependent components.
861        let a = default_model_dir();
862        let b = default_model_dir();
863        assert_eq!(a, b, "default_model_dir should be deterministic");
864    }
865
866    // -- find_model -----------------------------------------------------------
867
868    #[test]
869    fn test_find_model_exact_name() {
870        let m = find_model("tsukuyomi-6lang-v2");
871        assert!(m.is_some());
872        assert_eq!(m.unwrap().name, "tsukuyomi-6lang-v2");
873    }
874
875    #[test]
876    fn test_find_model_partial_name() {
877        // "css10" is a unique substring across all model names.
878        let m = find_model("css10");
879        assert!(m.is_some());
880        assert!(
881            m.unwrap().name.contains("css10"),
882            "partial name match should return a model containing the query string"
883        );
884    }
885
886    #[test]
887    fn test_find_model_description_match() {
888        // "Tsukuyomi" appears only in one model's description.
889        let m = find_model("Tsukuyomi");
890        assert!(m.is_some());
891        assert!(
892            m.unwrap().description.to_lowercase().contains("tsukuyomi"),
893            "description match should return a model whose description contains the query"
894        );
895    }
896
897    #[test]
898    fn test_find_model_case_insensitive_description() {
899        let m = find_model("tsukuyomi");
900        assert!(m.is_some());
901        assert!(
902            m.unwrap().description.to_lowercase().contains("tsukuyomi"),
903            "case-insensitive description match should find a model"
904        );
905    }
906
907    #[test]
908    fn test_find_model_no_match() {
909        let m = find_model("nonexistent-model-xyz");
910        assert!(m.is_none());
911    }
912
913    #[test]
914    fn test_find_model_ambiguous_returns_none() {
915        // "6lang" appears in both model names, so partial match is ambiguous.
916        let m = find_model("6lang");
917        assert!(m.is_none(), "ambiguous partial match should return None");
918    }
919
920    // -- resolve_model_path ---------------------------------------------------
921
922    #[test]
923    fn test_resolve_model_path_existing_file() {
924        let dir = tempfile::tempdir().unwrap();
925        let file = dir.path().join("my-model.onnx");
926        std::fs::write(&file, b"fake onnx").unwrap();
927
928        let resolved = resolve_model_path(file.to_str().unwrap(), None).unwrap();
929        assert_eq!(resolved, file);
930    }
931
932    #[test]
933    fn test_resolve_model_path_cached_model() {
934        let dir = tempfile::tempdir().unwrap();
935        let dir_path = dir.path();
936
937        // Create cached files for tsukuyomi-6lang-v2
938        std::fs::write(dir_path.join("tsukuyomi-6lang-v2.onnx"), b"fake").unwrap();
939        std::fs::write(dir_path.join("tsukuyomi-6lang-v2.onnx.json"), b"{}").unwrap();
940
941        let resolved = resolve_model_path("tsukuyomi-6lang-v2", Some(dir_path)).unwrap();
942        assert_eq!(resolved, dir_path.join("tsukuyomi-6lang-v2.onnx"));
943    }
944
945    #[test]
946    fn test_resolve_model_path_cached_via_alias() {
947        let dir = tempfile::tempdir().unwrap();
948        let dir_path = dir.path();
949
950        // "css10" partial match resolves to "css10-6lang"
951        std::fs::write(dir_path.join("css10-6lang.onnx"), b"fake").unwrap();
952        std::fs::write(dir_path.join("css10-6lang.onnx.json"), b"{}").unwrap();
953
954        let resolved = resolve_model_path("css10", Some(dir_path)).unwrap();
955        assert_eq!(resolved, dir_path.join("css10-6lang.onnx"));
956    }
957
958    #[test]
959    fn test_resolve_model_path_unknown_model_error() {
960        let result = resolve_model_path("nonexistent-model-xyz", None);
961        assert!(result.is_err());
962        let msg = format!("{}", result.unwrap_err());
963        assert!(msg.contains("not found"), "error message: {msg}");
964    }
965}