Skip to main content

stakpak_api/
storage.rs

1//! Session Storage abstraction
2//!
3//! Provides a unified interface for session and checkpoint management
4//! with implementations for both Stakpak API and local SQLite storage.
5
6use async_trait::async_trait;
7use chrono::{DateTime, Utc};
8use serde::{Deserialize, Serialize};
9use stakpak_shared::models::integrations::openai::ChatMessage;
10use uuid::Uuid;
11
12// Re-export implementations
13pub use crate::local::storage::LocalStorage;
14pub use crate::stakpak::storage::StakpakStorage;
15
16// =============================================================================
17// SessionStorage Trait
18// =============================================================================
19
20/// Unified session storage trait
21///
22/// Abstracts session and checkpoint operations for both
23/// Stakpak API and local SQLite storage backends.
24#[async_trait]
25pub trait SessionStorage: Send + Sync {
26    // =========================================================================
27    // Session Operations
28    // =========================================================================
29
30    /// List all sessions
31    async fn list_sessions(
32        &self,
33        query: &ListSessionsQuery,
34    ) -> Result<ListSessionsResult, StorageError>;
35
36    /// Get a session by ID (includes active checkpoint)
37    async fn get_session(&self, session_id: Uuid) -> Result<Session, StorageError>;
38
39    /// Create a new session with initial checkpoint
40    async fn create_session(
41        &self,
42        request: &CreateSessionRequest,
43    ) -> Result<CreateSessionResult, StorageError>;
44
45    /// Update session metadata (title, visibility)
46    async fn update_session(
47        &self,
48        session_id: Uuid,
49        request: &UpdateSessionRequest,
50    ) -> Result<Session, StorageError>;
51
52    /// Delete a session
53    async fn delete_session(&self, session_id: Uuid) -> Result<(), StorageError>;
54
55    // =========================================================================
56    // Checkpoint Operations
57    // =========================================================================
58
59    /// List checkpoints for a session
60    async fn list_checkpoints(
61        &self,
62        session_id: Uuid,
63        query: &ListCheckpointsQuery,
64    ) -> Result<ListCheckpointsResult, StorageError>;
65
66    /// Get a checkpoint by ID
67    async fn get_checkpoint(&self, checkpoint_id: Uuid) -> Result<Checkpoint, StorageError>;
68
69    /// Create a new checkpoint for a session
70    async fn create_checkpoint(
71        &self,
72        session_id: Uuid,
73        request: &CreateCheckpointRequest,
74    ) -> Result<Checkpoint, StorageError>;
75
76    // =========================================================================
77    // Convenience Methods (with default implementations)
78    // =========================================================================
79
80    /// Get the latest/active checkpoint for a session
81    async fn get_active_checkpoint(&self, session_id: Uuid) -> Result<Checkpoint, StorageError> {
82        let session = self.get_session(session_id).await?;
83        session
84            .active_checkpoint
85            .ok_or(StorageError::NotFound("No active checkpoint".to_string()))
86    }
87
88    /// Get session stats (optional - returns default if not supported)
89    async fn get_session_stats(&self, _session_id: Uuid) -> Result<SessionStats, StorageError> {
90        Ok(SessionStats::default())
91    }
92}
93
94/// Box wrapper for dynamic dispatch
95pub type BoxedSessionStorage = Box<dyn SessionStorage>;
96
97// =============================================================================
98// Error Types
99// =============================================================================
100
101/// Storage operation errors
102#[derive(Debug, Clone, PartialEq)]
103pub enum StorageError {
104    /// Resource not found
105    NotFound(String),
106    /// Invalid request
107    InvalidRequest(String),
108    /// Authentication/authorization error
109    Unauthorized(String),
110    /// Rate limit exceeded
111    RateLimited(String),
112    /// Internal storage error
113    Internal(String),
114    /// Connection error
115    Connection(String),
116}
117
118impl std::fmt::Display for StorageError {
119    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
120        match self {
121            StorageError::NotFound(msg) => write!(f, "Not found: {}", msg),
122            StorageError::InvalidRequest(msg) => write!(f, "Invalid request: {}", msg),
123            StorageError::Unauthorized(msg) => write!(f, "Unauthorized: {}", msg),
124            StorageError::RateLimited(msg) => write!(f, "Rate limited: {}", msg),
125            StorageError::Internal(msg) => write!(f, "Internal error: {}", msg),
126            StorageError::Connection(msg) => write!(f, "Connection error: {}", msg),
127        }
128    }
129}
130
131impl std::error::Error for StorageError {}
132
133impl From<String> for StorageError {
134    fn from(s: String) -> Self {
135        StorageError::Internal(s)
136    }
137}
138
139impl From<&str> for StorageError {
140    fn from(s: &str) -> Self {
141        StorageError::Internal(s.to_string())
142    }
143}
144
145// =============================================================================
146// Session Types
147// =============================================================================
148
149/// Session visibility
150#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Default)]
151#[serde(rename_all = "UPPERCASE")]
152pub enum SessionVisibility {
153    #[default]
154    Private,
155    Public,
156}
157
158impl std::fmt::Display for SessionVisibility {
159    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
160        match self {
161            SessionVisibility::Private => write!(f, "PRIVATE"),
162            SessionVisibility::Public => write!(f, "PUBLIC"),
163        }
164    }
165}
166
167/// Session status
168#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Default)]
169#[serde(rename_all = "UPPERCASE")]
170pub enum SessionStatus {
171    #[default]
172    Active,
173    Deleted,
174}
175
176impl std::fmt::Display for SessionStatus {
177    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
178        match self {
179            SessionStatus::Active => write!(f, "ACTIVE"),
180            SessionStatus::Deleted => write!(f, "DELETED"),
181        }
182    }
183}
184
185/// Full session with optional active checkpoint
186#[derive(Debug, Clone, Serialize, Deserialize)]
187pub struct Session {
188    pub id: Uuid,
189    pub title: String,
190    pub visibility: SessionVisibility,
191    pub status: SessionStatus,
192    pub cwd: Option<String>,
193    pub created_at: DateTime<Utc>,
194    pub updated_at: DateTime<Utc>,
195    pub active_checkpoint: Option<Checkpoint>,
196}
197
198/// Session summary for list responses
199#[derive(Debug, Clone, Serialize, Deserialize)]
200pub struct SessionSummary {
201    pub id: Uuid,
202    pub title: String,
203    pub visibility: SessionVisibility,
204    pub status: SessionStatus,
205    pub cwd: Option<String>,
206    pub created_at: DateTime<Utc>,
207    pub updated_at: DateTime<Utc>,
208    pub message_count: u32,
209    pub active_checkpoint_id: Option<Uuid>,
210    pub last_message_at: Option<DateTime<Utc>>,
211}
212
213// =============================================================================
214// Checkpoint Types
215// =============================================================================
216
217/// Full checkpoint with state
218#[derive(Debug, Clone, Serialize, Deserialize)]
219pub struct Checkpoint {
220    pub id: Uuid,
221    pub session_id: Uuid,
222    pub parent_id: Option<Uuid>,
223    pub state: CheckpointState,
224    pub created_at: DateTime<Utc>,
225    pub updated_at: DateTime<Utc>,
226}
227
228/// Checkpoint summary for list responses
229#[derive(Debug, Clone, Serialize, Deserialize)]
230pub struct CheckpointSummary {
231    pub id: Uuid,
232    pub session_id: Uuid,
233    pub parent_id: Option<Uuid>,
234    pub message_count: u32,
235    pub created_at: DateTime<Utc>,
236    pub updated_at: DateTime<Utc>,
237}
238
239/// Checkpoint state containing messages
240#[derive(Debug, Clone, Default, Serialize, Deserialize)]
241pub struct CheckpointState {
242    #[serde(default)]
243    pub messages: Vec<ChatMessage>,
244    /// Optional metadata for context trimming state, etc.
245    #[serde(default, skip_serializing_if = "Option::is_none")]
246    pub metadata: Option<serde_json::Value>,
247}
248
249// =============================================================================
250// Request Types
251// =============================================================================
252
253/// Request to create a session with initial checkpoint
254#[derive(Debug, Clone, Serialize)]
255pub struct CreateSessionRequest {
256    pub title: String,
257    pub visibility: SessionVisibility,
258    pub cwd: Option<String>,
259    pub initial_state: CheckpointState,
260}
261
262impl CreateSessionRequest {
263    pub fn new(title: impl Into<String>, messages: Vec<ChatMessage>) -> Self {
264        Self {
265            title: title.into(),
266            visibility: SessionVisibility::Private,
267            cwd: None,
268            initial_state: CheckpointState {
269                messages,
270                metadata: None,
271            },
272        }
273    }
274
275    pub fn with_visibility(mut self, visibility: SessionVisibility) -> Self {
276        self.visibility = visibility;
277        self
278    }
279
280    pub fn with_cwd(mut self, cwd: impl Into<String>) -> Self {
281        self.cwd = Some(cwd.into());
282        self
283    }
284}
285
286/// Request to update a session
287#[derive(Debug, Clone, Default, Serialize)]
288pub struct UpdateSessionRequest {
289    pub title: Option<String>,
290    pub visibility: Option<SessionVisibility>,
291}
292
293impl UpdateSessionRequest {
294    pub fn new() -> Self {
295        Self::default()
296    }
297
298    pub fn with_title(mut self, title: impl Into<String>) -> Self {
299        self.title = Some(title.into());
300        self
301    }
302
303    pub fn with_visibility(mut self, visibility: SessionVisibility) -> Self {
304        self.visibility = Some(visibility);
305        self
306    }
307}
308
309/// Request to create a checkpoint
310#[derive(Debug, Clone, Serialize)]
311pub struct CreateCheckpointRequest {
312    pub state: CheckpointState,
313    pub parent_id: Option<Uuid>,
314}
315
316impl CreateCheckpointRequest {
317    pub fn new(messages: Vec<ChatMessage>) -> Self {
318        Self {
319            state: CheckpointState {
320                messages,
321                metadata: None,
322            },
323            parent_id: None,
324        }
325    }
326
327    pub fn with_parent(mut self, parent_id: Uuid) -> Self {
328        self.parent_id = Some(parent_id);
329        self
330    }
331
332    pub fn with_metadata(mut self, metadata: serde_json::Value) -> Self {
333        self.state.metadata = Some(metadata);
334        self
335    }
336}
337
338/// Query parameters for listing sessions
339#[derive(Debug, Clone, Default, Serialize)]
340pub struct ListSessionsQuery {
341    pub limit: Option<u32>,
342    pub offset: Option<u32>,
343    pub search: Option<String>,
344    pub status: Option<SessionStatus>,
345    pub visibility: Option<SessionVisibility>,
346}
347
348impl ListSessionsQuery {
349    pub fn new() -> Self {
350        Self::default()
351    }
352
353    pub fn with_limit(mut self, limit: u32) -> Self {
354        self.limit = Some(limit);
355        self
356    }
357
358    pub fn with_offset(mut self, offset: u32) -> Self {
359        self.offset = Some(offset);
360        self
361    }
362
363    pub fn with_search(mut self, search: impl Into<String>) -> Self {
364        self.search = Some(search.into());
365        self
366    }
367}
368
369/// Query parameters for listing checkpoints
370#[derive(Debug, Clone, Default, Serialize)]
371pub struct ListCheckpointsQuery {
372    pub limit: Option<u32>,
373    pub offset: Option<u32>,
374    pub include_state: Option<bool>,
375}
376
377impl ListCheckpointsQuery {
378    pub fn new() -> Self {
379        Self::default()
380    }
381
382    pub fn with_limit(mut self, limit: u32) -> Self {
383        self.limit = Some(limit);
384        self
385    }
386
387    pub fn with_state(mut self) -> Self {
388        self.include_state = Some(true);
389        self
390    }
391}
392
393// =============================================================================
394// Response Types
395// =============================================================================
396
397/// Result of creating a session
398#[derive(Debug, Clone)]
399pub struct CreateSessionResult {
400    pub session_id: Uuid,
401    pub checkpoint: Checkpoint,
402}
403
404/// Result of listing sessions
405#[derive(Debug, Clone)]
406pub struct ListSessionsResult {
407    pub sessions: Vec<SessionSummary>,
408    pub total: Option<u32>,
409}
410
411/// Result of listing checkpoints
412#[derive(Debug, Clone)]
413pub struct ListCheckpointsResult {
414    pub checkpoints: Vec<CheckpointSummary>,
415    pub total: Option<u32>,
416}
417
418// =============================================================================
419// Stats Types
420// =============================================================================
421
422/// Session statistics
423#[derive(Debug, Clone, Default, Serialize, Deserialize)]
424pub struct SessionStats {
425    pub total_sessions: u32,
426    pub total_tool_calls: u32,
427    pub successful_tool_calls: u32,
428    pub failed_tool_calls: u32,
429    pub aborted_tool_calls: u32,
430    pub sessions_with_activity: u32,
431    pub total_time_saved_seconds: Option<u32>,
432    pub tools_usage: Vec<ToolUsageStats>,
433}
434
435/// Tool usage statistics
436#[derive(Debug, Clone, Serialize, Deserialize)]
437pub struct ToolUsageStats {
438    pub tool_name: String,
439    pub display_name: String,
440    pub usage_counts: ToolUsageCounts,
441    pub time_saved_per_call: Option<f64>,
442    pub time_saved_seconds: Option<u32>,
443}
444
445/// Tool usage counts
446#[derive(Debug, Clone, Serialize, Deserialize)]
447pub struct ToolUsageCounts {
448    pub total: u32,
449    pub successful: u32,
450    pub failed: u32,
451    pub aborted: u32,
452}