llm-sdk-rs 0.3.0

A Rust library that enables the development of applications that can interact with different language models through a unified interface.
Documentation
use crate::{
    LanguageModelInput, LanguageModelResult, LanguageModelStream, ModelResponse, ModelUsage,
    PartialModelResponse,
};
use futures::StreamExt;
use opentelemetry::trace::Status;
use std::time::Instant;
use tracing::{info_span, Span};
use tracing_futures::Instrument;
use tracing_opentelemetry::OpenTelemetrySpanExt;

pub struct LmSpan {
    span: Span,
    usage: Option<ModelUsage>,
    cost: Option<f64>,
    start_time: Instant,
    time_to_first_token: Option<f64>,
    max_tokens: Option<u32>,
    temperature: Option<f64>,
    top_p: Option<f64>,
    top_k: Option<i32>,
    presence_penalty: Option<f64>,
    frequency_penalty: Option<f64>,
    seed: Option<i64>,
}

impl LmSpan {
    pub fn new(provider: &str, model_id: &str, method: &str, input: &LanguageModelInput) -> Self {
        let span = if method == "stream" {
            info_span!("llm_sdk.stream")
        } else {
            info_span!("llm_sdk.generate")
        };
        span.set_attribute("gen_ai.operation.name", "generate_content");
        span.set_attribute("gen_ai.provider.name", provider.to_string());
        span.set_attribute("gen_ai.request.model", model_id.to_string());
        span.set_attribute("llm_sdk.method", method.to_string());

        Self {
            span,
            usage: None,
            cost: None,
            start_time: Instant::now(),
            time_to_first_token: None,
            max_tokens: input.max_tokens,
            temperature: input.temperature,
            top_p: input.top_p,
            top_k: input.top_k,
            presence_penalty: input.presence_penalty,
            frequency_penalty: input.frequency_penalty,
            seed: input.seed,
        }
    }

    fn span(&self) -> Span {
        self.span.clone()
    }

    pub async fn instrument_future<F>(&self, future: F) -> F::Output
    where
        F: std::future::Future,
    {
        future.instrument(self.span()).await
    }

    pub fn on_response(&mut self, response: &ModelResponse) {
        if let Some(usage) = &response.usage {
            self.usage = Some(usage.clone());
        }
        if let Some(cost) = response.cost {
            self.cost = Some(cost);
        }
    }

    pub fn on_stream_partial(&mut self, partial: &PartialModelResponse) {
        if let Some(usage) = &partial.usage {
            self.usage
                .get_or_insert_with(ModelUsage::default)
                .add(usage);
        }
        if let Some(cost) = partial.cost {
            *self.cost.get_or_insert(0.0) += cost;
        }
        if partial.delta.is_some() && self.time_to_first_token.is_none() {
            self.time_to_first_token = Some(self.elapsed_seconds());
        }
    }

    pub fn on_error(&mut self, error: &(dyn std::error::Error + 'static)) {
        self.span
            .set_attribute("exception.message", error.to_string());
        self.span.set_status(Status::error(error.to_string()));
    }

    pub fn on_end(&mut self) {
        if let Some(usage) = &self.usage {
            self.span
                .set_attribute("gen_ai.usage.input_tokens", i64::from(usage.input_tokens));
            self.span
                .set_attribute("gen_ai.usage.output_tokens", i64::from(usage.output_tokens));
        }

        if let Some(cost) = self.cost {
            self.span.set_attribute("llm_sdk.cost", cost);
        }

        if let Some(time_to_first_token) = self.time_to_first_token {
            self.span
                .set_attribute("gen_ai.server.time_to_first_token", time_to_first_token);
        }

        if let Some(max_tokens) = self.max_tokens {
            self.span
                .set_attribute("gen_ai.request.max_tokens", i64::from(max_tokens));
        }
        if let Some(temperature) = self.temperature {
            self.span
                .set_attribute("gen_ai.request.temperature", temperature);
        }
        if let Some(top_p) = self.top_p {
            self.span.set_attribute("gen_ai.request.top_p", top_p);
        }
        if let Some(top_k) = self.top_k {
            self.span
                .set_attribute("gen_ai.request.top_k", i64::from(top_k));
        }
        if let Some(presence_penalty) = self.presence_penalty {
            self.span
                .set_attribute("gen_ai.request.presence_penalty", presence_penalty);
        }
        if let Some(frequency_penalty) = self.frequency_penalty {
            self.span
                .set_attribute("gen_ai.request.frequency_penalty", frequency_penalty);
        }
        if let Some(seed) = self.seed {
            self.span.set_attribute("gen_ai.request.seed", seed);
        }
    }

    fn elapsed_seconds(&self) -> f64 {
        self.start_time.elapsed().as_secs_f64()
    }
}

impl Drop for LmSpan {
    fn drop(&mut self) {
        self.on_end();
    }
}

pub async fn trace_generate<F, Fut>(
    provider: &str,
    model_id: &str,
    input: LanguageModelInput,
    f: F,
) -> LanguageModelResult<ModelResponse>
where
    F: FnOnce(LanguageModelInput) -> Fut,
    Fut: std::future::Future<Output = LanguageModelResult<ModelResponse>>,
{
    let mut span = LmSpan::new(provider, model_id, "generate", &input);
    let result = span.instrument_future(f(input)).await;

    match &result {
        Ok(response) => span.on_response(response),
        Err(error) => span.on_error(error),
    }

    span.on_end();
    result
}

pub async fn trace_stream<F, Fut>(
    provider: &str,
    model_id: &str,
    input: LanguageModelInput,
    f: F,
) -> LanguageModelResult<LanguageModelStream>
where
    F: FnOnce(LanguageModelInput) -> Fut,
    Fut: std::future::Future<Output = LanguageModelResult<LanguageModelStream>>,
{
    let mut span = LmSpan::new(provider, model_id, "stream", &input);
    let stream_result = span.instrument_future(f(input)).await;

    match stream_result {
        Ok(mut stream) => {
            let span_handle = span.span();
            let streaming_span = span;
            let instrumented = async_stream::try_stream! {
                let mut span_state = streaming_span;

                while let Some(item) = stream.next().await {
                    match item {
                        Ok(partial) => {
                            span_state.on_stream_partial(&partial);
                            yield partial;
                        }
                        Err(err) => {
                            span_state.on_error(&err);
                            Err(err)?;
                        }
                    }
                }
            }
            .instrument(span_handle);

            Ok(LanguageModelStream::from_stream(instrumented))
        }
        Err(error) => {
            span.on_error(&error);
            span.on_end();
            Err(error)
        }
    }
}