dwctl 8.40.0

The Doubleword Control Layer - A self-hostable observability and analytics platform for LLM applications
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
//! Error enrichment middleware for AI proxy requests.
//!
//! This module provides middleware that intercepts error responses from the onwards
//! routing layer and enriches them with helpful context. The primary use case is
//! transforming generic 403 errors (insufficient credits) into detailed responses
//! that explain the issue and suggest remediation.
//!
//! ## Architecture
//!
//! The middleware sits in the response path after onwards but before outlet:
//! ```text
//! Request → onwards (routing) → error_enrichment (this) → outlet (logging) → Response
//! ```
//!
//! ## Error Cases Enriched
//!
//! 1. **403 Forbidden - Insufficient Credits**: User's balance < 0 for paid models
//!    - Shows current balance
//! 2. **403 Forbidden - Model Access Denied**: User is not a member of a group with access to the requested model
//!    - Shows which model was requested

use crate::{
    db::errors::DbError,
    db::handlers::{Credits, api_keys::ApiKeys},
    errors::Error,
    types::UserId,
};
use axum::{
    body::Body,
    extract::State,
    http::{Request, Response, StatusCode},
    middleware::Next,
    response::IntoResponse,
};
use rust_decimal::Decimal;
use serde::Deserialize;
use sqlx::PgPool;
use tracing::{debug, instrument};

/// Request body structure for extracting model name
#[derive(Debug, Deserialize)]
struct ChatRequest {
    model: String,
}

/// Middleware that enriches error responses from the AI proxy with helpful context
///
/// Currently handles:
/// - 403 Forbidden errors (likely insufficient credits) → enriched with balance
/// - 403 Forbidden errors (likely model access denied) → enriched with model name
#[instrument(name = "dwctl.error_enrichment", skip_all, fields(http.request.method = %request.method(), url.path = %request.uri().path(), url.query = request.uri().query().unwrap_or("")))]
pub async fn error_enrichment_middleware(State(pool): State<PgPool>, request: Request<Body>, next: Next) -> Response<Body> {
    // Extract API key from request headers before passing to onwards
    let api_key = request
        .headers()
        .get("authorization")
        .and_then(|h| h.to_str().ok())
        .and_then(|auth| auth.strip_prefix("Bearer ").or_else(|| auth.strip_prefix("bearer ")))
        .map(|token| token.trim().to_string());

    // Extract request body to get model name for potential error enrichment
    let (parts, body) = request.into_parts();
    let bytes = match axum::body::to_bytes(body, usize::MAX).await {
        Ok(b) => b,
        Err(_) => {
            // If we can't read the body, just pass through
            let reconstructed = Request::from_parts(parts, Body::empty());
            return next.run(reconstructed).await;
        }
    };

    let model_name = serde_json::from_slice::<ChatRequest>(&bytes).ok().map(|req| req.model);

    // Reconstruct the request with the body
    let reconstructed = Request::from_parts(parts, Body::from(bytes));

    // Let the request proceed through onwards
    let response = next.run(reconstructed).await;

    // Only enrich 403 errors when we have an API key
    // Note: This middleware is applied only to the onwards router (AI proxy paths),
    // so no path filtering is needed here
    if response.status() == StatusCode::FORBIDDEN
        && let Some(key) = api_key
    {
        debug!("Intercepted 403 response on AI proxy path, attempting enrichment");

        // First check if it's a model access issue
        if let Some(model) = model_name
            && let Ok(user_id) = get_user_id_of_api_key(pool.clone(), &key).await
            && let Ok(has_access) = check_user_has_model_access(pool.clone(), user_id, &model).await
            && !has_access
        {
            return Error::ModelAccessDenied {
                model_name: model.clone(),
                message: format!(
                    "You do not have access to '{}'. Please contact your administrator to request access.",
                    model
                ),
            }
            .into_response();
        }

        // If access is OK but we have a 403, check balance - onwards excludes keys with balance <= 0
        if let Ok(balance) = get_balance_of_api_key(pool.clone(), &key).await
            && balance <= Decimal::ZERO
        {
            return Error::InsufficientCredits {
                current_balance: balance,
                message: "Account balance too low. Please add credits to continue.".to_string(),
            }
            .into_response();
        }
    }

    response
}

