Skip to main content

orchard/model/
resolver.rs

1//! Model resolution utilities.
2
3use std::collections::HashMap;
4use std::path::{Path, PathBuf};
5
6use hf_hub::api::tokio::{Api, ApiBuilder};
7use serde::{Deserialize, Serialize};
8
9use crate::error::{Error, Result};
10
11/// Result of resolving a model identifier.
12#[derive(Debug, Clone, Serialize, Deserialize)]
13pub struct ResolvedModel {
14    pub canonical_id: String,
15    pub model_path: PathBuf,
16    pub source: String,
17    pub metadata: HashMap<String, String>,
18    pub hf_repo: Option<String>,
19}
20
21/// Resolves model identifiers to local filesystem paths.
22pub struct ModelResolver {
23    resolved_cache: HashMap<String, ResolvedModel>,
24    hf_api: Api,
25}
26
27impl ModelResolver {
28    /// Create a new model resolver.
29    pub fn new() -> Result<Self> {
30        Ok(Self {
31            resolved_cache: HashMap::new(),
32            hf_api: ApiBuilder::from_env()
33                .build()
34                .map_err(|e| Error::HfApiInit(e.to_string()))?,
35        })
36    }
37
38    /// Resolve a model identifier to a local filesystem path.
39    ///
40    /// # Arguments
41    /// * `requested_id` - Model identifier, which can be:
42    ///   - Local path: `/path/to/model` or `./relative/path`
43    ///   - HF repo ID: `meta-llama/Llama-3.1-8B-Instruct`
44    pub async fn resolve(&mut self, requested_id: &str) -> Result<ResolvedModel> {
45        let identifier = requested_id.trim();
46        if identifier.is_empty() {
47            return Err(Error::EmptyModelId);
48        }
49
50        // Check cache first
51        let cache_key = identifier.to_lowercase();
52        if let Some(cached) = self.resolved_cache.get(&cache_key) {
53            return Ok(cached.clone());
54        }
55
56        // 1. Try as local path
57        if let Some(resolved) = self.try_local_path(identifier).await? {
58            self.resolved_cache.insert(cache_key, resolved.clone());
59            return Ok(resolved);
60        }
61
62        // 2. Resolve via HuggingFace
63        let resolved = self.resolve_huggingface(identifier).await?;
64
65        self.resolved_cache.insert(cache_key, resolved.clone());
66        Ok(resolved)
67    }
68
69    /// Clear the resolution cache.
70    pub fn clear_cache(&mut self) {
71        self.resolved_cache.clear();
72    }
73
74    async fn try_local_path(&self, identifier: &str) -> Result<Option<ResolvedModel>> {
75        let path = PathBuf::from(identifier);
76
77        // Check absolute path
78        if path.is_absolute() && path.is_dir() {
79            return Ok(Some(
80                self.build_resolved_model(path, "local", None, None).await?,
81            ));
82        }
83
84        // Check relative path
85        if path.is_dir() {
86            let resolved = std::fs::canonicalize(&path)?;
87            return Ok(Some(
88                self.build_resolved_model(resolved, "local", None, None)
89                    .await?,
90            ));
91        }
92
93        Ok(None)
94    }
95
96    async fn resolve_huggingface(&self, repo_id: &str) -> Result<ResolvedModel> {
97        let repo = self.hf_api.model(repo_id.to_string());
98
99        // Try to get from cache first, then download if needed
100        let path = match repo.get("config.json").await {
101            Ok(config_path) => config_path
102                .parent()
103                .map(|p| p.to_path_buf())
104                .unwrap_or_else(|| config_path),
105            Err(e) => {
106                return Err(Error::DownloadFailed(repo_id.to_string(), e.to_string()));
107            }
108        };
109
110        let source = if path.to_string_lossy().contains("cache") {
111            "hf_cache"
112        } else {
113            "hf_hub"
114        };
115
116        self.build_resolved_model(path, source, Some(repo_id), Some(repo_id))
117            .await
118    }
119
120    async fn build_resolved_model(
121        &self,
122        model_path: PathBuf,
123        source: &str,
124        canonical_id: Option<&str>,
125        hf_repo: Option<&str>,
126    ) -> Result<ResolvedModel> {
127        let model_path = if model_path.is_absolute() {
128            model_path
129        } else {
130            std::fs::canonicalize(&model_path)?
131        };
132
133        // Load and parse config
134        let config = self.load_config(&model_path)?;
135        let metadata = Self::collect_metadata(&config);
136
137        // Determine canonical ID
138        let canonical_id = canonical_id
139            .map(String::from)
140            .or_else(|| Self::determine_canonical_id(&config, &model_path))
141            .unwrap_or_else(|| {
142                model_path
143                    .file_name()
144                    .map(|n| n.to_string_lossy().to_string())
145                    .unwrap_or_else(|| "unknown".to_string())
146            });
147
148        // Infer HF repo
149        let hf_repo = hf_repo
150            .map(String::from)
151            .or_else(|| Self::infer_hf_repo(&config));
152
153        Ok(ResolvedModel {
154            canonical_id,
155            model_path,
156            source: source.to_string(),
157            metadata,
158            hf_repo,
159        })
160    }
161
162    fn load_config(&self, model_dir: &Path) -> Result<serde_json::Value> {
163        let config_file = model_dir.join("config.json");
164        if !config_file.exists() {
165            return Err(Error::MissingConfig(model_dir.to_path_buf()));
166        }
167
168        let content = std::fs::read_to_string(&config_file)?;
169        serde_json::from_str(&content).map_err(Error::from)
170    }
171
172    fn determine_canonical_id(config: &serde_json::Value, model_dir: &Path) -> Option<String> {
173        config
174            .get("_name_or_path")
175            .and_then(|v| v.as_str())
176            .filter(|s| !s.is_empty())
177            .map(String::from)
178            .or_else(|| {
179                config
180                    .get("model_id")
181                    .and_then(|v| v.as_str())
182                    .map(String::from)
183            })
184            .or_else(|| {
185                model_dir
186                    .file_name()
187                    .map(|n| n.to_string_lossy().to_string())
188            })
189    }
190
191    fn infer_hf_repo(config: &serde_json::Value) -> Option<String> {
192        let candidate = config
193            .get("_name_or_path")
194            .or_else(|| config.get("original_repo"))
195            .and_then(|v| v.as_str());
196
197        candidate
198            .filter(|s| s.contains('/') && !s.starts_with('/'))
199            .map(String::from)
200    }
201
202    fn collect_metadata(config: &serde_json::Value) -> HashMap<String, String> {
203        let mut metadata = HashMap::new();
204
205        let keys = [
206            "model_type",
207            "hidden_size",
208            "num_hidden_layers",
209            "architecture",
210        ];
211
212        for key in keys {
213            if let Some(value) = config.get(key) {
214                let str_value = match value {
215                    serde_json::Value::String(s) => s.clone(),
216                    serde_json::Value::Number(n) => n.to_string(),
217                    serde_json::Value::Bool(b) => b.to_string(),
218                    serde_json::Value::Object(_) | serde_json::Value::Array(_) => {
219                        serde_json::to_string(value).unwrap_or_default()
220                    }
221                    serde_json::Value::Null => continue,
222                };
223                metadata.insert(key.to_string(), str_value);
224            }
225        }
226
227        // Handle quantization config
228        if let Some(quant_cfg) = config
229            .get("quantization_config")
230            .or_else(|| config.get("quantization"))
231        {
232            if let Some(bits) = quant_cfg
233                .get("bits")
234                .or_else(|| quant_cfg.get("num_bits"))
235                .and_then(|v| v.as_u64())
236            {
237                metadata.insert("quantization_bits".to_string(), bits.to_string());
238            }
239        }
240
241        metadata
242    }
243}
244
245#[cfg(test)]
246mod tests {
247    use super::*;
248
249    #[test]
250    fn test_resolver_creation() {
251        let resolver = ModelResolver::new().unwrap();
252        assert!(resolver.resolved_cache.is_empty());
253    }
254
255    #[test]
256    fn test_collect_metadata() {
257        let config = serde_json::json!({
258            "model_type": "llama",
259            "hidden_size": 4096,
260            "num_hidden_layers": 32,
261            "quantization_config": {
262                "bits": 4
263            }
264        });
265
266        let metadata = ModelResolver::collect_metadata(&config);
267        assert_eq!(metadata.get("model_type"), Some(&"llama".to_string()));
268        assert_eq!(metadata.get("hidden_size"), Some(&"4096".to_string()));
269        assert_eq!(metadata.get("quantization_bits"), Some(&"4".to_string()));
270    }
271
272    #[test]
273    fn test_infer_hf_repo() {
274        let config = serde_json::json!({
275            "_name_or_path": "meta-llama/Llama-3.1-8B-Instruct"
276        });
277
278        let repo = ModelResolver::infer_hf_repo(&config);
279        assert_eq!(repo, Some("meta-llama/Llama-3.1-8B-Instruct".to_string()));
280    }
281
282    #[test]
283    fn test_infer_hf_repo_local_path() {
284        let config = serde_json::json!({
285            "_name_or_path": "/local/path/to/model"
286        });
287
288        let repo = ModelResolver::infer_hf_repo(&config);
289        assert_eq!(repo, None);
290    }
291}