use crate::ServerConfig;
use axum::{
Extension, Json,
extract::{Path, State},
http::StatusCode,
};
use serde::Serialize;
use std::collections::HashMap;
#[derive(Clone)]
pub struct DebugController {
config: ServerConfig,
}
impl DebugController {
pub fn new(config: ServerConfig) -> Self {
Self { config }
}
}
#[derive(Serialize)]
pub struct GraphResponse {
#[serde(rename = "dotSrc")]
pub dot_src: String,
}
fn authorize_path_user_id(
request_context: &Option<adk_core::RequestContext>,
user_id: &str,
) -> Result<(), StatusCode> {
match request_context {
Some(context) if context.user_id != user_id => Err(StatusCode::FORBIDDEN),
_ => Ok(()),
}
}
pub async fn get_trace_by_event_id(
State(controller): State<DebugController>,
Extension(request_context): Extension<Option<adk_core::RequestContext>>,
Path(event_id): Path<String>,
) -> Result<Json<HashMap<String, String>>, StatusCode> {
if request_context.is_some() && !controller.config.security.expose_admin_debug {
return Err(StatusCode::NOT_FOUND);
}
if let Some(exporter) = &controller.config.span_exporter {
if let Some(attributes) = exporter.get_trace_by_event_id(&event_id) {
return Ok(Json(attributes));
}
let trace_dict = exporter.get_trace_dict();
for (_, attributes) in trace_dict.iter() {
if attributes.values().any(|v| v == &event_id) {
return Ok(Json(attributes.clone()));
}
}
}
Err(StatusCode::NOT_FOUND)
}
fn convert_to_span_data(attributes: &HashMap<String, String>) -> serde_json::Value {
let start_time: u64 = attributes.get("start_time").and_then(|s| s.parse().ok()).unwrap_or(0);
let end_time: u64 = attributes.get("end_time").and_then(|s| s.parse().ok()).unwrap_or(0);
let mut obj = serde_json::json!({
"name": attributes.get("span_name").map_or("unknown", |v| v.as_str()),
"span_id": attributes.get("span_id").map_or("", |v| v.as_str()),
"trace_id": attributes.get("trace_id").map_or("", |v| v.as_str()),
"start_time": start_time,
"end_time": end_time,
"attributes": attributes,
"invoc_id": attributes.get("gcp.vertex.agent.invocation_id").map_or("", |v| v.as_str())
});
if let Some(llm_req) = attributes.get("gcp.vertex.agent.llm_request") {
obj["gcp.vertex.agent.llm_request"] = serde_json::Value::String(llm_req.clone());
}
if let Some(llm_resp) = attributes.get("gcp.vertex.agent.llm_response") {
obj["gcp.vertex.agent.llm_response"] = serde_json::Value::String(llm_resp.clone());
}
obj
}
pub async fn get_session_traces(
State(controller): State<DebugController>,
Extension(request_context): Extension<Option<adk_core::RequestContext>>,
Path(session_id): Path<String>,
) -> Result<Json<Vec<serde_json::Value>>, StatusCode> {
if let Some(exporter) = &controller.config.span_exporter {
let traces = exporter.get_session_trace(&session_id);
if let Some(context) = request_context.as_ref() {
for trace in &traces {
if let Some(owner) = trace.get("adk.user_id") {
if owner != &context.user_id {
return Err(StatusCode::FORBIDDEN);
}
}
}
}
let span_data: Vec<serde_json::Value> = traces.iter().map(convert_to_span_data).collect();
return Ok(Json(span_data));
}
Ok(Json(Vec::new()))
}
pub async fn get_graph(
State(_controller): State<DebugController>,
Extension(request_context): Extension<Option<adk_core::RequestContext>>,
Path((_app_name, user_id, _session_id, _event_id)): Path<(String, String, String, String)>,
) -> Result<Json<GraphResponse>, (StatusCode, Json<serde_json::Value>)> {
authorize_path_user_id(&request_context, &user_id)
.map_err(|s| (s, Json(serde_json::json!({"error": "forbidden"}))))?;
Err((
StatusCode::NOT_IMPLEMENTED,
Json(serde_json::json!({"error": "graph generation is not yet implemented"})),
))
}
pub async fn get_eval_sets(
State(_controller): State<DebugController>,
Extension(_request_context): Extension<Option<adk_core::RequestContext>>,
Path(_app_name): Path<String>,
) -> Result<Json<Vec<serde_json::Value>>, (StatusCode, Json<serde_json::Value>)> {
Err((
StatusCode::NOT_IMPLEMENTED,
Json(serde_json::json!({"error": "eval sets are not yet implemented"})),
))
}
pub async fn get_event(
State(controller): State<DebugController>,
Extension(request_context): Extension<Option<adk_core::RequestContext>>,
Path((app_name, user_id, session_id, event_id)): Path<(String, String, String, String)>,
) -> Result<Json<serde_json::Value>, StatusCode> {
authorize_path_user_id(&request_context, &user_id)?;
if let Some(exporter) = &controller.config.span_exporter {
let traces = exporter.get_session_trace(&session_id);
for attrs in traces {
if let Some(stored_event_id) = attrs.get("gcp.vertex.agent.event_id") {
if stored_event_id == &event_id {
let invocation_id =
attrs.get("gcp.vertex.agent.invocation_id").cloned().unwrap_or_default();
return Ok(Json(serde_json::json!({
"id": event_id,
"invocationId": invocation_id,
"appName": app_name,
"sessionId": session_id,
"attributes": attrs,
"gcp.vertex.agent.llm_request": attrs.get("gcp.vertex.agent.llm_request"),
"gcp.vertex.agent.llm_response": attrs.get("gcp.vertex.agent.llm_response")
})));
}
}
}
}
Err(StatusCode::NOT_FOUND)
}