langdb_core 0.3.2

AI gateway Core for LangDB AI Gateway.
Documentation
use std::collections::HashMap;

use async_openai::types::{CreateEmbeddingResponse, Embedding, EmbeddingUsage};
use tracing::field;
use tracing::Span;
use tracing_futures::Instrument;
use valuable::Valuable;

use crate::model::types::LLMStartEvent;
use crate::{
    create_model_span,
    events::{JsonValue, SPAN_GEMINI},
    model::{
        embeddings::EmbeddingsModelInstance,
        error::ModelError,
        gemini::{
            client::Client,
            model::gemini_client,
            types::{Part, PartWithThought},
        },
        types::{LLMFinishEvent, ModelEvent, ModelEventType, ModelFinishReason},
        CredentialsIdent,
    },
    types::{
        credentials::ApiKeyCredentials,
        embed::EmbeddingResult,
        gateway::{CompletionModelUsage, CreateEmbeddingRequest, Input},
    },
    GatewayResult,
};

macro_rules! target {
    () => {
        "langdb::user_tracing::models::gemini"
    };
    ($subtgt:literal) => {
        concat!("langdb::user_tracing::models::gemini::", $subtgt)
    };
}

pub struct GeminiEmbeddings {
    client: Client,
    credentials_ident: CredentialsIdent,
}

impl GeminiEmbeddings {
    pub fn new(credentials: Option<&ApiKeyCredentials>) -> Result<Self, ModelError> {
        let client = gemini_client(credentials)?;
        Ok(GeminiEmbeddings {
            client,
            credentials_ident: credentials
                .map(|_c| CredentialsIdent::Own)
                .unwrap_or(CredentialsIdent::Langdb),
        })
    }

    async fn execute(
        &self,
        embedding_request: crate::model::gemini::types::CreateEmbeddingRequest,
        token_count_request: crate::model::gemini::types::CountTokensRequest,
        model_name: &str,
        outer_tx: &tokio::sync::mpsc::Sender<Option<ModelEvent>>,
    ) -> GatewayResult<EmbeddingResult> {
        let span = Span::current();
        let _ = outer_tx.try_send(Some(ModelEvent::new(
            &span,
            ModelEventType::LlmStart(LLMStartEvent {
                provider_name: "gemini".to_string(),
                model_name: model_name.to_string(),
                input: serde_json::to_string(&embedding_request)?,
            }),
        )));

        let response = self
            .client
            .embeddings(model_name, embedding_request)
            .await?;

        let tokens_count = self
            .client
            .count_tokens(model_name, token_count_request)
            .await?;

        let _ = outer_tx
            .send(Some(ModelEvent::new(
                &span,
                ModelEventType::LlmStop(LLMFinishEvent {
                    provider_name: SPAN_GEMINI.to_string(),
                    model_name: model_name.to_string(),
                    output: None,
                    usage: Some(CompletionModelUsage {
                        input_tokens: tokens_count.total_tokens as u32,
                        output_tokens: 0,
                        total_tokens: tokens_count.total_tokens as u32,
                        ..Default::default()
                    }),
                    finish_reason: ModelFinishReason::Stop,
                    tool_calls: vec![],
                    credentials_ident: self.credentials_ident.clone(),
                }),
            )))
            .await;

        span.record(
            "raw_usage",
            JsonValue(&serde_json::to_value(tokens_count.clone())?).as_value(),
        );
        span.record(
            "usage",
            JsonValue(&serde_json::to_value(Self::map_usage(&tokens_count))?).as_value(),
        );

        Ok(EmbeddingResult::Float(CreateEmbeddingResponse {
            object: "list".to_string(),
            data: vec![Embedding {
                object: "embedding".to_string(),
                embedding: response.embedding.values,
                index: 0,
            }],
            model: model_name.to_string(),
            usage: EmbeddingUsage {
                prompt_tokens: tokens_count.total_tokens as u32,
                total_tokens: tokens_count.total_tokens as u32,
            },
        }))
    }

    fn map_usage(usage: &crate::model::gemini::types::CountTokensResponse) -> CompletionModelUsage {
        CompletionModelUsage {
            input_tokens: usage.total_tokens as u32,
            total_tokens: usage.total_tokens as u32,
            ..Default::default()
        }
    }
}

#[async_trait::async_trait]
impl EmbeddingsModelInstance for GeminiEmbeddings {
    async fn embed(
        &self,
        request: &CreateEmbeddingRequest,
        outer_tx: tokio::sync::mpsc::Sender<Option<ModelEvent>>,
        tags: HashMap<String, String>,
    ) -> GatewayResult<EmbeddingResult> {
        let contents = match &request.input {
            Input::String(s) => vec![Part::Text(s.clone())],
            Input::Array(vec) => vec.iter().map(|s| Part::Text(s.clone())).collect(),
        };

        let embedding_request = crate::model::gemini::types::CreateEmbeddingRequest {
            content: crate::model::gemini::types::ContentPart {
                parts: contents.clone(),
            },
            task_type: None,
            title: None,
            output_dimensionality: request.dimensions,
        };

        let token_count_request = crate::model::gemini::types::CountTokensRequest {
            contents: crate::model::gemini::types::Content::user_with_multiple_parts(
                contents
                    .iter()
                    .map(|c| PartWithThought::from(c.clone()))
                    .collect(),
            ),
        };

        let span = create_model_span!(
            SPAN_GEMINI,
            target!("embedding"),
            tags,
            0,
            input = serde_json::to_string(&embedding_request)?
        );

        self.execute(
            embedding_request,
            token_count_request,
            &request.model,
            &outer_tx,
        )
        .instrument(span.clone())
        .await
    }
}