use crate::server::AppState;
use axum::{
extract::{Path, Query, State},
http::StatusCode,
response::Json,
};
use chrono::{DateTime, Local, Utc};
use otelite_core::api::{
ErrorResponse, SessionContextGrowth, SessionDiagnoseResponse, SessionInteraction,
SessionListResponse, SessionSummary,
};
use otelite_core::query::{Operator, QueryPredicate, QueryValue};
use otelite_core::storage::QueryParams;
use otelite_core::telemetry::trace::StatusCode as SpanStatusCode;
use otelite_core::telemetry::{extract_ttft_secs, GenAiSpanInfo, Span};
use serde::Deserialize;
use std::collections::{BTreeSet, HashMap, HashSet};
fn root_llm_span(spans: &[Span]) -> Option<&Span> {
spans
.iter()
.filter(|s| s.parent_span_id.is_none())
.find(|s| s.attributes.keys().any(|k| k.starts_with("gen_ai.")))
.or_else(|| {
spans
.iter()
.find(|s| s.attributes.keys().any(|k| k.starts_with("gen_ai.")))
})
}
pub async fn get_session_diagnose(
State(state): State<AppState>,
Path(session_id): Path<String>,
) -> Result<Json<SessionDiagnoseResponse>, (StatusCode, Json<ErrorResponse>)> {
let query = QueryParams {
predicates: vec![QueryPredicate {
field: "session.id".to_string(),
operator: Operator::Equal,
value: QueryValue::String(session_id.clone()),
}],
limit: Some(10_000),
..Default::default()
};
let all_spans = state.storage.query_spans(&query).await.map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(ErrorResponse::internal_error(e.to_string())),
)
})?;
if all_spans.is_empty() {
return Err((
StatusCode::NOT_FOUND,
Json(ErrorResponse::not_found(format!("session {}", session_id))),
));
}
let mut by_trace: HashMap<String, Vec<Span>> = HashMap::new();
for span in all_spans {
by_trace
.entry(span.trace_id.clone())
.or_default()
.push(span);
}
let mut trace_groups: Vec<(String, Vec<Span>)> = by_trace.into_iter().collect();
trace_groups.sort_by_key(|(_, spans)| spans.iter().map(|s| s.start_time).min().unwrap_or(0));
let mut interactions: Vec<SessionInteraction> = Vec::new();
for (idx, (trace_id, spans)) in trace_groups.iter().enumerate() {
let root = match root_llm_span(spans) {
Some(s) => s,
None => continue,
};
let genai = GenAiSpanInfo::from_attributes(&root.attributes);
let ttft = extract_ttft_secs(&root.attributes);
let duration_ms = (root.end_time - root.start_time) / 1_000_000;
let is_error = root.status.code == SpanStatusCode::Error;
let is_stall = is_error && ttft.is_some() && duration_ms > 30_000;
let dt = DateTime::<Utc>::from_timestamp_nanos(root.start_time);
let time_str = dt.with_timezone(&Local).format("%H:%M:%S").to_string();
let (body_length, prompt_id) = if is_error {
let log_params = QueryParams {
trace_id: Some(trace_id.clone()),
search_text: Some("api_request_body".to_string()),
limit: Some(1),
..Default::default()
};
let log_body_len = state
.storage
.query_logs(&log_params)
.await
.ok()
.and_then(|logs| logs.into_iter().next())
.and_then(|log| {
log.attributes
.get("body_length")
.and_then(|v| v.parse::<u64>().ok())
});
let pid = root.attributes.get("prompt.id").cloned();
(log_body_len, pid)
} else {
(None, None)
};
interactions.push(SessionInteraction {
index: idx + 1,
time: time_str,
model: genai
.model
.clone()
.or_else(|| root.attributes.get("gen_ai.request.model").cloned()),
input_tokens: genai.input_tokens,
output_tokens: genai.output_tokens,
cache_read_tokens: genai.cache_read_tokens,
cache_creation_tokens: genai.cache_creation_tokens,
ttft_secs: ttft,
duration_ms,
is_error,
is_stall,
response_id: genai.response_id.clone(),
trace_id: trace_id.clone(),
start_time_ns: root.start_time,
body_length,
prompt_id,
});
}
if interactions.is_empty() {
return Err((
StatusCode::NOT_FOUND,
Json(ErrorResponse::not_found(format!(
"GenAI spans for session {}",
session_id
))),
));
}
let models: Vec<String> = interactions
.iter()
.filter_map(|i| i.model.clone())
.collect::<HashSet<_>>()
.into_iter()
.collect();
let first_ts = interactions.first().map(|i| i.start_time_ns).unwrap_or(0);
let last_ts = interactions.last().map(|i| i.start_time_ns).unwrap_or(0);
let start_time = DateTime::<Utc>::from_timestamp_nanos(first_ts)
.format("%Y-%m-%dT%H:%M:%SZ")
.to_string();
let end_time = DateTime::<Utc>::from_timestamp_nanos(last_ts)
.format("%Y-%m-%dT%H:%M:%SZ")
.to_string();
let error_count = interactions.iter().filter(|i| i.is_error).count();
let stall_count = interactions.iter().filter(|i| i.is_stall).count();
let input_series: Vec<u64> = interactions.iter().filter_map(|i| i.input_tokens).collect();
let context_growth = if input_series.len() >= 2 {
Some(SessionContextGrowth {
first_tokens: *input_series.first().unwrap(),
last_tokens: *input_series.last().unwrap(),
peak_tokens: *input_series.iter().max().unwrap(),
interaction_count: interactions.len(),
})
} else {
None
};
Ok(Json(SessionDiagnoseResponse {
session_id,
models,
start_time,
end_time,
total_interactions: interactions.len(),
error_count,
stall_count,
interactions,
context_growth,
}))
}
#[derive(Debug, Deserialize, Default)]
pub struct SessionListQuery {
pub start_time: Option<i64>,
pub end_time: Option<i64>,
pub limit: Option<usize>,
}
pub async fn list_sessions(
State(state): State<AppState>,
Query(params): Query<SessionListQuery>,
) -> Result<Json<SessionListResponse>, (StatusCode, Json<ErrorResponse>)> {
let limit = params.limit.unwrap_or(200);
let query = QueryParams {
start_time: params.start_time,
end_time: params.end_time,
limit: Some(20_000),
..Default::default()
};
let all_spans = state.storage.query_spans(&query).await.map_err(|e| {
(
StatusCode::INTERNAL_SERVER_ERROR,
Json(ErrorResponse::internal_error(e.to_string())),
)
})?;
let mut by_session: HashMap<String, HashMap<String, Vec<Span>>> = HashMap::new();
for span in all_spans {
let sid = match span.attributes.get("session.id") {
Some(s) if !s.is_empty() => s.clone(),
_ => continue,
};
by_session
.entry(sid)
.or_default()
.entry(span.trace_id.clone())
.or_default()
.push(span);
}
let mut summaries: Vec<SessionSummary> = Vec::with_capacity(by_session.len());
for (session_id, by_trace) in by_session {
let mut models: BTreeSet<String> = BTreeSet::new();
let mut interaction_count = 0usize;
let mut total_input: u64 = 0;
let mut total_output: u64 = 0;
let mut error_count = 0usize;
let mut first_seen_ns = i64::MAX;
let mut last_seen_ns = i64::MIN;
for (_trace_id, spans) in by_trace {
let root = match root_llm_span(&spans) {
Some(s) => s,
None => continue,
};
interaction_count += 1;
let genai = GenAiSpanInfo::from_attributes(&root.attributes);
if let Some(m) = genai.model.as_deref().or_else(|| {
root.attributes
.get("gen_ai.request.model")
.map(|s| s.as_str())
}) {
models.insert(m.to_string());
}
if let Some(v) = genai.input_tokens {
total_input += v;
}
if let Some(v) = genai.output_tokens {
total_output += v;
}
if root.status.code == SpanStatusCode::Error {
error_count += 1;
}
if root.start_time < first_seen_ns {
first_seen_ns = root.start_time;
}
if root.start_time > last_seen_ns {
last_seen_ns = root.start_time;
}
}
if interaction_count == 0 {
continue;
}
summaries.push(SessionSummary {
session_id,
models: models.into_iter().collect(),
interaction_count,
total_input_tokens: total_input,
total_output_tokens: total_output,
error_count,
first_seen_ns,
last_seen_ns,
});
}
summaries.sort_by_key(|s| std::cmp::Reverse(s.last_seen_ns));
let total = summaries.len();
summaries.truncate(limit);
Ok(Json(SessionListResponse {
sessions: summaries,
total,
}))
}