lash-core 0.1.0-alpha.40

Sans-IO turn machine and runtime kernel for the lash agent runtime.
Documentation
use super::support::*;

/// Component bundle returned by provider factories.
#[derive(Debug)]
pub struct ProviderComponents {
    pub provider: Box<dyn Provider>,
    pub model_policy: Arc<dyn ProviderModelPolicy>,
    pub failure_classifier: Arc<dyn ProviderFailureClassifier>,
    pub rate_limiter: Arc<ProviderRateLimiter>,
}

impl ProviderComponents {
    pub fn new(provider: Box<dyn Provider>, model_policy: Arc<dyn ProviderModelPolicy>) -> Self {
        let options = provider.options();
        Self {
            provider,
            model_policy,
            failure_classifier: Arc::new(DefaultProviderFailureClassifier),
            rate_limiter: Arc::new(ProviderRateLimiter::new(options.reliability.rate_limits)),
        }
    }

    /// Install a transport-level decorator that wraps the provider.
    pub fn map_provider(
        mut self,
        map: impl FnOnce(Box<dyn Provider>) -> Box<dyn Provider>,
    ) -> Self {
        self.provider = map(self.provider);
        self
    }

    pub fn with_failure_classifier(
        mut self,
        classifier: Arc<dyn ProviderFailureClassifier>,
    ) -> Self {
        self.failure_classifier = classifier;
        self
    }
}

impl Clone for ProviderComponents {
    fn clone(&self) -> Self {
        Self {
            provider: self.provider.clone_boxed(),
            model_policy: Arc::clone(&self.model_policy),
            failure_classifier: Arc::clone(&self.failure_classifier),
            rate_limiter: Arc::clone(&self.rate_limiter),
        }
    }
}

/// Owning handle to provider components. This is an executable transport
/// handle supplied by the host, not a persistence format.
pub struct ProviderHandle {
    components: ProviderComponents,
}

impl ProviderHandle {
    pub fn new(components: ProviderComponents) -> Self {
        Self { components }
    }

    pub fn components(&self) -> &ProviderComponents {
        &self.components
    }

    pub fn components_mut(&mut self) -> &mut ProviderComponents {
        &mut self.components
    }

    pub fn kind(&self) -> &'static str {
        self.components.provider.kind()
    }

    pub fn supported_variants(&self, model: &str) -> &'static [&'static str] {
        self.components.model_policy.supported_variants(model)
    }

    pub fn validate_variant(&self, model: &str, variant: &str) -> Result<(), String> {
        let variants = self.supported_variants(model);
        if variants.is_empty() {
            return Err(format!(
                "Model `{}` on {} does not expose configurable variants.",
                model,
                self.kind()
            ));
        }
        if variants.contains(&variant) {
            return Ok(());
        }
        Err(format!(
            "Unsupported variant `{}` for `{}` on {}. Available: {}",
            variant,
            model,
            self.kind(),
            variants.join(", ")
        ))
    }

    pub fn input_usage_excludes_cached_tokens(&self) -> bool {
        self.components
            .model_policy
            .input_usage_excludes_cached_tokens()
    }

    pub fn options(&self) -> ProviderOptions {
        self.components.provider.options()
    }

    pub fn set_options(&mut self, options: ProviderOptions) {
        self.components
            .rate_limiter
            .configure(options.reliability.rate_limits.clone());
        self.components.provider.set_options(options)
    }

    pub fn requires_streaming(&self) -> bool {
        self.components.provider.requires_streaming()
    }

    pub async fn complete(
        &mut self,
        request: LlmRequest,
    ) -> Result<LlmResponse, LlmTransportError> {
        let reliability = self.options().reliability;
        let attempts = reliability.retry.attempts();
        let mut attempt = 0;
        loop {
            let _permit = self.components.rate_limiter.admit(&request).await;
            let result = self.components.provider.complete(request.clone()).await;
            match result {
                Ok(response) => return Ok(response),
                Err(failure) => {
                    let failure = self.components.failure_classifier.classify(failure);
                    if attempt + 1 >= attempts || !failure.retryable {
                        return Err(failure);
                    }
                    let delay = reliability
                        .retry
                        .delay_for_attempt(attempt, failure.retry_after);
                    tracing::debug!(
                        target: "lash_core::provider::reliability",
                        provider = self.kind(),
                        attempt = attempt + 1,
                        max_attempts = attempts,
                        delay_ms = delay.as_millis() as u64,
                        err = %failure.message,
                        "provider call failed with retryable failure; sleeping before retry"
                    );
                    if let Some(events) = request.stream_events.as_ref() {
                        events.send(crate::llm::types::LlmStreamEvent::RetryStatus {
                            wait_seconds: delay.as_secs(),
                            attempt: (attempt + 1) as usize,
                            max_attempts: attempts as usize,
                            reason: failure.message.clone(),
                        });
                    }
                    tokio::time::sleep(delay).await;
                    attempt += 1;
                }
            }
        }
    }

    pub fn to_spec(&self) -> ProviderSpec {
        ProviderSpec {
            kind: self.kind().to_string(),
            config: self.components.provider.serialize_config(),
        }
    }

    /// Validate model syntax only.
    pub fn validate_model_name(&self, model: &str) -> Result<(), String> {
        let m = model.trim();
        if m.is_empty() {
            return Err("model cannot be empty".to_string());
        }
        if m.contains(char::is_whitespace) {
            return Err("model cannot contain whitespace".to_string());
        }
        Ok(())
    }
}

impl std::fmt::Debug for ProviderHandle {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        self.components.fmt(f)
    }
}

impl Clone for ProviderHandle {
    fn clone(&self) -> Self {
        Self {
            components: self.components.clone(),
        }
    }
}

impl PartialEq for ProviderHandle {
    fn eq(&self, other: &Self) -> bool {
        self.kind() == other.kind() && self.to_spec().config == other.to_spec().config
    }
}

impl Eq for ProviderHandle {}

impl Default for ProviderHandle {
    fn default() -> Self {
        Self::new(UnconfiguredProvider::default().into_components())
    }
}

/// Placeholder provider used when `SessionPolicy::default()` is
/// constructed without an explicit provider. Every transport-level
/// method errors; calling code MUST replace this before executing a
/// turn. It exists solely so `..Default::default()` shorthand keeps
/// working in host code that always overrides the provider field.
#[derive(Clone, Debug, Default)]
pub struct UnconfiguredProvider {
    options: ProviderOptions,
}

impl UnconfiguredProvider {
    fn into_components(self) -> ProviderComponents {
        ProviderComponents::new(Box::new(self), Arc::new(StaticModelPolicy::new()))
    }
}

#[async_trait]
impl Provider for UnconfiguredProvider {
    fn kind(&self) -> &'static str {
        "unconfigured"
    }

    fn options(&self) -> ProviderOptions {
        self.options.clone()
    }

    fn set_options(&mut self, options: ProviderOptions) {
        self.options = options;
    }

    fn serialize_config(&self) -> serde_json::Value {
        serde_json::Value::Object(Default::default())
    }

    async fn complete(&mut self, _request: LlmRequest) -> Result<LlmResponse, LlmTransportError> {
        Err(LlmTransportError::new(
            "no provider configured: host must set SessionPolicy.provider before running a turn",
        ))
    }

    fn clone_boxed(&self) -> Box<dyn Provider> {
        Box::new(self.clone())
    }
}