systemprompt-ai 0.1.21

Core AI module for systemprompt.io
Documentation
use crate::error::{AiError, Result};
use crate::models::AiRequestRecordBuilder;
use crate::models::image_generation::{ImageGenerationRequest, ImageGenerationResponse};
use crate::repository::AiRequestRepository;
use systemprompt_identifiers::{FileId, McpExecutionId, SessionId, TraceId, UserId};
use systemprompt_traits::{
    AiFilePersistenceProvider, AiGeneratedFile, ImageGenerationInfo, ImageMetadata,
    InsertAiFileParams,
};
use uuid::Uuid;

pub struct FileLocation<'a> {
    pub path: &'a str,
    pub public_url: &'a str,
}

pub async fn persist_image_generation(
    ai_request_repo: &AiRequestRepository,
    file_provider: &dyn AiFilePersistenceProvider,
    request: &ImageGenerationRequest,
    response: &ImageGenerationResponse,
    location: FileLocation<'_>,
) -> Result<()> {
    persist_ai_request(ai_request_repo, request, response).await?;
    persist_file_record(
        file_provider,
        request,
        response,
        location.path,
        location.public_url,
    )
    .await
}

async fn persist_ai_request(
    ai_request_repo: &AiRequestRepository,
    request: &ImageGenerationRequest,
    response: &ImageGenerationResponse,
) -> Result<()> {
    let user_id = UserId::new(request.user_id.as_deref().unwrap_or("anonymous"));

    let mut builder = AiRequestRecordBuilder::new(&response.request_id, user_id)
        .provider(&response.provider)
        .model(&response.model)
        .cost(response.cost_estimate.map_or(0, |c| c.round() as i64))
        .latency(response.generation_time_ms as i32)
        .completed();

    if let Some(session_id) = &request.session_id {
        builder = builder.session_id(SessionId::new(session_id));
    }

    if let Some(trace_id) = &request.trace_id {
        builder = builder.trace_id(TraceId::new(trace_id));
    }

    if let Some(mcp_execution_id) = &request.mcp_execution_id {
        builder = builder.mcp_execution_id(McpExecutionId::new(mcp_execution_id));
    }

    let record = builder
        .build()
        .map_err(|e| AiError::InvalidInput(e.to_string()))?;

    ai_request_repo
        .insert(&record)
        .await
        .map(|_| ())
        .map_err(|e| AiError::DatabaseError(e.to_string()))
}

async fn persist_file_record(
    file_provider: &dyn AiFilePersistenceProvider,
    request: &ImageGenerationRequest,
    response: &ImageGenerationResponse,
    file_path: &str,
    public_url: &str,
) -> Result<()> {
    let generation_info =
        ImageGenerationInfo::new(&request.prompt, &response.model, &response.provider)
            .with_resolution(response.resolution.as_str())
            .with_aspect_ratio(response.aspect_ratio.as_str())
            .with_generation_time(response.generation_time_ms as i32)
            .with_request_id(&response.request_id);

    let generation_info = match response.cost_estimate {
        Some(cost) => generation_info.with_cost_estimate(cost),
        None => generation_info,
    };

    let image_metadata = ImageMetadata::new().with_generation(generation_info);
    let metadata = serde_json::to_value(image_metadata).map_err(AiError::SerializationError)?;

    let id = Uuid::parse_str(&response.id)
        .map_err(|e| AiError::InvalidInput(format!("Invalid UUID: {e}")))?;

    let params = InsertAiFileParams {
        id,
        path: file_path.to_string(),
        public_url: public_url.to_string(),
        mime_type: response.mime_type.clone(),
        size_bytes: response.file_size_bytes.map(|s| s as i64),
        metadata,
        user_id: request.user_id.as_ref().map(UserId::new),
        session_id: request.session_id.as_ref().map(SessionId::new),
        trace_id: request.trace_id.as_ref().map(TraceId::new),
        context_id: None,
    };

    file_provider
        .insert_file(params)
        .await
        .map_err(|e| AiError::DatabaseError(e.to_string()))
}

pub async fn get_generated_image(
    file_provider: &dyn AiFilePersistenceProvider,
    uuid: &str,
) -> Result<Option<AiGeneratedFile>> {
    file_provider
        .find_by_id(&FileId::new(uuid))
        .await
        .map_err(|e| AiError::DatabaseError(e.to_string()))
}

pub async fn list_user_images(
    file_provider: &dyn AiFilePersistenceProvider,
    user_id: &UserId,
    limit: Option<i64>,
    offset: Option<i64>,
) -> Result<Vec<AiGeneratedFile>> {
    let limit = limit.unwrap_or(50);
    let offset = offset.unwrap_or(0);
    file_provider
        .list_by_user(user_id, limit, offset)
        .await
        .map_err(|e| AiError::DatabaseError(e.to_string()))
}

pub async fn delete_image(
    file_provider: &dyn AiFilePersistenceProvider,
    storage: &crate::services::storage::ImageStorage,
    uuid: &str,
) -> Result<()> {
    let file_id = FileId::new(uuid);
    let file = file_provider
        .find_by_id(&file_id)
        .await
        .map_err(|e| AiError::DatabaseError(e.to_string()))?;

    if let Some(file_record) = file {
        let file_path = std::path::Path::new(&file_record.path);
        storage.delete_image(file_path)?;
        file_provider
            .delete(&file_id)
            .await
            .map_err(|e| AiError::DatabaseError(e.to_string()))?;
    }

    Ok(())
}