Skip to main content

kaizen/store/sqlite/
session_read.rs

1use super::rows::*;
2use super::*;
3
4impl Store {
5    pub fn list_sessions(&self, workspace: &str) -> Result<Vec<SessionRecord>> {
6        Ok(self
7            .list_sessions_page(workspace, 0, i64::MAX as usize, SessionFilter::default())?
8            .rows)
9    }
10
11    pub fn list_sessions_page(
12        &self,
13        workspace: &str,
14        offset: usize,
15        limit: usize,
16        filter: SessionFilter,
17    ) -> Result<SessionPage> {
18        let (where_sql, args) = session_filter_sql(workspace, &filter);
19        let total = self.query_session_page_count(&where_sql, &args)?;
20        let rows = self.query_session_page_rows(&where_sql, &args, offset, limit)?;
21        let next = offset.saturating_add(rows.len());
22        Ok(SessionPage {
23            rows,
24            total,
25            next_offset: (next < total).then_some(next),
26        })
27    }
28
29    pub(super) fn query_session_page_count(
30        &self,
31        where_sql: &str,
32        args: &[Value],
33    ) -> Result<usize> {
34        let sql = format!("SELECT COUNT(*) FROM sessions {where_sql}");
35        let total: i64 = self
36            .conn
37            .query_row(&sql, params_from_iter(args.iter()), |r| r.get(0))?;
38        Ok(total as usize)
39    }
40
41    pub(super) fn query_session_page_rows(
42        &self,
43        where_sql: &str,
44        args: &[Value],
45        offset: usize,
46        limit: usize,
47    ) -> Result<Vec<SessionRecord>> {
48        let sql = format!(
49            "{SESSION_SELECT} {where_sql} ORDER BY started_at_ms DESC, id ASC LIMIT ? OFFSET ?"
50        );
51        let mut values = args.to_vec();
52        values.push(Value::Integer(limit.min(i64::MAX as usize) as i64));
53        values.push(Value::Integer(offset.min(i64::MAX as usize) as i64));
54        let mut stmt = self.conn.prepare(&sql)?;
55        let rows = stmt.query_map(params_from_iter(values.iter()), session_row)?;
56        rows.map(|r| r.map_err(anyhow::Error::from)).collect()
57    }
58
59    pub fn list_sessions_started_after(
60        &self,
61        workspace: &str,
62        after_started_at_ms: u64,
63    ) -> Result<Vec<SessionRecord>> {
64        let mut stmt = self.conn.prepare(
65            "SELECT id, agent, model, workspace, started_at_ms, ended_at_ms, status, trace_path,
66                    start_commit, end_commit, branch, dirty_start, dirty_end, repo_binding_source,
67                    prompt_fingerprint, parent_session_id, agent_version, os, arch,
68                    repo_file_count, repo_total_loc
69             FROM sessions
70             WHERE workspace = ?1 AND started_at_ms > ?2
71             ORDER BY started_at_ms DESC, id ASC",
72        )?;
73        let rows = stmt.query_map(params![workspace, after_started_at_ms as i64], session_row)?;
74        rows.map(|r| r.map_err(anyhow::Error::from)).collect()
75    }
76
77    pub fn session_statuses(&self, ids: &[String]) -> Result<Vec<SessionStatusRow>> {
78        if ids.is_empty() {
79            return Ok(Vec::new());
80        }
81        let placeholders = ids.iter().map(|_| "?").collect::<Vec<_>>().join(",");
82        let sql =
83            format!("SELECT id, status, ended_at_ms FROM sessions WHERE id IN ({placeholders})");
84        let mut stmt = self.conn.prepare(&sql)?;
85        let params: Vec<&dyn rusqlite::ToSql> =
86            ids.iter().map(|s| s as &dyn rusqlite::ToSql).collect();
87        let rows = stmt.query_map(params.as_slice(), |r| {
88            let status: String = r.get(1)?;
89            Ok(SessionStatusRow {
90                id: r.get(0)?,
91                status: status_from_str(&status),
92                ended_at_ms: r.get::<_, Option<i64>>(2)?.map(|v| v as u64),
93            })
94        })?;
95        rows.map(|r| r.map_err(anyhow::Error::from)).collect()
96    }
97
98    pub(super) fn running_session_ids(&self) -> Result<Vec<String>> {
99        let mut stmt = self
100            .conn
101            .prepare("SELECT id FROM sessions WHERE status != 'Done' ORDER BY started_at_ms ASC")?;
102        let rows = stmt.query_map([], |r| r.get::<_, String>(0))?;
103        rows.map(|r| r.map_err(anyhow::Error::from)).collect()
104    }
105
106    pub fn get_session(&self, id: &str) -> Result<Option<SessionRecord>> {
107        let mut stmt = self.conn.prepare(
108            "SELECT id, agent, model, workspace, started_at_ms, ended_at_ms, status, trace_path,
109                    start_commit, end_commit, branch, dirty_start, dirty_end, repo_binding_source,
110                    prompt_fingerprint, parent_session_id, agent_version, os, arch,
111                    repo_file_count, repo_total_loc
112             FROM sessions WHERE id = ?1",
113        )?;
114        let mut rows = stmt.query_map(params![id], |row| {
115            Ok((
116                row.get::<_, String>(0)?,
117                row.get::<_, String>(1)?,
118                row.get::<_, Option<String>>(2)?,
119                row.get::<_, String>(3)?,
120                row.get::<_, i64>(4)?,
121                row.get::<_, Option<i64>>(5)?,
122                row.get::<_, String>(6)?,
123                row.get::<_, String>(7)?,
124                row.get::<_, Option<String>>(8)?,
125                row.get::<_, Option<String>>(9)?,
126                row.get::<_, Option<String>>(10)?,
127                row.get::<_, Option<i64>>(11)?,
128                row.get::<_, Option<i64>>(12)?,
129                row.get::<_, String>(13)?,
130                row.get::<_, Option<String>>(14)?,
131                row.get::<_, Option<String>>(15)?,
132                row.get::<_, Option<String>>(16)?,
133                row.get::<_, Option<String>>(17)?,
134                row.get::<_, Option<String>>(18)?,
135                row.get::<_, Option<i64>>(19)?,
136                row.get::<_, Option<i64>>(20)?,
137            ))
138        })?;
139
140        if let Some(row) = rows.next() {
141            let (
142                id,
143                agent,
144                model,
145                workspace,
146                started,
147                ended,
148                status_str,
149                trace,
150                start_commit,
151                end_commit,
152                branch,
153                dirty_start,
154                dirty_end,
155                source,
156                prompt_fingerprint,
157                parent_session_id,
158                agent_version,
159                os,
160                arch,
161                repo_file_count,
162                repo_total_loc,
163            ) = row?;
164            Ok(Some(SessionRecord {
165                id,
166                agent,
167                model,
168                workspace,
169                started_at_ms: started as u64,
170                ended_at_ms: ended.map(|v| v as u64),
171                status: status_from_str(&status_str),
172                trace_path: trace,
173                start_commit,
174                end_commit,
175                branch,
176                dirty_start: dirty_start.map(i64_to_bool),
177                dirty_end: dirty_end.map(i64_to_bool),
178                repo_binding_source: empty_to_none(source),
179                prompt_fingerprint,
180                parent_session_id,
181                agent_version,
182                os,
183                arch,
184                repo_file_count: repo_file_count.map(|v| v as u32),
185                repo_total_loc: repo_total_loc.map(|v| v as u64),
186            }))
187        } else {
188            Ok(None)
189        }
190    }
191}