tandem-server 0.5.8

HTTP server for Tandem engine APIs
use std::collections::HashMap;
use std::time::{Duration, Instant};

use serde_json::Value;
use tandem_channels::config::ChannelSecurityProfile;
use tandem_types::RequestPrincipal;
use tokio::sync::RwLock;

const DEFAULT_PROMPT_LIMIT_PER_MINUTE: u32 = 10;
const DEFAULT_DECISION_LIMIT_PER_MINUTE: u32 = 30;

#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ChannelRateLimitKey {
    pub channel: String,
    pub user_id: String,
}

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ChannelRateLimitKind {
    Prompt,
    Decision,
}

#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct ChannelRateLimitDecision {
    pub allowed: bool,
    pub retry_after_secs: u64,
}

#[derive(Debug)]
struct TokenBucket {
    tokens: f64,
    last_refill: Instant,
}

impl TokenBucket {
    fn new(capacity: u32, now: Instant) -> Self {
        Self {
            tokens: capacity as f64,
            last_refill: now,
        }
    }

    fn check(
        &mut self,
        capacity: u32,
        refill_per_sec: f64,
        now: Instant,
    ) -> ChannelRateLimitDecision {
        let elapsed = now.duration_since(self.last_refill).as_secs_f64();
        self.tokens = (self.tokens + elapsed * refill_per_sec).min(capacity as f64);
        self.last_refill = now;
        if self.tokens >= 1.0 {
            self.tokens -= 1.0;
            return ChannelRateLimitDecision {
                allowed: true,
                retry_after_secs: 0,
            };
        }
        let missing = 1.0 - self.tokens;
        let retry_after_secs = (missing / refill_per_sec).ceil().max(1.0) as u64;
        ChannelRateLimitDecision {
            allowed: false,
            retry_after_secs,
        }
    }
}

#[derive(Debug, Default)]
pub struct ChannelRateLimiter {
    buckets: RwLock<HashMap<String, TokenBucket>>,
}

impl ChannelRateLimiter {
    pub async fn check(
        &self,
        key: &ChannelRateLimitKey,
        kind: ChannelRateLimitKind,
        profile: ChannelSecurityProfile,
    ) -> ChannelRateLimitDecision {
        let capacity = rate_limit_capacity(kind, profile);
        let refill_per_sec = capacity as f64 / 60.0;
        let bucket_key = format!(
            "{}:{}:{}",
            key.channel.trim().to_ascii_lowercase(),
            key.user_id.trim().to_ascii_lowercase(),
            kind.as_str()
        );
        let now = Instant::now();
        let mut guard = self.buckets.write().await;
        guard
            .entry(bucket_key)
            .or_insert_with(|| TokenBucket::new(capacity, now))
            .check(capacity, refill_per_sec, now)
    }
}

impl ChannelRateLimitKind {
    fn as_str(self) -> &'static str {
        match self {
            Self::Prompt => "prompt",
            Self::Decision => "decision",
        }
    }
}

pub fn channel_rate_limit_key_from_session_metadata(
    metadata: Option<&Value>,
) -> Option<ChannelRateLimitKey> {
    let metadata = metadata?;
    let channel = metadata
        .get("channel")
        .or_else(|| metadata.get("source_platform"))
        .and_then(Value::as_str)?
        .trim();
    let user_id = metadata
        .get("user_id")
        .or_else(|| metadata.get("surface_user_id"))
        .or_else(|| metadata.get("sender_id"))
        .and_then(Value::as_str)?
        .trim();
    if channel.is_empty() || user_id.is_empty() {
        return None;
    }
    Some(ChannelRateLimitKey {
        channel: channel.to_ascii_lowercase(),
        user_id: user_id.to_string(),
    })
}

pub fn channel_rate_limit_key_from_principal(
    principal: &RequestPrincipal,
) -> Option<ChannelRateLimitKey> {
    let actor_id = principal.actor_id.as_deref()?;
    let mut parts = actor_id.splitn(4, ':');
    if parts.next()? != "channel" {
        return None;
    }
    let channel = parts.next()?.trim();
    let user_id = parts.next()?.trim();
    if channel.is_empty() || user_id.is_empty() {
        return None;
    }
    Some(ChannelRateLimitKey {
        channel: channel.to_ascii_lowercase(),
        user_id: user_id.to_string(),
    })
}

pub fn rate_limit_capacity(kind: ChannelRateLimitKind, profile: ChannelSecurityProfile) -> u32 {
    let base_env_name = match kind {
        ChannelRateLimitKind::Prompt => "TANDEM_CHANNEL_PROMPT_RATE_LIMIT_PER_MINUTE",
        ChannelRateLimitKind::Decision => "TANDEM_CHANNEL_DECISION_RATE_LIMIT_PER_MINUTE",
    };
    let profile_env_name = format!(
        "{}_{}",
        base_env_name,
        match profile {
            ChannelSecurityProfile::Operator => "OPERATOR",
            ChannelSecurityProfile::TrustedTeam => "TRUSTED_TEAM",
            ChannelSecurityProfile::PublicDemo => "PUBLIC_DEMO",
        }
    );
    read_positive_u32_env(&profile_env_name)
        .or_else(|| read_positive_u32_env(base_env_name))
        .unwrap_or(match kind {
            ChannelRateLimitKind::Prompt => DEFAULT_PROMPT_LIMIT_PER_MINUTE,
            ChannelRateLimitKind::Decision => DEFAULT_DECISION_LIMIT_PER_MINUTE,
        })
}

pub fn retry_after_duration(decision: ChannelRateLimitDecision) -> Duration {
    Duration::from_secs(decision.retry_after_secs.max(1))
}

fn read_positive_u32_env(name: &str) -> Option<u32> {
    std::env::var(name)
        .ok()
        .and_then(|raw| raw.parse::<u32>().ok())
        .filter(|value| *value > 0)
}

#[cfg(test)]
mod tests {
    use super::*;
    use serde_json::json;

    #[test]
    fn extracts_channel_rate_key_from_session_metadata() {
        let key = channel_rate_limit_key_from_session_metadata(Some(&json!({
            "channel": "Telegram",
            "user_id": "42"
        })))
        .unwrap();
        assert_eq!(key.channel, "telegram");
        assert_eq!(key.user_id, "42");
    }

    #[test]
    fn extracts_channel_rate_key_from_principal() {
        let principal = RequestPrincipal {
            actor_id: Some("channel:slack:U123".to_string()),
            source: "channel:slack".to_string(),
        };
        let key = channel_rate_limit_key_from_principal(&principal).unwrap();
        assert_eq!(key.channel, "slack");
        assert_eq!(key.user_id, "U123");
    }

    #[tokio::test]
    async fn eleventh_prompt_is_limited_by_default() {
        let limiter = ChannelRateLimiter::default();
        let key = ChannelRateLimitKey {
            channel: "telegram".to_string(),
            user_id: "42".to_string(),
        };
        for _ in 0..DEFAULT_PROMPT_LIMIT_PER_MINUTE {
            let decision = limiter
                .check(
                    &key,
                    ChannelRateLimitKind::Prompt,
                    ChannelSecurityProfile::PublicDemo,
                )
                .await;
            assert!(decision.allowed);
        }
        let decision = limiter
            .check(
                &key,
                ChannelRateLimitKind::Prompt,
                ChannelSecurityProfile::PublicDemo,
            )
            .await;
        assert!(!decision.allowed);
        assert!(decision.retry_after_secs > 0);
    }
}