use crate::core::budget::{
BudgetStatus, Currency, ModelLimitConfig, ProviderLimitConfig, ResetPeriod, UnifiedBudgetLimits,
};
use crate::server::routes::ApiResponse;
use actix_web::{HttpResponse, Result as ActixResult, web};
use serde::{Deserialize, Serialize};
use std::sync::Arc;
use tracing::{error, info, warn};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SetProviderBudgetRequest {
pub provider: String,
pub max_budget: f64,
#[serde(default = "default_reset_period")]
pub reset_period: ResetPeriod,
#[serde(default = "default_soft_limit_percentage")]
pub soft_limit_percentage: f64,
#[serde(default)]
pub currency: Currency,
#[serde(default = "default_enabled")]
pub enabled: bool,
}
fn default_reset_period() -> ResetPeriod {
ResetPeriod::Monthly
}
fn default_soft_limit_percentage() -> f64 {
0.8
}
fn default_enabled() -> bool {
true
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SetModelBudgetRequest {
pub model: String,
pub max_budget: f64,
#[serde(default = "default_reset_period")]
pub reset_period: ResetPeriod,
#[serde(default = "default_soft_limit_percentage")]
pub soft_limit_percentage: f64,
#[serde(default)]
pub currency: Currency,
#[serde(default = "default_enabled")]
pub enabled: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProviderBudgetResponse {
pub provider: String,
pub max_budget: f64,
pub current_spend: f64,
pub remaining: f64,
pub usage_percentage: f64,
pub status: BudgetStatus,
pub reset_period: ResetPeriod,
pub currency: Currency,
pub enabled: bool,
pub request_count: u64,
pub last_reset_at: Option<chrono::DateTime<chrono::Utc>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModelBudgetResponse {
pub model: String,
pub max_budget: f64,
pub current_spend: f64,
pub remaining: f64,
pub usage_percentage: f64,
pub status: BudgetStatus,
pub reset_period: ResetPeriod,
pub currency: Currency,
pub enabled: bool,
pub request_count: u64,
pub last_reset_at: Option<chrono::DateTime<chrono::Utc>>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ListProviderBudgetsResponse {
pub providers: Vec<ProviderBudgetResponse>,
pub total: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ListModelBudgetsResponse {
pub models: Vec<ModelBudgetResponse>,
pub total: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DeleteBudgetResponse {
pub success: bool,
pub message: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BudgetSummaryResponse {
pub total_provider_budgets: usize,
pub total_model_budgets: usize,
pub exceeded_providers: Vec<String>,
pub warning_providers: Vec<String>,
pub total_provider_allocated: f64,
pub total_provider_spent: f64,
pub total_model_allocated: f64,
pub total_model_spent: f64,
}
pub async fn set_provider_budget(
budget_limits: web::Data<Arc<UnifiedBudgetLimits>>,
request: web::Json<SetProviderBudgetRequest>,
) -> ActixResult<HttpResponse> {
if request.provider.trim().is_empty() {
return Ok(HttpResponse::BadRequest().json(ApiResponse::<()>::error(
"Provider name cannot be empty".to_string(),
)));
}
if request.max_budget <= 0.0 {
return Ok(HttpResponse::BadRequest().json(ApiResponse::<()>::error(
"max_budget must be greater than 0".to_string(),
)));
}
if request.soft_limit_percentage < 0.0 || request.soft_limit_percentage > 1.0 {
return Ok(HttpResponse::BadRequest().json(ApiResponse::<()>::error(
"soft_limit_percentage must be between 0.0 and 1.0".to_string(),
)));
}
let config = ProviderLimitConfig {
max_budget: request.max_budget,
reset_period: request.reset_period,
soft_limit_percentage: request.soft_limit_percentage,
currency: request.currency,
enabled: request.enabled,
};
budget_limits
.providers
.set_provider_limit(&request.provider, config);
info!(
"Set provider budget for '{}': ${:.2} ({})",
request.provider, request.max_budget, request.reset_period
);
match budget_limits
.providers
.get_provider_usage(&request.provider)
{
Some(usage) => {
let response = ProviderBudgetResponse {
provider: usage.provider_name,
max_budget: usage.max_budget,
current_spend: usage.current_spend,
remaining: usage.remaining,
usage_percentage: usage.usage_percentage,
status: usage.status,
reset_period: usage.reset_period,
currency: request.currency,
enabled: request.enabled,
request_count: usage.request_count,
last_reset_at: usage.last_reset_at,
};
Ok(HttpResponse::Created().json(ApiResponse::success(response)))
}
None => {
error!("Failed to retrieve provider budget after creation");
Ok(
HttpResponse::InternalServerError().json(ApiResponse::<()>::error(
"Failed to retrieve budget".to_string(),
)),
)
}
}
}
pub async fn list_provider_budgets(
budget_limits: web::Data<Arc<UnifiedBudgetLimits>>,
) -> ActixResult<HttpResponse> {
let usage_list = budget_limits.providers.list_provider_usage();
let providers: Vec<ProviderBudgetResponse> = usage_list
.into_iter()
.map(|usage| {
let budgets = budget_limits.providers.list_provider_budgets();
let budget = budgets
.iter()
.find(|b| b.provider_name == usage.provider_name);
ProviderBudgetResponse {
provider: usage.provider_name,
max_budget: usage.max_budget,
current_spend: usage.current_spend,
remaining: usage.remaining,
usage_percentage: usage.usage_percentage,
status: usage.status,
reset_period: usage.reset_period,
currency: budget.map(|b| b.currency).unwrap_or_default(),
enabled: budget.map(|b| b.enabled).unwrap_or(true),
request_count: usage.request_count,
last_reset_at: usage.last_reset_at,
}
})
.collect();
let response = ListProviderBudgetsResponse {
total: providers.len(),
providers,
};
Ok(HttpResponse::Ok().json(ApiResponse::success(response)))
}
pub async fn get_provider_budget(
budget_limits: web::Data<Arc<UnifiedBudgetLimits>>,
path: web::Path<String>,
) -> ActixResult<HttpResponse> {
let provider_name = path.into_inner();
match budget_limits.providers.get_provider_usage(&provider_name) {
Some(usage) => {
let budgets = budget_limits.providers.list_provider_budgets();
let budget = budgets.iter().find(|b| b.provider_name == provider_name);
let response = ProviderBudgetResponse {
provider: usage.provider_name,
max_budget: usage.max_budget,
current_spend: usage.current_spend,
remaining: usage.remaining,
usage_percentage: usage.usage_percentage,
status: usage.status,
reset_period: usage.reset_period,
currency: budget.map(|b| b.currency).unwrap_or_default(),
enabled: budget.map(|b| b.enabled).unwrap_or(true),
request_count: usage.request_count,
last_reset_at: usage.last_reset_at,
};
Ok(HttpResponse::Ok().json(ApiResponse::success(response)))
}
None => {
warn!("Provider budget not found: {}", provider_name);
Ok(
HttpResponse::NotFound().json(ApiResponse::<()>::error(format!(
"Provider budget not found: {}",
provider_name
))),
)
}
}
}
pub async fn delete_provider_budget(
budget_limits: web::Data<Arc<UnifiedBudgetLimits>>,
path: web::Path<String>,
) -> ActixResult<HttpResponse> {
let provider_name = path.into_inner();
if budget_limits
.providers
.remove_provider_limit(&provider_name)
{
info!("Removed provider budget for '{}'", provider_name);
let response = DeleteBudgetResponse {
success: true,
message: format!("Provider budget '{}' removed successfully", provider_name),
};
Ok(HttpResponse::Ok().json(ApiResponse::success(response)))
} else {
warn!("Provider budget not found for deletion: {}", provider_name);
Ok(
HttpResponse::NotFound().json(ApiResponse::<()>::error(format!(
"Provider budget not found: {}",
provider_name
))),
)
}
}
pub async fn reset_provider_budget(
budget_limits: web::Data<Arc<UnifiedBudgetLimits>>,
path: web::Path<String>,
) -> ActixResult<HttpResponse> {
let provider_name = path.into_inner();
if budget_limits
.providers
.reset_provider_budget(&provider_name)
{
info!("Reset provider budget for '{}'", provider_name);
match budget_limits.providers.get_provider_usage(&provider_name) {
Some(usage) => {
let budgets = budget_limits.providers.list_provider_budgets();
let budget = budgets.iter().find(|b| b.provider_name == provider_name);
let response = ProviderBudgetResponse {
provider: usage.provider_name,
max_budget: usage.max_budget,
current_spend: usage.current_spend,
remaining: usage.remaining,
usage_percentage: usage.usage_percentage,
status: usage.status,
reset_period: usage.reset_period,
currency: budget.map(|b| b.currency).unwrap_or_default(),
enabled: budget.map(|b| b.enabled).unwrap_or(true),
request_count: usage.request_count,
last_reset_at: usage.last_reset_at,
};
Ok(HttpResponse::Ok().json(ApiResponse::success(response)))
}
None => Ok(
HttpResponse::InternalServerError().json(ApiResponse::<()>::error(
"Failed to retrieve budget after reset".to_string(),
)),
),
}
} else {
warn!("Provider budget not found for reset: {}", provider_name);
Ok(
HttpResponse::NotFound().json(ApiResponse::<()>::error(format!(
"Provider budget not found: {}",
provider_name
))),
)
}
}
pub async fn set_model_budget(
budget_limits: web::Data<Arc<UnifiedBudgetLimits>>,
request: web::Json<SetModelBudgetRequest>,
) -> ActixResult<HttpResponse> {
if request.model.trim().is_empty() {
return Ok(HttpResponse::BadRequest().json(ApiResponse::<()>::error(
"Model name cannot be empty".to_string(),
)));
}
if request.max_budget <= 0.0 {
return Ok(HttpResponse::BadRequest().json(ApiResponse::<()>::error(
"max_budget must be greater than 0".to_string(),
)));
}
if request.soft_limit_percentage < 0.0 || request.soft_limit_percentage > 1.0 {
return Ok(HttpResponse::BadRequest().json(ApiResponse::<()>::error(
"soft_limit_percentage must be between 0.0 and 1.0".to_string(),
)));
}
let config = ModelLimitConfig {
max_budget: request.max_budget,
reset_period: request.reset_period,
soft_limit_percentage: request.soft_limit_percentage,
currency: request.currency,
enabled: request.enabled,
};
budget_limits.models.set_model_limit(&request.model, config);
info!(
"Set model budget for '{}': ${:.2} ({})",
request.model, request.max_budget, request.reset_period
);
match budget_limits.models.get_model_usage(&request.model) {
Some(usage) => {
let response = ModelBudgetResponse {
model: usage.model_name,
max_budget: usage.max_budget,
current_spend: usage.current_spend,
remaining: usage.remaining,
usage_percentage: usage.usage_percentage,
status: usage.status,
reset_period: usage.reset_period,
currency: request.currency,
enabled: request.enabled,
request_count: usage.request_count,
last_reset_at: usage.last_reset_at,
};
Ok(HttpResponse::Created().json(ApiResponse::success(response)))
}
None => {
error!("Failed to retrieve model budget after creation");
Ok(
HttpResponse::InternalServerError().json(ApiResponse::<()>::error(
"Failed to retrieve budget".to_string(),
)),
)
}
}
}
pub async fn list_model_budgets(
budget_limits: web::Data<Arc<UnifiedBudgetLimits>>,
) -> ActixResult<HttpResponse> {
let usage_list = budget_limits.models.list_model_usage();
let models: Vec<ModelBudgetResponse> = usage_list
.into_iter()
.map(|usage| {
let budgets = budget_limits.models.list_model_budgets();
let budget = budgets.iter().find(|b| b.model_name == usage.model_name);
ModelBudgetResponse {
model: usage.model_name,
max_budget: usage.max_budget,
current_spend: usage.current_spend,
remaining: usage.remaining,
usage_percentage: usage.usage_percentage,
status: usage.status,
reset_period: usage.reset_period,
currency: budget.map(|b| b.currency).unwrap_or_default(),
enabled: budget.map(|b| b.enabled).unwrap_or(true),
request_count: usage.request_count,
last_reset_at: usage.last_reset_at,
}
})
.collect();
let response = ListModelBudgetsResponse {
total: models.len(),
models,
};
Ok(HttpResponse::Ok().json(ApiResponse::success(response)))
}
pub async fn get_model_budget(
budget_limits: web::Data<Arc<UnifiedBudgetLimits>>,
path: web::Path<String>,
) -> ActixResult<HttpResponse> {
let model_name = path.into_inner();
match budget_limits.models.get_model_usage(&model_name) {
Some(usage) => {
let budgets = budget_limits.models.list_model_budgets();
let budget = budgets.iter().find(|b| b.model_name == model_name);
let response = ModelBudgetResponse {
model: usage.model_name,
max_budget: usage.max_budget,
current_spend: usage.current_spend,
remaining: usage.remaining,
usage_percentage: usage.usage_percentage,
status: usage.status,
reset_period: usage.reset_period,
currency: budget.map(|b| b.currency).unwrap_or_default(),
enabled: budget.map(|b| b.enabled).unwrap_or(true),
request_count: usage.request_count,
last_reset_at: usage.last_reset_at,
};
Ok(HttpResponse::Ok().json(ApiResponse::success(response)))
}
None => {
warn!("Model budget not found: {}", model_name);
Ok(
HttpResponse::NotFound().json(ApiResponse::<()>::error(format!(
"Model budget not found: {}",
model_name
))),
)
}
}
}
pub async fn delete_model_budget(
budget_limits: web::Data<Arc<UnifiedBudgetLimits>>,
path: web::Path<String>,
) -> ActixResult<HttpResponse> {
let model_name = path.into_inner();
if budget_limits.models.remove_model_limit(&model_name) {
info!("Removed model budget for '{}'", model_name);
let response = DeleteBudgetResponse {
success: true,
message: format!("Model budget '{}' removed successfully", model_name),
};
Ok(HttpResponse::Ok().json(ApiResponse::success(response)))
} else {
warn!("Model budget not found for deletion: {}", model_name);
Ok(
HttpResponse::NotFound().json(ApiResponse::<()>::error(format!(
"Model budget not found: {}",
model_name
))),
)
}
}
pub async fn reset_model_budget(
budget_limits: web::Data<Arc<UnifiedBudgetLimits>>,
path: web::Path<String>,
) -> ActixResult<HttpResponse> {
let model_name = path.into_inner();
if budget_limits.models.reset_model_budget(&model_name) {
info!("Reset model budget for '{}'", model_name);
match budget_limits.models.get_model_usage(&model_name) {
Some(usage) => {
let budgets = budget_limits.models.list_model_budgets();
let budget = budgets.iter().find(|b| b.model_name == model_name);
let response = ModelBudgetResponse {
model: usage.model_name,
max_budget: usage.max_budget,
current_spend: usage.current_spend,
remaining: usage.remaining,
usage_percentage: usage.usage_percentage,
status: usage.status,
reset_period: usage.reset_period,
currency: budget.map(|b| b.currency).unwrap_or_default(),
enabled: budget.map(|b| b.enabled).unwrap_or(true),
request_count: usage.request_count,
last_reset_at: usage.last_reset_at,
};
Ok(HttpResponse::Ok().json(ApiResponse::success(response)))
}
None => Ok(
HttpResponse::InternalServerError().json(ApiResponse::<()>::error(
"Failed to retrieve budget after reset".to_string(),
)),
),
}
} else {
warn!("Model budget not found for reset: {}", model_name);
Ok(
HttpResponse::NotFound().json(ApiResponse::<()>::error(format!(
"Model budget not found: {}",
model_name
))),
)
}
}
pub async fn get_budget_summary(
budget_limits: web::Data<Arc<UnifiedBudgetLimits>>,
) -> ActixResult<HttpResponse> {
let provider_usage = budget_limits.providers.list_provider_usage();
let model_usage = budget_limits.models.list_model_usage();
let exceeded_providers: Vec<String> = provider_usage
.iter()
.filter(|u| u.status == BudgetStatus::Exceeded)
.map(|u| u.provider_name.clone())
.collect();
let warning_providers: Vec<String> = provider_usage
.iter()
.filter(|u| u.status == BudgetStatus::Warning)
.map(|u| u.provider_name.clone())
.collect();
let total_provider_allocated: f64 = provider_usage.iter().map(|u| u.max_budget).sum();
let total_provider_spent: f64 = provider_usage.iter().map(|u| u.current_spend).sum();
let total_model_allocated: f64 = model_usage.iter().map(|u| u.max_budget).sum();
let total_model_spent: f64 = model_usage.iter().map(|u| u.current_spend).sum();
let response = BudgetSummaryResponse {
total_provider_budgets: provider_usage.len(),
total_model_budgets: model_usage.len(),
exceeded_providers,
warning_providers,
total_provider_allocated,
total_provider_spent,
total_model_allocated,
total_model_spent,
};
Ok(HttpResponse::Ok().json(ApiResponse::success(response)))
}
pub fn configure_budget_routes(cfg: &mut web::ServiceConfig) {
cfg.service(
web::scope("/v1/budget")
.route("/providers", web::post().to(set_provider_budget))
.route("/providers", web::get().to(list_provider_budgets))
.route("/providers/{name}", web::get().to(get_provider_budget))
.route(
"/providers/{name}",
web::delete().to(delete_provider_budget),
)
.route(
"/providers/{name}/reset",
web::post().to(reset_provider_budget),
)
.route("/models", web::post().to(set_model_budget))
.route("/models", web::get().to(list_model_budgets))
.route("/models/{name}", web::get().to(get_model_budget))
.route("/models/{name}", web::delete().to(delete_model_budget))
.route("/models/{name}/reset", web::post().to(reset_model_budget))
.route("/summary", web::get().to(get_budget_summary)),
);
}
#[cfg(test)]
mod tests {
use super::*;
use actix_web::{App, test};
#[actix_web::test]
async fn test_set_provider_budget() {
let budget_limits = Arc::new(UnifiedBudgetLimits::new());
let app = test::init_service(
App::new()
.app_data(web::Data::new(budget_limits))
.configure(configure_budget_routes),
)
.await;
let request = SetProviderBudgetRequest {
provider: "openai".to_string(),
max_budget: 1000.0,
reset_period: ResetPeriod::Monthly,
soft_limit_percentage: 0.8,
currency: Currency::USD,
enabled: true,
};
let req = test::TestRequest::post()
.uri("/v1/budget/providers")
.set_json(&request)
.to_request();
let resp: actix_web::dev::ServiceResponse = test::call_service(&app, req).await;
assert!(resp.status().is_success());
}
#[actix_web::test]
async fn test_set_provider_budget_validation() {
let budget_limits = Arc::new(UnifiedBudgetLimits::new());
let app = test::init_service(
App::new()
.app_data(web::Data::new(budget_limits))
.configure(configure_budget_routes),
)
.await;
let request = SetProviderBudgetRequest {
provider: "".to_string(),
max_budget: 1000.0,
reset_period: ResetPeriod::Monthly,
soft_limit_percentage: 0.8,
currency: Currency::USD,
enabled: true,
};
let req = test::TestRequest::post()
.uri("/v1/budget/providers")
.set_json(&request)
.to_request();
let resp: actix_web::dev::ServiceResponse = test::call_service(&app, req).await;
assert_eq!(resp.status(), 400);
let request = SetProviderBudgetRequest {
provider: "openai".to_string(),
max_budget: -100.0,
reset_period: ResetPeriod::Monthly,
soft_limit_percentage: 0.8,
currency: Currency::USD,
enabled: true,
};
let req = test::TestRequest::post()
.uri("/v1/budget/providers")
.set_json(&request)
.to_request();
let resp: actix_web::dev::ServiceResponse = test::call_service(&app, req).await;
assert_eq!(resp.status(), 400);
}
#[actix_web::test]
async fn test_list_provider_budgets() {
let budget_limits = Arc::new(UnifiedBudgetLimits::new());
budget_limits.providers.set_provider_limit(
"openai",
ProviderLimitConfig::new(1000.0, ResetPeriod::Monthly),
);
budget_limits.providers.set_provider_limit(
"anthropic",
ProviderLimitConfig::new(500.0, ResetPeriod::Monthly),
);
let app = test::init_service(
App::new()
.app_data(web::Data::new(budget_limits))
.configure(configure_budget_routes),
)
.await;
let req = test::TestRequest::get()
.uri("/v1/budget/providers")
.to_request();
let resp: actix_web::dev::ServiceResponse = test::call_service(&app, req).await;
assert!(resp.status().is_success());
}
#[actix_web::test]
async fn test_get_provider_budget() {
let budget_limits = Arc::new(UnifiedBudgetLimits::new());
budget_limits.providers.set_provider_limit(
"openai",
ProviderLimitConfig::new(1000.0, ResetPeriod::Monthly),
);
let app = test::init_service(
App::new()
.app_data(web::Data::new(budget_limits))
.configure(configure_budget_routes),
)
.await;
let req = test::TestRequest::get()
.uri("/v1/budget/providers/openai")
.to_request();
let resp: actix_web::dev::ServiceResponse = test::call_service(&app, req).await;
assert!(resp.status().is_success());
}
#[actix_web::test]
async fn test_get_provider_budget_not_found() {
let budget_limits = Arc::new(UnifiedBudgetLimits::new());
let app = test::init_service(
App::new()
.app_data(web::Data::new(budget_limits))
.configure(configure_budget_routes),
)
.await;
let req = test::TestRequest::get()
.uri("/v1/budget/providers/nonexistent")
.to_request();
let resp: actix_web::dev::ServiceResponse = test::call_service(&app, req).await;
assert_eq!(resp.status(), 404);
}
#[actix_web::test]
async fn test_delete_provider_budget() {
let budget_limits = Arc::new(UnifiedBudgetLimits::new());
budget_limits.providers.set_provider_limit(
"openai",
ProviderLimitConfig::new(1000.0, ResetPeriod::Monthly),
);
let app = test::init_service(
App::new()
.app_data(web::Data::new(budget_limits))
.configure(configure_budget_routes),
)
.await;
let req = test::TestRequest::delete()
.uri("/v1/budget/providers/openai")
.to_request();
let resp: actix_web::dev::ServiceResponse = test::call_service(&app, req).await;
assert!(resp.status().is_success());
}
#[actix_web::test]
async fn test_set_model_budget() {
let budget_limits = Arc::new(UnifiedBudgetLimits::new());
let app = test::init_service(
App::new()
.app_data(web::Data::new(budget_limits))
.configure(configure_budget_routes),
)
.await;
let request = SetModelBudgetRequest {
model: "gpt-4".to_string(),
max_budget: 500.0,
reset_period: ResetPeriod::Monthly,
soft_limit_percentage: 0.8,
currency: Currency::USD,
enabled: true,
};
let req = test::TestRequest::post()
.uri("/v1/budget/models")
.set_json(&request)
.to_request();
let resp: actix_web::dev::ServiceResponse = test::call_service(&app, req).await;
assert!(resp.status().is_success());
}
#[actix_web::test]
async fn test_list_model_budgets() {
let budget_limits = Arc::new(UnifiedBudgetLimits::new());
budget_limits
.models
.set_model_limit("gpt-4", ModelLimitConfig::new(500.0, ResetPeriod::Monthly));
let app = test::init_service(
App::new()
.app_data(web::Data::new(budget_limits))
.configure(configure_budget_routes),
)
.await;
let req = test::TestRequest::get()
.uri("/v1/budget/models")
.to_request();
let resp: actix_web::dev::ServiceResponse = test::call_service(&app, req).await;
assert!(resp.status().is_success());
}
#[actix_web::test]
async fn test_get_budget_summary() {
let budget_limits = Arc::new(UnifiedBudgetLimits::new());
budget_limits.providers.set_provider_limit(
"openai",
ProviderLimitConfig::new(1000.0, ResetPeriod::Monthly),
);
budget_limits
.models
.set_model_limit("gpt-4", ModelLimitConfig::new(500.0, ResetPeriod::Monthly));
let app = test::init_service(
App::new()
.app_data(web::Data::new(budget_limits))
.configure(configure_budget_routes),
)
.await;
let req = test::TestRequest::get()
.uri("/v1/budget/summary")
.to_request();
let resp: actix_web::dev::ServiceResponse = test::call_service(&app, req).await;
assert!(resp.status().is_success());
}
}