Skip to main content

ai_lib_rust/client/
builder.rs

1use crate::client::core::AiClient;
2use crate::protocol::ProtocolLoader;
3use crate::Result;
4use std::sync::Arc;
5use tokio::sync::Semaphore;
6
7/// Builder for creating clients with custom configuration.
8///
9/// Keep this surface area small and predictable (developer-friendly).
10pub struct AiClientBuilder {
11    protocol_path: Option<String>,
12    hot_reload: bool,
13    fallbacks: Vec<String>,
14    strict_streaming: bool,
15    feedback: Arc<dyn crate::telemetry::FeedbackSink>,
16    max_inflight: Option<usize>,
17    breaker: Option<Arc<crate::resilience::circuit_breaker::CircuitBreaker>>,
18    rate_limiter: Option<Arc<crate::resilience::rate_limiter::RateLimiter>>,
19    /// Override base URL (primarily for testing with mock servers)
20    base_url_override: Option<String>,
21}
22
23impl AiClientBuilder {
24    pub fn new() -> Self {
25        Self {
26            protocol_path: None,
27            hot_reload: false,
28            fallbacks: Vec::new(),
29            strict_streaming: false,
30            feedback: crate::telemetry::noop_sink(),
31            max_inflight: None,
32            breaker: None,
33            rate_limiter: None,
34            base_url_override: None,
35        }
36    }
37
38    /// Set custom protocol directory path.
39    pub fn protocol_path(mut self, path: String) -> Self {
40        self.protocol_path = Some(path);
41        self
42    }
43
44    /// Enable hot reload of protocol files.
45    pub fn hot_reload(mut self, enable: bool) -> Self {
46        self.hot_reload = enable;
47        self
48    }
49
50    /// Set fallback models.
51    pub fn with_fallbacks(mut self, fallbacks: Vec<String>) -> Self {
52        self.fallbacks = fallbacks;
53        self
54    }
55
56    /// Enable strict streaming validation (fail fast when streaming config is incomplete).
57    ///
58    /// This is intentionally opt-in to preserve compatibility with partial manifests.
59    pub fn strict_streaming(mut self, enable: bool) -> Self {
60        self.strict_streaming = enable;
61        self
62    }
63
64    /// Inject a feedback sink. Default is a no-op sink.
65    pub fn feedback_sink(mut self, sink: Arc<dyn crate::telemetry::FeedbackSink>) -> Self {
66        self.feedback = sink;
67        self
68    }
69
70    /// Limit maximum number of in-flight requests/streams.
71    /// This is a simple backpressure mechanism for production safety.
72    pub fn max_inflight(mut self, n: usize) -> Self {
73        self.max_inflight = Some(n.max(1));
74        self
75    }
76
77    /// Enable a minimal circuit breaker.
78    ///
79    /// Defaults can also be enabled via env:
80    /// - `AI_LIB_BREAKER_FAILURE_THRESHOLD` (default 5)
81    /// - `AI_LIB_BREAKER_COOLDOWN_SECS` (default 30)
82    pub fn circuit_breaker_default(mut self) -> Self {
83        let threshold = std::env::var("AI_LIB_BREAKER_FAILURE_THRESHOLD")
84            .ok()
85            .and_then(|s| s.parse::<u32>().ok())
86            .unwrap_or(5);
87        let cooldown_secs = std::env::var("AI_LIB_BREAKER_COOLDOWN_SECS")
88            .ok()
89            .and_then(|s| s.parse::<u64>().ok())
90            .unwrap_or(30);
91        let cfg = crate::resilience::circuit_breaker::CircuitBreakerConfig {
92            failure_threshold: threshold.max(1),
93            cooldown: std::time::Duration::from_secs(cooldown_secs.max(1)),
94        };
95        self.breaker = Some(Arc::new(
96            crate::resilience::circuit_breaker::CircuitBreaker::new(cfg),
97        ));
98        self
99    }
100
101    /// Enable a minimal token-bucket rate limiter.
102    ///
103    /// - Prefer configuring via env to keep API surface small:
104    ///   - `AI_LIB_RPS` (requests per second)
105    ///   - `AI_LIB_RPM` (requests per minute)
106    pub fn rate_limit_rps(mut self, rps: f64) -> Self {
107        if let Some(cfg) = crate::resilience::rate_limiter::RateLimiterConfig::from_rps(rps) {
108            self.rate_limiter = Some(Arc::new(crate::resilience::rate_limiter::RateLimiter::new(
109                cfg,
110            )));
111        }
112        self
113    }
114
115    /// Override the base URL from the protocol manifest.
116    ///
117    /// This is primarily for testing with mock servers. In production, use the
118    /// base_url defined in the protocol manifest.
119    pub fn base_url_override(mut self, base_url: impl Into<String>) -> Self {
120        self.base_url_override = Some(base_url.into());
121        self
122    }
123
124    /// Build the client.
125    pub async fn build(self, model: &str) -> Result<AiClient> {
126        let mut loader = ProtocolLoader::new();
127
128        if let Some(path) = self.protocol_path {
129            loader = loader.with_base_path(path);
130        }
131
132        if self.hot_reload {
133            loader = loader.with_hot_reload(true);
134        }
135
136        // model is in form "provider/model-id"
137        let parts: Vec<&str> = model.split('/').collect();
138        let model_id = parts
139            .get(1)
140            .map(|s| s.to_string())
141            .unwrap_or_else(|| model.to_string());
142
143        let manifest = loader.load_model(model).await?;
144        let strict_streaming = self.strict_streaming
145            || std::env::var("AI_LIB_STRICT_STREAMING").ok().as_deref() == Some("1");
146        crate::client::validation::validate_manifest(&manifest, strict_streaming)?;
147
148        let transport = Arc::new(crate::transport::HttpTransport::new_with_base_url(
149            &manifest,
150            &model_id,
151            self.base_url_override.as_deref(),
152        )?);
153        let pipeline = Arc::new(crate::pipeline::Pipeline::from_manifest(&manifest)?);
154
155        let max_inflight = self.max_inflight.or_else(|| {
156            std::env::var("AI_LIB_MAX_INFLIGHT")
157                .ok()?
158                .parse::<usize>()
159                .ok()
160        });
161        let inflight = max_inflight.map(|n| Arc::new(Semaphore::new(n.max(1))));
162
163        // Optional per-attempt timeout (policy signal). Transport has its own timeout too; this is an extra guard.
164        let attempt_timeout = std::env::var("AI_LIB_ATTEMPT_TIMEOUT_MS")
165            .ok()
166            .and_then(|s| s.parse::<u64>().ok())
167            .filter(|ms| *ms > 0)
168            .map(std::time::Duration::from_millis);
169
170        let env_rps = std::env::var("AI_LIB_RPS")
171            .ok()
172            .and_then(|s| s.parse::<f64>().ok());
173        let env_rpm = std::env::var("AI_LIB_RPM")
174            .ok()
175            .and_then(|s| s.parse::<f64>().ok());
176        let env_rate_limiter = env_rps
177            .or_else(|| env_rpm.map(|rpm| rpm / 60.0))
178            .and_then(crate::resilience::rate_limiter::RateLimiterConfig::from_rps)
179            .map(|cfg| Arc::new(crate::resilience::rate_limiter::RateLimiter::new(cfg)));
180
181        // If no explicit rate limiter and manifest has rate limit headers, enable adaptive mode (rps=0)
182        let rate_limiter = self.rate_limiter.or(env_rate_limiter).or_else(|| {
183            if manifest.rate_limit_headers.is_some() {
184                crate::resilience::rate_limiter::RateLimiterConfig::from_rps(0.0)
185                    .map(|cfg| Arc::new(crate::resilience::rate_limiter::RateLimiter::new(cfg)))
186            } else {
187                None
188            }
189        });
190
191        Ok(AiClient {
192            manifest,
193            transport,
194            pipeline,
195            loader: Arc::new(loader),
196            fallbacks: self.fallbacks,
197            model_id,
198            strict_streaming,
199            feedback: self.feedback,
200            inflight,
201            max_inflight,
202            attempt_timeout,
203            breaker: self.breaker,
204            rate_limiter,
205        })
206    }
207}
208
209impl Default for AiClientBuilder {
210    fn default() -> Self {
211        Self::new()
212    }
213}