1use std::collections::HashMap;
2use std::sync::Arc;
3
4use axum::Json;
5use axum::extract::{Path, Query, State};
6use serde::{Deserialize, Serialize};
7
8use lago_core::event::{EventEnvelope, EventPayload};
9use lago_core::id::{BranchId, EventId, SessionId};
10use lago_core::session::{Session, SessionConfig};
11
12use crate::error::ApiError;
13use crate::state::AppState;
14
15#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
20#[serde(rename_all = "snake_case")]
21pub enum SessionType {
22 Vault,
24 Agent,
26 SiteAssets,
28 SiteContent,
30 Default,
32}
33
34impl SessionType {
35 pub fn from_name(name: &str) -> Self {
37 if name.starts_with("vault:") {
38 Self::Vault
39 } else if name.starts_with("agent:") {
40 Self::Agent
41 } else if name.starts_with("site-assets:") {
42 Self::SiteAssets
43 } else if name.starts_with("site-content:") {
44 Self::SiteContent
45 } else {
46 Self::Default
47 }
48 }
49
50 fn as_str(self) -> &'static str {
52 match self {
53 Self::Vault => "vault",
54 Self::Agent => "agent",
55 Self::SiteAssets => "site_assets",
56 Self::SiteContent => "site_content",
57 Self::Default => "default",
58 }
59 }
60
61 fn from_query(s: &str) -> Option<Self> {
63 match s.to_ascii_lowercase().as_str() {
64 "vault" => Some(Self::Vault),
65 "agent" => Some(Self::Agent),
66 "site_assets" | "site-assets" => Some(Self::SiteAssets),
67 "site_content" | "site-content" => Some(Self::SiteContent),
68 "default" => Some(Self::Default),
69 _ => None,
70 }
71 }
72}
73
74impl std::fmt::Display for SessionType {
75 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
76 f.write_str(self.as_str())
77 }
78}
79
80fn validate_session_name(name: &str) -> Result<SessionType, ApiError> {
88 let session_type = SessionType::from_name(name);
89
90 match session_type {
91 SessionType::Agent => {
92 let agent_id = name.strip_prefix("agent:").unwrap_or("");
93 if agent_id.is_empty() {
94 return Err(ApiError::BadRequest(
95 "agent: session requires a non-empty agent ID (e.g. agent:my-agent)".into(),
96 ));
97 }
98 }
99 SessionType::Vault => {
100 let user_id = name.strip_prefix("vault:").unwrap_or("");
101 if user_id.is_empty() {
102 return Err(ApiError::BadRequest(
103 "vault: session requires a non-empty user ID (e.g. vault:user_123)".into(),
104 ));
105 }
106 }
107 SessionType::SiteAssets => {
108 let scope = name.strip_prefix("site-assets:").unwrap_or("");
109 if scope.is_empty() {
110 return Err(ApiError::BadRequest(
111 "site-assets: session requires a non-empty scope (e.g. site-assets:public)"
112 .into(),
113 ));
114 }
115 }
116 SessionType::SiteContent => {
117 let scope = name.strip_prefix("site-content:").unwrap_or("");
118 if scope.is_empty() {
119 return Err(ApiError::BadRequest(
120 "site-content: session requires a non-empty scope (e.g. site-content:public)"
121 .into(),
122 ));
123 }
124 }
125 SessionType::Default => {}
126 }
127
128 Ok(session_type)
129}
130
131#[derive(Deserialize, Serialize)]
134pub struct CreateSessionRequest {
135 pub name: String,
136 #[serde(default)]
137 pub model: Option<String>,
138 #[serde(default)]
139 pub params: Option<HashMap<String, String>>,
140}
141
142#[derive(Serialize, Deserialize)]
143pub struct CreateSessionResponse {
144 pub session_id: String,
145 pub branch_id: String,
146 pub session_type: SessionType,
147}
148
149#[derive(Serialize, Deserialize)]
150pub struct SessionResponse {
151 pub session_id: String,
152 pub name: String,
153 pub model: String,
154 pub created_at: u64,
155 pub branches: Vec<String>,
156 pub session_type: SessionType,
157}
158
159impl From<&Session> for SessionResponse {
160 fn from(s: &Session) -> Self {
161 Self {
162 session_id: s.session_id.to_string(),
163 name: s.config.name.clone(),
164 model: s.config.model.clone(),
165 created_at: s.created_at,
166 branches: s.branches.iter().map(|b| b.to_string()).collect(),
167 session_type: SessionType::from_name(&s.config.name),
168 }
169 }
170}
171
172#[derive(Deserialize)]
174pub struct ListSessionsQuery {
175 #[serde(rename = "type")]
177 pub session_type: Option<String>,
178}
179
180pub async fn create_session(
184 State(state): State<Arc<AppState>>,
185 Json(body): Json<CreateSessionRequest>,
186) -> Result<(axum::http::StatusCode, Json<CreateSessionResponse>), ApiError> {
187 let session_type = validate_session_name(&body.name)?;
189
190 let session_id = SessionId::new();
191 let branch_id = BranchId::from_string("main");
192
193 let config = SessionConfig {
194 name: body.name.clone(),
195 model: body.model.unwrap_or_default(),
196 params: body.params.unwrap_or_default(),
197 };
198
199 let session = Session {
200 session_id: session_id.clone(),
201 config: config.clone(),
202 created_at: EventEnvelope::now_micros(),
203 branches: vec![branch_id.clone()],
204 };
205
206 state.journal.put_session(session).await?;
207
208 let event = EventEnvelope {
210 event_id: EventId::new(),
211 session_id: session_id.clone(),
212 branch_id: branch_id.clone(),
213 run_id: None,
214 seq: 0,
215 timestamp: EventEnvelope::now_micros(),
216 parent_id: None,
217 payload: EventPayload::SessionCreated {
218 name: body.name,
219 config: serde_json::to_value(&config).unwrap_or_default(),
220 },
221 metadata: HashMap::new(),
222 schema_version: 1,
223 };
224
225 state.journal.append(event).await?;
226
227 Ok((
228 axum::http::StatusCode::CREATED,
229 Json(CreateSessionResponse {
230 session_id: session_id.to_string(),
231 branch_id: branch_id.to_string(),
232 session_type,
233 }),
234 ))
235}
236
237pub async fn list_sessions(
242 State(state): State<Arc<AppState>>,
243 Query(query): Query<ListSessionsQuery>,
244) -> Result<Json<Vec<SessionResponse>>, ApiError> {
245 let sessions = state.journal.list_sessions().await?;
246
247 let type_filter = match &query.session_type {
249 Some(t) => {
250 let st = SessionType::from_query(t).ok_or_else(|| {
251 ApiError::BadRequest(format!(
252 "unknown session type: {t}. Valid types: vault, agent, site_assets, site_content, default"
253 ))
254 })?;
255 Some(st)
256 }
257 None => None,
258 };
259
260 let responses: Vec<SessionResponse> = sessions
261 .iter()
262 .filter(|s| match type_filter {
263 Some(t) => SessionType::from_name(&s.config.name) == t,
264 None => true,
265 })
266 .map(SessionResponse::from)
267 .collect();
268
269 Ok(Json(responses))
270}
271
272pub async fn get_session(
274 State(state): State<Arc<AppState>>,
275 Path(id): Path<String>,
276) -> Result<Json<SessionResponse>, ApiError> {
277 let session_id = SessionId::from_string(id.clone());
278 let session = state
279 .journal
280 .get_session(&session_id)
281 .await?
282 .ok_or_else(|| ApiError::NotFound(format!("session not found: {id}")))?;
283 Ok(Json(SessionResponse::from(&session)))
284}
285
286pub async fn upsert_session(
295 State(state): State<Arc<AppState>>,
296 Path(id): Path<String>,
297 Json(mut session): Json<Session>,
298) -> Result<axum::http::StatusCode, ApiError> {
299 session.session_id = SessionId::from_string(id);
301 state.journal.put_session(session).await?;
302 Ok(axum::http::StatusCode::NO_CONTENT)
303}
304
305#[cfg(test)]
306mod tests {
307 use super::*;
308
309 #[test]
310 fn session_type_from_name() {
311 assert_eq!(SessionType::from_name("vault:user_123"), SessionType::Vault);
312 assert_eq!(SessionType::from_name("agent:my-agent"), SessionType::Agent);
313 assert_eq!(
314 SessionType::from_name("site-assets:public"),
315 SessionType::SiteAssets
316 );
317 assert_eq!(
318 SessionType::from_name("site-content:public"),
319 SessionType::SiteContent
320 );
321 assert_eq!(
322 SessionType::from_name("my-custom-session"),
323 SessionType::Default
324 );
325 assert_eq!(SessionType::from_name(""), SessionType::Default);
326 }
327
328 #[test]
329 fn session_type_from_query() {
330 assert_eq!(SessionType::from_query("vault"), Some(SessionType::Vault));
331 assert_eq!(SessionType::from_query("agent"), Some(SessionType::Agent));
332 assert_eq!(
333 SessionType::from_query("site_assets"),
334 Some(SessionType::SiteAssets)
335 );
336 assert_eq!(
337 SessionType::from_query("site-assets"),
338 Some(SessionType::SiteAssets)
339 );
340 assert_eq!(
341 SessionType::from_query("site_content"),
342 Some(SessionType::SiteContent)
343 );
344 assert_eq!(
345 SessionType::from_query("site-content"),
346 Some(SessionType::SiteContent)
347 );
348 assert_eq!(
349 SessionType::from_query("default"),
350 Some(SessionType::Default)
351 );
352 assert_eq!(SessionType::from_query("AGENT"), Some(SessionType::Agent));
353 assert_eq!(SessionType::from_query("unknown"), None);
354 }
355
356 #[test]
357 fn validate_session_name_valid() {
358 assert_eq!(
359 validate_session_name("agent:my-agent").unwrap(),
360 SessionType::Agent
361 );
362 assert_eq!(
363 validate_session_name("vault:user_123").unwrap(),
364 SessionType::Vault
365 );
366 assert_eq!(
367 validate_session_name("site-content:public").unwrap(),
368 SessionType::SiteContent
369 );
370 assert_eq!(
371 validate_session_name("site-assets:images").unwrap(),
372 SessionType::SiteAssets
373 );
374 assert_eq!(
375 validate_session_name("my-session").unwrap(),
376 SessionType::Default
377 );
378 assert_eq!(validate_session_name("").unwrap(), SessionType::Default);
379 }
380
381 #[test]
382 fn validate_session_name_empty_prefix() {
383 assert!(validate_session_name("agent:").is_err());
384 assert!(validate_session_name("vault:").is_err());
385 assert!(validate_session_name("site-content:").is_err());
386 assert!(validate_session_name("site-assets:").is_err());
387 }
388
389 #[test]
390 fn session_type_display() {
391 assert_eq!(SessionType::Vault.to_string(), "vault");
392 assert_eq!(SessionType::Agent.to_string(), "agent");
393 assert_eq!(SessionType::SiteAssets.to_string(), "site_assets");
394 assert_eq!(SessionType::SiteContent.to_string(), "site_content");
395 assert_eq!(SessionType::Default.to_string(), "default");
396 }
397
398 #[test]
399 fn session_type_serde_roundtrip() {
400 let json = serde_json::to_string(&SessionType::Agent).unwrap();
401 assert_eq!(json, r#""agent""#);
402
403 let parsed: SessionType = serde_json::from_str(&json).unwrap();
404 assert_eq!(parsed, SessionType::Agent);
405 }
406
407 #[test]
408 fn create_session_response_includes_type() {
409 let resp = CreateSessionResponse {
410 session_id: "s1".into(),
411 branch_id: "main".into(),
412 session_type: SessionType::Agent,
413 };
414 let json = serde_json::to_value(&resp).unwrap();
415 assert_eq!(json["session_type"], "agent");
416 }
417
418 #[test]
419 fn session_response_from_session() {
420 let session = Session {
421 session_id: SessionId::from_string("s1".to_string()),
422 config: SessionConfig {
423 name: "agent:test-bot".to_string(),
424 model: "mock".to_string(),
425 params: HashMap::new(),
426 },
427 created_at: 12345,
428 branches: vec![BranchId::from_string("main".to_string())],
429 };
430
431 let resp = SessionResponse::from(&session);
432 assert_eq!(resp.session_type, SessionType::Agent);
433 assert_eq!(resp.name, "agent:test-bot");
434 }
435}