Skip to main content

oven_cli/db/
graph.rs

1use anyhow::{Context, Result};
2use rusqlite::{Connection, params};
3
4/// State of a node in the dependency graph (stored as text in `SQLite`).
5#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
6pub enum NodeState {
7    Pending,
8    InFlight,
9    AwaitingMerge,
10    Merged,
11    Failed,
12}
13
14impl std::fmt::Display for NodeState {
15    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
16        f.write_str(match self {
17            Self::Pending => "pending",
18            Self::InFlight => "in_flight",
19            Self::AwaitingMerge => "awaiting_merge",
20            Self::Merged => "merged",
21            Self::Failed => "failed",
22        })
23    }
24}
25
26impl std::str::FromStr for NodeState {
27    type Err = anyhow::Error;
28
29    fn from_str(s: &str) -> Result<Self, Self::Err> {
30        match s {
31            "pending" => Ok(Self::Pending),
32            "in_flight" => Ok(Self::InFlight),
33            "awaiting_merge" => Ok(Self::AwaitingMerge),
34            "merged" => Ok(Self::Merged),
35            "failed" => Ok(Self::Failed),
36            other => anyhow::bail!("unknown node state: {other}"),
37        }
38    }
39}
40
41/// A row from the `graph_nodes` table.
42#[derive(Debug, Clone)]
43pub struct GraphNodeRow {
44    pub issue_number: u32,
45    pub session_id: String,
46    pub state: NodeState,
47    pub pr_number: Option<u32>,
48    pub run_id: Option<String>,
49    pub title: String,
50    pub area: String,
51    pub predicted_files: Vec<String>,
52    pub has_migration: bool,
53    pub complexity: String,
54    pub target_repo: Option<String>,
55}
56
57pub fn insert_node(conn: &Connection, session_id: &str, node: &GraphNodeRow) -> Result<()> {
58    let files_json =
59        serde_json::to_string(&node.predicted_files).context("serializing predicted_files")?;
60    conn.execute(
61        "INSERT OR REPLACE INTO graph_nodes \
62         (issue_number, session_id, state, pr_number, run_id, title, area, \
63          predicted_files, has_migration, complexity, target_repo) \
64         VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11)",
65        params![
66            node.issue_number,
67            session_id,
68            node.state.to_string(),
69            node.pr_number,
70            node.run_id,
71            node.title,
72            node.area,
73            files_json,
74            node.has_migration,
75            node.complexity,
76            node.target_repo,
77        ],
78    )
79    .context("inserting graph node")?;
80    Ok(())
81}
82
83pub fn update_node_state(
84    conn: &Connection,
85    session_id: &str,
86    issue_number: u32,
87    state: NodeState,
88) -> Result<()> {
89    conn.execute(
90        "UPDATE graph_nodes SET state = ?1 WHERE issue_number = ?2 AND session_id = ?3",
91        params![state.to_string(), issue_number, session_id],
92    )
93    .context("updating graph node state")?;
94    Ok(())
95}
96
97pub fn update_node_pr(
98    conn: &Connection,
99    session_id: &str,
100    issue_number: u32,
101    pr_number: u32,
102) -> Result<()> {
103    conn.execute(
104        "UPDATE graph_nodes SET pr_number = ?1 WHERE issue_number = ?2 AND session_id = ?3",
105        params![pr_number, issue_number, session_id],
106    )
107    .context("updating graph node PR")?;
108    Ok(())
109}
110
111pub fn update_node_run_id(
112    conn: &Connection,
113    session_id: &str,
114    issue_number: u32,
115    run_id: &str,
116) -> Result<()> {
117    conn.execute(
118        "UPDATE graph_nodes SET run_id = ?1 WHERE issue_number = ?2 AND session_id = ?3",
119        params![run_id, issue_number, session_id],
120    )
121    .context("updating graph node run_id")?;
122    Ok(())
123}
124
125pub fn insert_edge(
126    conn: &Connection,
127    session_id: &str,
128    from_issue: u32,
129    to_issue: u32,
130) -> Result<()> {
131    conn.execute(
132        "INSERT OR IGNORE INTO graph_edges (session_id, from_issue, to_issue) \
133         VALUES (?1, ?2, ?3)",
134        params![session_id, from_issue, to_issue],
135    )
136    .context("inserting graph edge")?;
137    Ok(())
138}
139
140pub fn get_nodes(conn: &Connection, session_id: &str) -> Result<Vec<GraphNodeRow>> {
141    let mut stmt = conn
142        .prepare(
143            "SELECT issue_number, session_id, state, pr_number, run_id, title, area, \
144             predicted_files, has_migration, complexity, target_repo \
145             FROM graph_nodes WHERE session_id = ?1 ORDER BY issue_number",
146        )
147        .context("preparing get_nodes")?;
148
149    let rows = stmt
150        .query_map(params![session_id], |row| {
151            let state_str: String = row.get(2)?;
152            let files_json: String = row.get(7)?;
153            Ok(GraphNodeRow {
154                issue_number: row.get(0)?,
155                session_id: row.get(1)?,
156                state: state_str.parse().map_err(|_| {
157                    rusqlite::Error::InvalidColumnType(
158                        2,
159                        "state".to_string(),
160                        rusqlite::types::Type::Text,
161                    )
162                })?,
163                pr_number: row.get(3)?,
164                run_id: row.get(4)?,
165                title: row.get(5)?,
166                area: row.get(6)?,
167                predicted_files: serde_json::from_str(&files_json).unwrap_or_default(),
168                has_migration: row.get(8)?,
169                complexity: row.get(9)?,
170                target_repo: row.get(10)?,
171            })
172        })
173        .context("querying graph nodes")?;
174
175    rows.collect::<std::result::Result<Vec<_>, _>>().context("collecting graph nodes")
176}
177
178pub fn get_edges(conn: &Connection, session_id: &str) -> Result<Vec<(u32, u32)>> {
179    let mut stmt = conn
180        .prepare(
181            "SELECT from_issue, to_issue FROM graph_edges \
182             WHERE session_id = ?1 ORDER BY from_issue, to_issue",
183        )
184        .context("preparing get_edges")?;
185
186    let rows = stmt
187        .query_map(params![session_id], |row| Ok((row.get(0)?, row.get(1)?)))
188        .context("querying graph edges")?;
189
190    rows.collect::<std::result::Result<Vec<_>, _>>().context("collecting graph edges")
191}
192
193pub fn delete_session(conn: &Connection, session_id: &str) -> Result<()> {
194    conn.execute("DELETE FROM graph_edges WHERE session_id = ?1", params![session_id])
195        .context("deleting graph edges")?;
196    conn.execute("DELETE FROM graph_nodes WHERE session_id = ?1", params![session_id])
197        .context("deleting graph nodes")?;
198    Ok(())
199}
200
201/// Find a session that has at least one non-terminal node.
202///
203/// When multiple active sessions exist, the returned session is arbitrary.
204/// In practice there is at most one active session at a time.
205pub fn get_active_session(conn: &Connection) -> Result<Option<String>> {
206    let mut stmt = conn
207        .prepare(
208            "SELECT DISTINCT session_id FROM graph_nodes \
209             WHERE state NOT IN ('merged', 'failed') \
210             LIMIT 1",
211        )
212        .context("preparing get_active_session")?;
213
214    let mut rows = stmt.query_map([], |row| row.get(0)).context("querying active session")?;
215    match rows.next() {
216        Some(row) => Ok(Some(row.context("reading session_id")?)),
217        None => Ok(None),
218    }
219}
220
221#[cfg(test)]
222mod tests {
223    use super::*;
224    use crate::db;
225
226    fn test_db() -> Connection {
227        db::open_in_memory().unwrap()
228    }
229
230    fn sample_node(issue: u32, session: &str) -> GraphNodeRow {
231        GraphNodeRow {
232            issue_number: issue,
233            session_id: session.to_string(),
234            state: NodeState::Pending,
235            pr_number: None,
236            run_id: None,
237            title: format!("Issue #{issue}"),
238            area: "test".to_string(),
239            predicted_files: vec!["src/main.rs".to_string()],
240            has_migration: false,
241            complexity: "full".to_string(),
242            target_repo: None,
243        }
244    }
245
246    #[test]
247    fn insert_and_get_nodes() {
248        let conn = test_db();
249        let node = sample_node(1, "sess1");
250        insert_node(&conn, "sess1", &node).unwrap();
251
252        let nodes = get_nodes(&conn, "sess1").unwrap();
253        assert_eq!(nodes.len(), 1);
254        assert_eq!(nodes[0].issue_number, 1);
255        assert_eq!(nodes[0].state, NodeState::Pending);
256        assert_eq!(nodes[0].predicted_files, vec!["src/main.rs"]);
257    }
258
259    #[test]
260    fn insert_and_get_edges() {
261        let conn = test_db();
262        insert_node(&conn, "sess1", &sample_node(1, "sess1")).unwrap();
263        insert_node(&conn, "sess1", &sample_node(2, "sess1")).unwrap();
264        insert_edge(&conn, "sess1", 2, 1).unwrap();
265
266        let edges = get_edges(&conn, "sess1").unwrap();
267        assert_eq!(edges, vec![(2, 1)]);
268    }
269
270    #[test]
271    fn update_state_persists() {
272        let conn = test_db();
273        insert_node(&conn, "sess1", &sample_node(1, "sess1")).unwrap();
274        update_node_state(&conn, "sess1", 1, NodeState::InFlight).unwrap();
275
276        let nodes = get_nodes(&conn, "sess1").unwrap();
277        assert_eq!(nodes[0].state, NodeState::InFlight);
278    }
279
280    #[test]
281    fn update_pr_persists() {
282        let conn = test_db();
283        insert_node(&conn, "sess1", &sample_node(1, "sess1")).unwrap();
284        update_node_pr(&conn, "sess1", 1, 42).unwrap();
285
286        let nodes = get_nodes(&conn, "sess1").unwrap();
287        assert_eq!(nodes[0].pr_number, Some(42));
288    }
289
290    #[test]
291    fn update_run_id_persists() {
292        let conn = test_db();
293        insert_node(&conn, "sess1", &sample_node(1, "sess1")).unwrap();
294        update_node_run_id(&conn, "sess1", 1, "abc123").unwrap();
295
296        let nodes = get_nodes(&conn, "sess1").unwrap();
297        assert_eq!(nodes[0].run_id.as_deref(), Some("abc123"));
298    }
299
300    #[test]
301    fn session_isolation() {
302        let conn = test_db();
303        insert_node(&conn, "sess1", &sample_node(1, "sess1")).unwrap();
304        insert_node(&conn, "sess2", &sample_node(2, "sess2")).unwrap();
305
306        assert_eq!(get_nodes(&conn, "sess1").unwrap().len(), 1);
307        assert_eq!(get_nodes(&conn, "sess2").unwrap().len(), 1);
308        assert_eq!(get_nodes(&conn, "sess3").unwrap().len(), 0);
309    }
310
311    #[test]
312    fn delete_session_removes_all() {
313        let conn = test_db();
314        insert_node(&conn, "sess1", &sample_node(1, "sess1")).unwrap();
315        insert_node(&conn, "sess1", &sample_node(2, "sess1")).unwrap();
316        insert_edge(&conn, "sess1", 2, 1).unwrap();
317
318        delete_session(&conn, "sess1").unwrap();
319        assert!(get_nodes(&conn, "sess1").unwrap().is_empty());
320        assert!(get_edges(&conn, "sess1").unwrap().is_empty());
321    }
322
323    #[test]
324    fn get_active_session_finds_non_terminal() {
325        let conn = test_db();
326        insert_node(&conn, "sess1", &sample_node(1, "sess1")).unwrap();
327        assert_eq!(get_active_session(&conn).unwrap().as_deref(), Some("sess1"));
328    }
329
330    #[test]
331    fn get_active_session_skips_all_terminal() {
332        let conn = test_db();
333        insert_node(&conn, "sess1", &sample_node(1, "sess1")).unwrap();
334        update_node_state(&conn, "sess1", 1, NodeState::Merged).unwrap();
335        assert!(get_active_session(&conn).unwrap().is_none());
336    }
337
338    #[test]
339    fn duplicate_edge_is_idempotent() {
340        let conn = test_db();
341        insert_node(&conn, "sess1", &sample_node(1, "sess1")).unwrap();
342        insert_node(&conn, "sess1", &sample_node(2, "sess1")).unwrap();
343        insert_edge(&conn, "sess1", 2, 1).unwrap();
344        insert_edge(&conn, "sess1", 2, 1).unwrap(); // no error
345
346        let edges = get_edges(&conn, "sess1").unwrap();
347        assert_eq!(edges.len(), 1);
348    }
349
350    #[test]
351    fn node_state_display_roundtrip() {
352        let states = [
353            NodeState::Pending,
354            NodeState::InFlight,
355            NodeState::AwaitingMerge,
356            NodeState::Merged,
357            NodeState::Failed,
358        ];
359        for state in states {
360            let s = state.to_string();
361            let parsed: NodeState = s.parse().unwrap();
362            assert_eq!(state, parsed);
363        }
364    }
365
366    #[test]
367    fn upsert_overwrites_existing_node() {
368        let conn = test_db();
369        let mut node = sample_node(1, "sess1");
370        insert_node(&conn, "sess1", &node).unwrap();
371
372        node.title = "Updated title".to_string();
373        node.state = NodeState::InFlight;
374        insert_node(&conn, "sess1", &node).unwrap();
375
376        let nodes = get_nodes(&conn, "sess1").unwrap();
377        assert_eq!(nodes.len(), 1);
378        assert_eq!(nodes[0].title, "Updated title");
379        assert_eq!(nodes[0].state, NodeState::InFlight);
380    }
381
382    #[test]
383    fn predicted_files_roundtrip_empty() {
384        let conn = test_db();
385        let mut node = sample_node(1, "sess1");
386        node.predicted_files = vec![];
387        insert_node(&conn, "sess1", &node).unwrap();
388
389        let nodes = get_nodes(&conn, "sess1").unwrap();
390        assert!(nodes[0].predicted_files.is_empty());
391    }
392}