langdb_core 0.3.2

AI gateway Core for LangDB AI Gateway.
Documentation
use super::CredentialsIdent;
use crate::events::{JsonValue, SPAN_BEDROCK};
use crate::model::bedrock::bedrock_client;
use crate::model::error::ModelError;
use crate::model::types::{
    LLMFinishEvent, LLMStartEvent, ModelEvent, ModelEventType, ModelFinishReason,
};
use crate::types::credentials::BedrockCredentials;
use crate::types::embed::EmbeddingResult;
use crate::types::gateway::{CompletionModelUsage, CreateEmbeddingRequest, Input};
use crate::{create_model_span, GatewayResult};
use async_openai::types::{CreateEmbeddingResponse, Embedding, EmbeddingUsage};
use aws_sdk_bedrockruntime::Client;
use aws_smithy_types::Blob;
use serde::Deserialize;
use serde::Serialize;
use std::collections::HashMap;
use tracing::field;
use tracing::Span;
use tracing_futures::Instrument;
use valuable::Valuable;

use super::EmbeddingsModelInstance;

#[derive(Debug, Deserialize)]
pub struct AmazonTitanEmbeddingResponse {
    embedding: Vec<f32>,
    #[serde(alias = "inputTextTokenCount")]
    input_text_token_count: u32,
}

#[derive(Debug, Deserialize)]
pub struct CohereEmbeddingResponse {
    embeddings: Vec<Vec<f32>>,
}

#[derive(Debug, Serialize)]
pub struct CohereEmbeddingRequest {
    #[serde(skip_serializing_if = "Vec::is_empty")]
    texts: Vec<String>,
    #[serde(skip_serializing_if = "Vec::is_empty")]
    images: Vec<String>,
    input_type: CohereEmbeddingInputType,
}

#[derive(Debug, Serialize, Default)]
#[serde(rename_all = "snake_case")]
pub enum CohereEmbeddingInputType {
    #[default]
    SearchDocument,
    SearchQuery,
    Classification,
    Clustering,
    Image,
}

#[derive(Debug, Serialize)]
pub struct AmazonTitanEmbeddingRequest {
    #[serde(rename = "inputText")]
    input_text: String,
}

pub enum BedrockEmbeddingProvider {
    Cohere,
    Other(String),
}

pub struct BedrockEmbeddings {
    pub client: Client,
    pub credentials_ident: CredentialsIdent,
}

macro_rules! target {
    () => {
        "langdb::user_tracing::models::bedrock"
    };
    ($subtgt:literal) => {
        concat!("langdb::user_tracing::models::bedrock::", $subtgt)
    };
}

impl BedrockEmbeddings {
    pub async fn new(credentials: Option<&BedrockCredentials>) -> Result<Self, ModelError> {
        let client = bedrock_client(credentials).await?;
        Ok(BedrockEmbeddings {
            client,
            credentials_ident: credentials
                .map(|_c| CredentialsIdent::Own)
                .unwrap_or(CredentialsIdent::Langdb),
        })
    }

    async fn execute(
        &self,
        request: &CreateEmbeddingRequest,
        outer_tx: &tokio::sync::mpsc::Sender<Option<ModelEvent>>,
    ) -> GatewayResult<EmbeddingResult> {
        let span = Span::current();
        let _ = outer_tx.try_send(Some(ModelEvent::new(
            &span,
            ModelEventType::LlmStart(LLMStartEvent {
                provider_name: "bedrock".to_string(),
                model_name: request.model.clone(),
                input: serde_json::to_string(&request)?,
            }),
        )));

        let builder = self.client.invoke_model();

        let provider = match_provider(&request.model);
        let (blob, input_tokens) = generate_invoke_model_input(&request.input, &provider)?;

        let invoke = builder
            .model_id(request.model.clone())
            .body(blob)
            .send()
            .await
            .map_err(|e| ModelError::CustomError(e.to_string()))?;

        let response = map_response(invoke.body, &request.model, &provider, input_tokens)?;

        let _ = outer_tx
            .send(Some(ModelEvent::new(
                &span,
                ModelEventType::LlmStop(LLMFinishEvent {
                    provider_name: SPAN_BEDROCK.to_string(),
                    model_name: request.model.clone(),
                    output: None,
                    usage: Some(CompletionModelUsage {
                        input_tokens: response.usage().prompt_tokens,
                        output_tokens: 0,
                        total_tokens: response.usage().total_tokens,
                        ..Default::default()
                    }),
                    finish_reason: ModelFinishReason::Stop,
                    tool_calls: vec![],
                    credentials_ident: self.credentials_ident.clone(),
                }),
            )))
            .await;

        span.record(
            "raw_usage",
            JsonValue(&serde_json::to_value(response.usage())?).as_value(),
        );
        span.record(
            "usage",
            JsonValue(&serde_json::to_value(Self::map_usage(response.usage()))?).as_value(),
        );

        Ok(response)
    }

