langdb_core 0.3.2

AI gateway Core for LangDB AI Gateway.
Documentation
use std::collections::HashMap;
use std::sync::Arc;

use crate::model::error::ModelError;
use langdb_open::OpenAISpecModel;
use openai::OpenAIImageGeneration;
use serde::Serialize;
use serde_json::Value;
use tracing::info_span;
use tracing_futures::Instrument;
use valuable::Valuable;

use crate::events::{JsonValue, RecordResult, SPAN_MODEL_CALL};
use crate::model::types::ModelEventType;
use crate::types::engine::{ImageGenerationEngineParams, ImageGenerationModelDefinition};
use crate::types::gateway::{CostCalculator, CreateImageRequest, ImageGenerationModelUsage, Usage};
use crate::types::image::ImagesResponse;
use crate::GatewayResult;

use super::types::ModelEvent;
use super::CredentialsIdent;
use tokio::sync::mpsc::channel;

pub mod langdb_open;
pub mod openai;

#[async_trait::async_trait]
pub trait ImageGenerationModelInstance: Sync + Send {
    async fn create_new(
        &self,
        request: &CreateImageRequest,
        tx: tokio::sync::mpsc::Sender<Option<ModelEvent>>,
        tags: HashMap<String, String>,
    ) -> GatewayResult<ImagesResponse>;
}

fn initialize_image_generation_model_instance(
    definition: ImageGenerationModelDefinition,
    cost_calculator: Option<Arc<Box<dyn CostCalculator>>>,
    endpoint: Option<&str>,
    provider_name: Option<&str>,
) -> Result<Box<dyn ImageGenerationModelInstance>, ModelError> {
    match &definition.engine {
        ImageGenerationEngineParams::OpenAi {
            credentials,
            endpoint,
            ..
        } => Ok(Box::new(TracedImageGenerationModel {
            inner: OpenAIImageGeneration::new(
                credentials.clone().as_ref(),
                None,
                endpoint.as_ref().map(|s| s.as_str()),
            )?,
            definition: definition.clone(),
            cost_calculator: cost_calculator.clone(),
        })),
        ImageGenerationEngineParams::LangdbOpen { credentials, .. } => {
            Ok(Box::new(TracedImageGenerationModel {
                inner: OpenAISpecModel::new(
                    credentials.clone().as_ref(),
                    endpoint,
                    provider_name.unwrap(),
                )?,
                definition: definition.clone(),
                cost_calculator: cost_calculator.clone(),
            }))
        }
    }
}

pub async fn initialize_image_generation(
    definition: ImageGenerationModelDefinition,
    cost_calculator: Option<Arc<Box<dyn CostCalculator>>>,
    endpoint: Option<&str>,
    provider_name: Option<&str>,
) -> Result<Box<dyn ImageGenerationModelInstance>, ModelError> {
    initialize_image_generation_model_instance(definition, cost_calculator, endpoint, provider_name)
}
pub struct TracedImageGenerationModel<Inner: ImageGenerationModelInstance> {
    inner: Inner,
    definition: ImageGenerationModelDefinition,
    cost_calculator: Option<Arc<Box<dyn CostCalculator>>>,
}

#[derive(Clone, Serialize)]
struct TracedImageGenerationModelDefinition {
    pub name: String,
    pub provider_name: String,
    pub engine_name: String,
    pub prompt_name: Option<String>,
    pub model_params: ImageGenerationModelDefinition,
    pub model_name: String,
}

impl TracedImageGenerationModelDefinition {
    pub fn sanitize_json(&self) -> GatewayResult<Value> {
        let mut model = self.clone();

        match &mut model.model_params.engine {
            ImageGenerationEngineParams::OpenAi {
                ref mut credentials,
                ..
            } => {
                credentials.take();
            }
            ImageGenerationEngineParams::LangdbOpen {
                ref mut credentials,
                ..
            } => {
                credentials.take();
            }
        }
        let model = serde_json::to_value(&model)?;
        Ok(model)
    }

