use super::support::*;
#[derive(Debug)]
pub struct ProviderComponents {
pub provider: Box<dyn Provider>,
pub model_policy: Arc<dyn ProviderModelPolicy>,
pub failure_classifier: Arc<dyn ProviderFailureClassifier>,
pub rate_limiter: Arc<ProviderRateLimiter>,
}
impl ProviderComponents {
pub fn new(provider: Box<dyn Provider>, model_policy: Arc<dyn ProviderModelPolicy>) -> Self {
let options = provider.options();
Self {
provider,
model_policy,
failure_classifier: Arc::new(DefaultProviderFailureClassifier),
rate_limiter: Arc::new(ProviderRateLimiter::new(options.reliability.rate_limits)),
}
}
pub fn map_provider(
mut self,
map: impl FnOnce(Box<dyn Provider>) -> Box<dyn Provider>,
) -> Self {
self.provider = map(self.provider);
self
}
pub fn with_failure_classifier(
mut self,
classifier: Arc<dyn ProviderFailureClassifier>,
) -> Self {
self.failure_classifier = classifier;
self
}
}
impl Clone for ProviderComponents {
fn clone(&self) -> Self {
Self {
provider: self.provider.clone_boxed(),
model_policy: Arc::clone(&self.model_policy),
failure_classifier: Arc::clone(&self.failure_classifier),
rate_limiter: Arc::clone(&self.rate_limiter),
}
}
}
pub struct ProviderHandle {
components: ProviderComponents,
}
impl ProviderHandle {
pub fn new(components: ProviderComponents) -> Self {
Self { components }
}
pub fn components(&self) -> &ProviderComponents {
&self.components
}
pub fn components_mut(&mut self) -> &mut ProviderComponents {
&mut self.components
}
pub fn kind(&self) -> &'static str {
self.components.provider.kind()
}
pub fn supported_variants(&self, model: &str) -> &'static [&'static str] {
self.components.model_policy.supported_variants(model)
}
pub fn validate_variant(&self, model: &str, variant: &str) -> Result<(), String> {
let variants = self.supported_variants(model);
if variants.is_empty() {
return Err(format!(
"Model `{}` on {} does not expose configurable variants.",
model,
self.kind()
));
}
if variants.contains(&variant) {
return Ok(());
}
Err(format!(
"Unsupported variant `{}` for `{}` on {}. Available: {}",
variant,
model,
self.kind(),
variants.join(", ")
))
}
pub fn input_usage_excludes_cached_tokens(&self) -> bool {
self.components
.model_policy
.input_usage_excludes_cached_tokens()
}
pub fn options(&self) -> ProviderOptions {
self.components.provider.options()
}
pub fn set_options(&mut self, options: ProviderOptions) {
self.components
.rate_limiter
.configure(options.reliability.rate_limits.clone());
self.components.provider.set_options(options)
}
pub fn requires_streaming(&self) -> bool {
self.components.provider.requires_streaming()
}
pub async fn complete(
&mut self,
request: LlmRequest,
) -> Result<LlmResponse, LlmTransportError> {
let reliability = self.options().reliability;
let attempts = reliability.retry.attempts();
let mut attempt = 0;
loop {
let _permit = self.components.rate_limiter.admit(&request).await;
let result = self.components.provider.complete(request.clone()).await;
match result {
Ok(response) => return Ok(response),
Err(failure) => {
let failure = self.components.failure_classifier.classify(failure);
if attempt + 1 >= attempts || !failure.retryable {
return Err(failure);
}
let delay = reliability
.retry
.delay_for_attempt(attempt, failure.retry_after);
tracing::debug!(
target: "lash_core::provider::reliability",
provider = self.kind(),
attempt = attempt + 1,
max_attempts = attempts,
delay_ms = delay.as_millis() as u64,
err = %failure.message,
"provider call failed with retryable failure; sleeping before retry"
);
if let Some(events) = request.stream_events.as_ref() {
events.send(crate::llm::types::LlmStreamEvent::RetryStatus {
wait_seconds: delay.as_secs(),
attempt: (attempt + 1) as usize,
max_attempts: attempts as usize,
reason: failure.message.clone(),
});
}
tokio::time::sleep(delay).await;
attempt += 1;
}
}
}
}
pub fn to_spec(&self) -> ProviderSpec {
ProviderSpec {
kind: self.kind().to_string(),
config: self.components.provider.serialize_config(),
}
}
pub fn validate_model_name(&self, model: &str) -> Result<(), String> {
let m = model.trim();
if m.is_empty() {
return Err("model cannot be empty".to_string());
}
if m.contains(char::is_whitespace) {
return Err("model cannot contain whitespace".to_string());
}
Ok(())
}
}
impl std::fmt::Debug for ProviderHandle {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
self.components.fmt(f)
}
}
impl Clone for ProviderHandle {
fn clone(&self) -> Self {
Self {
components: self.components.clone(),
}
}
}
impl PartialEq for ProviderHandle {
fn eq(&self, other: &Self) -> bool {
self.kind() == other.kind() && self.to_spec().config == other.to_spec().config
}
}
impl Eq for ProviderHandle {}
impl Default for ProviderHandle {
fn default() -> Self {
Self::new(UnconfiguredProvider::default().into_components())
}
}
#[derive(Clone, Debug, Default)]
pub struct UnconfiguredProvider {
options: ProviderOptions,
}
impl UnconfiguredProvider {
fn into_components(self) -> ProviderComponents {
ProviderComponents::new(Box::new(self), Arc::new(StaticModelPolicy::new()))
}
}
#[async_trait]
impl Provider for UnconfiguredProvider {
fn kind(&self) -> &'static str {
"unconfigured"
}
fn options(&self) -> ProviderOptions {
self.options.clone()
}
fn set_options(&mut self, options: ProviderOptions) {
self.options = options;
}
fn serialize_config(&self) -> serde_json::Value {
serde_json::Value::Object(Default::default())
}
async fn complete(&mut self, _request: LlmRequest) -> Result<LlmResponse, LlmTransportError> {
Err(LlmTransportError::new(
"no provider configured: host must set SessionPolicy.provider before running a turn",
))
}
fn clone_boxed(&self) -> Box<dyn Provider> {
Box::new(self.clone())
}
}