lash-core 0.1.0-alpha.39

Sans-IO turn machine and runtime kernel for the lash agent runtime.
Documentation
use super::support::*;
use crate::LlmTerminalReason;

/// A configured LLM backend: its identity, host-config serialization, its
/// generation options, and the request transport.
#[async_trait]
pub trait Provider: Send + Sync + std::fmt::Debug {
    fn kind(&self) -> &'static str;

    fn options(&self) -> ProviderOptions;
    fn set_options(&mut self, options: ProviderOptions);

    /// Emit the provider-specific JSON body used by [`ProviderSpec`]. The
    /// object must NOT contain a `type` field — [`ProviderSpec::Serialize`]
    /// layers that on top.
    fn serialize_config(&self) -> serde_json::Value;

    async fn complete(&mut self, request: LlmRequest) -> Result<LlmResponse, LlmTransportError>;

    fn requires_streaming(&self) -> bool {
        false
    }

    fn clone_boxed(&self) -> Box<dyn Provider>;
}

pub trait ProviderModelPolicy: Send + Sync + std::fmt::Debug {
    fn supported_variants(&self, model: &str) -> &'static [&'static str];

    fn input_usage_excludes_cached_tokens(&self) -> bool {
        false
    }
}

pub trait ProviderFailureClassifier: Send + Sync + std::fmt::Debug {
    fn classify(&self, failure: ProviderFailure) -> ProviderFailure;
}

#[derive(Clone, Debug, Default)]
pub struct DefaultProviderFailureClassifier;

impl ProviderFailureClassifier for DefaultProviderFailureClassifier {
    fn classify(&self, mut failure: ProviderFailure) -> ProviderFailure {
        if let Some(status) = failure.status.or_else(|| {
            failure
                .code
                .as_deref()
                .and_then(|code| code.parse::<u16>().ok())
        }) {
            failure.status = Some(status);
            if failure.kind == ProviderFailureKind::Unknown {
                failure.kind = ProviderFailureKind::Http;
            }
            failure.retryable = matches!(status, 408 | 409 | 425 | 429 | 500 | 502 | 503 | 504);
            if matches!(status, 401 | 403) {
                failure.kind = ProviderFailureKind::Auth;
            } else if matches!(status, 400 | 413 | 422) {
                failure.kind = ProviderFailureKind::Validation;
            }
        } else if matches!(
            failure.kind,
            ProviderFailureKind::Transport | ProviderFailureKind::Timeout
        ) {
            failure.retryable = true;
        }

        let haystack = format!(
            "{}\n{}\n{}",
            failure.code.as_deref().unwrap_or_default(),
            failure.message,
            failure.raw.as_deref().unwrap_or_default()
        )
        .to_ascii_lowercase();
        if is_context_overflow_text(&haystack) {
            failure.kind = ProviderFailureKind::Validation;
            failure.retryable = false;
            failure.terminal_reason = LlmTerminalReason::ContextOverflow;
        }
        if haystack.contains("insufficient_quota")
            || haystack.contains("usage_limit_reached")
            || haystack.contains("usage_not_included")
            || haystack.contains("quota")
        {
            failure.kind = ProviderFailureKind::Quota;
            failure.retryable = false;
        }
        if haystack.contains("content_filter")
            || haystack.contains("prohibited_content")
            || haystack.contains("safety")
            || haystack.contains("sensitive")
        {
            failure.terminal_reason = LlmTerminalReason::ContentFilter;
        }
        if haystack.contains("model_not_found")
            || haystack.contains("unsupported model")
            || haystack.contains("does not exist")
        {
            failure.kind = ProviderFailureKind::Unsupported;
            failure.retryable = false;
        }
        failure
    }
}

pub fn is_context_overflow_text(haystack: &str) -> bool {
    let lower = haystack.to_ascii_lowercase();
    if lower.contains("rate limit")
        || lower.contains("rate_limit")
        || lower.contains("ratelimit")
        || lower.contains("throttle")
        || lower.contains("throttling")
        || lower.contains("too many requests")
        || lower.contains("tokens per minute")
        || lower.contains("tpm")
        || lower.contains("quota")
    {
        return false;
    }

    lower.contains("context_length_exceeded")
        || lower.contains("context_length")
        || lower.contains("context length")
        || lower.contains("maximum context")
        || lower.contains("max context")
        || lower.contains("context window")
        || lower.contains("context window exceeds limit")
        || lower.contains("exceeds the context window")
        || lower.contains("prompt is too long")
        || lower.contains("prompt too long")
        || lower.contains("request_too_large")
        || lower.contains("input token count") && lower.contains("exceeds the maximum")
        || lower.contains("maximum prompt length is")
        || lower.contains("reduce the length of the messages")
        || lower.contains("maximum context length is")
        || lower.contains("model's context length")
        || lower.contains("models context length")
        || lower.contains("exceeds the available context size")
        || lower.contains("greater than the context length")
        || lower.contains("exceeded model token limit")
        || lower.contains("too large for model with")
        || lower.contains("model_context_window_exceeded")
        || lower.contains("too many tokens")
        || lower.contains("exceeds the maximum number of tokens")
        || lower.contains("exceeds maximum number of tokens")
        || lower.contains("request too large")
        || lower.contains("input is too long")
        || lower.contains("token limit exceeded")
        || lower.contains("reduce the length of the messages")
        || lower.contains("reduce the length of your prompt")
}