Skip to main content

opensession_local_db/
lib.rs

1pub mod git;
2
3use anyhow::{Context, Result};
4use opensession_api_types::SessionSummary;
5use opensession_core::trace::Session;
6use rusqlite::{params, Connection, OptionalExtension};
7use std::path::PathBuf;
8use std::sync::Mutex;
9
10use git::GitContext;
11
12/// A local session row stored in the local SQLite database.
13#[derive(Debug, Clone)]
14pub struct LocalSessionRow {
15    pub id: String,
16    pub source_path: Option<String>,
17    pub sync_status: String,
18    pub last_synced_at: Option<String>,
19    pub user_id: Option<String>,
20    pub nickname: Option<String>,
21    pub team_id: Option<String>,
22    pub tool: String,
23    pub agent_provider: Option<String>,
24    pub agent_model: Option<String>,
25    pub title: Option<String>,
26    pub description: Option<String>,
27    pub tags: Option<String>,
28    pub created_at: String,
29    pub uploaded_at: Option<String>,
30    pub message_count: i64,
31    pub task_count: i64,
32    pub event_count: i64,
33    pub duration_seconds: i64,
34    pub total_input_tokens: i64,
35    pub total_output_tokens: i64,
36    pub git_remote: Option<String>,
37    pub git_branch: Option<String>,
38    pub git_commit: Option<String>,
39    pub git_repo_name: Option<String>,
40    pub pr_number: Option<i64>,
41    pub pr_url: Option<String>,
42    pub working_directory: Option<String>,
43}
44
45/// Filter for listing sessions from the local DB.
46#[derive(Debug, Default)]
47pub struct LocalSessionFilter {
48    pub team_id: Option<String>,
49    pub sync_status: Option<String>,
50    pub git_repo_name: Option<String>,
51    pub search: Option<String>,
52    pub tool: Option<String>,
53}
54
55/// Local SQLite database shared by TUI and Daemon.
56/// Thread-safe: wraps the connection in a Mutex so it can be shared via `Arc<LocalDb>`.
57pub struct LocalDb {
58    conn: Mutex<Connection>,
59}
60
61impl LocalDb {
62    /// Open (or create) the local database at the default path.
63    /// `~/.local/share/opensession/local.db`
64    pub fn open() -> Result<Self> {
65        let path = default_db_path()?;
66        Self::open_path(&path)
67    }
68
69    /// Open (or create) the local database at a specific path.
70    pub fn open_path(path: &PathBuf) -> Result<Self> {
71        if let Some(parent) = path.parent() {
72            std::fs::create_dir_all(parent)
73                .with_context(|| format!("create dir for {}", path.display()))?;
74        }
75        let conn = Connection::open(path).with_context(|| format!("open db {}", path.display()))?;
76        conn.execute_batch("PRAGMA journal_mode=WAL;")?;
77        conn.execute_batch("PRAGMA foreign_keys=ON;")?;
78        conn.execute_batch(opensession_api_types::db::LOCAL_SCHEMA)?;
79        Ok(Self {
80            conn: Mutex::new(conn),
81        })
82    }
83
84    fn conn(&self) -> std::sync::MutexGuard<'_, Connection> {
85        self.conn.lock().expect("local db mutex poisoned")
86    }
87
88    // ── Upsert local session (parsed from file) ────────────────────────
89
90    pub fn upsert_local_session(
91        &self,
92        session: &Session,
93        source_path: &str,
94        git: &GitContext,
95    ) -> Result<()> {
96        let title = session.context.title.as_deref();
97        let description = session.context.description.as_deref();
98        let tags = if session.context.tags.is_empty() {
99            None
100        } else {
101            Some(session.context.tags.join(","))
102        };
103        let created_at = session.context.created_at.to_rfc3339();
104        let cwd = session
105            .context
106            .attributes
107            .get("cwd")
108            .or_else(|| session.context.attributes.get("working_directory"))
109            .and_then(|v| v.as_str().map(String::from));
110
111        self.conn().execute(
112            "INSERT INTO local_sessions \
113             (id, source_path, sync_status, tool, agent_provider, agent_model, \
114              title, description, tags, created_at, \
115              message_count, task_count, event_count, duration_seconds, \
116              total_input_tokens, total_output_tokens, \
117              git_remote, git_branch, git_commit, git_repo_name, working_directory) \
118             VALUES (?1,?2,'local_only',?3,?4,?5,?6,?7,?8,?9,?10,?11,?12,?13,?14,?15,?16,?17,?18,?19,?20) \
119             ON CONFLICT(id) DO UPDATE SET \
120              source_path=excluded.source_path, \
121              tool=excluded.tool, agent_provider=excluded.agent_provider, \
122              agent_model=excluded.agent_model, \
123              title=excluded.title, description=excluded.description, \
124              tags=excluded.tags, \
125              message_count=excluded.message_count, task_count=excluded.task_count, \
126              event_count=excluded.event_count, duration_seconds=excluded.duration_seconds, \
127              total_input_tokens=excluded.total_input_tokens, \
128              total_output_tokens=excluded.total_output_tokens, \
129              git_remote=excluded.git_remote, git_branch=excluded.git_branch, \
130              git_commit=excluded.git_commit, git_repo_name=excluded.git_repo_name, \
131              working_directory=excluded.working_directory",
132            params![
133                &session.session_id,
134                source_path,
135                &session.agent.tool,
136                &session.agent.provider,
137                &session.agent.model,
138                title,
139                description,
140                &tags,
141                &created_at,
142                session.stats.message_count as i64,
143                session.stats.task_count as i64,
144                session.stats.event_count as i64,
145                session.stats.duration_seconds as i64,
146                session.stats.total_input_tokens as i64,
147                session.stats.total_output_tokens as i64,
148                &git.remote,
149                &git.branch,
150                &git.commit,
151                &git.repo_name,
152                &cwd,
153            ],
154        )?;
155        Ok(())
156    }
157
158    // ── Upsert remote session (from server sync pull) ──────────────────
159
160    pub fn upsert_remote_session(&self, summary: &SessionSummary) -> Result<()> {
161        self.conn().execute(
162            "INSERT INTO local_sessions \
163             (id, sync_status, user_id, nickname, team_id, tool, \
164              agent_provider, agent_model, title, description, tags, \
165              created_at, uploaded_at, \
166              message_count, task_count, event_count, duration_seconds, \
167              total_input_tokens, total_output_tokens) \
168             VALUES (?1,'remote_only',?2,?3,?4,?5,?6,?7,?8,?9,?10,?11,?12,?13,?14,?15,?16,?17,?18) \
169             ON CONFLICT(id) DO UPDATE SET \
170              nickname=excluded.nickname, \
171              title=excluded.title, description=excluded.description, \
172              tags=excluded.tags, uploaded_at=excluded.uploaded_at, \
173              message_count=excluded.message_count, task_count=excluded.task_count, \
174              event_count=excluded.event_count, duration_seconds=excluded.duration_seconds, \
175              total_input_tokens=excluded.total_input_tokens, \
176              total_output_tokens=excluded.total_output_tokens \
177              WHERE sync_status = 'remote_only'",
178            params![
179                &summary.id,
180                &summary.user_id,
181                &summary.nickname,
182                &summary.team_id,
183                &summary.tool,
184                &summary.agent_provider,
185                &summary.agent_model,
186                &summary.title,
187                &summary.description,
188                &summary.tags,
189                &summary.created_at,
190                &summary.uploaded_at,
191                summary.message_count,
192                summary.task_count,
193                summary.event_count,
194                summary.duration_seconds,
195                summary.total_input_tokens,
196                summary.total_output_tokens,
197            ],
198        )?;
199        Ok(())
200    }
201
202    // ── List sessions ──────────────────────────────────────────────────
203
204    pub fn list_sessions(&self, filter: &LocalSessionFilter) -> Result<Vec<LocalSessionRow>> {
205        let mut where_clauses = vec!["1=1".to_string()];
206        let mut param_values: Vec<Box<dyn rusqlite::types::ToSql>> = Vec::new();
207        let mut idx = 1u32;
208
209        if let Some(ref team_id) = filter.team_id {
210            where_clauses.push(format!("team_id = ?{idx}"));
211            param_values.push(Box::new(team_id.clone()));
212            idx += 1;
213        }
214
215        if let Some(ref sync_status) = filter.sync_status {
216            where_clauses.push(format!("sync_status = ?{idx}"));
217            param_values.push(Box::new(sync_status.clone()));
218            idx += 1;
219        }
220
221        if let Some(ref repo) = filter.git_repo_name {
222            where_clauses.push(format!("git_repo_name = ?{idx}"));
223            param_values.push(Box::new(repo.clone()));
224            idx += 1;
225        }
226
227        if let Some(ref tool) = filter.tool {
228            where_clauses.push(format!("tool = ?{idx}"));
229            param_values.push(Box::new(tool.clone()));
230            idx += 1;
231        }
232
233        if let Some(ref search) = filter.search {
234            let like = format!("%{search}%");
235            where_clauses.push(format!(
236                "(title LIKE ?{i1} OR description LIKE ?{i2} OR tags LIKE ?{i3})",
237                i1 = idx,
238                i2 = idx + 1,
239                i3 = idx + 2,
240            ));
241            param_values.push(Box::new(like.clone()));
242            param_values.push(Box::new(like.clone()));
243            param_values.push(Box::new(like));
244            // idx += 3; // not needed after last use
245        }
246
247        let where_str = where_clauses.join(" AND ");
248        let sql = format!(
249            "SELECT id, source_path, sync_status, last_synced_at, \
250                    user_id, nickname, team_id, tool, agent_provider, agent_model, \
251                    title, description, tags, created_at, uploaded_at, \
252                    message_count, task_count, event_count, duration_seconds, \
253                    total_input_tokens, total_output_tokens, \
254                    git_remote, git_branch, git_commit, git_repo_name, \
255                    pr_number, pr_url, working_directory \
256             FROM local_sessions WHERE {where_str} \
257             ORDER BY created_at DESC"
258        );
259
260        let param_refs: Vec<&dyn rusqlite::types::ToSql> =
261            param_values.iter().map(|p| p.as_ref()).collect();
262        let conn = self.conn();
263        let mut stmt = conn.prepare(&sql)?;
264        let rows = stmt.query_map(param_refs.as_slice(), row_to_local_session)?;
265
266        let mut result = Vec::new();
267        for row in rows {
268            result.push(row?);
269        }
270        Ok(result)
271    }
272
273    // ── Sync cursor ────────────────────────────────────────────────────
274
275    pub fn get_sync_cursor(&self, team_id: &str) -> Result<Option<String>> {
276        let cursor = self
277            .conn()
278            .query_row(
279                "SELECT cursor FROM sync_cursors WHERE team_id = ?1",
280                params![team_id],
281                |row| row.get(0),
282            )
283            .optional()?;
284        Ok(cursor)
285    }
286
287    pub fn set_sync_cursor(&self, team_id: &str, cursor: &str) -> Result<()> {
288        self.conn().execute(
289            "INSERT INTO sync_cursors (team_id, cursor, updated_at) \
290             VALUES (?1, ?2, datetime('now')) \
291             ON CONFLICT(team_id) DO UPDATE SET cursor=excluded.cursor, updated_at=datetime('now')",
292            params![team_id, cursor],
293        )?;
294        Ok(())
295    }
296
297    // ── Upload tracking ────────────────────────────────────────────────
298
299    /// Get sessions that are local_only and need to be uploaded.
300    pub fn pending_uploads(&self, team_id: &str) -> Result<Vec<LocalSessionRow>> {
301        let sql = "SELECT id, source_path, sync_status, last_synced_at, \
302                          user_id, nickname, team_id, tool, agent_provider, agent_model, \
303                          title, description, tags, created_at, uploaded_at, \
304                          message_count, task_count, event_count, duration_seconds, \
305                          total_input_tokens, total_output_tokens, \
306                          git_remote, git_branch, git_commit, git_repo_name, \
307                          pr_number, pr_url, working_directory \
308                   FROM local_sessions WHERE sync_status = 'local_only' AND team_id = ?1 \
309                   ORDER BY created_at ASC";
310        let conn = self.conn();
311        let mut stmt = conn.prepare(sql)?;
312        let rows = stmt.query_map(params![team_id], row_to_local_session)?;
313        let mut result = Vec::new();
314        for row in rows {
315            result.push(row?);
316        }
317        Ok(result)
318    }
319
320    pub fn mark_synced(&self, session_id: &str) -> Result<()> {
321        self.conn().execute(
322            "UPDATE local_sessions SET sync_status = 'synced', last_synced_at = datetime('now') \
323             WHERE id = ?1",
324            params![session_id],
325        )?;
326        Ok(())
327    }
328
329    /// Check if a session was already uploaded (synced or remote_only) since the given modification time.
330    pub fn was_uploaded_after(
331        &self,
332        source_path: &str,
333        modified: &chrono::DateTime<chrono::Utc>,
334    ) -> Result<bool> {
335        let result: Option<String> = self
336            .conn()
337            .query_row(
338                "SELECT last_synced_at FROM local_sessions \
339                 WHERE source_path = ?1 AND sync_status = 'synced' AND last_synced_at IS NOT NULL",
340                params![source_path],
341                |row| row.get(0),
342            )
343            .optional()?;
344
345        if let Some(synced_at) = result {
346            if let Ok(dt) = chrono::DateTime::parse_from_rfc3339(&synced_at) {
347                return Ok(dt >= *modified);
348            }
349        }
350        Ok(false)
351    }
352
353    // ── Body cache ─────────────────────────────────────────────────────
354
355    pub fn cache_body(&self, session_id: &str, body: &[u8]) -> Result<()> {
356        self.conn().execute(
357            "INSERT INTO body_cache (session_id, body, cached_at) \
358             VALUES (?1, ?2, datetime('now')) \
359             ON CONFLICT(session_id) DO UPDATE SET body=excluded.body, cached_at=datetime('now')",
360            params![session_id, body],
361        )?;
362        Ok(())
363    }
364
365    pub fn get_cached_body(&self, session_id: &str) -> Result<Option<Vec<u8>>> {
366        let body = self
367            .conn()
368            .query_row(
369                "SELECT body FROM body_cache WHERE session_id = ?1",
370                params![session_id],
371                |row| row.get(0),
372            )
373            .optional()?;
374        Ok(body)
375    }
376
377    // ── Migration helper ───────────────────────────────────────────────
378
379    /// Migrate entries from the old state.json UploadState into the local DB.
380    /// Marks them as `synced` with no metadata (we only know the file path was uploaded).
381    pub fn migrate_from_state_json(
382        &self,
383        uploaded: &std::collections::HashMap<String, chrono::DateTime<chrono::Utc>>,
384    ) -> Result<usize> {
385        let mut count = 0;
386        for (path, uploaded_at) in uploaded {
387            let exists: bool = self
388                .conn()
389                .query_row(
390                    "SELECT COUNT(*) > 0 FROM local_sessions WHERE source_path = ?1",
391                    params![path],
392                    |row| row.get(0),
393                )
394                .unwrap_or(false);
395
396            if exists {
397                self.conn().execute(
398                    "UPDATE local_sessions SET sync_status = 'synced', last_synced_at = ?1 \
399                     WHERE source_path = ?2 AND sync_status = 'local_only'",
400                    params![uploaded_at.to_rfc3339(), path],
401                )?;
402                count += 1;
403            }
404        }
405        Ok(count)
406    }
407
408    /// Get a list of distinct git repo names present in the DB.
409    pub fn list_repos(&self) -> Result<Vec<String>> {
410        let conn = self.conn();
411        let mut stmt = conn.prepare(
412            "SELECT DISTINCT git_repo_name FROM local_sessions \
413             WHERE git_repo_name IS NOT NULL ORDER BY git_repo_name ASC",
414        )?;
415        let rows = stmt.query_map([], |row| row.get(0))?;
416        let mut result = Vec::new();
417        for row in rows {
418            result.push(row?);
419        }
420        Ok(result)
421    }
422}
423
424fn row_to_local_session(row: &rusqlite::Row) -> rusqlite::Result<LocalSessionRow> {
425    Ok(LocalSessionRow {
426        id: row.get(0)?,
427        source_path: row.get(1)?,
428        sync_status: row.get(2)?,
429        last_synced_at: row.get(3)?,
430        user_id: row.get(4)?,
431        nickname: row.get(5)?,
432        team_id: row.get(6)?,
433        tool: row.get(7)?,
434        agent_provider: row.get(8)?,
435        agent_model: row.get(9)?,
436        title: row.get(10)?,
437        description: row.get(11)?,
438        tags: row.get(12)?,
439        created_at: row.get(13)?,
440        uploaded_at: row.get(14)?,
441        message_count: row.get(15)?,
442        task_count: row.get(16)?,
443        event_count: row.get(17)?,
444        duration_seconds: row.get(18)?,
445        total_input_tokens: row.get(19)?,
446        total_output_tokens: row.get(20)?,
447        git_remote: row.get(21)?,
448        git_branch: row.get(22)?,
449        git_commit: row.get(23)?,
450        git_repo_name: row.get(24)?,
451        pr_number: row.get(25)?,
452        pr_url: row.get(26)?,
453        working_directory: row.get(27)?,
454    })
455}
456
457fn default_db_path() -> Result<PathBuf> {
458    let home = std::env::var("HOME")
459        .or_else(|_| std::env::var("USERPROFILE"))
460        .context("Could not determine home directory")?;
461    Ok(PathBuf::from(home)
462        .join(".local")
463        .join("share")
464        .join("opensession")
465        .join("local.db"))
466}
467
468#[cfg(test)]
469mod tests {
470    use super::*;
471
472    fn test_db() -> LocalDb {
473        let dir = tempfile::tempdir().unwrap();
474        let path = dir.keep().join("test.db");
475        LocalDb::open_path(&path).unwrap()
476    }
477
478    #[test]
479    fn test_open_and_schema() {
480        let _db = test_db();
481    }
482
483    #[test]
484    fn test_sync_cursor() {
485        let db = test_db();
486        assert_eq!(db.get_sync_cursor("team1").unwrap(), None);
487        db.set_sync_cursor("team1", "2024-01-01T00:00:00Z").unwrap();
488        assert_eq!(
489            db.get_sync_cursor("team1").unwrap(),
490            Some("2024-01-01T00:00:00Z".to_string())
491        );
492        // Update
493        db.set_sync_cursor("team1", "2024-06-01T00:00:00Z").unwrap();
494        assert_eq!(
495            db.get_sync_cursor("team1").unwrap(),
496            Some("2024-06-01T00:00:00Z".to_string())
497        );
498    }
499
500    #[test]
501    fn test_body_cache() {
502        let db = test_db();
503        assert_eq!(db.get_cached_body("s1").unwrap(), None);
504        db.cache_body("s1", b"hello world").unwrap();
505        assert_eq!(
506            db.get_cached_body("s1").unwrap(),
507            Some(b"hello world".to_vec())
508        );
509    }
510
511    #[test]
512    fn test_upsert_remote_session() {
513        let db = test_db();
514        let summary = SessionSummary {
515            id: "remote-1".to_string(),
516            user_id: Some("u1".to_string()),
517            nickname: Some("alice".to_string()),
518            team_id: "t1".to_string(),
519            tool: "claude-code".to_string(),
520            agent_provider: None,
521            agent_model: None,
522            title: Some("Test session".to_string()),
523            description: None,
524            tags: None,
525            created_at: "2024-01-01T00:00:00Z".to_string(),
526            uploaded_at: "2024-01-01T01:00:00Z".to_string(),
527            message_count: 10,
528            task_count: 2,
529            event_count: 20,
530            duration_seconds: 300,
531            total_input_tokens: 1000,
532            total_output_tokens: 500,
533        };
534        db.upsert_remote_session(&summary).unwrap();
535
536        let sessions = db.list_sessions(&LocalSessionFilter::default()).unwrap();
537        assert_eq!(sessions.len(), 1);
538        assert_eq!(sessions[0].id, "remote-1");
539        assert_eq!(sessions[0].sync_status, "remote_only");
540        assert_eq!(sessions[0].nickname, Some("alice".to_string()));
541    }
542
543    #[test]
544    fn test_list_filter_by_repo() {
545        let db = test_db();
546        // Insert a remote session with team_id
547        let summary1 = SessionSummary {
548            id: "s1".to_string(),
549            user_id: None,
550            nickname: None,
551            team_id: "t1".to_string(),
552            tool: "claude-code".to_string(),
553            agent_provider: None,
554            agent_model: None,
555            title: Some("Session 1".to_string()),
556            description: None,
557            tags: None,
558            created_at: "2024-01-01T00:00:00Z".to_string(),
559            uploaded_at: "2024-01-01T01:00:00Z".to_string(),
560            message_count: 5,
561            task_count: 0,
562            event_count: 10,
563            duration_seconds: 60,
564            total_input_tokens: 100,
565            total_output_tokens: 50,
566        };
567        db.upsert_remote_session(&summary1).unwrap();
568
569        // Filter by team
570        let filter = LocalSessionFilter {
571            team_id: Some("t1".to_string()),
572            ..Default::default()
573        };
574        assert_eq!(db.list_sessions(&filter).unwrap().len(), 1);
575
576        let filter = LocalSessionFilter {
577            team_id: Some("t999".to_string()),
578            ..Default::default()
579        };
580        assert_eq!(db.list_sessions(&filter).unwrap().len(), 0);
581    }
582}