1use async_trait::async_trait;
7use chrono::{DateTime, Utc};
8use serde::{Deserialize, Serialize};
9use stakpak_shared::models::integrations::openai::ChatMessage;
10use uuid::Uuid;
11
12pub use crate::local::storage::LocalStorage;
14pub use crate::stakpak::storage::StakpakStorage;
15
16#[async_trait]
25pub trait SessionStorage: Send + Sync {
26 fn backend_info(&self) -> BackendInfo;
28
29 async fn list_sessions(
35 &self,
36 query: &ListSessionsQuery,
37 ) -> Result<ListSessionsResult, StorageError>;
38
39 async fn get_session(&self, session_id: Uuid) -> Result<Session, StorageError>;
41
42 async fn create_session(
44 &self,
45 request: &CreateSessionRequest,
46 ) -> Result<CreateSessionResult, StorageError>;
47
48 async fn update_session(
50 &self,
51 session_id: Uuid,
52 request: &UpdateSessionRequest,
53 ) -> Result<Session, StorageError>;
54
55 async fn delete_session(&self, session_id: Uuid) -> Result<(), StorageError>;
57
58 async fn list_checkpoints(
64 &self,
65 session_id: Uuid,
66 query: &ListCheckpointsQuery,
67 ) -> Result<ListCheckpointsResult, StorageError>;
68
69 async fn get_checkpoint(&self, checkpoint_id: Uuid) -> Result<Checkpoint, StorageError>;
71
72 async fn create_checkpoint(
74 &self,
75 session_id: Uuid,
76 request: &CreateCheckpointRequest,
77 ) -> Result<Checkpoint, StorageError>;
78
79 async fn get_active_checkpoint(&self, session_id: Uuid) -> Result<Checkpoint, StorageError> {
85 let session = self.get_session(session_id).await?;
86 session
87 .active_checkpoint
88 .ok_or(StorageError::NotFound("No active checkpoint".to_string()))
89 }
90
91 async fn get_session_stats(&self, _session_id: Uuid) -> Result<SessionStats, StorageError> {
93 Ok(SessionStats::default())
94 }
95}
96
97pub type BoxedSessionStorage = Box<dyn SessionStorage>;
99
100#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
105#[serde(rename_all = "kebab-case")]
106pub enum BackendKind {
107 StakpakApi,
108 Local,
109}
110
111#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
112pub struct BackendInfo {
113 pub kind: BackendKind,
114 #[serde(skip_serializing_if = "Option::is_none")]
115 pub profile: Option<String>,
116 #[serde(skip_serializing_if = "Option::is_none")]
117 pub endpoint: Option<String>,
118 #[serde(skip_serializing_if = "Option::is_none")]
119 pub store_path: Option<String>,
120}
121
122impl BackendInfo {
123 pub fn stakpak_api(profile: Option<String>, endpoint: impl Into<String>) -> Self {
124 Self {
125 kind: BackendKind::StakpakApi,
126 profile,
127 endpoint: Some(endpoint.into()),
128 store_path: None,
129 }
130 }
131
132 pub fn local(store_path: impl Into<String>) -> Self {
133 Self {
134 kind: BackendKind::Local,
135 profile: None,
136 endpoint: None,
137 store_path: Some(store_path.into()),
138 }
139 }
140}
141
142#[derive(Debug, Clone, PartialEq)]
148pub enum StorageError {
149 NotFound(String),
151 InvalidRequest(String),
153 Unauthorized(String),
155 RateLimited(String),
157 Internal(String),
159 Connection(String),
161}
162
163impl std::fmt::Display for StorageError {
164 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
165 match self {
166 StorageError::NotFound(msg) => write!(f, "Not found: {}", msg),
167 StorageError::InvalidRequest(msg) => write!(f, "Invalid request: {}", msg),
168 StorageError::Unauthorized(msg) => write!(f, "Unauthorized: {}", msg),
169 StorageError::RateLimited(msg) => write!(f, "Rate limited: {}", msg),
170 StorageError::Internal(msg) => write!(f, "Internal error: {}", msg),
171 StorageError::Connection(msg) => write!(f, "Connection error: {}", msg),
172 }
173 }
174}
175
176impl std::error::Error for StorageError {}
177
178impl From<String> for StorageError {
179 fn from(s: String) -> Self {
180 StorageError::Internal(s)
181 }
182}
183
184impl From<&str> for StorageError {
185 fn from(s: &str) -> Self {
186 StorageError::Internal(s.to_string())
187 }
188}
189
190#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Default)]
196#[serde(rename_all = "UPPERCASE")]
197pub enum SessionVisibility {
198 #[default]
199 Private,
200 Public,
201}
202
203impl std::fmt::Display for SessionVisibility {
204 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
205 match self {
206 SessionVisibility::Private => write!(f, "PRIVATE"),
207 SessionVisibility::Public => write!(f, "PUBLIC"),
208 }
209 }
210}
211
212#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Default)]
214#[serde(rename_all = "UPPERCASE")]
215pub enum SessionStatus {
216 #[default]
217 Active,
218 Deleted,
219}
220
221impl std::fmt::Display for SessionStatus {
222 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
223 match self {
224 SessionStatus::Active => write!(f, "ACTIVE"),
225 SessionStatus::Deleted => write!(f, "DELETED"),
226 }
227 }
228}
229
230#[derive(Debug, Clone, Serialize, Deserialize)]
232pub struct Session {
233 pub id: Uuid,
234 pub title: String,
235 pub visibility: SessionVisibility,
236 pub status: SessionStatus,
237 pub cwd: Option<String>,
238 pub created_at: DateTime<Utc>,
239 pub updated_at: DateTime<Utc>,
240 pub active_checkpoint: Option<Checkpoint>,
241}
242
243#[derive(Debug, Clone, Serialize, Deserialize)]
245pub struct SessionSummary {
246 pub id: Uuid,
247 pub title: String,
248 pub visibility: SessionVisibility,
249 pub status: SessionStatus,
250 pub cwd: Option<String>,
251 pub created_at: DateTime<Utc>,
252 pub updated_at: DateTime<Utc>,
253 pub message_count: u32,
254 pub active_checkpoint_id: Option<Uuid>,
255 pub last_message_at: Option<DateTime<Utc>>,
256}
257
258#[derive(Debug, Clone, Serialize, Deserialize)]
264pub struct Checkpoint {
265 pub id: Uuid,
266 pub session_id: Uuid,
267 pub parent_id: Option<Uuid>,
268 pub state: CheckpointState,
269 pub created_at: DateTime<Utc>,
270 pub updated_at: DateTime<Utc>,
271}
272
273#[derive(Debug, Clone, Serialize, Deserialize)]
275pub struct CheckpointSummary {
276 pub id: Uuid,
277 pub session_id: Uuid,
278 pub parent_id: Option<Uuid>,
279 pub message_count: u32,
280 pub created_at: DateTime<Utc>,
281 pub updated_at: DateTime<Utc>,
282}
283
284#[derive(Debug, Clone, Default, Serialize, Deserialize)]
286pub struct CheckpointState {
287 #[serde(default)]
288 pub messages: Vec<ChatMessage>,
289 #[serde(default, skip_serializing_if = "Option::is_none")]
291 pub metadata: Option<serde_json::Value>,
292}
293
294#[derive(Debug, Clone, Serialize)]
300pub struct CreateSessionRequest {
301 pub title: String,
302 pub visibility: SessionVisibility,
303 pub cwd: Option<String>,
304 pub initial_state: CheckpointState,
305}
306
307impl CreateSessionRequest {
308 pub fn new(title: impl Into<String>, messages: Vec<ChatMessage>) -> Self {
309 Self {
310 title: title.into(),
311 visibility: SessionVisibility::Private,
312 cwd: None,
313 initial_state: CheckpointState {
314 messages,
315 metadata: None,
316 },
317 }
318 }
319
320 pub fn with_visibility(mut self, visibility: SessionVisibility) -> Self {
321 self.visibility = visibility;
322 self
323 }
324
325 pub fn with_cwd(mut self, cwd: impl Into<String>) -> Self {
326 self.cwd = Some(cwd.into());
327 self
328 }
329}
330
331#[derive(Debug, Clone, Default, Serialize)]
333pub struct UpdateSessionRequest {
334 pub title: Option<String>,
335 pub visibility: Option<SessionVisibility>,
336}
337
338impl UpdateSessionRequest {
339 pub fn new() -> Self {
340 Self::default()
341 }
342
343 pub fn with_title(mut self, title: impl Into<String>) -> Self {
344 self.title = Some(title.into());
345 self
346 }
347
348 pub fn with_visibility(mut self, visibility: SessionVisibility) -> Self {
349 self.visibility = Some(visibility);
350 self
351 }
352}
353
354#[derive(Debug, Clone, Serialize)]
356pub struct CreateCheckpointRequest {
357 pub state: CheckpointState,
358 pub parent_id: Option<Uuid>,
359}
360
361impl CreateCheckpointRequest {
362 pub fn new(messages: Vec<ChatMessage>) -> Self {
363 Self {
364 state: CheckpointState {
365 messages,
366 metadata: None,
367 },
368 parent_id: None,
369 }
370 }
371
372 pub fn with_parent(mut self, parent_id: Uuid) -> Self {
373 self.parent_id = Some(parent_id);
374 self
375 }
376
377 pub fn with_metadata(mut self, metadata: serde_json::Value) -> Self {
378 self.state.metadata = Some(metadata);
379 self
380 }
381}
382
383#[derive(Debug, Clone, Default, Serialize)]
385pub struct ListSessionsQuery {
386 pub limit: Option<u32>,
387 pub offset: Option<u32>,
388 pub search: Option<String>,
389 pub status: Option<SessionStatus>,
390 pub visibility: Option<SessionVisibility>,
391}
392
393impl ListSessionsQuery {
394 pub fn new() -> Self {
395 Self::default()
396 }
397
398 pub fn with_limit(mut self, limit: u32) -> Self {
399 self.limit = Some(limit);
400 self
401 }
402
403 pub fn with_offset(mut self, offset: u32) -> Self {
404 self.offset = Some(offset);
405 self
406 }
407
408 pub fn with_search(mut self, search: impl Into<String>) -> Self {
409 self.search = Some(search.into());
410 self
411 }
412}
413
414#[derive(Debug, Clone, Default, Serialize)]
416pub struct ListCheckpointsQuery {
417 pub limit: Option<u32>,
418 pub offset: Option<u32>,
419 pub include_state: Option<bool>,
420}
421
422impl ListCheckpointsQuery {
423 pub fn new() -> Self {
424 Self::default()
425 }
426
427 pub fn with_limit(mut self, limit: u32) -> Self {
428 self.limit = Some(limit);
429 self
430 }
431
432 pub fn with_state(mut self) -> Self {
433 self.include_state = Some(true);
434 self
435 }
436}
437
438#[derive(Debug, Clone)]
444pub struct CreateSessionResult {
445 pub session_id: Uuid,
446 pub checkpoint: Checkpoint,
447}
448
449#[derive(Debug, Clone)]
451pub struct ListSessionsResult {
452 pub sessions: Vec<SessionSummary>,
453 pub total: Option<u32>,
454}
455
456#[derive(Debug, Clone)]
458pub struct ListCheckpointsResult {
459 pub checkpoints: Vec<CheckpointSummary>,
460 pub total: Option<u32>,
461}
462
463#[derive(Debug, Clone, Default, Serialize, Deserialize)]
469pub struct SessionStats {
470 pub total_sessions: u32,
471 pub total_tool_calls: u32,
472 pub successful_tool_calls: u32,
473 pub failed_tool_calls: u32,
474 pub aborted_tool_calls: u32,
475 pub sessions_with_activity: u32,
476 pub total_time_saved_seconds: Option<u32>,
477 pub tools_usage: Vec<ToolUsageStats>,
478}
479
480#[derive(Debug, Clone, Serialize, Deserialize)]
482pub struct ToolUsageStats {
483 pub tool_name: String,
484 pub display_name: String,
485 pub usage_counts: ToolUsageCounts,
486 pub time_saved_per_call: Option<f64>,
487 pub time_saved_seconds: Option<u32>,
488}
489
490#[derive(Debug, Clone, Serialize, Deserialize)]
492pub struct ToolUsageCounts {
493 pub total: u32,
494 pub successful: u32,
495 pub failed: u32,
496 pub aborted: u32,
497}