systemprompt-ai 0.1.21

Core AI module for systemprompt.io
Documentation
use crate::error::RepositoryError;
use crate::models::{AiRequest, ProviderUsage, UserAiUsage};
use chrono::Utc;
use systemprompt_identifiers::{AiRequestId, SessionId, UserId};

use super::AiRequestRepository;

impl AiRequestRepository {
    #[must_use = "this returns a Result that should not be ignored"]
    pub async fn get_by_id(&self, id: &AiRequestId) -> Result<Option<AiRequest>, RepositoryError> {
        sqlx::query_as!(
            AiRequest,
            r#"
            SELECT id, request_id, user_id, session_id, task_id, context_id, trace_id,
                   provider, model, temperature, top_p, max_tokens, tokens_used,
                   input_tokens, output_tokens, cost_microdollars, latency_ms, cache_hit,
                   cache_read_tokens, cache_creation_tokens, is_streaming, status,
                   error_message, created_at, updated_at, completed_at
            FROM ai_requests
            WHERE id = $1
            "#,
            id.as_str()
        )
        .fetch_optional(self.pool())
        .await
        .map_err(RepositoryError::from)
    }

    pub async fn get_provider_usage(
        &self,
        days: i32,
    ) -> Result<Vec<ProviderUsage>, RepositoryError> {
        let cutoff = Utc::now() - chrono::Duration::days(i64::from(days));
        sqlx::query_as!(
            ProviderUsage,
            r#"
            SELECT
                provider,
                model,
                COUNT(*)::bigint as "request_count!",
                COALESCE(SUM(tokens_used), 0)::bigint as "total_tokens!",
                COALESCE(SUM(cost_microdollars), 0)::float8 / 1000000.0 as "total_cost!",
                AVG(latency_ms)::bigint as "avg_latency_ms"
            FROM ai_requests
            WHERE created_at > $1 AND status = 'completed'
            GROUP BY provider, model
            ORDER BY COUNT(*) DESC
            "#,
            cutoff
        )
        .fetch_all(self.pool())
        .await
        .map_err(RepositoryError::from)
    }

    pub async fn get_user_usage(&self, user_id: &UserId) -> Result<UserAiUsage, RepositoryError> {
        sqlx::query_as!(
            UserAiUsage,
            r#"
            SELECT
                user_id as "user_id!: UserId",
                COUNT(*)::bigint as "request_count!",
                COALESCE(SUM(tokens_used), 0)::bigint as "total_tokens!",
                COALESCE(SUM(cost_microdollars), 0)::float8 / 1000000.0 as "total_cost!",
                AVG(tokens_used)::float8 as "avg_tokens_per_request"
            FROM ai_requests
            WHERE user_id = $1
            GROUP BY user_id
            "#,
            user_id.as_str()
        )
        .fetch_one(self.pool())
        .await
        .map_err(RepositoryError::from)
    }

    pub async fn get_session_usage(
        &self,
        session_id: &SessionId,
    ) -> Result<UserAiUsage, RepositoryError> {
        sqlx::query_as!(
            UserAiUsage,
            r#"
            SELECT
                user_id as "user_id!: UserId",
                COUNT(*)::bigint as "request_count!",
                COALESCE(SUM(tokens_used), 0)::bigint as "total_tokens!",
                COALESCE(SUM(cost_microdollars), 0)::float8 / 1000000.0 as "total_cost!",
                AVG(tokens_used)::float8 as "avg_tokens_per_request"
            FROM ai_requests
            WHERE session_id = $1
            GROUP BY user_id
            "#,
            session_id.as_str()
        )
        .fetch_one(self.pool())
        .await
        .map_err(RepositoryError::from)
    }
}