1use crate::ServerConfig;
2use crate::auth_bridge::{RequestContextError, RequestContextExtractor};
3use crate::ui_protocol::{
4 SUPPORTED_UI_PROTOCOLS, UI_PROTOCOL_CAPABILITIES, normalize_runtime_ui_protocol,
5};
6use adk_core::RequestContext;
7use axum::{
8 Json,
9 extract::{Path, State},
10 http::{HeaderMap, StatusCode},
11 response::sse::{Event, KeepAlive, Sse},
12};
13use base64::{Engine as _, engine::general_purpose::STANDARD as BASE64_STANDARD};
14use futures::{
15 StreamExt,
16 stream::{self, Stream},
17};
18use serde::{Deserialize, Serialize};
19use serde_json::json;
20use std::convert::Infallible;
21use tracing::{Instrument, info, warn};
22
23fn default_streaming_true() -> bool {
24 true
25}
26
27const UI_PROTOCOL_HEADER: &str = "x-adk-ui-protocol";
28
29#[derive(Clone)]
30pub struct RuntimeController {
31 config: ServerConfig,
32}
33
34impl RuntimeController {
35 pub fn new(config: ServerConfig) -> Self {
36 Self { config }
37 }
38}
39
40#[derive(Serialize, Deserialize, Debug)]
42pub struct Attachment {
43 pub name: String,
44 #[serde(rename = "type")]
45 pub mime_type: String,
46 pub base64: String,
47}
48
49#[derive(Serialize, Deserialize)]
50pub struct RunRequest {
51 pub new_message: String,
52 #[serde(default, alias = "uiProtocol")]
53 pub ui_protocol: Option<String>,
54 #[serde(default)]
55 pub attachments: Vec<Attachment>,
56}
57
58#[derive(Serialize, Deserialize, Debug)]
60#[serde(rename_all = "camelCase")]
61pub struct RunSseRequest {
62 pub app_name: String,
63 pub user_id: String,
64 pub session_id: String,
65 pub new_message: NewMessage,
66 #[serde(default = "default_streaming_true")]
67 pub streaming: bool,
68 #[serde(default)]
69 pub state_delta: Option<serde_json::Value>,
70 #[serde(default, alias = "ui_protocol")]
71 pub ui_protocol: Option<String>,
72}
73
74#[derive(Serialize, Deserialize, Debug)]
75pub struct NewMessage {
76 pub role: String,
77 pub parts: Vec<MessagePart>,
78}
79
80#[derive(Serialize, Deserialize, Debug)]
81pub struct MessagePart {
82 #[serde(default)]
83 pub text: Option<String>,
84 #[serde(default, rename = "inlineData")]
85 pub inline_data: Option<InlineData>,
86}
87
88#[derive(Serialize, Deserialize, Debug)]
89#[serde(rename_all = "camelCase")]
90pub struct InlineData {
91 pub display_name: Option<String>,
92 pub data: String,
93 pub mime_type: String,
94}
95
96#[derive(Debug, Clone, Copy, PartialEq, Eq)]
97enum UiProfile {
98 AdkUi,
99 A2ui,
100 AgUi,
101 McpApps,
102}
103
104impl UiProfile {
105 fn as_str(self) -> &'static str {
106 match self {
107 Self::AdkUi => "adk_ui",
108 Self::A2ui => "a2ui",
109 Self::AgUi => "ag_ui",
110 Self::McpApps => "mcp_apps",
111 }
112 }
113}
114
115type RuntimeError = (StatusCode, String);
116
117fn parse_ui_profile(raw: &str) -> Option<UiProfile> {
118 match normalize_runtime_ui_protocol(raw)? {
119 "adk_ui" => Some(UiProfile::AdkUi),
120 "a2ui" => Some(UiProfile::A2ui),
121 "ag_ui" => Some(UiProfile::AgUi),
122 "mcp_apps" => Some(UiProfile::McpApps),
123 _ => None,
124 }
125}
126
127fn resolve_ui_profile(
128 headers: &HeaderMap,
129 body_ui_protocol: Option<&str>,
130) -> Result<UiProfile, RuntimeError> {
131 let header_value = headers.get(UI_PROTOCOL_HEADER).and_then(|v| v.to_str().ok());
132 let candidate = header_value.or(body_ui_protocol);
133
134 let Some(raw) = candidate else {
135 return Ok(UiProfile::AdkUi);
136 };
137
138 parse_ui_profile(raw).ok_or_else(|| {
139 let supported = SUPPORTED_UI_PROTOCOLS.join(", ");
140 warn!(
141 requested = %raw,
142 header = %UI_PROTOCOL_HEADER,
143 "unsupported ui protocol requested"
144 );
145 (
146 StatusCode::BAD_REQUEST,
147 format!("Unsupported ui protocol '{}'. Supported profiles: {}", raw, supported),
148 )
149 })
150}
151
152fn serialize_runtime_event(event: &adk_core::Event, profile: UiProfile) -> Option<String> {
153 if profile == UiProfile::AdkUi {
154 return serde_json::to_string(event).ok();
155 }
156
157 serde_json::to_string(&json!({
158 "ui_protocol": profile.as_str(),
159 "event": event
160 }))
161 .ok()
162}
163
164fn log_profile_deprecation(profile: UiProfile) {
165 if profile != UiProfile::AdkUi {
166 return;
167 }
168 let Some(spec) = UI_PROTOCOL_CAPABILITIES
169 .iter()
170 .find(|capability| capability.protocol == profile.as_str())
171 .and_then(|capability| capability.deprecation)
172 else {
173 return;
174 };
175
176 warn!(
177 protocol = %profile.as_str(),
178 stage = %spec.stage,
179 announced_on = %spec.announced_on,
180 sunset_target_on = ?spec.sunset_target_on,
181 replacements = ?spec.replacement_protocols,
182 "legacy ui protocol profile selected"
183 );
184}
185
186fn build_content_with_attachments(
188 text: &str,
189 attachments: &[Attachment],
190) -> Result<adk_core::Content, RuntimeError> {
191 let mut content = adk_core::Content::new("user");
192
193 content.parts.push(adk_core::Part::Text { text: text.to_string() });
195
196 for attachment in attachments {
198 match BASE64_STANDARD.decode(&attachment.base64) {
199 Ok(data) => {
200 if data.len() > adk_core::MAX_INLINE_DATA_SIZE {
201 return Err((
202 StatusCode::PAYLOAD_TOO_LARGE,
203 format!(
204 "Attachment '{}' exceeds max inline size of {} bytes",
205 attachment.name,
206 adk_core::MAX_INLINE_DATA_SIZE
207 ),
208 ));
209 }
210 content.parts.push(adk_core::Part::InlineData {
211 mime_type: attachment.mime_type.clone(),
212 data,
213 });
214 }
215 Err(e) => {
216 return Err((
217 StatusCode::BAD_REQUEST,
218 format!("Invalid base64 data for attachment '{}': {}", attachment.name, e),
219 ));
220 }
221 }
222 }
223
224 Ok(content)
225}
226
227fn build_content_from_parts(parts: &[MessagePart]) -> Result<adk_core::Content, RuntimeError> {
229 let mut content = adk_core::Content::new("user");
230
231 for part in parts {
232 if let Some(text) = &part.text {
234 content.parts.push(adk_core::Part::Text { text: text.clone() });
235 }
236
237 if let Some(inline_data) = &part.inline_data {
239 match BASE64_STANDARD.decode(&inline_data.data) {
240 Ok(data) => {
241 if data.len() > adk_core::MAX_INLINE_DATA_SIZE {
242 return Err((
243 StatusCode::PAYLOAD_TOO_LARGE,
244 format!(
245 "inline_data exceeds max inline size of {} bytes",
246 adk_core::MAX_INLINE_DATA_SIZE
247 ),
248 ));
249 }
250 content.parts.push(adk_core::Part::InlineData {
251 mime_type: inline_data.mime_type.clone(),
252 data,
253 });
254 }
255 Err(e) => {
256 return Err((
257 StatusCode::BAD_REQUEST,
258 format!("Invalid base64 data in inline_data: {}", e),
259 ));
260 }
261 }
262 }
263 }
264
265 Ok(content)
266}
267
268async fn extract_request_context(
274 extractor: Option<&dyn RequestContextExtractor>,
275 headers: &HeaderMap,
276) -> Result<Option<RequestContext>, RuntimeError> {
277 let Some(extractor) = extractor else {
278 return Ok(None);
279 };
280
281 let mut builder = axum::http::Request::builder();
283 for (name, value) in headers {
284 builder = builder.header(name, value);
285 }
286 let (parts, _) = builder
287 .body(())
288 .map_err(|e| {
289 (StatusCode::INTERNAL_SERVER_ERROR, format!("failed to build request parts: {e}"))
290 })?
291 .into_parts();
292
293 match extractor.extract(&parts).await {
294 Ok(ctx) => Ok(Some(ctx)),
295 Err(RequestContextError::MissingAuth) => {
296 Err((StatusCode::UNAUTHORIZED, "missing authorization".to_string()))
297 }
298 Err(RequestContextError::InvalidToken(msg)) => {
299 Err((StatusCode::UNAUTHORIZED, format!("invalid token: {msg}")))
300 }
301 Err(RequestContextError::ExtractionFailed(msg)) => {
302 Err((StatusCode::INTERNAL_SERVER_ERROR, format!("auth extraction failed: {msg}")))
303 }
304 }
305}
306
307pub async fn run_sse(
308 State(controller): State<RuntimeController>,
309 Path((app_name, user_id, session_id)): Path<(String, String, String)>,
310 headers: HeaderMap,
311 Json(req): Json<RunRequest>,
312) -> Result<Sse<impl Stream<Item = Result<Event, Infallible>>>, RuntimeError> {
313 let ui_profile = resolve_ui_profile(&headers, req.ui_protocol.as_deref())?;
314 let span = tracing::info_span!("run_sse", session_id = %session_id, app_name = %app_name, user_id = %user_id);
315
316 async move {
317 log_profile_deprecation(ui_profile);
318 info!(
319 ui_protocol = %ui_profile.as_str(),
320 "resolved ui protocol profile for runtime request"
321 );
322
323 let request_context = extract_request_context(
327 controller.config.request_context_extractor.as_deref(),
328 &headers,
329 )
330 .await?;
331
332 let effective_user_id = request_context.as_ref().map_or(user_id, |rc| rc.user_id.clone());
338
339 controller
341 .config
342 .session_service
343 .get(adk_session::GetRequest {
344 app_name: app_name.clone(),
345 user_id: effective_user_id.clone(),
346 session_id: session_id.clone(),
347 num_recent_events: None,
348 after: None,
349 })
350 .await
351 .map_err(|_| (StatusCode::NOT_FOUND, "session not found".to_string()))?;
352
353 let agent =
355 controller.config.agent_loader.load_agent(&app_name).await.map_err(|_| {
356 (StatusCode::INTERNAL_SERVER_ERROR, "failed to load agent".to_string())
357 })?;
358
359 let runner = adk_runner::Runner::new(adk_runner::RunnerConfig {
361 app_name: app_name.clone(),
362 agent,
363 session_service: controller.config.session_service.clone(),
364 artifact_service: controller.config.artifact_service.clone(),
365 memory_service: controller.config.memory_service.clone(),
366 plugin_manager: None,
367 run_config: None,
368 compaction_config: controller.config.compaction_config.clone(),
369 context_cache_config: controller.config.context_cache_config.clone(),
370 cache_capable: controller.config.cache_capable.clone(),
371 request_context,
372 cancellation_token: None,
373 })
374 .map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, "failed to create runner".to_string()))?;
375
376 let content = build_content_with_attachments(&req.new_message, &req.attachments)?;
378
379 if !req.attachments.is_empty() {
381 info!(attachment_count = req.attachments.len(), "processing request with attachments");
382 }
383
384 let event_stream = runner
386 .run(effective_user_id, session_id, content)
387 .await
388 .map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, "failed to run agent".to_string()))?;
389
390 let selected_profile = ui_profile;
392 let sse_stream = stream::unfold(event_stream, move |mut stream| async move {
393 match stream.next().await {
394 Some(Ok(event)) => {
395 let json = serialize_runtime_event(&event, selected_profile)?;
396 Some((Ok(Event::default().data(json)), stream))
397 }
398 _ => None,
399 }
400 });
401
402 Ok(Sse::new(sse_stream).keep_alive(KeepAlive::default()))
403 }
404 .instrument(span)
405 .await
406}
407
408pub async fn run_sse_compat(
411 State(controller): State<RuntimeController>,
412 headers: HeaderMap,
413 Json(req): Json<RunSseRequest>,
414) -> Result<Sse<impl Stream<Item = Result<Event, Infallible>>>, RuntimeError> {
415 let ui_profile = resolve_ui_profile(&headers, req.ui_protocol.as_deref())?;
416 let app_name = req.app_name;
417 let user_id = req.user_id;
418 let session_id = req.session_id;
419
420 info!(
421 app_name = %app_name,
422 user_id = %user_id,
423 session_id = %session_id,
424 ui_protocol = %ui_profile.as_str(),
425 "POST /run_sse request received"
426 );
427 log_profile_deprecation(ui_profile);
428
429 let request_context =
433 extract_request_context(controller.config.request_context_extractor.as_deref(), &headers)
434 .await?;
435
436 let effective_user_id = request_context.as_ref().map_or(user_id, |rc| rc.user_id.clone());
442
443 let content = build_content_from_parts(&req.new_message.parts)?;
445
446 let text_parts: Vec<_> = req.new_message.parts.iter().filter(|p| p.text.is_some()).collect();
448 let data_parts: Vec<_> =
449 req.new_message.parts.iter().filter(|p| p.inline_data.is_some()).collect();
450 if !data_parts.is_empty() {
451 info!(
452 text_parts = text_parts.len(),
453 inline_data_parts = data_parts.len(),
454 "processing request with inline data"
455 );
456 }
457
458 let session_result = controller
460 .config
461 .session_service
462 .get(adk_session::GetRequest {
463 app_name: app_name.clone(),
464 user_id: effective_user_id.clone(),
465 session_id: session_id.clone(),
466 num_recent_events: None,
467 after: None,
468 })
469 .await;
470
471 if session_result.is_err() {
473 controller
474 .config
475 .session_service
476 .create(adk_session::CreateRequest {
477 app_name: app_name.clone(),
478 user_id: effective_user_id.clone(),
479 session_id: Some(session_id.clone()),
480 state: std::collections::HashMap::new(),
481 })
482 .await
483 .map_err(|_| {
484 (StatusCode::INTERNAL_SERVER_ERROR, "failed to create session".to_string())
485 })?;
486 }
487
488 let agent = controller
490 .config
491 .agent_loader
492 .load_agent(&app_name)
493 .await
494 .map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, "failed to load agent".to_string()))?;
495
496 let streaming_mode =
498 if req.streaming { adk_core::StreamingMode::SSE } else { adk_core::StreamingMode::None };
499
500 let runner = adk_runner::Runner::new(adk_runner::RunnerConfig {
501 app_name,
502 agent,
503 session_service: controller.config.session_service.clone(),
504 artifact_service: controller.config.artifact_service.clone(),
505 memory_service: controller.config.memory_service.clone(),
506 plugin_manager: None,
507 run_config: Some(adk_core::RunConfig { streaming_mode, ..adk_core::RunConfig::default() }),
508 compaction_config: controller.config.compaction_config.clone(),
509 context_cache_config: controller.config.context_cache_config.clone(),
510 cache_capable: controller.config.cache_capable.clone(),
511 request_context,
512 cancellation_token: None,
513 })
514 .map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, "failed to create runner".to_string()))?;
515
516 let event_stream = runner
518 .run(effective_user_id, session_id, content)
519 .await
520 .map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, "failed to run agent".to_string()))?;
521
522 let selected_profile = ui_profile;
524 let sse_stream = stream::unfold(event_stream, move |mut stream| async move {
525 match stream.next().await {
526 Some(Ok(event)) => {
527 let json = serialize_runtime_event(&event, selected_profile)?;
528 Some((Ok(Event::default().data(json)), stream))
529 }
530 _ => None,
531 }
532 });
533
534 Ok(Sse::new(sse_stream).keep_alive(KeepAlive::default()))
535}