1use crate::protocol::{ProtocolError, ProtocolManifest};
6use arc_swap::ArcSwap;
7use lru::LruCache;
8use std::path::{Path, PathBuf};
9use std::sync::{Arc, Mutex};
10
11pub struct ProtocolLoader {
13 base_path: Option<PathBuf>,
14 hot_reload: bool,
15 validator: crate::protocol::validator::ProtocolValidator,
16 cache: Mutex<LruCache<String, Arc<ProtocolManifest>>>,
17}
18
19impl ProtocolLoader {
20 pub fn new() -> Self {
21 Self {
22 base_path: None,
23 hot_reload: false,
24 validator: crate::protocol::validator::ProtocolValidator::default(),
25 cache: Mutex::new(LruCache::new(
28 std::num::NonZeroUsize::new(100)
29 .expect("Cache size must be non-zero (this should never happen)"),
30 )),
31 }
32 }
33
34 pub fn with_base_path(mut self, path: impl AsRef<Path>) -> Self {
36 self.base_path = Some(path.as_ref().to_path_buf());
37 self
38 }
39
40 pub fn with_hot_reload(mut self, enable: bool) -> Self {
42 self.hot_reload = enable;
43 self
44 }
45
46 pub async fn load_model(&self, model: &str) -> Result<ProtocolManifest, ProtocolError> {
49 {
51 let mut cache = self.cache.lock().map_err(|e| {
52 ProtocolError::Internal(format!(
53 "Failed to acquire cache lock while loading model '{}': {}",
54 model, e
55 ))
56 })?;
57 if let Some(manifest) = cache.get(model) {
58 return Ok(manifest.as_ref().clone());
59 }
60 }
61
62 let parts: Vec<&str> = model.split('/').collect();
63 if parts.len() != 2 {
64 return Err(ProtocolError::NotFound {
65 id: model.to_string(),
66 hint: Some("Ensure the model name follows the 'provider/model' format".to_string()),
67 });
68 }
69
70 let provider = parts[0];
71 let model_name = parts[1];
72
73 let manifest = match self.load_model_config(model_name).await {
77 Ok(model_config) => self.load_provider(&model_config.provider).await?,
78 Err(ProtocolError::NotFound { .. }) => self.load_provider(provider).await?,
79 Err(e) => return Err(e),
80 };
81
82 {
84 let mut cache = self.cache.lock().map_err(|e| {
85 ProtocolError::Internal(format!(
86 "Failed to acquire cache lock while caching model '{}': {}",
87 model, e
88 ))
89 })?;
90 cache.put(model.to_string(), Arc::new(manifest.clone()));
91 }
92
93 Ok(manifest)
94 }
95
96 pub async fn load_provider(
98 &self,
99 provider_id: &str,
100 ) -> Result<ProtocolManifest, ProtocolError> {
101 let mut search_locations: Vec<(PathBuf, bool)> = Vec::new(); if let Some(ref base_path) = self.base_path {
112 search_locations.push((base_path.join("dist").join("v1").join("providers"), true));
114 search_locations.push((base_path.join("v1").join("providers"), false));
116 }
117
118 if let Ok(root) =
120 std::env::var("AI_PROTOCOL_DIR").or_else(|_| std::env::var("AI_PROTOCOL_PATH"))
121 {
122 if root.starts_with("http://") || root.starts_with("https://") {
123 let url = if root.ends_with('/') {
127 format!("{}dist/v1/providers/{}.json", root, provider_id)
128 } else {
129 format!("{}/dist/v1/providers/{}.json", root, provider_id)
130 };
131
132 if let Ok(manifest) = self.load_from_json_url(&url).await {
134 return Ok(manifest);
135 }
136
137 let url_yaml = if root.ends_with('/') {
139 format!("{}v1/providers/{}.yaml", root, provider_id)
140 } else {
141 format!("{}/v1/providers/{}.yaml", root, provider_id)
142 };
143 return self.load_from_url(&url_yaml).await;
144 } else {
145 let root = PathBuf::from(root);
147 search_locations.push((root.join("dist").join("v1").join("providers"), true));
148 search_locations.push((root.join("v1").join("providers"), false));
149 }
150 }
151
152 let default_roots = vec![
154 PathBuf::from("ai-protocol"),
155 PathBuf::from("../ai-protocol"),
156 PathBuf::from("../../ai-protocol"),
157 PathBuf::from("D:\\ai-protocol"),
158 ];
159
160 for root in default_roots {
161 search_locations.push((root.join("dist").join("v1").join("providers"), true));
162 search_locations.push((root.join("v1").join("providers"), false));
163 }
164
165 for (base, prefer_json) in search_locations {
167 if prefer_json {
168 let json_path = base.join(format!("{}.json", provider_id));
169 if json_path.exists() {
170 return self.load_from_json_file(&json_path).await;
171 }
172 } else {
173 let yaml_path = base.join(format!("{}.yaml", provider_id));
174 if yaml_path.exists() {
175 return self.load_from_file(&yaml_path).await;
176 }
177 }
178 }
179
180 let github_json = format!(
182 "https://raw.githubusercontent.com/hiddenpath/ai-protocol/main/dist/v1/providers/{}.json",
183 provider_id
184 );
185 if let Ok(manifest) = self.load_from_json_url(&github_json).await {
186 return Ok(manifest);
187 }
188
189 let github_yaml = format!(
191 "https://raw.githubusercontent.com/hiddenpath/ai-protocol/main/v1/providers/{}.yaml",
192 provider_id
193 );
194 if let Ok(manifest) = self.load_from_url(&github_yaml).await {
195 return Ok(manifest);
196 }
197
198 Err(ProtocolError::NotFound {
199 id: provider_id.to_string(),
200 hint: Some(format!(
201 "Check if the provider file '{}.json' or '{}.yaml' exists in your protocol directory",
202 provider_id, provider_id
203 )),
204 })
205 }
206
207 async fn load_from_json_file(&self, path: &Path) -> Result<ProtocolManifest, ProtocolError> {
209 let content = tokio::fs::read(path)
210 .await
211 .map_err(|e| ProtocolError::LoadError {
212 path: path.to_string_lossy().to_string(),
213 reason: e.to_string(),
214 hint: Some("Check file permissions.".to_string()),
215 })?;
216
217 let manifest: ProtocolManifest = serde_json::from_slice(&content)
218 .map_err(|e| ProtocolError::ValidationError(format!("Invalid JSON manifest: {}", e)))?;
219
220 self.validator.validate(&manifest)?;
223
224 Ok(manifest)
225 }
226
227 async fn load_from_file(&self, path: &Path) -> Result<ProtocolManifest, ProtocolError> {
229 let bytes = tokio::fs::read(path)
231 .await
232 .map_err(|e| ProtocolError::LoadError {
233 path: path.to_string_lossy().to_string(),
234 reason: e.to_string(),
235 hint: Some("Check if the file exists and you have read permissions.".to_string()),
236 })?;
237
238 let content = if bytes.len() >= 2 && bytes[0] == 0xFF && bytes[1] == 0xFE {
240 let utf16_bytes = &bytes[2..];
242 let mut utf16_chars = Vec::new();
243 for i in (0..utf16_bytes.len()).step_by(2) {
244 if i + 1 < utf16_bytes.len() {
245 let code_unit = u16::from_le_bytes([utf16_bytes[i], utf16_bytes[i + 1]]);
246 utf16_chars.push(code_unit);
247 }
248 }
249 String::from_utf16(&utf16_chars).map_err(|e| ProtocolError::LoadError {
250 path: path.to_string_lossy().to_string(),
251 reason: format!("Invalid UTF-16: {}", e),
252 hint: None,
253 })?
254 } else if bytes.len() >= 3 && bytes[0] == 0xEF && bytes[1] == 0xBB && bytes[2] == 0xBF {
255 String::from_utf8(bytes[3..].to_vec()).map_err(|e| ProtocolError::LoadError {
256 path: path.to_string_lossy().to_string(),
257 reason: format!("Invalid UTF-8 (after BOM): {}", e),
258 hint: None,
259 })?
260 } else {
261 String::from_utf8(bytes).map_err(|e| ProtocolError::LoadError {
262 path: path.to_string_lossy().to_string(),
263 reason: format!("Invalid UTF-8: {}", e),
264 hint: None,
265 })?
266 };
267
268 let manifest: ProtocolManifest = Self::parse_manifest_yaml(&content)?;
269 self.validator.validate(&manifest)?;
270 Ok(manifest)
271 }
272
273 async fn load_from_json_url(&self, url: &str) -> Result<ProtocolManifest, ProtocolError> {
275 let client = reqwest::Client::builder()
276 .timeout(std::time::Duration::from_secs(30))
277 .build()
278 .map_err(|e| ProtocolError::Internal(format!("Failed to create HTTP client: {}", e)))?;
279
280 let response = client
281 .get(url)
282 .send()
283 .await
284 .map_err(|e| ProtocolError::LoadError {
285 path: url.to_string(),
286 reason: format!("HTTP request failed: {}", e),
287 hint: None,
288 })?;
289
290 if !response.status().is_success() {
291 return Err(ProtocolError::LoadError {
292 path: url.to_string(),
293 reason: format!("HTTP {}", response.status()),
294 hint: None,
295 });
296 }
297
298 let content = response
299 .bytes()
300 .await
301 .map_err(|e| ProtocolError::LoadError {
302 path: url.to_string(),
303 reason: format!("Failed to read bytes: {}", e),
304 hint: None,
305 })?;
306
307 let manifest: ProtocolManifest = serde_json::from_slice(&content).map_err(|e| {
308 ProtocolError::ValidationError(format!("Invalid JSON manifest from URL: {}", e))
309 })?;
310
311 self.validator.validate(&manifest)?;
312 Ok(manifest)
313 }
314
315 async fn load_from_url(&self, url: &str) -> Result<ProtocolManifest, ProtocolError> {
317 let client = reqwest::Client::builder()
318 .timeout(std::time::Duration::from_secs(30))
319 .build()
320 .map_err(|e| ProtocolError::Internal(format!("Failed to create HTTP client: {}", e)))?;
321
322 let response = client
323 .get(url)
324 .send()
325 .await
326 .map_err(|e| ProtocolError::LoadError {
327 path: url.to_string(),
328 reason: format!("HTTP request failed: {}", e),
329 hint: Some(
330 "Check your internet connection and verify the URL is accessible.".to_string(),
331 ),
332 })?;
333
334 if !response.status().is_success() {
335 return Err(ProtocolError::LoadError {
336 path: url.to_string(),
337 reason: format!(
338 "HTTP {}: {}",
339 response.status(),
340 response.text().await.unwrap_or_default()
341 ),
342 hint: Some(
343 "Verify the remote registry URL and your API permissions if any.".to_string(),
344 ),
345 });
346 }
347
348 let content = response
349 .text()
350 .await
351 .map_err(|e| ProtocolError::LoadError {
352 path: url.to_string(),
353 reason: format!("Failed to read response: {}", e),
354 hint: None,
355 })?;
356
357 let manifest: ProtocolManifest = Self::parse_manifest_yaml(&content)?;
358
359 self.validator.validate(&manifest)?;
361
362 Ok(manifest)
363 }
364
365 fn parse_manifest_yaml(content: &str) -> Result<ProtocolManifest, ProtocolError> {
371 serde_yaml::from_str::<ProtocolManifest>(content).map_err(|e| {
372 let msg = e.to_string();
373 let looks_structural = msg.contains("missing field")
376 || msg.contains("unknown field")
377 || msg.contains("invalid type")
378 || msg.contains("invalid value")
379 || msg.contains("expected");
380
381 if looks_structural {
382 ProtocolError::ValidationError(format!("Invalid manifest structure: {}", msg))
383 } else {
384 ProtocolError::YamlError(msg)
385 }
386 })
387 }
388
389 async fn load_model_config(&self, model_name: &str) -> Result<ModelConfig, ProtocolError> {
391 let mut search_locations: Vec<(PathBuf, bool)> = Vec::new(); if let Ok(root) =
398 std::env::var("AI_PROTOCOL_DIR").or_else(|_| std::env::var("AI_PROTOCOL_PATH"))
399 {
400 if !root.starts_with("http://") && !root.starts_with("https://") {
404 let root = PathBuf::from(root);
405 search_locations.push((root.join("dist").join("v1").join("models"), true));
406 search_locations.push((root.join("v1").join("models"), false));
407 }
408 }
409
410 let default_roots = vec![
412 PathBuf::from("ai-protocol"),
413 PathBuf::from("../ai-protocol"),
414 PathBuf::from("../../ai-protocol"),
415 PathBuf::from("D:\\ai-protocol"),
416 ];
417
418 for root in default_roots {
419 search_locations.push((root.join("dist").join("v1").join("models"), true));
420 search_locations.push((root.join("v1").join("models"), false));
421 }
422
423 for (base, prefer_json) in search_locations {
424 if !base.exists() {
425 continue;
426 }
427 let mut rd = match tokio::fs::read_dir(&base).await {
428 Ok(rd) => rd,
429 Err(_) => continue,
430 };
431
432 while let Ok(Some(entry)) = rd.next_entry().await {
433 let path = entry.path();
434 let extension = path.extension().and_then(|s| s.to_str());
435
436 let is_match = if prefer_json {
437 extension.map(|s| s.eq_ignore_ascii_case("json")) == Some(true)
438 } else {
439 extension
440 .map(|s| s.eq_ignore_ascii_case("yaml") || s.eq_ignore_ascii_case("yml"))
441 == Some(true)
442 };
443
444 if !is_match {
445 continue;
446 }
447
448 if prefer_json {
449 if let Ok(config) = self.load_model_registry_json(&path).await {
450 if let Some(model) = config.models.get(model_name) {
451 return Ok(model.clone());
452 }
453 }
454 } else {
455 if let Ok(config) = self.load_model_registry_yaml(&path).await {
456 if let Some(model) = config.models.get(model_name) {
457 return Ok(model.clone());
458 }
459 }
460 }
461 }
462 }
463
464 Err(ProtocolError::NotFound {
465 id: model_name.to_string(),
466 hint: Some(
467 "Check if the model is registered in the manifests/v1/models/ directory"
468 .to_string(),
469 ),
470 })
471 }
472
473 async fn load_model_registry_json(&self, path: &Path) -> Result<ModelRegistry, ProtocolError> {
474 let content = tokio::fs::read(path)
475 .await
476 .map_err(|e| ProtocolError::LoadError {
477 path: path.to_string_lossy().to_string(),
478 reason: e.to_string(),
479 hint: None,
480 })?;
481 let registry: ModelRegistry = serde_json::from_slice(&content).map_err(|e| {
482 ProtocolError::ValidationError(format!("Invalid JSON model registry: {}", e))
483 })?;
484 Ok(registry)
485 }
486
487 async fn load_model_registry_yaml(&self, path: &Path) -> Result<ModelRegistry, ProtocolError> {
488 let content =
489 tokio::fs::read_to_string(path)
490 .await
491 .map_err(|e| ProtocolError::LoadError {
492 path: path.to_string_lossy().to_string(),
493 reason: format!("Failed to read model registry: {}", e),
494 hint: None,
495 })?;
496
497 let registry: ModelRegistry = serde_yaml::from_str(&content).map_err(|e| {
498 ProtocolError::YamlError(format!("Failed to parse model registry: {}", e))
499 })?;
500
501 Ok(registry)
502 }
503}
504
505impl Default for ProtocolLoader {
506 fn default() -> Self {
507 Self::new()
508 }
509}
510
511#[derive(Debug, Clone, serde::Deserialize)]
513struct ModelRegistry {
514 models: std::collections::HashMap<String, ModelConfig>,
515}
516
517#[allow(dead_code)]
519#[derive(Debug, Clone, serde::Deserialize)]
520struct ModelConfig {
521 provider: String,
522 #[serde(default)]
523 model_id: Option<String>,
524 #[serde(default)]
525 context_window: Option<u32>,
526 #[serde(default)]
527 capabilities: Vec<String>,
528}
529
530pub struct ProtocolRegistry {
532 manifests: ArcSwap<std::collections::HashMap<String, Arc<ProtocolManifest>>>,
533 loader: ProtocolLoader,
534}
535
536impl ProtocolRegistry {
537 pub fn new() -> Self {
538 Self {
539 manifests: ArcSwap::from_pointee(std::collections::HashMap::new()),
540 loader: ProtocolLoader::new(),
541 }
542 }
543
544 pub async fn get_manifest(
546 &self,
547 provider_id: &str,
548 ) -> Result<Arc<ProtocolManifest>, ProtocolError> {
549 let current = self.manifests.load();
551 if let Some(manifest) = current.get(provider_id) {
552 return Ok(Arc::clone(manifest));
553 }
554
555 let manifest = self.loader.load_provider(provider_id).await?;
557 let manifest_arc = Arc::new(manifest);
558
559 let mut updated_map = std::collections::HashMap::new();
561 for (k, v) in current.iter() {
562 updated_map.insert(k.clone(), v.clone());
563 }
564 updated_map.insert(provider_id.to_string(), manifest_arc.clone());
565 self.manifests.store(Arc::new(updated_map));
566
567 Ok(manifest_arc)
568 }
569}
570
571impl Default for ProtocolRegistry {
572 fn default() -> Self {
573 Self::new()
574 }
575}