adk_server/rest/controllers/
debug.rs1use crate::ServerConfig;
2use axum::{
3 Extension, Json,
4 extract::{Path, State},
5 http::StatusCode,
6};
7use serde::Serialize;
8use std::collections::HashMap;
9
10#[derive(Clone)]
11pub struct DebugController {
12 config: ServerConfig,
13}
14
15impl DebugController {
16 pub fn new(config: ServerConfig) -> Self {
17 Self { config }
18 }
19}
20
21#[derive(Serialize)]
22pub struct GraphResponse {
23 #[serde(rename = "dotSrc")]
24 pub dot_src: String,
25}
26
27fn authorize_path_user_id(
28 request_context: &Option<adk_core::RequestContext>,
29 user_id: &str,
30) -> Result<(), StatusCode> {
31 match request_context {
32 Some(context) if context.user_id != user_id => Err(StatusCode::FORBIDDEN),
33 _ => Ok(()),
34 }
35}
36
37pub async fn get_trace_by_event_id(
39 State(controller): State<DebugController>,
40 Extension(request_context): Extension<Option<adk_core::RequestContext>>,
41 Path(event_id): Path<String>,
42) -> Result<Json<HashMap<String, String>>, StatusCode> {
43 if request_context.is_some() && !controller.config.security.expose_admin_debug {
44 return Err(StatusCode::NOT_FOUND);
45 }
46
47 if let Some(exporter) = &controller.config.span_exporter {
48 if let Some(attributes) = exporter.get_trace_by_event_id(&event_id) {
50 return Ok(Json(attributes));
51 }
52
53 let trace_dict = exporter.get_trace_dict();
55 for (_, attributes) in trace_dict.iter() {
56 if attributes.values().any(|v| v == &event_id) {
58 return Ok(Json(attributes.clone()));
59 }
60 }
61 }
62
63 Err(StatusCode::NOT_FOUND)
64}
65
66fn convert_to_span_data(attributes: &HashMap<String, String>) -> serde_json::Value {
69 let start_time: u64 = attributes.get("start_time").and_then(|s| s.parse().ok()).unwrap_or(0);
70 let end_time: u64 = attributes.get("end_time").and_then(|s| s.parse().ok()).unwrap_or(0);
71
72 let mut obj = serde_json::json!({
74 "name": attributes.get("span_name").map_or("unknown", |v| v.as_str()),
75 "span_id": attributes.get("span_id").map_or("", |v| v.as_str()),
76 "trace_id": attributes.get("trace_id").map_or("", |v| v.as_str()),
77 "start_time": start_time,
78 "end_time": end_time,
79 "attributes": attributes,
80 "invoc_id": attributes.get("gcp.vertex.agent.invocation_id").map_or("", |v| v.as_str())
81 });
82
83 if let Some(llm_req) = attributes.get("gcp.vertex.agent.llm_request") {
85 obj["gcp.vertex.agent.llm_request"] = serde_json::Value::String(llm_req.clone());
86 }
87 if let Some(llm_resp) = attributes.get("gcp.vertex.agent.llm_response") {
88 obj["gcp.vertex.agent.llm_response"] = serde_json::Value::String(llm_resp.clone());
89 }
90
91 obj
92}
93
94pub async fn get_session_traces(
96 State(controller): State<DebugController>,
97 Extension(request_context): Extension<Option<adk_core::RequestContext>>,
98 Path(session_id): Path<String>,
99) -> Result<Json<Vec<serde_json::Value>>, StatusCode> {
100 if let Some(exporter) = &controller.config.span_exporter {
101 let traces = exporter.get_session_trace(&session_id);
102 if let Some(context) = request_context.as_ref() {
103 for trace in &traces {
104 if let Some(owner) = trace.get("adk.user_id") {
105 if owner != &context.user_id {
106 return Err(StatusCode::FORBIDDEN);
107 }
108 }
109 }
110 }
111 let span_data: Vec<serde_json::Value> = traces.iter().map(convert_to_span_data).collect();
112 return Ok(Json(span_data));
113 }
114
115 Ok(Json(Vec::new()))
116}
117
118pub async fn get_graph(
119 State(_controller): State<DebugController>,
120 Extension(request_context): Extension<Option<adk_core::RequestContext>>,
121 Path((_app_name, user_id, _session_id, _event_id)): Path<(String, String, String, String)>,
122) -> Result<Json<GraphResponse>, StatusCode> {
123 authorize_path_user_id(&request_context, &user_id)?;
124
125 let dot_src = "digraph G { Agent -> User [label=\"response\"]; }".to_string();
127 Ok(Json(GraphResponse { dot_src }))
128}
129
130pub async fn get_eval_sets(
132 State(_controller): State<DebugController>,
133 Extension(_request_context): Extension<Option<adk_core::RequestContext>>,
134 Path(_app_name): Path<String>,
135) -> Result<Json<Vec<serde_json::Value>>, StatusCode> {
136 Ok(Json(Vec::new()))
138}
139
140pub async fn get_event(
142 State(controller): State<DebugController>,
143 Extension(request_context): Extension<Option<adk_core::RequestContext>>,
144 Path((app_name, user_id, session_id, event_id)): Path<(String, String, String, String)>,
145) -> Result<Json<serde_json::Value>, StatusCode> {
146 authorize_path_user_id(&request_context, &user_id)?;
147
148 if let Some(exporter) = &controller.config.span_exporter {
150 let traces = exporter.get_session_trace(&session_id);
151
152 for attrs in traces {
154 if let Some(stored_event_id) = attrs.get("gcp.vertex.agent.event_id") {
155 if stored_event_id == &event_id {
156 let invocation_id =
158 attrs.get("gcp.vertex.agent.invocation_id").cloned().unwrap_or_default();
159
160 return Ok(Json(serde_json::json!({
161 "id": event_id,
162 "invocationId": invocation_id,
163 "appName": app_name,
164 "sessionId": session_id,
165 "attributes": attrs,
166 "gcp.vertex.agent.llm_request": attrs.get("gcp.vertex.agent.llm_request"),
167 "gcp.vertex.agent.llm_response": attrs.get("gcp.vertex.agent.llm_response")
168 })));
169 }
170 }
171 }
172 }
173
174 Ok(Json(serde_json::json!({
176 "id": event_id,
177 "invocationId": "",
178 "appName": app_name,
179 "sessionId": session_id
180 })))
181}