ai_lib_rust/client/
builder.rs1use crate::client::core::AiClient;
2use crate::feedback::FeedbackSink;
3use crate::protocol::ProtocolLoader;
4use crate::Result;
5use std::sync::Arc;
6use tokio::sync::Semaphore;
7
8pub struct AiClientBuilder {
12 protocol_path: Option<String>,
13 hot_reload: bool,
14 fallbacks: Vec<String>,
15 strict_streaming: bool,
16 feedback: Arc<dyn FeedbackSink>,
17 max_inflight: Option<usize>,
18 breaker: Option<Arc<crate::resilience::circuit_breaker::CircuitBreaker>>,
19 rate_limiter: Option<Arc<crate::resilience::rate_limiter::RateLimiter>>,
20 base_url_override: Option<String>,
22}
23
24impl AiClientBuilder {
25 pub fn new() -> Self {
26 Self {
27 protocol_path: None,
28 hot_reload: false,
29 fallbacks: Vec::new(),
30 strict_streaming: false,
31 feedback: crate::feedback::noop_sink(),
32 max_inflight: None,
33 breaker: None,
34 rate_limiter: None,
35 base_url_override: None,
36 }
37 }
38
39 pub fn protocol_path(mut self, path: String) -> Self {
41 self.protocol_path = Some(path);
42 self
43 }
44
45 pub fn hot_reload(mut self, enable: bool) -> Self {
47 self.hot_reload = enable;
48 self
49 }
50
51 pub fn with_fallbacks(mut self, fallbacks: Vec<String>) -> Self {
53 self.fallbacks = fallbacks;
54 self
55 }
56
57 pub fn strict_streaming(mut self, enable: bool) -> Self {
61 self.strict_streaming = enable;
62 self
63 }
64
65 pub fn feedback_sink(mut self, sink: Arc<dyn FeedbackSink>) -> Self {
67 self.feedback = sink;
68 self
69 }
70
71 pub fn max_inflight(mut self, n: usize) -> Self {
74 self.max_inflight = Some(n.max(1));
75 self
76 }
77
78 pub fn circuit_breaker_default(mut self) -> Self {
84 let threshold = std::env::var("AI_LIB_BREAKER_FAILURE_THRESHOLD")
85 .ok()
86 .and_then(|s| s.parse::<u32>().ok())
87 .unwrap_or(5);
88 let cooldown_secs = std::env::var("AI_LIB_BREAKER_COOLDOWN_SECS")
89 .ok()
90 .and_then(|s| s.parse::<u64>().ok())
91 .unwrap_or(30);
92 let cfg = crate::resilience::circuit_breaker::CircuitBreakerConfig {
93 failure_threshold: threshold.max(1),
94 cooldown: std::time::Duration::from_secs(cooldown_secs.max(1)),
95 };
96 self.breaker = Some(Arc::new(
97 crate::resilience::circuit_breaker::CircuitBreaker::new(cfg),
98 ));
99 self
100 }
101
102 pub fn rate_limit_rps(mut self, rps: f64) -> Self {
108 if let Some(cfg) = crate::resilience::rate_limiter::RateLimiterConfig::from_rps(rps) {
109 self.rate_limiter = Some(Arc::new(crate::resilience::rate_limiter::RateLimiter::new(
110 cfg,
111 )));
112 }
113 self
114 }
115
116 pub fn base_url_override(mut self, base_url: impl Into<String>) -> Self {
121 self.base_url_override = Some(base_url.into());
122 self
123 }
124
125 pub async fn build(self, model: &str) -> Result<AiClient> {
127 let mut loader = ProtocolLoader::new();
128
129 if let Some(path) = self.protocol_path {
130 loader = loader.with_base_path(path);
131 }
132
133 if self.hot_reload {
134 loader = loader.with_hot_reload(true);
135 }
136
137 let parts: Vec<&str> = model.split('/').collect();
139 let model_id = if parts.len() >= 2 {
140 parts[1..].join("/")
141 } else {
142 model.to_string()
143 };
144
145 let manifest = loader.load_model(model).await?;
146 let strict_streaming = self.strict_streaming
147 || std::env::var("AI_LIB_STRICT_STREAMING").ok().as_deref() == Some("1");
148 crate::client::validation::validate_manifest(&manifest, strict_streaming)?;
149
150 let transport = Arc::new(crate::transport::HttpTransport::new_with_base_url(
151 &manifest,
152 &model_id,
153 self.base_url_override.as_deref(),
154 )?);
155 let pipeline = Arc::new(crate::pipeline::Pipeline::from_manifest(&manifest)?);
156
157 let max_inflight = self.max_inflight.or_else(|| {
158 std::env::var("AI_LIB_MAX_INFLIGHT")
159 .ok()?
160 .parse::<usize>()
161 .ok()
162 });
163 let inflight = max_inflight.map(|n| Arc::new(Semaphore::new(n.max(1))));
164
165 let attempt_timeout = std::env::var("AI_LIB_ATTEMPT_TIMEOUT_MS")
167 .ok()
168 .and_then(|s| s.parse::<u64>().ok())
169 .filter(|ms| *ms > 0)
170 .map(std::time::Duration::from_millis);
171
172 let env_rps = std::env::var("AI_LIB_RPS")
173 .ok()
174 .and_then(|s| s.parse::<f64>().ok());
175 let env_rpm = std::env::var("AI_LIB_RPM")
176 .ok()
177 .and_then(|s| s.parse::<f64>().ok());
178 let env_rate_limiter = env_rps
179 .or_else(|| env_rpm.map(|rpm| rpm / 60.0))
180 .and_then(crate::resilience::rate_limiter::RateLimiterConfig::from_rps)
181 .map(|cfg| Arc::new(crate::resilience::rate_limiter::RateLimiter::new(cfg)));
182
183 let rate_limiter = self.rate_limiter.or(env_rate_limiter).or_else(|| {
185 if manifest.rate_limit_headers.is_some() {
186 crate::resilience::rate_limiter::RateLimiterConfig::from_rps(0.0)
187 .map(|cfg| Arc::new(crate::resilience::rate_limiter::RateLimiter::new(cfg)))
188 } else {
189 None
190 }
191 });
192
193 Ok(AiClient {
194 manifest,
195 transport,
196 pipeline,
197 loader: Arc::new(loader),
198 fallbacks: self.fallbacks,
199 model_id,
200 strict_streaming,
201 feedback: self.feedback,
202 inflight,
203 max_inflight,
204 attempt_timeout,
205 breaker: self.breaker,
206 rate_limiter,
207 })
208 }
209}
210
211impl Default for AiClientBuilder {
212 fn default() -> Self {
213 Self::new()
214 }
215}