ai_lib_rust/protocol/
loader.rs

1//! Protocol loader with support for local files, embedded assets, and remote URLs
2//! Heartbeat sync - 2026-01-06
3//! Includes hot-reload capability using ArcSwap
4
5use crate::protocol::{ProtocolError, ProtocolManifest};
6use arc_swap::ArcSwap;
7use lru::LruCache;
8use std::path::{Path, PathBuf};
9use std::sync::{Arc, Mutex};
10
11/// Protocol loader that supports multiple sources
12pub struct ProtocolLoader {
13    base_path: Option<PathBuf>,
14    hot_reload: bool,
15    validator: crate::protocol::validator::ProtocolValidator,
16    cache: Mutex<LruCache<String, Arc<ProtocolManifest>>>,
17}
18
19impl ProtocolLoader {
20    pub fn new() -> Self {
21        Self {
22            base_path: None,
23            hot_reload: false,
24            validator: crate::protocol::validator::ProtocolValidator::default(),
25            // Use 100 as default cache size
26            // NonZeroUsize::new(100) is guaranteed to be Some, but use expect for clarity
27            cache: Mutex::new(LruCache::new(
28                std::num::NonZeroUsize::new(100)
29                    .expect("Cache size must be non-zero (this should never happen)"),
30            )),
31        }
32    }
33
34    /// Set base path for protocol files
35    pub fn with_base_path(mut self, path: impl AsRef<Path>) -> Self {
36        self.base_path = Some(path.as_ref().to_path_buf());
37        self
38    }
39
40    /// Enable hot reload
41    pub fn with_hot_reload(mut self, enable: bool) -> Self {
42        self.hot_reload = enable;
43        self
44    }
45
46    /// Load a model configuration
47    /// Model identifier format: "provider/model-name"
48    pub async fn load_model(&self, model: &str) -> Result<ProtocolManifest, ProtocolError> {
49        // 1. Check Cache
50        {
51            let mut cache = self.cache.lock().map_err(|e| {
52                ProtocolError::Internal(format!(
53                    "Failed to acquire cache lock while loading model '{}': {}",
54                    model, e
55                ))
56            })?;
57            if let Some(manifest) = cache.get(model) {
58                return Ok(manifest.as_ref().clone());
59            }
60        }
61
62        let parts: Vec<&str> = model.split('/').collect();
63        if parts.len() != 2 {
64            return Err(ProtocolError::NotFound {
65                id: model.to_string(),
66                hint: Some("Ensure the model name follows the 'provider/model' format".to_string()),
67            });
68        }
69
70        let provider = parts[0];
71        let model_name = parts[1];
72
73        // First, try to load model registry to get provider reference.
74        // If registry doesn't contain this model (common for providers like deepseek),
75        // fall back to loading provider manifest directly using the provider segment.
76        let manifest = match self.load_model_config(model_name).await {
77            Ok(model_config) => self.load_provider(&model_config.provider).await?,
78            Err(ProtocolError::NotFound { .. }) => self.load_provider(provider).await?,
79            Err(e) => return Err(e),
80        };
81
82        // 2. Update Cache
83        {
84            let mut cache = self.cache.lock().map_err(|e| {
85                ProtocolError::Internal(format!(
86                    "Failed to acquire cache lock while caching model '{}': {}",
87                    model, e
88                ))
89            })?;
90            cache.put(model.to_string(), Arc::new(manifest.clone()));
91        }
92
93        Ok(manifest)
94    }
95
96    /// Load provider configuration
97    pub async fn load_provider(
98        &self,
99        provider_id: &str,
100    ) -> Result<ProtocolManifest, ProtocolError> {
101        // Try multiple sources in order:
102        // 1. Local file system (if base_path is set)
103        // 2. GitHub URL (if AI_PROTOCOL_DIR is a URL)
104        // 3. Local file system (default paths)
105        // 4. Embedded assets (future: compile-time inclusion)
106
107        if let Some(ref base_path) = self.base_path {
108            let provider_path = base_path
109                .join("v1")
110                .join("providers")
111                .join(format!("{}.yaml", provider_id));
112
113            if provider_path.exists() {
114                return self.load_from_file(&provider_path).await;
115            }
116        }
117
118        // Check if AI_PROTOCOL_DIR is a GitHub URL
119        if let Ok(root) =
120            std::env::var("AI_PROTOCOL_DIR").or_else(|_| std::env::var("AI_PROTOCOL_PATH"))
121        {
122            // Check if it's a URL (starts with http:// or https://)
123            if root.starts_with("http://") || root.starts_with("https://") {
124                // Construct GitHub raw URL for provider manifest
125                let url = if root.ends_with('/') {
126                    format!("{}v1/providers/{}.yaml", root, provider_id)
127                } else {
128                    format!("{}/v1/providers/{}.yaml", root, provider_id)
129                };
130                return self.load_from_url(&url).await;
131            }
132        }
133
134        // Default search paths (local file system):
135        // - env `AI_PROTOCOL_DIR` / `AI_PROTOCOL_PATH` pointing to the ai-protocol repo root
136        // - relative paths for submodule/sibling setups
137        // - (dev convenience) `D:\ai-protocol\...` if present
138        let mut default_paths: Vec<PathBuf> = Vec::new();
139        if let Ok(root) =
140            std::env::var("AI_PROTOCOL_DIR").or_else(|_| std::env::var("AI_PROTOCOL_PATH"))
141        {
142            // Only add if it's not a URL (already handled above)
143            if !root.starts_with("http://") && !root.starts_with("https://") {
144                let root = PathBuf::from(root);
145                default_paths.push(root.join("v1").join("providers"));
146            }
147        }
148        default_paths.push(PathBuf::from("ai-protocol/v1/providers"));
149        default_paths.push(PathBuf::from("../ai-protocol/v1/providers"));
150        default_paths.push(PathBuf::from("../../ai-protocol/v1/providers"));
151        let win_dev = PathBuf::from("D:\\ai-protocol\\v1\\providers");
152        if win_dev.exists() {
153            default_paths.push(win_dev);
154        }
155
156        for base in default_paths {
157            let provider_path = base.join(format!("{}.yaml", provider_id));
158            if provider_path.exists() {
159                return self.load_from_file(&provider_path).await;
160            }
161        }
162
163        // Last resort: try GitHub raw URL (canonical source)
164        let github_url = format!(
165            "https://raw.githubusercontent.com/hiddenpath/ai-protocol/main/v1/providers/{}.yaml",
166            provider_id
167        );
168        if let Ok(manifest) = self.load_from_url(&github_url).await {
169            return Ok(manifest);
170        }
171
172        Err(ProtocolError::NotFound {
173            id: provider_id.to_string(),
174            hint: Some(format!(
175                "Check if the provider file '{}.yaml' exists in your protocol directory",
176                provider_id
177            )),
178        })
179    }
180
181    /// Load protocol from local file
182    async fn load_from_file(&self, path: &Path) -> Result<ProtocolManifest, ProtocolError> {
183        // Read as bytes first to handle different encodings
184        let bytes = tokio::fs::read(path)
185            .await
186            .map_err(|e| ProtocolError::LoadError {
187                path: path.to_string_lossy().to_string(),
188                reason: e.to_string(),
189                hint: Some("Check if the file exists and you have read permissions.".to_string()),
190            })?;
191
192        // Detect encoding and convert to UTF-8 string
193        let content = if bytes.len() >= 2 && bytes[0] == 0xFF && bytes[1] == 0xFE {
194            // UTF-16 LE with BOM
195            let utf16_bytes = &bytes[2..];
196            // Convert UTF-16 LE bytes to u16 array
197            let mut utf16_chars = Vec::new();
198            for i in (0..utf16_bytes.len()).step_by(2) {
199                if i + 1 < utf16_bytes.len() {
200                    let code_unit = u16::from_le_bytes([utf16_bytes[i], utf16_bytes[i + 1]]);
201                    utf16_chars.push(code_unit);
202                }
203            }
204            String::from_utf16(&utf16_chars).map_err(|e| ProtocolError::LoadError {
205                path: path.to_string_lossy().to_string(),
206                reason: format!(
207                    "Invalid UTF-16: {}. Please convert the file to UTF-8 encoding.",
208                    e
209                ),
210                hint: Some(
211                    "The runtime expects UTF-8 manifests. Try converting the file encoding."
212                        .to_string(),
213                ),
214            })?
215        } else if bytes.len() >= 3 && bytes[0] == 0xEF && bytes[1] == 0xBB && bytes[2] == 0xBF {
216            // UTF-8 with BOM, skip BOM
217            String::from_utf8(bytes[3..].to_vec()).map_err(|e| ProtocolError::LoadError {
218                path: path.to_string_lossy().to_string(),
219                reason: format!("Invalid UTF-8 (after BOM): {}", e),
220                hint: Some(
221                    "Remove Byte Order Mark (BOM) and ensure the file is valid UTF-8.".to_string(),
222                ),
223            })?
224        } else {
225            // Regular UTF-8 (no BOM)
226            String::from_utf8(bytes).map_err(|e| ProtocolError::LoadError {
227                path: path.to_string_lossy().to_string(),
228                reason: format!(
229                    "Invalid UTF-8: {}. Please convert the file to UTF-8 encoding.",
230                    e
231                ),
232                hint: Some("Verify the file content is valid UTF-8.".to_string()),
233            })?
234        };
235
236        let manifest: ProtocolManifest = Self::parse_manifest_yaml(&content)?;
237
238        // Validate against JSON Schema
239        self.validator.validate(&manifest)?;
240
241        Ok(manifest)
242    }
243
244    /// Load protocol from remote URL (GitHub raw URL)
245    async fn load_from_url(&self, url: &str) -> Result<ProtocolManifest, ProtocolError> {
246        let client = reqwest::Client::builder()
247            .timeout(std::time::Duration::from_secs(30))
248            .build()
249            .map_err(|e| ProtocolError::Internal(format!("Failed to create HTTP client: {}", e)))?;
250
251        let response = client
252            .get(url)
253            .send()
254            .await
255            .map_err(|e| ProtocolError::LoadError {
256                path: url.to_string(),
257                reason: format!("HTTP request failed: {}", e),
258                hint: Some(
259                    "Check your internet connection and verify the URL is accessible.".to_string(),
260                ),
261            })?;
262
263        if !response.status().is_success() {
264            return Err(ProtocolError::LoadError {
265                path: url.to_string(),
266                reason: format!(
267                    "HTTP {}: {}",
268                    response.status(),
269                    response.text().await.unwrap_or_default()
270                ),
271                hint: Some(
272                    "Verify the remote registry URL and your API permissions if any.".to_string(),
273                ),
274            });
275        }
276
277        let content = response
278            .text()
279            .await
280            .map_err(|e| ProtocolError::LoadError {
281                path: url.to_string(),
282                reason: format!("Failed to read response: {}", e),
283                hint: None,
284            })?;
285
286        let manifest: ProtocolManifest = Self::parse_manifest_yaml(&content)?;
287
288        // Validate against JSON Schema
289        self.validator.validate(&manifest)?;
290
291        Ok(manifest)
292    }
293
294    /// Parse YAML into a ProtocolManifest with better error classification.
295    ///
296    /// Rationale:
297    /// - YAML syntax/encoding issues are "load" errors.
298    /// - Structural mismatches (missing required fields, wrong types) are "validation" errors.
299    fn parse_manifest_yaml(content: &str) -> Result<ProtocolManifest, ProtocolError> {
300        serde_yaml::from_str::<ProtocolManifest>(content).map_err(|e| {
301            let msg = e.to_string();
302            // Heuristic classification based on serde error messages.
303            // This keeps public error categories stable without pulling in serde internals.
304            let looks_structural = msg.contains("missing field")
305                || msg.contains("unknown field")
306                || msg.contains("invalid type")
307                || msg.contains("invalid value")
308                || msg.contains("expected");
309
310            if looks_structural {
311                ProtocolError::ValidationError(format!("Invalid manifest structure: {}", msg))
312            } else {
313                ProtocolError::YamlError(msg)
314            }
315        })
316    }
317
318    /// Load model configuration from registry
319    async fn load_model_config(&self, model_name: &str) -> Result<ModelConfig, ProtocolError> {
320        // Try to find model in v1/models/ directory, scanning all `*.yaml` registries.
321        let mut model_paths: Vec<PathBuf> = Vec::new();
322        if let Ok(root) =
323            std::env::var("AI_PROTOCOL_DIR").or_else(|_| std::env::var("AI_PROTOCOL_PATH"))
324        {
325            let root = PathBuf::from(root);
326            model_paths.push(root.join("v1").join("models"));
327        }
328        model_paths.push(PathBuf::from("ai-protocol/v1/models"));
329        model_paths.push(PathBuf::from("../ai-protocol/v1/models"));
330        model_paths.push(PathBuf::from("../../ai-protocol/v1/models"));
331        let win_dev = PathBuf::from("D:\\ai-protocol\\v1\\models");
332        if win_dev.exists() {
333            model_paths.push(win_dev);
334        }
335
336        for base in model_paths {
337            if !base.exists() {
338                continue;
339            }
340            let mut rd = match tokio::fs::read_dir(&base).await {
341                Ok(rd) => rd,
342                Err(_) => continue,
343            };
344            while let Ok(Some(entry)) = rd.next_entry().await {
345                let path = entry.path();
346                if path
347                    .extension()
348                    .and_then(|s| s.to_str())
349                    .map(|s| s.eq_ignore_ascii_case("yaml") || s.eq_ignore_ascii_case("yml"))
350                    != Some(true)
351                {
352                    continue;
353                }
354                if let Ok(config) = self.load_model_registry(&path).await {
355                    if let Some(model) = config.models.get(model_name) {
356                        return Ok(model.clone());
357                    }
358                }
359            }
360        }
361
362        Err(ProtocolError::NotFound {
363            id: model_name.to_string(),
364            hint: Some(
365                "Check if the model is registered in the manifests/v1/models/ directory"
366                    .to_string(),
367            ),
368        })
369    }
370
371    async fn load_model_registry(&self, path: &Path) -> Result<ModelRegistry, ProtocolError> {
372        let content =
373            tokio::fs::read_to_string(path)
374                .await
375                .map_err(|e| ProtocolError::LoadError {
376                    path: path.to_string_lossy().to_string(),
377                    reason: format!("Failed to read model registry: {}", e),
378                    hint: None,
379                })?;
380
381        let registry: ModelRegistry = serde_yaml::from_str(&content).map_err(|e| {
382            ProtocolError::YamlError(format!("Failed to parse model registry: {}", e))
383        })?;
384
385        Ok(registry)
386    }
387}
388
389impl Default for ProtocolLoader {
390    fn default() -> Self {
391        Self::new()
392    }
393}
394
395/// Model registry structure
396#[derive(Debug, Clone, serde::Deserialize)]
397struct ModelRegistry {
398    models: std::collections::HashMap<String, ModelConfig>,
399}
400
401/// Model configuration from registry
402#[allow(dead_code)]
403#[derive(Debug, Clone, serde::Deserialize)]
404struct ModelConfig {
405    provider: String,
406    #[serde(default)]
407    model_id: Option<String>,
408    #[serde(default)]
409    context_window: Option<u32>,
410    #[serde(default)]
411    capabilities: Vec<String>,
412}
413
414/// Hot-reloadable protocol registry
415pub struct ProtocolRegistry {
416    manifests: ArcSwap<std::collections::HashMap<String, Arc<ProtocolManifest>>>,
417    loader: ProtocolLoader,
418}
419
420impl ProtocolRegistry {
421    pub fn new() -> Self {
422        Self {
423            manifests: ArcSwap::from_pointee(std::collections::HashMap::new()),
424            loader: ProtocolLoader::new(),
425        }
426    }
427
428    /// Get or load a protocol manifest
429    pub async fn get_manifest(
430        &self,
431        provider_id: &str,
432    ) -> Result<Arc<ProtocolManifest>, ProtocolError> {
433        // Check cache first
434        let current = self.manifests.load();
435        if let Some(manifest) = current.get(provider_id) {
436            return Ok(Arc::clone(manifest));
437        }
438
439        // Load and cache
440        let manifest = self.loader.load_provider(provider_id).await?;
441        let manifest_arc = Arc::new(manifest);
442
443        // Update cache atomically
444        let mut updated_map = std::collections::HashMap::new();
445        for (k, v) in current.iter() {
446            updated_map.insert(k.clone(), v.clone());
447        }
448        updated_map.insert(provider_id.to_string(), manifest_arc.clone());
449        self.manifests.store(Arc::new(updated_map));
450
451        Ok(manifest_arc)
452    }
453}
454
455impl Default for ProtocolRegistry {
456    fn default() -> Self {
457        Self::new()
458    }
459}