adk_server/rest/controllers/
session.rs

1use axum::{
2    extract::{Path, State},
3    http::StatusCode,
4    Json,
5};
6use serde::{Deserialize, Serialize};
7use std::sync::Arc;
8use tracing::info;
9
10#[derive(Clone)]
11pub struct SessionController {
12    session_service: Arc<dyn adk_session::SessionService>,
13}
14
15impl SessionController {
16    pub fn new(session_service: Arc<dyn adk_session::SessionService>) -> Self {
17        Self { session_service }
18    }
19}
20
21#[derive(Serialize, Deserialize)]
22pub struct CreateSessionRequest {
23    #[serde(rename = "appName")]
24    pub app_name: String,
25    #[serde(rename = "userId")]
26    pub user_id: String,
27    #[serde(rename = "sessionId", default)]
28    pub session_id: Option<String>,
29}
30
31#[derive(Serialize, Deserialize)]
32#[serde(rename_all = "camelCase")]
33pub struct SessionResponse {
34    pub id: String,
35    pub app_name: String,
36    pub user_id: String,
37    pub last_update_time: i64,
38    pub events: Vec<serde_json::Value>,
39    pub state: std::collections::HashMap<String, serde_json::Value>,
40}
41
42pub async fn create_session(
43    State(controller): State<SessionController>,
44    Json(req): Json<CreateSessionRequest>,
45) -> Result<Json<SessionResponse>, StatusCode> {
46    info!(
47        app_name = %req.app_name,
48        user_id = %req.user_id,
49        session_id = ?req.session_id,
50        "POST /sessions - Creating session"
51    );
52
53    // Generate session ID if not provided
54    let session_id = req.session_id.unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
55
56    let session = controller
57        .session_service
58        .create(adk_session::CreateRequest {
59            app_name: req.app_name.clone(),
60            user_id: req.user_id.clone(),
61            session_id: Some(session_id),
62            state: std::collections::HashMap::new(),
63        })
64        .await
65        .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
66
67    let response = SessionResponse {
68        id: session.id().to_string(),
69        app_name: session.app_name().to_string(),
70        user_id: session.user_id().to_string(),
71        last_update_time: session.last_update_time().timestamp(),
72        events: vec![],
73        state: std::collections::HashMap::new(),
74    };
75
76    info!(session_id = %response.id, "Session created successfully");
77
78    Ok(Json(response))
79}
80
81pub async fn get_session(
82    State(controller): State<SessionController>,
83    Path((app_name, user_id, session_id)): Path<(String, String, String)>,
84) -> Result<Json<SessionResponse>, StatusCode> {
85    let session = controller
86        .session_service
87        .get(adk_session::GetRequest {
88            app_name,
89            user_id,
90            session_id,
91            num_recent_events: None,
92            after: None,
93        })
94        .await
95        .map_err(|_| StatusCode::NOT_FOUND)?;
96
97    Ok(Json(SessionResponse {
98        id: session.id().to_string(),
99        app_name: session.app_name().to_string(),
100        user_id: session.user_id().to_string(),
101        last_update_time: session.last_update_time().timestamp(),
102        events: vec![],
103        state: std::collections::HashMap::new(),
104    }))
105}
106
107pub async fn delete_session(
108    State(controller): State<SessionController>,
109    Path((app_name, user_id, session_id)): Path<(String, String, String)>,
110) -> Result<StatusCode, StatusCode> {
111    controller
112        .session_service
113        .delete(adk_session::DeleteRequest { app_name, user_id, session_id })
114        .await
115        .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
116
117    Ok(StatusCode::NO_CONTENT)
118}
119
120/// Request body for creating session (optional, can be empty)
121#[derive(Serialize, Deserialize, Default)]
122pub struct CreateSessionBodyRequest {
123    #[serde(default)]
124    pub state: std::collections::HashMap<String, serde_json::Value>,
125    #[serde(default)]
126    pub events: Vec<serde_json::Value>,
127}
128
129/// Path parameters for session routes
130#[derive(Deserialize)]
131pub struct SessionPathParams {
132    pub app_name: String,
133    pub user_id: String,
134    #[serde(default)]
135    pub session_id: Option<String>,
136}
137
138/// Create session from URL path parameters (adk-go compatible)
139/// POST /apps/{app_name}/users/{user_id}/sessions
140/// POST /apps/{app_name}/users/{user_id}/sessions/{session_id}
141pub async fn create_session_from_path(
142    State(controller): State<SessionController>,
143    Path(params): Path<SessionPathParams>,
144    body: Option<Json<CreateSessionBodyRequest>>,
145) -> Result<Json<SessionResponse>, StatusCode> {
146    let session_id = params.session_id.unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
147
148    let session = controller
149        .session_service
150        .create(adk_session::CreateRequest {
151            app_name: params.app_name.clone(),
152            user_id: params.user_id.clone(),
153            session_id: Some(session_id),
154            state: body.map(|b| b.state.clone()).unwrap_or_default().into_iter().collect(),
155        })
156        .await
157        .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
158
159    Ok(Json(SessionResponse {
160        id: session.id().to_string(),
161        app_name: session.app_name().to_string(),
162        user_id: session.user_id().to_string(),
163        last_update_time: session.last_update_time().timestamp(),
164        events: vec![],
165        state: std::collections::HashMap::new(),
166    }))
167}
168
169/// Get session from URL path parameters (adk-go compatible)
170pub async fn get_session_from_path(
171    State(controller): State<SessionController>,
172    Path(params): Path<SessionPathParams>,
173) -> Result<Json<SessionResponse>, StatusCode> {
174    let session_id = params.session_id.ok_or(StatusCode::BAD_REQUEST)?;
175
176    let session = controller
177        .session_service
178        .get(adk_session::GetRequest {
179            app_name: params.app_name,
180            user_id: params.user_id,
181            session_id,
182            num_recent_events: None,
183            after: None,
184        })
185        .await
186        .map_err(|_| StatusCode::NOT_FOUND)?;
187
188    Ok(Json(SessionResponse {
189        id: session.id().to_string(),
190        app_name: session.app_name().to_string(),
191        user_id: session.user_id().to_string(),
192        last_update_time: session.last_update_time().timestamp(),
193        events: vec![],
194        state: std::collections::HashMap::new(),
195    }))
196}
197
198/// Delete session from URL path parameters (adk-go compatible)
199pub async fn delete_session_from_path(
200    State(controller): State<SessionController>,
201    Path(params): Path<SessionPathParams>,
202) -> Result<StatusCode, StatusCode> {
203    let session_id = params.session_id.ok_or(StatusCode::BAD_REQUEST)?;
204
205    controller
206        .session_service
207        .delete(adk_session::DeleteRequest {
208            app_name: params.app_name,
209            user_id: params.user_id,
210            session_id,
211        })
212        .await
213        .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
214
215    Ok(StatusCode::NO_CONTENT)
216}
217
218/// List sessions for a user (adk-go compatible)
219pub async fn list_sessions(
220    State(controller): State<SessionController>,
221    Path(params): Path<SessionPathParams>,
222) -> Result<Json<Vec<SessionResponse>>, StatusCode> {
223    let sessions = controller
224        .session_service
225        .list(adk_session::ListRequest { app_name: params.app_name, user_id: params.user_id })
226        .await
227        .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
228
229    let responses: Vec<SessionResponse> = sessions
230        .into_iter()
231        .map(|s| SessionResponse {
232            id: s.id().to_string(),
233            app_name: s.app_name().to_string(),
234            user_id: s.user_id().to_string(),
235            last_update_time: s.last_update_time().timestamp(),
236            events: vec![],
237            state: std::collections::HashMap::new(),
238        })
239        .collect();
240
241    Ok(Json(responses))
242}