mi6_storage_sqlite/
storage.rs

1//! Storage trait implementation for SqliteStorage.
2//!
3//! This module implements the `Storage` trait from mi6-core, providing
4//! the main database operations:
5//! - Event insertion with session updates
6//! - Event querying with filters
7//! - Garbage collection
8//! - Statistics computation
9//! - Transcript position tracking
10
11use std::path::Path;
12use std::time::Duration;
13
14use chrono::Utc;
15use mi6_core::{
16    Event, EventQuery, FilePosition, GitBranchInfo, Order, Session, SessionQuery, Storage,
17    StorageError, StorageStats, StorageStatsQuery,
18};
19use rusqlite::{Connection, params};
20
21use crate::query_builder::QueryBuilder;
22use crate::row_parsing::row_to_event;
23use crate::session;
24use crate::sql;
25
26/// Implement Storage trait for SqliteStorage.
27impl Storage for super::SqliteStorage {
28    fn insert(&self, event: &Event) -> Result<i64, StorageError> {
29        // Use a transaction to ensure atomicity of event insertion and session updates.
30        // If either operation fails, the entire transaction is rolled back, preventing
31        // the database from being left in an inconsistent state where an event exists
32        // but session aggregates are not updated.
33        //
34        // SAFETY: unchecked_transaction() is safe here because rusqlite::Connection is
35        // not Sync, so concurrent access from multiple threads is already prevented.
36        // The "unchecked" aspect refers to not preventing nested transactions at compile
37        // time; SQLite handles nested transaction attempts gracefully at runtime.
38        let tx = self
39            .conn
40            .unchecked_transaction()
41            .map_err(|e| StorageError::Query(Box::new(e)))?;
42
43        // Insert the event
44        tx.execute(
45            sql::INSERT_EVENT,
46            params![
47                event.machine_id,
48                event.timestamp.timestamp_millis(),
49                event.event_type.to_string(),
50                event.session_id,
51                event.framework,
52                event.tool_use_id,
53                event.spawned_agent_id,
54                event.tool_name,
55                event.subagent_type,
56                event.permission_mode,
57                event.transcript_path,
58                event.pid,
59                event.cwd,
60                event.git_branch,
61                event.model,
62                event.tokens_input,
63                event.tokens_output,
64                event.tokens_cache_read,
65                event.tokens_cache_write,
66                event.cost_usd,
67                event.duration_ms,
68                event.payload,
69                event.metadata,
70                event.source,
71                event.is_sidechain,
72            ],
73        )
74        .map_err(|e| StorageError::Query(Box::new(e)))?;
75
76        let event_id = tx.last_insert_rowid();
77
78        // Update the sessions table based on event type
79        session::update_for_event(&tx, event)?;
80
81        // Commit the transaction - on error, tx is dropped and changes are rolled back
82        tx.commit().map_err(|e| StorageError::Query(Box::new(e)))?;
83
84        Ok(event_id)
85    }
86
87    fn query(&self, query: &EventQuery) -> Result<Vec<Event>, StorageError> {
88        let mut qb = QueryBuilder::new(
89            "SELECT id, machine_id, timestamp, event_type, session_id, framework, tool_use_id, spawned_agent_id, tool_name, subagent_type, permission_mode, transcript_path, pid, cwd, git_branch, model, tokens_input, tokens_output, tokens_cache_read, tokens_cache_write, cost_usd, duration_ms, payload, metadata, source, is_sidechain FROM events",
90        );
91
92        // Apply session filter (single or multiple)
93        if let Some(ref session_ids) = query.session_ids {
94            // Empty session_ids means no results (handled by and_in returning false)
95            if !qb.and_in("session_id", session_ids) {
96                return Ok(vec![]);
97            }
98        } else if let Some(ref session_id) = query.session_id {
99            qb.and_eq("session_id", session_id.clone());
100        }
101        if let Some(ref event_type) = query.event_type {
102            qb.and_eq_upper("event_type", event_type.clone());
103        }
104        if let Some(ref permission_mode) = query.permission_mode {
105            qb.and_eq("permission_mode", permission_mode.clone());
106        }
107        if let Some(ref framework) = query.framework {
108            qb.and_eq("framework", framework.clone());
109        }
110        if let Some(after_ts) = query.after_ts {
111            qb.and_gt("timestamp", after_ts.timestamp_millis());
112        }
113        if let Some(before_ts) = query.before_ts {
114            qb.and_lt("timestamp", before_ts.timestamp_millis());
115        }
116        if let Some(after_id) = query.after_id {
117            qb.and_gt("id", after_id);
118        }
119
120        // Filter for API requests only or exclude them
121        if query.api_requests_only {
122            qb.and_eq_upper("event_type", "ApiRequest".to_string());
123        } else if query.exclude_api_requests {
124            qb.and_neq_upper("event_type", "ApiRequest".to_string());
125        }
126
127        // Use EventQuery helper methods for ordering
128        let orders_by_id = query.orders_by_id();
129        let direction = query.effective_direction();
130        let order_clause = match (orders_by_id, direction) {
131            (true, Order::Asc) => "id ASC",
132            (true, Order::Desc) => "id DESC",
133            (false, Order::Asc) => "timestamp ASC",
134            (false, Order::Desc) => "timestamp DESC",
135        };
136        qb.order_by(order_clause);
137
138        // Apply limit if specified
139        if let Some(limit) = query.limit {
140            qb.limit(limit);
141        }
142
143        let (sql, params) = qb.build();
144
145        let mut stmt = self
146            .conn
147            .prepare(&sql)
148            .map_err(|e| StorageError::Query(Box::new(e)))?;
149
150        let events = stmt
151            .query_map(params.as_slice(), row_to_event)
152            .map_err(|e| StorageError::Query(Box::new(e)))?
153            .collect::<Result<Vec<_>, _>>()
154            .map_err(|e| StorageError::Query(Box::new(e)))?;
155
156        Ok(events)
157    }
158
159    fn gc(&self, retention: Duration) -> Result<usize, StorageError> {
160        let chrono_retention =
161            chrono::Duration::from_std(retention).map_err(|e| StorageError::Query(Box::new(e)))?;
162        let cutoff = (Utc::now() - chrono_retention).timestamp_millis();
163
164        // Delete events older than retention period (includes API requests now)
165        let deleted = self
166            .conn
167            .execute("DELETE FROM events WHERE timestamp < ?1", [cutoff])
168            .map_err(|e| StorageError::Query(Box::new(e)))?;
169
170        Ok(deleted)
171    }
172
173    fn count_expired(&self, retention: Duration) -> Result<usize, StorageError> {
174        let chrono_retention =
175            chrono::Duration::from_std(retention).map_err(|e| StorageError::Query(Box::new(e)))?;
176        let cutoff = (Utc::now() - chrono_retention).timestamp_millis();
177
178        let count: usize = self
179            .conn
180            .query_row(
181                "SELECT COUNT(*) FROM events WHERE timestamp < ?1",
182                [cutoff],
183                |row| row.get(0),
184            )
185            .map_err(|e| StorageError::Query(Box::new(e)))?;
186
187        Ok(count)
188    }
189
190    fn count(&self) -> Result<usize, StorageError> {
191        let count: usize = self
192            .conn
193            .query_row("SELECT COUNT(*) FROM events", [], |row| row.get(0))
194            .map_err(|e| StorageError::Query(Box::new(e)))?;
195
196        Ok(count)
197    }
198
199    fn list_sessions(&self, query: &SessionQuery) -> Result<Vec<Session>, StorageError> {
200        session::list(&self.conn, query)
201    }
202
203    fn get_session(&self, session_id: &str) -> Result<Option<Session>, StorageError> {
204        session::get(&self.conn, session_id)
205    }
206
207    fn get_session_by_key(
208        &self,
209        machine_id: &str,
210        session_id: &str,
211    ) -> Result<Option<Session>, StorageError> {
212        session::get_by_key(&self.conn, machine_id, session_id)
213    }
214
215    fn get_session_by_pid(&self, pid: i32) -> Result<Option<Session>, StorageError> {
216        session::get_by_pid(&self.conn, pid)
217    }
218
219    fn update_session_git_info(
220        &self,
221        session_id: &str,
222        git_info: &GitBranchInfo,
223    ) -> Result<bool, StorageError> {
224        session::update_git_info(&self.conn, session_id, git_info)
225    }
226
227    fn update_session_github_repo(
228        &self,
229        session_id: &str,
230        github_repo: &str,
231    ) -> Result<bool, StorageError> {
232        session::update_github_repo(&self.conn, session_id, github_repo)
233    }
234
235    fn update_session_transcript_path(
236        &self,
237        machine_id: &str,
238        session_id: &str,
239        transcript_path: &str,
240    ) -> Result<bool, StorageError> {
241        session::update_transcript_path(&self.conn, machine_id, session_id, transcript_path)
242    }
243
244    fn upsert_session_git_context(
245        &self,
246        session_id: &str,
247        machine_id: &str,
248        framework: &str,
249        timestamp: i64,
250        local_git_dir: Option<&str>,
251        github_repo: Option<&str>,
252    ) -> Result<(), StorageError> {
253        session::upsert_git_context(
254            &self.conn,
255            session_id,
256            machine_id,
257            framework,
258            timestamp,
259            local_git_dir,
260            github_repo,
261        )
262    }
263
264    fn storage_stats(&self, query: &StorageStatsQuery) -> Result<StorageStats, StorageError> {
265        storage_stats(&self.conn, query)
266    }
267
268    fn get_transcript_position(
269        &self,
270        path: &std::path::Path,
271    ) -> Result<Option<FilePosition>, StorageError> {
272        get_transcript_position(&self.conn, path)
273    }
274
275    fn set_transcript_position(
276        &self,
277        path: &std::path::Path,
278        position: &FilePosition,
279    ) -> Result<(), StorageError> {
280        set_transcript_position(&self.conn, path, position)
281    }
282
283    fn event_exists_by_uuid(&self, session_id: &str, uuid: &str) -> Result<bool, StorageError> {
284        event_exists_by_uuid(&self.conn, session_id, uuid)
285    }
286
287    fn query_transcript_positions(&self) -> Result<Vec<(String, FilePosition)>, StorageError> {
288        query_transcript_positions(&self.conn)
289    }
290}
291
292/// Compute storage statistics across all sessions.
293pub(crate) fn storage_stats(
294    conn: &Connection,
295    query: &StorageStatsQuery,
296) -> Result<StorageStats, StorageError> {
297    // v14 schema uses unified columns directly
298    let mut qb = QueryBuilder::new(
299        r"SELECT
300            COUNT(*) as session_count,
301            COUNT(CASE WHEN last_ended_at IS NULL THEN 1 END) as active_count,
302            COALESCE(SUM(tokens_input + tokens_output + tokens_cache_read + tokens_cache_write), 0) as total_tokens,
303            COALESCE(SUM(cost_usd), 0.0) as total_cost,
304            COALESCE(SUM(api_request_count), 0) as total_requests
305        FROM sessions",
306    );
307
308    // Apply filters
309    if query.active_only {
310        qb.and_is_null("last_ended_at");
311    }
312    if let Some(ref framework) = query.framework {
313        qb.and_eq("framework", framework.clone());
314    }
315
316    let (sql, params) = qb.build();
317
318    conn.query_row(&sql, params.as_slice(), |row| {
319        Ok(StorageStats {
320            session_count: row.get::<_, i64>(0)? as u32,
321            active_session_count: row.get::<_, i64>(1)? as u32,
322            total_tokens: row.get(2)?,
323            total_cost_usd: row.get(3)?,
324            total_api_requests: row.get::<_, i64>(4)? as u32,
325        })
326    })
327    .map_err(|e| StorageError::Query(Box::new(e)))
328}
329
330/// Get the last scanned position for a transcript file.
331///
332/// Returns `None` if the file has never been scanned.
333pub(crate) fn get_transcript_position(
334    conn: &Connection,
335    path: &Path,
336) -> Result<Option<FilePosition>, StorageError> {
337    let path_str = path.to_string_lossy();
338
339    let result = conn.query_row(
340        "SELECT byte_offset, line_number, last_uuid FROM transcript_positions WHERE file_path = ?1",
341        [&path_str],
342        |row| {
343            Ok(FilePosition {
344                offset: row.get::<_, i64>(0)? as u64,
345                line_number: row.get::<_, i64>(1)? as u64,
346                last_uuid: row.get(2)?,
347            })
348        },
349    );
350
351    match result {
352        Ok(pos) => Ok(Some(pos)),
353        Err(rusqlite::Error::QueryReturnedNoRows) => Ok(None),
354        Err(e) => Err(StorageError::Query(Box::new(e))),
355    }
356}
357
358/// Set the scanned position for a transcript file.
359///
360/// Uses UPSERT to insert or update the position.
361pub(crate) fn set_transcript_position(
362    conn: &Connection,
363    path: &Path,
364    position: &FilePosition,
365) -> Result<(), StorageError> {
366    let path_str = path.to_string_lossy();
367    let now = std::time::SystemTime::now()
368        .duration_since(std::time::UNIX_EPOCH)
369        .map(|d| d.as_millis() as i64)
370        .unwrap_or(0);
371
372    conn.execute(
373        r"INSERT INTO transcript_positions (file_path, byte_offset, line_number, last_uuid, updated_at)
374          VALUES (?1, ?2, ?3, ?4, ?5)
375          ON CONFLICT (file_path) DO UPDATE SET
376            byte_offset = excluded.byte_offset,
377            line_number = excluded.line_number,
378            last_uuid = excluded.last_uuid,
379            updated_at = excluded.updated_at",
380        params![
381            path_str,
382            position.offset as i64,
383            position.line_number as i64,
384            position.last_uuid,
385            now,
386        ],
387    )
388    .map_err(|e| StorageError::Query(Box::new(e)))?;
389
390    Ok(())
391}
392
393/// Check if an event with the given UUID exists for a session.
394///
395/// Used for deduplication when parsing transcripts.
396pub(crate) fn event_exists_by_uuid(
397    conn: &Connection,
398    session_id: &str,
399    uuid: &str,
400) -> Result<bool, StorageError> {
401    // UUIDs are stored in metadata JSON as {"uuid": "..."}
402    let pattern = format!("%\"uuid\":\"{}\"%", uuid);
403
404    let count: i64 = conn
405        .query_row(
406            "SELECT COUNT(*) FROM events WHERE session_id = ?1 AND metadata LIKE ?2",
407            params![session_id, pattern],
408            |row| row.get(0),
409        )
410        .map_err(|e| StorageError::Query(Box::new(e)))?;
411
412    Ok(count > 0)
413}
414
415/// Query all transcript file positions.
416///
417/// Returns a list of (file_path, position) pairs.
418pub(crate) fn query_transcript_positions(
419    conn: &Connection,
420) -> Result<Vec<(String, FilePosition)>, StorageError> {
421    let mut stmt = conn
422        .prepare(
423            "SELECT file_path, byte_offset, line_number, last_uuid FROM transcript_positions ORDER BY file_path",
424        )
425        .map_err(|e| StorageError::Query(Box::new(e)))?;
426
427    let positions = stmt
428        .query_map([], |row| {
429            let path: String = row.get(0)?;
430            let position = FilePosition {
431                offset: row.get::<_, i64>(1)? as u64,
432                line_number: row.get::<_, i64>(2)? as u64,
433                last_uuid: row.get(3)?,
434            };
435            Ok((path, position))
436        })
437        .map_err(|e| StorageError::Query(Box::new(e)))?
438        .collect::<Result<Vec<_>, _>>()
439        .map_err(|e| StorageError::Query(Box::new(e)))?;
440
441    Ok(positions)
442}