adk_server/rest/controllers/
runtime.rs1use crate::ServerConfig;
2use axum::{
3 Json,
4 extract::{Path, State},
5 http::StatusCode,
6 response::sse::{Event, KeepAlive, Sse},
7};
8use futures::stream::{self, Stream};
9use serde::{Deserialize, Serialize};
10use std::convert::Infallible;
11use tracing::{Instrument, info};
12
13fn default_streaming_true() -> bool {
14 true
15}
16
17#[derive(Clone)]
18pub struct RuntimeController {
19 config: ServerConfig,
20}
21
22impl RuntimeController {
23 pub fn new(config: ServerConfig) -> Self {
24 Self { config }
25 }
26}
27
28#[derive(Serialize, Deserialize)]
29pub struct RunRequest {
30 pub new_message: String,
31}
32
33#[derive(Serialize, Deserialize, Debug)]
35#[serde(rename_all = "camelCase")]
36pub struct RunSseRequest {
37 pub app_name: String,
38 pub user_id: String,
39 pub session_id: String,
40 pub new_message: NewMessage,
41 #[serde(default = "default_streaming_true")]
42 pub streaming: bool,
43 #[serde(default)]
44 pub state_delta: Option<serde_json::Value>,
45}
46
47#[derive(Serialize, Deserialize, Debug)]
48pub struct NewMessage {
49 pub role: String,
50 pub parts: Vec<MessagePart>,
51}
52
53#[derive(Serialize, Deserialize, Debug)]
54pub struct MessagePart {
55 #[serde(default)]
56 pub text: Option<String>,
57 #[serde(default, rename = "inlineData")]
58 pub inline_data: Option<InlineData>,
59}
60
61#[derive(Serialize, Deserialize, Debug)]
62#[serde(rename_all = "camelCase")]
63pub struct InlineData {
64 pub display_name: Option<String>,
65 pub data: String,
66 pub mime_type: String,
67}
68
69pub async fn run_sse(
70 State(controller): State<RuntimeController>,
71 Path((app_name, user_id, session_id)): Path<(String, String, String)>,
72 Json(req): Json<RunRequest>,
73) -> Result<Sse<impl Stream<Item = Result<Event, Infallible>>>, StatusCode> {
74 let span = tracing::info_span!("run_sse", session_id = %session_id, app_name = %app_name, user_id = %user_id);
75
76 async move {
77 controller
79 .config
80 .session_service
81 .get(adk_session::GetRequest {
82 app_name: app_name.clone(),
83 user_id: user_id.clone(),
84 session_id: session_id.clone(),
85 num_recent_events: None,
86 after: None,
87 })
88 .await
89 .map_err(|_| StatusCode::NOT_FOUND)?;
90
91 let agent = controller
93 .config
94 .agent_loader
95 .load_agent(&app_name)
96 .await
97 .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
98
99 let runner = adk_runner::Runner::new(adk_runner::RunnerConfig {
101 app_name: app_name.clone(),
102 agent,
103 session_service: controller.config.session_service.clone(),
104 artifact_service: controller.config.artifact_service.clone(),
105 memory_service: None,
106 run_config: None,
107 })
108 .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
109
110 let event_stream = runner
112 .run(user_id, session_id, adk_core::Content::new("user").with_text(&req.new_message))
113 .await
114 .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
115
116 let sse_stream = stream::unfold(event_stream, move |mut stream| async move {
118 use futures::StreamExt;
119 match stream.next().await {
120 Some(Ok(event)) => {
121 let json = serde_json::to_string(&event).ok()?;
122 Some((Ok(Event::default().data(json)), stream))
123 }
124 _ => None,
125 }
126 });
127
128 Ok(Sse::new(sse_stream).keep_alive(KeepAlive::default()))
129 }
130 .instrument(span)
131 .await
132}
133
134pub async fn run_sse_compat(
137 State(controller): State<RuntimeController>,
138 Json(req): Json<RunSseRequest>,
139) -> Result<Sse<impl Stream<Item = Result<Event, Infallible>>>, StatusCode> {
140 let app_name = req.app_name;
141 let user_id = req.user_id;
142 let session_id = req.session_id;
143
144 info!(
145 app_name = %app_name,
146 user_id = %user_id,
147 session_id = %session_id,
148 "POST /run_sse request received"
149 );
150
151 let message_text = req
153 .new_message
154 .parts
155 .iter()
156 .filter_map(|p| p.text.as_ref())
157 .cloned()
158 .collect::<Vec<_>>()
159 .join(" ");
160
161 let session_result = controller
163 .config
164 .session_service
165 .get(adk_session::GetRequest {
166 app_name: app_name.clone(),
167 user_id: user_id.clone(),
168 session_id: session_id.clone(),
169 num_recent_events: None,
170 after: None,
171 })
172 .await;
173
174 if session_result.is_err() {
176 controller
177 .config
178 .session_service
179 .create(adk_session::CreateRequest {
180 app_name: app_name.clone(),
181 user_id: user_id.clone(),
182 session_id: Some(session_id.clone()),
183 state: std::collections::HashMap::new(),
184 })
185 .await
186 .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
187 }
188
189 let agent = controller
191 .config
192 .agent_loader
193 .load_agent(&app_name)
194 .await
195 .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
196
197 let streaming_mode =
199 if req.streaming { adk_core::StreamingMode::SSE } else { adk_core::StreamingMode::None };
200
201 let runner = adk_runner::Runner::new(adk_runner::RunnerConfig {
202 app_name,
203 agent,
204 session_service: controller.config.session_service.clone(),
205 artifact_service: controller.config.artifact_service.clone(),
206 memory_service: None,
207 run_config: Some(adk_core::RunConfig { streaming_mode }),
208 })
209 .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
210
211 let event_stream = runner
213 .run(user_id, session_id, adk_core::Content::new("user").with_text(&message_text))
214 .await
215 .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
216
217 let sse_stream = stream::unfold(event_stream, move |mut stream| async move {
219 use futures::StreamExt;
220 match stream.next().await {
221 Some(Ok(event)) => {
222 let json = serde_json::to_string(&event).ok()?;
223 Some((Ok(Event::default().data(json)), stream))
224 }
225 _ => None,
226 }
227 });
228
229 Ok(Sse::new(sse_stream).keep_alive(KeepAlive::default()))
230}