systemprompt-ai 0.1.19

Core AI module for systemprompt.io
Documentation
use crate::error::{AiError, Result};
use crate::models::image_generation::{
    AspectRatio, ImageGenerationRequest, ImageGenerationResponse, ImageResolution,
    NewImageGenerationResponse,
};
use crate::models::providers::gemini::GeminiResponse;
use crate::services::providers::image_provider_trait::{ImageProvider, ImageProviderCapabilities};
use async_trait::async_trait;
use reqwest::Client;
use std::collections::HashMap;
use std::time::Instant;
use systemprompt_models::services::ModelDefinition;
use tracing::error;

use super::gemini_images_helpers::{build_image_request, extract_image_from_response};

#[derive(Debug)]
pub struct GeminiImageProvider {
    client: Client,
    api_key: String,
    endpoint: String,
    default_model: String,
    model_definitions: HashMap<String, ModelDefinition>,
}

impl GeminiImageProvider {
    pub fn new(api_key: String) -> Self {
        let client = Client::builder()
            .timeout(std::time::Duration::from_secs(300))
            .connect_timeout(std::time::Duration::from_secs(30))
            .build()
            .unwrap_or_else(|e| {
                error!(error = %e, "Failed to build HTTP client for GeminiImageProvider, using default");
                Client::new()
            });

        Self {
            client,
            api_key,
            endpoint: "https://generativelanguage.googleapis.com/v1beta".to_string(),
            default_model: "gemini-2.5-flash-image".to_string(),
            model_definitions: HashMap::new(),
        }
    }

    pub fn with_endpoint(api_key: String, endpoint: String) -> Self {
        let mut provider = Self::new(api_key);
        provider.endpoint = endpoint;
        provider
    }

    pub fn with_default_model(mut self, model: String) -> Self {
        self.default_model = model;
        self
    }

    pub fn with_model_definitions(mut self, models: HashMap<String, ModelDefinition>) -> Self {
        self.model_definitions = models;
        self
    }
}

#[async_trait]
impl ImageProvider for GeminiImageProvider {
    fn name(&self) -> &'static str {
        "gemini-image"
    }

    fn capabilities(&self) -> ImageProviderCapabilities {
        ImageProviderCapabilities {
            supported_resolutions: vec![
                ImageResolution::OneK,
                ImageResolution::TwoK,
                ImageResolution::FourK,
            ],
            supported_aspect_ratios: vec![
                AspectRatio::Square,
                AspectRatio::Landscape169,
                AspectRatio::Portrait916,
                AspectRatio::Landscape43,
                AspectRatio::Portrait34,
                AspectRatio::UltraWide,
            ],
            supports_batch: true,
            supports_image_editing: true,
            supports_search_grounding: true,
            max_prompt_length: 8000,
            cost_per_image_cents: 4.0,
        }
    }

    fn supported_models(&self) -> Vec<String> {
        vec![
            "gemini-3.1-flash-image-preview".to_string(),
            "gemini-3-pro-image-preview".to_string(),
            "gemini-2.5-flash-image".to_string(),
        ]
    }

    fn default_model(&self) -> &str {
        &self.default_model
    }

    async fn generate_image(
        &self,
        request: &ImageGenerationRequest,
    ) -> Result<ImageGenerationResponse> {
        let start = Instant::now();

        if request.prompt.len() > self.capabilities().max_prompt_length {
            return Err(AiError::ProviderError {
                provider: self.name().to_string(),
                message: format!(
                    "Prompt length {} exceeds maximum {}",
                    request.prompt.len(),
                    self.capabilities().max_prompt_length
                ),
            });
        }

        if !self.supports_resolution(&request.resolution) {
            return Err(AiError::ProviderError {
                provider: self.name().to_string(),
                message: format!("Resolution {} not supported", request.resolution.as_str()),
            });
        }

        if !self.supports_aspect_ratio(&request.aspect_ratio) {
            return Err(AiError::ProviderError {
                provider: self.name().to_string(),
                message: format!(
                    "Aspect ratio {} not supported",
                    request.aspect_ratio.as_str()
                ),
            });
        }

        let model = request
            .model
            .as_deref()
            .unwrap_or_else(|| self.default_model());

        if !self.supports_model(model) {
            return Err(AiError::ProviderError {
                provider: self.name().to_string(),
                message: format!("Model {model} not supported"),
            });
        }

        let gemini_request = build_image_request(request, model, &self.model_definitions);

        let url = format!("{}/models/{}:generateContent", self.endpoint, model);

        let response = self
            .client
            .post(&url)
            .header("x-goog-api-key", &self.api_key)
            .json(&gemini_request)
            .send()
            .await
            .map_err(|e| AiError::ProviderError {
                provider: self.name().to_string(),
                message: format!("HTTP request failed: {e}"),
            })?;

        if !response.status().is_success() {
            let status = response.status();
            let error_body = response
                .text()
                .await
                .unwrap_or_else(|e| format!("<error reading response: {}>", e));
            return Err(AiError::ProviderError {
                provider: self.name().to_string(),
                message: format!("API returned status {status}: {error_body}"),
            });
        }

        let gemini_response: GeminiResponse =
            response.json().await.map_err(|e| AiError::ProviderError {
                provider: self.name().to_string(),
                message: format!("Failed to parse response: {e}"),
            })?;

        let (image_data, mime_type) = extract_image_from_response(&gemini_response)?;

        let generation_time_ms = start.elapsed().as_millis() as u64;

        Ok(ImageGenerationResponse::new(NewImageGenerationResponse {
            provider: self.name().to_string(),
            model: model.to_string(),
            image_data,
            mime_type,
            resolution: request.resolution,
            aspect_ratio: request.aspect_ratio,
            generation_time_ms,
        }))
    }

    async fn generate_batch(
        &self,
        requests: &[ImageGenerationRequest],
    ) -> Result<Vec<ImageGenerationResponse>> {
        let mut responses = Vec::new();
        for request in requests {
            responses.push(self.generate_image(request).await?);
        }
        Ok(responses)
    }
}