use super::manager::BudgetManager;
use super::types::BudgetScope;
use actix_web::dev::{Service, ServiceRequest, ServiceResponse, Transform, forward_ready};
use actix_web::web;
use futures::future::{Ready, ok};
use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use tracing::{debug, warn};
pub struct BudgetMiddleware {
manager: Arc<BudgetManager>,
scope_extractor: ScopeExtractor,
cost_estimator: CostEstimator,
}
type ScopeExtractor = Arc<dyn Fn(&ServiceRequest) -> Option<BudgetScope> + Send + Sync>;
type CostEstimator = Arc<dyn Fn(&ServiceRequest) -> f64 + Send + Sync>;
impl BudgetMiddleware {
pub fn new(manager: Arc<BudgetManager>) -> Self {
Self {
manager,
scope_extractor: Arc::new(default_scope_extractor),
cost_estimator: Arc::new(default_cost_estimator),
}
}
pub fn with_scope_extractor<F>(mut self, extractor: F) -> Self
where
F: Fn(&ServiceRequest) -> Option<BudgetScope> + Send + Sync + 'static,
{
self.scope_extractor = Arc::new(extractor);
self
}
pub fn with_cost_estimator<F>(mut self, estimator: F) -> Self
where
F: Fn(&ServiceRequest) -> f64 + Send + Sync + 'static,
{
self.cost_estimator = Arc::new(estimator);
self
}
}
impl Clone for BudgetMiddleware {
fn clone(&self) -> Self {
Self {
manager: Arc::clone(&self.manager),
scope_extractor: Arc::clone(&self.scope_extractor),
cost_estimator: Arc::clone(&self.cost_estimator),
}
}
}
fn default_scope_extractor(req: &ServiceRequest) -> Option<BudgetScope> {
if let Some(user_id) = req.headers().get("X-User-ID")
&& let Ok(user_id_str) = user_id.to_str()
{
return Some(BudgetScope::User(user_id_str.to_string()));
}
if let Some(auth_header) = req.headers().get("Authorization")
&& let Ok(auth_str) = auth_header.to_str()
&& let Some(api_key) = auth_str.strip_prefix("Bearer ")
{
let key_id = if api_key.len() > 16 {
&api_key[..16]
} else {
api_key
};
return Some(BudgetScope::ApiKey(key_id.to_string()));
}
if let Some(team_id) = req.headers().get("X-Team-ID")
&& let Ok(team_id_str) = team_id.to_str()
{
return Some(BudgetScope::Team(team_id_str.to_string()));
}
Some(BudgetScope::Global)
}
fn default_cost_estimator(_req: &ServiceRequest) -> f64 {
0.001
}
impl<S, B> Transform<S, ServiceRequest> for BudgetMiddleware
where
S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = actix_web::Error> + 'static,
S::Future: 'static,
B: 'static,
{
type Response = ServiceResponse<B>;
type Error = actix_web::Error;
type InitError = ();
type Transform = BudgetMiddlewareService<S>;
type Future = Ready<Result<Self::Transform, Self::InitError>>;
fn new_transform(&self, service: S) -> Self::Future {
ok(BudgetMiddlewareService {
service,
manager: Arc::clone(&self.manager),
scope_extractor: Arc::clone(&self.scope_extractor),
cost_estimator: Arc::clone(&self.cost_estimator),
})
}
}
pub struct BudgetMiddlewareService<S> {
service: S,
manager: Arc<BudgetManager>,
scope_extractor: ScopeExtractor,
cost_estimator: CostEstimator,
}
impl<S, B> Service<ServiceRequest> for BudgetMiddlewareService<S>
where
S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = actix_web::Error> + 'static,
S::Future: 'static,
B: 'static,
{
type Response = ServiceResponse<B>;
type Error = actix_web::Error;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>>>>;
forward_ready!(service);
fn call(&self, req: ServiceRequest) -> Self::Future {
let _manager = Arc::clone(&self.manager);
let _scope_extractor = Arc::clone(&self.scope_extractor);
let _cost_estimator = Arc::clone(&self.cost_estimator);
let fut = self.service.call(req);
Box::pin(async move {
let res = fut.await?;
Ok(res)
})
}
}
pub struct BudgetCheckMiddleware {
manager: Arc<BudgetManager>,
}
impl BudgetCheckMiddleware {
pub fn new(manager: Arc<BudgetManager>) -> Self {
Self { manager }
}
}
impl Clone for BudgetCheckMiddleware {
fn clone(&self) -> Self {
Self {
manager: Arc::clone(&self.manager),
}
}
}
impl<S, B> Transform<S, ServiceRequest> for BudgetCheckMiddleware
where
S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = actix_web::Error> + 'static,
S::Future: 'static,
B: 'static,
{
type Response = ServiceResponse<B>;
type Error = actix_web::Error;
type InitError = ();
type Transform = BudgetCheckMiddlewareService<S>;
type Future = Ready<Result<Self::Transform, Self::InitError>>;
fn new_transform(&self, service: S) -> Self::Future {
ok(BudgetCheckMiddlewareService {
service,
manager: Arc::clone(&self.manager),
})
}
}
pub struct BudgetCheckMiddlewareService<S> {
service: S,
manager: Arc<BudgetManager>,
}
impl<S, B> Service<ServiceRequest> for BudgetCheckMiddlewareService<S>
where
S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = actix_web::Error> + 'static,
S::Future: 'static,
B: 'static,
{
type Response = ServiceResponse<B>;
type Error = actix_web::Error;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>>>>;
forward_ready!(service);
fn call(&self, req: ServiceRequest) -> Self::Future {
let manager = Arc::clone(&self.manager);
let scope = extract_scope_from_request(&req);
let path = req.path().to_string();
let method = req.method().to_string();
let fut = self.service.call(req);
Box::pin(async move {
if !manager.is_enabled().await {
let res = fut.await?;
return Ok(res);
}
if let Some(scope) = scope {
let estimated_cost = 0.01; let check_result = manager.check_spend(&scope, estimated_cost).await;
if !check_result.allowed {
warn!(
"Budget exceeded for {} {} (scope: {}): ${:.2} / ${:.2}",
method, path, scope, check_result.current_spend, check_result.max_budget
);
return Err(actix_web::error::ErrorTooManyRequests(
serde_json::json!({
"error": {
"type": "budget_exceeded",
"message": "Budget limit exceeded",
"code": 429,
"details": {
"scope": scope.to_string(),
"current_spend": check_result.current_spend,
"max_budget": check_result.max_budget,
"remaining": check_result.remaining,
"usage_percentage": check_result.usage_percentage
}
}
})
.to_string(),
));
}
debug!(
"Budget check passed for {} {} (scope: {}): ${:.2} remaining",
method, path, scope, check_result.remaining
);
}
let res = fut.await?;
Ok(res)
})
}
}
fn extract_scope_from_request(req: &ServiceRequest) -> Option<BudgetScope> {
if let Some(scope_header) = req.headers().get("X-Budget-Scope")
&& let Ok(scope_str) = scope_header.to_str()
&& let Some(scope) = BudgetScope::from_key(scope_str)
{
return Some(scope);
}
if let Some(user_id) = req.headers().get("X-User-ID")
&& let Ok(user_id_str) = user_id.to_str()
&& !user_id_str.is_empty()
{
return Some(BudgetScope::User(user_id_str.to_string()));
}
if let Some(team_id) = req.headers().get("X-Team-ID")
&& let Ok(team_id_str) = team_id.to_str()
&& !team_id_str.is_empty()
{
return Some(BudgetScope::Team(team_id_str.to_string()));
}
if let Some(auth_header) = req.headers().get("Authorization")
&& let Ok(auth_str) = auth_header.to_str()
&& let Some(api_key) = auth_str.strip_prefix("Bearer ")
&& !api_key.is_empty()
{
let key_id = if api_key.len() > 16 {
format!("{}...", &api_key[..16])
} else {
api_key.to_string()
};
return Some(BudgetScope::ApiKey(key_id));
}
Some(BudgetScope::Global)
}
#[derive(Clone)]
pub struct BudgetRecorder {
manager: Arc<BudgetManager>,
}
impl BudgetRecorder {
pub fn new(manager: Arc<BudgetManager>) -> Self {
Self { manager }
}
pub async fn record_spend(&self, scope: &BudgetScope, amount: f64) {
if let Some(result) = self.manager.record_spend(scope, amount).await {
debug!(
"Recorded spend ${:.4} for {}: ${:.2} / ${:.2}",
amount, scope, result.current_spend, result.max_budget
);
if result.should_alert_soft_limit {
warn!(
"Budget soft limit reached for {}: ${:.2} / ${:.2} ({:.1}%)",
scope,
result.current_spend,
result.max_budget,
(result.current_spend / result.max_budget) * 100.0
);
}
if result.should_alert_exceeded {
warn!(
"Budget exceeded for {}: ${:.2} / ${:.2} ({:.1}%)",
scope,
result.current_spend,
result.max_budget,
(result.current_spend / result.max_budget) * 100.0
);
}
}
}
pub async fn record_request_spend(
&self,
user_id: Option<&str>,
team_id: Option<&str>,
api_key: Option<&str>,
model: Option<&str>,
provider: Option<&str>,
amount: f64,
) {
if let Some(user) = user_id {
let scope = BudgetScope::User(user.to_string());
self.record_spend(&scope, amount).await;
}
if let Some(team) = team_id {
let scope = BudgetScope::Team(team.to_string());
self.record_spend(&scope, amount).await;
}
if let Some(key) = api_key {
let key_id = if key.len() > 16 {
format!("{}...", &key[..16])
} else {
key.to_string()
};
let scope = BudgetScope::ApiKey(key_id);
self.record_spend(&scope, amount).await;
}
if let Some(model_name) = model {
let scope = BudgetScope::Model(model_name.to_string());
self.record_spend(&scope, amount).await;
}
if let Some(provider_name) = provider {
let scope = BudgetScope::Provider(provider_name.to_string());
self.record_spend(&scope, amount).await;
}
self.record_spend(&BudgetScope::Global, amount).await;
}
}
pub trait BudgetRecorderExt {
fn budget_recorder(&self) -> Option<web::Data<BudgetRecorder>>;
}
#[cfg(test)]
mod tests {
use super::*;
use crate::core::budget::types::BudgetConfig;
#[tokio::test]
async fn test_budget_recorder() {
let manager = Arc::new(BudgetManager::new());
manager
.create_budget(BudgetScope::Global, BudgetConfig::new("Global", 100.0))
.await
.unwrap();
let recorder = BudgetRecorder::new(Arc::clone(&manager));
recorder.record_spend(&BudgetScope::Global, 10.0).await;
assert_eq!(manager.get_current_spend(&BudgetScope::Global), 10.0);
}
#[tokio::test]
async fn test_budget_recorder_multiple_scopes() {
let manager = Arc::new(BudgetManager::new());
manager
.create_budget(BudgetScope::Global, BudgetConfig::new("Global", 1000.0))
.await
.unwrap();
manager
.create_budget(
BudgetScope::User("user-1".to_string()),
BudgetConfig::new("User 1", 100.0),
)
.await
.unwrap();
manager
.create_budget(
BudgetScope::Model("gpt-4".to_string()),
BudgetConfig::new("GPT-4", 500.0),
)
.await
.unwrap();
let recorder = BudgetRecorder::new(Arc::clone(&manager));
recorder
.record_request_spend(Some("user-1"), None, None, Some("gpt-4"), None, 5.0)
.await;
assert_eq!(manager.get_current_spend(&BudgetScope::Global), 5.0);
assert_eq!(
manager.get_current_spend(&BudgetScope::User("user-1".to_string())),
5.0
);
assert_eq!(
manager.get_current_spend(&BudgetScope::Model("gpt-4".to_string())),
5.0
);
}
#[test]
fn test_scope_from_key() {
assert_eq!(
BudgetScope::from_key("user:test-user"),
Some(BudgetScope::User("test-user".to_string()))
);
assert_eq!(
BudgetScope::from_key("team:team-1"),
Some(BudgetScope::Team("team-1".to_string()))
);
assert_eq!(BudgetScope::from_key("global"), Some(BudgetScope::Global));
assert_eq!(BudgetScope::from_key("invalid"), None);
}
#[test]
fn test_default_scope_extractor() {
}
#[tokio::test]
async fn test_budget_check_middleware_creation() {
let manager = Arc::new(BudgetManager::new());
let middleware = BudgetCheckMiddleware::new(Arc::clone(&manager));
let _cloned = middleware.clone();
}
#[tokio::test]
async fn test_budget_middleware_creation() {
let manager = Arc::new(BudgetManager::new());
let middleware = BudgetMiddleware::new(Arc::clone(&manager));
let _cloned = middleware.clone();
let _custom = BudgetMiddleware::new(Arc::clone(&manager))
.with_scope_extractor(|_req| Some(BudgetScope::Global))
.with_cost_estimator(|_req| 0.5);
}
}