lago_api/routes/
sessions.rs1use std::collections::HashMap;
2use std::sync::Arc;
3
4use axum::Json;
5use axum::extract::{Path, State};
6use serde::{Deserialize, Serialize};
7
8use lago_core::event::{EventEnvelope, EventPayload};
9use lago_core::id::{BranchId, EventId, SessionId};
10use lago_core::session::{Session, SessionConfig};
11
12use crate::error::ApiError;
13use crate::state::AppState;
14
15#[derive(Deserialize, Serialize)]
18pub struct CreateSessionRequest {
19 pub name: String,
20 #[serde(default)]
21 pub model: Option<String>,
22 #[serde(default)]
23 pub params: Option<HashMap<String, String>>,
24}
25
26#[derive(Serialize, Deserialize)]
27pub struct CreateSessionResponse {
28 pub session_id: String,
29 pub branch_id: String,
30}
31
32#[derive(Serialize, Deserialize)]
33pub struct SessionResponse {
34 pub session_id: String,
35 pub name: String,
36 pub model: String,
37 pub created_at: u64,
38 pub branches: Vec<String>,
39}
40
41impl From<&Session> for SessionResponse {
42 fn from(s: &Session) -> Self {
43 Self {
44 session_id: s.session_id.to_string(),
45 name: s.config.name.clone(),
46 model: s.config.model.clone(),
47 created_at: s.created_at,
48 branches: s.branches.iter().map(|b| b.to_string()).collect(),
49 }
50 }
51}
52
53pub async fn create_session(
57 State(state): State<Arc<AppState>>,
58 Json(body): Json<CreateSessionRequest>,
59) -> Result<(axum::http::StatusCode, Json<CreateSessionResponse>), ApiError> {
60 let session_id = SessionId::new();
61 let branch_id = BranchId::from_string("main");
62
63 let config = SessionConfig {
64 name: body.name.clone(),
65 model: body.model.unwrap_or_default(),
66 params: body.params.unwrap_or_default(),
67 };
68
69 let session = Session {
70 session_id: session_id.clone(),
71 config: config.clone(),
72 created_at: EventEnvelope::now_micros(),
73 branches: vec![branch_id.clone()],
74 };
75
76 state.journal.put_session(session).await?;
77
78 let event = EventEnvelope {
80 event_id: EventId::new(),
81 session_id: session_id.clone(),
82 branch_id: branch_id.clone(),
83 run_id: None,
84 seq: 0,
85 timestamp: EventEnvelope::now_micros(),
86 parent_id: None,
87 payload: EventPayload::SessionCreated {
88 name: body.name,
89 config: serde_json::to_value(&config).unwrap_or_default(),
90 },
91 metadata: HashMap::new(),
92 };
93
94 state.journal.append(event).await?;
95
96 Ok((
97 axum::http::StatusCode::CREATED,
98 Json(CreateSessionResponse {
99 session_id: session_id.to_string(),
100 branch_id: branch_id.to_string(),
101 }),
102 ))
103}
104
105pub async fn list_sessions(
107 State(state): State<Arc<AppState>>,
108) -> Result<Json<Vec<SessionResponse>>, ApiError> {
109 let sessions = state.journal.list_sessions().await?;
110 let responses: Vec<SessionResponse> = sessions.iter().map(SessionResponse::from).collect();
111 Ok(Json(responses))
112}
113
114pub async fn get_session(
116 State(state): State<Arc<AppState>>,
117 Path(id): Path<String>,
118) -> Result<Json<SessionResponse>, ApiError> {
119 let session_id = SessionId::from_string(id.clone());
120 let session = state
121 .journal
122 .get_session(&session_id)
123 .await?
124 .ok_or_else(|| ApiError::NotFound(format!("session not found: {id}")))?;
125 Ok(Json(SessionResponse::from(&session)))
126}