adk_server/rest/controllers/
runtime.rs

1use crate::ServerConfig;
2use axum::{
3    Json,
4    extract::{Path, State},
5    http::StatusCode,
6    response::sse::{Event, KeepAlive, Sse},
7};
8use futures::stream::{self, Stream};
9use serde::{Deserialize, Serialize};
10use std::convert::Infallible;
11use tracing::{Instrument, info};
12
13fn default_streaming_true() -> bool {
14    true
15}
16
17#[derive(Clone)]
18pub struct RuntimeController {
19    config: ServerConfig,
20}
21
22impl RuntimeController {
23    pub fn new(config: ServerConfig) -> Self {
24        Self { config }
25    }
26}
27
28#[derive(Serialize, Deserialize)]
29pub struct RunRequest {
30    pub new_message: String,
31}
32
33/// Request format for /run_sse (adk-go compatible)
34#[derive(Serialize, Deserialize, Debug)]
35#[serde(rename_all = "camelCase")]
36pub struct RunSseRequest {
37    pub app_name: String,
38    pub user_id: String,
39    pub session_id: String,
40    pub new_message: NewMessage,
41    #[serde(default = "default_streaming_true")]
42    pub streaming: bool,
43    #[serde(default)]
44    pub state_delta: Option<serde_json::Value>,
45}
46
47#[derive(Serialize, Deserialize, Debug)]
48pub struct NewMessage {
49    pub role: String,
50    pub parts: Vec<MessagePart>,
51}
52
53#[derive(Serialize, Deserialize, Debug)]
54pub struct MessagePart {
55    #[serde(default)]
56    pub text: Option<String>,
57    #[serde(default, rename = "inlineData")]
58    pub inline_data: Option<InlineData>,
59}
60
61#[derive(Serialize, Deserialize, Debug)]
62#[serde(rename_all = "camelCase")]
63pub struct InlineData {
64    pub display_name: Option<String>,
65    pub data: String,
66    pub mime_type: String,
67}
68
69pub async fn run_sse(
70    State(controller): State<RuntimeController>,
71    Path((app_name, user_id, session_id)): Path<(String, String, String)>,
72    Json(req): Json<RunRequest>,
73) -> Result<Sse<impl Stream<Item = Result<Event, Infallible>>>, StatusCode> {
74    let span = tracing::info_span!("run_sse", session_id = %session_id, app_name = %app_name, user_id = %user_id);
75
76    async move {
77        // Validate session exists
78        controller
79            .config
80            .session_service
81            .get(adk_session::GetRequest {
82                app_name: app_name.clone(),
83                user_id: user_id.clone(),
84                session_id: session_id.clone(),
85                num_recent_events: None,
86                after: None,
87            })
88            .await
89            .map_err(|_| StatusCode::NOT_FOUND)?;
90
91        // Load agent
92        let agent = controller
93            .config
94            .agent_loader
95            .load_agent(&app_name)
96            .await
97            .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
98
99        // Create runner
100        let runner = adk_runner::Runner::new(adk_runner::RunnerConfig {
101            app_name: app_name.clone(),
102            agent,
103            session_service: controller.config.session_service.clone(),
104            artifact_service: controller.config.artifact_service.clone(),
105            memory_service: None,
106            run_config: None,
107        })
108        .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
109
110        // Run agent
111        let event_stream = runner
112            .run(user_id, session_id, adk_core::Content::new("user").with_text(&req.new_message))
113            .await
114            .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
115
116        // Convert to SSE stream
117        let sse_stream = stream::unfold(event_stream, move |mut stream| async move {
118            use futures::StreamExt;
119            match stream.next().await {
120                Some(Ok(event)) => {
121                    let json = serde_json::to_string(&event).ok()?;
122                    Some((Ok(Event::default().data(json)), stream))
123                }
124                _ => None,
125            }
126        });
127
128        Ok(Sse::new(sse_stream).keep_alive(KeepAlive::default()))
129    }
130    .instrument(span)
131    .await
132}
133
134/// POST /run_sse - adk-go compatible endpoint
135/// Accepts JSON body with appName, userId, sessionId, newMessage
136pub async fn run_sse_compat(
137    State(controller): State<RuntimeController>,
138    Json(req): Json<RunSseRequest>,
139) -> Result<Sse<impl Stream<Item = Result<Event, Infallible>>>, StatusCode> {
140    let app_name = req.app_name;
141    let user_id = req.user_id;
142    let session_id = req.session_id;
143
144    info!(
145        app_name = %app_name,
146        user_id = %user_id,
147        session_id = %session_id,
148        "POST /run_sse request received"
149    );
150
151    // Extract text from message parts
152    let message_text = req
153        .new_message
154        .parts
155        .iter()
156        .filter_map(|p| p.text.as_ref())
157        .cloned()
158        .collect::<Vec<_>>()
159        .join(" ");
160
161    // Validate session exists or create it
162    let session_result = controller
163        .config
164        .session_service
165        .get(adk_session::GetRequest {
166            app_name: app_name.clone(),
167            user_id: user_id.clone(),
168            session_id: session_id.clone(),
169            num_recent_events: None,
170            after: None,
171        })
172        .await;
173
174    // If session doesn't exist, create it
175    if session_result.is_err() {
176        controller
177            .config
178            .session_service
179            .create(adk_session::CreateRequest {
180                app_name: app_name.clone(),
181                user_id: user_id.clone(),
182                session_id: Some(session_id.clone()),
183                state: std::collections::HashMap::new(),
184            })
185            .await
186            .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
187    }
188
189    // Load agent
190    let agent = controller
191        .config
192        .agent_loader
193        .load_agent(&app_name)
194        .await
195        .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
196
197    // Create runner with streaming config from request
198    let streaming_mode =
199        if req.streaming { adk_core::StreamingMode::SSE } else { adk_core::StreamingMode::None };
200
201    let runner = adk_runner::Runner::new(adk_runner::RunnerConfig {
202        app_name,
203        agent,
204        session_service: controller.config.session_service.clone(),
205        artifact_service: controller.config.artifact_service.clone(),
206        memory_service: None,
207        run_config: Some(adk_core::RunConfig { streaming_mode }),
208    })
209    .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
210
211    // Run agent
212    let event_stream = runner
213        .run(user_id, session_id, adk_core::Content::new("user").with_text(&message_text))
214        .await
215        .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
216
217    // Convert to SSE stream
218    let sse_stream = stream::unfold(event_stream, move |mut stream| async move {
219        use futures::StreamExt;
220        match stream.next().await {
221            Some(Ok(event)) => {
222                let json = serde_json::to_string(&event).ok()?;
223                Some((Ok(Event::default().data(json)), stream))
224            }
225            _ => None,
226        }
227    });
228
229    Ok(Sse::new(sse_stream).keep_alive(KeepAlive::default()))
230}