Skip to main content

adk_server/rest/controllers/
runtime.rs

1use crate::ServerConfig;
2use crate::auth_bridge::{RequestContextError, RequestContextExtractor};
3use crate::ui_protocol::{
4    SUPPORTED_UI_PROTOCOLS, UI_PROTOCOL_CAPABILITIES, normalize_runtime_ui_protocol,
5};
6use adk_core::RequestContext;
7use axum::{
8    Json,
9    extract::{Path, State},
10    http::{HeaderMap, StatusCode},
11    response::sse::{Event, KeepAlive, Sse},
12};
13use base64::{Engine as _, engine::general_purpose::STANDARD as BASE64_STANDARD};
14use futures::{
15    StreamExt,
16    stream::{self, Stream},
17};
18use serde::{Deserialize, Serialize};
19use serde_json::json;
20use std::convert::Infallible;
21use tracing::{Instrument, info, warn};
22
23fn default_streaming_true() -> bool {
24    true
25}
26
27const UI_PROTOCOL_HEADER: &str = "x-adk-ui-protocol";
28
29#[derive(Clone)]
30pub struct RuntimeController {
31    config: ServerConfig,
32}
33
34impl RuntimeController {
35    pub fn new(config: ServerConfig) -> Self {
36        Self { config }
37    }
38}
39
40/// Attachment structure for the legacy /run endpoint
41#[derive(Serialize, Deserialize, Debug)]
42pub struct Attachment {
43    pub name: String,
44    #[serde(rename = "type")]
45    pub mime_type: String,
46    pub base64: String,
47}
48
49#[derive(Serialize, Deserialize)]
50pub struct RunRequest {
51    pub new_message: String,
52    #[serde(default, alias = "uiProtocol")]
53    pub ui_protocol: Option<String>,
54    #[serde(default)]
55    pub attachments: Vec<Attachment>,
56}
57
58/// Request format for /run_sse (adk-go compatible)
59#[derive(Serialize, Deserialize, Debug)]
60#[serde(rename_all = "camelCase")]
61pub struct RunSseRequest {
62    pub app_name: String,
63    pub user_id: String,
64    pub session_id: String,
65    pub new_message: NewMessage,
66    #[serde(default = "default_streaming_true")]
67    pub streaming: bool,
68    #[serde(default)]
69    pub state_delta: Option<serde_json::Value>,
70    #[serde(default, alias = "ui_protocol")]
71    pub ui_protocol: Option<String>,
72}
73
74#[derive(Serialize, Deserialize, Debug)]
75pub struct NewMessage {
76    pub role: String,
77    pub parts: Vec<MessagePart>,
78}
79
80#[derive(Serialize, Deserialize, Debug)]
81pub struct MessagePart {
82    #[serde(default)]
83    pub text: Option<String>,
84    #[serde(default, rename = "inlineData")]
85    pub inline_data: Option<InlineData>,
86}
87
88#[derive(Serialize, Deserialize, Debug)]
89#[serde(rename_all = "camelCase")]
90pub struct InlineData {
91    pub display_name: Option<String>,
92    pub data: String,
93    pub mime_type: String,
94}
95
96#[derive(Debug, Clone, Copy, PartialEq, Eq)]
97enum UiProfile {
98    AdkUi,
99    A2ui,
100    AgUi,
101    McpApps,
102}
103
104impl UiProfile {
105    fn as_str(self) -> &'static str {
106        match self {
107            Self::AdkUi => "adk_ui",
108            Self::A2ui => "a2ui",
109            Self::AgUi => "ag_ui",
110            Self::McpApps => "mcp_apps",
111        }
112    }
113}
114
115type RuntimeError = (StatusCode, String);
116
117fn parse_ui_profile(raw: &str) -> Option<UiProfile> {
118    match normalize_runtime_ui_protocol(raw)? {
119        "adk_ui" => Some(UiProfile::AdkUi),
120        "a2ui" => Some(UiProfile::A2ui),
121        "ag_ui" => Some(UiProfile::AgUi),
122        "mcp_apps" => Some(UiProfile::McpApps),
123        _ => None,
124    }
125}
126
127fn resolve_ui_profile(
128    headers: &HeaderMap,
129    body_ui_protocol: Option<&str>,
130) -> Result<UiProfile, RuntimeError> {
131    let header_value = headers.get(UI_PROTOCOL_HEADER).and_then(|v| v.to_str().ok());
132    let candidate = header_value.or(body_ui_protocol);
133
134    let Some(raw) = candidate else {
135        return Ok(UiProfile::AdkUi);
136    };
137
138    parse_ui_profile(raw).ok_or_else(|| {
139        let supported = SUPPORTED_UI_PROTOCOLS.join(", ");
140        warn!(
141            requested = %raw,
142            header = %UI_PROTOCOL_HEADER,
143            "unsupported ui protocol requested"
144        );
145        (
146            StatusCode::BAD_REQUEST,
147            format!("Unsupported ui protocol '{}'. Supported profiles: {}", raw, supported),
148        )
149    })
150}
151
152fn serialize_runtime_event(event: &adk_core::Event, profile: UiProfile) -> Option<String> {
153    if profile == UiProfile::AdkUi {
154        return serde_json::to_string(event).ok();
155    }
156
157    serde_json::to_string(&json!({
158        "ui_protocol": profile.as_str(),
159        "event": event
160    }))
161    .ok()
162}
163
164fn log_profile_deprecation(profile: UiProfile) {
165    if profile != UiProfile::AdkUi {
166        return;
167    }
168    let Some(spec) = UI_PROTOCOL_CAPABILITIES
169        .iter()
170        .find(|capability| capability.protocol == profile.as_str())
171        .and_then(|capability| capability.deprecation)
172    else {
173        return;
174    };
175
176    warn!(
177        protocol = %profile.as_str(),
178        stage = %spec.stage,
179        announced_on = %spec.announced_on,
180        sunset_target_on = ?spec.sunset_target_on,
181        replacements = ?spec.replacement_protocols,
182        "legacy ui protocol profile selected"
183    );
184}
185
186/// Build Content from message text and attachments
187fn build_content_with_attachments(
188    text: &str,
189    attachments: &[Attachment],
190) -> Result<adk_core::Content, RuntimeError> {
191    let mut content = adk_core::Content::new("user");
192
193    // Add the text part
194    content.parts.push(adk_core::Part::Text { text: text.to_string() });
195
196    // Add attachment parts
197    for attachment in attachments {
198        match BASE64_STANDARD.decode(&attachment.base64) {
199            Ok(data) => {
200                if data.len() > adk_core::MAX_INLINE_DATA_SIZE {
201                    return Err((
202                        StatusCode::PAYLOAD_TOO_LARGE,
203                        format!(
204                            "Attachment '{}' exceeds max inline size of {} bytes",
205                            attachment.name,
206                            adk_core::MAX_INLINE_DATA_SIZE
207                        ),
208                    ));
209                }
210                content.parts.push(adk_core::Part::InlineData {
211                    mime_type: attachment.mime_type.clone(),
212                    data,
213                });
214            }
215            Err(e) => {
216                return Err((
217                    StatusCode::BAD_REQUEST,
218                    format!("Invalid base64 data for attachment '{}': {}", attachment.name, e),
219                ));
220            }
221        }
222    }
223
224    Ok(content)
225}
226
227/// Build Content from message parts (for /run_sse endpoint)
228fn build_content_from_parts(parts: &[MessagePart]) -> Result<adk_core::Content, RuntimeError> {
229    let mut content = adk_core::Content::new("user");
230
231    for part in parts {
232        // Add text part if present
233        if let Some(text) = &part.text {
234            content.parts.push(adk_core::Part::Text { text: text.clone() });
235        }
236
237        // Add inline data part if present
238        if let Some(inline_data) = &part.inline_data {
239            match BASE64_STANDARD.decode(&inline_data.data) {
240                Ok(data) => {
241                    if data.len() > adk_core::MAX_INLINE_DATA_SIZE {
242                        return Err((
243                            StatusCode::PAYLOAD_TOO_LARGE,
244                            format!(
245                                "inline_data exceeds max inline size of {} bytes",
246                                adk_core::MAX_INLINE_DATA_SIZE
247                            ),
248                        ));
249                    }
250                    content.parts.push(adk_core::Part::InlineData {
251                        mime_type: inline_data.mime_type.clone(),
252                        data,
253                    });
254                }
255                Err(e) => {
256                    return Err((
257                        StatusCode::BAD_REQUEST,
258                        format!("Invalid base64 data in inline_data: {}", e),
259                    ));
260                }
261            }
262        }
263    }
264
265    Ok(content)
266}
267
268/// Extract [`RequestContext`] from the configured extractor, if present.
269///
270/// Constructs minimal HTTP request [`Parts`] from the provided headers so the
271/// extractor can inspect `Authorization` and other headers. Returns `None`
272/// when no extractor is configured (fall-through to existing behavior).
273async fn extract_request_context(
274    extractor: Option<&dyn RequestContextExtractor>,
275    headers: &HeaderMap,
276) -> Result<Option<RequestContext>, RuntimeError> {
277    let Some(extractor) = extractor else {
278        return Ok(None);
279    };
280
281    // Build minimal Parts from the headers
282    let mut builder = axum::http::Request::builder();
283    for (name, value) in headers {
284        builder = builder.header(name, value);
285    }
286    let (parts, _) = builder
287        .body(())
288        .map_err(|e| {
289            (StatusCode::INTERNAL_SERVER_ERROR, format!("failed to build request parts: {e}"))
290        })?
291        .into_parts();
292
293    match extractor.extract(&parts).await {
294        Ok(ctx) => Ok(Some(ctx)),
295        Err(RequestContextError::MissingAuth) => {
296            Err((StatusCode::UNAUTHORIZED, "missing authorization".to_string()))
297        }
298        Err(RequestContextError::InvalidToken(msg)) => {
299            Err((StatusCode::UNAUTHORIZED, format!("invalid token: {msg}")))
300        }
301        Err(RequestContextError::ExtractionFailed(msg)) => {
302            Err((StatusCode::INTERNAL_SERVER_ERROR, format!("auth extraction failed: {msg}")))
303        }
304    }
305}
306
307pub async fn run_sse(
308    State(controller): State<RuntimeController>,
309    Path((app_name, user_id, session_id)): Path<(String, String, String)>,
310    headers: HeaderMap,
311    Json(req): Json<RunRequest>,
312) -> Result<Sse<impl Stream<Item = Result<Event, Infallible>>>, RuntimeError> {
313    let ui_profile = resolve_ui_profile(&headers, req.ui_protocol.as_deref())?;
314    let span = tracing::info_span!("run_sse", session_id = %session_id, app_name = %app_name, user_id = %user_id);
315
316    async move {
317        log_profile_deprecation(ui_profile);
318        info!(
319            ui_protocol = %ui_profile.as_str(),
320            "resolved ui protocol profile for runtime request"
321        );
322
323        // Extract request context from auth middleware bridge if configured.
324        // This returns Err (401/500) when the extractor is present but auth
325        // fails, ensuring authorization checks are never bypassed.
326        let request_context = extract_request_context(
327            controller.config.request_context_extractor.as_deref(),
328            &headers,
329        )
330        .await?;
331
332        // Explicit authenticated user override: when an auth extractor is
333        // configured and succeeds, the authenticated user_id takes precedence
334        // over the path parameter. This prevents callers from impersonating
335        // other users via the URL while keeping the path param as a fallback
336        // for unauthenticated deployments (no extractor configured).
337        let effective_user_id = request_context.as_ref().map_or(user_id, |rc| rc.user_id.clone());
338
339        // Validate session exists
340        controller
341            .config
342            .session_service
343            .get(adk_session::GetRequest {
344                app_name: app_name.clone(),
345                user_id: effective_user_id.clone(),
346                session_id: session_id.clone(),
347                num_recent_events: None,
348                after: None,
349            })
350            .await
351            .map_err(|_| (StatusCode::NOT_FOUND, "session not found".to_string()))?;
352
353        // Load agent
354        let agent =
355            controller.config.agent_loader.load_agent(&app_name).await.map_err(|_| {
356                (StatusCode::INTERNAL_SERVER_ERROR, "failed to load agent".to_string())
357            })?;
358
359        // Create runner
360        let runner = adk_runner::Runner::new(adk_runner::RunnerConfig {
361            app_name: app_name.clone(),
362            agent,
363            session_service: controller.config.session_service.clone(),
364            artifact_service: controller.config.artifact_service.clone(),
365            memory_service: controller.config.memory_service.clone(),
366            plugin_manager: None,
367            run_config: None,
368            compaction_config: controller.config.compaction_config.clone(),
369            context_cache_config: controller.config.context_cache_config.clone(),
370            cache_capable: controller.config.cache_capable.clone(),
371            request_context,
372            cancellation_token: None,
373        })
374        .map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, "failed to create runner".to_string()))?;
375
376        // Build content with attachments
377        let content = build_content_with_attachments(&req.new_message, &req.attachments)?;
378
379        // Log attachment info
380        if !req.attachments.is_empty() {
381            info!(attachment_count = req.attachments.len(), "processing request with attachments");
382        }
383
384        // Run agent
385        let event_stream = runner
386            .run(effective_user_id, session_id, content)
387            .await
388            .map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, "failed to run agent".to_string()))?;
389
390        // Convert to SSE stream
391        let selected_profile = ui_profile;
392        let sse_stream = stream::unfold(event_stream, move |mut stream| async move {
393            match stream.next().await {
394                Some(Ok(event)) => {
395                    let json = serialize_runtime_event(&event, selected_profile)?;
396                    Some((Ok(Event::default().data(json)), stream))
397                }
398                _ => None,
399            }
400        });
401
402        Ok(Sse::new(sse_stream).keep_alive(KeepAlive::default()))
403    }
404    .instrument(span)
405    .await
406}
407
408/// POST /run_sse - adk-go compatible endpoint
409/// Accepts JSON body with appName, userId, sessionId, newMessage
410pub async fn run_sse_compat(
411    State(controller): State<RuntimeController>,
412    headers: HeaderMap,
413    Json(req): Json<RunSseRequest>,
414) -> Result<Sse<impl Stream<Item = Result<Event, Infallible>>>, RuntimeError> {
415    let ui_profile = resolve_ui_profile(&headers, req.ui_protocol.as_deref())?;
416    let app_name = req.app_name;
417    let user_id = req.user_id;
418    let session_id = req.session_id;
419
420    info!(
421        app_name = %app_name,
422        user_id = %user_id,
423        session_id = %session_id,
424        ui_protocol = %ui_profile.as_str(),
425        "POST /run_sse request received"
426    );
427    log_profile_deprecation(ui_profile);
428
429    // Extract request context from auth middleware bridge if configured.
430    // This returns Err (401/500) when the extractor is present but auth
431    // fails, ensuring authorization checks are never bypassed.
432    let request_context =
433        extract_request_context(controller.config.request_context_extractor.as_deref(), &headers)
434            .await?;
435
436    // Explicit authenticated user override: when an auth extractor is
437    // configured and succeeds, the authenticated user_id takes precedence
438    // over the request body value. This prevents callers from impersonating
439    // other users via the JSON payload while keeping the body param as a
440    // fallback for unauthenticated deployments (no extractor configured).
441    let effective_user_id = request_context.as_ref().map_or(user_id, |rc| rc.user_id.clone());
442
443    // Build content from message parts (includes both text and inline_data)
444    let content = build_content_from_parts(&req.new_message.parts)?;
445
446    // Log part info
447    let text_parts: Vec<_> = req.new_message.parts.iter().filter(|p| p.text.is_some()).collect();
448    let data_parts: Vec<_> =
449        req.new_message.parts.iter().filter(|p| p.inline_data.is_some()).collect();
450    if !data_parts.is_empty() {
451        info!(
452            text_parts = text_parts.len(),
453            inline_data_parts = data_parts.len(),
454            "processing request with inline data"
455        );
456    }
457
458    // Validate session exists or create it
459    let session_result = controller
460        .config
461        .session_service
462        .get(adk_session::GetRequest {
463            app_name: app_name.clone(),
464            user_id: effective_user_id.clone(),
465            session_id: session_id.clone(),
466            num_recent_events: None,
467            after: None,
468        })
469        .await;
470
471    // If session doesn't exist, create it
472    if session_result.is_err() {
473        controller
474            .config
475            .session_service
476            .create(adk_session::CreateRequest {
477                app_name: app_name.clone(),
478                user_id: effective_user_id.clone(),
479                session_id: Some(session_id.clone()),
480                state: std::collections::HashMap::new(),
481            })
482            .await
483            .map_err(|_| {
484                (StatusCode::INTERNAL_SERVER_ERROR, "failed to create session".to_string())
485            })?;
486    }
487
488    // Load agent
489    let agent = controller
490        .config
491        .agent_loader
492        .load_agent(&app_name)
493        .await
494        .map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, "failed to load agent".to_string()))?;
495
496    // Create runner with streaming config from request
497    let streaming_mode =
498        if req.streaming { adk_core::StreamingMode::SSE } else { adk_core::StreamingMode::None };
499
500    let runner = adk_runner::Runner::new(adk_runner::RunnerConfig {
501        app_name,
502        agent,
503        session_service: controller.config.session_service.clone(),
504        artifact_service: controller.config.artifact_service.clone(),
505        memory_service: controller.config.memory_service.clone(),
506        plugin_manager: None,
507        run_config: Some(adk_core::RunConfig { streaming_mode, ..adk_core::RunConfig::default() }),
508        compaction_config: controller.config.compaction_config.clone(),
509        context_cache_config: controller.config.context_cache_config.clone(),
510        cache_capable: controller.config.cache_capable.clone(),
511        request_context,
512        cancellation_token: None,
513    })
514    .map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, "failed to create runner".to_string()))?;
515
516    // Run agent with full content (text + inline data)
517    let event_stream = runner
518        .run(effective_user_id, session_id, content)
519        .await
520        .map_err(|_| (StatusCode::INTERNAL_SERVER_ERROR, "failed to run agent".to_string()))?;
521
522    // Convert to SSE stream
523    let selected_profile = ui_profile;
524    let sse_stream = stream::unfold(event_stream, move |mut stream| async move {
525        match stream.next().await {
526            Some(Ok(event)) => {
527                let json = serialize_runtime_event(&event, selected_profile)?;
528                Some((Ok(Event::default().data(json)), stream))
529            }
530            _ => None,
531        }
532    });
533
534    Ok(Sse::new(sse_stream).keep_alive(KeepAlive::default()))
535}