lash_core/provider/
handle.rs1use super::support::*;
2
3#[derive(Debug)]
5pub struct ProviderComponents {
6 pub provider: Box<dyn Provider>,
7 pub model_policy: Arc<dyn ProviderModelPolicy>,
8 pub failure_classifier: Arc<dyn ProviderFailureClassifier>,
9 pub rate_limiter: Arc<ProviderRateLimiter>,
10}
11
12impl ProviderComponents {
13 pub fn new(provider: Box<dyn Provider>, model_policy: Arc<dyn ProviderModelPolicy>) -> Self {
14 let options = provider.options();
15 Self {
16 provider,
17 model_policy,
18 failure_classifier: Arc::new(DefaultProviderFailureClassifier),
19 rate_limiter: Arc::new(ProviderRateLimiter::new(options.reliability.rate_limits)),
20 }
21 }
22
23 pub fn map_provider(
25 mut self,
26 map: impl FnOnce(Box<dyn Provider>) -> Box<dyn Provider>,
27 ) -> Self {
28 self.provider = map(self.provider);
29 self
30 }
31
32 pub fn with_failure_classifier(
33 mut self,
34 classifier: Arc<dyn ProviderFailureClassifier>,
35 ) -> Self {
36 self.failure_classifier = classifier;
37 self
38 }
39
40 pub fn with_clock(mut self, clock: Arc<dyn crate::Clock>) -> Self {
41 let options = self.provider.options();
42 self.rate_limiter = Arc::new(ProviderRateLimiter::with_clock(
43 options.reliability.rate_limits,
44 clock,
45 ));
46 self
47 }
48}
49
50impl Clone for ProviderComponents {
51 fn clone(&self) -> Self {
52 Self {
53 provider: self.provider.clone_boxed(),
54 model_policy: Arc::clone(&self.model_policy),
55 failure_classifier: Arc::clone(&self.failure_classifier),
56 rate_limiter: Arc::clone(&self.rate_limiter),
57 }
58 }
59}
60
61pub struct ProviderHandle {
64 components: ProviderComponents,
65}
66
67impl ProviderHandle {
68 pub fn new(components: ProviderComponents) -> Self {
69 Self { components }
70 }
71
72 pub fn unconfigured() -> Self {
73 Self::new(UnconfiguredProvider::default().into_components())
74 }
75
76 pub fn components(&self) -> &ProviderComponents {
77 &self.components
78 }
79
80 pub fn components_mut(&mut self) -> &mut ProviderComponents {
81 &mut self.components
82 }
83
84 pub fn with_clock(mut self, clock: Arc<dyn crate::Clock>) -> Self {
85 self.components = self.components.with_clock(clock);
86 self
87 }
88
89 pub fn kind(&self) -> &'static str {
90 self.components.provider.kind()
91 }
92
93 pub fn supported_variants(&self, model: &str) -> &'static [&'static str] {
94 self.components.model_policy.supported_variants(model)
95 }
96
97 pub fn validate_variant(&self, model: &str, variant: &str) -> Result<(), String> {
98 let variants = self.supported_variants(model);
99 if variants.is_empty() {
100 return Err(format!(
101 "Model `{}` on {} does not expose configurable variants.",
102 model,
103 self.kind()
104 ));
105 }
106 if variants.contains(&variant) {
107 return Ok(());
108 }
109 Err(format!(
110 "Unsupported variant `{}` for `{}` on {}. Available: {}",
111 variant,
112 model,
113 self.kind(),
114 variants.join(", ")
115 ))
116 }
117
118 pub fn input_usage_excludes_cached_tokens(&self) -> bool {
119 self.components
120 .model_policy
121 .input_usage_excludes_cached_tokens()
122 }
123
124 pub fn options(&self) -> ProviderOptions {
125 self.components.provider.options()
126 }
127
128 pub fn set_options(&mut self, options: ProviderOptions) {
129 self.components
130 .rate_limiter
131 .configure(options.reliability.rate_limits.clone());
132 self.components.provider.set_options(options)
133 }
134
135 pub fn requires_streaming(&self) -> bool {
136 self.components.provider.requires_streaming()
137 }
138
139 pub async fn complete(
140 &mut self,
141 request: LlmRequest,
142 ) -> Result<LlmResponse, LlmTransportError> {
143 let reliability = self.options().reliability;
144 let attempts = reliability.retry.attempts();
145 let mut attempt = 0;
146 loop {
147 let _permit = self.components.rate_limiter.admit(&request).await;
148 let result = self.components.provider.complete(request.clone()).await;
149 match result {
150 Ok(response) => return Ok(response),
151 Err(failure) => {
152 let failure = self.components.failure_classifier.classify(failure);
153 if attempt + 1 >= attempts || !failure.retryable {
154 return Err(failure);
155 }
156 let delay = reliability
157 .retry
158 .delay_for_attempt(attempt, failure.retry_after);
159 tracing::debug!(
160 target: "lash_core::provider::reliability",
161 provider = self.kind(),
162 attempt = attempt + 1,
163 max_attempts = attempts,
164 delay_ms = delay.as_millis() as u64,
165 err = %failure.message,
166 "provider call failed with retryable failure; sleeping before retry"
167 );
168 if let Some(events) = request.stream_events.as_ref() {
169 events.send(crate::llm::types::LlmStreamEvent::RetryStatus {
170 wait_seconds: delay.as_secs(),
171 attempt: (attempt + 1) as usize,
172 max_attempts: attempts as usize,
173 reason: failure.message.clone(),
174 });
175 }
176 self.components.rate_limiter.clock().sleep(delay).await;
177 attempt += 1;
178 }
179 }
180 }
181 }
182
183 pub fn to_spec(&self) -> ProviderSpec {
184 ProviderSpec {
185 kind: self.kind().to_string(),
186 config: self.components.provider.serialize_config(),
187 }
188 }
189
190 pub fn validate_model_name(&self, model: &str) -> Result<(), String> {
192 let m = model.trim();
193 if m.is_empty() {
194 return Err("model cannot be empty".to_string());
195 }
196 if m.contains(char::is_whitespace) {
197 return Err("model cannot contain whitespace".to_string());
198 }
199 Ok(())
200 }
201}
202
203impl std::fmt::Debug for ProviderHandle {
204 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
205 self.components.fmt(f)
206 }
207}
208
209impl Clone for ProviderHandle {
210 fn clone(&self) -> Self {
211 Self {
212 components: self.components.clone(),
213 }
214 }
215}
216
217impl PartialEq for ProviderHandle {
218 fn eq(&self, other: &Self) -> bool {
219 self.kind() == other.kind() && self.to_spec().config == other.to_spec().config
220 }
221}
222
223impl Eq for ProviderHandle {}
224
225#[derive(Clone, Debug, Default)]
229pub struct UnconfiguredProvider {
230 options: ProviderOptions,
231}
232
233impl UnconfiguredProvider {
234 fn into_components(self) -> ProviderComponents {
235 ProviderComponents::new(Box::new(self), Arc::new(StaticModelPolicy::new()))
236 }
237}
238
239#[async_trait]
240impl Provider for UnconfiguredProvider {
241 fn kind(&self) -> &'static str {
242 "unconfigured"
243 }
244
245 fn options(&self) -> ProviderOptions {
246 self.options.clone()
247 }
248
249 fn set_options(&mut self, options: ProviderOptions) {
250 self.options = options;
251 }
252
253 fn serialize_config(&self) -> serde_json::Value {
254 serde_json::Value::Object(Default::default())
255 }
256
257 async fn complete(&mut self, _request: LlmRequest) -> Result<LlmResponse, LlmTransportError> {
258 Err(LlmTransportError::new(
259 "no provider configured: host must set SessionPolicy.provider before running a turn",
260 ))
261 }
262
263 fn clone_boxed(&self) -> Box<dyn Provider> {
264 Box::new(self.clone())
265 }
266}