langdb_core 0.3.2

AI gateway Core for LangDB AI Gateway.
Documentation
use crate::usage::{InMemoryStorage, LimitPeriod};
use actix_web::dev::forward_ready;
use actix_web::{
    dev::{Service, ServiceRequest, ServiceResponse, Transform},
    Error,
};
use serde::{Deserialize, Serialize};
use std::future::{ready, Future, Ready};
use std::pin::Pin;
use std::rc::Rc;
use std::sync::Arc;
use tokio::sync::Mutex;

pub const API_CALLS: &str = "api_calls";
pub const API_CALLS_BY_IP: &str = "api_calls_by_ip";

#[derive(Debug, Serialize, Deserialize, Clone, Default)]
pub struct RateLimiting {
    pub hourly: Option<u64>,
    pub daily: Option<u64>,
    pub monthly: Option<u64>,
}
pub struct RateLimitMiddleware;

impl<S, B> Transform<S, ServiceRequest> for RateLimitMiddleware
where
    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
    S::Future: 'static,
{
    type Response = ServiceResponse<B>;
    type Error = Error;
    type InitError = ();
    type Transform = RateLimitMiddlewareService<S>;
    type Future = Ready<Result<Self::Transform, Self::InitError>>;

    fn new_transform(&self, service: S) -> Self::Future {
        ready(Ok(RateLimitMiddlewareService {
            service: service.into(),
        }))
    }
}

pub struct RateLimitMiddlewareService<S> {
    service: Rc<S>,
}

type LocalBoxFuture<T> = Pin<Box<dyn Future<Output = T> + 'static>>;

impl<S, B> Service<ServiceRequest> for RateLimitMiddlewareService<S>
where
    S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error> + 'static,
    S::Future: 'static,
{
    type Response = ServiceResponse<B>;
    type Error = Error;
    type Future = LocalBoxFuture<Result<Self::Response, Self::Error>>;

    forward_ready!(service);

    fn call(&self, req: ServiceRequest) -> Self::Future {
        let service = Rc::clone(&self.service);

        Box::pin(async move {
            let rate_limit_config = req.app_data::<Option<RateLimiting>>().cloned();
            if let Some(Some(rate_limit)) = rate_limit_config {
                let storage = req
                    .app_data::<Arc<Mutex<InMemoryStorage>>>()
                    .unwrap()
                    .clone();

                if let Some(hourly) = rate_limit.hourly {
                    check_limit(storage.clone(), &LimitPeriod::Hour, hourly).await?;
                }
                if let Some(daily) = rate_limit.daily {
                    check_limit(storage.clone(), &LimitPeriod::Day, daily).await?;
                }
                if let Some(monthly) = rate_limit.monthly {
                    check_limit(storage.clone(), &LimitPeriod::Month, monthly).await?;
                }
            }

            service.call(req).await
        })
    }
}

async fn check_limit(
    storage: Arc<Mutex<InMemoryStorage>>,
    period: &LimitPeriod,
    limit: u64,
) -> Result<(), Error> {
    let current_calls = storage
        .lock()
        .await
        .increment_and_get_value(period, "default", API_CALLS, 1.0)
        .await;

    if current_calls > limit as f64 {
        Err(actix_web::error::ErrorTooManyRequests(
            "API call limit exceeded",
        ))
    } else {
        Ok(())
    }
}