    pub fn get_credentials_owner(&self) -> CredentialsIdent {
        match &self.model_params.engine {
            ImageGenerationEngineParams::OpenAi { credentials, .. } => match &credentials {
                Some(_) => CredentialsIdent::Own,
                None => CredentialsIdent::Langdb,
            },
            ImageGenerationEngineParams::LangdbOpen { credentials, .. } => match &credentials {
                Some(_) => CredentialsIdent::Own,
                None => CredentialsIdent::Langdb,
            },
        }
    }
}

impl From<ImageGenerationModelDefinition> for TracedImageGenerationModelDefinition {
    fn from(value: ImageGenerationModelDefinition) -> Self {
        Self {
            model_name: value.db_model.inference_model_name.clone(),
            name: value.name.clone(),
            provider_name: value.db_model.provider_name.clone(),
            engine_name: value.engine.engine_name().to_string(),
            prompt_name: None,
            model_params: value.clone(),
        }
    }
}

#[async_trait::async_trait]
impl<Inner: ImageGenerationModelInstance> ImageGenerationModelInstance
    for TracedImageGenerationModel<Inner>
{
    async fn create_new(
        &self,
        request: &CreateImageRequest,
        outer_tx: tokio::sync::mpsc::Sender<Option<ModelEvent>>,
        tags: HashMap<String, String>,
    ) -> GatewayResult<ImagesResponse> {
        let traced_model: TracedImageGenerationModelDefinition = self.definition.clone().into();
        let credentials_ident = traced_model.get_credentials_owner();
        let model = traced_model.sanitize_json()?;
        let model_str = serde_json::to_string(&model)?;
        let provider_name = self.definition.db_model.provider_name.clone();
        let request_str = serde_json::to_string(request)?;

        let (tx, mut rx) = channel::<Option<ModelEvent>>(outer_tx.max_capacity());
        let span = info_span!(
            target: "langdb::user_tracing::models", SPAN_MODEL_CALL,
            input = &request_str,
            model = model_str,
            provider_name = provider_name,
            output = tracing::field::Empty,
            error = tracing::field::Empty,
            credentials_identifier = credentials_ident.to_string(),
            cost = tracing::field::Empty,
            usage = tracing::field::Empty,
            tags = JsonValue(&serde_json::to_value(tags.clone())?).as_value(),
        );

        let cost_calculator = self.cost_calculator.clone();
        let price = self.definition.db_model.price.clone();
        tokio::spawn(
            async move {
                while let Some(Some(msg)) = rx.recv().await {
                    if let Some(cost_calculator) = cost_calculator.as_ref() {
                        if let ModelEventType::ImageGenerationFinish(generation_finish_event) =
                            &msg.event
                        {
                            let s = tracing::Span::current();
                            let u = ImageGenerationModelUsage {
                                quality: generation_finish_event.quality.clone(),
                                size: generation_finish_event.size.clone().into(),
                                images_count: generation_finish_event.count_of_images,
                                steps_count: generation_finish_event.steps,
                            };
                            match cost_calculator
                                .calculate_cost(
                                    &price,
                                    &Usage::ImageGenerationModelUsage(u.clone()),
                                    &credentials_ident,
                                )
                                .await
                            {
                                Ok(c) => {
                                    s.record("cost", serde_json::to_string(&c).unwrap());
                                }
                                Err(e) => {
                                    tracing::error!("Error calculating cost: {:?}", e);
                                }
                            };

                            s.record("usage", serde_json::to_string(&u).unwrap());
                        }
                    }

                    outer_tx.send(Some(msg)).await.unwrap();
                }
            }
            .instrument(span.clone()),
        );

        async {
            let result = self.inner.create_new(request, tx, tags).await;
            let _ = result.as_ref().map(|r| r.data.len()).record();

            result
        }
        .instrument(span)
        .await
    }
}