systemprompt-ai 0.1.21

Core AI module for systemprompt.io
Documentation
use crate::error::RepositoryError;
use crate::models::{AiRequest, AiRequestRecord, RequestStatus};
use systemprompt_identifiers::{AiRequestId, SessionId, TraceId};

use super::{AiRequestRepository, CreateAiRequest};

#[derive(Debug)]
pub struct UpdateCompletionParams {
    pub id: AiRequestId,
    pub tokens_used: i32,
    pub input_tokens: i32,
    pub output_tokens: i32,
    pub cost_microdollars: i64,
    pub latency_ms: i32,
}

impl AiRequestRepository {
    #[must_use = "this returns a Result that should not be ignored"]
    pub async fn create(&self, request: CreateAiRequest<'_>) -> Result<AiRequest, RepositoryError> {
        let id = AiRequestId::generate();
        let logical_request_id = AiRequestId::generate();

        sqlx::query_as!(
            AiRequest,
            r#"
            INSERT INTO ai_requests (
                id, request_id, user_id, session_id, trace_id, provider, model,
                temperature, max_tokens, status, created_at
            )
            VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, CURRENT_TIMESTAMP)
            RETURNING id, request_id, user_id, session_id, task_id, context_id, trace_id,
                      provider, model, temperature, top_p, max_tokens, tokens_used,
                      input_tokens, output_tokens, cost_microdollars, latency_ms, cache_hit,
                      cache_read_tokens, cache_creation_tokens, is_streaming, status,
                      error_message, created_at, updated_at, completed_at
            "#,
            id.as_str(),
            logical_request_id.as_str(),
            request.user_id.as_str(),
            request.session_id.map(SessionId::as_str),
            request.trace_id.map(TraceId::as_str),
            request.provider,
            request.model,
            request.temperature,
            request.max_tokens,
            RequestStatus::Pending.as_str()
        )
        .fetch_one(self.write_pool())
        .await
        .map_err(RepositoryError::from)
    }

    #[must_use = "this returns a Result that should not be ignored"]
    pub async fn update_completion(
        &self,
        params: UpdateCompletionParams,
    ) -> Result<AiRequest, RepositoryError> {
        sqlx::query_as!(
            AiRequest,
            r#"
            UPDATE ai_requests
            SET tokens_used = $1, input_tokens = $2, output_tokens = $3,
                cost_microdollars = $4, latency_ms = $5, status = $6,
                completed_at = CURRENT_TIMESTAMP, updated_at = CURRENT_TIMESTAMP
            WHERE id = $7
            RETURNING id, request_id, user_id, session_id, task_id, context_id, trace_id,
                      provider, model, temperature, top_p, max_tokens, tokens_used,
                      input_tokens, output_tokens, cost_microdollars, latency_ms, cache_hit,
                      cache_read_tokens, cache_creation_tokens, is_streaming, status,
                      error_message, created_at, updated_at, completed_at
            "#,
            params.tokens_used,
            params.input_tokens,
            params.output_tokens,
            params.cost_microdollars,
            params.latency_ms,
            RequestStatus::Completed.as_str(),
            params.id.as_str()
        )
        .fetch_one(self.write_pool())
        .await
        .map_err(RepositoryError::from)
    }

    #[must_use = "this returns a Result that should not be ignored"]
    pub async fn update_error(
        &self,
        id: &AiRequestId,
        error_message: &str,
    ) -> Result<AiRequest, RepositoryError> {
        sqlx::query_as!(
            AiRequest,
            r#"
            UPDATE ai_requests
            SET status = $1, error_message = $2, completed_at = CURRENT_TIMESTAMP,
                updated_at = CURRENT_TIMESTAMP
            WHERE id = $3
            RETURNING id, request_id, user_id, session_id, task_id, context_id, trace_id,
                      provider, model, temperature, top_p, max_tokens, tokens_used,
                      input_tokens, output_tokens, cost_microdollars, latency_ms, cache_hit,
                      cache_read_tokens, cache_creation_tokens, is_streaming, status,
                      error_message, created_at, updated_at, completed_at
            "#,
            RequestStatus::Failed.as_str(),
            error_message,
            id.as_str()
        )
        .fetch_one(self.write_pool())
        .await
        .map_err(RepositoryError::from)
    }

    #[must_use = "this returns a Result that should not be ignored"]
    pub async fn insert(&self, record: &AiRequestRecord) -> Result<AiRequestId, RepositoryError> {
        use systemprompt_identifiers::{ContextId, McpExecutionId, SessionId, TaskId, TraceId};

        let id = AiRequestId::generate();
        let status = record.status.as_str();

        let use_completed_at = matches!(
            record.status,
            RequestStatus::Completed | RequestStatus::Failed
        );

        sqlx::query!(
            r#"
            INSERT INTO ai_requests (
                id, request_id, user_id, session_id, task_id, context_id, trace_id,
                mcp_execution_id, provider, model, max_tokens, tokens_used, input_tokens, output_tokens,
                cache_hit, cache_read_tokens, cache_creation_tokens, is_streaming,
                cost_microdollars, latency_ms, status, error_message,
                created_at, updated_at, completed_at
            )
            VALUES (
                $1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13, $14,
                $15, $16, $17, $18, $19, $20, $21, $22,
                CURRENT_TIMESTAMP, CURRENT_TIMESTAMP,
                CASE WHEN $23 THEN CURRENT_TIMESTAMP ELSE NULL END
            )
            "#,
            id.as_str(),
            record.request_id,
            record.user_id.as_str(),
            record.session_id.as_ref().map(SessionId::as_str),
            record.task_id.as_ref().map(TaskId::as_str),
            record.context_id.as_ref().map(ContextId::as_str),
            record.trace_id.as_ref().map(TraceId::as_str),
            record.mcp_execution_id.as_ref().map(McpExecutionId::as_str),
            record.provider,
            record.model,
            record.max_tokens,
            record.tokens.tokens_used,
            record.tokens.input_tokens,
            record.tokens.output_tokens,
            record.cache.hit,
            record.cache.read_tokens,
            record.cache.creation_tokens,
            record.is_streaming,
            record.cost_microdollars,
            record.latency_ms,
            status,
            record.error_message.as_deref(),
            use_completed_at
        )
        .execute(self.write_pool())
        .await?;
        Ok(id)
    }
}