#[instrument(skip_all, name = "dwctl.get_user_id_of_api_key")]
pub async fn get_user_id_of_api_key(pool: PgPool, api_key: &str) -> Result<UserId, DbError> {
    let mut conn = pool.acquire().await?;
    let mut api_keys_repo = ApiKeys::new(&mut conn);
    api_keys_repo
        .get_user_id_by_secret(api_key)
        .await?
        .ok_or_else(|| anyhow::anyhow!("API key not found or associated user doesn't exist").into())
}

#[instrument(skip_all, name = "dwctl.get_balance_of_api_key")]
pub async fn get_balance_of_api_key(pool: PgPool, api_key: &str) -> Result<Decimal, DbError> {
    // Look up user_id from API key
    let user_id = get_user_id_of_api_key(pool.clone(), api_key).await?;

    debug!("Found user_id for API key: {}", user_id);

    // Query user's current balance
    let mut conn = pool.acquire().await?;
    let mut credits_repo = Credits::new(&mut conn);
    credits_repo.get_user_balance(user_id).await
}

#[instrument(skip_all, name = "dwctl.check_user_has_model_access")]
pub async fn check_user_has_model_access(pool: PgPool, user_id: UserId, model_alias: &str) -> Result<bool, DbError> {
    let mut conn = pool.acquire().await?;

    // Query to check if user has access to this deployment through group membership
    let result = sqlx::query_scalar!(
        r#"
        SELECT EXISTS(
            SELECT 1
            FROM deployed_models d
            JOIN deployment_groups dg ON dg.deployment_id = d.id
            WHERE d.alias = $1
              AND d.deleted = false
              AND dg.group_id IN (
                  SELECT ug.group_id FROM user_groups ug WHERE ug.user_id = $2
                  UNION
                  SELECT '00000000-0000-0000-0000-000000000000'::uuid
                  WHERE $2 != '00000000-0000-0000-0000-000000000000'
              )
        ) as "has_access!"
        "#,
        model_alias,
        user_id
    )
    .fetch_one(&mut *conn)
    .await?;

    Ok(result)
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::db::{
        handlers::{Credits, Repository as _, api_keys::ApiKeys},
        models::{
            api_keys::{ApiKeyCreateDBRequest, ApiKeyPurpose},
            credits::{CreditTransactionCreateDBRequest, CreditTransactionType},
        },
    };
    use crate::{api::models::users::Role, test::utils::create_test_user};
    use rust_decimal::Decimal;

    /// Integration test: Error enrichment middleware enriches 403 with balance info
    #[sqlx::test]
    #[test_log::test]
    async fn test_error_enrichment_middleware_enriches_403_with_balance(pool: PgPool) {
        use crate::test::utils::{add_deployment_to_group, add_user_to_group, create_test_group};

        // Create test user with an API key
        let user = create_test_user(&pool, Role::StandardUser).await;

        let mut api_key_conn = pool.acquire().await.unwrap();
        let mut api_keys_repo = ApiKeys::new(&mut api_key_conn);
        let api_key = api_keys_repo
            .create(&ApiKeyCreateDBRequest {
                user_id: user.id,
                name: "Test Key".to_string(),
                description: None,
                purpose: ApiKeyPurpose::Realtime,
                requests_per_second: None,
                burst_size: None,
                created_by: user.id,
            })
            .await
            .unwrap();

        // Give user some credits
        let mut credits_conn = pool.acquire().await.unwrap();
        let mut credits_repo = Credits::new(&mut credits_conn);
        credits_repo
            .create_transaction(&CreditTransactionCreateDBRequest {
                user_id: user.id,
                transaction_type: CreditTransactionType::AdminGrant,
                amount: Decimal::new(5000, 2),
                source_id: uuid::Uuid::new_v4().to_string(),
                description: Some("Initial credits".to_string()),
                fusillade_batch_id: None,
                api_key_id: None,
            })
            .await
            .unwrap();

        // Create a test app with just the error enrichment middleware
        let router = axum::Router::new()
            .route(
                "/ai/v1/chat/completions",
                axum::routing::post(|| async {
                    // Simulate onwards returning 403 (insufficient credits)
                    axum::response::Response::builder()
                        .status(StatusCode::FORBIDDEN)
                        .body(axum::body::Body::from("Forbidden"))
                        .unwrap()
                }),
            )
            .layer(axum::middleware::from_fn_with_state(
                pool.clone(),
                crate::error_enrichment::error_enrichment_middleware,
            ));

        let server = axum_test::TestServer::new(router).expect("Failed to create test server");

        // Make request with API key in Authorization header
        let response = server
            .post("/ai/v1/chat/completions")
            .add_header("authorization", &format!("Bearer {}", api_key.secret))
            .json(&serde_json::json!({
                "model": "test-model",
                "messages": [{"role": "user", "content": "Hello"}]
            }))
            .await;

        // Test that response is 403 and enriched with model access denied (user has positive balance but no model access)
        assert_eq!(response.status_code().as_u16(), 403);
        let body = response.text();
        assert!(body.contains("do not have access to 'test-model'"));

        // Now, deduct all credits to simulate insufficient balance
        let mut credits_conn = pool.acquire().await.unwrap();
        let mut credits_repo = Credits::new(&mut credits_conn);
        credits_repo
            .create_transaction(&CreditTransactionCreateDBRequest {
                user_id: user.id,
                transaction_type: CreditTransactionType::Usage,
                amount: Decimal::new(10000, 2),
                source_id: uuid::Uuid::new_v4().to_string(),
                description: Some("Usage".to_string()),
                fusillade_batch_id: None,
                api_key_id: None,
            })
            .await
            .unwrap();

        // We need to ensure user is part of a group with access to the deployment to avoid 403
        let endpoint_id = crate::test::utils::create_test_endpoint(&pool, "test-endpoint", user.id).await;
        let deployment_id =
            crate::test::utils::create_test_model(&pool, "authorized-model-name", "authorized-model", endpoint_id, user.id).await;

        let group = create_test_group(&pool).await;
        add_user_to_group(&pool, user.id, group.id).await;
        // Grant access to the group
        add_deployment_to_group(&pool, deployment_id, group.id, user.id).await;

        // Make request with API key in Authorization header, using the model the user has access to
        let response = server
            .post("/ai/v1/chat/completions")
            .add_header("authorization", &format!("Bearer {}", api_key.secret))
            .json(&serde_json::json!({
                "model": "authorized-model",
                "messages": [{"role": "user", "content": "Hello"}]
            }))
            .await;

        // Test that response is 402 and has been enriched
        assert_eq!(response.status_code().as_u16(), 402);
        let body = response.text();
        println!("Enriched response body: {}", body);
        assert!(body.contains("balance too low"));
    }

    /// Integration test: Error enrichment middleware passes through 403 when user has access
    #[sqlx::test]
    #[test_log::test]
    async fn test_error_enrichment_middleware_passes_through_legitimate_403(pool: PgPool) {
        use crate::test::utils::{add_deployment_to_group, add_user_to_group, create_test_group};

        // Create test user with positive balance and model access
        let user = create_test_user(&pool, Role::StandardUser).await;

        // Create a group and add the user to it
        let group = create_test_group(&pool).await;
        add_user_to_group(&pool, user.id, group.id).await;

        // Create API key
        let mut api_key_conn = pool.acquire().await.unwrap();
        let mut api_keys_repo = ApiKeys::new(&mut api_key_conn);
        let api_key = api_keys_repo
            .create(&ApiKeyCreateDBRequest {
                user_id: user.id,
                name: "Test Key".to_string(),
                description: None,
                purpose: ApiKeyPurpose::Realtime,
                requests_per_second: None,
                burst_size: None,
                created_by: user.id,
            })
            .await
            .unwrap();

        // Give user positive balance
        let mut credits_conn = pool.acquire().await.unwrap();
        let mut credits_repo = Credits::new(&mut credits_conn);
        credits_repo
            .create_transaction(&CreditTransactionCreateDBRequest {
                user_id: user.id,
                transaction_type: CreditTransactionType::AdminGrant,
                amount: Decimal::new(5000, 2),
                source_id: uuid::Uuid::new_v4().to_string(),
                description: Some("Initial credits".to_string()),
                fusillade_batch_id: None,
                api_key_id: None,
            })
            .await
            .unwrap();

        // Create a deployment with 'authorized-model' alias and grant access to the group
        let endpoint_id = crate::test::utils::create_test_endpoint(&pool, "test-endpoint", user.id).await;
        let deployment_id =
            crate::test::utils::create_test_model(&pool, "authorized-model-name", "authorized-model", endpoint_id, user.id).await;

        // Grant access to the group
        add_deployment_to_group(&pool, deployment_id, group.id, user.id).await;

        // Create a test app with middleware that returns 403 for a different reason (e.g., rate limit)
        let router = axum::Router::new()
            .route(
                "/ai/v1/chat/completions",
                axum::routing::post(|| async {
                    // Simulate a legitimate 403 error (not credits/access related)
                    axum::response::Response::builder()
                        .status(StatusCode::FORBIDDEN)
                        .body(axum::body::Body::from("Rate limit exceeded"))
                        .unwrap()
                }),
            )
            .layer(axum::middleware::from_fn_with_state(
                pool.clone(),
                crate::error_enrichment::error_enrichment_middleware,
            ));

        let server = axum_test::TestServer::new(router).expect("Failed to create test server");

        // Make request with API key and a model the user has access to
        let response = server
            .post("/ai/v1/chat/completions")
            .add_header("authorization", &format!("Bearer {}", api_key.secret))
            .json(&serde_json::json!({
                "model": "authorized-model",
                "messages": [{"role": "user", "content": "Hello"}]
            }))
            .await;

        // Should pass through the original 403 without enrichment
        assert_eq!(response.status_code().as_u16(), 403);
        let body = response.text();
        assert!(body.contains("Rate limit exceeded"));
        assert!(!body.contains("balance"));
        assert!(!body.contains("access"));
    }

    /// Integration test: Error enrichment middleware only affects /ai/ paths
    #[sqlx::test]
    #[test_log::test]
    async fn test_error_enrichment_middleware_ignores_non_ai_paths(pool: PgPool) {
        let router = axum::Router::new()
            .route(
                "/admin/api/v1/users",
                axum::routing::get(|| async {
                    // Return 403 on admin path
                    axum::response::Response::builder()
                        .status(StatusCode::FORBIDDEN)
                        .body(axum::body::Body::from("Admin Forbidden"))
                        .unwrap()
                }),
            )
            .layer(axum::middleware::from_fn_with_state(
                pool.clone(),
                crate::error_enrichment::error_enrichment_middleware,
            ));

        let server = axum_test::TestServer::new(router).expect("Failed to create test server");
        let response = server.get("/admin/api/v1/users").await;

        // Should remain 403, not enriched
        assert_eq!(response.status_code().as_u16(), 403);
        assert_eq!(response.text(), "Admin Forbidden");
    }

    /// Integration test: Error enrichment middleware ignores non-403 responses
    #[sqlx::test]
    #[test_log::test]
    async fn test_error_enrichment_middleware_ignores_non_403_errors(pool: PgPool) {
        let router = axum::Router::new()
            .route(
                "/ai/v1/chat/completions",
                axum::routing::post(|| async {
                    // Return 404 instead of 403
                    axum::response::Response::builder()
                        .status(StatusCode::NOT_FOUND)
                        .header("authorization", "Bearer dummy-key")
                        .body(axum::body::Body::from("Not Found"))
                        .unwrap()
                }),
            )
            .layer(axum::middleware::from_fn_with_state(
                pool.clone(),
                crate::error_enrichment::error_enrichment_middleware,
            ));

        let server = axum_test::TestServer::new(router).expect("Failed to create test server");
        let response = server
            .post("/ai/v1/chat/completions")
            .add_header("authorization", "Bearer test-key")
            .json(&serde_json::json!({"model": "test"}))
            .await;

        // Should remain 404, not enriched
        assert_eq!(response.status_code().as_u16(), 404);
        assert_eq!(response.text(), "Not Found");
    }

    /// Integration test: Error enrichment middleware passes through when no auth header
    #[sqlx::test]
    #[test_log::test]
    async fn test_error_enrichment_middleware_without_auth_header(pool: PgPool) {
        let router = axum::Router::new()
            .route(
                "/ai/v1/chat/completions",
                axum::routing::post(|| async {
                    // Return 403 without auth header
                    axum::response::Response::builder()
                        .status(StatusCode::FORBIDDEN)
                        .body(axum::body::Body::from("No Auth"))
                        .unwrap()
                }),
            )
            .layer(axum::middleware::from_fn_with_state(
                pool.clone(),
                crate::error_enrichment::error_enrichment_middleware,
            ));

        let server = axum_test::TestServer::new(router).expect("Failed to create test server");
        let response = server
            .post("/ai/v1/chat/completions")
            .json(&serde_json::json!({"model": "test"}))
            .await;

        // Should pass through original 403 without enrichment
        assert_eq!(response.status_code().as_u16(), 403);
        assert_eq!(response.text(), "No Auth");
    }
}