litellm-rs 0.1.1

A high-performance AI Gateway written in Rust, providing OpenAI-compatible APIs with intelligent routing, load balancing, and enterprise features
//! Stability AI provider implementation
//!
//! This module provides Stability AI API integration for image generation.

use super::{BaseProvider, ModelPricing, Provider, ProviderError, ProviderType};
use crate::config::ProviderConfig;
use crate::core::models::{RequestContext, openai::*};
use crate::utils::error::Result;
use async_trait::async_trait;
use serde_json::json;
use std::collections::HashMap;
use tracing::{debug, info};

/// Stability AI provider implementation
#[derive(Debug, Clone)]
pub struct StabilityAIProvider {
    /// Base provider functionality
    base: BaseProvider,
    /// Model pricing cache
    pricing_cache: HashMap<String, ModelPricing>,
}

impl StabilityAIProvider {
    /// Create a new Stability AI provider
    pub async fn new(config: &ProviderConfig) -> Result<Self> {
        let base = BaseProvider::new(config)?;

        let base_url = config
            .base_url
            .clone()
            .unwrap_or_else(|| "https://api.stability.ai".to_string());

        let provider = Self {
            base: BaseProvider { base_url, ..base },
            pricing_cache: Self::initialize_pricing_cache(),
        };

        info!(
            "Stability AI provider '{}' initialized successfully",
            config.name
        );
        Ok(provider)
    }

    /// Initialize pricing cache with Stability AI model prices
    fn initialize_pricing_cache() -> HashMap<String, ModelPricing> {
        let mut cache = HashMap::new();

        // Stable Diffusion models (per image)
        cache.insert(
            "stable-diffusion-xl-1024-v1-0".to_string(),
            ModelPricing {
                model: "stable-diffusion-xl-1024-v1-0".to_string(),
                input_cost_per_1k: 0.0,   // No input tokens
                output_cost_per_1k: 0.04, // $0.04 per image
                currency: "USD".to_string(),
                updated_at: chrono::Utc::now(),
            },
        );

        cache.insert(
            "stable-diffusion-v1-6".to_string(),
            ModelPricing {
                model: "stable-diffusion-v1-6".to_string(),
                input_cost_per_1k: 0.0,
                output_cost_per_1k: 0.02, // $0.02 per image
                currency: "USD".to_string(),
                updated_at: chrono::Utc::now(),
            },
        );

        cache.insert(
            "stable-diffusion-512-v2-1".to_string(),
            ModelPricing {
                model: "stable-diffusion-512-v2-1".to_string(),
                input_cost_per_1k: 0.0,
                output_cost_per_1k: 0.02,
                currency: "USD".to_string(),
                updated_at: chrono::Utc::now(),
            },
        );

        cache.insert(
            "stable-diffusion-xl-beta-v2-2-2".to_string(),
            ModelPricing {
                model: "stable-diffusion-xl-beta-v2-2-2".to_string(),
                input_cost_per_1k: 0.0,
                output_cost_per_1k: 0.08,
                currency: "USD".to_string(),
                updated_at: chrono::Utc::now(),
            },
        );

        cache
    }

    /// Convert Stability AI response to OpenAI format
    fn convert_stability_response_to_openai(
        &self,
        stability_response: serde_json::Value,
    ) -> Result<ImageGenerationResponse> {
        let artifacts = stability_response
            .get("artifacts")
            .and_then(|a| a.as_array())
            .ok_or_else(|| ProviderError::Parsing("No artifacts in response".to_string()))?;

        let data: Vec<ImageObject> = artifacts
            .iter()
            .filter_map(|artifact| {
                let base64 = artifact.get("base64")?.as_str()?;
                Some(ImageObject {
                    url: None,
                    b64_json: Some(base64.to_string()),
                })
            })
            .collect();

        if data.is_empty() {
            return Err(ProviderError::Parsing("No valid images in response".to_string()).into());
        }

        Ok(ImageGenerationResponse {
            created: chrono::Utc::now().timestamp() as u64,
            data,
        })
    }
}

#[async_trait]
impl Provider for StabilityAIProvider {
    fn name(&self) -> &str {
        &self.base.name
    }

    fn provider_type(&self) -> ProviderType {
        ProviderType::Custom("stability_ai".to_string())
    }

    async fn supports_model(&self, model: &str) -> bool {
        self.base.is_model_supported(model) || model.contains("stable-diffusion")
    }

    async fn supports_images(&self) -> bool {
        true // Stability AI specializes in image generation
    }

    async fn supports_embeddings(&self) -> bool {
        false
    }

    async fn supports_streaming(&self) -> bool {
        false
    }

