Skip to main content

ai_lib_rust/protocol/
loader.rs

1//! Protocol loader with support for local files, embedded assets, and remote URLs
2//! Heartbeat sync - 2026-01-06
3//! Includes hot-reload capability using ArcSwap
4
5use crate::protocol::{ProtocolError, ProtocolManifest};
6use arc_swap::ArcSwap;
7use lru::LruCache;
8use std::path::{Path, PathBuf};
9use std::sync::{Arc, Mutex};
10
11/// Protocol loader that supports multiple sources
12pub struct ProtocolLoader {
13    base_path: Option<PathBuf>,
14    hot_reload: bool,
15    validator: crate::protocol::validator::ProtocolValidator,
16    cache: Mutex<LruCache<String, Arc<ProtocolManifest>>>,
17}
18
19impl ProtocolLoader {
20    pub fn new() -> Self {
21        Self {
22            base_path: None,
23            hot_reload: false,
24            validator: crate::protocol::validator::ProtocolValidator::default(),
25            // Use 100 as default cache size
26            // NonZeroUsize::new(100) is guaranteed to be Some, but use expect for clarity
27            cache: Mutex::new(LruCache::new(
28                std::num::NonZeroUsize::new(100)
29                    .expect("Cache size must be non-zero (this should never happen)"),
30            )),
31        }
32    }
33
34    /// Set base path for protocol files
35    pub fn with_base_path(mut self, path: impl AsRef<Path>) -> Self {
36        self.base_path = Some(path.as_ref().to_path_buf());
37        self
38    }
39
40    /// Enable hot reload
41    pub fn with_hot_reload(mut self, enable: bool) -> Self {
42        self.hot_reload = enable;
43        self
44    }
45
46    /// Load a model configuration
47    /// Model identifier format: "provider/model-name"
48    pub async fn load_model(&self, model: &str) -> Result<ProtocolManifest, ProtocolError> {
49        // 1. Check Cache
50        {
51            let mut cache = self.cache.lock().map_err(|e| {
52                ProtocolError::Internal(format!(
53                    "Failed to acquire cache lock while loading model '{}': {}",
54                    model, e
55                ))
56            })?;
57            if let Some(manifest) = cache.get(model) {
58                return Ok(manifest.as_ref().clone());
59            }
60        }
61
62        let parts: Vec<&str> = model.split('/').collect();
63        if parts.len() != 2 {
64            return Err(ProtocolError::NotFound {
65                id: model.to_string(),
66                hint: Some("Ensure the model name follows the 'provider/model' format".to_string()),
67            });
68        }
69
70        let provider = parts[0];
71        let model_name = parts[1];
72
73        // First, try to load model registry to get provider reference.
74        // If registry doesn't contain this model (common for providers like deepseek),
75        // fall back to loading provider manifest directly using the provider segment.
76        let manifest = match self.load_model_config(model_name).await {
77            Ok(model_config) => self.load_provider(&model_config.provider).await?,
78            Err(ProtocolError::NotFound { .. }) => self.load_provider(provider).await?,
79            Err(e) => return Err(e),
80        };
81
82        // 2. Update Cache
83        {
84            let mut cache = self.cache.lock().map_err(|e| {
85                ProtocolError::Internal(format!(
86                    "Failed to acquire cache lock while caching model '{}': {}",
87                    model, e
88                ))
89            })?;
90            cache.put(model.to_string(), Arc::new(manifest.clone()));
91        }
92
93        Ok(manifest)
94    }
95
96    /// Load provider configuration
97    pub async fn load_provider(
98        &self,
99        provider_id: &str,
100    ) -> Result<ProtocolManifest, ProtocolError> {
101        // Try multiple sources in order:
102        // 1. Local file system (dist JSON) - PREFERRED
103        // 2. Local file system (source YAML) - FALLBACK
104        // 3. GitHub URL (if AI_PROTOCOL_DIR is a URL)
105        // 4. Embedded assets (future)
106
107        // Path prioritization helper
108        let mut search_locations: Vec<(PathBuf, bool)> = Vec::new(); // (path_base, is_json_preferred)
109
110        // 1. Check user-configured base_path
111        if let Some(ref base_path) = self.base_path {
112            // Priority 1: dist/v1/providers/{id}.json
113            search_locations.push((base_path.join("dist").join("v1").join("providers"), true));
114            // Priority 2: v1/providers/{id}.yaml
115            search_locations.push((base_path.join("v1").join("providers"), false));
116        }
117
118        // 2. Check AI_PROTOCOL_DIR Env Var
119        if let Ok(root) =
120            std::env::var("AI_PROTOCOL_DIR").or_else(|_| std::env::var("AI_PROTOCOL_PATH"))
121        {
122            if root.starts_with("http://") || root.starts_with("https://") {
123                // Handling URL sources (Remote)
124                // Try JSON first if it looks like a raw github url, but typically raw github urls are specific.
125                // For simplicity, we stick to the existing logic for URLs but could enhance later to try .json
126                let url = if root.ends_with('/') {
127                    format!("{}dist/v1/providers/{}.json", root, provider_id)
128                } else {
129                    format!("{}/dist/v1/providers/{}.json", root, provider_id)
130                };
131
132                // Try JSON from remote
133                if let Ok(manifest) = self.load_from_json_url(&url).await {
134                    return Ok(manifest);
135                }
136
137                // Fallback to YAML from remote
138                let url_yaml = if root.ends_with('/') {
139                    format!("{}v1/providers/{}.yaml", root, provider_id)
140                } else {
141                    format!("{}/v1/providers/{}.yaml", root, provider_id)
142                };
143                return self.load_from_url(&url_yaml).await;
144            } else {
145                // Local Path from Env
146                let root = PathBuf::from(root);
147                search_locations.push((root.join("dist").join("v1").join("providers"), true));
148                search_locations.push((root.join("v1").join("providers"), false));
149            }
150        }
151
152        // 3. Default dev locations
153        let default_roots = vec![
154            PathBuf::from("ai-protocol"),
155            PathBuf::from("../ai-protocol"),
156            PathBuf::from("../../ai-protocol"),
157            PathBuf::from("D:\\ai-protocol"),
158        ];
159
160        for root in default_roots {
161            search_locations.push((root.join("dist").join("v1").join("providers"), true));
162            search_locations.push((root.join("v1").join("providers"), false));
163        }
164
165        // Execute Search
166        for (base, prefer_json) in search_locations {
167            if prefer_json {
168                let json_path = base.join(format!("{}.json", provider_id));
169                if json_path.exists() {
170                    return self.load_from_json_file(&json_path).await;
171                }
172            } else {
173                let yaml_path = base.join(format!("{}.yaml", provider_id));
174                if yaml_path.exists() {
175                    return self.load_from_file(&yaml_path).await;
176                }
177            }
178        }
179
180        // Last resort: try GitHub raw URL (canonical source) - JSON
181        let github_json = format!(
182            "https://raw.githubusercontent.com/hiddenpath/ai-protocol/main/dist/v1/providers/{}.json",
183            provider_id
184        );
185        if let Ok(manifest) = self.load_from_json_url(&github_json).await {
186            return Ok(manifest);
187        }
188
189        // Last resort fallback: YAML
190        let github_yaml = format!(
191            "https://raw.githubusercontent.com/hiddenpath/ai-protocol/main/v1/providers/{}.yaml",
192            provider_id
193        );
194        if let Ok(manifest) = self.load_from_url(&github_yaml).await {
195            return Ok(manifest);
196        }
197
198        Err(ProtocolError::NotFound {
199            id: provider_id.to_string(),
200            hint: Some(format!(
201                "Check if the provider file '{}.json' or '{}.yaml' exists in your protocol directory",
202                provider_id, provider_id
203            )),
204        })
205    }
206
207    /// Load protocol from local JSON file (Fast Path)
208    async fn load_from_json_file(&self, path: &Path) -> Result<ProtocolManifest, ProtocolError> {
209        let content = tokio::fs::read(path)
210            .await
211            .map_err(|e| ProtocolError::LoadError {
212                path: path.to_string_lossy().to_string(),
213                reason: e.to_string(),
214                hint: Some("Check file permissions.".to_string()),
215            })?;
216
217        let manifest: ProtocolManifest = serde_json::from_slice(&content)
218            .map_err(|e| ProtocolError::ValidationError(format!("Invalid JSON manifest: {}", e)))?;
219
220        // Validate against JSON Schema (Optional but recommended even for dist)
221        // For max speed, we might skip this in release, but keeping for safety now.
222        self.validator.validate(&manifest)?;
223
224        Ok(manifest)
225    }
226
227    /// Load protocol from local YAML file (Legacy/Dev Path)
228    async fn load_from_file(&self, path: &Path) -> Result<ProtocolManifest, ProtocolError> {
229        // Read as bytes first to handle different encodings
230        let bytes = tokio::fs::read(path)
231            .await
232            .map_err(|e| ProtocolError::LoadError {
233                path: path.to_string_lossy().to_string(),
234                reason: e.to_string(),
235                hint: Some("Check if the file exists and you have read permissions.".to_string()),
236            })?;
237
238        // ... (encoding detection remains same)
239        let content = if bytes.len() >= 2 && bytes[0] == 0xFF && bytes[1] == 0xFE {
240            // UTF-16 LE with BOM
241            let utf16_bytes = &bytes[2..];
242            let mut utf16_chars = Vec::new();
243            for i in (0..utf16_bytes.len()).step_by(2) {
244                if i + 1 < utf16_bytes.len() {
245                    let code_unit = u16::from_le_bytes([utf16_bytes[i], utf16_bytes[i + 1]]);
246                    utf16_chars.push(code_unit);
247                }
248            }
249            String::from_utf16(&utf16_chars).map_err(|e| ProtocolError::LoadError {
250                path: path.to_string_lossy().to_string(),
251                reason: format!("Invalid UTF-16: {}", e),
252                hint: None,
253            })?
254        } else if bytes.len() >= 3 && bytes[0] == 0xEF && bytes[1] == 0xBB && bytes[2] == 0xBF {
255            String::from_utf8(bytes[3..].to_vec()).map_err(|e| ProtocolError::LoadError {
256                path: path.to_string_lossy().to_string(),
257                reason: format!("Invalid UTF-8 (after BOM): {}", e),
258                hint: None,
259            })?
260        } else {
261            String::from_utf8(bytes).map_err(|e| ProtocolError::LoadError {
262                path: path.to_string_lossy().to_string(),
263                reason: format!("Invalid UTF-8: {}", e),
264                hint: None,
265            })?
266        };
267
268        let manifest: ProtocolManifest = Self::parse_manifest_yaml(&content)?;
269        self.validator.validate(&manifest)?;
270        Ok(manifest)
271    }
272
273    /// Load protocol from remote JSON URL
274    async fn load_from_json_url(&self, url: &str) -> Result<ProtocolManifest, ProtocolError> {
275        let client = reqwest::Client::builder()
276            .timeout(std::time::Duration::from_secs(30))
277            .build()
278            .map_err(|e| ProtocolError::Internal(format!("Failed to create HTTP client: {}", e)))?;
279
280        let response = client
281            .get(url)
282            .send()
283            .await
284            .map_err(|e| ProtocolError::LoadError {
285                path: url.to_string(),
286                reason: format!("HTTP request failed: {}", e),
287                hint: None,
288            })?;
289
290        if !response.status().is_success() {
291            return Err(ProtocolError::LoadError {
292                path: url.to_string(),
293                reason: format!("HTTP {}", response.status()),
294                hint: None,
295            });
296        }
297
298        let content = response
299            .bytes()
300            .await
301            .map_err(|e| ProtocolError::LoadError {
302                path: url.to_string(),
303                reason: format!("Failed to read bytes: {}", e),
304                hint: None,
305            })?;
306
307        let manifest: ProtocolManifest = serde_json::from_slice(&content).map_err(|e| {
308            ProtocolError::ValidationError(format!("Invalid JSON manifest from URL: {}", e))
309        })?;
310
311        self.validator.validate(&manifest)?;
312        Ok(manifest)
313    }
314
315    /// Load protocol from remote URL (GitHub raw URL)
316    async fn load_from_url(&self, url: &str) -> Result<ProtocolManifest, ProtocolError> {
317        let client = reqwest::Client::builder()
318            .timeout(std::time::Duration::from_secs(30))
319            .build()
320            .map_err(|e| ProtocolError::Internal(format!("Failed to create HTTP client: {}", e)))?;
321
322        let response = client
323            .get(url)
324            .send()
325            .await
326            .map_err(|e| ProtocolError::LoadError {
327                path: url.to_string(),
328                reason: format!("HTTP request failed: {}", e),
329                hint: Some(
330                    "Check your internet connection and verify the URL is accessible.".to_string(),
331                ),
332            })?;
333
334        if !response.status().is_success() {
335            return Err(ProtocolError::LoadError {
336                path: url.to_string(),
337                reason: format!(
338                    "HTTP {}: {}",
339                    response.status(),
340                    response.text().await.unwrap_or_default()
341                ),
342                hint: Some(
343                    "Verify the remote registry URL and your API permissions if any.".to_string(),
344                ),
345            });
346        }
347
348        let content = response
349            .text()
350            .await
351            .map_err(|e| ProtocolError::LoadError {
352                path: url.to_string(),
353                reason: format!("Failed to read response: {}", e),
354                hint: None,
355            })?;
356
357        let manifest: ProtocolManifest = Self::parse_manifest_yaml(&content)?;
358
359        // Validate against JSON Schema
360        self.validator.validate(&manifest)?;
361
362        Ok(manifest)
363    }
364
365    /// Parse YAML into a ProtocolManifest with better error classification.
366    ///
367    /// Rationale:
368    /// - YAML syntax/encoding issues are "load" errors.
369    /// - Structural mismatches (missing required fields, wrong types) are "validation" errors.
370    fn parse_manifest_yaml(content: &str) -> Result<ProtocolManifest, ProtocolError> {
371        serde_yaml::from_str::<ProtocolManifest>(content).map_err(|e| {
372            let msg = e.to_string();
373            // Heuristic classification based on serde error messages.
374            // This keeps public error categories stable without pulling in serde internals.
375            let looks_structural = msg.contains("missing field")
376                || msg.contains("unknown field")
377                || msg.contains("invalid type")
378                || msg.contains("invalid value")
379                || msg.contains("expected");
380
381            if looks_structural {
382                ProtocolError::ValidationError(format!("Invalid manifest structure: {}", msg))
383            } else {
384                ProtocolError::YamlError(msg)
385            }
386        })
387    }
388
389    /// Load model configuration from registry
390    async fn load_model_config(&self, model_name: &str) -> Result<ModelConfig, ProtocolError> {
391        // Try to find model, scanning registries.
392        // Priority: dist/v1/models/*.json -> v1/models/*.yaml
393
394        let mut search_locations: Vec<(PathBuf, bool)> = Vec::new(); // (path_base, is_json_preferred)
395
396        // 1. Env Var AI_PROTOCOL_DIR
397        if let Ok(root) =
398            std::env::var("AI_PROTOCOL_DIR").or_else(|_| std::env::var("AI_PROTOCOL_PATH"))
399        {
400            // If HTTP, skipped here as typically model config loading implies local or full repo clone access.
401            // If we really need remote model loading, we'd need a different strategy (scanning a remote index).
402            // For now, assume local model registry for this heuristic.
403            if !root.starts_with("http://") && !root.starts_with("https://") {
404                let root = PathBuf::from(root);
405                search_locations.push((root.join("dist").join("v1").join("models"), true));
406                search_locations.push((root.join("v1").join("models"), false));
407            }
408        }
409
410        // 2. Default paths
411        let default_roots = vec![
412            PathBuf::from("ai-protocol"),
413            PathBuf::from("../ai-protocol"),
414            PathBuf::from("../../ai-protocol"),
415            PathBuf::from("D:\\ai-protocol"),
416        ];
417
418        for root in default_roots {
419            search_locations.push((root.join("dist").join("v1").join("models"), true));
420            search_locations.push((root.join("v1").join("models"), false));
421        }
422
423        for (base, prefer_json) in search_locations {
424            if !base.exists() {
425                continue;
426            }
427            let mut rd = match tokio::fs::read_dir(&base).await {
428                Ok(rd) => rd,
429                Err(_) => continue,
430            };
431
432            while let Ok(Some(entry)) = rd.next_entry().await {
433                let path = entry.path();
434                let extension = path.extension().and_then(|s| s.to_str());
435
436                let is_match = if prefer_json {
437                    extension.map(|s| s.eq_ignore_ascii_case("json")) == Some(true)
438                } else {
439                    extension
440                        .map(|s| s.eq_ignore_ascii_case("yaml") || s.eq_ignore_ascii_case("yml"))
441                        == Some(true)
442                };
443
444                if !is_match {
445                    continue;
446                }
447
448                if prefer_json {
449                    if let Ok(config) = self.load_model_registry_json(&path).await {
450                        if let Some(model) = config.models.get(model_name) {
451                            return Ok(model.clone());
452                        }
453                    }
454                } else {
455                    if let Ok(config) = self.load_model_registry_yaml(&path).await {
456                        if let Some(model) = config.models.get(model_name) {
457                            return Ok(model.clone());
458                        }
459                    }
460                }
461            }
462        }
463
464        Err(ProtocolError::NotFound {
465            id: model_name.to_string(),
466            hint: Some(
467                "Check if the model is registered in the manifests/v1/models/ directory"
468                    .to_string(),
469            ),
470        })
471    }
472
473    async fn load_model_registry_json(&self, path: &Path) -> Result<ModelRegistry, ProtocolError> {
474        let content = tokio::fs::read(path)
475            .await
476            .map_err(|e| ProtocolError::LoadError {
477                path: path.to_string_lossy().to_string(),
478                reason: e.to_string(),
479                hint: None,
480            })?;
481        let registry: ModelRegistry = serde_json::from_slice(&content).map_err(|e| {
482            ProtocolError::ValidationError(format!("Invalid JSON model registry: {}", e))
483        })?;
484        Ok(registry)
485    }
486
487    async fn load_model_registry_yaml(&self, path: &Path) -> Result<ModelRegistry, ProtocolError> {
488        let content =
489            tokio::fs::read_to_string(path)
490                .await
491                .map_err(|e| ProtocolError::LoadError {
492                    path: path.to_string_lossy().to_string(),
493                    reason: format!("Failed to read model registry: {}", e),
494                    hint: None,
495                })?;
496
497        let registry: ModelRegistry = serde_yaml::from_str(&content).map_err(|e| {
498            ProtocolError::YamlError(format!("Failed to parse model registry: {}", e))
499        })?;
500
501        Ok(registry)
502    }
503}
504
505impl Default for ProtocolLoader {
506    fn default() -> Self {
507        Self::new()
508    }
509}
510
511/// Model registry structure
512#[derive(Debug, Clone, serde::Deserialize)]
513struct ModelRegistry {
514    models: std::collections::HashMap<String, ModelConfig>,
515}
516
517/// Model configuration from registry
518#[allow(dead_code)]
519#[derive(Debug, Clone, serde::Deserialize)]
520struct ModelConfig {
521    provider: String,
522    #[serde(default)]
523    model_id: Option<String>,
524    #[serde(default)]
525    context_window: Option<u32>,
526    #[serde(default)]
527    capabilities: Vec<String>,
528}
529
530/// Hot-reloadable protocol registry
531pub struct ProtocolRegistry {
532    manifests: ArcSwap<std::collections::HashMap<String, Arc<ProtocolManifest>>>,
533    loader: ProtocolLoader,
534}
535
536impl ProtocolRegistry {
537    pub fn new() -> Self {
538        Self {
539            manifests: ArcSwap::from_pointee(std::collections::HashMap::new()),
540            loader: ProtocolLoader::new(),
541        }
542    }
543
544    /// Get or load a protocol manifest
545    pub async fn get_manifest(
546        &self,
547        provider_id: &str,
548    ) -> Result<Arc<ProtocolManifest>, ProtocolError> {
549        // Check cache first
550        let current = self.manifests.load();
551        if let Some(manifest) = current.get(provider_id) {
552            return Ok(Arc::clone(manifest));
553        }
554
555        // Load and cache
556        let manifest = self.loader.load_provider(provider_id).await?;
557        let manifest_arc = Arc::new(manifest);
558
559        // Update cache atomically
560        let mut updated_map = std::collections::HashMap::new();
561        for (k, v) in current.iter() {
562            updated_map.insert(k.clone(), v.clone());
563        }
564        updated_map.insert(provider_id.to_string(), manifest_arc.clone());
565        self.manifests.store(Arc::new(updated_map));
566
567        Ok(manifest_arc)
568    }
569}
570
571impl Default for ProtocolRegistry {
572    fn default() -> Self {
573        Self::new()
574    }
575}