active_call/offline/
downloader.rs

1use anyhow::{Context, Result};
2use hf_hub::api::sync::{Api, ApiBuilder};
3use hf_hub::{Repo, RepoType};
4use std::fs;
5use std::path::Path;
6use tracing::{info, warn};
7
8pub struct ModelDownloader {
9    api: Api,
10}
11
12#[derive(Debug, Clone, Copy, PartialEq, Eq)]
13pub enum ModelType {
14    Sensevoice,
15    Supertonic,
16    All,
17}
18
19impl ModelType {
20    pub fn from_str(s: &str) -> Option<Self> {
21        match s.to_lowercase().as_str() {
22            "sensevoice" => Some(Self::Sensevoice),
23            "supertonic" => Some(Self::Supertonic),
24            "all" => Some(Self::All),
25            _ => None,
26        }
27    }
28}
29
30impl ModelDownloader {
31    pub fn new() -> Result<Self> {
32        let api = ApiBuilder::from_env()
33            .build()
34            .context("failed to create HuggingFace API client")?;
35        Ok(Self { api })
36    }
37
38    pub fn download(&self, model_type: ModelType, dest_dir: &Path) -> Result<()> {
39        match model_type {
40            ModelType::Sensevoice => self.download_sensevoice(dest_dir),
41            ModelType::Supertonic => self.download_supertonic(dest_dir),
42            ModelType::All => {
43                self.download_sensevoice(dest_dir)?;
44                self.download_supertonic(dest_dir)?;
45                Ok(())
46            }
47        }
48    }
49
50    fn download_file(
51        &self,
52        repo_id: &str,
53        revision: &str,
54        filename: &str,
55        dest: &Path,
56    ) -> Result<()> {
57        if dest.exists() {
58            let meta = fs::metadata(dest).context("failed to get metadata")?;
59            if meta.len() < 100 && filename.ends_with(".onnx") {
60                warn!(
61                    "  File {} exists but is too small ({} bytes). Deleting...",
62                    filename,
63                    meta.len()
64                );
65                fs::remove_file(dest).context("failed to remove corrupted file")?;
66            } else {
67                info!("  {} already exists, skipping", filename);
68                return Ok(());
69            }
70        }
71
72        info!("  Downloading {}...", filename);
73        let repo = self.api.repo(Repo::with_revision(
74            repo_id.to_string(),
75            RepoType::Model,
76            revision.to_string(),
77        ));
78
79        match repo.get(filename) {
80            Ok(src_path) => {
81                fs::copy(&src_path, dest).with_context(|| {
82                    format!(
83                        "failed to copy {} to {}",
84                        src_path.display(),
85                        dest.display()
86                    )
87                })?;
88                info!("  ✓ Downloaded {}", filename);
89                Ok(())
90            }
91            Err(err) => {
92                warn!(
93                    "  hf-hub failed for {}: {}. Attempting manual download...",
94                    filename, err
95                );
96
97                // Fallback using curl
98                let endpoint =
99                    std::env::var("HF_ENDPOINT").unwrap_or("https://huggingface.co".to_string());
100                // Handle nested paths in URL
101                let url = format!("{}/{}/resolve/{}/{}", endpoint, repo_id, revision, filename);
102
103                let status = std::process::Command::new("curl")
104                    .arg("-f") // Fail on 404
105                    .arg("-L")
106                    .arg("-o")
107                    .arg(dest)
108                    .arg(&url)
109                    .status()
110                    .context("failed to execute curl")?;
111
112                if status.success() {
113                    info!("  ✓ Downloaded {} (curl fallback)", filename);
114                    Ok(())
115                } else {
116                    anyhow::bail!("Both hf-hub and curl failed to download {}", filename);
117                }
118            }
119        }
120    }
121
122    fn download_sensevoice(&self, dest_dir: &Path) -> Result<()> {
123        info!("Downloading SenseVoice model...");
124        let sensevoice_dir = dest_dir.join("sensevoice");
125        fs::create_dir_all(&sensevoice_dir).context("failed to create sensevoice directory")?;
126
127        let repo_id = "csukuangfj/sherpa-onnx-sense-voice-zh-en-ja-ko-yue-int8-2025-09-09";
128        let revision = "main";
129        let files = vec!["model.int8.onnx", "tokens.txt"];
130
131        for file in files {
132            let dest = sensevoice_dir.join(file);
133            self.download_file(repo_id, revision, file, &dest)?;
134        }
135
136        info!("✓ SenseVoice model downloaded successfully");
137        Ok(())
138    }
139
140    fn download_supertonic(&self, dest_dir: &Path) -> Result<()> {
141        info!("Downloading Supertonic model...");
142        let supertonic_dir = dest_dir.join("supertonic");
143        fs::create_dir_all(&supertonic_dir).context("failed to create supertonic directory")?;
144
145        let repo_id = "Supertone/supertonic-2";
146        let revision = "main";
147        let files = vec![
148            "onnx/duration_predictor.onnx",
149            "onnx/text_encoder.onnx",
150            "onnx/vector_estimator.onnx",
151            "onnx/vocoder.onnx",
152            "onnx/unicode_indexer.json",
153            "onnx/tts.json",
154            "voice_styles/M1.json",
155            "voice_styles/M2.json",
156            "voice_styles/F1.json",
157            "voice_styles/F2.json",
158        ];
159
160        for file in files {
161            let dest = supertonic_dir.join(file);
162
163            // Create parent directory if needed
164            if let Some(parent) = dest.parent() {
165                fs::create_dir_all(parent)
166                    .with_context(|| format!("failed to create directory {}", parent.display()))?;
167            }
168
169            self.download_file(repo_id, revision, file, &dest)?;
170        }
171
172        info!("✓ Supertonic model downloaded successfully");
173        Ok(())
174    }
175}
176
177#[cfg(test)]
178mod tests {
179    use super::*;
180
181    #[test]
182    fn test_model_type_from_str() {
183        assert_eq!(
184            ModelType::from_str("sensevoice"),
185            Some(ModelType::Sensevoice)
186        );
187        assert_eq!(
188            ModelType::from_str("SUPERTONIC"),
189            Some(ModelType::Supertonic)
190        );
191        assert_eq!(ModelType::from_str("all"), Some(ModelType::All));
192        assert_eq!(ModelType::from_str("invalid"), None);
193    }
194}