adk_server/rest/controllers/
runtime.rs

1use crate::ServerConfig;
2use axum::{
3    extract::{Path, State},
4    http::StatusCode,
5    response::sse::{Event, KeepAlive, Sse},
6    Json,
7};
8use futures::stream::{self, Stream};
9use serde::{Deserialize, Serialize};
10use std::convert::Infallible;
11use tracing::{error, info};
12
13#[derive(Clone)]
14pub struct RuntimeController {
15    config: ServerConfig,
16}
17
18impl RuntimeController {
19    pub fn new(config: ServerConfig) -> Self {
20        Self { config }
21    }
22}
23
24#[derive(Serialize, Deserialize)]
25pub struct RunRequest {
26    pub new_message: String,
27}
28
29/// Request format for /run_sse (adk-go compatible)
30#[derive(Serialize, Deserialize, Debug)]
31#[serde(rename_all = "camelCase")]
32pub struct RunSseRequest {
33    pub app_name: String,
34    pub user_id: String,
35    pub session_id: String,
36    pub new_message: NewMessage,
37    #[serde(default)]
38    pub streaming: bool,
39    #[serde(default)]
40    pub state_delta: Option<serde_json::Value>,
41}
42
43#[derive(Serialize, Deserialize, Debug)]
44pub struct NewMessage {
45    pub role: String,
46    pub parts: Vec<MessagePart>,
47}
48
49#[derive(Serialize, Deserialize, Debug)]
50pub struct MessagePart {
51    #[serde(default)]
52    pub text: Option<String>,
53    #[serde(default, rename = "inlineData")]
54    pub inline_data: Option<InlineData>,
55}
56
57#[derive(Serialize, Deserialize, Debug)]
58#[serde(rename_all = "camelCase")]
59pub struct InlineData {
60    pub display_name: Option<String>,
61    pub data: String,
62    pub mime_type: String,
63}
64
65pub async fn run_sse(
66    State(controller): State<RuntimeController>,
67    Path((app_name, user_id, session_id)): Path<(String, String, String)>,
68    Json(req): Json<RunRequest>,
69) -> Result<Sse<impl Stream<Item = Result<Event, Infallible>>>, StatusCode> {
70    // Validate session exists
71    controller
72        .config
73        .session_service
74        .get(adk_session::GetRequest {
75            app_name: app_name.clone(),
76            user_id: user_id.clone(),
77            session_id: session_id.clone(),
78            num_recent_events: None,
79            after: None,
80        })
81        .await
82        .map_err(|_| StatusCode::NOT_FOUND)?;
83
84    // Load agent
85    let agent = controller
86        .config
87        .agent_loader
88        .load_agent(&app_name)
89        .await
90        .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
91
92    // Create runner
93    let runner = adk_runner::Runner::new(adk_runner::RunnerConfig {
94        app_name,
95        agent,
96        session_service: controller.config.session_service.clone(),
97        artifact_service: controller.config.artifact_service.clone(),
98        memory_service: None,
99    })
100    .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
101
102    // Run agent
103    let event_stream = runner
104        .run(user_id, session_id, adk_core::Content::new("user").with_text(&req.new_message))
105        .await
106        .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
107
108    // Convert to SSE stream
109    let sse_stream = stream::unfold(event_stream, |mut stream| async move {
110        use futures::StreamExt;
111        match stream.next().await {
112            Some(Ok(event)) => {
113                let json = serde_json::to_string(&event).ok()?;
114                Some((Ok(Event::default().data(json)), stream))
115            }
116            _ => None,
117        }
118    });
119
120    Ok(Sse::new(sse_stream).keep_alive(KeepAlive::default()))
121}
122
123/// POST /run_sse - adk-go compatible endpoint
124/// Accepts JSON body with appName, userId, sessionId, newMessage
125pub async fn run_sse_compat(
126    State(controller): State<RuntimeController>,
127    Json(req): Json<RunSseRequest>,
128) -> Result<Sse<impl Stream<Item = Result<Event, Infallible>>>, StatusCode> {
129    let app_name = req.app_name;
130    let user_id = req.user_id;
131    let session_id = req.session_id;
132
133    info!(
134        app_name = %app_name,
135        user_id = %user_id,
136        session_id = %session_id,
137        "POST /run_sse request received"
138    );
139
140    // Extract text from message parts
141    let message_text = req
142        .new_message
143        .parts
144        .iter()
145        .filter_map(|p| p.text.as_ref())
146        .cloned()
147        .collect::<Vec<_>>()
148        .join(" ");
149
150    // Validate session exists
151    controller
152        .config
153        .session_service
154        .get(adk_session::GetRequest {
155            app_name: app_name.clone(),
156            user_id: user_id.clone(),
157            session_id: session_id.clone(),
158            num_recent_events: None,
159            after: None,
160        })
161        .await
162        .map_err(|e| {
163            error!(error = ?e, "Session not found");
164            StatusCode::NOT_FOUND
165        })?;
166
167    // Load agent
168    let agent = controller
169        .config
170        .agent_loader
171        .load_agent(&app_name)
172        .await
173        .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
174
175    // Create runner
176    let runner = adk_runner::Runner::new(adk_runner::RunnerConfig {
177        app_name,
178        agent,
179        session_service: controller.config.session_service.clone(),
180        artifact_service: controller.config.artifact_service.clone(),
181        memory_service: None,
182    })
183    .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
184
185    // Run agent
186    let event_stream = runner
187        .run(user_id, session_id, adk_core::Content::new("user").with_text(&message_text))
188        .await
189        .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
190
191    // Convert to SSE stream
192    let sse_stream = stream::unfold(event_stream, |mut stream| async move {
193        use futures::StreamExt;
194        match stream.next().await {
195            Some(Ok(event)) => {
196                let json = serde_json::to_string(&event).ok()?;
197                Some((Ok(Event::default().data(json)), stream))
198            }
199            _ => None,
200        }
201    });
202
203    Ok(Sse::new(sse_stream).keep_alive(KeepAlive::default()))
204}