lash-core 0.1.0-alpha.37

Sans-IO turn machine and runtime kernel for the lash agent runtime.
Documentation
use super::support::*;

pub const DEFAULT_REQUEST_TIMEOUT_MS: u64 = 300_000;
pub const DEFAULT_CHUNK_TIMEOUT_MS: u64 = 120_000;

#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub struct LlmTimeouts {
    pub request_timeout: Option<Duration>,
    pub chunk_timeout: Duration,
}

impl Default for LlmTimeouts {
    fn default() -> Self {
        Self {
            request_timeout: Some(Duration::from_millis(DEFAULT_REQUEST_TIMEOUT_MS)),
            chunk_timeout: Duration::from_millis(DEFAULT_CHUNK_TIMEOUT_MS),
        }
    }
}

#[derive(Clone, Copy, Debug, PartialEq, Eq)]
pub enum RequestTimeout {
    Disabled,
    Millis(u64),
}

impl Serialize for RequestTimeout {
    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
    where
        S: Serializer,
    {
        match self {
            Self::Disabled => serializer.serialize_bool(false),
            Self::Millis(value) => serializer.serialize_u64(*value),
        }
    }
}

impl<'de> Deserialize<'de> for RequestTimeout {
    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
    where
        D: Deserializer<'de>,
    {
        struct RequestTimeoutVisitor;

        impl Visitor<'_> for RequestTimeoutVisitor {
            type Value = RequestTimeout;

            fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
                formatter.write_str("a positive timeout in milliseconds or false")
            }

            fn visit_bool<E>(self, value: bool) -> Result<Self::Value, E>
            where
                E: de::Error,
            {
                if value {
                    return Err(E::custom("timeout must be a positive integer or false"));
                }
                Ok(RequestTimeout::Disabled)
            }

            fn visit_u64<E>(self, value: u64) -> Result<Self::Value, E>
            where
                E: de::Error,
            {
                if value == 0 {
                    return Err(E::custom("timeout must be greater than 0"));
                }
                Ok(RequestTimeout::Millis(value))
            }

            fn visit_i64<E>(self, value: i64) -> Result<Self::Value, E>
            where
                E: de::Error,
            {
                if value <= 0 {
                    return Err(E::custom("timeout must be greater than 0"));
                }
                Ok(RequestTimeout::Millis(value as u64))
            }
        }

        deserializer.deserialize_any(RequestTimeoutVisitor)
    }
}

/// Prompt-cache lifetime hint. Providers translate this into their own
/// wire dialect (Anthropic `cache_control` TTL, OpenRouter-Claude
/// `cache_control` markers via Chat Completions, OpenAI Responses
/// `prompt_cache_key` / `prompt_cache_retention`). Providers without a
/// cache-control concept (Google, Codex) read the value but emit nothing
/// for it.
#[derive(Clone, Copy, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "kebab-case")]
pub enum CacheRetention {
    /// Do not emit any cache_control markers.
    None,
    /// Default Anthropic ephemeral window (5 minutes).
    #[default]
    Short,
    /// Extend to a 1-hour TTL where the API supports it.
    Long,
}

impl CacheRetention {
    pub fn is_default(&self) -> bool {
        matches!(self, CacheRetention::Short)
    }
}

#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
#[serde(deny_unknown_fields)]
pub struct ProviderOptions {
    #[serde(default)]
    pub reliability: ProviderReliability,
    #[serde(default, skip_serializing_if = "ProviderThinkingPolicy::is_default")]
    pub thinking: ProviderThinkingPolicy,
    /// Per-request output-token cap. `None` lets each provider apply its
    /// own default. Providers translate to their wire-specific field
    /// (`max_tokens`, `max_output_tokens`, `maxOutputTokens`, …).
    #[serde(default, skip_serializing_if = "Option::is_none")]
    pub max_output_tokens: Option<u64>,
    /// Prompt-cache lifetime hint; see [`CacheRetention`].
    #[serde(default, skip_serializing_if = "CacheRetention::is_default")]
    pub cache_retention: CacheRetention,
}

impl ProviderOptions {
    pub fn is_default(&self) -> bool {
        self.reliability == ProviderReliability::default_llm()
            && self.thinking == ProviderThinkingPolicy::default()
            && self.max_output_tokens.is_none()
            && self.cache_retention.is_default()
    }

