Skip to main content

ai_lib_rust/client/
builder.rs

1use crate::client::core::AiClient;
2use crate::feedback::FeedbackSink;
3use crate::protocol::ProtocolLoader;
4use crate::Result;
5use std::sync::Arc;
6use tokio::sync::Semaphore;
7
8/// Builder for creating clients with custom configuration.
9///
10/// Keep this surface area small and predictable (developer-friendly).
11pub struct AiClientBuilder {
12    protocol_path: Option<String>,
13    hot_reload: bool,
14    fallbacks: Vec<String>,
15    strict_streaming: bool,
16    feedback: Arc<dyn FeedbackSink>,
17    max_inflight: Option<usize>,
18    breaker: Option<Arc<crate::resilience::circuit_breaker::CircuitBreaker>>,
19    rate_limiter: Option<Arc<crate::resilience::rate_limiter::RateLimiter>>,
20    /// Override base URL (primarily for testing with mock servers)
21    base_url_override: Option<String>,
22}
23
24impl AiClientBuilder {
25    pub fn new() -> Self {
26        Self {
27            protocol_path: None,
28            hot_reload: false,
29            fallbacks: Vec::new(),
30            strict_streaming: false,
31            feedback: crate::feedback::noop_sink(),
32            max_inflight: None,
33            breaker: None,
34            rate_limiter: None,
35            base_url_override: None,
36        }
37    }
38
39    /// Set custom protocol directory path.
40    pub fn protocol_path(mut self, path: String) -> Self {
41        self.protocol_path = Some(path);
42        self
43    }
44
45    /// Enable hot reload of protocol files.
46    pub fn hot_reload(mut self, enable: bool) -> Self {
47        self.hot_reload = enable;
48        self
49    }
50
51    /// Set fallback models.
52    pub fn with_fallbacks(mut self, fallbacks: Vec<String>) -> Self {
53        self.fallbacks = fallbacks;
54        self
55    }
56
57    /// Enable strict streaming validation (fail fast when streaming config is incomplete).
58    ///
59    /// This is intentionally opt-in to preserve compatibility with partial manifests.
60    pub fn strict_streaming(mut self, enable: bool) -> Self {
61        self.strict_streaming = enable;
62        self
63    }
64
65    /// Inject a feedback sink. Default is a no-op sink.
66    pub fn feedback_sink(mut self, sink: Arc<dyn FeedbackSink>) -> Self {
67        self.feedback = sink;
68        self
69    }
70
71    /// Limit maximum number of in-flight requests/streams.
72    /// This is a simple backpressure mechanism for production safety.
73    pub fn max_inflight(mut self, n: usize) -> Self {
74        self.max_inflight = Some(n.max(1));
75        self
76    }
77
78    /// Enable a minimal circuit breaker.
79    ///
80    /// Defaults can also be enabled via env:
81    /// - `AI_LIB_BREAKER_FAILURE_THRESHOLD` (default 5)
82    /// - `AI_LIB_BREAKER_COOLDOWN_SECS` (default 30)
83    pub fn circuit_breaker_default(mut self) -> Self {
84        let threshold = std::env::var("AI_LIB_BREAKER_FAILURE_THRESHOLD")
85            .ok()
86            .and_then(|s| s.parse::<u32>().ok())
87            .unwrap_or(5);
88        let cooldown_secs = std::env::var("AI_LIB_BREAKER_COOLDOWN_SECS")
89            .ok()
90            .and_then(|s| s.parse::<u64>().ok())
91            .unwrap_or(30);
92        let cfg = crate::resilience::circuit_breaker::CircuitBreakerConfig {
93            failure_threshold: threshold.max(1),
94            cooldown: std::time::Duration::from_secs(cooldown_secs.max(1)),
95        };
96        self.breaker = Some(Arc::new(
97            crate::resilience::circuit_breaker::CircuitBreaker::new(cfg),
98        ));
99        self
100    }
101
102    /// Enable a minimal token-bucket rate limiter.
103    ///
104    /// - Prefer configuring via env to keep API surface small:
105    ///   - `AI_LIB_RPS` (requests per second)
106    ///   - `AI_LIB_RPM` (requests per minute)
107    pub fn rate_limit_rps(mut self, rps: f64) -> Self {
108        if let Some(cfg) = crate::resilience::rate_limiter::RateLimiterConfig::from_rps(rps) {
109            self.rate_limiter = Some(Arc::new(crate::resilience::rate_limiter::RateLimiter::new(
110                cfg,
111            )));
112        }
113        self
114    }
115
116    /// Override the base URL from the protocol manifest.
117    ///
118    /// This is primarily for testing with mock servers. In production, use the
119    /// base_url defined in the protocol manifest.
120    pub fn base_url_override(mut self, base_url: impl Into<String>) -> Self {
121        self.base_url_override = Some(base_url.into());
122        self
123    }
124
125    /// Build the client.
126    pub async fn build(self, model: &str) -> Result<AiClient> {
127        let mut loader = ProtocolLoader::new();
128
129        if let Some(path) = self.protocol_path {
130            loader = loader.with_base_path(path);
131        }
132
133        if self.hot_reload {
134            loader = loader.with_hot_reload(true);
135        }
136
137        // model is in form "provider/model-id" or "provider/org/model-name" (e.g. nvidia/minimaxai/minimax-m2)
138        let parts: Vec<&str> = model.split('/').collect();
139        let model_id = if parts.len() >= 2 {
140            parts[1..].join("/")
141        } else {
142            model.to_string()
143        };
144
145        let manifest = loader.load_model(model).await?;
146        let strict_streaming = self.strict_streaming
147            || std::env::var("AI_LIB_STRICT_STREAMING").ok().as_deref() == Some("1");
148        crate::client::validation::validate_manifest(&manifest, strict_streaming)?;
149
150        let transport = Arc::new(crate::transport::HttpTransport::new_with_base_url(
151            &manifest,
152            &model_id,
153            self.base_url_override.as_deref(),
154        )?);
155        let pipeline = Arc::new(crate::pipeline::Pipeline::from_manifest(&manifest)?);
156
157        let max_inflight = self.max_inflight.or_else(|| {
158            std::env::var("AI_LIB_MAX_INFLIGHT")
159                .ok()?
160                .parse::<usize>()
161                .ok()
162        });
163        let inflight = max_inflight.map(|n| Arc::new(Semaphore::new(n.max(1))));
164
165        // Optional per-attempt timeout (policy signal). Transport has its own timeout too; this is an extra guard.
166        let attempt_timeout = std::env::var("AI_LIB_ATTEMPT_TIMEOUT_MS")
167            .ok()
168            .and_then(|s| s.parse::<u64>().ok())
169            .filter(|ms| *ms > 0)
170            .map(std::time::Duration::from_millis);
171
172        let env_rps = std::env::var("AI_LIB_RPS")
173            .ok()
174            .and_then(|s| s.parse::<f64>().ok());
175        let env_rpm = std::env::var("AI_LIB_RPM")
176            .ok()
177            .and_then(|s| s.parse::<f64>().ok());
178        let env_rate_limiter = env_rps
179            .or_else(|| env_rpm.map(|rpm| rpm / 60.0))
180            .and_then(crate::resilience::rate_limiter::RateLimiterConfig::from_rps)
181            .map(|cfg| Arc::new(crate::resilience::rate_limiter::RateLimiter::new(cfg)));
182
183        // If no explicit rate limiter and manifest has rate limit headers, enable adaptive mode (rps=0)
184        let rate_limiter = self.rate_limiter.or(env_rate_limiter).or_else(|| {
185            if manifest.rate_limit_headers.is_some() {
186                crate::resilience::rate_limiter::RateLimiterConfig::from_rps(0.0)
187                    .map(|cfg| Arc::new(crate::resilience::rate_limiter::RateLimiter::new(cfg)))
188            } else {
189                None
190            }
191        });
192
193        Ok(AiClient {
194            manifest,
195            transport,
196            pipeline,
197            loader: Arc::new(loader),
198            fallbacks: self.fallbacks,
199            model_id,
200            strict_streaming,
201            feedback: self.feedback,
202            inflight,
203            max_inflight,
204            attempt_timeout,
205            breaker: self.breaker,
206            rate_limiter,
207        })
208    }
209}
210
211impl Default for AiClientBuilder {
212    fn default() -> Self {
213        Self::new()
214    }
215}