llm_cost_ops_api/ingestion/
webhook.rs

1// Webhook HTTP server for receiving usage data
2
3use axum::{
4    extract::State,
5    http::StatusCode,
6    response::{IntoResponse, Response},
7    routing::{get, post},
8    Json, Router,
9};
10use std::sync::Arc;
11use tower::ServiceBuilder;
12use tower_http::{
13    cors::CorsLayer,
14    trace::{DefaultMakeSpan, TraceLayer},
15};
16use tracing::{error, info, warn};
17
18use llm_cost_ops::Result;
19
20use super::models::{BatchIngestionRequest, IngestionResponse, UsageWebhookPayload};
21use super::traits::{IngestionHandler, RateLimiter};
22
23/// Webhook server state
24#[derive(Clone)]
25pub struct WebhookServerState<H: IngestionHandler> {
26    handler: Arc<H>,
27}
28
29impl<H: IngestionHandler> WebhookServerState<H> {
30    pub fn new(handler: H) -> Self {
31        Self {
32            handler: Arc::new(handler),
33        }
34    }
35}
36
37/// Webhook server state with rate limiting
38#[derive(Clone)]
39pub struct WebhookServerStateWithRateLimit<H: IngestionHandler, R: RateLimiter> {
40    handler: Arc<H>,
41    rate_limiter: Arc<R>,
42}
43
44impl<H: IngestionHandler, R: RateLimiter> WebhookServerStateWithRateLimit<H, R> {
45    pub fn new(handler: H, rate_limiter: R) -> Self {
46        Self {
47            handler: Arc::new(handler),
48            rate_limiter: Arc::new(rate_limiter),
49        }
50    }
51}
52
53/// Create webhook router with all endpoints
54pub fn create_webhook_router<H: IngestionHandler + 'static>(
55    handler: H,
56) -> Router {
57    let state = WebhookServerState::new(handler);
58
59    Router::new()
60        .route("/health", get(health_handler))
61        .route("/v1/usage", post(ingest_single_handler::<H>))
62        .route("/v1/usage/batch", post(ingest_batch_handler::<H>))
63        .with_state(state)
64        .layer(
65            ServiceBuilder::new()
66                .layer(axum::middleware::from_fn(llm_cost_ops::metrics::middleware::metrics_middleware))
67                .layer(
68                    TraceLayer::new_for_http()
69                        .make_span_with(DefaultMakeSpan::new().level(tracing::Level::INFO)),
70                )
71                .layer(CorsLayer::permissive()),
72        )
73}
74
75/// Create webhook router with rate limiting enabled
76pub fn create_webhook_router_with_rate_limit<H: IngestionHandler + 'static, R: RateLimiter + Clone + 'static>(
77    handler: H,
78    rate_limiter: R,
79) -> Router {
80    let state = WebhookServerStateWithRateLimit::new(handler, rate_limiter);
81
82    Router::new()
83        .route("/health", get(health_handler))
84        .route("/v1/usage", post(ingest_single_handler_with_rate_limit))
85        .route("/v1/usage/batch", post(ingest_batch_handler_with_rate_limit))
86        .with_state(state)
87        .layer(
88            ServiceBuilder::new()
89                .layer(axum::middleware::from_fn(llm_cost_ops::metrics::middleware::metrics_middleware))
90                .layer(
91                    TraceLayer::new_for_http()
92                        .make_span_with(DefaultMakeSpan::new().level(tracing::Level::INFO)),
93                )
94                .layer(CorsLayer::permissive()),
95        )
96}
97
98/// Single usage record ingestion endpoint with rate limiting
99async fn ingest_single_handler_with_rate_limit<H: IngestionHandler, R: RateLimiter>(
100    State(state): State<WebhookServerStateWithRateLimit<H, R>>,
101    Json(payload): Json<UsageWebhookPayload>,
102) -> std::result::Result<Json<IngestionResponse>, AppError> {
103    use std::time::Instant;
104    let start = Instant::now();
105    let org_id = payload.organization_id.clone();
106
107    info!(
108        request_id = %payload.request_id,
109        organization_id = %org_id,
110        "Received single usage ingestion request"
111    );
112
113    // Check rate limit
114    match state.rate_limiter.check_rate_limit(&org_id).await {
115        Ok(allowed) => {
116            if !allowed {
117                warn!(
118                    organization_id = %org_id,
119                    "Rate limit exceeded"
120                );
121                let duration_ms = start.elapsed().as_secs_f64() * 1000.0;
122                llm_cost_ops::metrics::collectors::IngestionMetrics::record_failure(&org_id, "rate_limit", duration_ms);
123                return Err(AppError::RateLimitExceeded(org_id));
124            }
125        }
126        Err(e) => {
127            error!(error = %e, "Rate limit check failed, allowing request");
128            // On error, allow the request (fail open for availability)
129        }
130    }
131
132    // Process request
133    match state.handler.handle_single(payload).await {
134        Ok(response) => {
135            let duration_ms = start.elapsed().as_secs_f64() * 1000.0;
136            llm_cost_ops::metrics::collectors::IngestionMetrics::record_success(&org_id, 1, duration_ms);
137            Ok(Json(response))
138        }
139        Err(e) => {
140            error!(error = %e, "Failed to handle ingestion request");
141            let duration_ms = start.elapsed().as_secs_f64() * 1000.0;
142            llm_cost_ops::metrics::collectors::IngestionMetrics::record_failure(&org_id, "processing_error", duration_ms);
143            Err(AppError::InternalError(e.to_string()))
144        }
145    }
146}
147
148/// Batch usage records ingestion endpoint with rate limiting
149async fn ingest_batch_handler_with_rate_limit<H: IngestionHandler, R: RateLimiter>(
150    State(state): State<WebhookServerStateWithRateLimit<H, R>>,
151    Json(request): Json<BatchIngestionRequest>,
152) -> std::result::Result<Json<IngestionResponse>, AppError> {
153    use std::time::Instant;
154    let start = Instant::now();
155    let batch_size = request.records.len();
156
157    info!(
158        batch_id = %request.batch_id,
159        batch_size = batch_size,
160        source = %request.source,
161        "Received batch usage ingestion request"
162    );
163
164    // Record batch size metric
165    llm_cost_ops::metrics::collectors::IngestionMetrics::record_batch_size(batch_size);
166
167    // Extract organization ID from first record
168    let org_id = request.records.first()
169        .map(|r| r.organization_id.clone())
170        .unwrap_or_else(|| "unknown".to_string());
171
172    // Check rate limit
173    match state.rate_limiter.check_rate_limit(&org_id).await {
174        Ok(allowed) => {
175            if !allowed {
176                warn!(
177                    organization_id = %org_id,
178                    "Rate limit exceeded for batch request"
179                );
180                let duration_ms = start.elapsed().as_secs_f64() * 1000.0;
181                llm_cost_ops::metrics::collectors::IngestionMetrics::record_failure(&org_id, "rate_limit", duration_ms);
182                return Err(AppError::RateLimitExceeded(org_id));
183            }
184        }
185        Err(e) => {
186            error!(error = %e, "Rate limit check failed, allowing request");
187        }
188    }
189
190    // Process request
191    match state.handler.handle_batch(request.records).await {
192        Ok(response) => {
193            let duration_ms = start.elapsed().as_secs_f64() * 1000.0;
194            llm_cost_ops::metrics::collectors::IngestionMetrics::record_success(&org_id, response.accepted, duration_ms);
195            if response.rejected > 0 {
196                llm_cost_ops::metrics::collectors::IngestionMetrics::record_rejected(&org_id, response.rejected);
197            }
198            Ok(Json(response))
199        }
200        Err(e) => {
201            error!(error = %e, "Failed to handle batch ingestion request");
202            let duration_ms = start.elapsed().as_secs_f64() * 1000.0;
203            llm_cost_ops::metrics::collectors::IngestionMetrics::record_failure(&org_id, "processing_error", duration_ms);
204            Err(AppError::InternalError(e.to_string()))
205        }
206    }
207}
208
209/// Health check endpoint
210async fn health_handler() -> impl IntoResponse {
211    Json(serde_json::json!({
212        "status": "healthy",
213        "service": "llm-cost-ops-ingestion",
214        "timestamp": chrono::Utc::now().to_rfc3339()
215    }))
216}
217
218/// Single usage record ingestion endpoint
219async fn ingest_single_handler<H: IngestionHandler>(
220    State(state): State<WebhookServerState<H>>,
221    Json(payload): Json<UsageWebhookPayload>,
222) -> std::result::Result<Json<IngestionResponse>, AppError> {
223    use std::time::Instant;
224    let start = Instant::now();
225    let org_id = payload.organization_id.clone();
226
227    info!(
228        request_id = %payload.request_id,
229        organization_id = %org_id,
230        "Received single usage ingestion request"
231    );
232
233    match state.handler.handle_single(payload).await {
234        Ok(response) => {
235            let duration_ms = start.elapsed().as_secs_f64() * 1000.0;
236            llm_cost_ops::metrics::collectors::IngestionMetrics::record_success(&org_id, 1, duration_ms);
237            Ok(Json(response))
238        }
239        Err(e) => {
240            error!(error = %e, "Failed to handle ingestion request");
241            let duration_ms = start.elapsed().as_secs_f64() * 1000.0;
242            llm_cost_ops::metrics::collectors::IngestionMetrics::record_failure(&org_id, "processing_error", duration_ms);
243            Err(AppError::InternalError(e.to_string()))
244        }
245    }
246}
247
248/// Batch usage records ingestion endpoint
249async fn ingest_batch_handler<H: IngestionHandler>(
250    State(state): State<WebhookServerState<H>>,
251    Json(request): Json<BatchIngestionRequest>,
252) -> std::result::Result<Json<IngestionResponse>, AppError> {
253    use std::time::Instant;
254    let start = Instant::now();
255    let batch_size = request.records.len();
256
257    // Extract organization ID from first record
258    let org_id = request.records.first()
259        .map(|r| r.organization_id.clone())
260        .unwrap_or_else(|| "unknown".to_string());
261
262    info!(
263        batch_id = %request.batch_id,
264        batch_size = batch_size,
265        source = %request.source,
266        organization_id = %org_id,
267        "Received batch usage ingestion request"
268    );
269
270    // Record batch size metric
271    llm_cost_ops::metrics::collectors::IngestionMetrics::record_batch_size(batch_size);
272
273    match state.handler.handle_batch(request.records).await {
274        Ok(response) => {
275            let duration_ms = start.elapsed().as_secs_f64() * 1000.0;
276            llm_cost_ops::metrics::collectors::IngestionMetrics::record_success(&org_id, response.accepted, duration_ms);
277            if response.rejected > 0 {
278                llm_cost_ops::metrics::collectors::IngestionMetrics::record_rejected(&org_id, response.rejected);
279            }
280            Ok(Json(response))
281        }
282        Err(e) => {
283            error!(error = %e, "Failed to handle batch ingestion request");
284            let duration_ms = start.elapsed().as_secs_f64() * 1000.0;
285            llm_cost_ops::metrics::collectors::IngestionMetrics::record_failure(&org_id, "processing_error", duration_ms);
286            Err(AppError::InternalError(e.to_string()))
287        }
288    }
289}
290
291/// Application error types
292#[derive(Debug)]
293pub enum AppError {
294    ValidationError(String),
295    InternalError(String),
296    RateLimitExceeded(String),
297}
298
299impl IntoResponse for AppError {
300    fn into_response(self) -> Response {
301        let (status, error_message, org_id) = match self {
302            AppError::ValidationError(msg) => (StatusCode::BAD_REQUEST, msg, None),
303            AppError::InternalError(msg) => (StatusCode::INTERNAL_SERVER_ERROR, msg, None),
304            AppError::RateLimitExceeded(org) => (
305                StatusCode::TOO_MANY_REQUESTS,
306                "Rate limit exceeded".to_string(),
307                Some(org),
308            ),
309        };
310
311        let mut body_json = serde_json::json!({
312            "error": error_message,
313            "timestamp": chrono::Utc::now().to_rfc3339()
314        });
315
316        if let Some(org) = org_id {
317            body_json["organization_id"] = serde_json::json!(org);
318        }
319
320        let body = Json(body_json);
321
322        let mut response = (status, body).into_response();
323
324        // Add rate limit headers for rate limit errors
325        if status == StatusCode::TOO_MANY_REQUESTS {
326            response.headers_mut().insert(
327                "Retry-After",
328                "60".parse().unwrap(),
329            );
330            response.headers_mut().insert(
331                "X-RateLimit-Limit",
332                "1000".parse().unwrap(),
333            );
334            response.headers_mut().insert(
335                "X-RateLimit-Remaining",
336                "0".parse().unwrap(),
337            );
338        }
339
340        response
341    }
342}
343
344/// Start webhook server
345pub async fn start_webhook_server<H: IngestionHandler + 'static>(
346    bind_addr: &str,
347    handler: H,
348) -> Result<()> {
349    info!(bind_addr = %bind_addr, "Starting webhook server");
350
351    let app = create_webhook_router(handler);
352
353    let listener = tokio::net::TcpListener::bind(bind_addr).await?;
354
355    info!(
356        addr = %listener.local_addr()?,
357        "Webhook server listening"
358    );
359
360    axum::serve(listener, app).await?;
361
362    Ok(())
363}
364
365#[cfg(test)]
366mod tests {
367    use super::*;
368    use async_trait::async_trait;
369    use chrono::Utc;
370    use uuid::Uuid;
371
372    #[derive(Clone)]
373    struct MockHandler;
374
375    #[async_trait]
376    impl IngestionHandler for MockHandler {
377        async fn handle_single(
378            &self,
379            payload: UsageWebhookPayload,
380        ) -> Result<IngestionResponse> {
381            Ok(IngestionResponse {
382                request_id: payload.request_id,
383                status: super::super::models::IngestionStatus::Success,
384                accepted: 1,
385                rejected: 0,
386                errors: vec![],
387                processed_at: Utc::now(),
388            })
389        }
390
391        async fn handle_batch(
392            &self,
393            payloads: Vec<UsageWebhookPayload>,
394        ) -> Result<IngestionResponse> {
395            Ok(IngestionResponse {
396                request_id: Uuid::new_v4(),
397                status: super::super::models::IngestionStatus::Success,
398                accepted: payloads.len(),
399                rejected: 0,
400                errors: vec![],
401                processed_at: Utc::now(),
402            })
403        }
404
405        fn name(&self) -> &str {
406            "mock_handler"
407        }
408
409        async fn health_check(&self) -> Result<bool> {
410            Ok(true)
411        }
412    }
413
414    #[tokio::test]
415    async fn test_create_router() {
416        let handler = MockHandler;
417        let _router = create_webhook_router(handler);
418        // Router creation should succeed
419        assert!(true);
420    }
421}