use crate::core::models::RequestContext;
use crate::utils::error::{GatewayError, Result};
use reqwest::Client;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::RwLock;
use tracing::{debug, error, info, warn};
use uuid::Uuid;
#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum WebhookEventType {
RequestStarted,
RequestCompleted,
RequestFailed,
RateLimitExceeded,
CostThresholdExceeded,
ProviderHealthChanged,
CacheEvent,
BatchCompleted,
BatchFailed,
UserCreated,
UserUpdated,
ApiKeyCreated,
ApiKeyRevoked,
BudgetThresholdReached,
SecurityAlert,
Custom(String),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WebhookPayload {
pub event_type: WebhookEventType,
pub timestamp: chrono::DateTime<chrono::Utc>,
pub request_context: Option<RequestContext>,
pub data: serde_json::Value,
pub metadata: HashMap<String, String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WebhookConfig {
pub url: String,
pub events: Vec<WebhookEventType>,
pub headers: HashMap<String, String>,
pub secret: Option<String>,
pub timeout_seconds: u64,
pub max_retries: u32,
pub retry_delay_seconds: u64,
pub enabled: bool,
}
impl Default for WebhookConfig {
fn default() -> Self {
Self {
url: String::new(),
events: vec![],
headers: HashMap::new(),
secret: None,
timeout_seconds: 30,
max_retries: 3,
retry_delay_seconds: 5,
enabled: true,
}
}
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub enum WebhookDeliveryStatus {
Pending,
Delivered,
Failed,
Retrying,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct WebhookDelivery {
pub id: String,
pub webhook_id: String,
pub payload: WebhookPayload,
pub status: WebhookDeliveryStatus,
pub response_status: Option<u16>,
pub response_body: Option<String>,
pub attempts: u32,
pub created_at: chrono::DateTime<chrono::Utc>,
pub last_attempt_at: Option<chrono::DateTime<chrono::Utc>>,
pub next_retry_at: Option<chrono::DateTime<chrono::Utc>>,
}
pub struct WebhookManager {
webhooks: Arc<RwLock<HashMap<String, WebhookConfig>>>,
client: Client,
delivery_queue: Arc<RwLock<Vec<WebhookDelivery>>>,
stats: Arc<RwLock<WebhookStats>>,
}
#[derive(Debug, Default, Clone, Serialize, Deserialize)]
pub struct WebhookStats {
pub total_events: u64,
pub successful_deliveries: u64,
pub failed_deliveries: u64,
pub avg_delivery_time_ms: f64,
pub events_by_type: HashMap<String, u64>,
}
impl WebhookManager {
pub fn new() -> Self {
let client = Client::builder()
.timeout(Duration::from_secs(30))
.build()
.expect("Failed to create HTTP client");
Self {
webhooks: Arc::new(RwLock::new(HashMap::new())),
client,
delivery_queue: Arc::new(RwLock::new(Vec::new())),
stats: Arc::new(RwLock::new(WebhookStats::default())),
}
}
pub async fn register_webhook(&self, id: String, config: WebhookConfig) -> Result<()> {
info!("Registering webhook: {} -> {}", id, config.url);
if config.url.is_empty() {
return Err(GatewayError::Validation(
"Webhook URL cannot be empty".to_string(),
));
}
if !config.url.starts_with("http://") && !config.url.starts_with("https://") {
return Err(GatewayError::Validation(
"Webhook URL must be HTTP or HTTPS".to_string(),
));
}
let mut webhooks = self.webhooks.write().await;
webhooks.insert(id, config);
Ok(())
}
pub async fn unregister_webhook(&self, id: &str) -> Result<()> {
info!("Unregistering webhook: {}", id);
let mut webhooks = self.webhooks.write().await;
webhooks.remove(id);
Ok(())
}
pub async fn send_event(
&self,
event_type: WebhookEventType,
data: serde_json::Value,
context: Option<RequestContext>,
) -> Result<()> {
let payload = WebhookPayload {
event_type: event_type.clone(),
timestamp: chrono::Utc::now(),
request_context: context,
data,
metadata: HashMap::new(),
};
let webhooks = self.webhooks.read().await;
let mut deliveries = Vec::new();
for (webhook_id, config) in webhooks.iter() {
if config.enabled && config.events.contains(&event_type) {
let delivery = WebhookDelivery {
id: Uuid::new_v4().to_string(),
webhook_id: webhook_id.clone(),
payload: payload.clone(),
status: WebhookDeliveryStatus::Pending,
response_status: None,
response_body: None,
attempts: 0,
created_at: chrono::Utc::now(),
last_attempt_at: None,
next_retry_at: Some(chrono::Utc::now()),
};
deliveries.push(delivery);
}
}
let delivery_count = deliveries.len();
if !deliveries.is_empty() {
let mut queue = self.delivery_queue.write().await;
queue.extend(deliveries);
let mut stats = self.stats.write().await;
stats.total_events += 1;
*stats
.events_by_type
.entry(format!("{:?}", event_type))
.or_insert(0) += 1;
}
debug!(
"Queued {} webhook deliveries for event: {:?}",
delivery_count, event_type
);
Ok(())
}
pub async fn process_delivery_queue(&self) -> Result<()> {
let mut queue = self.delivery_queue.write().await;
let mut processed_deliveries = Vec::new();
for delivery in queue.iter_mut() {
if delivery.status == WebhookDeliveryStatus::Pending
|| (delivery.status == WebhookDeliveryStatus::Retrying
&& delivery
.next_retry_at
.map_or(false, |t| t <= chrono::Utc::now()))
{
let result = self.deliver_webhook(delivery).await;
processed_deliveries.push(delivery.id.clone());
match result {
Ok(_) => {
delivery.status = WebhookDeliveryStatus::Delivered;
let mut stats = self.stats.write().await;
stats.successful_deliveries += 1;
}
Err(e) => {
delivery.attempts += 1;
delivery.last_attempt_at = Some(chrono::Utc::now());
if delivery.attempts
>= self
.get_webhook_config(&delivery.webhook_id)
.await?
.max_retries
{
delivery.status = WebhookDeliveryStatus::Failed;
let mut stats = self.stats.write().await;
stats.failed_deliveries += 1;
error!("Webhook delivery failed permanently: {}", e);
} else {
delivery.status = WebhookDeliveryStatus::Retrying;
delivery.next_retry_at = Some(
chrono::Utc::now()
+ chrono::Duration::seconds(
self.get_webhook_config(&delivery.webhook_id)
.await?
.retry_delay_seconds
as i64,
),
);
warn!("Webhook delivery failed, will retry: {}", e);
}
}
}
}
}
queue.retain(|d| d.status != WebhookDeliveryStatus::Delivered);
Ok(())
}
async fn deliver_webhook(&self, delivery: &mut WebhookDelivery) -> Result<()> {
let config = self.get_webhook_config(&delivery.webhook_id).await?;
let start_time = std::time::Instant::now();
let mut request = self
.client
.post(&config.url)
.timeout(Duration::from_secs(config.timeout_seconds))
.header("Content-Type", "application/json")
.header("User-Agent", "LiteLLM-Gateway/1.0");
for (key, value) in &config.headers {
request = request.header(key, value);
}
if let Some(secret) = &config.secret {
let signature = self.generate_signature(&delivery.payload, secret)?;
request = request.header("X-Webhook-Signature", signature);
}
let response = request
.json(&delivery.payload)
.send()
.await
.map_err(|e| GatewayError::Network(e.to_string()))?;
let status_code = response.status().as_u16();
let response_body = response.text().await.unwrap_or_default();
delivery.response_status = Some(status_code);
delivery.response_body = Some(response_body.clone());
let delivery_time = start_time.elapsed().as_millis() as f64;
let mut stats = self.stats.write().await;
stats.avg_delivery_time_ms =
(stats.avg_delivery_time_ms * (stats.successful_deliveries as f64) + delivery_time)
/ (stats.successful_deliveries + 1) as f64;
if status_code >= 200 && status_code < 300 {
debug!(
"Webhook delivered successfully: {} -> {}",
delivery.webhook_id, config.url
);
Ok(())
} else {
Err(GatewayError::External(format!(
"Webhook returned status {}: {}",
status_code, response_body
)))
}
}
async fn get_webhook_config(&self, webhook_id: &str) -> Result<WebhookConfig> {
let webhooks = self.webhooks.read().await;
webhooks
.get(webhook_id)
.cloned()
.ok_or_else(|| GatewayError::NotFound(format!("Webhook not found: {}", webhook_id)))
}
fn generate_signature(&self, payload: &WebhookPayload, secret: &str) -> Result<String> {
use hmac::{Hmac, Mac};
use sha2::Sha256;
type HmacSha256 = Hmac<Sha256>;
let payload_json =
serde_json::to_string(payload).map_err(|e| GatewayError::Serialization(e))?;
let mut mac = HmacSha256::new_from_slice(secret.as_bytes())
.map_err(|e| GatewayError::Crypto(e.to_string()))?;
mac.update(payload_json.as_bytes());
let result = mac.finalize();
Ok(format!("sha256={}", hex::encode(result.into_bytes())))
}
pub async fn get_stats(&self) -> WebhookStats {
self.stats.read().await.clone()
}
pub async fn list_webhooks(&self) -> HashMap<String, WebhookConfig> {
self.webhooks.read().await.clone()
}
pub async fn get_delivery_history(&self, limit: Option<usize>) -> Vec<WebhookDelivery> {
let queue = self.delivery_queue.read().await;
let limit = limit.unwrap_or(100);
queue.iter().rev().take(limit).cloned().collect()
}
pub async fn start_delivery_processor(&self) -> Result<()> {
let manager = self.clone();
tokio::spawn(async move {
let mut interval = tokio::time::interval(Duration::from_secs(5));
loop {
interval.tick().await;
if let Err(e) = manager.process_delivery_queue().await {
error!("Error processing webhook delivery queue: {}", e);
}
}
});
info!("Started webhook delivery processor");
Ok(())
}
}
impl Clone for WebhookManager {
fn clone(&self) -> Self {
Self {
webhooks: self.webhooks.clone(),
client: self.client.clone(),
delivery_queue: self.delivery_queue.clone(),
stats: self.stats.clone(),
}
}
}
impl Default for WebhookManager {
fn default() -> Self {
Self::new()
}
}
pub mod events {
use super::*;
pub fn request_started(
model: &str,
provider: &str,
context: RequestContext,
) -> (WebhookEventType, serde_json::Value) {
(
WebhookEventType::RequestStarted,
serde_json::json!({
"model": model,
"provider": provider,
"request_id": context.request_id,
"user_id": context.user_id,
"timestamp": chrono::Utc::now()
}),
)
}
pub fn request_completed(
model: &str,
provider: &str,
tokens_used: u32,
cost: f64,
latency_ms: u64,
context: RequestContext,
) -> (WebhookEventType, serde_json::Value) {
(
WebhookEventType::RequestCompleted,
serde_json::json!({
"model": model,
"provider": provider,
"tokens_used": tokens_used,
"cost": cost,
"latency_ms": latency_ms,
"request_id": context.request_id,
"user_id": context.user_id,
"timestamp": chrono::Utc::now()
}),
)
}
pub fn request_failed(
model: &str,
provider: &str,
error: &str,
context: RequestContext,
) -> (WebhookEventType, serde_json::Value) {
(
WebhookEventType::RequestFailed,
serde_json::json!({
"model": model,
"provider": provider,
"error": error,
"request_id": context.request_id,
"user_id": context.user_id,
"timestamp": chrono::Utc::now()
}),
)
}
pub fn cost_threshold_exceeded(
user_id: &str,
current_cost: f64,
threshold: f64,
period: &str,
) -> (WebhookEventType, serde_json::Value) {
(
WebhookEventType::CostThresholdExceeded,
serde_json::json!({
"user_id": user_id,
"current_cost": current_cost,
"threshold": threshold,
"period": period,
"timestamp": chrono::Utc::now()
}),
)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_webhook_manager_creation() {
let manager = WebhookManager::new();
let webhooks = manager.list_webhooks().await;
assert!(webhooks.is_empty());
}
#[tokio::test]
async fn test_webhook_registration() {
let manager = WebhookManager::new();
let config = WebhookConfig {
url: "https://example.com/webhook".to_string(),
events: vec![WebhookEventType::RequestCompleted],
..Default::default()
};
manager
.register_webhook("test".to_string(), config)
.await
.unwrap();
let webhooks = manager.list_webhooks().await;
assert_eq!(webhooks.len(), 1);
assert!(webhooks.contains_key("test"));
}
#[test]
fn test_webhook_event_types() {
let event = WebhookEventType::RequestStarted;
assert_eq!(event, WebhookEventType::RequestStarted);
let custom_event = WebhookEventType::Custom("my_event".to_string());
assert_eq!(
custom_event,
WebhookEventType::Custom("my_event".to_string())
);
}
#[test]
fn test_webhook_payload_serialization() {
let payload = WebhookPayload {
event_type: WebhookEventType::RequestCompleted,
timestamp: chrono::Utc::now(),
request_context: None,
data: serde_json::json!({"test": "data"}),
metadata: HashMap::new(),
};
let serialized = serde_json::to_string(&payload).unwrap();
assert!(serialized.contains("RequestCompleted"));
assert!(serialized.contains("test"));
}
}