langdb_core 0.3.2

AI gateway Core for LangDB AI Gateway.
Documentation
use std::collections::HashMap;

use crate::events::JsonValue;
use crate::executor::context::ExecutorContext;
use crate::model::DefaultModelMetadataFactory;
use crate::routing::interceptor::rate_limiter::InMemoryRateLimiterService;
use crate::routing::RoutingStrategy;
use crate::types::gateway::ChatCompletionRequestWithTools;
use crate::types::gateway::CompletionModelUsage;
use crate::types::gateway::Extra;
use crate::types::guardrails::service::GuardrailsEvaluator;
use crate::usage::InMemoryStorage;
use actix_web::{web, HttpRequest, HttpResponse};
use bytes::Bytes;
use std::sync::Arc;
use tokio::sync::Mutex;
use valuable::Valuable;

use crate::handler::AvailableModels;
use crate::handler::CallbackHandlerFn;
use crate::model::ModelMetadataFactory;
use crate::types::gateway::{
    ChatCompletionChunk, ChatCompletionChunkChoice, ChatCompletionDelta, ChatCompletionUsage,
    CostCalculator,
};
use crate::GatewayApiError;
use tracing::Span;
use tracing_futures::Instrument;

use super::can_execute_llm_for_request;

use crate::executor::chat_completion::routed_executor::RoutedExecutor;

pub type SSOChatEvent = (
    Option<ChatCompletionDelta>,
    Option<CompletionModelUsage>,
    Option<String>,
);

#[allow(clippy::too_many_arguments)]
pub async fn create_chat_completion(
    request: web::Json<ChatCompletionRequestWithTools<RoutingStrategy>>,
    callback_handler: web::Data<CallbackHandlerFn>,
    req: HttpRequest,
    provided_models: web::Data<AvailableModels>,
    cost_calculator: web::Data<Box<dyn CostCalculator>>,
    evaluator_service: web::Data<Box<dyn GuardrailsEvaluator>>,
) -> Result<HttpResponse, GatewayApiError> {
    can_execute_llm_for_request(&req).await?;

    let span = Span::or_current(tracing::info_span!(
        target: "langdb::user_tracing::api_invoke",
        "api_invoke",
        request = tracing::field::Empty,
        response = tracing::field::Empty,
        error = tracing::field::Empty,
        thread_id = tracing::field::Empty,
        message_id = tracing::field::Empty,
        user = tracing::field::Empty,
    ));

    if let Some(Extra {
        user: Some(user), ..
    }) = &request.extra
    {
        span.record(
            "user",
            JsonValue(&serde_json::to_value(user.clone())?).as_value(),
        );
    }

    let memory_storage = req.app_data::<Arc<Mutex<InMemoryStorage>>>().cloned();
    let rate_limiter_service = InMemoryRateLimiterService::new();
    let guardrails_evaluator_service = evaluator_service.clone().into_inner();
    let executor_context = ExecutorContext::new(
        callback_handler.get_ref().clone(),
        cost_calculator.into_inner(),
        Arc::new(
            Box::new(DefaultModelMetadataFactory::new(&provided_models.0))
                as Box<dyn ModelMetadataFactory>,
        ),
        &req,
        HashMap::new(),
        guardrails_evaluator_service,
        Arc::new(rate_limiter_service),
    )?;

    let executor = RoutedExecutor::new(request.clone());
    executor
        .execute(&executor_context, memory_storage, None)
        .instrument(span.clone())
        .await
}

pub fn map_sso_event(
    delta: Result<SSOChatEvent, GatewayApiError>,
    model_name: String,
) -> Result<Bytes, GatewayApiError> {
    let model_name = model_name.clone();
    let chunks = match delta {
        Ok((None, usage, Some(finish_reason))) => {
            let mut chunks = vec![];
            chunks.push(ChatCompletionChunk {
                id: uuid::Uuid::new_v4().to_string(),
                object: "chat.completion.chunk".to_string(),
                created: chrono::Utc::now().timestamp(),
                model: model_name.clone(),
                choices: vec![ChatCompletionChunkChoice {
                    index: 0,
                    delta: ChatCompletionDelta {
                        content: None,
                        role: None,
                        tool_calls: None,
                    },
                    finish_reason: Some(finish_reason.clone()),
                    logprobs: None,
                }],
                usage: None,
            });

            if let Some(u) = &usage {
                chunks.push(ChatCompletionChunk {
                    id: uuid::Uuid::new_v4().to_string(),
                    object: "chat.completion.chunk".to_string(),
                    created: chrono::Utc::now().timestamp(),
                    model: model_name.clone(),
                    choices: vec![],
                    usage: Some(ChatCompletionUsage {
                        prompt_tokens: u.input_tokens as i32,
                        completion_tokens: u.output_tokens as i32,
                        total_tokens: u.total_tokens as i32,
                        prompt_tokens_details: u.prompt_tokens_details.clone(),
                        completion_tokens_details: u.completion_tokens_details.clone(),
                        cost: 0.0,
                    }),
                });
            }

            Ok(chunks)
        }
        Ok((delta, _, finish_reason)) => {
            let chunk = ChatCompletionChunk {
                id: uuid::Uuid::new_v4().to_string(),
                object: "chat.completion.chunk".to_string(),
                created: chrono::Utc::now().timestamp(),
                model: model_name.clone(),
                choices: delta.as_ref().map_or(vec![], |d| {
                    vec![ChatCompletionChunkChoice {
                        index: 0,
                        delta: d.clone(),
                        finish_reason,
                        logprobs: None,
                    }]
                }),
                usage: None,
            };

            Ok(vec![chunk])
        }
        Err(e) => Err(e),
    };

    let mut result_combined = String::new();
    match chunks {
        Ok(chunks) => {
            for c in chunks {
                let json_str = serde_json::to_string(&c).unwrap_or_else(|e| {
                    format!("{{\"error\": \"Failed to serialize chunk: {e}\"}}")
                });

                result_combined.push_str(&format!("data: {json_str}\n\n"));
            }
        }
        Err(e) => {
            let result = serde_json::to_string(&HashMap::from([("error", e.to_string())]))
                .unwrap_or_else(|e| format!("{{\"error\": \"Failed to serialize chunk: {e}\"}}"));

            result_combined.push_str(&format!("data: {result}\n\n"));
        }
    }

    Ok(Bytes::from(result_combined))
}