1use std::path::{Path, PathBuf};
4
5use serde::{Deserialize, Serialize};
6use tracing::info;
7
8use crate::InferenceError;
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
12#[serde(rename_all = "snake_case")]
13pub enum ModelRole {
14 Small,
16 Medium,
18 Large,
20 Expert,
22 Embedding,
24}
25
26#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct ModelInfo {
29 pub name: String,
30 pub hf_repo: String,
31 pub hf_filename: String,
32 pub tokenizer_repo: String,
33 pub role: ModelRole,
34 pub param_count: &'static str,
35 pub quantized_size_mb: u64,
36 pub downloaded: bool,
37}
38
39pub struct ModelRegistry {
41 models_dir: PathBuf,
42 catalog: Vec<ModelSpec>,
43}
44
45struct ModelSpec {
46 name: &'static str,
47 hf_repo: &'static str,
48 hf_filename: &'static str,
49 tokenizer_repo: &'static str,
50 role: ModelRole,
51 param_count: &'static str,
52 quantized_size_mb: u64,
53}
54
55impl ModelRegistry {
56 pub fn new(models_dir: PathBuf) -> Self {
57 Self {
58 models_dir,
59 catalog: builtin_catalog(),
60 }
61 }
62
63 pub fn list_models(&self) -> Vec<ModelInfo> {
65 self.catalog
66 .iter()
67 .map(|spec| {
68 let local_path = self.models_dir.join(spec.name).join("model.gguf");
69 ModelInfo {
70 name: spec.name.to_string(),
71 hf_repo: spec.hf_repo.to_string(),
72 hf_filename: spec.hf_filename.to_string(),
73 tokenizer_repo: spec.tokenizer_repo.to_string(),
74 role: spec.role,
75 param_count: spec.param_count,
76 quantized_size_mb: spec.quantized_size_mb,
77 downloaded: local_path.exists(),
78 }
79 })
80 .collect()
81 }
82
83 fn find_spec(&self, name: &str) -> Option<&ModelSpec> {
85 self.catalog
86 .iter()
87 .find(|s| s.name.eq_ignore_ascii_case(name))
88 }
89
90 pub async fn ensure_model(&self, name: &str) -> Result<PathBuf, InferenceError> {
92 let spec = self
93 .find_spec(name)
94 .ok_or_else(|| InferenceError::ModelNotFound(name.to_string()))?;
95
96 let model_dir = self.models_dir.join(spec.name);
97 let model_path = model_dir.join("model.gguf");
98 let tokenizer_path = model_dir.join("tokenizer.json");
99
100 if model_path.exists() && tokenizer_path.exists() {
101 return Ok(model_dir);
102 }
103
104 std::fs::create_dir_all(&model_dir)?;
105
106 if !model_path.exists() {
108 info!(
109 model = spec.name,
110 repo = spec.hf_repo,
111 "downloading model weights"
112 );
113 download_file(spec.hf_repo, spec.hf_filename, &model_path).await?;
114 }
115
116 if !tokenizer_path.exists() {
118 info!(
119 model = spec.name,
120 repo = spec.tokenizer_repo,
121 "downloading tokenizer"
122 );
123 download_file(spec.tokenizer_repo, "tokenizer.json", &tokenizer_path).await?;
124 }
125
126 Ok(model_dir)
127 }
128
129 pub fn remove_model(&self, name: &str) -> Result<(), InferenceError> {
131 let _spec = self
132 .find_spec(name)
133 .ok_or_else(|| InferenceError::ModelNotFound(name.to_string()))?;
134
135 let model_dir = self.models_dir.join(name);
136 if model_dir.exists() {
137 std::fs::remove_dir_all(&model_dir)?;
138 info!(model = name, "removed model");
139 }
140 Ok(())
141 }
142}
143
144async fn download_file(repo: &str, filename: &str, dest: &Path) -> Result<(), InferenceError> {
146 let api = hf_hub::api::tokio::Api::new()
147 .map_err(|e| InferenceError::DownloadFailed(e.to_string()))?;
148
149 let repo = api.model(repo.to_string());
150 let path = repo
151 .get(filename)
152 .await
153 .map_err(|e| InferenceError::DownloadFailed(format!("{filename}: {e}")))?;
154
155 if dest.exists() {
157 return Ok(());
158 }
159
160 #[cfg(unix)]
162 {
163 if std::os::unix::fs::symlink(&path, dest).is_ok() {
164 return Ok(());
165 }
166 }
167
168 std::fs::copy(&path, dest)
169 .map_err(|e| InferenceError::DownloadFailed(format!("copy to {}: {e}", dest.display())))?;
170 Ok(())
171}
172
173fn builtin_catalog() -> Vec<ModelSpec> {
175 vec![
176 ModelSpec {
177 name: "Qwen3-Embedding-0.6B",
178 hf_repo: "Qwen/Qwen3-Embedding-0.6B-GGUF",
179 hf_filename: "Qwen3-Embedding-0.6B-Q8_0.gguf",
180 tokenizer_repo: "Qwen/Qwen3-Embedding-0.6B",
181 role: ModelRole::Embedding,
182 param_count: "0.6B",
183 quantized_size_mb: 639,
184 },
185 ModelSpec {
186 name: "Qwen3-0.6B",
187 hf_repo: "Qwen/Qwen3-0.6B-GGUF",
188 hf_filename: "Qwen3-0.6B-Q8_0.gguf",
189 tokenizer_repo: "Qwen/Qwen3-0.6B",
190 role: ModelRole::Small,
191 param_count: "0.6B",
192 quantized_size_mb: 650,
193 },
194 ModelSpec {
195 name: "Qwen3-1.7B",
196 hf_repo: "Qwen/Qwen3-1.7B-GGUF",
197 hf_filename: "Qwen3-1.7B-Q8_0.gguf",
198 tokenizer_repo: "Qwen/Qwen3-1.7B",
199 role: ModelRole::Medium,
200 param_count: "1.7B",
201 quantized_size_mb: 1800,
202 },
203 ModelSpec {
204 name: "Qwen3-4B",
205 hf_repo: "Qwen/Qwen3-4B-GGUF",
206 hf_filename: "Qwen3-4B-Q4_K_M.gguf",
207 tokenizer_repo: "Qwen/Qwen3-4B",
208 role: ModelRole::Medium,
209 param_count: "4B",
210 quantized_size_mb: 2500,
211 },
212 ModelSpec {
213 name: "Qwen3-8B",
214 hf_repo: "Qwen/Qwen3-8B-GGUF",
215 hf_filename: "Qwen3-8B-Q4_K_M.gguf",
216 tokenizer_repo: "Qwen/Qwen3-8B",
217 role: ModelRole::Large,
218 param_count: "8B",
219 quantized_size_mb: 4900,
220 },
221 ModelSpec {
222 name: "Qwen3-30B-A3B",
223 hf_repo: "Qwen/Qwen3-30B-A3B-GGUF",
224 hf_filename: "Qwen3-30B-A3B-Q4_K_M.gguf",
225 tokenizer_repo: "Qwen/Qwen3-30B-A3B",
226 role: ModelRole::Expert,
227 param_count: "30B (3B active)",
228 quantized_size_mb: 17000,
229 },
230 ]
231}