use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use actix_web::Error;
use actix_web::dev::{Service, ServiceRequest, ServiceResponse, Transform, forward_ready};
use actix_web::http::header::HeaderValue;
use chrono::Utc;
use futures::future::{Ready, ok};
use tracing::debug;
use super::client::LangfuseClient;
use super::config::LangfuseConfig;
use super::types::{IngestionEvent, Span, Trace};
pub const TRACE_ID_HEADER: &str = "x-langfuse-trace-id";
pub const PARENT_SPAN_ID_HEADER: &str = "x-langfuse-parent-span-id";
pub const SESSION_ID_HEADER: &str = "x-langfuse-session-id";
pub const USER_ID_HEADER: &str = "x-langfuse-user-id";
fn extract_or_generate_trace_id(req: &ServiceRequest) -> String {
req.headers()
.get(TRACE_ID_HEADER)
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string())
.unwrap_or_else(super::types::generate_id)
}
fn extract_header(req: &ServiceRequest, name: &str) -> Option<String> {
req.headers()
.get(name)
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string())
}
pub struct LangfuseTracing {
client: Option<Arc<LangfuseClient>>,
include_request_body: bool,
include_response_body: bool,
exclude_paths: Vec<String>,
service_name: String,
}
impl LangfuseTracing {
pub fn new(config: LangfuseConfig) -> Self {
let client = match LangfuseClient::new(config) {
Ok(c) => Some(Arc::new(c)),
Err(e) => {
tracing::warn!("Failed to create Langfuse client: {}", e);
None
}
};
Self {
client,
include_request_body: false,
include_response_body: false,
exclude_paths: vec![
"/health".to_string(),
"/metrics".to_string(),
"/ready".to_string(),
"/live".to_string(),
],
service_name: "litellm-rs".to_string(),
}
}
pub fn from_env() -> Self {
Self::new(LangfuseConfig::from_env())
}
pub fn include_request_body(mut self, include: bool) -> Self {
self.include_request_body = include;
self
}
pub fn include_response_body(mut self, include: bool) -> Self {
self.include_response_body = include;
self
}
pub fn exclude_paths(mut self, paths: Vec<String>) -> Self {
self.exclude_paths = paths;
self
}
pub fn exclude_path(mut self, path: impl Into<String>) -> Self {
self.exclude_paths.push(path.into());
self
}
pub fn service_name(mut self, name: impl Into<String>) -> Self {
self.service_name = name.into();
self
}
#[cfg(test)]
fn should_trace(&self, path: &str) -> bool {
!self.exclude_paths.iter().any(|p| path.starts_with(p))
}
}
impl<S, B> Transform<S, ServiceRequest> for LangfuseTracing
where
S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
S::Future: 'static,
B: 'static,
{
type Response = ServiceResponse<B>;
type Error = Error;
type InitError = ();
type Transform = LangfuseTracingMiddleware<S>;
type Future = Ready<Result<Self::Transform, Self::InitError>>;
fn new_transform(&self, service: S) -> Self::Future {
ok(LangfuseTracingMiddleware {
service,
client: self.client.clone(),
_include_request_body: self.include_request_body,
include_response_body: self.include_response_body,
exclude_paths: self.exclude_paths.clone(),
service_name: self.service_name.clone(),
})
}
}
pub struct LangfuseTracingMiddleware<S> {
service: S,
client: Option<Arc<LangfuseClient>>,
_include_request_body: bool,
include_response_body: bool,
exclude_paths: Vec<String>,
service_name: String,
}
impl<S> LangfuseTracingMiddleware<S> {
fn should_trace(&self, path: &str) -> bool {
!self.exclude_paths.iter().any(|p| path.starts_with(p))
}
}
impl<S, B> Service<ServiceRequest> for LangfuseTracingMiddleware<S>
where
S: Service<ServiceRequest, Response = ServiceResponse<B>, Error = Error>,
S::Future: 'static,
B: 'static,
{
type Response = ServiceResponse<B>;
type Error = Error;
type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>>>>;
forward_ready!(service);
fn call(&self, mut req: ServiceRequest) -> Self::Future {
let path = req.path().to_string();
let method = req.method().to_string();
if !self.should_trace(&path) || self.client.is_none() {
let fut = self.service.call(req);
return Box::pin(fut);
}
let Some(client) = self.client.clone() else {
let fut = self.service.call(req);
return Box::pin(fut);
};
let start_time = Utc::now();
let start_instant = std::time::Instant::now();
let trace_id = extract_or_generate_trace_id(&req);
let parent_span_id = extract_header(&req, PARENT_SPAN_ID_HEADER);
let session_id = extract_header(&req, SESSION_ID_HEADER);
let user_id = extract_header(&req, USER_ID_HEADER);
let uri = req.uri().to_string();
let query = req.query_string().to_string();
let content_type = req
.headers()
.get("content-type")
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string());
if let Ok(header_value) = HeaderValue::from_str(&trace_id) {
req.headers_mut().insert(
actix_web::http::header::HeaderName::from_static(TRACE_ID_HEADER),
header_value,
);
}
let service_name = self.service_name.clone();
let _include_response = self.include_response_body;
let span_id = super::types::generate_id();
let mut span = Span::new(&trace_id)
.name(format!("{} {}", method, path))
.input(serde_json::json!({
"method": method,
"path": path,
"uri": uri,
"query": if query.is_empty() { None } else { Some(query) },
"content_type": content_type,
}));
span.id = span_id.clone();
span.start_time = Some(start_time);
span.parent_observation_id = parent_span_id;
let mut trace = Trace::with_id(&trace_id)
.name(format!("{} {}", method, path))
.metadata("service", serde_json::json!(service_name))
.metadata("http.method", serde_json::json!(method))
.metadata("http.path", serde_json::json!(path));
if let Some(ref uid) = user_id {
trace = trace.user_id(uid);
}
if let Some(ref sid) = session_id {
trace = trace.session_id(sid);
}
let fut = self.service.call(req);
Box::pin(async move {
let trace_event = IngestionEvent::trace_create(trace);
let span_event = IngestionEvent::span_create(span.clone());
let mut batch = super::types::IngestionBatch::new();
batch.add(trace_event);
batch.add(span_event);
let result = fut.await;
let duration_ms = start_instant.elapsed().as_millis() as u64;
let end_time = Utc::now();
let (status_code, level) = match &result {
Ok(res) => {
let status = res.status().as_u16();
let level = if status >= 500 {
super::types::Level::Error
} else if status >= 400 {
super::types::Level::Warning
} else {
super::types::Level::Default
};
(status, level)
}
Err(_) => (500, super::types::Level::Error),
};
let mut completed_span = Span::new(&trace_id)
.output(serde_json::json!({
"status_code": status_code,
"duration_ms": duration_ms,
}))
.level(level);
completed_span.id = span_id;
completed_span.end_time = Some(end_time);
batch.add(IngestionEvent::span_update(completed_span));
let client_clone = client.clone();
tokio::spawn(async move {
if let Err(e) = client_clone.ingest(batch).await {
tracing::warn!("Failed to send Langfuse events: {}", e);
}
});
debug!(
"Langfuse: Traced {} {} -> {} ({}ms)",
method, path, status_code, duration_ms
);
result
})
}
}
pub trait LangfuseRequestExt {
fn trace_id(&self) -> Option<String>;
fn session_id(&self) -> Option<String>;
fn user_id(&self) -> Option<String>;
}
impl LangfuseRequestExt for actix_web::HttpRequest {
fn trace_id(&self) -> Option<String> {
self.headers()
.get(TRACE_ID_HEADER)
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string())
}
fn session_id(&self) -> Option<String> {
self.headers()
.get(SESSION_ID_HEADER)
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string())
}
fn user_id(&self) -> Option<String> {
self.headers()
.get(USER_ID_HEADER)
.and_then(|v| v.to_str().ok())
.map(|s| s.to_string())
}
}
#[cfg(test)]
mod tests {
use super::*;
fn test_config() -> LangfuseConfig {
LangfuseConfig {
public_key: Some("pk-test".to_string()),
secret_key: Some("sk-test".to_string()),
host: "https://cloud.langfuse.com".to_string(),
enabled: true,
batch_size: 10,
flush_interval_ms: 1000,
debug: true,
release: None,
}
}
#[test]
fn test_middleware_creation() {
let middleware = LangfuseTracing::new(test_config());
assert!(middleware.client.is_some());
}
#[test]
fn test_middleware_from_env() {
let middleware = LangfuseTracing::from_env();
let _ = middleware;
}
#[test]
fn test_middleware_builder() {
let middleware = LangfuseTracing::new(test_config())
.include_request_body(true)
.include_response_body(true)
.exclude_path("/api/internal")
.service_name("my-service");
assert!(middleware.include_request_body);
assert!(middleware.include_response_body);
assert!(
middleware
.exclude_paths
.contains(&"/api/internal".to_string())
);
assert_eq!(middleware.service_name, "my-service");
}
#[test]
fn test_should_trace() {
let middleware = LangfuseTracing::new(test_config())
.exclude_paths(vec!["/health".to_string(), "/metrics".to_string()]);
assert!(!middleware.should_trace("/health"));
assert!(!middleware.should_trace("/health/live"));
assert!(!middleware.should_trace("/metrics"));
assert!(middleware.should_trace("/api/chat"));
assert!(middleware.should_trace("/v1/completions"));
}
#[test]
fn test_default_exclude_paths() {
let middleware = LangfuseTracing::new(test_config());
assert!(!middleware.should_trace("/health"));
assert!(!middleware.should_trace("/metrics"));
assert!(!middleware.should_trace("/ready"));
assert!(!middleware.should_trace("/live"));
}
#[test]
fn test_header_constants() {
assert_eq!(TRACE_ID_HEADER, "x-langfuse-trace-id");
assert_eq!(PARENT_SPAN_ID_HEADER, "x-langfuse-parent-span-id");
assert_eq!(SESSION_ID_HEADER, "x-langfuse-session-id");
assert_eq!(USER_ID_HEADER, "x-langfuse-user-id");
}
#[test]
fn test_disabled_config() {
let config = LangfuseConfig {
enabled: false,
..Default::default()
};
let middleware = LangfuseTracing::new(config);
assert!(middleware.client.is_none());
}
#[test]
fn test_missing_credentials() {
let config = LangfuseConfig::default();
let middleware = LangfuseTracing::new(config);
assert!(middleware.client.is_none());
}
}