Skip to main content

ai_lib_rust/protocol/
loader.rs

1//! Protocol loader with support for local files, embedded assets, and remote URLs
2//! Heartbeat sync - 2026-01-06
3//! Includes hot-reload capability using ArcSwap
4
5use crate::protocol::{ProtocolError, ProtocolManifest};
6use arc_swap::ArcSwap;
7use lru::LruCache;
8use std::path::{Path, PathBuf};
9use std::sync::{Arc, Mutex};
10
11/// Protocol loader that supports multiple sources
12pub struct ProtocolLoader {
13    base_path: Option<PathBuf>,
14    hot_reload: bool,
15    validator: crate::protocol::validator::ProtocolValidator,
16    cache: Mutex<LruCache<String, Arc<ProtocolManifest>>>,
17}
18
19impl ProtocolLoader {
20    pub fn new() -> Self {
21        Self {
22            base_path: None,
23            hot_reload: false,
24            validator: crate::protocol::validator::ProtocolValidator::default(),
25            // Use 100 as default cache size
26            // NonZeroUsize::new(100) is guaranteed to be Some, but use expect for clarity
27            cache: Mutex::new(LruCache::new(
28                std::num::NonZeroUsize::new(100)
29                    .expect("Cache size must be non-zero (this should never happen)"),
30            )),
31        }
32    }
33
34    /// Set base path for protocol files
35    pub fn with_base_path(mut self, path: impl AsRef<Path>) -> Self {
36        self.base_path = Some(path.as_ref().to_path_buf());
37        self
38    }
39
40    /// Enable hot reload
41    pub fn with_hot_reload(mut self, enable: bool) -> Self {
42        self.hot_reload = enable;
43        self
44    }
45
46    /// Load a model configuration
47    /// Model identifier format: "provider/model-name"
48    pub async fn load_model(&self, model: &str) -> Result<ProtocolManifest, ProtocolError> {
49        // 1. Check Cache
50        {
51            let mut cache = self.cache.lock().map_err(|e| {
52                ProtocolError::Internal(format!(
53                    "Failed to acquire cache lock while loading model '{}': {}",
54                    model, e
55                ))
56            })?;
57            if let Some(manifest) = cache.get(model) {
58                return Ok(manifest.as_ref().clone());
59            }
60        }
61
62        // Allow "provider/model" or "provider/org/model-name" (e.g. nvidia/minimaxai/minimax-m2)
63        let parts: Vec<&str> = model.split('/').collect();
64        if parts.len() < 2 {
65            return Err(ProtocolError::NotFound {
66                id: model.to_string(),
67                hint: Some("Ensure the model name follows the 'provider/model' format".to_string()),
68            });
69        }
70
71        let provider = parts[0];
72        let model_name = parts[1..].join("/");
73
74        // First, try to load model registry to get provider reference.
75        // If registry doesn't contain this model (common for providers like deepseek),
76        // fall back to loading provider manifest directly using the provider segment.
77        let manifest = match self.load_model_config(&model_name).await {
78            Ok(model_config) => self.load_provider(&model_config.provider).await?,
79            Err(ProtocolError::NotFound { .. }) => self.load_provider(provider).await?,
80            Err(e) => return Err(e),
81        };
82
83        // 2. Update Cache
84        {
85            let mut cache = self.cache.lock().map_err(|e| {
86                ProtocolError::Internal(format!(
87                    "Failed to acquire cache lock while caching model '{}': {}",
88                    model, e
89                ))
90            })?;
91            cache.put(model.to_string(), Arc::new(manifest.clone()));
92        }
93
94        Ok(manifest)
95    }
96
97    /// Load provider configuration
98    pub async fn load_provider(
99        &self,
100        provider_id: &str,
101    ) -> Result<ProtocolManifest, ProtocolError> {
102        // Try multiple sources in order:
103        // 1. Local file system (dist JSON) - PREFERRED
104        // 2. Local file system (source YAML) - FALLBACK
105        // 3. GitHub URL (if AI_PROTOCOL_DIR is a URL)
106        // 4. Embedded assets (future)
107
108        // Path prioritization helper
109        let mut search_locations: Vec<(PathBuf, bool)> = Vec::new(); // (path_base, is_json_preferred)
110
111        // 1. Check user-configured base_path
112        if let Some(ref base_path) = self.base_path {
113            // Priority 1: dist/v1/providers/{id}.json
114            search_locations.push((base_path.join("dist").join("v1").join("providers"), true));
115            // Priority 2: v1/providers/{id}.yaml
116            search_locations.push((base_path.join("v1").join("providers"), false));
117        }
118
119        // 2. Check AI_PROTOCOL_DIR Env Var
120        if let Ok(root) =
121            std::env::var("AI_PROTOCOL_DIR").or_else(|_| std::env::var("AI_PROTOCOL_PATH"))
122        {
123            if root.starts_with("http://") || root.starts_with("https://") {
124                // Handling URL sources (Remote)
125                // Try JSON first if it looks like a raw github url, but typically raw github urls are specific.
126                // For simplicity, we stick to the existing logic for URLs but could enhance later to try .json
127                let url = if root.ends_with('/') {
128                    format!("{}dist/v1/providers/{}.json", root, provider_id)
129                } else {
130                    format!("{}/dist/v1/providers/{}.json", root, provider_id)
131                };
132
133                // Try JSON from remote
134                if let Ok(manifest) = self.load_from_json_url(&url).await {
135                    return Ok(manifest);
136                }
137
138                // Fallback to YAML from remote
139                let url_yaml = if root.ends_with('/') {
140                    format!("{}v1/providers/{}.yaml", root, provider_id)
141                } else {
142                    format!("{}/v1/providers/{}.yaml", root, provider_id)
143                };
144                return self.load_from_url(&url_yaml).await;
145            } else {
146                // Local Path from Env
147                let root = PathBuf::from(root);
148                search_locations.push((root.join("dist").join("v1").join("providers"), true));
149                search_locations.push((root.join("v1").join("providers"), false));
150            }
151        }
152
153        // 3. Default dev locations
154        let default_roots = vec![
155            PathBuf::from("ai-protocol"),
156            PathBuf::from("../ai-protocol"),
157            PathBuf::from("../../ai-protocol"),
158            PathBuf::from("D:\\ai-protocol"),
159        ];
160
161        for root in default_roots {
162            search_locations.push((root.join("dist").join("v1").join("providers"), true));
163            search_locations.push((root.join("v1").join("providers"), false));
164        }
165
166        // Execute Search
167        for (base, prefer_json) in search_locations {
168            if prefer_json {
169                let json_path = base.join(format!("{}.json", provider_id));
170                if json_path.exists() {
171                    return self.load_from_json_file(&json_path).await;
172                }
173            } else {
174                let yaml_path = base.join(format!("{}.yaml", provider_id));
175                if yaml_path.exists() {
176                    return self.load_from_file(&yaml_path).await;
177                }
178            }
179        }
180
181        // Last resort: try GitHub raw URL (canonical source) - JSON
182        let github_json = format!(
183            "https://raw.githubusercontent.com/hiddenpath/ai-protocol/main/dist/v1/providers/{}.json",
184            provider_id
185        );
186        if let Ok(manifest) = self.load_from_json_url(&github_json).await {
187            return Ok(manifest);
188        }
189
190        // Last resort fallback: YAML
191        let github_yaml = format!(
192            "https://raw.githubusercontent.com/hiddenpath/ai-protocol/main/v1/providers/{}.yaml",
193            provider_id
194        );
195        if let Ok(manifest) = self.load_from_url(&github_yaml).await {
196            return Ok(manifest);
197        }
198
199        Err(ProtocolError::NotFound {
200            id: provider_id.to_string(),
201            hint: Some(format!(
202                "Check if the provider file '{}.json' or '{}.yaml' exists in your protocol directory",
203                provider_id, provider_id
204            )),
205        })
206    }
207
208    /// Load protocol from local JSON file (Fast Path)
209    async fn load_from_json_file(&self, path: &Path) -> Result<ProtocolManifest, ProtocolError> {
210        let content = tokio::fs::read(path)
211            .await
212            .map_err(|e| ProtocolError::LoadError {
213                path: path.to_string_lossy().to_string(),
214                reason: e.to_string(),
215                hint: Some("Check file permissions.".to_string()),
216            })?;
217
218        let manifest: ProtocolManifest = serde_json::from_slice(&content)
219            .map_err(|e| ProtocolError::ValidationError(format!("Invalid JSON manifest: {}", e)))?;
220
221        // Validate against JSON Schema (Optional but recommended even for dist)
222        // For max speed, we might skip this in release, but keeping for safety now.
223        self.validator.validate(&manifest)?;
224
225        Ok(manifest)
226    }
227
228    /// Load protocol from local YAML file (Legacy/Dev Path)
229    async fn load_from_file(&self, path: &Path) -> Result<ProtocolManifest, ProtocolError> {
230        // Read as bytes first to handle different encodings
231        let bytes = tokio::fs::read(path)
232            .await
233            .map_err(|e| ProtocolError::LoadError {
234                path: path.to_string_lossy().to_string(),
235                reason: e.to_string(),
236                hint: Some("Check if the file exists and you have read permissions.".to_string()),
237            })?;
238
239        // ... (encoding detection remains same)
240        let content = if bytes.len() >= 2 && bytes[0] == 0xFF && bytes[1] == 0xFE {
241            // UTF-16 LE with BOM
242            let utf16_bytes = &bytes[2..];
243            let mut utf16_chars = Vec::new();
244            for i in (0..utf16_bytes.len()).step_by(2) {
245                if i + 1 < utf16_bytes.len() {
246                    let code_unit = u16::from_le_bytes([utf16_bytes[i], utf16_bytes[i + 1]]);
247                    utf16_chars.push(code_unit);
248                }
249            }
250            String::from_utf16(&utf16_chars).map_err(|e| ProtocolError::LoadError {
251                path: path.to_string_lossy().to_string(),
252                reason: format!("Invalid UTF-16: {}", e),
253                hint: None,
254            })?
255        } else if bytes.len() >= 3 && bytes[0] == 0xEF && bytes[1] == 0xBB && bytes[2] == 0xBF {
256            String::from_utf8(bytes[3..].to_vec()).map_err(|e| ProtocolError::LoadError {
257                path: path.to_string_lossy().to_string(),
258                reason: format!("Invalid UTF-8 (after BOM): {}", e),
259                hint: None,
260            })?
261        } else {
262            String::from_utf8(bytes).map_err(|e| ProtocolError::LoadError {
263                path: path.to_string_lossy().to_string(),
264                reason: format!("Invalid UTF-8: {}", e),
265                hint: None,
266            })?
267        };
268
269        let manifest: ProtocolManifest = Self::parse_manifest_yaml(&content)?;
270        self.validator.validate(&manifest)?;
271        Ok(manifest)
272    }
273
274    /// Load protocol from remote JSON URL
275    async fn load_from_json_url(&self, url: &str) -> Result<ProtocolManifest, ProtocolError> {
276        let client = reqwest::Client::builder()
277            .timeout(std::time::Duration::from_secs(30))
278            .build()
279            .map_err(|e| ProtocolError::Internal(format!("Failed to create HTTP client: {}", e)))?;
280
281        let response = client
282            .get(url)
283            .send()
284            .await
285            .map_err(|e| ProtocolError::LoadError {
286                path: url.to_string(),
287                reason: format!("HTTP request failed: {}", e),
288                hint: None,
289            })?;
290
291        if !response.status().is_success() {
292            return Err(ProtocolError::LoadError {
293                path: url.to_string(),
294                reason: format!("HTTP {}", response.status()),
295                hint: None,
296            });
297        }
298
299        let content = response
300            .bytes()
301            .await
302            .map_err(|e| ProtocolError::LoadError {
303                path: url.to_string(),
304                reason: format!("Failed to read bytes: {}", e),
305                hint: None,
306            })?;
307
308        let manifest: ProtocolManifest = serde_json::from_slice(&content).map_err(|e| {
309            ProtocolError::ValidationError(format!("Invalid JSON manifest from URL: {}", e))
310        })?;
311
312        self.validator.validate(&manifest)?;
313        Ok(manifest)
314    }
315
316    /// Load protocol from remote URL (GitHub raw URL)
317    async fn load_from_url(&self, url: &str) -> Result<ProtocolManifest, ProtocolError> {
318        let client = reqwest::Client::builder()
319            .timeout(std::time::Duration::from_secs(30))
320            .build()
321            .map_err(|e| ProtocolError::Internal(format!("Failed to create HTTP client: {}", e)))?;
322
323        let response = client
324            .get(url)
325            .send()
326            .await
327            .map_err(|e| ProtocolError::LoadError {
328                path: url.to_string(),
329                reason: format!("HTTP request failed: {}", e),
330                hint: Some(
331                    "Check your internet connection and verify the URL is accessible.".to_string(),
332                ),
333            })?;
334
335        if !response.status().is_success() {
336            return Err(ProtocolError::LoadError {
337                path: url.to_string(),
338                reason: format!(
339                    "HTTP {}: {}",
340                    response.status(),
341                    response.text().await.unwrap_or_default()
342                ),
343                hint: Some(
344                    "Verify the remote registry URL and your API permissions if any.".to_string(),
345                ),
346            });
347        }
348
349        let content = response
350            .text()
351            .await
352            .map_err(|e| ProtocolError::LoadError {
353                path: url.to_string(),
354                reason: format!("Failed to read response: {}", e),
355                hint: None,
356            })?;
357
358        let manifest: ProtocolManifest = Self::parse_manifest_yaml(&content)?;
359
360        // Validate against JSON Schema
361        self.validator.validate(&manifest)?;
362
363        Ok(manifest)
364    }
365
366    /// Parse YAML into a ProtocolManifest with better error classification.
367    ///
368    /// Rationale:
369    /// - YAML syntax/encoding issues are "load" errors.
370    /// - Structural mismatches (missing required fields, wrong types) are "validation" errors.
371    fn parse_manifest_yaml(content: &str) -> Result<ProtocolManifest, ProtocolError> {
372        serde_yaml::from_str::<ProtocolManifest>(content).map_err(|e| {
373            let msg = e.to_string();
374            // Heuristic classification based on serde error messages.
375            // This keeps public error categories stable without pulling in serde internals.
376            let looks_structural = msg.contains("missing field")
377                || msg.contains("unknown field")
378                || msg.contains("invalid type")
379                || msg.contains("invalid value")
380                || msg.contains("expected");
381
382            if looks_structural {
383                ProtocolError::ValidationError(format!("Invalid manifest structure: {}", msg))
384            } else {
385                ProtocolError::YamlError(msg)
386            }
387        })
388    }
389
390    /// Load model configuration from registry
391    async fn load_model_config(&self, model_name: &str) -> Result<ModelConfig, ProtocolError> {
392        // Try to find model, scanning registries.
393        // Priority: dist/v1/models/*.json -> v1/models/*.yaml
394
395        let mut search_locations: Vec<(PathBuf, bool)> = Vec::new(); // (path_base, is_json_preferred)
396
397        // 1. Env Var AI_PROTOCOL_DIR
398        if let Ok(root) =
399            std::env::var("AI_PROTOCOL_DIR").or_else(|_| std::env::var("AI_PROTOCOL_PATH"))
400        {
401            // If HTTP, skipped here as typically model config loading implies local or full repo clone access.
402            // If we really need remote model loading, we'd need a different strategy (scanning a remote index).
403            // For now, assume local model registry for this heuristic.
404            if !root.starts_with("http://") && !root.starts_with("https://") {
405                let root = PathBuf::from(root);
406                search_locations.push((root.join("dist").join("v1").join("models"), true));
407                search_locations.push((root.join("v1").join("models"), false));
408            }
409        }
410
411        // 2. Default paths
412        let default_roots = vec![
413            PathBuf::from("ai-protocol"),
414            PathBuf::from("../ai-protocol"),
415            PathBuf::from("../../ai-protocol"),
416            PathBuf::from("D:\\ai-protocol"),
417        ];
418
419        for root in default_roots {
420            search_locations.push((root.join("dist").join("v1").join("models"), true));
421            search_locations.push((root.join("v1").join("models"), false));
422        }
423
424        for (base, prefer_json) in search_locations {
425            if !base.exists() {
426                continue;
427            }
428            let mut rd = match tokio::fs::read_dir(&base).await {
429                Ok(rd) => rd,
430                Err(_) => continue,
431            };
432
433            while let Ok(Some(entry)) = rd.next_entry().await {
434                let path = entry.path();
435                let extension = path.extension().and_then(|s| s.to_str());
436
437                let is_match = if prefer_json {
438                    extension.map(|s| s.eq_ignore_ascii_case("json")) == Some(true)
439                } else {
440                    extension
441                        .map(|s| s.eq_ignore_ascii_case("yaml") || s.eq_ignore_ascii_case("yml"))
442                        == Some(true)
443                };
444
445                if !is_match {
446                    continue;
447                }
448
449                if prefer_json {
450                    if let Ok(config) = self.load_model_registry_json(&path).await {
451                        if let Some(model) = config.models.get(model_name) {
452                            return Ok(model.clone());
453                        }
454                    }
455                } else {
456                    if let Ok(config) = self.load_model_registry_yaml(&path).await {
457                        if let Some(model) = config.models.get(model_name) {
458                            return Ok(model.clone());
459                        }
460                    }
461                }
462            }
463        }
464
465        Err(ProtocolError::NotFound {
466            id: model_name.to_string(),
467            hint: Some(
468                "Check if the model is registered in the manifests/v1/models/ directory"
469                    .to_string(),
470            ),
471        })
472    }
473
474    async fn load_model_registry_json(&self, path: &Path) -> Result<ModelRegistry, ProtocolError> {
475        let content = tokio::fs::read(path)
476            .await
477            .map_err(|e| ProtocolError::LoadError {
478                path: path.to_string_lossy().to_string(),
479                reason: e.to_string(),
480                hint: None,
481            })?;
482        let registry: ModelRegistry = serde_json::from_slice(&content).map_err(|e| {
483            ProtocolError::ValidationError(format!("Invalid JSON model registry: {}", e))
484        })?;
485        Ok(registry)
486    }
487
488    async fn load_model_registry_yaml(&self, path: &Path) -> Result<ModelRegistry, ProtocolError> {
489        let content =
490            tokio::fs::read_to_string(path)
491                .await
492                .map_err(|e| ProtocolError::LoadError {
493                    path: path.to_string_lossy().to_string(),
494                    reason: format!("Failed to read model registry: {}", e),
495                    hint: None,
496                })?;
497
498        let registry: ModelRegistry = serde_yaml::from_str(&content).map_err(|e| {
499            ProtocolError::YamlError(format!("Failed to parse model registry: {}", e))
500        })?;
501
502        Ok(registry)
503    }
504}
505
506impl Default for ProtocolLoader {
507    fn default() -> Self {
508        Self::new()
509    }
510}
511
512/// Model registry structure
513#[derive(Debug, Clone, serde::Deserialize)]
514struct ModelRegistry {
515    models: std::collections::HashMap<String, ModelConfig>,
516}
517
518/// Model configuration from registry
519#[allow(dead_code)]
520#[derive(Debug, Clone, serde::Deserialize)]
521struct ModelConfig {
522    provider: String,
523    #[serde(default)]
524    model_id: Option<String>,
525    #[serde(default)]
526    context_window: Option<u32>,
527    #[serde(default)]
528    capabilities: Vec<String>,
529}
530
531/// Hot-reloadable protocol registry
532pub struct ProtocolRegistry {
533    manifests: ArcSwap<std::collections::HashMap<String, Arc<ProtocolManifest>>>,
534    loader: ProtocolLoader,
535}
536
537impl ProtocolRegistry {
538    pub fn new() -> Self {
539        Self {
540            manifests: ArcSwap::from_pointee(std::collections::HashMap::new()),
541            loader: ProtocolLoader::new(),
542        }
543    }
544
545    /// Get or load a protocol manifest
546    pub async fn get_manifest(
547        &self,
548        provider_id: &str,
549    ) -> Result<Arc<ProtocolManifest>, ProtocolError> {
550        // Check cache first
551        let current = self.manifests.load();
552        if let Some(manifest) = current.get(provider_id) {
553            return Ok(Arc::clone(manifest));
554        }
555
556        // Load and cache
557        let manifest = self.loader.load_provider(provider_id).await?;
558        let manifest_arc = Arc::new(manifest);
559
560        // Update cache atomically
561        let mut updated_map = std::collections::HashMap::new();
562        for (k, v) in current.iter() {
563            updated_map.insert(k.clone(), v.clone());
564        }
565        updated_map.insert(provider_id.to_string(), manifest_arc.clone());
566        self.manifests.store(Arc::new(updated_map));
567
568        Ok(manifest_arc)
569    }
570}
571
572impl Default for ProtocolRegistry {
573    fn default() -> Self {
574        Self::new()
575    }
576}