systemprompt-ai 0.1.19

Core AI module for systemprompt.io
Documentation
use anyhow::{Result, anyhow};
use std::collections::HashMap;
use std::sync::Arc;
use systemprompt_models::services::AiProviderConfig;

use super::{BoxedImageProvider, GeminiImageProvider, OpenAiImageProvider};

#[derive(Debug, Copy, Clone)]
pub struct ImageProviderFactory;

impl ImageProviderFactory {
    pub fn create(name: &str, config: &AiProviderConfig) -> Result<BoxedImageProvider> {
        if !config.enabled {
            return Err(anyhow!("Image provider {name} is disabled"));
        }

        match name {
            "gemini" => Ok(Self::create_gemini(config)),
            "openai" => Ok(Self::create_openai(config)),
            _ => Err(anyhow!("Unknown image provider: {name}")),
        }
    }

    pub fn create_with_fallback(
        name: &str,
        config: &AiProviderConfig,
        all_configs: &HashMap<String, AiProviderConfig>,
    ) -> Result<BoxedImageProvider> {
        match Self::create(name, config) {
            Ok(provider) => Ok(provider),
            Err(_) if !Self::supports_image_generation(name) => {
                for fallback_name in &["openai", "gemini"] {
                    if let Some(fallback_config) = all_configs.get(*fallback_name) {
                        if fallback_config.enabled {
                            if let Ok(provider) = Self::create(fallback_name, fallback_config) {
                                tracing::info!(
                                    primary = %name,
                                    fallback = %fallback_name,
                                    "Using fallback image provider"
                                );
                                return Ok(provider);
                            }
                        }
                    }
                }
                Err(anyhow!(
                    "No image provider available (primary: {} does not support images)",
                    name
                ))
            },
            Err(e) => Err(e),
        }
    }

    pub fn supports_image_generation(provider_name: &str) -> bool {
        matches!(provider_name, "openai" | "gemini")
    }

    fn create_gemini(config: &AiProviderConfig) -> BoxedImageProvider {
        let base = config.endpoint.as_ref().map_or_else(
            || GeminiImageProvider::new(config.api_key.clone()),
            |ep| GeminiImageProvider::with_endpoint(config.api_key.clone(), ep.clone()),
        );

        let provider = base.with_model_definitions(config.models.clone());

        let provider = match config.default_image_model.as_str() {
            "" => provider,
            model => provider.with_default_model(model.to_string()),
        };

        Arc::new(provider)
    }

    fn create_openai(config: &AiProviderConfig) -> BoxedImageProvider {
        let base = config.endpoint.as_ref().map_or_else(
            || OpenAiImageProvider::new(config.api_key.clone()),
            |ep| OpenAiImageProvider::with_endpoint(config.api_key.clone(), ep.clone()),
        );

        let provider = match config.default_image_model.as_str() {
            "" => base,
            model => base.with_default_model(model.to_string()),
        };

        Arc::new(provider)
    }

    pub fn create_all(
        configs: &HashMap<String, AiProviderConfig>,
    ) -> Result<HashMap<String, BoxedImageProvider>> {
        let mut providers = HashMap::new();

        for (name, config) in configs.iter().filter(|(_, c)| c.enabled) {
            match Self::create(name, config) {
                Ok(provider) => {
                    providers.insert(name.clone(), provider);
                },
                Err(e) => {
                    tracing::warn!(provider = %name, error = %e, "Failed to create image provider");
                },
            }
        }

        if providers.is_empty() {
            return Err(anyhow!("No image providers could be initialized"));
        }

        Ok(providers)
    }
}