adk_server/rest/controllers/
session.rs

1use axum::{
2    Json,
3    extract::{Path, State},
4    http::StatusCode,
5};
6use serde::{Deserialize, Serialize};
7use std::sync::Arc;
8use tracing::info;
9
10#[derive(Clone)]
11pub struct SessionController {
12    session_service: Arc<dyn adk_session::SessionService>,
13}
14
15impl SessionController {
16    pub fn new(session_service: Arc<dyn adk_session::SessionService>) -> Self {
17        Self { session_service }
18    }
19
20    /// Helper function to convert a session to SessionResponse with actual events and state
21    fn session_to_response(session: &dyn adk_session::Session) -> SessionResponse {
22        // Convert events to JSON values
23        let events: Vec<serde_json::Value> = session
24            .events()
25            .all()
26            .into_iter()
27            .map(|event| serde_json::to_value(event).unwrap_or(serde_json::Value::Null))
28            .collect();
29
30        SessionResponse {
31            id: session.id().to_string(),
32            app_name: session.app_name().to_string(),
33            user_id: session.user_id().to_string(),
34            last_update_time: session.last_update_time().timestamp(),
35            events,
36            state: session.state().all(),
37        }
38    }
39}
40
41#[derive(Serialize, Deserialize)]
42pub struct CreateSessionRequest {
43    #[serde(rename = "appName")]
44    pub app_name: String,
45    #[serde(rename = "userId")]
46    pub user_id: String,
47    #[serde(rename = "sessionId", default)]
48    pub session_id: Option<String>,
49}
50
51#[derive(Serialize, Deserialize)]
52#[serde(rename_all = "camelCase")]
53pub struct SessionResponse {
54    pub id: String,
55    pub app_name: String,
56    pub user_id: String,
57    pub last_update_time: i64,
58    pub events: Vec<serde_json::Value>,
59    pub state: std::collections::HashMap<String, serde_json::Value>,
60}
61
62pub async fn create_session(
63    State(controller): State<SessionController>,
64    Json(req): Json<CreateSessionRequest>,
65) -> Result<Json<SessionResponse>, StatusCode> {
66    info!(
67        app_name = %req.app_name,
68        user_id = %req.user_id,
69        session_id = ?req.session_id,
70        "POST /sessions - Creating session"
71    );
72
73    // Generate session ID if not provided
74    let session_id = req.session_id.unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
75
76    let session = controller
77        .session_service
78        .create(adk_session::CreateRequest {
79            app_name: req.app_name.clone(),
80            user_id: req.user_id.clone(),
81            session_id: Some(session_id),
82            state: std::collections::HashMap::new(),
83        })
84        .await
85        .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
86
87    let response = SessionController::session_to_response(session.as_ref());
88
89    info!(session_id = %response.id, "Session created successfully");
90
91    Ok(Json(response))
92}
93
94pub async fn get_session(
95    State(controller): State<SessionController>,
96    Path((app_name, user_id, session_id)): Path<(String, String, String)>,
97) -> Result<Json<SessionResponse>, StatusCode> {
98    let session = controller
99        .session_service
100        .get(adk_session::GetRequest {
101            app_name,
102            user_id,
103            session_id,
104            num_recent_events: None,
105            after: None,
106        })
107        .await
108        .map_err(|_| StatusCode::NOT_FOUND)?;
109
110    Ok(Json(SessionController::session_to_response(session.as_ref())))
111}
112
113pub async fn delete_session(
114    State(controller): State<SessionController>,
115    Path((app_name, user_id, session_id)): Path<(String, String, String)>,
116) -> Result<StatusCode, StatusCode> {
117    controller
118        .session_service
119        .delete(adk_session::DeleteRequest { app_name, user_id, session_id })
120        .await
121        .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
122
123    Ok(StatusCode::NO_CONTENT)
124}
125
126/// Request body for creating session (optional, can be empty)
127#[derive(Serialize, Deserialize, Default)]
128pub struct CreateSessionBodyRequest {
129    #[serde(default)]
130    pub state: std::collections::HashMap<String, serde_json::Value>,
131    #[serde(default)]
132    pub events: Vec<serde_json::Value>,
133}
134
135/// Path parameters for session routes
136#[derive(Deserialize)]
137pub struct SessionPathParams {
138    pub app_name: String,
139    pub user_id: String,
140    #[serde(default)]
141    pub session_id: Option<String>,
142}
143
144/// Create session from URL path parameters (adk-go compatible)
145/// POST /apps/{app_name}/users/{user_id}/sessions
146/// POST /apps/{app_name}/users/{user_id}/sessions/{session_id}
147pub async fn create_session_from_path(
148    State(controller): State<SessionController>,
149    Path(params): Path<SessionPathParams>,
150    body: Option<Json<CreateSessionBodyRequest>>,
151) -> Result<Json<SessionResponse>, StatusCode> {
152    let session_id = params.session_id.unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
153
154    let session = controller
155        .session_service
156        .create(adk_session::CreateRequest {
157            app_name: params.app_name.clone(),
158            user_id: params.user_id.clone(),
159            session_id: Some(session_id),
160            state: body.map(|b| b.state.clone()).unwrap_or_default().into_iter().collect(),
161        })
162        .await
163        .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
164
165    Ok(Json(SessionController::session_to_response(session.as_ref())))
166}
167
168/// Get session from URL path parameters (adk-go compatible)
169pub async fn get_session_from_path(
170    State(controller): State<SessionController>,
171    Path(params): Path<SessionPathParams>,
172) -> Result<Json<SessionResponse>, StatusCode> {
173    let session_id = params.session_id.ok_or(StatusCode::BAD_REQUEST)?;
174
175    let session = controller
176        .session_service
177        .get(adk_session::GetRequest {
178            app_name: params.app_name,
179            user_id: params.user_id,
180            session_id,
181            num_recent_events: None,
182            after: None,
183        })
184        .await
185        .map_err(|_| StatusCode::NOT_FOUND)?;
186
187    Ok(Json(SessionController::session_to_response(session.as_ref())))
188}
189
190/// Delete session from URL path parameters (adk-go compatible)
191pub async fn delete_session_from_path(
192    State(controller): State<SessionController>,
193    Path(params): Path<SessionPathParams>,
194) -> Result<StatusCode, StatusCode> {
195    let session_id = params.session_id.ok_or(StatusCode::BAD_REQUEST)?;
196
197    controller
198        .session_service
199        .delete(adk_session::DeleteRequest {
200            app_name: params.app_name,
201            user_id: params.user_id,
202            session_id,
203        })
204        .await
205        .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
206
207    Ok(StatusCode::NO_CONTENT)
208}
209
210/// List sessions for a user (adk-go compatible)
211pub async fn list_sessions(
212    State(controller): State<SessionController>,
213    Path(params): Path<SessionPathParams>,
214) -> Result<Json<Vec<SessionResponse>>, StatusCode> {
215    tracing::info!(
216        "list_sessions called with app_name: {}, user_id: {}",
217        params.app_name,
218        params.user_id
219    );
220
221    let sessions = controller
222        .session_service
223        .list(adk_session::ListRequest {
224            app_name: params.app_name.clone(),
225            user_id: params.user_id.clone(),
226        })
227        .await
228        .map_err(|e| {
229            tracing::error!("Failed to list sessions: {:?}", e);
230            StatusCode::INTERNAL_SERVER_ERROR
231        })?;
232
233    tracing::info!("Found {} sessions", sessions.len());
234
235    let responses: Vec<SessionResponse> =
236        sessions.into_iter().map(|s| SessionController::session_to_response(s.as_ref())).collect();
237
238    Ok(Json(responses))
239}