systemprompt-ai 0.1.18

Core AI module for systemprompt.io
Documentation
use crate::error::{AiError, Result};
use crate::models::image_generation::{
    AspectRatio, ImageGenerationRequest, ImageGenerationResponse, ImageResolution,
    NewImageGenerationResponse,
};
use crate::services::providers::image_provider_trait::{ImageProvider, ImageProviderCapabilities};
use async_trait::async_trait;
use reqwest::Client;
use serde::{Deserialize, Serialize};
use std::time::Instant;

#[derive(Debug)]
pub struct OpenAiImageProvider {
    client: Client,
    api_key: String,
    endpoint: String,
    default_model: String,
}

impl OpenAiImageProvider {
    pub fn new(api_key: String) -> Self {
        let client = Client::builder()
            .timeout(std::time::Duration::from_secs(120))
            .build()
            .unwrap_or_else(|_| Client::new());

        Self {
            client,
            api_key,
            endpoint: "https://api.openai.com/v1".to_string(),
            default_model: "gpt-image-1".to_string(),
        }
    }

    pub fn with_endpoint(api_key: String, endpoint: String) -> Self {
        let mut provider = Self::new(api_key);
        provider.endpoint = endpoint;
        provider
    }

    pub fn with_default_model(mut self, model: String) -> Self {
        self.default_model = model;
        self
    }

    const fn map_size(aspect_ratio: &AspectRatio) -> &'static str {
        match aspect_ratio {
            AspectRatio::Square => "1024x1024",
            AspectRatio::Portrait916 | AspectRatio::Portrait34 => "1024x1792",
            _ => "1792x1024",
        }
    }
}

#[derive(Debug, Serialize)]
struct DalleRequest {
    model: String,
    prompt: String,
    size: String,
    quality: String,
    n: u32,
    response_format: String,
}

#[derive(Debug, Deserialize)]
struct DalleResponse {
    data: Vec<DalleImageData>,
}

#[derive(Debug, Deserialize)]
struct DalleImageData {
    b64_json: Option<String>,
}

#[async_trait]
impl ImageProvider for OpenAiImageProvider {
    fn name(&self) -> &'static str {
        "openai-image"
    }

    fn capabilities(&self) -> ImageProviderCapabilities {
        ImageProviderCapabilities {
            supported_resolutions: vec![ImageResolution::OneK],
            supported_aspect_ratios: vec![
                AspectRatio::Square,
                AspectRatio::Landscape169,
                AspectRatio::Portrait916,
            ],
            supports_batch: false,
            supports_image_editing: true,
            supports_search_grounding: false,
            max_prompt_length: 4000,
            cost_per_image_cents: 4.0,
        }
    }

    fn supported_models(&self) -> Vec<String> {
        vec![
            "gpt-image-1".to_string(),
            "gpt-image-1-mini".to_string(),
            "dall-e-3".to_string(),
            "dall-e-2".to_string(),
        ]
    }

    fn default_model(&self) -> &str {
        &self.default_model
    }

    async fn generate_image(
        &self,
        request: &ImageGenerationRequest,
    ) -> Result<ImageGenerationResponse> {
        let start = Instant::now();

        if request.prompt.len() > self.capabilities().max_prompt_length {
            return Err(AiError::ProviderError {
                provider: self.name().to_string(),
                message: format!(
                    "Prompt length {} exceeds maximum {}",
                    request.prompt.len(),
                    self.capabilities().max_prompt_length
                ),
            });
        }

        if !self.supports_resolution(&request.resolution) {
            return Err(AiError::ProviderError {
                provider: self.name().to_string(),
                message: format!("Resolution {} not supported", request.resolution.as_str()),
            });
        }

        if !self.supports_aspect_ratio(&request.aspect_ratio) {
            return Err(AiError::ProviderError {
                provider: self.name().to_string(),
                message: format!(
                    "Aspect ratio {} not supported",
                    request.aspect_ratio.as_str()
                ),
            });
        }

        let model = request
            .model
            .as_deref()
            .unwrap_or_else(|| self.default_model());

        if !self.supports_model(model) {
            return Err(AiError::ProviderError {
                provider: self.name().to_string(),
                message: format!("Model {model} not supported"),
            });
        }

        let dalle_request = DalleRequest {
            model: model.to_string(),
            prompt: request.prompt.clone(),
            size: Self::map_size(&request.aspect_ratio).to_string(),
            quality: "standard".to_string(),
            n: 1,
            response_format: "b64_json".to_string(),
        };

        let url = format!("{}/images/generations", self.endpoint);

        let response = self
            .client
            .post(&url)
            .header("Authorization", format!("Bearer {}", &self.api_key))
            .header("Content-Type", "application/json")
            .json(&dalle_request)
            .send()
            .await
            .map_err(|e| AiError::ProviderError {
                provider: self.name().to_string(),
                message: format!("HTTP request failed: {e}"),
            })?;

        if !response.status().is_success() {
            let status = response.status();
            let error_body = response
                .text()
                .await
                .unwrap_or_else(|e| format!("<error reading response: {}>", e));
            return Err(AiError::ProviderError {
                provider: self.name().to_string(),
                message: format!("API returned status {status}: {error_body}"),
            });
        }

        let dalle_response: DalleResponse =
            response.json().await.map_err(|e| AiError::ProviderError {
                provider: self.name().to_string(),
                message: format!("Failed to parse response: {e}"),
            })?;

        let image_data = dalle_response
            .data
            .first()
            .and_then(|d| d.b64_json.clone())
            .ok_or_else(|| AiError::ProviderError {
                provider: self.name().to_string(),
                message: "No image data in response".to_string(),
            })?;

        let generation_time_ms = start.elapsed().as_millis() as u64;

        Ok(ImageGenerationResponse::new(NewImageGenerationResponse {
            provider: self.name().to_string(),
            model: model.to_string(),
            image_data,
            mime_type: "image/png".to_string(),
            resolution: request.resolution,
            aspect_ratio: request.aspect_ratio,
            generation_time_ms,
        }))
    }

    async fn generate_batch(
        &self,
        requests: &[ImageGenerationRequest],
    ) -> Result<Vec<ImageGenerationResponse>> {
        let mut responses = Vec::new();
        for request in requests {
            responses.push(self.generate_image(request).await?);
        }
        Ok(responses)
    }
}