    pub fn llm_timeouts(&self) -> LlmTimeouts {
        self.reliability.timeouts.llm_timeouts()
    }
}

#[derive(Clone, Debug, Default, PartialEq, Eq, Serialize, Deserialize)]
#[serde(deny_unknown_fields)]
pub struct ProviderThinkingPolicy {
    #[serde(default)]
    pub expose: bool,
}

impl ProviderThinkingPolicy {
    pub fn is_default(&self) -> bool {
        !self.expose
    }
}

#[derive(Clone, Debug, PartialEq, Eq)]
pub struct ResolvedGenerationPolicy<TThinking> {
    pub max_output_tokens: u64,
    pub cache_retention: CacheRetention,
    pub expose_thinking: bool,
    pub thinking: TThinking,
}

pub fn resolve_generation_policy<TThinking>(
    generation: &crate::GenerationOptions,
    options: &ProviderOptions,
    provider_default_max_output_tokens: u64,
    thinking: TThinking,
) -> ResolvedGenerationPolicy<TThinking> {
    let max_output_tokens = generation
        .output_token_cap_u64()
        .or(options.max_output_tokens)
        .unwrap_or(provider_default_max_output_tokens);
    ResolvedGenerationPolicy {
        max_output_tokens,
        cache_retention: options.cache_retention,
        expose_thinking: options.thinking.expose,
        thinking,
    }
}

#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
pub struct ProviderReliability {
    #[serde(default)]
    pub timeouts: ProviderTimeoutPolicy,
    #[serde(default)]
    pub retry: ProviderRetryPolicy,
    #[serde(default)]
    pub rate_limits: ProviderRateLimitPolicy,
}

impl ProviderReliability {
    pub fn default_llm() -> Self {
        Self {
            timeouts: ProviderTimeoutPolicy::default(),
            retry: ProviderRetryPolicy::default(),
            rate_limits: ProviderRateLimitPolicy::default(),
        }
    }

    pub fn codex() -> Self {
        Self {
            retry: ProviderRetryPolicy {
                max_attempts: 4,
                base_delay_ms: 1_000,
                max_delay_ms: 4_000,
                jitter_ms: 0,
                retry_after_cap_ms: Some(60_000),
                enabled: true,
            },
            ..Self::default_llm()
        }
    }

    pub fn disabled() -> Self {
        Self {
            retry: ProviderRetryPolicy::disabled(),
            rate_limits: ProviderRateLimitPolicy::default(),
            timeouts: ProviderTimeoutPolicy::default(),
        }
    }

    pub fn builder() -> ProviderReliabilityBuilder {
        ProviderReliabilityBuilder {
            reliability: Self::default_llm(),
        }
    }
}

impl Default for ProviderReliability {
    fn default() -> Self {
        Self::default_llm()
    }
}

#[derive(Clone, Debug, Serialize, Deserialize, Default, PartialEq, Eq)]
pub struct ProviderTimeoutPolicy {
    #[serde(default, skip_serializing_if = "Option::is_none")]
    pub request_timeout: Option<RequestTimeout>,
    #[serde(default, skip_serializing_if = "Option::is_none")]
    pub chunk_timeout: Option<u64>,
}

impl ProviderTimeoutPolicy {
    pub fn llm_timeouts(&self) -> LlmTimeouts {
        let request_timeout = match self.request_timeout {
            Some(RequestTimeout::Disabled) => None,
            Some(RequestTimeout::Millis(ms)) => Some(Duration::from_millis(ms)),
            None => Some(Duration::from_millis(DEFAULT_REQUEST_TIMEOUT_MS)),
        };
        let chunk_timeout_ms = self
            .chunk_timeout
            .filter(|value| *value > 0)
            .unwrap_or(DEFAULT_CHUNK_TIMEOUT_MS);
        LlmTimeouts {
            request_timeout,
            chunk_timeout: Duration::from_millis(chunk_timeout_ms),
        }
    }
}

