systemprompt-files 0.1.21

File management module for systemprompt.io
Documentation
mod request;
mod stats;

pub use request::InsertFileRequest;
pub use stats::FileStats;

use std::sync::Arc;

use anyhow::{Context, Result};
use chrono::Utc;
use sqlx::PgPool;
use systemprompt_database::DbPool;
use systemprompt_identifiers::{ContextId, FileId, SessionId, TraceId, UserId};

use crate::models::{File, FileMetadata};

#[derive(Debug, Clone)]
pub struct FileRepository {
    pub(crate) pool: Arc<PgPool>,
    write_pool: Arc<PgPool>,
}

impl FileRepository {
    pub fn new(db: &DbPool) -> Result<Self> {
        let pool = db.pool_arc()?;
        let write_pool = db.write_pool_arc()?;
        Ok(Self { pool, write_pool })
    }

    pub async fn insert(&self, request: InsertFileRequest) -> Result<FileId> {
        let id_uuid = uuid::Uuid::parse_str(request.id.as_str())
            .with_context(|| format!("Invalid UUID for file id: {}", request.id.as_str()))?;
        let now = Utc::now();

        let user_id_str = request.user_id.as_ref().map(UserId::as_str);
        let session_id_str = request.session_id.as_ref().map(SessionId::as_str);
        let trace_id_str = request.trace_id.as_ref().map(TraceId::as_str);
        let context_id_str = request.context_id.as_ref().map(ContextId::as_str);

        sqlx::query_as!(
            File,
            r#"
            INSERT INTO files (id, path, public_url, mime_type, size_bytes, ai_content, metadata, user_id, session_id, trace_id, context_id, created_at, updated_at)
            VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $12)
            ON CONFLICT (path) DO UPDATE SET
                public_url = EXCLUDED.public_url,
                mime_type = EXCLUDED.mime_type,
                size_bytes = EXCLUDED.size_bytes,
                ai_content = EXCLUDED.ai_content,
                metadata = EXCLUDED.metadata,
                updated_at = EXCLUDED.updated_at
            RETURNING id, path, public_url, mime_type, size_bytes, ai_content, metadata, user_id as "user_id: UserId", session_id as "session_id: SessionId", trace_id as "trace_id: TraceId", context_id as "context_id: ContextId", created_at, updated_at, deleted_at
            "#,
            id_uuid,
            request.path,
            request.public_url,
            request.mime_type,
            request.size_bytes,
            request.ai_content,
            request.metadata,
            user_id_str,
            session_id_str,
            trace_id_str,
            context_id_str,
            now
        )
        .fetch_one(&*self.write_pool)
        .await
        .with_context(|| {
            format!(
                "Failed to insert file (id: {}, path: {}, url: {})",
                request.id.as_str(),
                request.path,
                request.public_url
            )
        })?;

