langdb_core 0.3.2

AI gateway Core for LangDB AI Gateway.
Documentation
use std::collections::HashMap;

use crate::events::SPAN_OPENAI;
use crate::model::error::ModelError;
use async_openai::config::Config;
use async_openai::{config::OpenAIConfig, Client};

use crate::model::types::ModelEventType;
use crate::{
    error::GatewayError,
    model::{
        openai::openai_client,
        types::{ImageGenerationFinishEvent, ModelEvent},
        CredentialsIdent,
    },
    types::{
        credentials::ApiKeyCredentials,
        gateway::{CreateImageRequest, ImageQuality, ImageResponseFormat, ImageSize, ImageStyle},
        image::ImagesResponse,
    },
    GatewayResult,
};

use super::ImageGenerationModelInstance;
use crate::model::JsonValue;
use secrecy::ExposeSecret;
use serde::Deserialize;
use tracing::field;
use valuable::Valuable;

#[derive(Debug, Deserialize, Clone)]
pub struct OpenAIReqwestError {
    pub error: ApiError,
}

#[derive(Debug, Deserialize, Clone)]
pub struct ApiError {
    pub message: String,
    pub r#type: Option<String>,
    pub param: Option<String>,
    pub code: Option<String>,
}

#[derive(Clone)]
pub struct OpenAIImageGeneration {
    client: Client<OpenAIConfig>,
    credentials_ident: CredentialsIdent,
}

impl OpenAIImageGeneration {
    pub fn new(
        credentials: Option<&ApiKeyCredentials>,
        client: Option<Client<OpenAIConfig>>,
        endpoint: Option<&str>,
    ) -> Result<Self, ModelError> {
        Ok(OpenAIImageGeneration {
            credentials_ident: credentials
                .map(|_c| CredentialsIdent::Own)
                .unwrap_or(CredentialsIdent::Langdb),
            client: client.unwrap_or(openai_client(credentials, endpoint)?),
        })
    }

    fn generate_event(
        &self,
        model_name: &str,
        quality: Option<&ImageQuality>,
        size: Option<&ImageSize>,
        count_of_images: u8,
        steps: u8,
    ) -> ImageGenerationFinishEvent {
        ImageGenerationFinishEvent {
            model_name: model_name.to_string(),
            quality: quality
                .map(|q| q.to_string())
                .unwrap_or("standard".to_string()),
            size: size.cloned().unwrap_or(ImageSize::Size1024x1024),
            count_of_images,
            steps,
            credentials_ident: self.credentials_ident.clone(),
        }
    }

    fn map_size(
        &self,
        size: Option<&ImageSize>,
    ) -> Option<Result<async_openai::types::ImageSize, GatewayError>> {
        size.map(|s| match s {
            crate::types::gateway::ImageSize::Size256x256 => {
                Ok(async_openai::types::ImageSize::S256x256)
            }
            crate::types::gateway::ImageSize::Size512x512 => {
                Ok(async_openai::types::ImageSize::S512x512)
            }
            crate::types::gateway::ImageSize::Size1024x1024 => {
                Ok(async_openai::types::ImageSize::S1024x1024)
            }
            crate::types::gateway::ImageSize::Size1792x1024 => {
                Ok(async_openai::types::ImageSize::S1792x1024)
            }
            crate::types::gateway::ImageSize::Size1024x1792 => {
                Ok(async_openai::types::ImageSize::S1024x1792)
            }
            crate::types::gateway::ImageSize::Other((width, height)) => Err(
                GatewayError::CustomError(format!("Unsupported image size: {width}x{height}")),
            ),
        })
    }

    fn map_quality(
        &self,
        quality: Option<&ImageQuality>,
    ) -> Option<async_openai::types::ImageQuality> {
        quality.map(|q| match q {
            crate::types::gateway::ImageQuality::SD => async_openai::types::ImageQuality::Standard,
            crate::types::gateway::ImageQuality::HD => async_openai::types::ImageQuality::HD,
        })
    }
}

#[async_trait::async_trait]
impl ImageGenerationModelInstance for OpenAIImageGeneration {
    async fn create_new(
        &self,
        request: &CreateImageRequest,
        tx: tokio::sync::mpsc::Sender<Option<ModelEvent>>,
        tags: HashMap<String, String>,
    ) -> GatewayResult<ImagesResponse> {
        let input = serde_json::to_string(request)?;
        let call_span = tracing::info_span!(target: "langdb::user_tracing::models::openai::image_generation", SPAN_OPENAI, input = input, output = field::Empty, error = field::Empty, usage = field::Empty, ttft = field::Empty, tags = JsonValue(&serde_json::to_value(tags.clone()).unwrap_or_default()).as_value());

        let size = self.map_size(request.size.as_ref());

        let size = match size {
            Some(Ok(s)) => Some(s),
            Some(Err(e)) => return Err(e),
            None => None,
        };

        let quality = self.map_quality(request.quality.as_ref());

        let model = serde_json::from_str(&format!("\"{}\"", request.model))?;

        let r = async_openai::types::CreateImageRequest {
            prompt: request.prompt.clone(),
            n: request.n,
            size,
            response_format: request.response_format.as_ref().map(|f| match f {
                ImageResponseFormat::Url => async_openai::types::ImageResponseFormat::Url,
                ImageResponseFormat::B64Json => async_openai::types::ImageResponseFormat::B64Json,
            }),
            user: request.user.clone(),
            model: Some(model),
            quality,
            style: request.style.as_ref().map(|s| match s {
                ImageStyle::Vivid => async_openai::types::ImageStyle::Vivid,
                ImageStyle::Natural => async_openai::types::ImageStyle::Natural,
            }),
            moderation: request.moderation.as_ref().map(|m| match m {
                crate::types::gateway::ImageModeration::Auto => {
                    async_openai::types::ImageModeration::Auto
                }
                crate::types::gateway::ImageModeration::Low => {
                    async_openai::types::ImageModeration::Low
                }
            }),
        };

        let api_base = self.client.config().api_base().to_string();
        let api_key: String = self.client.config().api_key().expose_secret().to_string();

        let reqwest_client = reqwest::Client::new();
        let reqwest_result = reqwest_client
            .post(format!("{api_base}/images/generations"))
            .header("Authorization", format!("Bearer {api_key}"))
            .json(&r)
            .send()
            .await?;

        if reqwest_result.status().is_success() {
            let result = reqwest_result.json::<ImagesResponse>().await?;

            let event = self.generate_event(
                &request.model,
                request.quality.as_ref(),
                request.size.as_ref(),
                request.n.unwrap_or(1),
                1,
            );

            tx.send(Some(ModelEvent::new(
                &call_span,
                ModelEventType::ImageGenerationFinish(event),
            )))
            .await
            .unwrap();

            Ok(result)
        } else {
            let r: OpenAIReqwestError = reqwest_result.json().await.map_err(|e| {
                call_span.record("error", e.to_string());
                GatewayError::CustomError(format!("Failed to generate image: {e}"))
            })?;
            Err(GatewayError::CustomError(format!(
                "Failed to generate image: {}",
                r.error.message
            )))
        }
    }
}