adk_server/rest/controllers/
session.rs1use axum::{
2 Extension, Json,
3 extract::{Path, State},
4 http::StatusCode,
5};
6use serde::{Deserialize, Serialize};
7use std::sync::Arc;
8use tracing::info;
9
10#[derive(Clone)]
11pub struct SessionController {
12 session_service: Arc<dyn adk_session::SessionService>,
13}
14
15impl SessionController {
16 pub fn new(session_service: Arc<dyn adk_session::SessionService>) -> Self {
17 Self { session_service }
18 }
19
20 fn session_to_response(session: &dyn adk_session::Session) -> SessionResponse {
22 const MAX_EVENTS: usize = 10_000;
25 let events: Vec<serde_json::Value> = session
26 .events()
27 .all()
28 .into_iter()
29 .take(MAX_EVENTS)
30 .map(|event| serde_json::to_value(event).unwrap_or(serde_json::Value::Null))
31 .collect();
32
33 SessionResponse {
34 id: session.id().to_string(),
35 app_name: session.app_name().to_string(),
36 user_id: session.user_id().to_string(),
37 last_update_time: session.last_update_time().timestamp(),
38 events,
39 state: session.state().all(),
40 }
41 }
42}
43
44fn authorize_user_id(
45 request_context: &Option<adk_core::RequestContext>,
46 user_id: &str,
47) -> Result<String, StatusCode> {
48 match request_context {
49 Some(context) if context.user_id != user_id => Err(StatusCode::FORBIDDEN),
50 Some(context) => Ok(context.user_id.clone()),
51 None => Ok(user_id.to_string()),
52 }
53}
54
55fn effective_user_id(request_context: &Option<adk_core::RequestContext>, user_id: &str) -> String {
56 request_context
57 .as_ref()
58 .map(|context| context.user_id.clone())
59 .unwrap_or_else(|| user_id.to_string())
60}
61
62#[derive(Serialize, Deserialize)]
63pub struct CreateSessionRequest {
64 #[serde(rename = "appName")]
65 pub app_name: String,
66 #[serde(rename = "userId")]
67 pub user_id: String,
68 #[serde(rename = "sessionId", default)]
69 pub session_id: Option<String>,
70}
71
72#[derive(Serialize, Deserialize)]
73#[serde(rename_all = "camelCase")]
74pub struct SessionResponse {
75 pub id: String,
76 pub app_name: String,
77 pub user_id: String,
78 pub last_update_time: i64,
79 pub events: Vec<serde_json::Value>,
80 pub state: std::collections::HashMap<String, serde_json::Value>,
81}
82
83pub async fn create_session(
84 State(controller): State<SessionController>,
85 Extension(request_context): Extension<Option<adk_core::RequestContext>>,
86 Json(req): Json<CreateSessionRequest>,
87) -> Result<Json<SessionResponse>, StatusCode> {
88 let user_id = effective_user_id(&request_context, &req.user_id);
89
90 info!(
91 app_name = %req.app_name,
92 user_id = %user_id,
93 session_id = ?req.session_id,
94 "POST /sessions - Creating session"
95 );
96
97 let session_id = req.session_id.unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
99
100 let session = controller
101 .session_service
102 .create(adk_session::CreateRequest {
103 app_name: req.app_name.clone(),
104 user_id,
105 session_id: Some(session_id),
106 state: std::collections::HashMap::new(),
107 })
108 .await
109 .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
110
111 let response = SessionController::session_to_response(session.as_ref());
112
113 info!(session_id = %response.id, "Session created successfully");
114
115 Ok(Json(response))
116}
117
118pub async fn get_session(
119 State(controller): State<SessionController>,
120 Extension(request_context): Extension<Option<adk_core::RequestContext>>,
121 Path((app_name, user_id, session_id)): Path<(String, String, String)>,
122) -> Result<Json<SessionResponse>, StatusCode> {
123 let user_id = authorize_user_id(&request_context, &user_id)?;
124
125 let session = controller
126 .session_service
127 .get(adk_session::GetRequest {
128 app_name,
129 user_id,
130 session_id,
131 num_recent_events: None,
132 after: None,
133 })
134 .await
135 .map_err(|_| StatusCode::NOT_FOUND)?;
136
137 Ok(Json(SessionController::session_to_response(session.as_ref())))
138}
139
140pub async fn delete_session(
141 State(controller): State<SessionController>,
142 Extension(request_context): Extension<Option<adk_core::RequestContext>>,
143 Path((app_name, user_id, session_id)): Path<(String, String, String)>,
144) -> Result<StatusCode, StatusCode> {
145 let user_id = authorize_user_id(&request_context, &user_id)?;
146
147 controller
148 .session_service
149 .delete(adk_session::DeleteRequest { app_name, user_id, session_id })
150 .await
151 .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
152
153 Ok(StatusCode::NO_CONTENT)
154}
155
156const MAX_STATE_ENTRIES: usize = 1_000;
159
160const MAX_BODY_EVENTS: usize = 10_000;
162
163fn deserialize_bounded_state<'de, D>(
164 deserializer: D,
165) -> Result<std::collections::HashMap<String, serde_json::Value>, D::Error>
166where
167 D: serde::Deserializer<'de>,
168{
169 let full: std::collections::HashMap<String, serde_json::Value> =
170 serde::Deserialize::deserialize(deserializer)?;
171 if full.len() <= MAX_STATE_ENTRIES {
172 Ok(full)
173 } else {
174 Ok(full.into_iter().take(MAX_STATE_ENTRIES).collect())
175 }
176}
177
178fn deserialize_bounded_events<'de, D>(deserializer: D) -> Result<Vec<serde_json::Value>, D::Error>
179where
180 D: serde::Deserializer<'de>,
181{
182 let full: Vec<serde_json::Value> = serde::Deserialize::deserialize(deserializer)?;
183 if full.len() <= MAX_BODY_EVENTS {
184 Ok(full)
185 } else {
186 Ok(full.into_iter().take(MAX_BODY_EVENTS).collect())
187 }
188}
189
190#[derive(Serialize, Deserialize, Default)]
192pub struct CreateSessionBodyRequest {
193 #[serde(default, deserialize_with = "deserialize_bounded_state")]
194 pub state: std::collections::HashMap<String, serde_json::Value>,
195 #[serde(default, deserialize_with = "deserialize_bounded_events")]
196 pub events: Vec<serde_json::Value>,
197}
198
199#[derive(Deserialize)]
201pub struct SessionPathParams {
202 pub app_name: String,
203 pub user_id: String,
204 #[serde(default)]
205 pub session_id: Option<String>,
206}
207
208pub async fn create_session_from_path(
212 State(controller): State<SessionController>,
213 Extension(request_context): Extension<Option<adk_core::RequestContext>>,
214 Path(params): Path<SessionPathParams>,
215 body: Option<Json<CreateSessionBodyRequest>>,
216) -> Result<Json<SessionResponse>, StatusCode> {
217 let user_id = authorize_user_id(&request_context, ¶ms.user_id)?;
218 let session_id = params.session_id.unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
219
220 let session = controller
221 .session_service
222 .create(adk_session::CreateRequest {
223 app_name: params.app_name.clone(),
224 user_id,
225 session_id: Some(session_id),
226 state: match body {
227 Some(b) => {
228 let s = b.0.state;
229 if s.len() > MAX_STATE_ENTRIES {
230 s.into_iter().take(MAX_STATE_ENTRIES).collect()
231 } else {
232 s
233 }
234 }
235 None => std::collections::HashMap::new(),
236 },
237 })
238 .await
239 .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
240
241 Ok(Json(SessionController::session_to_response(session.as_ref())))
242}
243
244pub async fn get_session_from_path(
246 State(controller): State<SessionController>,
247 Extension(request_context): Extension<Option<adk_core::RequestContext>>,
248 Path(params): Path<SessionPathParams>,
249) -> Result<Json<SessionResponse>, StatusCode> {
250 let session_id = params.session_id.ok_or(StatusCode::BAD_REQUEST)?;
251 let user_id = authorize_user_id(&request_context, ¶ms.user_id)?;
252
253 let session = controller
254 .session_service
255 .get(adk_session::GetRequest {
256 app_name: params.app_name,
257 user_id,
258 session_id,
259 num_recent_events: None,
260 after: None,
261 })
262 .await
263 .map_err(|_| StatusCode::NOT_FOUND)?;
264
265 Ok(Json(SessionController::session_to_response(session.as_ref())))
266}
267
268pub async fn delete_session_from_path(
270 State(controller): State<SessionController>,
271 Extension(request_context): Extension<Option<adk_core::RequestContext>>,
272 Path(params): Path<SessionPathParams>,
273) -> Result<StatusCode, StatusCode> {
274 let session_id = params.session_id.ok_or(StatusCode::BAD_REQUEST)?;
275 let user_id = authorize_user_id(&request_context, ¶ms.user_id)?;
276
277 controller
278 .session_service
279 .delete(adk_session::DeleteRequest { app_name: params.app_name, user_id, session_id })
280 .await
281 .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
282
283 Ok(StatusCode::NO_CONTENT)
284}
285
286pub async fn list_sessions(
288 State(controller): State<SessionController>,
289 Extension(request_context): Extension<Option<adk_core::RequestContext>>,
290 Path(params): Path<SessionPathParams>,
291) -> Result<Json<Vec<SessionResponse>>, StatusCode> {
292 let user_id = authorize_user_id(&request_context, ¶ms.user_id)?;
293
294 tracing::info!("list_sessions called with app_name: {}, user_id: {}", params.app_name, user_id);
295
296 let sessions = controller
297 .session_service
298 .list(adk_session::ListRequest {
299 app_name: params.app_name.clone(),
300 user_id,
301 limit: None,
302 offset: None,
303 })
304 .await
305 .map_err(|e| {
306 tracing::error!("Failed to list sessions: {:?}", e);
307 StatusCode::INTERNAL_SERVER_ERROR
308 })?;
309
310 tracing::info!("Found {} sessions", sessions.len());
311
312 let responses: Vec<SessionResponse> =
313 sessions.into_iter().map(|s| SessionController::session_to_response(s.as_ref())).collect();
314
315 Ok(Json(responses))
316}