langdb_core 0.3.2

AI gateway Core for LangDB AI Gateway.
Documentation
use std::collections::HashMap;

use async_openai::{
    config::{AzureConfig, Config, OpenAIConfig},
    types::EmbeddingUsage,
    Client,
};
use tracing::field;
use tracing::Span;
use tracing_futures::Instrument;
use valuable::Valuable;

use crate::{
    create_model_span,
    error::GatewayError,
    events::{JsonValue, SPAN_OPENAI},
    model::{
        embeddings::EmbeddingsModelInstance,
        error::{AuthorizationError, ModelError},
        openai::{azure_openai_client, openai_client},
        types::{LLMFinishEvent, LLMStartEvent, ModelEvent, ModelEventType, ModelFinishReason},
        CredentialsIdent,
    },
    types::{
        credentials::ApiKeyCredentials,
        embed::EmbeddingResult,
        gateway::{CompletionModelUsage, CreateEmbeddingRequest, EncodingFormat, Input},
    },
    GatewayResult,
};

macro_rules! target {
    () => {
        "langdb::user_tracing::models::openai"
    };
    ($subtgt:literal) => {
        concat!("langdb::user_tracing::models::openai::", $subtgt)
    };
}

pub struct OpenAIEmbeddings<C: Config = OpenAIConfig> {
    client: Client<C>,
    credentials_ident: CredentialsIdent,
}

impl OpenAIEmbeddings<OpenAIConfig> {
    pub fn new(
        credentials: Option<&ApiKeyCredentials>,
        client: Option<Client<OpenAIConfig>>,
        endpoint: Option<&str>,
    ) -> Result<Self, ModelError> {
        Ok(OpenAIEmbeddings {
            credentials_ident: credentials
                .map(|_c| CredentialsIdent::Own)
                .unwrap_or(CredentialsIdent::Langdb),
            client: client.unwrap_or(openai_client(credentials, endpoint)?),
        })
    }
}

impl OpenAIEmbeddings<AzureConfig> {
    pub fn new_azure(
        credentials: Option<&ApiKeyCredentials>,
        endpoint: &str,
        deployment_id: &str,
    ) -> Result<Self, ModelError> {
        let api_key = if let Some(credentials) = credentials {
            credentials.api_key.clone()
        } else {
            std::env::var("LANGDB_OPENAI_API_KEY").map_err(|_| AuthorizationError::InvalidApiKey)?
        };
        Ok(OpenAIEmbeddings {
            credentials_ident: credentials
                .map(|_c| CredentialsIdent::Own)
                .unwrap_or(CredentialsIdent::Langdb),
            client: azure_openai_client(api_key, endpoint, deployment_id),
        })
    }
}

impl<C: Config> OpenAIEmbeddings<C> {
    async fn execute(
        &self,
        embedding_request: async_openai::types::CreateEmbeddingRequest,
        encoding_format: &EncodingFormat,
        model_name: &str,
        outer_tx: &tokio::sync::mpsc::Sender<Option<ModelEvent>>,
    ) -> GatewayResult<EmbeddingResult> {
        let span = Span::current();
        let _ = outer_tx.try_send(Some(ModelEvent::new(
            &span,
            ModelEventType::LlmStart(LLMStartEvent {
                provider_name: "openai".to_string(),
                model_name: model_name.to_string(),
                input: serde_json::to_string(&embedding_request)?,
            }),
        )));

        let response: EmbeddingResult = match encoding_format {
            EncodingFormat::Float => self
                .client
                .embeddings()
                .create(embedding_request)
                .await
                .map(|r| r.into())
                .map_err(ModelError::OpenAIApi)?,
            EncodingFormat::Base64 => self
                .client
                .embeddings()
                .create_base64(embedding_request)
                .await
                .map(|r| r.into())
                .map_err(ModelError::OpenAIApi)?,
        };

        outer_tx
            .try_send(Some(ModelEvent::new(
                &span,
                ModelEventType::LlmStop(LLMFinishEvent {
                    provider_name: SPAN_OPENAI.to_string(),
                    model_name: model_name.to_string(),
                    output: None,
                    usage: Some(CompletionModelUsage {
                        input_tokens: response.usage().prompt_tokens,
                        output_tokens: 0,
                        total_tokens: response.usage().total_tokens,
                        ..Default::default()
                    }),
                    finish_reason: ModelFinishReason::Stop,
                    tool_calls: vec![],
                    credentials_ident: self.credentials_ident.clone(),
                }),
            )))
            .map_err(|e| GatewayError::CustomError(e.to_string()))?;

        span.record(
            "raw_usage",
            JsonValue(&serde_json::to_value(response.usage()).unwrap()).as_value(),
        );
        span.record(
            "usage",
            JsonValue(&serde_json::to_value(Self::map_usage(response.usage())).unwrap()).as_value(),
        );

        Ok(response)
    }

    fn map_usage(usage: &EmbeddingUsage) -> CompletionModelUsage {
        CompletionModelUsage {
            input_tokens: usage.prompt_tokens,
            total_tokens: usage.total_tokens,
            ..Default::default()
        }
    }
}

#[async_trait::async_trait]
impl<C: Config + std::marker::Sync + std::marker::Send> EmbeddingsModelInstance
    for OpenAIEmbeddings<C>
{
    async fn embed(
        &self,
        request: &CreateEmbeddingRequest,
        outer_tx: tokio::sync::mpsc::Sender<Option<ModelEvent>>,
        tags: HashMap<String, String>,
    ) -> GatewayResult<EmbeddingResult> {
        let embedding_request = async_openai::types::CreateEmbeddingRequest {
            model: request.model.clone(),
            input: match &request.input {
                Input::String(s) => s.into(),
                Input::Array(vec) => vec.into(),
            },
            user: request.user.clone(),
            dimensions: request.dimensions.map(|d| d as u32),
            encoding_format: Some(match &request.encoding_format {
                EncodingFormat::Float => async_openai::types::EncodingFormat::Float,
                EncodingFormat::Base64 => async_openai::types::EncodingFormat::Base64,
            }),
        };

        let span = create_model_span!(
            SPAN_OPENAI,
            target!("embedding"),
            tags,
            0,
            input = serde_json::to_string(&embedding_request)?
        );

        self.execute(
            embedding_request,
            &request.encoding_format,
            &request.model,
            &outer_tx,
        )
        .instrument(span.clone())
        .await
    }
}