Skip to main content

ai_lib_rust/client/
builder.rs

1use crate::client::core::AiClient;
2use crate::protocol::ProtocolLoader;
3use crate::Result;
4use std::sync::Arc;
5use tokio::sync::Semaphore;
6
7/// Builder for creating clients with custom configuration.
8///
9/// Keep this surface area small and predictable (developer-friendly).
10pub struct AiClientBuilder {
11    protocol_path: Option<String>,
12    hot_reload: bool,
13    fallbacks: Vec<String>,
14    strict_streaming: bool,
15    feedback: Arc<dyn crate::telemetry::FeedbackSink>,
16    max_inflight: Option<usize>,
17    breaker: Option<Arc<crate::resilience::circuit_breaker::CircuitBreaker>>,
18    rate_limiter: Option<Arc<crate::resilience::rate_limiter::RateLimiter>>,
19    /// Override base URL (primarily for testing with mock servers)
20    base_url_override: Option<String>,
21}
22
23impl AiClientBuilder {
24    pub fn new() -> Self {
25        Self {
26            protocol_path: None,
27            hot_reload: false,
28            fallbacks: Vec::new(),
29            strict_streaming: false,
30            feedback: crate::telemetry::noop_sink(),
31            max_inflight: None,
32            breaker: None,
33            rate_limiter: None,
34            base_url_override: None,
35        }
36    }
37
38    /// Set custom protocol directory path.
39    pub fn protocol_path(mut self, path: String) -> Self {
40        self.protocol_path = Some(path);
41        self
42    }
43
44    /// Enable hot reload of protocol files.
45    pub fn hot_reload(mut self, enable: bool) -> Self {
46        self.hot_reload = enable;
47        self
48    }
49
50    /// Set fallback models.
51    pub fn with_fallbacks(mut self, fallbacks: Vec<String>) -> Self {
52        self.fallbacks = fallbacks;
53        self
54    }
55
56    /// Enable strict streaming validation (fail fast when streaming config is incomplete).
57    ///
58    /// This is intentionally opt-in to preserve compatibility with partial manifests.
59    pub fn strict_streaming(mut self, enable: bool) -> Self {
60        self.strict_streaming = enable;
61        self
62    }
63
64    /// Inject a feedback sink. Default is a no-op sink.
65    pub fn feedback_sink(mut self, sink: Arc<dyn crate::telemetry::FeedbackSink>) -> Self {
66        self.feedback = sink;
67        self
68    }
69
70    /// Limit maximum number of in-flight requests/streams.
71    /// This is a simple backpressure mechanism for production safety.
72    pub fn max_inflight(mut self, n: usize) -> Self {
73        self.max_inflight = Some(n.max(1));
74        self
75    }
76
77    /// Enable a minimal circuit breaker.
78    ///
79    /// Defaults can also be enabled via env:
80    /// - `AI_LIB_BREAKER_FAILURE_THRESHOLD` (default 5)
81    /// - `AI_LIB_BREAKER_COOLDOWN_SECS` (default 30)
82    pub fn circuit_breaker_default(mut self) -> Self {
83        let threshold = std::env::var("AI_LIB_BREAKER_FAILURE_THRESHOLD")
84            .ok()
85            .and_then(|s| s.parse::<u32>().ok())
86            .unwrap_or(5);
87        let cooldown_secs = std::env::var("AI_LIB_BREAKER_COOLDOWN_SECS")
88            .ok()
89            .and_then(|s| s.parse::<u64>().ok())
90            .unwrap_or(30);
91        let cfg = crate::resilience::circuit_breaker::CircuitBreakerConfig {
92            failure_threshold: threshold.max(1),
93            cooldown: std::time::Duration::from_secs(cooldown_secs.max(1)),
94        };
95        self.breaker = Some(Arc::new(
96            crate::resilience::circuit_breaker::CircuitBreaker::new(cfg),
97        ));
98        self
99    }
100
101    /// Enable a minimal token-bucket rate limiter.
102    ///
103    /// - Prefer configuring via env to keep API surface small:
104    ///   - `AI_LIB_RPS` (requests per second)
105    ///   - `AI_LIB_RPM` (requests per minute)
106    pub fn rate_limit_rps(mut self, rps: f64) -> Self {
107        if let Some(cfg) = crate::resilience::rate_limiter::RateLimiterConfig::from_rps(rps) {
108            self.rate_limiter = Some(Arc::new(crate::resilience::rate_limiter::RateLimiter::new(
109                cfg,
110            )));
111        }
112        self
113    }
114
115    /// Override the base URL from the protocol manifest.
116    ///
117    /// This is primarily for testing with mock servers. In production, use the
118    /// base_url defined in the protocol manifest.
119    pub fn base_url_override(mut self, base_url: impl Into<String>) -> Self {
120        self.base_url_override = Some(base_url.into());
121        self
122    }
123
124    /// Build the client.
125    pub async fn build(self, model: &str) -> Result<AiClient> {
126        let mut loader = ProtocolLoader::new();
127
128        if let Some(path) = self.protocol_path {
129            loader = loader.with_base_path(path);
130        }
131
132        if self.hot_reload {
133            loader = loader.with_hot_reload(true);
134        }
135
136        // model is in form "provider/model-id" or "provider/org/model-name" (e.g. nvidia/minimaxai/minimax-m2)
137        let parts: Vec<&str> = model.split('/').collect();
138        let model_id = if parts.len() >= 2 {
139            parts[1..].join("/")
140        } else {
141            model.to_string()
142        };
143
144        let manifest = loader.load_model(model).await?;
145        let strict_streaming = self.strict_streaming
146            || std::env::var("AI_LIB_STRICT_STREAMING").ok().as_deref() == Some("1");
147        crate::client::validation::validate_manifest(&manifest, strict_streaming)?;
148
149        let transport = Arc::new(crate::transport::HttpTransport::new_with_base_url(
150            &manifest,
151            &model_id,
152            self.base_url_override.as_deref(),
153        )?);
154        let pipeline = Arc::new(crate::pipeline::Pipeline::from_manifest(&manifest)?);
155
156        let max_inflight = self.max_inflight.or_else(|| {
157            std::env::var("AI_LIB_MAX_INFLIGHT")
158                .ok()?
159                .parse::<usize>()
160                .ok()
161        });
162        let inflight = max_inflight.map(|n| Arc::new(Semaphore::new(n.max(1))));
163
164        // Optional per-attempt timeout (policy signal). Transport has its own timeout too; this is an extra guard.
165        let attempt_timeout = std::env::var("AI_LIB_ATTEMPT_TIMEOUT_MS")
166            .ok()
167            .and_then(|s| s.parse::<u64>().ok())
168            .filter(|ms| *ms > 0)
169            .map(std::time::Duration::from_millis);
170
171        let env_rps = std::env::var("AI_LIB_RPS")
172            .ok()
173            .and_then(|s| s.parse::<f64>().ok());
174        let env_rpm = std::env::var("AI_LIB_RPM")
175            .ok()
176            .and_then(|s| s.parse::<f64>().ok());
177        let env_rate_limiter = env_rps
178            .or_else(|| env_rpm.map(|rpm| rpm / 60.0))
179            .and_then(crate::resilience::rate_limiter::RateLimiterConfig::from_rps)
180            .map(|cfg| Arc::new(crate::resilience::rate_limiter::RateLimiter::new(cfg)));
181
182        // If no explicit rate limiter and manifest has rate limit headers, enable adaptive mode (rps=0)
183        let rate_limiter = self.rate_limiter.or(env_rate_limiter).or_else(|| {
184            if manifest.rate_limit_headers.is_some() {
185                crate::resilience::rate_limiter::RateLimiterConfig::from_rps(0.0)
186                    .map(|cfg| Arc::new(crate::resilience::rate_limiter::RateLimiter::new(cfg)))
187            } else {
188                None
189            }
190        });
191
192        Ok(AiClient {
193            manifest,
194            transport,
195            pipeline,
196            loader: Arc::new(loader),
197            fallbacks: self.fallbacks,
198            model_id,
199            strict_streaming,
200            feedback: self.feedback,
201            inflight,
202            max_inflight,
203            attempt_timeout,
204            breaker: self.breaker,
205            rate_limiter,
206        })
207    }
208}
209
210impl Default for AiClientBuilder {
211    fn default() -> Self {
212        Self::new()
213    }
214}