lash_core/provider/
handle.rs1use super::support::*;
2
3#[derive(Debug)]
5pub struct ProviderComponents {
6 pub provider: Box<dyn Provider>,
7 pub model_policy: Arc<dyn ProviderModelPolicy>,
8 pub failure_classifier: Arc<dyn ProviderFailureClassifier>,
9 pub rate_limiter: Arc<ProviderRateLimiter>,
10}
11
12impl ProviderComponents {
13 pub fn new(provider: Box<dyn Provider>, model_policy: Arc<dyn ProviderModelPolicy>) -> Self {
14 let options = provider.options();
15 Self {
16 provider,
17 model_policy,
18 failure_classifier: Arc::new(DefaultProviderFailureClassifier),
19 rate_limiter: Arc::new(ProviderRateLimiter::new(options.reliability.rate_limits)),
20 }
21 }
22
23 pub fn map_provider(
25 mut self,
26 map: impl FnOnce(Box<dyn Provider>) -> Box<dyn Provider>,
27 ) -> Self {
28 self.provider = map(self.provider);
29 self
30 }
31
32 pub fn with_failure_classifier(
33 mut self,
34 classifier: Arc<dyn ProviderFailureClassifier>,
35 ) -> Self {
36 self.failure_classifier = classifier;
37 self
38 }
39
40 pub fn with_clock(mut self, clock: Arc<dyn crate::Clock>) -> Self {
41 let options = self.provider.options();
42 self.rate_limiter = Arc::new(ProviderRateLimiter::with_clock(
43 options.reliability.rate_limits,
44 clock,
45 ));
46 self
47 }
48}
49
50impl Clone for ProviderComponents {
51 fn clone(&self) -> Self {
52 Self {
53 provider: self.provider.clone_boxed(),
54 model_policy: Arc::clone(&self.model_policy),
55 failure_classifier: Arc::clone(&self.failure_classifier),
56 rate_limiter: Arc::clone(&self.rate_limiter),
57 }
58 }
59}
60
61pub struct ProviderHandle {
64 components: ProviderComponents,
65}
66
67impl ProviderHandle {
68 pub fn new(components: ProviderComponents) -> Self {
69 Self { components }
70 }
71
72 pub fn unconfigured() -> Self {
73 Self::new(UnconfiguredProvider::default().into_components())
74 }
75
76 pub fn components(&self) -> &ProviderComponents {
77 &self.components
78 }
79
80 pub fn components_mut(&mut self) -> &mut ProviderComponents {
81 &mut self.components
82 }
83
84 pub fn with_clock(mut self, clock: Arc<dyn crate::Clock>) -> Self {
85 self.components = self.components.with_clock(clock);
86 self
87 }
88
89 pub fn kind(&self) -> &'static str {
90 self.components.provider.kind()
91 }
92
93 pub fn supported_variants(&self, model: &str) -> &'static [&'static str] {
94 self.components.model_policy.supported_variants(model)
95 }
96
97 pub fn validate_variant(&self, model: &str, variant: &str) -> Result<(), String> {
98 let variants = self.supported_variants(model);
99 if variants.is_empty() {
100 return Err(format!(
101 "Model `{}` on {} does not expose configurable variants.",
102 model,
103 self.kind()
104 ));
105 }
106 if variants.contains(&variant) {
107 return Ok(());
108 }
109 Err(format!(
110 "Unsupported variant `{}` for `{}` on {}. Available: {}",
111 variant,
112 model,
113 self.kind(),
114 variants.join(", ")
115 ))
116 }
117
118 pub fn options(&self) -> ProviderOptions {
119 self.components.provider.options()
120 }
121
122 pub fn set_options(&mut self, options: ProviderOptions) {
123 self.components
124 .rate_limiter
125 .configure(options.reliability.rate_limits.clone());
126 self.components.provider.set_options(options)
127 }
128
129 pub fn requires_streaming(&self) -> bool {
130 self.components.provider.requires_streaming()
131 }
132
133 pub async fn complete(
134 &mut self,
135 request: LlmRequest,
136 ) -> Result<LlmResponse, LlmTransportError> {
137 let reliability = self.options().reliability;
138 let attempts = reliability.retry.attempts();
139 let mut attempt = 0;
140 let throttle_budget = Duration::from_millis(reliability.retry.throttle_wait_budget_ms);
143 let mut throttle_waited = Duration::ZERO;
144 loop {
145 let _permit = self.components.rate_limiter.admit(&request).await;
146 let result = self.components.provider.complete(request.clone()).await;
147 match result {
148 Ok(response) => return Ok(response),
149 Err(failure) => {
150 let failure = self.components.failure_classifier.classify(failure);
151 if failure.retryable
163 && failure.kind == ProviderFailureKind::Quota
164 && let Some(retry_after) = failure.retry_after
165 {
166 let wait = reliability.retry.cap_retry_after(retry_after);
167 let charge = wait.max(MIN_THROTTLE_BUDGET_CHARGE);
168 if throttle_waited.saturating_add(charge) <= throttle_budget {
171 throttle_waited += charge;
172 tracing::debug!(
173 target: "lash_core::provider::reliability",
174 provider = self.kind(),
175 attempt = attempt + 1,
176 max_attempts = attempts,
177 wait_ms = wait.as_millis() as u64,
178 throttle_waited_ms = throttle_waited.as_millis() as u64,
179 err = %failure.message,
180 "provider throttled with retry-after; waiting without consuming a retry attempt"
181 );
182 if let Some(events) = request.stream_events.as_ref() {
183 events.send(crate::llm::types::LlmStreamEvent::RetryStatus {
184 wait_seconds: wait.as_secs(),
185 attempt: (attempt + 1) as usize,
186 max_attempts: attempts as usize,
187 reason: failure.message.clone(),
188 });
189 }
190 self.components.rate_limiter.clock().sleep(wait).await;
191 continue;
192 }
193 }
194 if attempt + 1 >= attempts || !failure.retryable {
195 return Err(failure);
196 }
197 let delay = reliability
198 .retry
199 .delay_for_attempt(attempt, failure.retry_after);
200 tracing::debug!(
201 target: "lash_core::provider::reliability",
202 provider = self.kind(),
203 attempt = attempt + 1,
204 max_attempts = attempts,
205 delay_ms = delay.as_millis() as u64,
206 err = %failure.message,
207 "provider call failed with retryable failure; sleeping before retry"
208 );
209 if let Some(events) = request.stream_events.as_ref() {
210 events.send(crate::llm::types::LlmStreamEvent::RetryStatus {
211 wait_seconds: delay.as_secs(),
212 attempt: (attempt + 1) as usize,
213 max_attempts: attempts as usize,
214 reason: failure.message.clone(),
215 });
216 }
217 self.components.rate_limiter.clock().sleep(delay).await;
218 attempt += 1;
219 }
220 }
221 }
222 }
223
224 pub async fn close(&self) -> Result<(), LlmTransportError> {
232 self.components.provider.close().await
233 }
234
235 pub fn to_spec(&self) -> ProviderSpec {
236 ProviderSpec {
237 kind: self.kind().to_string(),
238 config: self.components.provider.serialize_config(),
239 }
240 }
241
242 pub fn validate_model_name(&self, model: &str) -> Result<(), String> {
244 let m = model.trim();
245 if m.is_empty() {
246 return Err("model cannot be empty".to_string());
247 }
248 if m.contains(char::is_whitespace) {
249 return Err("model cannot contain whitespace".to_string());
250 }
251 Ok(())
252 }
253}
254
255impl std::fmt::Debug for ProviderHandle {
256 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
257 self.components.fmt(f)
258 }
259}
260
261impl Clone for ProviderHandle {
262 fn clone(&self) -> Self {
263 Self {
264 components: self.components.clone(),
265 }
266 }
267}
268
269impl PartialEq for ProviderHandle {
270 fn eq(&self, other: &Self) -> bool {
271 self.kind() == other.kind() && self.to_spec().config == other.to_spec().config
272 }
273}
274
275impl Eq for ProviderHandle {}
276
277#[derive(Clone, Debug, Default)]
281pub struct UnconfiguredProvider {
282 options: ProviderOptions,
283}
284
285impl UnconfiguredProvider {
286 fn into_components(self) -> ProviderComponents {
287 ProviderComponents::new(Box::new(self), Arc::new(StaticModelPolicy::new()))
288 }
289}
290
291#[async_trait]
292impl Provider for UnconfiguredProvider {
293 fn kind(&self) -> &'static str {
294 "unconfigured"
295 }
296
297 fn options(&self) -> ProviderOptions {
298 self.options.clone()
299 }
300
301 fn set_options(&mut self, options: ProviderOptions) {
302 self.options = options;
303 }
304
305 fn serialize_config(&self) -> serde_json::Value {
306 serde_json::Value::Object(Default::default())
307 }
308
309 async fn complete(&mut self, _request: LlmRequest) -> Result<LlmResponse, LlmTransportError> {
310 Err(LlmTransportError::new(
311 "no provider configured: host must set SessionPolicy.provider before running a turn",
312 ))
313 }
314
315 fn clone_boxed(&self) -> Box<dyn Provider> {
316 Box::new(self.clone())
317 }
318}