adk_server/rest/controllers/
session.rs1use axum::{
2 Json,
3 extract::{Path, State},
4 http::StatusCode,
5};
6use serde::{Deserialize, Serialize};
7use std::sync::Arc;
8use tracing::info;
9
10#[derive(Clone)]
11pub struct SessionController {
12 session_service: Arc<dyn adk_session::SessionService>,
13}
14
15impl SessionController {
16 pub fn new(session_service: Arc<dyn adk_session::SessionService>) -> Self {
17 Self { session_service }
18 }
19
20 fn session_to_response(session: &dyn adk_session::Session) -> SessionResponse {
22 let events: Vec<serde_json::Value> = session
24 .events()
25 .all()
26 .into_iter()
27 .map(|event| serde_json::to_value(event).unwrap_or(serde_json::Value::Null))
28 .collect();
29
30 SessionResponse {
31 id: session.id().to_string(),
32 app_name: session.app_name().to_string(),
33 user_id: session.user_id().to_string(),
34 last_update_time: session.last_update_time().timestamp(),
35 events,
36 state: session.state().all(),
37 }
38 }
39}
40
41#[derive(Serialize, Deserialize)]
42pub struct CreateSessionRequest {
43 #[serde(rename = "appName")]
44 pub app_name: String,
45 #[serde(rename = "userId")]
46 pub user_id: String,
47 #[serde(rename = "sessionId", default)]
48 pub session_id: Option<String>,
49}
50
51#[derive(Serialize, Deserialize)]
52#[serde(rename_all = "camelCase")]
53pub struct SessionResponse {
54 pub id: String,
55 pub app_name: String,
56 pub user_id: String,
57 pub last_update_time: i64,
58 pub events: Vec<serde_json::Value>,
59 pub state: std::collections::HashMap<String, serde_json::Value>,
60}
61
62pub async fn create_session(
63 State(controller): State<SessionController>,
64 Json(req): Json<CreateSessionRequest>,
65) -> Result<Json<SessionResponse>, StatusCode> {
66 info!(
67 app_name = %req.app_name,
68 user_id = %req.user_id,
69 session_id = ?req.session_id,
70 "POST /sessions - Creating session"
71 );
72
73 let session_id = req.session_id.unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
75
76 let session = controller
77 .session_service
78 .create(adk_session::CreateRequest {
79 app_name: req.app_name.clone(),
80 user_id: req.user_id.clone(),
81 session_id: Some(session_id),
82 state: std::collections::HashMap::new(),
83 })
84 .await
85 .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
86
87 let response = SessionController::session_to_response(session.as_ref());
88
89 info!(session_id = %response.id, "Session created successfully");
90
91 Ok(Json(response))
92}
93
94pub async fn get_session(
95 State(controller): State<SessionController>,
96 Path((app_name, user_id, session_id)): Path<(String, String, String)>,
97) -> Result<Json<SessionResponse>, StatusCode> {
98 let session = controller
99 .session_service
100 .get(adk_session::GetRequest {
101 app_name,
102 user_id,
103 session_id,
104 num_recent_events: None,
105 after: None,
106 })
107 .await
108 .map_err(|_| StatusCode::NOT_FOUND)?;
109
110 Ok(Json(SessionController::session_to_response(session.as_ref())))
111}
112
113pub async fn delete_session(
114 State(controller): State<SessionController>,
115 Path((app_name, user_id, session_id)): Path<(String, String, String)>,
116) -> Result<StatusCode, StatusCode> {
117 controller
118 .session_service
119 .delete(adk_session::DeleteRequest { app_name, user_id, session_id })
120 .await
121 .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
122
123 Ok(StatusCode::NO_CONTENT)
124}
125
126#[derive(Serialize, Deserialize, Default)]
128pub struct CreateSessionBodyRequest {
129 #[serde(default)]
130 pub state: std::collections::HashMap<String, serde_json::Value>,
131 #[serde(default)]
132 pub events: Vec<serde_json::Value>,
133}
134
135#[derive(Deserialize)]
137pub struct SessionPathParams {
138 pub app_name: String,
139 pub user_id: String,
140 #[serde(default)]
141 pub session_id: Option<String>,
142}
143
144pub async fn create_session_from_path(
148 State(controller): State<SessionController>,
149 Path(params): Path<SessionPathParams>,
150 body: Option<Json<CreateSessionBodyRequest>>,
151) -> Result<Json<SessionResponse>, StatusCode> {
152 let session_id = params.session_id.unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
153
154 let session = controller
155 .session_service
156 .create(adk_session::CreateRequest {
157 app_name: params.app_name.clone(),
158 user_id: params.user_id.clone(),
159 session_id: Some(session_id),
160 state: body.map(|b| b.state.clone()).unwrap_or_default().into_iter().collect(),
161 })
162 .await
163 .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
164
165 Ok(Json(SessionController::session_to_response(session.as_ref())))
166}
167
168pub async fn get_session_from_path(
170 State(controller): State<SessionController>,
171 Path(params): Path<SessionPathParams>,
172) -> Result<Json<SessionResponse>, StatusCode> {
173 let session_id = params.session_id.ok_or(StatusCode::BAD_REQUEST)?;
174
175 let session = controller
176 .session_service
177 .get(adk_session::GetRequest {
178 app_name: params.app_name,
179 user_id: params.user_id,
180 session_id,
181 num_recent_events: None,
182 after: None,
183 })
184 .await
185 .map_err(|_| StatusCode::NOT_FOUND)?;
186
187 Ok(Json(SessionController::session_to_response(session.as_ref())))
188}
189
190pub async fn delete_session_from_path(
192 State(controller): State<SessionController>,
193 Path(params): Path<SessionPathParams>,
194) -> Result<StatusCode, StatusCode> {
195 let session_id = params.session_id.ok_or(StatusCode::BAD_REQUEST)?;
196
197 controller
198 .session_service
199 .delete(adk_session::DeleteRequest {
200 app_name: params.app_name,
201 user_id: params.user_id,
202 session_id,
203 })
204 .await
205 .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
206
207 Ok(StatusCode::NO_CONTENT)
208}
209
210pub async fn list_sessions(
212 State(controller): State<SessionController>,
213 Path(params): Path<SessionPathParams>,
214) -> Result<Json<Vec<SessionResponse>>, StatusCode> {
215 tracing::info!(
216 "list_sessions called with app_name: {}, user_id: {}",
217 params.app_name,
218 params.user_id
219 );
220
221 let sessions = controller
222 .session_service
223 .list(adk_session::ListRequest {
224 app_name: params.app_name.clone(),
225 user_id: params.user_id.clone(),
226 })
227 .await
228 .map_err(|e| {
229 tracing::error!("Failed to list sessions: {:?}", e);
230 StatusCode::INTERNAL_SERVER_ERROR
231 })?;
232
233 tracing::info!("Found {} sessions", sessions.len());
234
235 let responses: Vec<SessionResponse> =
236 sessions.into_iter().map(|s| SessionController::session_to_response(s.as_ref())).collect();
237
238 Ok(Json(responses))
239}