adk_server/rest/controllers/
session.rs1use axum::{
2 Extension, Json,
3 extract::{Path, State},
4 http::StatusCode,
5};
6use serde::{Deserialize, Serialize};
7use std::sync::Arc;
8use tracing::{error, info};
9
10#[derive(Clone)]
11pub struct SessionController {
12 session_service: Arc<dyn adk_session::SessionService>,
13}
14
15impl SessionController {
16 pub fn new(session_service: Arc<dyn adk_session::SessionService>) -> Self {
17 Self { session_service }
18 }
19
20 fn session_to_response(session: &dyn adk_session::Session) -> SessionResponse {
22 const MAX_EVENTS: usize = 10_000;
25 let events: Vec<serde_json::Value> = session
26 .events()
27 .all()
28 .into_iter()
29 .take(MAX_EVENTS)
30 .map(|event| serde_json::to_value(event).unwrap_or(serde_json::Value::Null))
31 .collect();
32
33 SessionResponse {
34 id: session.id().to_string(),
35 app_name: session.app_name().to_string(),
36 user_id: session.user_id().to_string(),
37 last_update_time: session.last_update_time().timestamp(),
38 events,
39 state: session.state().all(),
40 }
41 }
42}
43
44fn authorize_user_id(
45 request_context: &Option<adk_core::RequestContext>,
46 user_id: &str,
47) -> Result<String, StatusCode> {
48 match request_context {
49 Some(context) if context.user_id != user_id => Err(StatusCode::FORBIDDEN),
50 Some(context) => Ok(context.user_id.clone()),
51 None => Ok(user_id.to_string()),
52 }
53}
54
55fn effective_user_id(request_context: &Option<adk_core::RequestContext>, user_id: &str) -> String {
56 request_context
57 .as_ref()
58 .map(|context| context.user_id.clone())
59 .unwrap_or_else(|| user_id.to_string())
60}
61
62#[derive(Serialize, Deserialize)]
63pub struct CreateSessionRequest {
64 #[serde(rename = "appName")]
65 pub app_name: String,
66 #[serde(rename = "userId")]
67 pub user_id: String,
68 #[serde(rename = "sessionId", default)]
69 pub session_id: Option<String>,
70}
71
72#[derive(Serialize, Deserialize)]
73#[serde(rename_all = "camelCase")]
74pub struct SessionResponse {
75 pub id: String,
76 pub app_name: String,
77 pub user_id: String,
78 pub last_update_time: i64,
79 pub events: Vec<serde_json::Value>,
80 pub state: std::collections::HashMap<String, serde_json::Value>,
81}
82
83pub async fn create_session(
84 State(controller): State<SessionController>,
85 Extension(request_context): Extension<Option<adk_core::RequestContext>>,
86 Json(req): Json<CreateSessionRequest>,
87) -> Result<Json<SessionResponse>, StatusCode> {
88 let user_id = effective_user_id(&request_context, &req.user_id);
89
90 info!(
91 app_name = %req.app_name,
92 user_id = %user_id,
93 session_id = ?req.session_id,
94 "POST /sessions - Creating session"
95 );
96
97 let session_id = req.session_id.unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
99
100 let session = controller
101 .session_service
102 .create(adk_session::CreateRequest {
103 app_name: req.app_name.clone(),
104 user_id,
105 session_id: Some(session_id),
106 state: std::collections::HashMap::new(),
107 })
108 .await
109 .map_err(|e| {
110 error!(error = %e, "session create failed");
111 StatusCode::INTERNAL_SERVER_ERROR
112 })?;
113
114 let response = SessionController::session_to_response(session.as_ref());
115
116 info!(session_id = %response.id, "Session created successfully");
117
118 Ok(Json(response))
119}
120
121pub async fn get_session(
122 State(controller): State<SessionController>,
123 Extension(request_context): Extension<Option<adk_core::RequestContext>>,
124 Path((app_name, user_id, session_id)): Path<(String, String, String)>,
125) -> Result<Json<SessionResponse>, StatusCode> {
126 let user_id = authorize_user_id(&request_context, &user_id)?;
127
128 let session = controller
129 .session_service
130 .get(adk_session::GetRequest {
131 app_name,
132 user_id,
133 session_id,
134 num_recent_events: None,
135 after: None,
136 })
137 .await
138 .map_err(|e| {
139 error!(error = %e, "session get failed");
140 StatusCode::NOT_FOUND
141 })?;
142
143 Ok(Json(SessionController::session_to_response(session.as_ref())))
144}
145
146pub async fn delete_session(
147 State(controller): State<SessionController>,
148 Extension(request_context): Extension<Option<adk_core::RequestContext>>,
149 Path((app_name, user_id, session_id)): Path<(String, String, String)>,
150) -> Result<StatusCode, StatusCode> {
151 let user_id = authorize_user_id(&request_context, &user_id)?;
152
153 controller
154 .session_service
155 .delete(adk_session::DeleteRequest { app_name, user_id, session_id })
156 .await
157 .map_err(|e| {
158 error!(error = %e, "session delete failed");
159 StatusCode::INTERNAL_SERVER_ERROR
160 })?;
161
162 Ok(StatusCode::NO_CONTENT)
163}
164
165const MAX_STATE_ENTRIES: usize = 1_000;
168
169const MAX_BODY_EVENTS: usize = 10_000;
171
172fn deserialize_bounded_state<'de, D>(
173 deserializer: D,
174) -> Result<std::collections::HashMap<String, serde_json::Value>, D::Error>
175where
176 D: serde::Deserializer<'de>,
177{
178 let full: std::collections::HashMap<String, serde_json::Value> =
179 serde::Deserialize::deserialize(deserializer)?;
180 if full.len() <= MAX_STATE_ENTRIES {
181 Ok(full)
182 } else {
183 Ok(full.into_iter().take(MAX_STATE_ENTRIES).collect())
184 }
185}
186
187fn deserialize_bounded_events<'de, D>(deserializer: D) -> Result<Vec<serde_json::Value>, D::Error>
188where
189 D: serde::Deserializer<'de>,
190{
191 let full: Vec<serde_json::Value> = serde::Deserialize::deserialize(deserializer)?;
192 if full.len() <= MAX_BODY_EVENTS {
193 Ok(full)
194 } else {
195 Ok(full.into_iter().take(MAX_BODY_EVENTS).collect())
196 }
197}
198
199#[derive(Serialize, Deserialize, Default)]
201pub struct CreateSessionBodyRequest {
202 #[serde(default, deserialize_with = "deserialize_bounded_state")]
203 pub state: std::collections::HashMap<String, serde_json::Value>,
204 #[serde(default, deserialize_with = "deserialize_bounded_events")]
205 pub events: Vec<serde_json::Value>,
206}
207
208#[derive(Deserialize)]
210pub struct SessionPathParams {
211 pub app_name: String,
212 pub user_id: String,
213 #[serde(default)]
214 pub session_id: Option<String>,
215}
216
217pub async fn create_session_from_path(
221 State(controller): State<SessionController>,
222 Extension(request_context): Extension<Option<adk_core::RequestContext>>,
223 Path(params): Path<SessionPathParams>,
224 body: Option<Json<CreateSessionBodyRequest>>,
225) -> Result<Json<SessionResponse>, StatusCode> {
226 let user_id = authorize_user_id(&request_context, ¶ms.user_id)?;
227 let session_id = params.session_id.unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
228
229 let session = controller
230 .session_service
231 .create(adk_session::CreateRequest {
232 app_name: params.app_name.clone(),
233 user_id,
234 session_id: Some(session_id),
235 state: match body {
236 Some(b) => {
237 let s = b.0.state;
238 if s.len() > MAX_STATE_ENTRIES {
239 s.into_iter().take(MAX_STATE_ENTRIES).collect()
240 } else {
241 s
242 }
243 }
244 None => std::collections::HashMap::new(),
245 },
246 })
247 .await
248 .map_err(|e| {
249 error!(error = %e, "session create from path failed");
250 StatusCode::INTERNAL_SERVER_ERROR
251 })?;
252
253 Ok(Json(SessionController::session_to_response(session.as_ref())))
254}
255
256pub async fn get_session_from_path(
258 State(controller): State<SessionController>,
259 Extension(request_context): Extension<Option<adk_core::RequestContext>>,
260 Path(params): Path<SessionPathParams>,
261) -> Result<Json<SessionResponse>, StatusCode> {
262 let session_id = params.session_id.ok_or(StatusCode::BAD_REQUEST)?;
263 let user_id = authorize_user_id(&request_context, ¶ms.user_id)?;
264
265 let session = controller
266 .session_service
267 .get(adk_session::GetRequest {
268 app_name: params.app_name,
269 user_id,
270 session_id,
271 num_recent_events: None,
272 after: None,
273 })
274 .await
275 .map_err(|e| {
276 error!(error = %e, "session get from path failed");
277 StatusCode::NOT_FOUND
278 })?;
279
280 Ok(Json(SessionController::session_to_response(session.as_ref())))
281}
282
283pub async fn delete_session_from_path(
285 State(controller): State<SessionController>,
286 Extension(request_context): Extension<Option<adk_core::RequestContext>>,
287 Path(params): Path<SessionPathParams>,
288) -> Result<StatusCode, StatusCode> {
289 let session_id = params.session_id.ok_or(StatusCode::BAD_REQUEST)?;
290 let user_id = authorize_user_id(&request_context, ¶ms.user_id)?;
291
292 controller
293 .session_service
294 .delete(adk_session::DeleteRequest { app_name: params.app_name, user_id, session_id })
295 .await
296 .map_err(|e| {
297 error!(error = %e, "session delete from path failed");
298 StatusCode::INTERNAL_SERVER_ERROR
299 })?;
300
301 Ok(StatusCode::NO_CONTENT)
302}
303
304pub async fn list_sessions(
306 State(controller): State<SessionController>,
307 Extension(request_context): Extension<Option<adk_core::RequestContext>>,
308 Path(params): Path<SessionPathParams>,
309) -> Result<Json<Vec<SessionResponse>>, StatusCode> {
310 let user_id = authorize_user_id(&request_context, ¶ms.user_id)?;
311
312 tracing::info!("list_sessions called with app_name: {}, user_id: {}", params.app_name, user_id);
313
314 let sessions = controller
315 .session_service
316 .list(adk_session::ListRequest {
317 app_name: params.app_name.clone(),
318 user_id,
319 limit: None,
320 offset: None,
321 })
322 .await
323 .map_err(|e| {
324 error!(error = %e, "session list failed");
325 StatusCode::INTERNAL_SERVER_ERROR
326 })?;
327
328 tracing::info!("Found {} sessions", sessions.len());
329
330 let responses: Vec<SessionResponse> =
331 sessions.into_iter().map(|s| SessionController::session_to_response(s.as_ref())).collect();
332
333 Ok(Json(responses))
334}