adk_server/rest/controllers/
runtime.rs1use crate::ServerConfig;
2use axum::{
3 extract::{Path, State},
4 http::StatusCode,
5 response::sse::{Event, KeepAlive, Sse},
6 Json,
7};
8use futures::stream::{self, Stream};
9use serde::{Deserialize, Serialize};
10use std::convert::Infallible;
11use tracing::{error, info};
12
13#[derive(Clone)]
14pub struct RuntimeController {
15 config: ServerConfig,
16}
17
18impl RuntimeController {
19 pub fn new(config: ServerConfig) -> Self {
20 Self { config }
21 }
22}
23
24#[derive(Serialize, Deserialize)]
25pub struct RunRequest {
26 pub new_message: String,
27}
28
29#[derive(Serialize, Deserialize, Debug)]
31#[serde(rename_all = "camelCase")]
32pub struct RunSseRequest {
33 pub app_name: String,
34 pub user_id: String,
35 pub session_id: String,
36 pub new_message: NewMessage,
37 #[serde(default)]
38 pub streaming: bool,
39 #[serde(default)]
40 pub state_delta: Option<serde_json::Value>,
41}
42
43#[derive(Serialize, Deserialize, Debug)]
44pub struct NewMessage {
45 pub role: String,
46 pub parts: Vec<MessagePart>,
47}
48
49#[derive(Serialize, Deserialize, Debug)]
50pub struct MessagePart {
51 #[serde(default)]
52 pub text: Option<String>,
53 #[serde(default, rename = "inlineData")]
54 pub inline_data: Option<InlineData>,
55}
56
57#[derive(Serialize, Deserialize, Debug)]
58#[serde(rename_all = "camelCase")]
59pub struct InlineData {
60 pub display_name: Option<String>,
61 pub data: String,
62 pub mime_type: String,
63}
64
65pub async fn run_sse(
66 State(controller): State<RuntimeController>,
67 Path((app_name, user_id, session_id)): Path<(String, String, String)>,
68 Json(req): Json<RunRequest>,
69) -> Result<Sse<impl Stream<Item = Result<Event, Infallible>>>, StatusCode> {
70 controller
72 .config
73 .session_service
74 .get(adk_session::GetRequest {
75 app_name: app_name.clone(),
76 user_id: user_id.clone(),
77 session_id: session_id.clone(),
78 num_recent_events: None,
79 after: None,
80 })
81 .await
82 .map_err(|_| StatusCode::NOT_FOUND)?;
83
84 let agent = controller
86 .config
87 .agent_loader
88 .load_agent(&app_name)
89 .await
90 .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
91
92 let runner = adk_runner::Runner::new(adk_runner::RunnerConfig {
94 app_name,
95 agent,
96 session_service: controller.config.session_service.clone(),
97 artifact_service: controller.config.artifact_service.clone(),
98 memory_service: None,
99 })
100 .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
101
102 let event_stream = runner
104 .run(user_id, session_id, adk_core::Content::new("user").with_text(&req.new_message))
105 .await
106 .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
107
108 let sse_stream = stream::unfold(event_stream, |mut stream| async move {
110 use futures::StreamExt;
111 match stream.next().await {
112 Some(Ok(event)) => {
113 let json = serde_json::to_string(&event).ok()?;
114 Some((Ok(Event::default().data(json)), stream))
115 }
116 _ => None,
117 }
118 });
119
120 Ok(Sse::new(sse_stream).keep_alive(KeepAlive::default()))
121}
122
123pub async fn run_sse_compat(
126 State(controller): State<RuntimeController>,
127 Json(req): Json<RunSseRequest>,
128) -> Result<Sse<impl Stream<Item = Result<Event, Infallible>>>, StatusCode> {
129 let app_name = req.app_name;
130 let user_id = req.user_id;
131 let session_id = req.session_id;
132
133 info!(
134 app_name = %app_name,
135 user_id = %user_id,
136 session_id = %session_id,
137 "POST /run_sse request received"
138 );
139
140 let message_text = req
142 .new_message
143 .parts
144 .iter()
145 .filter_map(|p| p.text.as_ref())
146 .cloned()
147 .collect::<Vec<_>>()
148 .join(" ");
149
150 controller
152 .config
153 .session_service
154 .get(adk_session::GetRequest {
155 app_name: app_name.clone(),
156 user_id: user_id.clone(),
157 session_id: session_id.clone(),
158 num_recent_events: None,
159 after: None,
160 })
161 .await
162 .map_err(|e| {
163 error!(error = ?e, "Session not found");
164 StatusCode::NOT_FOUND
165 })?;
166
167 let agent = controller
169 .config
170 .agent_loader
171 .load_agent(&app_name)
172 .await
173 .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
174
175 let runner = adk_runner::Runner::new(adk_runner::RunnerConfig {
177 app_name,
178 agent,
179 session_service: controller.config.session_service.clone(),
180 artifact_service: controller.config.artifact_service.clone(),
181 memory_service: None,
182 })
183 .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
184
185 let event_stream = runner
187 .run(user_id, session_id, adk_core::Content::new("user").with_text(&message_text))
188 .await
189 .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
190
191 let sse_stream = stream::unfold(event_stream, |mut stream| async move {
193 use futures::StreamExt;
194 match stream.next().await {
195 Some(Ok(event)) => {
196 let json = serde_json::to_string(&event).ok()?;
197 Some((Ok(Event::default().data(json)), stream))
198 }
199 _ => None,
200 }
201 });
202
203 Ok(Sse::new(sse_stream).keep_alive(KeepAlive::default()))
204}