Skip to main content

stakpak_api/
storage.rs

1//! Session Storage abstraction
2//!
3//! Provides a unified interface for session and checkpoint management
4//! with implementations for both Stakpak API and local SQLite storage.
5
6use async_trait::async_trait;
7use chrono::{DateTime, Utc};
8use serde::{Deserialize, Serialize};
9use stakpak_shared::models::integrations::openai::ChatMessage;
10use uuid::Uuid;
11
12// Re-export implementations
13pub use crate::local::storage::LocalStorage;
14pub use crate::stakpak::storage::StakpakStorage;
15
16// =============================================================================
17// SessionStorage Trait
18// =============================================================================
19
20/// Unified session storage trait
21///
22/// Abstracts session and checkpoint operations for both
23/// Stakpak API and local SQLite storage backends.
24#[async_trait]
25pub trait SessionStorage: Send + Sync {
26    /// Describe the concrete backend serving this storage implementation.
27    fn backend_info(&self) -> BackendInfo;
28
29    // =========================================================================
30    // Session Operations
31    // =========================================================================
32
33    /// List all sessions
34    async fn list_sessions(
35        &self,
36        query: &ListSessionsQuery,
37    ) -> Result<ListSessionsResult, StorageError>;
38
39    /// Get a session by ID (includes active checkpoint)
40    async fn get_session(&self, session_id: Uuid) -> Result<Session, StorageError>;
41
42    /// Create a new session with initial checkpoint
43    async fn create_session(
44        &self,
45        request: &CreateSessionRequest,
46    ) -> Result<CreateSessionResult, StorageError>;
47
48    /// Update session metadata (title, visibility)
49    async fn update_session(
50        &self,
51        session_id: Uuid,
52        request: &UpdateSessionRequest,
53    ) -> Result<Session, StorageError>;
54
55    /// Delete a session
56    async fn delete_session(&self, session_id: Uuid) -> Result<(), StorageError>;
57
58    // =========================================================================
59    // Checkpoint Operations
60    // =========================================================================
61
62    /// List checkpoints for a session
63    async fn list_checkpoints(
64        &self,
65        session_id: Uuid,
66        query: &ListCheckpointsQuery,
67    ) -> Result<ListCheckpointsResult, StorageError>;
68
69    /// Get a checkpoint by ID
70    async fn get_checkpoint(&self, checkpoint_id: Uuid) -> Result<Checkpoint, StorageError>;
71
72    /// Create a new checkpoint for a session
73    async fn create_checkpoint(
74        &self,
75        session_id: Uuid,
76        request: &CreateCheckpointRequest,
77    ) -> Result<Checkpoint, StorageError>;
78
79    // =========================================================================
80    // Convenience Methods (with default implementations)
81    // =========================================================================
82
83    /// Get the latest/active checkpoint for a session
84    async fn get_active_checkpoint(&self, session_id: Uuid) -> Result<Checkpoint, StorageError> {
85        let session = self.get_session(session_id).await?;
86        session
87            .active_checkpoint
88            .ok_or(StorageError::NotFound("No active checkpoint".to_string()))
89    }
90
91    /// Get session stats (optional - returns default if not supported)
92    async fn get_session_stats(&self, _session_id: Uuid) -> Result<SessionStats, StorageError> {
93        Ok(SessionStats::default())
94    }
95}
96
97/// Box wrapper for dynamic dispatch
98pub type BoxedSessionStorage = Box<dyn SessionStorage>;
99
100// =============================================================================
101// Backend Descriptor
102// =============================================================================
103
104#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
105#[serde(rename_all = "kebab-case")]
106pub enum BackendKind {
107    StakpakApi,
108    Local,
109}
110
111#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
112pub struct BackendInfo {
113    pub kind: BackendKind,
114    #[serde(skip_serializing_if = "Option::is_none")]
115    pub profile: Option<String>,
116    #[serde(skip_serializing_if = "Option::is_none")]
117    pub endpoint: Option<String>,
118    #[serde(skip_serializing_if = "Option::is_none")]
119    pub store_path: Option<String>,
120}
121
122impl BackendInfo {
123    pub fn stakpak_api(profile: Option<String>, endpoint: impl Into<String>) -> Self {
124        Self {
125            kind: BackendKind::StakpakApi,
126            profile,
127            endpoint: Some(endpoint.into()),
128            store_path: None,
129        }
130    }
131
132    pub fn local(store_path: impl Into<String>) -> Self {
133        Self {
134            kind: BackendKind::Local,
135            profile: None,
136            endpoint: None,
137            store_path: Some(store_path.into()),
138        }
139    }
140}
141
142// =============================================================================
143// Error Types
144// =============================================================================
145
146/// Storage operation errors
147#[derive(Debug, Clone, PartialEq)]
148pub enum StorageError {
149    /// Resource not found
150    NotFound(String),
151    /// Invalid request
152    InvalidRequest(String),
153    /// Authentication/authorization error
154    Unauthorized(String),
155    /// Rate limit exceeded
156    RateLimited(String),
157    /// Internal storage error
158    Internal(String),
159    /// Connection error
160    Connection(String),
161}
162
163impl std::fmt::Display for StorageError {
164    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
165        match self {
166            StorageError::NotFound(msg) => write!(f, "Not found: {}", msg),
167            StorageError::InvalidRequest(msg) => write!(f, "Invalid request: {}", msg),
168            StorageError::Unauthorized(msg) => write!(f, "Unauthorized: {}", msg),
169            StorageError::RateLimited(msg) => write!(f, "Rate limited: {}", msg),
170            StorageError::Internal(msg) => write!(f, "Internal error: {}", msg),
171            StorageError::Connection(msg) => write!(f, "Connection error: {}", msg),
172        }
173    }
174}
175
176impl std::error::Error for StorageError {}
177
178impl From<String> for StorageError {
179    fn from(s: String) -> Self {
180        StorageError::Internal(s)
181    }
182}
183
184impl From<&str> for StorageError {
185    fn from(s: &str) -> Self {
186        StorageError::Internal(s.to_string())
187    }
188}
189
190// =============================================================================
191// Session Types
192// =============================================================================
193
194/// Session visibility
195#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Default)]
196#[serde(rename_all = "UPPERCASE")]
197pub enum SessionVisibility {
198    #[default]
199    Private,
200    Public,
201}
202
203impl std::fmt::Display for SessionVisibility {
204    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
205        match self {
206            SessionVisibility::Private => write!(f, "PRIVATE"),
207            SessionVisibility::Public => write!(f, "PUBLIC"),
208        }
209    }
210}
211
212/// Session status
213#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Default)]
214#[serde(rename_all = "UPPERCASE")]
215pub enum SessionStatus {
216    #[default]
217    Active,
218    Deleted,
219}
220
221impl std::fmt::Display for SessionStatus {
222    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
223        match self {
224            SessionStatus::Active => write!(f, "ACTIVE"),
225            SessionStatus::Deleted => write!(f, "DELETED"),
226        }
227    }
228}
229
230/// Full session with optional active checkpoint
231#[derive(Debug, Clone, Serialize, Deserialize)]
232pub struct Session {
233    pub id: Uuid,
234    pub title: String,
235    pub visibility: SessionVisibility,
236    pub status: SessionStatus,
237    pub cwd: Option<String>,
238    pub created_at: DateTime<Utc>,
239    pub updated_at: DateTime<Utc>,
240    pub active_checkpoint: Option<Checkpoint>,
241}
242
243/// Session summary for list responses
244#[derive(Debug, Clone, Serialize, Deserialize)]
245pub struct SessionSummary {
246    pub id: Uuid,
247    pub title: String,
248    pub visibility: SessionVisibility,
249    pub status: SessionStatus,
250    pub cwd: Option<String>,
251    pub created_at: DateTime<Utc>,
252    pub updated_at: DateTime<Utc>,
253    pub message_count: u32,
254    pub active_checkpoint_id: Option<Uuid>,
255    pub last_message_at: Option<DateTime<Utc>>,
256}
257
258// =============================================================================
259// Checkpoint Types
260// =============================================================================
261
262/// Full checkpoint with state
263#[derive(Debug, Clone, Serialize, Deserialize)]
264pub struct Checkpoint {
265    pub id: Uuid,
266    pub session_id: Uuid,
267    pub parent_id: Option<Uuid>,
268    pub state: CheckpointState,
269    pub created_at: DateTime<Utc>,
270    pub updated_at: DateTime<Utc>,
271}
272
273/// Checkpoint summary for list responses
274#[derive(Debug, Clone, Serialize, Deserialize)]
275pub struct CheckpointSummary {
276    pub id: Uuid,
277    pub session_id: Uuid,
278    pub parent_id: Option<Uuid>,
279    pub message_count: u32,
280    pub created_at: DateTime<Utc>,
281    pub updated_at: DateTime<Utc>,
282}
283
284/// Checkpoint state containing messages
285#[derive(Debug, Clone, Default, Serialize, Deserialize)]
286pub struct CheckpointState {
287    #[serde(default)]
288    pub messages: Vec<ChatMessage>,
289    /// Optional metadata for context trimming state, etc.
290    #[serde(default, skip_serializing_if = "Option::is_none")]
291    pub metadata: Option<serde_json::Value>,
292}
293
294// =============================================================================
295// Request Types
296// =============================================================================
297
298/// Request to create a session with initial checkpoint
299#[derive(Debug, Clone, Serialize)]
300pub struct CreateSessionRequest {
301    pub title: String,
302    pub visibility: SessionVisibility,
303    pub cwd: Option<String>,
304    pub initial_state: CheckpointState,
305}
306
307impl CreateSessionRequest {
308    pub fn new(title: impl Into<String>, messages: Vec<ChatMessage>) -> Self {
309        Self {
310            title: title.into(),
311            visibility: SessionVisibility::Private,
312            cwd: None,
313            initial_state: CheckpointState {
314                messages,
315                metadata: None,
316            },
317        }
318    }
319
320    pub fn with_visibility(mut self, visibility: SessionVisibility) -> Self {
321        self.visibility = visibility;
322        self
323    }
324
325    pub fn with_cwd(mut self, cwd: impl Into<String>) -> Self {
326        self.cwd = Some(cwd.into());
327        self
328    }
329}
330
331/// Request to update a session
332#[derive(Debug, Clone, Default, Serialize)]
333pub struct UpdateSessionRequest {
334    pub title: Option<String>,
335    pub visibility: Option<SessionVisibility>,
336}
337
338impl UpdateSessionRequest {
339    pub fn new() -> Self {
340        Self::default()
341    }
342
343    pub fn with_title(mut self, title: impl Into<String>) -> Self {
344        self.title = Some(title.into());
345        self
346    }
347
348    pub fn with_visibility(mut self, visibility: SessionVisibility) -> Self {
349        self.visibility = Some(visibility);
350        self
351    }
352}
353
354/// Request to create a checkpoint
355#[derive(Debug, Clone, Serialize)]
356pub struct CreateCheckpointRequest {
357    pub state: CheckpointState,
358    pub parent_id: Option<Uuid>,
359}
360
361impl CreateCheckpointRequest {
362    pub fn new(messages: Vec<ChatMessage>) -> Self {
363        Self {
364            state: CheckpointState {
365                messages,
366                metadata: None,
367            },
368            parent_id: None,
369        }
370    }
371
372    pub fn with_parent(mut self, parent_id: Uuid) -> Self {
373        self.parent_id = Some(parent_id);
374        self
375    }
376
377    pub fn with_metadata(mut self, metadata: serde_json::Value) -> Self {
378        self.state.metadata = Some(metadata);
379        self
380    }
381}
382
383/// Query parameters for listing sessions
384#[derive(Debug, Clone, Default, Serialize)]
385pub struct ListSessionsQuery {
386    pub limit: Option<u32>,
387    pub offset: Option<u32>,
388    pub search: Option<String>,
389    pub status: Option<SessionStatus>,
390    pub visibility: Option<SessionVisibility>,
391}
392
393impl ListSessionsQuery {
394    pub fn new() -> Self {
395        Self::default()
396    }
397
398    pub fn with_limit(mut self, limit: u32) -> Self {
399        self.limit = Some(limit);
400        self
401    }
402
403    pub fn with_offset(mut self, offset: u32) -> Self {
404        self.offset = Some(offset);
405        self
406    }
407
408    pub fn with_search(mut self, search: impl Into<String>) -> Self {
409        self.search = Some(search.into());
410        self
411    }
412}
413
414/// Query parameters for listing checkpoints
415#[derive(Debug, Clone, Default, Serialize)]
416pub struct ListCheckpointsQuery {
417    pub limit: Option<u32>,
418    pub offset: Option<u32>,
419    pub include_state: Option<bool>,
420}
421
422impl ListCheckpointsQuery {
423    pub fn new() -> Self {
424        Self::default()
425    }
426
427    pub fn with_limit(mut self, limit: u32) -> Self {
428        self.limit = Some(limit);
429        self
430    }
431
432    pub fn with_state(mut self) -> Self {
433        self.include_state = Some(true);
434        self
435    }
436}
437
438// =============================================================================
439// Response Types
440// =============================================================================
441
442/// Result of creating a session
443#[derive(Debug, Clone)]
444pub struct CreateSessionResult {
445    pub session_id: Uuid,
446    pub checkpoint: Checkpoint,
447}
448
449/// Result of listing sessions
450#[derive(Debug, Clone)]
451pub struct ListSessionsResult {
452    pub sessions: Vec<SessionSummary>,
453    pub total: Option<u32>,
454}
455
456/// Result of listing checkpoints
457#[derive(Debug, Clone)]
458pub struct ListCheckpointsResult {
459    pub checkpoints: Vec<CheckpointSummary>,
460    pub total: Option<u32>,
461}
462
463// =============================================================================
464// Stats Types
465// =============================================================================
466
467/// Session statistics
468#[derive(Debug, Clone, Default, Serialize, Deserialize)]
469pub struct SessionStats {
470    pub total_sessions: u32,
471    pub total_tool_calls: u32,
472    pub successful_tool_calls: u32,
473    pub failed_tool_calls: u32,
474    pub aborted_tool_calls: u32,
475    pub sessions_with_activity: u32,
476    pub total_time_saved_seconds: Option<u32>,
477    pub tools_usage: Vec<ToolUsageStats>,
478}
479
480/// Tool usage statistics
481#[derive(Debug, Clone, Serialize, Deserialize)]
482pub struct ToolUsageStats {
483    pub tool_name: String,
484    pub display_name: String,
485    pub usage_counts: ToolUsageCounts,
486    pub time_saved_per_call: Option<f64>,
487    pub time_saved_seconds: Option<u32>,
488}
489
490/// Tool usage counts
491#[derive(Debug, Clone, Serialize, Deserialize)]
492pub struct ToolUsageCounts {
493    pub total: u32,
494    pub successful: u32,
495    pub failed: u32,
496    pub aborted: u32,
497}