        Ok(request.id)
    }

    pub async fn insert_file(&self, file: &File) -> Result<FileId> {
        let file_id = FileId::new(file.id.to_string());

        let mut request = InsertFileRequest::new(
            file_id.clone(),
            file.path.clone(),
            file.public_url.clone(),
            file.mime_type.clone(),
        )
        .with_ai_content(file.ai_content)
        .with_metadata(file.metadata.clone());

        if let Some(size) = file.size_bytes {
            request = request.with_size(size);
        }

        if let Some(ref user_id) = file.user_id {
            request = request.with_user_id(user_id.clone());
        }

        if let Some(ref session_id) = file.session_id {
            request = request.with_session_id(session_id.clone());
        }

        if let Some(ref trace_id) = file.trace_id {
            request = request.with_trace_id(trace_id.clone());
        }

        if let Some(ref context_id) = file.context_id {
            request = request.with_context_id(context_id.clone());
        }

        self.insert(request).await
    }

    pub async fn find_by_id(&self, id: &FileId) -> Result<Option<File>> {
        let id_uuid = uuid::Uuid::parse_str(id.as_str()).context("Invalid UUID for file id")?;

        sqlx::query_as!(
            File,
            r#"
            SELECT id, path, public_url, mime_type, size_bytes, ai_content, metadata, user_id as "user_id: UserId", session_id as "session_id: SessionId", trace_id as "trace_id: TraceId", context_id as "context_id: ContextId", created_at, updated_at, deleted_at
            FROM files
            WHERE id = $1 AND deleted_at IS NULL
            "#,
            id_uuid
        )
        .fetch_optional(&*self.pool)
        .await
        .context(format!("Failed to find file by id: {id}"))
    }

    pub async fn find_by_path(&self, path: &str) -> Result<Option<File>> {
        sqlx::query_as!(
            File,
            r#"
            SELECT id, path, public_url, mime_type, size_bytes, ai_content, metadata, user_id as "user_id: UserId", session_id as "session_id: SessionId", trace_id as "trace_id: TraceId", context_id as "context_id: ContextId", created_at, updated_at, deleted_at
            FROM files
            WHERE path = $1 AND deleted_at IS NULL
            "#,
            path
        )
        .fetch_optional(&*self.pool)
        .await
        .context(format!("Failed to find file by path: {path}"))
    }

    pub async fn list_by_user(
        &self,
        user_id: &UserId,
        limit: i64,
        offset: i64,
    ) -> Result<Vec<File>> {
        let user_id_str = user_id.as_str();
        sqlx::query_as!(
            File,
            r#"
            SELECT id, path, public_url, mime_type, size_bytes, ai_content, metadata, user_id as "user_id: UserId", session_id as "session_id: SessionId", trace_id as "trace_id: TraceId", context_id as "context_id: ContextId", created_at, updated_at, deleted_at
            FROM files
            WHERE user_id = $1 AND deleted_at IS NULL
            ORDER BY created_at DESC
            LIMIT $2 OFFSET $3
            "#,
            user_id_str,
            limit,
            offset
        )
        .fetch_all(&*self.pool)
        .await
        .context(format!("Failed to list files for user: {user_id}"))
    }

    pub async fn list_all(&self, limit: i64, offset: i64) -> Result<Vec<File>> {
        sqlx::query_as!(
            File,
            r#"
            SELECT id, path, public_url, mime_type, size_bytes, ai_content, metadata, user_id as "user_id: UserId", session_id as "session_id: SessionId", trace_id as "trace_id: TraceId", context_id as "context_id: ContextId", created_at, updated_at, deleted_at
            FROM files
            WHERE deleted_at IS NULL
            ORDER BY created_at DESC
            LIMIT $1 OFFSET $2
            "#,
            limit,
            offset
        )
        .fetch_all(&*self.pool)
        .await
        .context("Failed to list all files")
    }

    pub async fn delete(&self, id: &FileId) -> Result<()> {
        let id_uuid = uuid::Uuid::parse_str(id.as_str()).context("Invalid UUID for file id")?;

        sqlx::query!(
            r#"
            DELETE FROM files
            WHERE id = $1
            "#,
            id_uuid
        )
        .execute(&*self.write_pool)
        .await
        .context(format!("Failed to delete file: {id}"))?;

        Ok(())
    }

    pub async fn update_metadata(&self, id: &FileId, metadata: &FileMetadata) -> Result<()> {
        let id_uuid = uuid::Uuid::parse_str(id.as_str()).context("Invalid UUID for file id")?;
        let metadata_json = serde_json::to_value(metadata)?;
        let now = Utc::now();

        sqlx::query!(
            r#"
            UPDATE files
            SET metadata = $1, updated_at = $2
            WHERE id = $3
            "#,
            metadata_json,
            now,
            id_uuid
        )
        .execute(&*self.write_pool)
        .await
        .context(format!("Failed to update metadata for file: {id}"))?;

        Ok(())
    }

    pub async fn search_by_path(&self, query: &str, limit: i64) -> Result<Vec<File>> {
        let pattern = format!("%{query}%");
        sqlx::query_as!(
            File,
            r#"
            SELECT id, path, public_url, mime_type, size_bytes, ai_content, metadata,
                   user_id as "user_id: UserId", session_id as "session_id: SessionId",
                   trace_id as "trace_id: TraceId", context_id as "context_id: ContextId",
                   created_at, updated_at, deleted_at
            FROM files
            WHERE path ILIKE $1 AND deleted_at IS NULL
            ORDER BY created_at DESC
            LIMIT $2
            "#,
            pattern,
            limit
        )
        .fetch_all(&*self.pool)
        .await
        .context(format!("Failed to search files by path: {query}"))
    }
}