systemprompt-ai 0.1.21

Core AI module for systemprompt.io
Documentation
use crate::error::{AiError, Result};
use crate::models::image_generation::{ImageGenerationRequest, ImageGenerationResponse};
use crate::repository::AiRequestRepository;
use crate::services::providers::image_provider_trait::{
    BoxedImageProvider, ImageProviderCapabilities,
};
use crate::services::storage::{ImageStorage, StorageConfig};
use std::collections::HashMap;
use std::sync::Arc;
use systemprompt_database::DbPool;
use systemprompt_identifiers::UserId;
use systemprompt_traits::{AiGeneratedFile, DynAiFilePersistenceProvider};
use tracing::error;
use uuid::Uuid;

use super::image_persistence;

pub struct ImageService {
    providers: HashMap<String, BoxedImageProvider>,
    storage: Arc<ImageStorage>,
    file_provider: DynAiFilePersistenceProvider,
    ai_request_repo: AiRequestRepository,
    default_provider: Option<String>,
}

impl std::fmt::Debug for ImageService {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        f.debug_struct("ImageService")
            .field("providers", &format!("{} providers", self.providers.len()))
            .field("storage", &self.storage)
            .field("file_provider", &"AiFilePersistenceProvider")
            .field("ai_request_repo", &"AiRequestRepository")
            .field("default_provider", &self.default_provider)
            .finish()
    }
}

impl ImageService {
    pub fn new(
        db_pool: &DbPool,
        storage_config: StorageConfig,
        file_provider: DynAiFilePersistenceProvider,
    ) -> Result<Self> {
        let storage = Arc::new(ImageStorage::new(storage_config)?);
        let ai_request_repo = AiRequestRepository::new(db_pool)?;

        Ok(Self {
            providers: HashMap::new(),
            storage,
            file_provider,
            ai_request_repo,
            default_provider: None,
        })
    }

    pub fn with_providers(
        db_pool: &DbPool,
        storage_config: StorageConfig,
        file_provider: DynAiFilePersistenceProvider,
        providers: HashMap<String, BoxedImageProvider>,
        default_provider: Option<String>,
    ) -> Result<Self> {
        let storage = Arc::new(ImageStorage::new(storage_config)?);
        let ai_request_repo = AiRequestRepository::new(db_pool)?;

        Ok(Self {
            providers,
            storage,
            file_provider,
            ai_request_repo,
            default_provider,
        })
    }

    pub fn register_provider(&mut self, provider: BoxedImageProvider) {
        let name = provider.name().to_string();
        self.providers.insert(name, provider);
    }

    pub fn set_default_provider(&mut self, provider_name: String) {
        self.default_provider = Some(provider_name);
    }

    pub fn get_provider(&self, name: &str) -> Option<&BoxedImageProvider> {
        self.providers.get(name)
    }

    pub fn list_providers(&self) -> Vec<String> {
        self.providers.keys().cloned().collect()
    }

    pub fn default_provider_name(&self) -> Option<&str> {
        self.default_provider.as_deref()
    }

    pub fn get_default_provider(&self) -> Option<&BoxedImageProvider> {
        self.default_provider
            .as_ref()
            .and_then(|name| self.providers.get(name))
    }

    pub fn default_provider_capabilities(&self) -> Option<ImageProviderCapabilities> {
        self.get_default_provider().map(|p| p.capabilities())
    }

    pub async fn generate_image(
        &self,
        mut request: ImageGenerationRequest,
    ) -> Result<ImageGenerationResponse> {
        let provider_name = if let Some(model) = &request.model {
            self.find_provider_for_model(model)?
        } else if let Some(default) = &self.default_provider {
            default.clone()
        } else {
            return Err(AiError::ConfigurationError(
                "No model specified and no default provider configured".to_string(),
            ));
        };

        let provider =
            self.providers
                .get(&provider_name)
                .ok_or_else(|| AiError::ProviderError {
                    provider: provider_name.clone(),
                    message: "Provider not found".to_string(),
                })?;

        if request.trace_id.is_none() {
            request.trace_id = Some(Uuid::new_v4().to_string());
        }

        let generation_result = provider.generate_image(&request).await;

        let mut response = match generation_result {
            Ok(resp) => resp,
            Err(e) => {
                error!(
                    error = %e,
                    provider = %provider_name,
                    model = ?request.model,
                    prompt_preview = %request.prompt.chars().take(200).collect::<String>(),
                    prompt_length = request.prompt.len(),
                    resolution = %request.resolution.as_str(),
                    aspect_ratio = %request.aspect_ratio.as_str(),
                    trace_id = ?request.trace_id,
                    user_id = ?request.user_id,
                    session_id = ?request.session_id,
                    reference_images_count = request.reference_images.len(),
                    "Image generation failed - full request context logged for debugging"
                );
                return Err(e);
            },
        };
        response.cost_estimate = Some(provider.capabilities().cost_per_image_cents);

        let (file_path, public_url) = self
            .storage
            .save_base64_image(&response.image_data, &response.mime_type)?;

        response.file_path = Some(file_path.to_string_lossy().to_string());
        response.public_url = Some(public_url.clone());
        response.file_size_bytes = Some(response.image_data.len());

        image_persistence::persist_image_generation(
            &self.ai_request_repo,
            self.file_provider.as_ref(),
            &request,
            &response,
            image_persistence::FileLocation {
                path: &file_path.to_string_lossy(),
                public_url: &public_url,
            },
        )
        .await?;

        Ok(response)
    }

    pub async fn generate_batch(
        &self,
        requests: Vec<ImageGenerationRequest>,
    ) -> Result<Vec<ImageGenerationResponse>> {
        let mut responses = Vec::new();

        for request in requests {
            match self.generate_image(request).await {
                Ok(response) => responses.push(response),
                Err(e) => {
                    return Err(e);
                },
            }
        }

        Ok(responses)
    }

    pub async fn get_generated_image(&self, uuid: &str) -> Result<Option<AiGeneratedFile>> {
        image_persistence::get_generated_image(self.file_provider.as_ref(), uuid).await
    }

    pub async fn list_user_images(
        &self,
        user_id: &UserId,
        limit: Option<i64>,
        offset: Option<i64>,
    ) -> Result<Vec<AiGeneratedFile>> {
        image_persistence::list_user_images(self.file_provider.as_ref(), user_id, limit, offset)
            .await
    }

    pub async fn delete_image(&self, uuid: &str) -> Result<()> {
        image_persistence::delete_image(self.file_provider.as_ref(), &self.storage, uuid).await
    }

    fn find_provider_for_model(&self, model: &str) -> Result<String> {
        for (name, provider) in &self.providers {
            if provider.supports_model(model) {
                return Ok(name.clone());
            }
        }

        Err(AiError::ProviderError {
            provider: "unknown".to_string(),
            message: format!("No provider found for model: {model}"),
        })
    }
}