    fn map_usage(usage: &EmbeddingUsage) -> CompletionModelUsage {
        CompletionModelUsage {
            input_tokens: usage.prompt_tokens,
            total_tokens: usage.total_tokens,
            ..Default::default()
        }
    }
}

#[async_trait::async_trait]
impl EmbeddingsModelInstance for BedrockEmbeddings {
    async fn embed(
        &self,
        request: &CreateEmbeddingRequest,
        outer_tx: tokio::sync::mpsc::Sender<Option<ModelEvent>>,
        tags: HashMap<String, String>,
    ) -> GatewayResult<EmbeddingResult> {
        let span = create_model_span!(
            SPAN_BEDROCK,
            target!("embedding"),
            tags,
            0,
            input = serde_json::to_string(&request)?
        );

        self.execute(request, &outer_tx)
            .instrument(span.clone())
            .await
    }
}

fn match_provider(model_name: &str) -> BedrockEmbeddingProvider {
    let model_name = model_name
        .split("/")
        .collect::<Vec<&str>>()
        .last()
        .map_or(model_name, |p| p);
    let provider = model_name.split(".").next();
    match provider {
        Some("cohere") => BedrockEmbeddingProvider::Cohere,
        Some(name) => BedrockEmbeddingProvider::Other(name.to_string()),
        None => BedrockEmbeddingProvider::Other(model_name.to_string()),
    }
}

fn generate_invoke_model_input(
    input: &Input,
    provider: &BedrockEmbeddingProvider,
) -> Result<(Blob, u32), ModelError> {
    match provider {
        BedrockEmbeddingProvider::Cohere => match input {
            Input::String(s) => {
                let request = CohereEmbeddingRequest {
                    texts: vec![s.clone()],
                    images: vec![],
                    input_type: CohereEmbeddingInputType::default(),
                };
                Ok((
                    Blob::new(serde_json::to_string(&request)?),
                    (s.len() / 3)
                        .try_into()
                        .map_err(|_| ModelError::CannotCalculateInputTokens)?,
                ))
            }
            Input::Array(vec) => {
                let len = vec.iter().map(|s| s.len()).sum::<usize>();
                let request = CohereEmbeddingRequest {
                    texts: vec.clone(),
                    images: vec![],
                    input_type: CohereEmbeddingInputType::default(),
                };
                Ok((
                    Blob::new(serde_json::to_string(&request)?),
                    (len / 3)
                        .try_into()
                        .map_err(|_| ModelError::CannotCalculateInputTokens)?,
                ))
            }
        },
        BedrockEmbeddingProvider::Other(_) => match input {
            Input::String(s) => {
                let request = AmazonTitanEmbeddingRequest {
                    input_text: s.clone(),
                };
                Ok((Blob::new(serde_json::to_string(&request)?), 0))
            }
            Input::Array(vec) => {
                let request = AmazonTitanEmbeddingRequest {
                    input_text: vec.join("\n"),
                };
                Ok((Blob::new(serde_json::to_string(&request)?), 0))
            }
        },
    }
}

fn map_response(
    response: Blob,
    model: &str,
    provider: &BedrockEmbeddingProvider,
    input_tokens: u32,
) -> Result<EmbeddingResult, ModelError> {
    let bytes = response.into_inner();
    match provider {
        BedrockEmbeddingProvider::Cohere => {
            let response: CohereEmbeddingResponse = serde_json::from_slice(&bytes)?;
            Ok(EmbeddingResult::Float(CreateEmbeddingResponse {
                data: response
                    .embeddings
                    .into_iter()
                    .enumerate()
                    .map(|(index, embedding)| Embedding {
                        object: "embedding".to_string(),
                        embedding: embedding.clone(),
                        index: index as u32,
                    })
                    .collect(),
                model: model.to_string(),
                object: "list".to_string(),
                usage: EmbeddingUsage {
                    prompt_tokens: input_tokens,
                    total_tokens: input_tokens,
                },
            }))
        }
        BedrockEmbeddingProvider::Other(_) => {
            let response: AmazonTitanEmbeddingResponse = serde_json::from_slice(&bytes)?;
            Ok(EmbeddingResult::Float(CreateEmbeddingResponse {
                data: vec![Embedding {
                    object: "embedding".to_string(),
                    embedding: response.embedding,
                    index: 0,
                }],
                model: model.to_string(),
                object: "list".to_string(),
                usage: EmbeddingUsage {
                    prompt_tokens: response.input_text_token_count,
                    total_tokens: response.input_text_token_count,
                },
            }))
        }
    }
}