lash-core 0.1.0-alpha.42

Sans-IO turn machine and runtime kernel for the lash agent runtime.
Documentation
use super::support::*;

#[derive(Debug)]
pub struct ProviderRateLimiter {
    state: Mutex<ProviderRateLimiterState>,
}

#[derive(Debug)]
struct ProviderRateLimiterState {
    policy: ProviderRateLimitPolicy,
    semaphore: Option<Arc<tokio::sync::Semaphore>>,
    request_bucket: WindowBucket,
    token_bucket: WindowBucket,
}

#[derive(Clone, Debug)]
struct WindowBucket {
    used: u32,
    reset_at: Instant,
}

impl Default for WindowBucket {
    fn default() -> Self {
        Self {
            used: 0,
            reset_at: Instant::now(),
        }
    }
}

#[derive(Debug)]
pub struct ProviderRateLimitPermit {
    _concurrency: Option<tokio::sync::OwnedSemaphorePermit>,
}

impl ProviderRateLimiter {
    pub fn new(policy: ProviderRateLimitPolicy) -> Self {
        let semaphore = policy
            .max_concurrency
            .filter(|limit| *limit > 0)
            .map(tokio::sync::Semaphore::new)
            .map(Arc::new);
        Self {
            state: Mutex::new(ProviderRateLimiterState {
                policy,
                semaphore,
                request_bucket: WindowBucket::default(),
                token_bucket: WindowBucket::default(),
            }),
        }
    }

    pub fn configure(&self, policy: ProviderRateLimitPolicy) {
        let mut state = self.state.lock().expect("provider rate limiter lock");
        if state.policy.max_concurrency != policy.max_concurrency {
            state.semaphore = policy
                .max_concurrency
                .filter(|limit| *limit > 0)
                .map(tokio::sync::Semaphore::new)
                .map(Arc::new);
        }
        state.policy = policy;
    }

    pub async fn admit(&self, request: &LlmRequest) -> ProviderRateLimitPermit {
        let semaphore = self
            .state
            .lock()
            .expect("provider rate limiter lock")
            .semaphore
            .clone();
        let concurrency = match semaphore {
            Some(semaphore) => Some(semaphore.acquire_owned().await.expect("semaphore open")),
            None => None,
        };
        self.wait_for_buckets(1, estimate_request_tokens(request))
            .await;
        ProviderRateLimitPermit {
            _concurrency: concurrency,
        }
    }

    async fn wait_for_buckets(&self, requests: u32, tokens: u32) {
        loop {
            let wait = {
                let mut state = self.state.lock().expect("provider rate limiter lock");
                let now = Instant::now();
                let policy = state.policy.clone();
                let request_wait = bucket_wait(
                    &mut state.request_bucket,
                    now,
                    policy.requests_per_window,
                    policy.request_window_ms,
                    requests,
                );
                let token_wait = bucket_wait(
                    &mut state.token_bucket,
                    now,
                    policy.tokens_per_window,
                    policy.token_window_ms,
                    tokens,
                );
                match (request_wait, token_wait) {
                    (None, None) => return,
                    (Some(a), Some(b)) => Some(a.max(b)),
                    (Some(a), None) | (None, Some(a)) => Some(a),
                }
            };
            if let Some(wait) = wait {
                tokio::time::sleep(wait).await;
            }
        }
    }
}

fn bucket_wait(
    bucket: &mut WindowBucket,
    now: Instant,
    limit: Option<u32>,
    window_ms: Option<u64>,
    cost: u32,
) -> Option<Duration> {
    let limit = limit.filter(|limit| *limit > 0)?;
    let window = Duration::from_millis(window_ms.unwrap_or(60_000).max(1));
    if now >= bucket.reset_at {
        bucket.used = 0;
        bucket.reset_at = now + window;
    }
    if bucket.used.saturating_add(cost.min(limit)) <= limit {
        bucket.used = bucket.used.saturating_add(cost.min(limit));
        None
    } else {
        Some(bucket.reset_at.saturating_duration_since(now))
    }
}

fn estimate_request_tokens(request: &LlmRequest) -> u32 {
    let mut chars = request.model.len();
    for message in &request.messages {
        for block in message.blocks.iter() {
            match block {
                LlmContentBlock::Text { text, .. } => chars += text.len(),
                LlmContentBlock::ToolCall { input_json, .. } => chars += input_json.len(),
                LlmContentBlock::ToolResult { content, .. } => chars += content.len(),
                LlmContentBlock::Reasoning { text, .. } => chars += text.len(),
                LlmContentBlock::Image { .. } => chars += 256,
            }
        }
    }
    chars = chars.saturating_add(request.attachments.iter().map(|a| a.data.len() / 4).sum());
    ((chars / 4).max(1)).try_into().unwrap_or(u32::MAX)
}