use crate::interceptor::{
AfterResponseContext, BeforeRequestContext, ErrorContext, Interceptor, StreamChunkContext,
StreamEndContext,
};
use crate::Result;
use opentelemetry::{
trace::{SpanKind, Tracer},
KeyValue,
};
use opentelemetry_langfuse::LangfuseContext;
use opentelemetry_semantic_conventions::attribute::{
GEN_AI_OPERATION_NAME, GEN_AI_REQUEST_MAX_TOKENS, GEN_AI_REQUEST_MODEL,
GEN_AI_REQUEST_TEMPERATURE, GEN_AI_RESPONSE_ID, GEN_AI_SYSTEM, GEN_AI_USAGE_INPUT_TOKENS,
GEN_AI_USAGE_OUTPUT_TOKENS,
};
use serde_json::Value;
use std::sync::{Arc, Mutex};
use tracing::{debug, error, info};
pub struct LangfuseState<S = opentelemetry::global::BoxedSpan> {
pub(crate) span: Mutex<Option<S>>,
}
impl<S> Default for LangfuseState<S> {
fn default() -> Self {
Self {
span: Mutex::new(None),
}
}
}
#[derive(Debug, Clone)]
pub struct LangfuseConfig {
pub debug: bool,
}
impl Default for LangfuseConfig {
fn default() -> Self {
Self {
debug: std::env::var("LANGFUSE_DEBUG")
.unwrap_or_else(|_| "false".to_string())
.parse()
.unwrap_or(false),
}
}
}
impl LangfuseConfig {
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn with_debug(mut self, debug: bool) -> Self {
self.debug = debug;
self
}
}
pub struct LangfuseInterceptor<T: Tracer + Send + Sync> {
config: LangfuseConfig,
tracer: Arc<T>,
context: Arc<LangfuseContext>,
}
impl<T: Tracer + Send + Sync> LangfuseInterceptor<T>
where
T::Span: Send + Sync + 'static,
{
pub fn new(tracer: T, config: LangfuseConfig) -> Self {
if config.debug {
info!("Langfuse interceptor initialized");
}
Self {
config,
tracer: Arc::new(tracer),
context: Arc::new(LangfuseContext::new()),
}
}
pub fn set_session_id(&self, session_id: impl Into<String>) {
self.context.set_session_id(session_id);
}
pub fn set_user_id(&self, user_id: impl Into<String>) {
self.context.set_user_id(user_id);
}
pub fn add_tags(&self, tags: Vec<String>) {
self.context.add_tags(tags);
}
pub fn add_tag(&self, tag: impl Into<String>) {
self.context.add_tag(tag);
}
pub fn set_metadata(&self, metadata: serde_json::Value) {
self.context.set_metadata(metadata);
}
pub fn clear_context(&self) {
self.context.clear();
}
pub fn context(&self) -> &Arc<LangfuseContext> {
&self.context
}
fn extract_request_params(request_json: &str) -> serde_json::Result<Value> {
serde_json::from_str(request_json)
}
}
#[async_trait::async_trait]
impl<T: Tracer + Send + Sync> Interceptor<LangfuseState<T::Span>> for LangfuseInterceptor<T>
where
T::Span: Send + Sync + 'static,
{
async fn before_request(
&self,
ctx: &mut BeforeRequestContext<'_, LangfuseState<T::Span>>,
) -> Result<()> {
let tracer = self.tracer.as_ref();
let mut attributes = vec![
KeyValue::new(GEN_AI_SYSTEM, "openai"),
KeyValue::new(GEN_AI_OPERATION_NAME, ctx.operation.to_string()),
KeyValue::new(GEN_AI_REQUEST_MODEL, ctx.model.to_string()),
];
attributes.extend(self.context.get_attributes());
if let Ok(params) = Self::extract_request_params(ctx.request_json) {
if let Some(temperature) = params
.get("temperature")
.and_then(serde_json::Value::as_f64)
{
attributes.push(KeyValue::new(GEN_AI_REQUEST_TEMPERATURE, temperature));
}
if let Some(max_tokens) = params.get("max_tokens").and_then(serde_json::Value::as_i64) {
attributes.push(KeyValue::new(GEN_AI_REQUEST_MAX_TOKENS, max_tokens));
}
if let Some(messages) = params.get("messages").and_then(serde_json::Value::as_array) {
for (i, message) in messages.iter().enumerate() {
if let Some(obj) = message.as_object() {
let role = obj
.get("role")
.and_then(serde_json::Value::as_str)
.unwrap_or("unknown")
.to_string();
let content = obj
.get("content")
.and_then(serde_json::Value::as_str)
.unwrap_or("")
.to_string();
attributes.push(KeyValue::new(format!("gen_ai.prompt.{i}.role"), role));
attributes
.push(KeyValue::new(format!("gen_ai.prompt.{i}.content"), content));
}
}
}
}
let span = tracer
.span_builder(ctx.operation.to_string())
.with_kind(SpanKind::Client)
.with_attributes(attributes)
.start(tracer);
*ctx.state.span.lock().unwrap() = Some(span);
if self.config.debug {
debug!("Started Langfuse span for operation: {}", ctx.operation);
}
Ok(())
}
async fn after_response(
&self,
ctx: &AfterResponseContext<'_, LangfuseState<T::Span>>,
) -> Result<()> {
use opentelemetry::trace::Span;
let Some(mut span) = ctx.state.span.lock().unwrap().take() else {
if self.config.debug {
debug!("No span found in state for operation: {}", ctx.operation);
}
return Ok(());
};
#[allow(clippy::cast_possible_truncation)]
{
span.set_attribute(KeyValue::new(
"duration_ms",
ctx.duration.as_millis() as i64,
));
}
if let Some(input_tokens) = ctx.input_tokens {
span.set_attribute(KeyValue::new(GEN_AI_USAGE_INPUT_TOKENS, input_tokens));
}
if let Some(output_tokens) = ctx.output_tokens {
span.set_attribute(KeyValue::new(GEN_AI_USAGE_OUTPUT_TOKENS, output_tokens));
}
if let Ok(response) = Self::extract_request_params(ctx.response_json) {
if let Some(id) = response.get("id").and_then(serde_json::Value::as_str) {
span.set_attribute(KeyValue::new(GEN_AI_RESPONSE_ID, id.to_string()));
}
if let Some(choices) = response
.get("choices")
.and_then(serde_json::Value::as_array)
{
for (i, choice) in choices.iter().enumerate() {
if let Some(message) = choice.get("message") {
if let Some(role) = message.get("role").and_then(serde_json::Value::as_str)
{
span.set_attribute(KeyValue::new(
format!("gen_ai.completion.{i}.role"),
role.to_string(),
));
}
if let Some(content) =
message.get("content").and_then(serde_json::Value::as_str)
{
span.set_attribute(KeyValue::new(
format!("gen_ai.completion.{i}.content"),
content.to_string(),
));
}
}
}
}
}
span.end();
if self.config.debug {
debug!("Completed Langfuse span for operation: {}", ctx.operation);
}
Ok(())
}
async fn on_stream_chunk(
&self,
_ctx: &StreamChunkContext<'_, LangfuseState<T::Span>>,
) -> Result<()> {
Ok(())
}
async fn on_stream_end(
&self,
ctx: &StreamEndContext<'_, LangfuseState<T::Span>>,
) -> Result<()> {
use opentelemetry::trace::Span;
let Some(mut span) = ctx.state.span.lock().unwrap().take() else {
if self.config.debug {
debug!(
"No span found in state for stream operation: {}",
ctx.operation
);
}
return Ok(());
};
#[allow(clippy::cast_possible_truncation, clippy::cast_possible_wrap)]
{
span.set_attribute(KeyValue::new(
"stream.total_chunks",
ctx.total_chunks as i64,
));
span.set_attribute(KeyValue::new(
"stream.duration_ms",
ctx.duration.as_millis() as i64,
));
}
if let Some(input_tokens) = ctx.input_tokens {
span.set_attribute(KeyValue::new(GEN_AI_USAGE_INPUT_TOKENS, input_tokens));
}
if let Some(output_tokens) = ctx.output_tokens {
span.set_attribute(KeyValue::new(GEN_AI_USAGE_OUTPUT_TOKENS, output_tokens));
}
span.end();
if self.config.debug {
info!(
"Completed streaming span for operation: {} with {} chunks",
ctx.operation, ctx.total_chunks
);
}
Ok(())
}
async fn on_error(&self, ctx: &ErrorContext<'_, LangfuseState<T::Span>>) {
use opentelemetry::trace::{Span, Status};
let Some(state) = ctx.state else {
if self.config.debug {
debug!(
"No state available for error in operation: {}",
ctx.operation
);
}
return;
};
let Some(mut span) = state.span.lock().unwrap().take() else {
if self.config.debug {
debug!(
"No span found in state for error in operation: {}",
ctx.operation
);
}
return;
};
span.set_status(Status::error(ctx.error.to_string()));
span.set_attribute(KeyValue::new("error.type", format!("{:?}", ctx.error)));
span.set_attribute(KeyValue::new("error.message", ctx.error.to_string()));
if let Some(model) = ctx.model {
span.set_attribute(KeyValue::new(GEN_AI_REQUEST_MODEL, model.to_string()));
}
span.end();
if self.config.debug {
error!(
"Recorded error for operation {}: {}",
ctx.operation, ctx.error
);
}
}
}
#[async_trait::async_trait]
impl<T: Tracer + Send + Sync> Interceptor<LangfuseState<T::Span>> for Arc<LangfuseInterceptor<T>>
where
T::Span: Send + Sync + 'static,
{
async fn before_request(
&self,
ctx: &mut BeforeRequestContext<'_, LangfuseState<T::Span>>,
) -> Result<()> {
(**self).before_request(ctx).await
}
async fn after_response(
&self,
ctx: &AfterResponseContext<'_, LangfuseState<T::Span>>,
) -> Result<()> {
(**self).after_response(ctx).await
}
async fn on_stream_chunk(
&self,
ctx: &StreamChunkContext<'_, LangfuseState<T::Span>>,
) -> Result<()> {
(**self).on_stream_chunk(ctx).await
}
async fn on_stream_end(
&self,
ctx: &StreamEndContext<'_, LangfuseState<T::Span>>,
) -> Result<()> {
(**self).on_stream_end(ctx).await
}
async fn on_error(&self, ctx: &ErrorContext<'_, LangfuseState<T::Span>>) {
(**self).on_error(ctx).await;
}
}
#[cfg(test)]
mod tests {
use super::*;
use opentelemetry::trace::noop::NoopTracer;
#[test]
fn test_config_from_env() {
std::env::set_var("LANGFUSE_DEBUG", "true");
let config = LangfuseConfig::default();
assert!(config.debug);
std::env::remove_var("LANGFUSE_DEBUG");
}
#[test]
fn test_interceptor_creation() {
let tracer = NoopTracer::new();
let config = LangfuseConfig::new().with_debug(true);
let _interceptor = LangfuseInterceptor::new(tracer, config);
}
}