ai_lib_rust/client/
builder.rs1use crate::client::core::AiClient;
2use crate::protocol::ProtocolLoader;
3use crate::Result;
4use std::sync::Arc;
5use tokio::sync::Semaphore;
6
7pub struct AiClientBuilder {
11 protocol_path: Option<String>,
12 hot_reload: bool,
13 fallbacks: Vec<String>,
14 strict_streaming: bool,
15 feedback: Arc<dyn crate::telemetry::FeedbackSink>,
16 max_inflight: Option<usize>,
17 breaker: Option<Arc<crate::resilience::circuit_breaker::CircuitBreaker>>,
18 rate_limiter: Option<Arc<crate::resilience::rate_limiter::RateLimiter>>,
19 base_url_override: Option<String>,
21}
22
23impl AiClientBuilder {
24 pub fn new() -> Self {
25 Self {
26 protocol_path: None,
27 hot_reload: false,
28 fallbacks: Vec::new(),
29 strict_streaming: false,
30 feedback: crate::telemetry::noop_sink(),
31 max_inflight: None,
32 breaker: None,
33 rate_limiter: None,
34 base_url_override: None,
35 }
36 }
37
38 pub fn protocol_path(mut self, path: String) -> Self {
40 self.protocol_path = Some(path);
41 self
42 }
43
44 pub fn hot_reload(mut self, enable: bool) -> Self {
46 self.hot_reload = enable;
47 self
48 }
49
50 pub fn with_fallbacks(mut self, fallbacks: Vec<String>) -> Self {
52 self.fallbacks = fallbacks;
53 self
54 }
55
56 pub fn strict_streaming(mut self, enable: bool) -> Self {
60 self.strict_streaming = enable;
61 self
62 }
63
64 pub fn feedback_sink(mut self, sink: Arc<dyn crate::telemetry::FeedbackSink>) -> Self {
66 self.feedback = sink;
67 self
68 }
69
70 pub fn max_inflight(mut self, n: usize) -> Self {
73 self.max_inflight = Some(n.max(1));
74 self
75 }
76
77 pub fn circuit_breaker_default(mut self) -> Self {
83 let threshold = std::env::var("AI_LIB_BREAKER_FAILURE_THRESHOLD")
84 .ok()
85 .and_then(|s| s.parse::<u32>().ok())
86 .unwrap_or(5);
87 let cooldown_secs = std::env::var("AI_LIB_BREAKER_COOLDOWN_SECS")
88 .ok()
89 .and_then(|s| s.parse::<u64>().ok())
90 .unwrap_or(30);
91 let cfg = crate::resilience::circuit_breaker::CircuitBreakerConfig {
92 failure_threshold: threshold.max(1),
93 cooldown: std::time::Duration::from_secs(cooldown_secs.max(1)),
94 };
95 self.breaker = Some(Arc::new(
96 crate::resilience::circuit_breaker::CircuitBreaker::new(cfg),
97 ));
98 self
99 }
100
101 pub fn rate_limit_rps(mut self, rps: f64) -> Self {
107 if let Some(cfg) = crate::resilience::rate_limiter::RateLimiterConfig::from_rps(rps) {
108 self.rate_limiter = Some(Arc::new(crate::resilience::rate_limiter::RateLimiter::new(
109 cfg,
110 )));
111 }
112 self
113 }
114
115 pub fn base_url_override(mut self, base_url: impl Into<String>) -> Self {
120 self.base_url_override = Some(base_url.into());
121 self
122 }
123
124 pub async fn build(self, model: &str) -> Result<AiClient> {
126 let mut loader = ProtocolLoader::new();
127
128 if let Some(path) = self.protocol_path {
129 loader = loader.with_base_path(path);
130 }
131
132 if self.hot_reload {
133 loader = loader.with_hot_reload(true);
134 }
135
136 let parts: Vec<&str> = model.split('/').collect();
138 let model_id = parts
139 .get(1)
140 .map(|s| s.to_string())
141 .unwrap_or_else(|| model.to_string());
142
143 let manifest = loader.load_model(model).await?;
144 let strict_streaming = self.strict_streaming
145 || std::env::var("AI_LIB_STRICT_STREAMING").ok().as_deref() == Some("1");
146 crate::client::validation::validate_manifest(&manifest, strict_streaming)?;
147
148 let transport = Arc::new(crate::transport::HttpTransport::new_with_base_url(
149 &manifest,
150 &model_id,
151 self.base_url_override.as_deref(),
152 )?);
153 let pipeline = Arc::new(crate::pipeline::Pipeline::from_manifest(&manifest)?);
154
155 let max_inflight = self.max_inflight.or_else(|| {
156 std::env::var("AI_LIB_MAX_INFLIGHT")
157 .ok()?
158 .parse::<usize>()
159 .ok()
160 });
161 let inflight = max_inflight.map(|n| Arc::new(Semaphore::new(n.max(1))));
162
163 let attempt_timeout = std::env::var("AI_LIB_ATTEMPT_TIMEOUT_MS")
165 .ok()
166 .and_then(|s| s.parse::<u64>().ok())
167 .filter(|ms| *ms > 0)
168 .map(std::time::Duration::from_millis);
169
170 let env_rps = std::env::var("AI_LIB_RPS")
171 .ok()
172 .and_then(|s| s.parse::<f64>().ok());
173 let env_rpm = std::env::var("AI_LIB_RPM")
174 .ok()
175 .and_then(|s| s.parse::<f64>().ok());
176 let env_rate_limiter = env_rps
177 .or_else(|| env_rpm.map(|rpm| rpm / 60.0))
178 .and_then(crate::resilience::rate_limiter::RateLimiterConfig::from_rps)
179 .map(|cfg| Arc::new(crate::resilience::rate_limiter::RateLimiter::new(cfg)));
180
181 let rate_limiter = self.rate_limiter.or(env_rate_limiter).or_else(|| {
183 if manifest.rate_limit_headers.is_some() {
184 crate::resilience::rate_limiter::RateLimiterConfig::from_rps(0.0)
185 .map(|cfg| Arc::new(crate::resilience::rate_limiter::RateLimiter::new(cfg)))
186 } else {
187 None
188 }
189 });
190
191 Ok(AiClient {
192 manifest,
193 transport,
194 pipeline,
195 loader: Arc::new(loader),
196 fallbacks: self.fallbacks,
197 model_id,
198 strict_streaming,
199 feedback: self.feedback,
200 inflight,
201 max_inflight,
202 attempt_timeout,
203 breaker: self.breaker,
204 rate_limiter,
205 })
206 }
207}
208
209impl Default for AiClientBuilder {
210 fn default() -> Self {
211 Self::new()
212 }
213}