langdb_core 0.3.2

AI gateway Core for LangDB AI Gateway.
Documentation
use crate::error::GatewayError;
use crate::events::{JsonValue, SPAN_CACHE};
use crate::model::types::{ModelEvent, ModelEventType, ModelFinishReason};
use crate::model::ModelInstance;
use crate::types::gateway::{ChatCompletionMessage, ChatCompletionMessageWithFinishReason};
use crate::types::threads::Message;
use crate::{create_model_span, GatewayResult};
use async_trait::async_trait;
use serde_json::Value;
use std::collections::HashMap;
use tokio::sync::mpsc;
use tracing::field;
use tracing_futures::Instrument;
use valuable::Valuable;

macro_rules! target {
    () => {
        "langdb::user_tracing::models::cached_response"
    };
    ($subtgt:literal) => {
        concat!("langdb::user_tracing::models::cached_response::", $subtgt)
    };
}

#[derive(Debug)]
pub struct CachedModel {
    events: Vec<ModelEvent>,
    response: Option<ChatCompletionMessage>,
}

impl CachedModel {
    pub fn new(events: Vec<ModelEvent>, response: Option<ChatCompletionMessage>) -> Self {
        Self { events, response }
    }

    async fn inner_stream(
        &self,
        tx: tokio::sync::mpsc::Sender<Option<ModelEvent>>,
    ) -> GatewayResult<()> {
        for event in &self.events {
            if let ModelEventType::LlmStop(e) = &event.event {
                let mut u = e.usage.clone();
                if let Some(u) = u.as_mut() {
                    u.is_cache_used = true;
                }
                let mut event_type = e.clone();
                event_type.usage = u;

                let mut ev = event.clone();
                ev.event = ModelEventType::LlmStop(event_type);
                tx.send(Some(ev)).await?;
                continue;
            }
            tx.send(Some(event.clone())).await?;
        }
        tx.send(None).await?;
        Ok(())
    }

    async fn invoke_inner(
        &self,
        tx: tokio::sync::mpsc::Sender<Option<ModelEvent>>,
    ) -> GatewayResult<ChatCompletionMessageWithFinishReason> {
        for event in &self.events {
            if let ModelEventType::LlmStop(e) = &event.event {
                let mut u = e.usage.clone();
                if let Some(u) = u.as_mut() {
                    u.is_cache_used = true;
                }
                let mut event_type = e.clone();
                event_type.usage = u;

                let mut ev = event.clone();
                ev.event = ModelEventType::LlmStop(event_type);
                tx.send(Some(ev)).await?;
                continue;
            }

            tx.send(Some(event.clone())).await?;
        }
        tx.send(None).await?;

        if let Some(response) = &self.response {
            return Ok(ChatCompletionMessageWithFinishReason::new(
                response.clone(),
                ModelFinishReason::Stop,
            ));
        }

        Err(GatewayError::CustomError(
            "Cached model response is None".to_string(),
        ))
    }
}

#[async_trait]
impl ModelInstance for CachedModel {
    async fn stream(
        &self,
        _input_vars: HashMap<String, Value>,
        tx: mpsc::Sender<Option<ModelEvent>>,
        _previous_messages: Vec<Message>,
        tags: HashMap<String, String>,
    ) -> GatewayResult<()> {
        let span = create_model_span!(SPAN_CACHE, target!("chat"), &tags, 0, cache_state = "HIT");

        self.inner_stream(tx).instrument(span).await
    }

    async fn invoke(
        &self,
        _input_vars: HashMap<String, Value>,
        tx: tokio::sync::mpsc::Sender<Option<ModelEvent>>,
        _previous_messages: Vec<Message>,
        tags: HashMap<String, String>,
    ) -> GatewayResult<ChatCompletionMessageWithFinishReason> {
        let span = create_model_span!(SPAN_CACHE, target!("chat"), &tags, 0, cache_state = "HIT");

        self.invoke_inner(tx).instrument(span).await
    }
}