1use anyhow::{Context, Result};
2use rusqlite::{Connection, params};
3
4#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
6pub enum NodeState {
7 Pending,
8 InFlight,
9 AwaitingMerge,
10 Merged,
11 Failed,
12}
13
14impl std::fmt::Display for NodeState {
15 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
16 f.write_str(match self {
17 Self::Pending => "pending",
18 Self::InFlight => "in_flight",
19 Self::AwaitingMerge => "awaiting_merge",
20 Self::Merged => "merged",
21 Self::Failed => "failed",
22 })
23 }
24}
25
26impl std::str::FromStr for NodeState {
27 type Err = anyhow::Error;
28
29 fn from_str(s: &str) -> Result<Self, Self::Err> {
30 match s {
31 "pending" => Ok(Self::Pending),
32 "in_flight" => Ok(Self::InFlight),
33 "awaiting_merge" => Ok(Self::AwaitingMerge),
34 "merged" => Ok(Self::Merged),
35 "failed" => Ok(Self::Failed),
36 other => anyhow::bail!("unknown node state: {other}"),
37 }
38 }
39}
40
41#[derive(Debug, Clone)]
43pub struct GraphNodeRow {
44 pub issue_number: u32,
45 pub session_id: String,
46 pub state: NodeState,
47 pub pr_number: Option<u32>,
48 pub run_id: Option<String>,
49 pub title: String,
50 pub area: String,
51 pub predicted_files: Vec<String>,
52 pub has_migration: bool,
53 pub complexity: String,
54 pub target_repo: Option<String>,
55}
56
57pub fn insert_node(conn: &Connection, session_id: &str, node: &GraphNodeRow) -> Result<()> {
58 let files_json =
59 serde_json::to_string(&node.predicted_files).context("serializing predicted_files")?;
60 conn.execute(
61 "INSERT OR REPLACE INTO graph_nodes \
62 (issue_number, session_id, state, pr_number, run_id, title, area, \
63 predicted_files, has_migration, complexity, target_repo) \
64 VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11)",
65 params![
66 node.issue_number,
67 session_id,
68 node.state.to_string(),
69 node.pr_number,
70 node.run_id,
71 node.title,
72 node.area,
73 files_json,
74 node.has_migration,
75 node.complexity,
76 node.target_repo,
77 ],
78 )
79 .context("inserting graph node")?;
80 Ok(())
81}
82
83pub fn update_node_state(
84 conn: &Connection,
85 session_id: &str,
86 issue_number: u32,
87 state: NodeState,
88) -> Result<()> {
89 conn.execute(
90 "UPDATE graph_nodes SET state = ?1 WHERE issue_number = ?2 AND session_id = ?3",
91 params![state.to_string(), issue_number, session_id],
92 )
93 .context("updating graph node state")?;
94 Ok(())
95}
96
97pub fn update_node_pr(
98 conn: &Connection,
99 session_id: &str,
100 issue_number: u32,
101 pr_number: u32,
102) -> Result<()> {
103 conn.execute(
104 "UPDATE graph_nodes SET pr_number = ?1 WHERE issue_number = ?2 AND session_id = ?3",
105 params![pr_number, issue_number, session_id],
106 )
107 .context("updating graph node PR")?;
108 Ok(())
109}
110
111pub fn update_node_run_id(
112 conn: &Connection,
113 session_id: &str,
114 issue_number: u32,
115 run_id: &str,
116) -> Result<()> {
117 conn.execute(
118 "UPDATE graph_nodes SET run_id = ?1 WHERE issue_number = ?2 AND session_id = ?3",
119 params![run_id, issue_number, session_id],
120 )
121 .context("updating graph node run_id")?;
122 Ok(())
123}
124
125pub fn insert_edge(
126 conn: &Connection,
127 session_id: &str,
128 from_issue: u32,
129 to_issue: u32,
130) -> Result<()> {
131 conn.execute(
132 "INSERT OR IGNORE INTO graph_edges (session_id, from_issue, to_issue) \
133 VALUES (?1, ?2, ?3)",
134 params![session_id, from_issue, to_issue],
135 )
136 .context("inserting graph edge")?;
137 Ok(())
138}
139
140pub fn get_nodes(conn: &Connection, session_id: &str) -> Result<Vec<GraphNodeRow>> {
141 let mut stmt = conn
142 .prepare(
143 "SELECT issue_number, session_id, state, pr_number, run_id, title, area, \
144 predicted_files, has_migration, complexity, target_repo \
145 FROM graph_nodes WHERE session_id = ?1 ORDER BY issue_number",
146 )
147 .context("preparing get_nodes")?;
148
149 let rows = stmt
150 .query_map(params![session_id], |row| {
151 let state_str: String = row.get(2)?;
152 let files_json: String = row.get(7)?;
153 Ok(GraphNodeRow {
154 issue_number: row.get(0)?,
155 session_id: row.get(1)?,
156 state: state_str.parse().map_err(|_| {
157 rusqlite::Error::InvalidColumnType(
158 2,
159 "state".to_string(),
160 rusqlite::types::Type::Text,
161 )
162 })?,
163 pr_number: row.get(3)?,
164 run_id: row.get(4)?,
165 title: row.get(5)?,
166 area: row.get(6)?,
167 predicted_files: serde_json::from_str(&files_json).unwrap_or_default(),
168 has_migration: row.get(8)?,
169 complexity: row.get(9)?,
170 target_repo: row.get(10)?,
171 })
172 })
173 .context("querying graph nodes")?;
174
175 rows.collect::<std::result::Result<Vec<_>, _>>().context("collecting graph nodes")
176}
177
178pub fn get_edges(conn: &Connection, session_id: &str) -> Result<Vec<(u32, u32)>> {
179 let mut stmt = conn
180 .prepare(
181 "SELECT from_issue, to_issue FROM graph_edges \
182 WHERE session_id = ?1 ORDER BY from_issue, to_issue",
183 )
184 .context("preparing get_edges")?;
185
186 let rows = stmt
187 .query_map(params![session_id], |row| Ok((row.get(0)?, row.get(1)?)))
188 .context("querying graph edges")?;
189
190 rows.collect::<std::result::Result<Vec<_>, _>>().context("collecting graph edges")
191}
192
193pub fn delete_session(conn: &Connection, session_id: &str) -> Result<()> {
194 conn.execute("DELETE FROM graph_edges WHERE session_id = ?1", params![session_id])
195 .context("deleting graph edges")?;
196 conn.execute("DELETE FROM graph_nodes WHERE session_id = ?1", params![session_id])
197 .context("deleting graph nodes")?;
198 Ok(())
199}
200
201pub fn get_active_session(conn: &Connection) -> Result<Option<String>> {
206 let mut stmt = conn
207 .prepare(
208 "SELECT DISTINCT session_id FROM graph_nodes \
209 WHERE state NOT IN ('merged', 'failed') \
210 LIMIT 1",
211 )
212 .context("preparing get_active_session")?;
213
214 let mut rows = stmt.query_map([], |row| row.get(0)).context("querying active session")?;
215 match rows.next() {
216 Some(row) => Ok(Some(row.context("reading session_id")?)),
217 None => Ok(None),
218 }
219}
220
221#[cfg(test)]
222mod tests {
223 use super::*;
224 use crate::db;
225
226 fn test_db() -> Connection {
227 db::open_in_memory().unwrap()
228 }
229
230 fn sample_node(issue: u32, session: &str) -> GraphNodeRow {
231 GraphNodeRow {
232 issue_number: issue,
233 session_id: session.to_string(),
234 state: NodeState::Pending,
235 pr_number: None,
236 run_id: None,
237 title: format!("Issue #{issue}"),
238 area: "test".to_string(),
239 predicted_files: vec!["src/main.rs".to_string()],
240 has_migration: false,
241 complexity: "full".to_string(),
242 target_repo: None,
243 }
244 }
245
246 #[test]
247 fn insert_and_get_nodes() {
248 let conn = test_db();
249 let node = sample_node(1, "sess1");
250 insert_node(&conn, "sess1", &node).unwrap();
251
252 let nodes = get_nodes(&conn, "sess1").unwrap();
253 assert_eq!(nodes.len(), 1);
254 assert_eq!(nodes[0].issue_number, 1);
255 assert_eq!(nodes[0].state, NodeState::Pending);
256 assert_eq!(nodes[0].predicted_files, vec!["src/main.rs"]);
257 }
258
259 #[test]
260 fn insert_and_get_edges() {
261 let conn = test_db();
262 insert_node(&conn, "sess1", &sample_node(1, "sess1")).unwrap();
263 insert_node(&conn, "sess1", &sample_node(2, "sess1")).unwrap();
264 insert_edge(&conn, "sess1", 2, 1).unwrap();
265
266 let edges = get_edges(&conn, "sess1").unwrap();
267 assert_eq!(edges, vec![(2, 1)]);
268 }
269
270 #[test]
271 fn update_state_persists() {
272 let conn = test_db();
273 insert_node(&conn, "sess1", &sample_node(1, "sess1")).unwrap();
274 update_node_state(&conn, "sess1", 1, NodeState::InFlight).unwrap();
275
276 let nodes = get_nodes(&conn, "sess1").unwrap();
277 assert_eq!(nodes[0].state, NodeState::InFlight);
278 }
279
280 #[test]
281 fn update_pr_persists() {
282 let conn = test_db();
283 insert_node(&conn, "sess1", &sample_node(1, "sess1")).unwrap();
284 update_node_pr(&conn, "sess1", 1, 42).unwrap();
285
286 let nodes = get_nodes(&conn, "sess1").unwrap();
287 assert_eq!(nodes[0].pr_number, Some(42));
288 }
289
290 #[test]
291 fn update_run_id_persists() {
292 let conn = test_db();
293 insert_node(&conn, "sess1", &sample_node(1, "sess1")).unwrap();
294 update_node_run_id(&conn, "sess1", 1, "abc123").unwrap();
295
296 let nodes = get_nodes(&conn, "sess1").unwrap();
297 assert_eq!(nodes[0].run_id.as_deref(), Some("abc123"));
298 }
299
300 #[test]
301 fn session_isolation() {
302 let conn = test_db();
303 insert_node(&conn, "sess1", &sample_node(1, "sess1")).unwrap();
304 insert_node(&conn, "sess2", &sample_node(2, "sess2")).unwrap();
305
306 assert_eq!(get_nodes(&conn, "sess1").unwrap().len(), 1);
307 assert_eq!(get_nodes(&conn, "sess2").unwrap().len(), 1);
308 assert_eq!(get_nodes(&conn, "sess3").unwrap().len(), 0);
309 }
310
311 #[test]
312 fn delete_session_removes_all() {
313 let conn = test_db();
314 insert_node(&conn, "sess1", &sample_node(1, "sess1")).unwrap();
315 insert_node(&conn, "sess1", &sample_node(2, "sess1")).unwrap();
316 insert_edge(&conn, "sess1", 2, 1).unwrap();
317
318 delete_session(&conn, "sess1").unwrap();
319 assert!(get_nodes(&conn, "sess1").unwrap().is_empty());
320 assert!(get_edges(&conn, "sess1").unwrap().is_empty());
321 }
322
323 #[test]
324 fn get_active_session_finds_non_terminal() {
325 let conn = test_db();
326 insert_node(&conn, "sess1", &sample_node(1, "sess1")).unwrap();
327 assert_eq!(get_active_session(&conn).unwrap().as_deref(), Some("sess1"));
328 }
329
330 #[test]
331 fn get_active_session_skips_all_terminal() {
332 let conn = test_db();
333 insert_node(&conn, "sess1", &sample_node(1, "sess1")).unwrap();
334 update_node_state(&conn, "sess1", 1, NodeState::Merged).unwrap();
335 assert!(get_active_session(&conn).unwrap().is_none());
336 }
337
338 #[test]
339 fn duplicate_edge_is_idempotent() {
340 let conn = test_db();
341 insert_node(&conn, "sess1", &sample_node(1, "sess1")).unwrap();
342 insert_node(&conn, "sess1", &sample_node(2, "sess1")).unwrap();
343 insert_edge(&conn, "sess1", 2, 1).unwrap();
344 insert_edge(&conn, "sess1", 2, 1).unwrap(); let edges = get_edges(&conn, "sess1").unwrap();
347 assert_eq!(edges.len(), 1);
348 }
349
350 #[test]
351 fn node_state_display_roundtrip() {
352 let states = [
353 NodeState::Pending,
354 NodeState::InFlight,
355 NodeState::AwaitingMerge,
356 NodeState::Merged,
357 NodeState::Failed,
358 ];
359 for state in states {
360 let s = state.to_string();
361 let parsed: NodeState = s.parse().unwrap();
362 assert_eq!(state, parsed);
363 }
364 }
365
366 #[test]
367 fn upsert_overwrites_existing_node() {
368 let conn = test_db();
369 let mut node = sample_node(1, "sess1");
370 insert_node(&conn, "sess1", &node).unwrap();
371
372 node.title = "Updated title".to_string();
373 node.state = NodeState::InFlight;
374 insert_node(&conn, "sess1", &node).unwrap();
375
376 let nodes = get_nodes(&conn, "sess1").unwrap();
377 assert_eq!(nodes.len(), 1);
378 assert_eq!(nodes[0].title, "Updated title");
379 assert_eq!(nodes[0].state, NodeState::InFlight);
380 }
381
382 #[test]
383 fn predicted_files_roundtrip_empty() {
384 let conn = test_db();
385 let mut node = sample_node(1, "sess1");
386 node.predicted_files = vec![];
387 insert_node(&conn, "sess1", &node).unwrap();
388
389 let nodes = get_nodes(&conn, "sess1").unwrap();
390 assert!(nodes[0].predicted_files.is_empty());
391 }
392}