use std::future::Future;
use std::pin::Pin;
use std::sync::Arc;
use opentelemetry_sdk::trace::{InMemorySpanExporter, SpanData};
use thiserror::Error;
#[non_exhaustive]
#[derive(Debug, Clone)]
pub enum RawSession {
OtelSpans {
session_id: String,
spans: Vec<SpanData>,
},
}
impl RawSession {
#[must_use]
pub fn session_id(&self) -> &str {
match self {
Self::OtelSpans { session_id, .. } => session_id,
}
}
}
#[non_exhaustive]
#[derive(Debug, Error)]
pub enum TraceProviderError {
#[error("trace backend `{backend}` is not available — enable the `{feature}` cargo feature")]
FeatureDisabled {
backend: String,
feature: String,
},
#[error("trace session `{session_id}` not found")]
SessionNotFound {
session_id: String,
},
#[error("trace session `{session_id}` is still in progress ({open_spans} span(s) not ended)")]
SessionInProgress {
session_id: String,
open_spans: usize,
},
#[error("trace backend failure: {reason}")]
BackendFailure {
reason: String,
},
}
pub trait TraceProvider: Send + Sync {
fn fetch_session<'a>(&'a self, session_id: &'a str) -> TraceProviderFuture<'a>;
}
pub type TraceProviderFuture<'a> =
Pin<Box<dyn Future<Output = Result<RawSession, TraceProviderError>> + Send + 'a>>;
#[derive(Clone, Debug)]
pub struct OtelInMemoryTraceProvider {
exporter: InMemorySpanExporter,
session_attribute: Arc<str>,
}
impl OtelInMemoryTraceProvider {
#[must_use]
pub fn new(exporter: InMemorySpanExporter) -> Self {
Self {
exporter,
session_attribute: Arc::from("session.id"),
}
}
#[must_use]
pub fn with_session_attribute(mut self, key: impl Into<String>) -> Self {
self.session_attribute = Arc::from(key.into());
self
}
#[must_use]
pub fn session_attribute(&self) -> &str {
&self.session_attribute
}
#[must_use]
pub fn exporter(&self) -> &InMemorySpanExporter {
&self.exporter
}
}
impl TraceProvider for OtelInMemoryTraceProvider {
fn fetch_session<'a>(&'a self, session_id: &'a str) -> TraceProviderFuture<'a> {
Box::pin(async move {
let all = self.exporter.get_finished_spans().map_err(|err| {
TraceProviderError::BackendFailure {
reason: format!("in-memory exporter lock: {err}"),
}
})?;
let key = self.session_attribute.as_ref();
let matching: Vec<SpanData> = all
.into_iter()
.filter(|span| {
span.attributes.iter().any(|kv| {
kv.key.as_str() == key && kv.value.as_str().as_ref() == session_id
})
})
.collect();
if matching.is_empty() {
return Err(TraceProviderError::SessionNotFound {
session_id: session_id.to_string(),
});
}
let open_spans = matching
.iter()
.filter(|span| span.end_time <= span.start_time)
.count();
if open_spans > 0 {
return Err(TraceProviderError::SessionInProgress {
session_id: session_id.to_string(),
open_spans,
});
}
Ok(RawSession::OtelSpans {
session_id: session_id.to_string(),
spans: matching,
})
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use opentelemetry::trace::{
SpanContext, SpanId, SpanKind, Status, TraceFlags, TraceId, TraceState,
};
use opentelemetry::{InstrumentationScope, KeyValue};
use opentelemetry_sdk::trace::{SpanEvents, SpanLinks};
use std::borrow::Cow;
use std::time::{Duration, SystemTime};
fn make_span(name: &str, attrs: Vec<KeyValue>, complete: bool) -> SpanData {
let start = SystemTime::now();
let end = if complete {
start + Duration::from_millis(1)
} else {
start
};
SpanData {
span_context: SpanContext::new(
TraceId::from(1_u128),
SpanId::from(1_u64),
TraceFlags::default(),
false,
TraceState::default(),
),
parent_span_id: SpanId::INVALID,
parent_span_is_remote: false,
span_kind: SpanKind::Internal,
name: Cow::Owned(name.to_string()),
start_time: start,
end_time: end,
attributes: attrs,
dropped_attributes_count: 0,
events: SpanEvents::default(),
links: SpanLinks::default(),
status: Status::Unset,
instrumentation_scope: InstrumentationScope::builder("test").build(),
}
}
#[test]
fn raw_session_reports_session_id() {
let s = RawSession::OtelSpans {
session_id: "abc".into(),
spans: vec![],
};
assert_eq!(s.session_id(), "abc");
}
#[test]
fn trace_provider_error_display_includes_fields() {
let err = TraceProviderError::SessionNotFound {
session_id: "sid".into(),
};
assert!(format!("{err}").contains("sid"));
let err = TraceProviderError::SessionInProgress {
session_id: "sid".into(),
open_spans: 2,
};
let rendered = format!("{err}");
assert!(rendered.contains("sid"));
assert!(rendered.contains('2'));
}
#[tokio::test]
async fn fetch_session_not_found_when_no_spans_match() {
let exporter = InMemorySpanExporter::default();
let provider = OtelInMemoryTraceProvider::new(exporter);
let err = provider
.fetch_session("missing")
.await
.expect_err("empty exporter");
assert!(matches!(err, TraceProviderError::SessionNotFound { .. }));
}
#[tokio::test]
async fn fetch_session_uses_configured_attribute_key() {
let exporter = InMemorySpanExporter::default();
let provider =
OtelInMemoryTraceProvider::new(exporter.clone()).with_session_attribute("custom.sid");
assert_eq!(provider.session_attribute(), "custom.sid");
use opentelemetry_sdk::trace::SpanExporter;
let span = make_span("root", vec![KeyValue::new("custom.sid", "S1")], true);
exporter.export(vec![span]).await.unwrap();
let raw = provider.fetch_session("S1").await.unwrap();
match raw {
RawSession::OtelSpans { session_id, spans } => {
assert_eq!(session_id, "S1");
assert_eq!(spans.len(), 1);
}
}
}
}