adk_server/rest/controllers/
session.rs1use axum::{
2 extract::{Path, State},
3 http::StatusCode,
4 Json,
5};
6use serde::{Deserialize, Serialize};
7use std::sync::Arc;
8use tracing::info;
9
10#[derive(Clone)]
11pub struct SessionController {
12 session_service: Arc<dyn adk_session::SessionService>,
13}
14
15impl SessionController {
16 pub fn new(session_service: Arc<dyn adk_session::SessionService>) -> Self {
17 Self { session_service }
18 }
19}
20
21#[derive(Serialize, Deserialize)]
22pub struct CreateSessionRequest {
23 #[serde(rename = "appName")]
24 pub app_name: String,
25 #[serde(rename = "userId")]
26 pub user_id: String,
27 #[serde(rename = "sessionId", default)]
28 pub session_id: Option<String>,
29}
30
31#[derive(Serialize, Deserialize)]
32#[serde(rename_all = "camelCase")]
33pub struct SessionResponse {
34 pub id: String,
35 pub app_name: String,
36 pub user_id: String,
37 pub last_update_time: i64,
38 pub events: Vec<serde_json::Value>,
39 pub state: std::collections::HashMap<String, serde_json::Value>,
40}
41
42pub async fn create_session(
43 State(controller): State<SessionController>,
44 Json(req): Json<CreateSessionRequest>,
45) -> Result<Json<SessionResponse>, StatusCode> {
46 info!(
47 app_name = %req.app_name,
48 user_id = %req.user_id,
49 session_id = ?req.session_id,
50 "POST /sessions - Creating session"
51 );
52
53 let session_id = req.session_id.unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
55
56 let session = controller
57 .session_service
58 .create(adk_session::CreateRequest {
59 app_name: req.app_name.clone(),
60 user_id: req.user_id.clone(),
61 session_id: Some(session_id),
62 state: std::collections::HashMap::new(),
63 })
64 .await
65 .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
66
67 let response = SessionResponse {
68 id: session.id().to_string(),
69 app_name: session.app_name().to_string(),
70 user_id: session.user_id().to_string(),
71 last_update_time: session.last_update_time().timestamp(),
72 events: vec![],
73 state: std::collections::HashMap::new(),
74 };
75
76 info!(session_id = %response.id, "Session created successfully");
77
78 Ok(Json(response))
79}
80
81pub async fn get_session(
82 State(controller): State<SessionController>,
83 Path((app_name, user_id, session_id)): Path<(String, String, String)>,
84) -> Result<Json<SessionResponse>, StatusCode> {
85 let session = controller
86 .session_service
87 .get(adk_session::GetRequest {
88 app_name,
89 user_id,
90 session_id,
91 num_recent_events: None,
92 after: None,
93 })
94 .await
95 .map_err(|_| StatusCode::NOT_FOUND)?;
96
97 Ok(Json(SessionResponse {
98 id: session.id().to_string(),
99 app_name: session.app_name().to_string(),
100 user_id: session.user_id().to_string(),
101 last_update_time: session.last_update_time().timestamp(),
102 events: vec![],
103 state: std::collections::HashMap::new(),
104 }))
105}
106
107pub async fn delete_session(
108 State(controller): State<SessionController>,
109 Path((app_name, user_id, session_id)): Path<(String, String, String)>,
110) -> Result<StatusCode, StatusCode> {
111 controller
112 .session_service
113 .delete(adk_session::DeleteRequest { app_name, user_id, session_id })
114 .await
115 .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
116
117 Ok(StatusCode::NO_CONTENT)
118}
119
120#[derive(Serialize, Deserialize, Default)]
122pub struct CreateSessionBodyRequest {
123 #[serde(default)]
124 pub state: std::collections::HashMap<String, serde_json::Value>,
125 #[serde(default)]
126 pub events: Vec<serde_json::Value>,
127}
128
129#[derive(Deserialize)]
131pub struct SessionPathParams {
132 pub app_name: String,
133 pub user_id: String,
134 #[serde(default)]
135 pub session_id: Option<String>,
136}
137
138pub async fn create_session_from_path(
142 State(controller): State<SessionController>,
143 Path(params): Path<SessionPathParams>,
144 body: Option<Json<CreateSessionBodyRequest>>,
145) -> Result<Json<SessionResponse>, StatusCode> {
146 let session_id = params.session_id.unwrap_or_else(|| uuid::Uuid::new_v4().to_string());
147
148 let session = controller
149 .session_service
150 .create(adk_session::CreateRequest {
151 app_name: params.app_name.clone(),
152 user_id: params.user_id.clone(),
153 session_id: Some(session_id),
154 state: body.map(|b| b.state.clone()).unwrap_or_default().into_iter().collect(),
155 })
156 .await
157 .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
158
159 Ok(Json(SessionResponse {
160 id: session.id().to_string(),
161 app_name: session.app_name().to_string(),
162 user_id: session.user_id().to_string(),
163 last_update_time: session.last_update_time().timestamp(),
164 events: vec![],
165 state: std::collections::HashMap::new(),
166 }))
167}
168
169pub async fn get_session_from_path(
171 State(controller): State<SessionController>,
172 Path(params): Path<SessionPathParams>,
173) -> Result<Json<SessionResponse>, StatusCode> {
174 let session_id = params.session_id.ok_or(StatusCode::BAD_REQUEST)?;
175
176 let session = controller
177 .session_service
178 .get(adk_session::GetRequest {
179 app_name: params.app_name,
180 user_id: params.user_id,
181 session_id,
182 num_recent_events: None,
183 after: None,
184 })
185 .await
186 .map_err(|_| StatusCode::NOT_FOUND)?;
187
188 Ok(Json(SessionResponse {
189 id: session.id().to_string(),
190 app_name: session.app_name().to_string(),
191 user_id: session.user_id().to_string(),
192 last_update_time: session.last_update_time().timestamp(),
193 events: vec![],
194 state: std::collections::HashMap::new(),
195 }))
196}
197
198pub async fn delete_session_from_path(
200 State(controller): State<SessionController>,
201 Path(params): Path<SessionPathParams>,
202) -> Result<StatusCode, StatusCode> {
203 let session_id = params.session_id.ok_or(StatusCode::BAD_REQUEST)?;
204
205 controller
206 .session_service
207 .delete(adk_session::DeleteRequest {
208 app_name: params.app_name,
209 user_id: params.user_id,
210 session_id,
211 })
212 .await
213 .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
214
215 Ok(StatusCode::NO_CONTENT)
216}
217
218pub async fn list_sessions(
220 State(controller): State<SessionController>,
221 Path(params): Path<SessionPathParams>,
222) -> Result<Json<Vec<SessionResponse>>, StatusCode> {
223 let sessions = controller
224 .session_service
225 .list(adk_session::ListRequest { app_name: params.app_name, user_id: params.user_id })
226 .await
227 .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?;
228
229 let responses: Vec<SessionResponse> = sessions
230 .into_iter()
231 .map(|s| SessionResponse {
232 id: s.id().to_string(),
233 app_name: s.app_name().to_string(),
234 user_id: s.user_id().to_string(),
235 last_update_time: s.last_update_time().timestamp(),
236 events: vec![],
237 state: std::collections::HashMap::new(),
238 })
239 .collect();
240
241 Ok(Json(responses))
242}