#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
pub struct ProviderRetryPolicy {
    pub enabled: bool,
    pub max_attempts: u32,
    pub base_delay_ms: u64,
    pub max_delay_ms: u64,
    pub jitter_ms: u64,
    #[serde(default, skip_serializing_if = "Option::is_none")]
    pub retry_after_cap_ms: Option<u64>,
}

impl Default for ProviderRetryPolicy {
    fn default() -> Self {
        Self {
            enabled: true,
            max_attempts: 4,
            base_delay_ms: 2_000,
            max_delay_ms: 10_000,
            jitter_ms: 0,
            retry_after_cap_ms: Some(60_000),
        }
    }
}

impl ProviderRetryPolicy {
    pub fn disabled() -> Self {
        Self {
            enabled: false,
            max_attempts: 1,
            base_delay_ms: 0,
            max_delay_ms: 0,
            jitter_ms: 0,
            retry_after_cap_ms: None,
        }
    }

    pub(crate) fn attempts(&self) -> u32 {
        if self.enabled {
            self.max_attempts.max(1)
        } else {
            1
        }
    }

    pub(crate) fn delay_for_attempt(
        &self,
        retry_index: u32,
        retry_after: Option<Duration>,
    ) -> Duration {
        if let Some(retry_after) = retry_after {
            return self
                .retry_after_cap_ms
                .map(Duration::from_millis)
                .map(|cap| retry_after.min(cap))
                .unwrap_or(retry_after);
        }
        let multiplier = 1u64.checked_shl(retry_index).unwrap_or(u64::MAX);
        let delay_ms = self
            .base_delay_ms
            .saturating_mul(multiplier)
            .min(self.max_delay_ms);
        Duration::from_millis(delay_ms.saturating_add(self.jitter_ms))
    }
}

#[derive(Clone, Debug, Serialize, Deserialize, Default, PartialEq, Eq)]
pub struct ProviderRateLimitPolicy {
    #[serde(default, skip_serializing_if = "Option::is_none")]
    pub max_concurrency: Option<usize>,
    #[serde(default, skip_serializing_if = "Option::is_none")]
    pub requests_per_window: Option<u32>,
    #[serde(default, skip_serializing_if = "Option::is_none")]
    pub request_window_ms: Option<u64>,
    #[serde(default, skip_serializing_if = "Option::is_none")]
    pub tokens_per_window: Option<u32>,
    #[serde(default, skip_serializing_if = "Option::is_none")]
    pub token_window_ms: Option<u64>,
}

pub struct ProviderReliabilityBuilder {
    reliability: ProviderReliability,
}

impl ProviderReliabilityBuilder {
    pub fn request_timeout(mut self, timeout: Option<RequestTimeout>) -> Self {
        self.reliability.timeouts.request_timeout = timeout;
        self
    }

    pub fn stream_chunk_timeout_ms(mut self, timeout_ms: Option<u64>) -> Self {
        self.reliability.timeouts.chunk_timeout = timeout_ms;
        self
    }

    pub fn max_attempts(mut self, attempts: u32) -> Self {
        self.reliability.retry.max_attempts = attempts.max(1);
        self
    }

    pub fn base_delay_ms(mut self, delay_ms: u64) -> Self {
        self.reliability.retry.base_delay_ms = delay_ms;
        self
    }

    pub fn max_delay_ms(mut self, delay_ms: u64) -> Self {
        self.reliability.retry.max_delay_ms = delay_ms;
        self
    }

    pub fn retry_after_cap_ms(mut self, cap_ms: Option<u64>) -> Self {
        self.reliability.retry.retry_after_cap_ms = cap_ms;
        self
    }

    pub fn max_concurrency(mut self, value: Option<usize>) -> Self {
        self.reliability.rate_limits.max_concurrency = value;
        self
    }

    pub fn requests_per_window(mut self, requests: Option<u32>, window_ms: Option<u64>) -> Self {
        self.reliability.rate_limits.requests_per_window = requests;
        self.reliability.rate_limits.request_window_ms = window_ms;
        self
    }

    pub fn tokens_per_window(mut self, tokens: Option<u32>, window_ms: Option<u64>) -> Self {
        self.reliability.rate_limits.tokens_per_window = tokens;
        self.reliability.rate_limits.token_window_ms = window_ms;
        self
    }

    pub fn build(self) -> ProviderReliability {
        self.reliability
    }
}