    async fn list_models(&self) -> Result<Vec<Model>> {
        let url = format!("{}/v1/engines/list", self.base.base_url);

        let response = self
            .base
            .client
            .get(&url)
            .header("Authorization", format!("Bearer {}", self.base.api_key))
            .send()
            .await
            .map_err(|e| ProviderError::Network(e.to_string()))?;

        if !response.status().is_success() {
            // Fallback to known models
            let known_models = vec![
                "stable-diffusion-xl-1024-v1-0",
                "stable-diffusion-v1-6",
                "stable-diffusion-512-v2-1",
                "stable-diffusion-xl-beta-v2-2-2",
            ];

            let models = known_models
                .into_iter()
                .map(|model| Model {
                    id: model.to_string(),
                    object: "model".to_string(),
                    created: chrono::Utc::now().timestamp() as u64,
                    owned_by: "stability-ai".to_string(),
                })
                .collect();

            return Ok(models);
        }

        let engines_response: serde_json::Value = self.base.parse_json_response(response).await?;

        let models = engines_response
            .as_array()
            .unwrap_or(&vec![])
            .iter()
            .filter_map(|engine| {
                Some(Model {
                    id: engine.get("id")?.as_str()?.to_string(),
                    object: "model".to_string(),
                    created: chrono::Utc::now().timestamp() as u64,
                    owned_by: "stability-ai".to_string(),
                })
            })
            .collect();

        Ok(models)
    }

    async fn health_check(&self) -> Result<()> {
        debug!("Performing Stability AI health check");

        let url = format!("{}/v1/user/account", self.base.base_url);

        let response = self
            .base
            .client
            .get(&url)
            .header("Authorization", format!("Bearer {}", self.base.api_key))
            .send()
            .await
            .map_err(|e| ProviderError::Network(e.to_string()))?;

        if response.status().is_success() {
            Ok(())
        } else {
            Err(
                ProviderError::Unknown(format!("Health check failed: {}", response.status()))
                    .into(),
            )
        }
    }

    async fn chat_completion(
        &self,
        _request: ChatCompletionRequest,
        _context: RequestContext,
    ) -> Result<ChatCompletionResponse> {
        Err(ProviderError::InvalidRequest(
            "Chat completion not supported by Stability AI".to_string(),
        )
        .into())
    }

    async fn completion(
        &self,
        _request: CompletionRequest,
        _context: RequestContext,
    ) -> Result<CompletionResponse> {
        Err(ProviderError::InvalidRequest(
            "Text completion not supported by Stability AI".to_string(),
        )
        .into())
    }

    async fn embedding(
        &self,
        _request: EmbeddingRequest,
        _context: RequestContext,
    ) -> Result<EmbeddingResponse> {
        Err(
            ProviderError::InvalidRequest("Embeddings not supported by Stability AI".to_string())
                .into(),
        )
    }

    async fn image_generation(
        &self,
        request: ImageGenerationRequest,
        _context: RequestContext,
    ) -> Result<ImageGenerationResponse> {
        debug!(
            "Stability AI image generation for model: {:?}",
            request.model
        );

        let mut body = json!({
            "text_prompts": [
                {
                    "text": request.prompt,
                    "weight": 1.0
                }
            ]
        });

        // Add optional parameters
        if let Some(n) = request.n {
            body["samples"] = json!(n);
        }

        // Parse size parameter
        if let Some(size) = &request.size {
            if let Some((width, height)) = size.split_once('x') {
                if let (Ok(w), Ok(h)) = (width.parse::<u32>(), height.parse::<u32>()) {
                    body["width"] = json!(w);
                    body["height"] = json!(h);
                }
            }
        }

        // Set default parameters
        body["cfg_scale"] = json!(7);
        body["steps"] = json!(30);

        let default_model = "stable-diffusion-xl-1024-v1-0".to_string();
        let model = request.model.as_ref().unwrap_or(&default_model);
        let url = format!(
            "{}/v1/generation/{}/text-to-image",
            self.base.base_url, model
        );

        let response = self
            .base
            .client
            .post(&url)
            .header("Authorization", format!("Bearer {}", self.base.api_key))
            .header("Content-Type", "application/json")
            .header("Accept", "application/json")
            .json(&body)
            .send()
            .await
            .map_err(|e| ProviderError::Network(e.to_string()))?;

        if !response.status().is_success() {
            let status = response.status();
            let error_text = response.text().await.unwrap_or_default();

            return Err(match status.as_u16() {
                401 => ProviderError::Authentication(error_text),
                429 => ProviderError::RateLimit(error_text),
                404 => ProviderError::ModelNotFound(error_text),
                400 => ProviderError::InvalidRequest(error_text),
                _ => ProviderError::Unknown(format!("HTTP {}: {}", status, error_text)),
            }
            .into());
        }

        let stability_response: serde_json::Value = self.base.parse_json_response(response).await?;
        self.convert_stability_response_to_openai(stability_response)
    }

    async fn get_model_pricing(&self, model: &str) -> Result<ModelPricing> {
        if let Some(pricing) = self.pricing_cache.get(model) {
            Ok(pricing.clone())
        } else {
            Ok(ModelPricing {
                model: model.to_string(),
                input_cost_per_1k: 0.0,
                output_cost_per_1k: 0.03, // Default $0.03 per image
                currency: "USD".to_string(),
                updated_at: chrono::Utc::now(),
            })
        }
    }

    async fn calculate_cost(
        &self,
        model: &str,
        _input_tokens: u32,
        output_tokens: u32,
    ) -> Result<f64> {
        let pricing = self.get_model_pricing(model).await?;

        // For image generation, output_tokens represents number of images
        let cost = (output_tokens as f64) * pricing.output_cost_per_1k;

        Ok(cost)
    }
}