dwctl 8.38.2

The Doubleword Control Layer - A self-hostable observability and analytics platform for LLM applications
use chrono::{DateTime, Utc};
use sqlx::PgConnection;
use tracing::instrument;
use uuid::Uuid;

use crate::db::errors::Result;

pub struct BatchCapacityReservations<'c> {
    db: &'c mut PgConnection,
}

impl<'c> BatchCapacityReservations<'c> {
    pub fn new(db: &'c mut PgConnection) -> Self {
        Self { db }
    }

    #[instrument(skip(self, model_ids), fields(count = model_ids.len()), err)]
    pub async fn sum_active_by_model_window(&mut self, model_ids: &[Uuid], completion_window: &str) -> Result<Vec<(Uuid, i64)>> {
        if model_ids.is_empty() {
            return Ok(Vec::new());
        }

        let rows = sqlx::query!(
            r#"
            SELECT model_id,
                   COALESCE(SUM(reserved_requests), 0)::BIGINT AS reserved
            FROM batch_capacity_reservations
            WHERE model_id = ANY($1)
              AND completion_window = $2
              AND released_at IS NULL
              AND expires_at > now()
            GROUP BY model_id
            "#,
            model_ids,
            completion_window
        )
        .fetch_all(&mut *self.db)
        .await?;

        Ok(rows.into_iter().map(|r| (r.model_id, r.reserved.unwrap_or(0))).collect())
    }

    #[instrument(skip(self, rows), fields(count = rows.len()), err)]
    pub async fn insert_reservations(&mut self, rows: &[(Uuid, &str, i64, DateTime<Utc>)]) -> Result<Vec<Uuid>> {
        if rows.is_empty() {
            return Ok(vec![]);
        }

        let model_ids: Vec<Uuid> = rows.iter().map(|(id, _, _, _)| *id).collect();
        let windows: Vec<&str> = rows.iter().map(|(_, w, _, _)| *w).collect();
        let counts: Vec<i64> = rows.iter().map(|(_, _, c, _)| *c).collect();
        let expires_ats: Vec<DateTime<Utc>> = rows.iter().map(|(_, _, _, e)| *e).collect();

        let ids = sqlx::query_scalar!(
            r#"
            INSERT INTO batch_capacity_reservations
                (model_id, completion_window, reserved_requests, expires_at)
            SELECT * FROM UNNEST($1::uuid[], $2::text[], $3::bigint[], $4::timestamptz[])
            RETURNING id
            "#,
            &model_ids,
            &windows as &[&str],
            &counts,
            &expires_ats as &[DateTime<Utc>],
        )
        .fetch_all(&mut *self.db)
        .await?;

        Ok(ids)
    }

    #[instrument(skip(self, ids), fields(count = ids.len()), err)]
    pub async fn release_reservations(&mut self, ids: &[Uuid]) -> Result<()> {
        if ids.is_empty() {
            return Ok(());
        }

        sqlx::query!(
            r#"
            UPDATE batch_capacity_reservations
            SET released_at = now()
            WHERE id = ANY($1)
            "#,
            ids
        )
        .execute(&mut *self.db)
        .await?;
        Ok(())
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::api::models::users::Role;
    use crate::test::utils::{create_test_endpoint, create_test_model, create_test_user};
    use chrono::{Duration, Utc};
    use sqlx::PgPool;
    use std::collections::HashMap;
    use uuid::Uuid;

    async fn setup_models(pool: &PgPool) -> (Uuid, Uuid) {
        let user = create_test_user(pool, Role::StandardUser).await;
        let endpoint_id = create_test_endpoint(pool, &format!("test-{}", Uuid::new_v4()), user.id).await;

        let model_a = create_test_model(pool, "model-a", &format!("alias-a-{}", Uuid::new_v4()), endpoint_id, user.id).await;

        let model_b = create_test_model(pool, "model-b", &format!("alias-b-{}", Uuid::new_v4()), endpoint_id, user.id).await;

        (model_a, model_b)
    }

    #[sqlx::test]
    #[test_log::test]
    async fn test_insert_and_sum_active_reservations(pool: PgPool) {
        let (model_a, model_b) = setup_models(&pool).await;

        let expires_at = Utc::now() + Duration::minutes(10);

        let mut conn = pool.acquire().await.unwrap();
        let mut repo = BatchCapacityReservations::new(&mut conn);

        let ids = repo
            .insert_reservations(&[(model_a, "24h", 10, expires_at), (model_b, "24h", 20, expires_at)])
            .await
            .unwrap();

        assert_eq!(ids.len(), 2);

        let rows = repo.sum_active_by_model_window(&[model_a, model_b], "24h").await.unwrap();

        let mut map = HashMap::new();
        for (id, sum) in rows {
            map.insert(id, sum);
        }

        assert_eq!(map.get(&model_a).copied().unwrap_or(0), 10);
        assert_eq!(map.get(&model_b).copied().unwrap_or(0), 20);
    }

    #[sqlx::test]
    #[test_log::test]
    async fn test_release_reservations_excluded_from_sum(pool: PgPool) {
        let (model_a, _) = setup_models(&pool).await;

        let expires_at = Utc::now() + Duration::minutes(10);

        let mut conn = pool.acquire().await.unwrap();
        let mut repo = BatchCapacityReservations::new(&mut conn);

        let ids = repo.insert_reservations(&[(model_a, "24h", 15, expires_at)]).await.unwrap();

        repo.release_reservations(&ids).await.unwrap();

        let rows = repo.sum_active_by_model_window(&[model_a], "24h").await.unwrap();

        let sum = rows.into_iter().find(|(id, _)| *id == model_a).map(|(_, v)| v).unwrap_or(0);

        assert_eq!(sum, 0);
    }

    #[sqlx::test]
    #[test_log::test]
    async fn test_expired_reservations_excluded_from_sum(pool: PgPool) {
        let (model_a, _) = setup_models(&pool).await;

        let expires_at = Utc::now() - Duration::minutes(1);

        let mut conn = pool.acquire().await.unwrap();
        let mut repo = BatchCapacityReservations::new(&mut conn);

        repo.insert_reservations(&[(model_a, "24h", 25, expires_at)]).await.unwrap();

        let rows = repo.sum_active_by_model_window(&[model_a], "24h").await.unwrap();

        let sum = rows.into_iter().find(|(id, _)| *id == model_a).map(|(_, v)| v).unwrap_or(0);

        assert_eq!(sum, 0);
    }
}