Skip to main content

ai_lib_rust/protocol/
mod.rs

1//! Protocol specification layer
2//!
3//! This module handles loading, validating, and managing AI-Protocol specifications.
4//! It provides the foundation for the protocol-driven architecture.
5
6pub mod loader;
7pub mod schema;
8pub mod validator;
9
10pub use loader::ProtocolLoader;
11pub use schema::ProtocolSchema;
12pub use validator::ProtocolValidator;
13
14use serde::{Deserialize, Serialize};
15use std::collections::HashMap;
16
17/// Unified request format (for protocol compilation)
18#[derive(Debug, Clone, Default)]
19pub struct UnifiedRequest {
20    /// Operation intent used for endpoint routing (e.g. "chat", "completions", "embeddings")
21    pub operation: String,
22    /// Provider model id (e.g. "deepseek-chat", "gpt-4o-mini")
23    pub model: String,
24    pub messages: Vec<crate::types::message::Message>,
25    pub temperature: Option<f64>,
26    pub max_tokens: Option<u32>,
27    pub stream: bool,
28    pub tools: Option<Vec<crate::types::tool::ToolDefinition>>,
29    /// OpenAI-style tool choice. Examples:
30    /// - "auto"
31    /// - "none"
32    /// - {"type":"function","function":{"name":"web_search"}}
33    pub tool_choice: Option<serde_json::Value>,
34}
35
36/// Protocol error types
37#[derive(Debug, thiserror::Error)]
38pub enum ProtocolError {
39    #[error("Failed to load protocol from {path}: {reason}{}", .hint.as_ref().map(|h| format!("\n💡 Hint: {}", h)).unwrap_or_default())]
40    LoadError {
41        path: String,
42        reason: String,
43        hint: Option<String>,
44    },
45
46    #[error("Protocol validation failed: {0}")]
47    ValidationError(String),
48
49    #[error("Schema mismatch: expected {expected}, found {actual} at {path}{}", .hint.as_ref().map(|h| format!("\n💡 Hint: {}", h)).unwrap_or_default())]
50    SchemaMismatch {
51        path: String,
52        expected: String,
53        actual: String,
54        hint: Option<String>,
55    },
56
57    #[error("Protocol not found: {id}{}", .hint.as_ref().map(|h| format!("\n💡 Hint: {}", h)).unwrap_or_default())]
58    NotFound { id: String, hint: Option<String> },
59
60    #[error("Unsupported protocol version '{version}' (max supported: {max_supported}){}", .hint.as_ref().map(|h| format!("\n💡 Hint: {}", h)).unwrap_or_default())]
61    InvalidVersion {
62        version: String,
63        max_supported: String,
64        hint: Option<String>,
65    },
66
67    #[error("Configuration manifest error: {0}")]
68    ManifestError(String),
69
70    #[error("Internal protocol error: {0}")]
71    Internal(String),
72
73    #[error("YAML syntax error: {0}")]
74    YamlError(String),
75}
76
77impl ProtocolError {
78    /// Attach an actionable hint to the error
79    pub fn with_hint(mut self, hint: impl Into<String>) -> Self {
80        let hint_val = Some(hint.into());
81        match self {
82            ProtocolError::LoadError { ref mut hint, .. } => *hint = hint_val,
83            ProtocolError::SchemaMismatch { ref mut hint, .. } => *hint = hint_val,
84            ProtocolError::NotFound { ref mut hint, .. } => *hint = hint_val,
85            ProtocolError::InvalidVersion { ref mut hint, .. } => *hint = hint_val,
86            _ => (),
87        }
88        self
89    }
90}
91
92/// Protocol manifest structure (parsed from YAML)
93///
94/// Required fields per schema: id, protocol_version, endpoint, availability, capabilities
95#[derive(Debug, Clone, Serialize, Deserialize)]
96pub struct ProtocolManifest {
97    #[serde(rename = "$schema", skip_serializing_if = "Option::is_none")]
98    pub schema: Option<String>,
99
100    // Required fields
101    pub id: String,
102    pub protocol_version: String,
103    pub endpoint: EndpointDefinition,
104    pub availability: AvailabilityConfig,
105    pub capabilities: Capabilities,
106
107    // Provider metadata (required in manifests)
108    #[serde(skip_serializing_if = "Option::is_none")]
109    pub name: Option<String>,
110    #[serde(skip_serializing_if = "Option::is_none")]
111    pub provider_id: Option<String>,
112    #[serde(skip_serializing_if = "Option::is_none")]
113    pub version: Option<String>,
114    pub status: String,   // stable/beta/deprecated
115    pub category: String, // ai_provider / model_provider / third_party_aggregator
116    pub official_url: String,
117    pub support_contact: String,
118
119    // Auth and configuration
120    pub auth: AuthConfig,
121    #[serde(skip_serializing_if = "Option::is_none")]
122    pub payload_format: Option<String>,
123    pub parameter_mappings: HashMap<String, String>,
124    #[serde(skip_serializing_if = "Option::is_none")]
125    pub response_format: Option<String>,
126    #[serde(skip_serializing_if = "Option::is_none")]
127    pub response_paths: Option<HashMap<String, String>>,
128
129    // Streaming and features
130    #[serde(skip_serializing_if = "Option::is_none")]
131    pub streaming: Option<StreamingConfig>,
132    #[serde(skip_serializing_if = "Option::is_none")]
133    pub features: Option<FeaturesConfig>,
134
135    // Endpoints and services
136    #[serde(skip_serializing_if = "Option::is_none")]
137    pub endpoints: Option<HashMap<String, EndpointConfig>>,
138    #[serde(skip_serializing_if = "Option::is_none")]
139    pub services: Option<HashMap<String, ServiceConfig>>,
140
141    // API families
142    #[serde(skip_serializing_if = "Option::is_none")]
143    pub api_families: Option<Vec<String>>,
144    #[serde(skip_serializing_if = "Option::is_none")]
145    pub default_api_family: Option<String>,
146
147    // Tooling and termination
148    #[serde(skip_serializing_if = "Option::is_none")]
149    pub termination: Option<TerminationConfig>,
150    #[serde(skip_serializing_if = "Option::is_none")]
151    pub tooling: Option<ToolingConfig>,
152
153    // Error handling and resilience
154    #[serde(skip_serializing_if = "Option::is_none")]
155    pub retry_policy: Option<RetryPolicy>,
156    #[serde(skip_serializing_if = "Option::is_none")]
157    pub error_classification: Option<ErrorClassification>,
158    #[serde(skip_serializing_if = "Option::is_none")]
159    pub rate_limit_headers: Option<RateLimitHeaders>,
160
161    // Experimental features
162    #[serde(skip_serializing_if = "Option::is_none")]
163    pub experimental_features: Option<Vec<String>>,
164}
165
166#[derive(Debug, Clone, Serialize)]
167pub struct EndpointConfig {
168    pub path: String,
169    pub method: String,
170    #[serde(skip_serializing_if = "Option::is_none")]
171    pub adapter: Option<String>,
172}
173
174impl<'de> Deserialize<'de> for EndpointConfig {
175    fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
176    where
177        D: serde::Deserializer<'de>,
178    {
179        #[derive(Deserialize)]
180        #[serde(untagged)]
181        enum Input {
182            // Shorthand: endpoint: "/v1/chat/completions"
183            Path(String),
184            // Full form
185            Obj {
186                path: String,
187                #[serde(default = "default_method")]
188                method: String,
189                #[serde(default)]
190                adapter: Option<String>,
191            },
192        }
193
194        match Input::deserialize(deserializer)? {
195            Input::Path(path) => Ok(EndpointConfig {
196                path,
197                method: default_method(),
198                adapter: None,
199            }),
200            Input::Obj {
201                path,
202                method,
203                adapter,
204            } => Ok(EndpointConfig {
205                path,
206                method,
207                adapter,
208            }),
209        }
210    }
211}
212
213fn default_method() -> String {
214    "POST".to_string()
215}
216
217#[derive(Debug, Clone, Serialize, Deserialize)]
218pub struct ServiceConfig {
219    pub path: String,
220    #[serde(default = "default_method_get")]
221    pub method: String,
222    #[serde(skip_serializing_if = "Option::is_none")]
223    pub headers: Option<HashMap<String, String>>,
224    #[serde(skip_serializing_if = "Option::is_none")]
225    pub query_params: Option<HashMap<String, String>>,
226    #[serde(skip_serializing_if = "Option::is_none")]
227    pub response_binding: Option<String>,
228}
229
230fn default_method_get() -> String {
231    "GET".to_string()
232}
233
234/// Structured endpoint definition (v1.1+ extension)
235#[derive(Debug, Clone, Serialize, Deserialize)]
236pub struct EndpointDefinition {
237    pub base_url: String,
238    #[serde(skip_serializing_if = "Option::is_none")]
239    pub protocol: Option<String>, // https, http, ws, wss
240    #[serde(skip_serializing_if = "Option::is_none")]
241    pub timeout_ms: Option<u32>,
242}
243
244/// Capabilities object format (v1.1+)
245/// Required fields: streaming, tools, vision
246#[derive(Debug, Clone, Serialize, Deserialize)]
247pub struct Capabilities {
248    pub streaming: bool,
249    pub tools: bool,
250    pub vision: bool,
251    #[serde(default, skip_serializing_if = "is_false")]
252    pub agentic: bool,
253    #[serde(default, skip_serializing_if = "is_false")]
254    pub parallel_tools: bool,
255    #[serde(default, skip_serializing_if = "is_false")]
256    pub reasoning: bool,
257    #[serde(default, skip_serializing_if = "is_false")]
258    pub multimodal: bool,
259    #[serde(default, skip_serializing_if = "is_false")]
260    pub audio: bool,
261}
262
263fn is_false(b: &bool) -> bool {
264    !*b
265}
266
267impl ProtocolManifest {
268    /// Check if protocol supports a specific capability
269    pub fn supports_capability(&self, capability: &str) -> bool {
270        match capability {
271            "streaming" => self.capabilities.streaming,
272            "tools" => self.capabilities.tools,
273            "vision" => self.capabilities.vision,
274            "agentic" => self.capabilities.agentic,
275            "parallel_tools" => self.capabilities.parallel_tools,
276            "reasoning" => self.capabilities.reasoning,
277            "multimodal" => {
278                self.capabilities.multimodal || self.capabilities.vision || self.capabilities.audio
279            }
280            "audio" => self.capabilities.audio,
281            _ => false,
282        }
283    }
284
285    /// Get base URL from endpoint definition
286    pub fn get_base_url(&self) -> &str {
287        &self.endpoint.base_url
288    }
289
290    /// Compile unified request to provider-specific format
291    pub fn compile_request(
292        &self,
293        request: &UnifiedRequest,
294    ) -> Result<serde_json::Value, ProtocolError> {
295        use crate::utils::PathMapper;
296
297        let mut provider_request = serde_json::json!({});
298
299        // Model is required for most OpenAI-compatible APIs
300        let model_path = self
301            .parameter_mappings
302            .get("model")
303            .map(|s| s.as_str())
304            .unwrap_or("model");
305        PathMapper::set_path(
306            &mut provider_request,
307            model_path,
308            serde_json::Value::String(request.model.clone()),
309        )
310        .map_err(|e| ProtocolError::ValidationError(format!("Failed to set model: {}", e)))?;
311
312        // Map standard parameters to provider-specific names using PathMapper
313        if let Some(temp) = request.temperature {
314            if let Some(mapped) = self.parameter_mappings.get("temperature") {
315                PathMapper::set_path(
316                    &mut provider_request,
317                    mapped,
318                    serde_json::Value::Number(serde_json::Number::from_f64(temp).ok_or_else(
319                        || ProtocolError::ValidationError("Invalid temperature".to_string()),
320                    )?),
321                )
322                .map_err(|e| {
323                    ProtocolError::ValidationError(format!("Failed to set temperature: {}", e))
324                })?;
325            }
326        }
327
328        if let Some(max) = request.max_tokens {
329            if let Some(mapped) = self.parameter_mappings.get("max_tokens") {
330                PathMapper::set_path(
331                    &mut provider_request,
332                    mapped,
333                    serde_json::Value::Number(max.into()),
334                )
335                .map_err(|e| {
336                    ProtocolError::ValidationError(format!("Failed to set max_tokens: {}", e))
337                })?;
338            }
339        }
340
341        if let Some(mapped) = self.parameter_mappings.get("stream") {
342            PathMapper::set_path(
343                &mut provider_request,
344                mapped,
345                serde_json::Value::Bool(request.stream),
346            )
347            .map_err(|e| ProtocolError::ValidationError(format!("Failed to set stream: {}", e)))?;
348        }
349
350        // Map messages (format depends on payload_format)
351        let messages_path = self
352            .parameter_mappings
353            .get("messages")
354            .map(|s| s.as_str())
355            .unwrap_or("messages");
356        let messages: Vec<serde_json::Value> = request
357            .messages
358            .iter()
359            .map(|m| serde_json::to_value(m).unwrap())
360            .collect();
361        PathMapper::set_path(
362            &mut provider_request,
363            messages_path,
364            serde_json::Value::Array(messages),
365        )
366        .map_err(|e| ProtocolError::ValidationError(format!("Failed to set messages: {}", e)))?;
367
368        // Map tools if present
369        if let Some(tools) = &request.tools {
370            if let Some(mapped) = self.parameter_mappings.get("tools") {
371                let tools_value: Vec<serde_json::Value> = tools
372                    .iter()
373                    .map(|t| serde_json::to_value(t).unwrap())
374                    .collect();
375                PathMapper::set_path(
376                    &mut provider_request,
377                    mapped,
378                    serde_json::Value::Array(tools_value),
379                )
380                .map_err(|e| {
381                    ProtocolError::ValidationError(format!("Failed to set tools: {}", e))
382                })?;
383            }
384        }
385
386        // Map tool_choice if present
387        if let Some(tool_choice) = &request.tool_choice {
388            if let Some(mapped) = self.parameter_mappings.get("tool_choice") {
389                PathMapper::set_path(&mut provider_request, mapped, tool_choice.clone()).map_err(
390                    |e| ProtocolError::ValidationError(format!("Failed to set tool_choice: {}", e)),
391                )?;
392            }
393        }
394
395        Ok(provider_request)
396    }
397}
398
399#[derive(Debug, Clone, Serialize, Deserialize)]
400pub struct AuthConfig {
401    #[serde(rename = "type")]
402    pub auth_type: String,
403    #[serde(default, skip_serializing_if = "Option::is_none")]
404    pub token_env: Option<String>,
405    #[serde(default, skip_serializing_if = "Option::is_none")]
406    pub key_env: Option<String>,
407    #[serde(default, skip_serializing_if = "Option::is_none")]
408    pub param_name: Option<String>,
409    #[serde(default, skip_serializing_if = "Option::is_none")]
410    pub header_name: Option<String>,
411    #[serde(default, skip_serializing_if = "Option::is_none")]
412    pub extra_headers: Option<Vec<HeaderConfig>>,
413}
414
415#[derive(Debug, Clone, Serialize, Deserialize)]
416pub struct HeaderConfig {
417    pub name: String,
418    pub value: String,
419}
420
421#[derive(Debug, Clone, Serialize, Deserialize)]
422pub struct StreamingConfig {
423    #[serde(skip_serializing_if = "Option::is_none")]
424    pub event_format: Option<String>,
425    #[serde(skip_serializing_if = "Option::is_none")]
426    pub decoder: Option<DecoderConfig>,
427    #[serde(skip_serializing_if = "Option::is_none")]
428    pub frame_selector: Option<String>,
429    /// Common path for content delta in streaming frames (provider-specific)
430    #[serde(default, skip_serializing_if = "Option::is_none")]
431    pub content_path: Option<String>,
432    /// Common path for tool call delta in streaming frames (provider-specific)
433    #[serde(default, skip_serializing_if = "Option::is_none")]
434    pub tool_call_path: Option<String>,
435    /// Common path for usage metadata in streaming frames (provider-specific)
436    #[serde(default, skip_serializing_if = "Option::is_none")]
437    pub usage_path: Option<String>,
438    #[serde(skip_serializing_if = "Option::is_none")]
439    pub candidate: Option<CandidateConfig>,
440    #[serde(skip_serializing_if = "Option::is_none")]
441    pub accumulator: Option<AccumulatorConfig>,
442    #[serde(default)]
443    pub event_map: Vec<EventMapRule>,
444    #[serde(skip_serializing_if = "Option::is_none")]
445    pub stop_condition: Option<String>,
446}
447
448#[derive(Debug, Clone, Serialize, Deserialize)]
449pub struct DecoderConfig {
450    pub format: String,
451    #[serde(skip_serializing_if = "Option::is_none")]
452    pub strategy: Option<String>,
453    #[serde(skip_serializing_if = "Option::is_none")]
454    pub delimiter: Option<String>,
455    #[serde(skip_serializing_if = "Option::is_none")]
456    pub prefix: Option<String>,
457    #[serde(skip_serializing_if = "Option::is_none")]
458    pub done_signal: Option<String>,
459}
460
461#[derive(Debug, Clone, Serialize, Deserialize)]
462pub struct CandidateConfig {
463    #[serde(skip_serializing_if = "Option::is_none")]
464    pub candidate_id_path: Option<String>,
465    #[serde(skip_serializing_if = "Option::is_none")]
466    pub fan_out: Option<bool>,
467}
468
469#[derive(Debug, Clone, Serialize, Deserialize)]
470pub struct AccumulatorConfig {
471    #[serde(default)]
472    pub stateful_tool_parsing: bool,
473    #[serde(skip_serializing_if = "Option::is_none")]
474    pub key_path: Option<String>,
475    #[serde(skip_serializing_if = "Option::is_none")]
476    pub flush_on: Option<String>,
477}
478
479#[derive(Debug, Clone, Serialize, Deserialize)]
480pub struct EventMapRule {
481    #[serde(rename = "match")]
482    pub match_expr: String,
483    pub emit: String,
484    #[serde(default, skip_serializing_if = "Option::is_none")]
485    pub fields: Option<HashMap<String, String>>,
486}
487
488#[derive(Debug, Clone, Serialize, Deserialize)]
489pub struct FeaturesConfig {
490    #[serde(default, skip_serializing_if = "Option::is_none")]
491    pub multi_candidate: Option<MultiCandidateConfig>,
492    #[serde(default, skip_serializing_if = "Option::is_none")]
493    pub response_mapping: Option<ResponseMappingConfig>,
494}
495
496#[derive(Debug, Clone, Serialize, Deserialize)]
497pub struct MultiCandidateConfig {
498    pub support_type: String,
499    #[serde(default, skip_serializing_if = "Option::is_none")]
500    pub param_name: Option<String>,
501    #[serde(default, skip_serializing_if = "Option::is_none")]
502    pub max_concurrent: Option<u32>,
503}
504
505#[derive(Debug, Clone, Serialize, Deserialize)]
506pub struct ResponseMappingConfig {
507    #[serde(default, skip_serializing_if = "Option::is_none")]
508    pub tool_calls: Option<ToolCallsMapping>,
509    #[serde(default, skip_serializing_if = "Option::is_none")]
510    pub error: Option<ErrorMapping>,
511}
512
513#[derive(Debug, Clone, Serialize, Deserialize)]
514pub struct ToolCallsMapping {
515    pub path: String,
516    #[serde(default, skip_serializing_if = "Option::is_none")]
517    pub filter: Option<String>,
518    pub fields: HashMap<String, String>,
519    #[serde(default, skip_serializing_if = "Option::is_none")]
520    pub array_fan_out: Option<bool>,
521}
522
523#[derive(Debug, Clone, Serialize, Deserialize)]
524pub struct ErrorMapping {
525    #[serde(default, skip_serializing_if = "Option::is_none")]
526    pub message_path: Option<String>,
527    #[serde(default, skip_serializing_if = "Option::is_none")]
528    pub code_path: Option<String>,
529    #[serde(default, skip_serializing_if = "Option::is_none")]
530    pub type_path: Option<String>,
531}
532
533#[derive(Debug, Clone, Serialize, Deserialize)]
534pub struct TerminationConfig {
535    pub source_field: String,
536    #[serde(default, skip_serializing_if = "Option::is_none")]
537    pub mapping: Option<HashMap<String, String>>,
538}
539
540#[derive(Debug, Clone, Serialize, Deserialize)]
541pub struct ToolingConfig {
542    pub source_model: String,
543    #[serde(default, skip_serializing_if = "Option::is_none")]
544    pub tool_use: Option<ToolUseMapping>,
545    #[serde(default, skip_serializing_if = "Option::is_none")]
546    pub tool_result: Option<ToolResultMapping>,
547}
548
549#[derive(Debug, Clone, Serialize, Deserialize)]
550pub struct ToolUseMapping {
551    #[serde(default, skip_serializing_if = "Option::is_none")]
552    pub id_path: Option<String>,
553    #[serde(default, skip_serializing_if = "Option::is_none")]
554    pub name_path: Option<String>,
555    #[serde(default, skip_serializing_if = "Option::is_none")]
556    pub input_path: Option<String>,
557    #[serde(default, skip_serializing_if = "Option::is_none")]
558    pub input_format: Option<String>,
559}
560
561#[derive(Debug, Clone, Serialize, Deserialize)]
562pub struct ToolResultMapping {
563    #[serde(default, skip_serializing_if = "Option::is_none")]
564    pub id_path: Option<String>,
565    #[serde(default, skip_serializing_if = "Option::is_none")]
566    pub name_path: Option<String>,
567    #[serde(default, skip_serializing_if = "Option::is_none")]
568    pub response_path: Option<String>,
569}
570
571#[derive(Debug, Clone, Serialize, Deserialize)]
572pub struct RetryPolicy {
573    pub strategy: String,
574    #[serde(default, skip_serializing_if = "Option::is_none")]
575    pub max_retries: Option<u32>,
576    #[serde(default, skip_serializing_if = "Option::is_none")]
577    pub min_delay_ms: Option<u32>,
578    #[serde(default, skip_serializing_if = "Option::is_none")]
579    pub max_delay_ms: Option<u32>,
580    #[serde(default, skip_serializing_if = "Option::is_none")]
581    pub jitter: Option<String>,
582    #[serde(default, skip_serializing_if = "Option::is_none")]
583    pub retry_on_http_status: Option<Vec<u16>>,
584    #[serde(default, skip_serializing_if = "Option::is_none")]
585    pub retry_on_error_status: Option<Vec<String>>,
586}
587
588#[derive(Debug, Clone, Serialize, Deserialize)]
589pub struct ErrorClassification {
590    #[serde(default, skip_serializing_if = "Option::is_none")]
591    pub by_http_status: Option<HashMap<String, String>>,
592    #[serde(default, skip_serializing_if = "Option::is_none")]
593    pub by_error_status: Option<HashMap<String, String>>,
594}
595
596/// Availability and health checking configuration (v1.1+ extension)
597/// Required fields: required, regions, check
598#[derive(Debug, Clone, Serialize, Deserialize)]
599pub struct AvailabilityConfig {
600    pub required: bool,
601    pub regions: Vec<String>, // cn, global, us, eu
602    pub check: HealthCheckConfig,
603    #[serde(skip_serializing_if = "Option::is_none")]
604    pub notes: Option<Vec<String>>,
605}
606
607/// Health check endpoint configuration
608/// Required fields: method, path, expected_status
609#[derive(Debug, Clone, Serialize, Deserialize)]
610pub struct HealthCheckConfig {
611    pub method: String, // HEAD, GET
612    pub path: String,
613    pub expected_status: Vec<u16>,
614    #[serde(skip_serializing_if = "Option::is_none")]
615    pub timeout_ms: Option<u32>,
616}
617
618#[derive(Debug, Clone, Serialize, Deserialize)]
619pub struct RateLimitHeaders {
620    #[serde(default, skip_serializing_if = "Option::is_none")]
621    pub requests_limit: Option<String>,
622    #[serde(default, skip_serializing_if = "Option::is_none")]
623    pub requests_remaining: Option<String>,
624    #[serde(default, skip_serializing_if = "Option::is_none")]
625    pub requests_reset: Option<String>,
626    #[serde(default, skip_serializing_if = "Option::is_none")]
627    pub tokens_limit: Option<String>,
628    #[serde(default, skip_serializing_if = "Option::is_none")]
629    pub tokens_remaining: Option<String>,
630    #[serde(default, skip_serializing_if = "Option::is_none")]
631    pub tokens_reset: Option<String>,
632    #[serde(default, skip_serializing_if = "Option::is_none")]
633    pub retry_after: Option<String>,
634}