lash_core/provider/
handle.rs1use super::support::*;
2
3#[derive(Debug)]
5pub struct ProviderComponents {
6 pub provider: Box<dyn Provider>,
7 pub model_policy: Arc<dyn ProviderModelPolicy>,
8 pub failure_classifier: Arc<dyn ProviderFailureClassifier>,
9 pub rate_limiter: Arc<ProviderRateLimiter>,
10}
11
12impl ProviderComponents {
13 pub fn new(provider: Box<dyn Provider>, model_policy: Arc<dyn ProviderModelPolicy>) -> Self {
14 let options = provider.options();
15 Self {
16 provider,
17 model_policy,
18 failure_classifier: Arc::new(DefaultProviderFailureClassifier),
19 rate_limiter: Arc::new(ProviderRateLimiter::new(options.reliability.rate_limits)),
20 }
21 }
22
23 pub fn map_provider(
25 mut self,
26 map: impl FnOnce(Box<dyn Provider>) -> Box<dyn Provider>,
27 ) -> Self {
28 self.provider = map(self.provider);
29 self
30 }
31
32 pub fn with_failure_classifier(
33 mut self,
34 classifier: Arc<dyn ProviderFailureClassifier>,
35 ) -> Self {
36 self.failure_classifier = classifier;
37 self
38 }
39
40 pub fn with_clock(mut self, clock: Arc<dyn crate::Clock>) -> Self {
41 let options = self.provider.options();
42 self.rate_limiter = Arc::new(ProviderRateLimiter::with_clock(
43 options.reliability.rate_limits,
44 clock,
45 ));
46 self
47 }
48}
49
50impl Clone for ProviderComponents {
51 fn clone(&self) -> Self {
52 Self {
53 provider: self.provider.clone_boxed(),
54 model_policy: Arc::clone(&self.model_policy),
55 failure_classifier: Arc::clone(&self.failure_classifier),
56 rate_limiter: Arc::clone(&self.rate_limiter),
57 }
58 }
59}
60
61pub struct ProviderHandle {
64 components: ProviderComponents,
65}
66
67impl ProviderHandle {
68 pub fn new(components: ProviderComponents) -> Self {
69 Self { components }
70 }
71
72 pub fn components(&self) -> &ProviderComponents {
73 &self.components
74 }
75
76 pub fn components_mut(&mut self) -> &mut ProviderComponents {
77 &mut self.components
78 }
79
80 pub fn with_clock(mut self, clock: Arc<dyn crate::Clock>) -> Self {
81 self.components = self.components.with_clock(clock);
82 self
83 }
84
85 pub fn kind(&self) -> &'static str {
86 self.components.provider.kind()
87 }
88
89 pub fn supported_variants(&self, model: &str) -> &'static [&'static str] {
90 self.components.model_policy.supported_variants(model)
91 }
92
93 pub fn validate_variant(&self, model: &str, variant: &str) -> Result<(), String> {
94 let variants = self.supported_variants(model);
95 if variants.is_empty() {
96 return Err(format!(
97 "Model `{}` on {} does not expose configurable variants.",
98 model,
99 self.kind()
100 ));
101 }
102 if variants.contains(&variant) {
103 return Ok(());
104 }
105 Err(format!(
106 "Unsupported variant `{}` for `{}` on {}. Available: {}",
107 variant,
108 model,
109 self.kind(),
110 variants.join(", ")
111 ))
112 }
113
114 pub fn input_usage_excludes_cached_tokens(&self) -> bool {
115 self.components
116 .model_policy
117 .input_usage_excludes_cached_tokens()
118 }
119
120 pub fn options(&self) -> ProviderOptions {
121 self.components.provider.options()
122 }
123
124 pub fn set_options(&mut self, options: ProviderOptions) {
125 self.components
126 .rate_limiter
127 .configure(options.reliability.rate_limits.clone());
128 self.components.provider.set_options(options)
129 }
130
131 pub fn requires_streaming(&self) -> bool {
132 self.components.provider.requires_streaming()
133 }
134
135 pub async fn complete(
136 &mut self,
137 request: LlmRequest,
138 ) -> Result<LlmResponse, LlmTransportError> {
139 let reliability = self.options().reliability;
140 let attempts = reliability.retry.attempts();
141 let mut attempt = 0;
142 loop {
143 let _permit = self.components.rate_limiter.admit(&request).await;
144 let result = self.components.provider.complete(request.clone()).await;
145 match result {
146 Ok(response) => return Ok(response),
147 Err(failure) => {
148 let failure = self.components.failure_classifier.classify(failure);
149 if attempt + 1 >= attempts || !failure.retryable {
150 return Err(failure);
151 }
152 let delay = reliability
153 .retry
154 .delay_for_attempt(attempt, failure.retry_after);
155 tracing::debug!(
156 target: "lash_core::provider::reliability",
157 provider = self.kind(),
158 attempt = attempt + 1,
159 max_attempts = attempts,
160 delay_ms = delay.as_millis() as u64,
161 err = %failure.message,
162 "provider call failed with retryable failure; sleeping before retry"
163 );
164 if let Some(events) = request.stream_events.as_ref() {
165 events.send(crate::llm::types::LlmStreamEvent::RetryStatus {
166 wait_seconds: delay.as_secs(),
167 attempt: (attempt + 1) as usize,
168 max_attempts: attempts as usize,
169 reason: failure.message.clone(),
170 });
171 }
172 self.components.rate_limiter.clock().sleep(delay).await;
173 attempt += 1;
174 }
175 }
176 }
177 }
178
179 pub fn to_spec(&self) -> ProviderSpec {
180 ProviderSpec {
181 kind: self.kind().to_string(),
182 config: self.components.provider.serialize_config(),
183 }
184 }
185
186 pub fn validate_model_name(&self, model: &str) -> Result<(), String> {
188 let m = model.trim();
189 if m.is_empty() {
190 return Err("model cannot be empty".to_string());
191 }
192 if m.contains(char::is_whitespace) {
193 return Err("model cannot contain whitespace".to_string());
194 }
195 Ok(())
196 }
197}
198
199impl std::fmt::Debug for ProviderHandle {
200 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
201 self.components.fmt(f)
202 }
203}
204
205impl Clone for ProviderHandle {
206 fn clone(&self) -> Self {
207 Self {
208 components: self.components.clone(),
209 }
210 }
211}
212
213impl PartialEq for ProviderHandle {
214 fn eq(&self, other: &Self) -> bool {
215 self.kind() == other.kind() && self.to_spec().config == other.to_spec().config
216 }
217}
218
219impl Eq for ProviderHandle {}
220
221impl Default for ProviderHandle {
222 fn default() -> Self {
223 Self::new(UnconfiguredProvider::default().into_components())
224 }
225}
226
227#[derive(Clone, Debug, Default)]
233pub struct UnconfiguredProvider {
234 options: ProviderOptions,
235}
236
237impl UnconfiguredProvider {
238 fn into_components(self) -> ProviderComponents {
239 ProviderComponents::new(Box::new(self), Arc::new(StaticModelPolicy::new()))
240 }
241}
242
243#[async_trait]
244impl Provider for UnconfiguredProvider {
245 fn kind(&self) -> &'static str {
246 "unconfigured"
247 }
248
249 fn options(&self) -> ProviderOptions {
250 self.options.clone()
251 }
252
253 fn set_options(&mut self, options: ProviderOptions) {
254 self.options = options;
255 }
256
257 fn serialize_config(&self) -> serde_json::Value {
258 serde_json::Value::Object(Default::default())
259 }
260
261 async fn complete(&mut self, _request: LlmRequest) -> Result<LlmResponse, LlmTransportError> {
262 Err(LlmTransportError::new(
263 "no provider configured: host must set SessionPolicy.provider before running a turn",
264 ))
265 }
266
267 fn clone_boxed(&self) -> Box<dyn Provider> {
268 Box::new(self.clone())
269 }
270}