orchard/model/
resolver.rs1use std::collections::HashMap;
4use std::path::{Path, PathBuf};
5
6use hf_hub::api::tokio::{Api, ApiBuilder};
7use serde::{Deserialize, Serialize};
8
9use crate::error::{Error, Result};
10
11#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct ResolvedModel {
14 pub canonical_id: String,
15 pub model_path: PathBuf,
16 pub source: String,
17 pub metadata: HashMap<String, String>,
18 pub hf_repo: Option<String>,
19}
20
21pub struct ModelResolver {
23 resolved_cache: HashMap<String, ResolvedModel>,
24 hf_api: Api,
25}
26
27impl ModelResolver {
28 pub fn new() -> Result<Self> {
30 Ok(Self {
31 resolved_cache: HashMap::new(),
32 hf_api: ApiBuilder::from_env()
33 .build()
34 .map_err(|e| Error::HfApiInit(e.to_string()))?,
35 })
36 }
37
38 pub async fn resolve(&mut self, requested_id: &str) -> Result<ResolvedModel> {
45 let identifier = requested_id.trim();
46 if identifier.is_empty() {
47 return Err(Error::EmptyModelId);
48 }
49
50 let cache_key = identifier.to_lowercase();
52 if let Some(cached) = self.resolved_cache.get(&cache_key) {
53 return Ok(cached.clone());
54 }
55
56 if let Some(resolved) = self.try_local_path(identifier).await? {
58 self.resolved_cache.insert(cache_key, resolved.clone());
59 return Ok(resolved);
60 }
61
62 let resolved = self.resolve_huggingface(identifier).await?;
64
65 self.resolved_cache.insert(cache_key, resolved.clone());
66 Ok(resolved)
67 }
68
69 pub fn clear_cache(&mut self) {
71 self.resolved_cache.clear();
72 }
73
74 async fn try_local_path(&self, identifier: &str) -> Result<Option<ResolvedModel>> {
75 let path = PathBuf::from(identifier);
76
77 if path.is_absolute() && path.is_dir() {
79 return Ok(Some(
80 self.build_resolved_model(path, "local", None, None).await?,
81 ));
82 }
83
84 if path.is_dir() {
86 let resolved = std::fs::canonicalize(&path)?;
87 return Ok(Some(
88 self.build_resolved_model(resolved, "local", None, None)
89 .await?,
90 ));
91 }
92
93 Ok(None)
94 }
95
96 async fn resolve_huggingface(&self, repo_id: &str) -> Result<ResolvedModel> {
97 let repo = self.hf_api.model(repo_id.to_string());
98
99 let path = match repo.get("config.json").await {
101 Ok(config_path) => config_path
102 .parent()
103 .map(|p| p.to_path_buf())
104 .unwrap_or_else(|| config_path),
105 Err(e) => {
106 return Err(Error::DownloadFailed(repo_id.to_string(), e.to_string()));
107 }
108 };
109
110 let source = if path.to_string_lossy().contains("cache") {
111 "hf_cache"
112 } else {
113 "hf_hub"
114 };
115
116 self.build_resolved_model(path, source, Some(repo_id), Some(repo_id))
117 .await
118 }
119
120 async fn build_resolved_model(
121 &self,
122 model_path: PathBuf,
123 source: &str,
124 canonical_id: Option<&str>,
125 hf_repo: Option<&str>,
126 ) -> Result<ResolvedModel> {
127 let model_path = if model_path.is_absolute() {
128 model_path
129 } else {
130 std::fs::canonicalize(&model_path)?
131 };
132
133 let config = self.load_config(&model_path)?;
135 let metadata = Self::collect_metadata(&config);
136
137 let canonical_id = canonical_id
139 .map(String::from)
140 .or_else(|| Self::determine_canonical_id(&config, &model_path))
141 .unwrap_or_else(|| {
142 model_path
143 .file_name()
144 .map(|n| n.to_string_lossy().to_string())
145 .unwrap_or_else(|| "unknown".to_string())
146 });
147
148 let hf_repo = hf_repo
150 .map(String::from)
151 .or_else(|| Self::infer_hf_repo(&config));
152
153 Ok(ResolvedModel {
154 canonical_id,
155 model_path,
156 source: source.to_string(),
157 metadata,
158 hf_repo,
159 })
160 }
161
162 fn load_config(&self, model_dir: &Path) -> Result<serde_json::Value> {
163 let config_file = model_dir.join("config.json");
164 if !config_file.exists() {
165 return Err(Error::MissingConfig(model_dir.to_path_buf()));
166 }
167
168 let content = std::fs::read_to_string(&config_file)?;
169 serde_json::from_str(&content).map_err(Error::from)
170 }
171
172 fn determine_canonical_id(config: &serde_json::Value, model_dir: &Path) -> Option<String> {
173 config
174 .get("_name_or_path")
175 .and_then(|v| v.as_str())
176 .filter(|s| !s.is_empty())
177 .map(String::from)
178 .or_else(|| {
179 config
180 .get("model_id")
181 .and_then(|v| v.as_str())
182 .map(String::from)
183 })
184 .or_else(|| {
185 model_dir
186 .file_name()
187 .map(|n| n.to_string_lossy().to_string())
188 })
189 }
190
191 fn infer_hf_repo(config: &serde_json::Value) -> Option<String> {
192 let candidate = config
193 .get("_name_or_path")
194 .or_else(|| config.get("original_repo"))
195 .and_then(|v| v.as_str());
196
197 candidate
198 .filter(|s| s.contains('/') && !s.starts_with('/'))
199 .map(String::from)
200 }
201
202 fn collect_metadata(config: &serde_json::Value) -> HashMap<String, String> {
203 let mut metadata = HashMap::new();
204
205 let keys = [
206 "model_type",
207 "hidden_size",
208 "num_hidden_layers",
209 "architecture",
210 ];
211
212 for key in keys {
213 if let Some(value) = config.get(key) {
214 let str_value = match value {
215 serde_json::Value::String(s) => s.clone(),
216 serde_json::Value::Number(n) => n.to_string(),
217 serde_json::Value::Bool(b) => b.to_string(),
218 serde_json::Value::Object(_) | serde_json::Value::Array(_) => {
219 serde_json::to_string(value).unwrap_or_default()
220 }
221 serde_json::Value::Null => continue,
222 };
223 metadata.insert(key.to_string(), str_value);
224 }
225 }
226
227 if let Some(quant_cfg) = config
229 .get("quantization_config")
230 .or_else(|| config.get("quantization"))
231 {
232 if let Some(bits) = quant_cfg
233 .get("bits")
234 .or_else(|| quant_cfg.get("num_bits"))
235 .and_then(|v| v.as_u64())
236 {
237 metadata.insert("quantization_bits".to_string(), bits.to_string());
238 }
239 }
240
241 metadata
242 }
243}
244
245#[cfg(test)]
246mod tests {
247 use super::*;
248
249 #[test]
250 fn test_resolver_creation() {
251 let resolver = ModelResolver::new().unwrap();
252 assert!(resolver.resolved_cache.is_empty());
253 }
254
255 #[test]
256 fn test_collect_metadata() {
257 let config = serde_json::json!({
258 "model_type": "llama",
259 "hidden_size": 4096,
260 "num_hidden_layers": 32,
261 "quantization_config": {
262 "bits": 4
263 }
264 });
265
266 let metadata = ModelResolver::collect_metadata(&config);
267 assert_eq!(metadata.get("model_type"), Some(&"llama".to_string()));
268 assert_eq!(metadata.get("hidden_size"), Some(&"4096".to_string()));
269 assert_eq!(metadata.get("quantization_bits"), Some(&"4".to_string()));
270 }
271
272 #[test]
273 fn test_infer_hf_repo() {
274 let config = serde_json::json!({
275 "_name_or_path": "meta-llama/Llama-3.1-8B-Instruct"
276 });
277
278 let repo = ModelResolver::infer_hf_repo(&config);
279 assert_eq!(repo, Some("meta-llama/Llama-3.1-8B-Instruct".to_string()));
280 }
281
282 #[test]
283 fn test_infer_hf_repo_local_path() {
284 let config = serde_json::json!({
285 "_name_or_path": "/local/path/to/model"
286 });
287
288 let repo = ModelResolver::infer_hf_repo(&config);
289 assert_eq!(repo, None);
290 }
291}