Skip to main content

adk_server/rest/controllers/
session.rs

1use axum::{
2    Extension, Json,
3    extract::{Path, State},
4    http::StatusCode,
5};
6use serde::{Deserialize, Serialize};
7use std::sync::Arc;
8use tracing::info;
9
10#[derive(Clone)]
11pub struct SessionController {
12    session_service: Arc<dyn adk_session::SessionService>,
13}
14
15impl SessionController {
16    pub fn new(session_service: Arc<dyn adk_session::SessionService>) -> Self {
17        Self { session_service }
18    }
19
20    /// Helper function to convert a session to SessionResponse with actual events and state
21    fn session_to_response(session: &dyn adk_session::Session) -> SessionResponse {
22        // Convert events to JSON values, capping at a reasonable limit to prevent
23        // uncontrolled allocation from very large session histories.
24        const MAX_EVENTS: usize = 10_000;
25        let events: Vec<serde_json::Value> = session
26            .events()
27            .all()
28            .into_iter()
29            .take(MAX_EVENTS)
30            .map(|event| serde_json::to_value(event).unwrap_or(serde_json::Value::Null))
31            .collect();
32
33        SessionResponse {
34            id: session.id().to_string(),
35            app_name: session.app_name().to_string(),
36            user_id: session.user_id().to_string(),
37            last_update_time: session.last_update_time().timestamp(),
38            events,
39            state: session.state().all(),
40        }
41    }
42}
43
44fn authorize_user_id(
45    request_context: &Option<adk_core::RequestContext>,
46    user_id: &str,
47) -> Result<String, StatusCode> {
48    match request_context {
49        Some(context) if context.user_id != user_id => Err(StatusCode::FORBIDDEN),
50        Some(context) => Ok(context.user_id.clone()),
51        None => Ok(user_id.to_string()),
52    }
53}
54
55fn effective_user_id(request_context: &Option<adk_core::RequestContext>, user_id: &str) -> String {
56    request_context
57        .as_ref()
58        .map(|context| context.user_id.clone())
59        .unwrap_or_else(|| user_id.to_string())
60}
61
62#[derive(Serialize, Deserialize)]
63pub struct CreateSessionRequest {
64    #[serde(rename = "appName")]
65    pub app_name: String,
66    #[serde(rename = "userId")]
67    pub user_id: String,
68    #[serde(rename = "sessionId", default)]
69    pub session_id: Option<String>,
70}
71
72#[derive(Serialize, Deserialize)]
73#[serde(rename_all = "camelCase")]
74pub struct SessionResponse {
75    pub id: String,
76    pub app_name: String,
77    pub user_id: String,
78    pub last_update_time: i64,
79    pub events: Vec<serde_json::Value>,
80    pub state: std::collections::HashMap<String, serde_json::Value>,
81}
82
83pub async fn create_session(
84    State(controller): State<SessionController>,
85    Extension(request_context): Extension<Option<adk_core::RequestContext>>,
86    Json(req): Json<CreateSessionRequest>,
87) -> Result<Json<SessionResponse>, StatusCode> {
88    let user_id = effective_user_id(&request_context, &req.user_id);
89
90    info!(
91        app_name = %req.app_name,
92        user_id = %user_id,
93        session_id = ?req.session_id,
94        "POST /sessions - Creating session"
95    );
96
97    // Generate session ID if not provided
98    let session_id = req.session_id.unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
99
100    let session = controller
101        .session_service
102        .create(adk_session::CreateRequest {
103            app_name: req.app_name.clone(),
104            user_id,
105            session_id: Some(session_id),
106            state: std::collections::HashMap::new(),
107        })
108        .await
109        .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
110
111    let response = SessionController::session_to_response(session.as_ref());
112
113    info!(session_id = %response.id, "Session created successfully");
114
115    Ok(Json(response))
116}
117
118pub async fn get_session(
119    State(controller): State<SessionController>,
120    Extension(request_context): Extension<Option<adk_core::RequestContext>>,
121    Path((app_name, user_id, session_id)): Path<(String, String, String)>,
122) -> Result<Json<SessionResponse>, StatusCode> {
123    let user_id = authorize_user_id(&request_context, &user_id)?;
124
125    let session = controller
126        .session_service
127        .get(adk_session::GetRequest {
128            app_name,
129            user_id,
130            session_id,
131            num_recent_events: None,
132            after: None,
133        })
134        .await
135        .map_err(|_| StatusCode::NOT_FOUND)?;
136
137    Ok(Json(SessionController::session_to_response(session.as_ref())))
138}
139
140pub async fn delete_session(
141    State(controller): State<SessionController>,
142    Extension(request_context): Extension<Option<adk_core::RequestContext>>,
143    Path((app_name, user_id, session_id)): Path<(String, String, String)>,
144) -> Result<StatusCode, StatusCode> {
145    let user_id = authorize_user_id(&request_context, &user_id)?;
146
147    controller
148        .session_service
149        .delete(adk_session::DeleteRequest { app_name, user_id, session_id })
150        .await
151        .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
152
153    Ok(StatusCode::NO_CONTENT)
154}
155
156/// Maximum number of state entries accepted in a create-session request body.
157/// Prevents uncontrolled allocation from user-provided input.
158const MAX_STATE_ENTRIES: usize = 1_000;
159
160/// Maximum number of events accepted in a create-session request body.
161const MAX_BODY_EVENTS: usize = 10_000;
162
163fn deserialize_bounded_state<'de, D>(
164    deserializer: D,
165) -> Result<std::collections::HashMap<String, serde_json::Value>, D::Error>
166where
167    D: serde::Deserializer<'de>,
168{
169    let full: std::collections::HashMap<String, serde_json::Value> =
170        serde::Deserialize::deserialize(deserializer)?;
171    if full.len() <= MAX_STATE_ENTRIES {
172        Ok(full)
173    } else {
174        Ok(full.into_iter().take(MAX_STATE_ENTRIES).collect())
175    }
176}
177
178fn deserialize_bounded_events<'de, D>(deserializer: D) -> Result<Vec<serde_json::Value>, D::Error>
179where
180    D: serde::Deserializer<'de>,
181{
182    let full: Vec<serde_json::Value> = serde::Deserialize::deserialize(deserializer)?;
183    if full.len() <= MAX_BODY_EVENTS {
184        Ok(full)
185    } else {
186        Ok(full.into_iter().take(MAX_BODY_EVENTS).collect())
187    }
188}
189
190/// Request body for creating session (optional, can be empty)
191#[derive(Serialize, Deserialize, Default)]
192pub struct CreateSessionBodyRequest {
193    #[serde(default, deserialize_with = "deserialize_bounded_state")]
194    pub state: std::collections::HashMap<String, serde_json::Value>,
195    #[serde(default, deserialize_with = "deserialize_bounded_events")]
196    pub events: Vec<serde_json::Value>,
197}
198
199/// Path parameters for session routes
200#[derive(Deserialize)]
201pub struct SessionPathParams {
202    pub app_name: String,
203    pub user_id: String,
204    #[serde(default)]
205    pub session_id: Option<String>,
206}
207
208/// Create session from URL path parameters (adk-go compatible)
209/// POST /apps/{app_name}/users/{user_id}/sessions
210/// POST /apps/{app_name}/users/{user_id}/sessions/{session_id}
211pub async fn create_session_from_path(
212    State(controller): State<SessionController>,
213    Extension(request_context): Extension<Option<adk_core::RequestContext>>,
214    Path(params): Path<SessionPathParams>,
215    body: Option<Json<CreateSessionBodyRequest>>,
216) -> Result<Json<SessionResponse>, StatusCode> {
217    let user_id = authorize_user_id(&request_context, &params.user_id)?;
218    let session_id = params.session_id.unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
219
220    let session = controller
221        .session_service
222        .create(adk_session::CreateRequest {
223            app_name: params.app_name.clone(),
224            user_id,
225            session_id: Some(session_id),
226            state: match body {
227                Some(b) => {
228                    let s = b.0.state;
229                    if s.len() > MAX_STATE_ENTRIES {
230                        s.into_iter().take(MAX_STATE_ENTRIES).collect()
231                    } else {
232                        s
233                    }
234                }
235                None => std::collections::HashMap::new(),
236            },
237        })
238        .await
239        .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
240
241    Ok(Json(SessionController::session_to_response(session.as_ref())))
242}
243
244/// Get session from URL path parameters (adk-go compatible)
245pub async fn get_session_from_path(
246    State(controller): State<SessionController>,
247    Extension(request_context): Extension<Option<adk_core::RequestContext>>,
248    Path(params): Path<SessionPathParams>,
249) -> Result<Json<SessionResponse>, StatusCode> {
250    let session_id = params.session_id.ok_or(StatusCode::BAD_REQUEST)?;
251    let user_id = authorize_user_id(&request_context, &params.user_id)?;
252
253    let session = controller
254        .session_service
255        .get(adk_session::GetRequest {
256            app_name: params.app_name,
257            user_id,
258            session_id,
259            num_recent_events: None,
260            after: None,
261        })
262        .await
263        .map_err(|_| StatusCode::NOT_FOUND)?;
264
265    Ok(Json(SessionController::session_to_response(session.as_ref())))
266}
267
268/// Delete session from URL path parameters (adk-go compatible)
269pub async fn delete_session_from_path(
270    State(controller): State<SessionController>,
271    Extension(request_context): Extension<Option<adk_core::RequestContext>>,
272    Path(params): Path<SessionPathParams>,
273) -> Result<StatusCode, StatusCode> {
274    let session_id = params.session_id.ok_or(StatusCode::BAD_REQUEST)?;
275    let user_id = authorize_user_id(&request_context, &params.user_id)?;
276
277    controller
278        .session_service
279        .delete(adk_session::DeleteRequest { app_name: params.app_name, user_id, session_id })
280        .await
281        .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
282
283    Ok(StatusCode::NO_CONTENT)
284}
285
286/// List sessions for a user (adk-go compatible)
287pub async fn list_sessions(
288    State(controller): State<SessionController>,
289    Extension(request_context): Extension<Option<adk_core::RequestContext>>,
290    Path(params): Path<SessionPathParams>,
291) -> Result<Json<Vec<SessionResponse>>, StatusCode> {
292    let user_id = authorize_user_id(&request_context, &params.user_id)?;
293
294    tracing::info!("list_sessions called with app_name: {}, user_id: {}", params.app_name, user_id);
295
296    let sessions = controller
297        .session_service
298        .list(adk_session::ListRequest {
299            app_name: params.app_name.clone(),
300            user_id,
301            limit: None,
302            offset: None,
303        })
304        .await
305        .map_err(|e| {
306            tracing::error!("Failed to list sessions: {:?}", e);
307            StatusCode::INTERNAL_SERVER_ERROR
308        })?;
309
310    tracing::info!("Found {} sessions", sessions.len());
311
312    let responses: Vec<SessionResponse> =
313        sessions.into_iter().map(|s| SessionController::session_to_response(s.as_ref())).collect();
314
315    Ok(Json(responses))
316}