Skip to main content

adk_server/rest/controllers/
session.rs

1use axum::{
2    Extension, Json,
3    extract::{Path, State},
4    http::StatusCode,
5};
6use serde::{Deserialize, Serialize};
7use std::sync::Arc;
8use tracing::{error, info};
9
10#[derive(Clone)]
11pub struct SessionController {
12    session_service: Arc<dyn adk_session::SessionService>,
13}
14
15impl SessionController {
16    pub fn new(session_service: Arc<dyn adk_session::SessionService>) -> Self {
17        Self { session_service }
18    }
19
20    /// Helper function to convert a session to SessionResponse with actual events and state
21    fn session_to_response(session: &dyn adk_session::Session) -> SessionResponse {
22        // Convert events to JSON values, capping at a reasonable limit to prevent
23        // uncontrolled allocation from very large session histories.
24        const MAX_EVENTS: usize = 10_000;
25        let events: Vec<serde_json::Value> = session
26            .events()
27            .all()
28            .into_iter()
29            .take(MAX_EVENTS)
30            .map(|event| serde_json::to_value(event).unwrap_or(serde_json::Value::Null))
31            .collect();
32
33        SessionResponse {
34            id: session.id().to_string(),
35            app_name: session.app_name().to_string(),
36            user_id: session.user_id().to_string(),
37            last_update_time: session.last_update_time().timestamp(),
38            events,
39            state: session.state().all(),
40        }
41    }
42}
43
44fn authorize_user_id(
45    request_context: &Option<adk_core::RequestContext>,
46    user_id: &str,
47) -> Result<String, StatusCode> {
48    match request_context {
49        Some(context) if context.user_id != user_id => Err(StatusCode::FORBIDDEN),
50        Some(context) => Ok(context.user_id.clone()),
51        None => Ok(user_id.to_string()),
52    }
53}
54
55fn effective_user_id(request_context: &Option<adk_core::RequestContext>, user_id: &str) -> String {
56    request_context
57        .as_ref()
58        .map(|context| context.user_id.clone())
59        .unwrap_or_else(|| user_id.to_string())
60}
61
62#[derive(Serialize, Deserialize)]
63pub struct CreateSessionRequest {
64    #[serde(rename = "appName")]
65    pub app_name: String,
66    #[serde(rename = "userId")]
67    pub user_id: String,
68    #[serde(rename = "sessionId", default)]
69    pub session_id: Option<String>,
70}
71
72#[derive(Serialize, Deserialize)]
73#[serde(rename_all = "camelCase")]
74pub struct SessionResponse {
75    pub id: String,
76    pub app_name: String,
77    pub user_id: String,
78    pub last_update_time: i64,
79    pub events: Vec<serde_json::Value>,
80    pub state: std::collections::HashMap<String, serde_json::Value>,
81}
82
83pub async fn create_session(
84    State(controller): State<SessionController>,
85    Extension(request_context): Extension<Option<adk_core::RequestContext>>,
86    Json(req): Json<CreateSessionRequest>,
87) -> Result<Json<SessionResponse>, StatusCode> {
88    let user_id = effective_user_id(&request_context, &req.user_id);
89
90    info!(
91        app_name = %req.app_name,
92        user_id = %user_id,
93        session_id = ?req.session_id,
94        "POST /sessions - Creating session"
95    );
96
97    // Generate session ID if not provided
98    let session_id = req.session_id.unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
99
100    let session = controller
101        .session_service
102        .create(adk_session::CreateRequest {
103            app_name: req.app_name.clone(),
104            user_id,
105            session_id: Some(session_id),
106            state: std::collections::HashMap::new(),
107        })
108        .await
109        .map_err(|e| {
110            error!(error = %e, "session create failed");
111            StatusCode::INTERNAL_SERVER_ERROR
112        })?;
113
114    let response = SessionController::session_to_response(session.as_ref());
115
116    info!(session_id = %response.id, "Session created successfully");
117
118    Ok(Json(response))
119}
120
121pub async fn get_session(
122    State(controller): State<SessionController>,
123    Extension(request_context): Extension<Option<adk_core::RequestContext>>,
124    Path((app_name, user_id, session_id)): Path<(String, String, String)>,
125) -> Result<Json<SessionResponse>, StatusCode> {
126    let user_id = authorize_user_id(&request_context, &user_id)?;
127
128    let session = controller
129        .session_service
130        .get(adk_session::GetRequest {
131            app_name,
132            user_id,
133            session_id,
134            num_recent_events: None,
135            after: None,
136        })
137        .await
138        .map_err(|e| {
139            error!(error = %e, "session get failed");
140            StatusCode::NOT_FOUND
141        })?;
142
143    Ok(Json(SessionController::session_to_response(session.as_ref())))
144}
145
146pub async fn delete_session(
147    State(controller): State<SessionController>,
148    Extension(request_context): Extension<Option<adk_core::RequestContext>>,
149    Path((app_name, user_id, session_id)): Path<(String, String, String)>,
150) -> Result<StatusCode, StatusCode> {
151    let user_id = authorize_user_id(&request_context, &user_id)?;
152
153    controller
154        .session_service
155        .delete(adk_session::DeleteRequest { app_name, user_id, session_id })
156        .await
157        .map_err(|e| {
158            error!(error = %e, "session delete failed");
159            StatusCode::INTERNAL_SERVER_ERROR
160        })?;
161
162    Ok(StatusCode::NO_CONTENT)
163}
164
165/// Maximum number of state entries accepted in a create-session request body.
166/// Prevents uncontrolled allocation from user-provided input.
167const MAX_STATE_ENTRIES: usize = 1_000;
168
169/// Maximum number of events accepted in a create-session request body.
170const MAX_BODY_EVENTS: usize = 10_000;
171
172fn deserialize_bounded_state<'de, D>(
173    deserializer: D,
174) -> Result<std::collections::HashMap<String, serde_json::Value>, D::Error>
175where
176    D: serde::Deserializer<'de>,
177{
178    let full: std::collections::HashMap<String, serde_json::Value> =
179        serde::Deserialize::deserialize(deserializer)?;
180    if full.len() <= MAX_STATE_ENTRIES {
181        Ok(full)
182    } else {
183        Ok(full.into_iter().take(MAX_STATE_ENTRIES).collect())
184    }
185}
186
187fn deserialize_bounded_events<'de, D>(deserializer: D) -> Result<Vec<serde_json::Value>, D::Error>
188where
189    D: serde::Deserializer<'de>,
190{
191    let full: Vec<serde_json::Value> = serde::Deserialize::deserialize(deserializer)?;
192    if full.len() <= MAX_BODY_EVENTS {
193        Ok(full)
194    } else {
195        Ok(full.into_iter().take(MAX_BODY_EVENTS).collect())
196    }
197}
198
199/// Request body for creating session (optional, can be empty)
200#[derive(Serialize, Deserialize, Default)]
201pub struct CreateSessionBodyRequest {
202    #[serde(default, deserialize_with = "deserialize_bounded_state")]
203    pub state: std::collections::HashMap<String, serde_json::Value>,
204    #[serde(default, deserialize_with = "deserialize_bounded_events")]
205    pub events: Vec<serde_json::Value>,
206}
207
208/// Path parameters for session routes
209#[derive(Deserialize)]
210pub struct SessionPathParams {
211    pub app_name: String,
212    pub user_id: String,
213    #[serde(default)]
214    pub session_id: Option<String>,
215}
216
217/// Create session from URL path parameters (adk-go compatible)
218/// POST /apps/{app_name}/users/{user_id}/sessions
219/// POST /apps/{app_name}/users/{user_id}/sessions/{session_id}
220pub async fn create_session_from_path(
221    State(controller): State<SessionController>,
222    Extension(request_context): Extension<Option<adk_core::RequestContext>>,
223    Path(params): Path<SessionPathParams>,
224    body: Option<Json<CreateSessionBodyRequest>>,
225) -> Result<Json<SessionResponse>, StatusCode> {
226    let user_id = authorize_user_id(&request_context, &params.user_id)?;
227    let session_id = params.session_id.unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
228
229    let session = controller
230        .session_service
231        .create(adk_session::CreateRequest {
232            app_name: params.app_name.clone(),
233            user_id,
234            session_id: Some(session_id),
235            state: match body {
236                Some(b) => {
237                    let s = b.0.state;
238                    if s.len() > MAX_STATE_ENTRIES {
239                        s.into_iter().take(MAX_STATE_ENTRIES).collect()
240                    } else {
241                        s
242                    }
243                }
244                None => std::collections::HashMap::new(),
245            },
246        })
247        .await
248        .map_err(|e| {
249            error!(error = %e, "session create from path failed");
250            StatusCode::INTERNAL_SERVER_ERROR
251        })?;
252
253    Ok(Json(SessionController::session_to_response(session.as_ref())))
254}
255
256/// Get session from URL path parameters (adk-go compatible)
257pub async fn get_session_from_path(
258    State(controller): State<SessionController>,
259    Extension(request_context): Extension<Option<adk_core::RequestContext>>,
260    Path(params): Path<SessionPathParams>,
261) -> Result<Json<SessionResponse>, StatusCode> {
262    let session_id = params.session_id.ok_or(StatusCode::BAD_REQUEST)?;
263    let user_id = authorize_user_id(&request_context, &params.user_id)?;
264
265    let session = controller
266        .session_service
267        .get(adk_session::GetRequest {
268            app_name: params.app_name,
269            user_id,
270            session_id,
271            num_recent_events: None,
272            after: None,
273        })
274        .await
275        .map_err(|e| {
276            error!(error = %e, "session get from path failed");
277            StatusCode::NOT_FOUND
278        })?;
279
280    Ok(Json(SessionController::session_to_response(session.as_ref())))
281}
282
283/// Delete session from URL path parameters (adk-go compatible)
284pub async fn delete_session_from_path(
285    State(controller): State<SessionController>,
286    Extension(request_context): Extension<Option<adk_core::RequestContext>>,
287    Path(params): Path<SessionPathParams>,
288) -> Result<StatusCode, StatusCode> {
289    let session_id = params.session_id.ok_or(StatusCode::BAD_REQUEST)?;
290    let user_id = authorize_user_id(&request_context, &params.user_id)?;
291
292    controller
293        .session_service
294        .delete(adk_session::DeleteRequest { app_name: params.app_name, user_id, session_id })
295        .await
296        .map_err(|e| {
297            error!(error = %e, "session delete from path failed");
298            StatusCode::INTERNAL_SERVER_ERROR
299        })?;
300
301    Ok(StatusCode::NO_CONTENT)
302}
303
304/// List sessions for a user (adk-go compatible)
305pub async fn list_sessions(
306    State(controller): State<SessionController>,
307    Extension(request_context): Extension<Option<adk_core::RequestContext>>,
308    Path(params): Path<SessionPathParams>,
309) -> Result<Json<Vec<SessionResponse>>, StatusCode> {
310    let user_id = authorize_user_id(&request_context, &params.user_id)?;
311
312    tracing::info!("list_sessions called with app_name: {}, user_id: {}", params.app_name, user_id);
313
314    let sessions = controller
315        .session_service
316        .list(adk_session::ListRequest {
317            app_name: params.app_name.clone(),
318            user_id,
319            limit: None,
320            offset: None,
321        })
322        .await
323        .map_err(|e| {
324            error!(error = %e, "session list failed");
325            StatusCode::INTERNAL_SERVER_ERROR
326        })?;
327
328    tracing::info!("Found {} sessions", sessions.len());
329
330    let responses: Vec<SessionResponse> =
331        sessions.into_iter().map(|s| SessionController::session_to_response(s.as_ref())).collect();
332
333    Ok(Json(responses))
334}