1use crate::db::models::Dependency;
2use crate::error::{IntentError, Result};
3use sqlx::SqlitePool;
4
5pub async fn check_circular_dependency(
45 pool: &SqlitePool,
46 blocking_task_id: i64,
47 blocked_task_id: i64,
48) -> Result<bool> {
49 if blocking_task_id == blocked_task_id {
51 return Ok(true);
52 }
53
54 let has_cycle: bool = sqlx::query_scalar(
57 r#"
58 WITH RECURSIVE dep_chain(task_id, depth) AS (
59 -- Start from the NEW blocking task
60 SELECT ? as task_id, 0 as depth
61
62 UNION ALL
63
64 -- Follow what each task depends on (its blocking tasks)
65 SELECT d.blocking_task_id, dc.depth + 1
66 FROM dependencies d
67 JOIN dep_chain dc ON d.blocked_task_id = dc.task_id
68 WHERE dc.depth < 100
69 )
70 SELECT COUNT(*) > 0
71 FROM dep_chain
72 WHERE task_id = ?
73 "#,
74 )
75 .bind(blocking_task_id)
76 .bind(blocked_task_id)
77 .fetch_one(pool)
78 .await?;
79
80 Ok(has_cycle)
81}
82
83pub async fn add_dependency(
102 pool: &SqlitePool,
103 blocking_task_id: i64,
104 blocked_task_id: i64,
105) -> Result<Dependency> {
106 let blocking_exists: bool = sqlx::query_scalar("SELECT COUNT(*) > 0 FROM tasks WHERE id = ?")
108 .bind(blocking_task_id)
109 .fetch_one(pool)
110 .await?;
111
112 if !blocking_exists {
113 return Err(IntentError::TaskNotFound(blocking_task_id));
114 }
115
116 let blocked_exists: bool = sqlx::query_scalar("SELECT COUNT(*) > 0 FROM tasks WHERE id = ?")
117 .bind(blocked_task_id)
118 .fetch_one(pool)
119 .await?;
120
121 if !blocked_exists {
122 return Err(IntentError::TaskNotFound(blocked_task_id));
123 }
124
125 if check_circular_dependency(pool, blocking_task_id, blocked_task_id).await? {
127 return Err(IntentError::CircularDependency {
128 blocking_task_id,
129 blocked_task_id,
130 });
131 }
132
133 let result = sqlx::query(
135 r#"
136 INSERT INTO dependencies (blocking_task_id, blocked_task_id)
137 VALUES (?, ?)
138 "#,
139 )
140 .bind(blocking_task_id)
141 .bind(blocked_task_id)
142 .execute(pool)
143 .await?;
144
145 let dependency_id = result.last_insert_rowid();
146
147 let dependency = sqlx::query_as::<_, Dependency>(
149 "SELECT id, blocking_task_id, blocked_task_id, created_at FROM dependencies WHERE id = ?",
150 )
151 .bind(dependency_id)
152 .fetch_one(pool)
153 .await?;
154
155 Ok(dependency)
156}
157
158pub async fn get_blocking_tasks(pool: &SqlitePool, task_id: i64) -> Result<Vec<i64>> {
169 let blocking_ids = sqlx::query_scalar::<_, i64>(
170 r#"
171 SELECT blocking_task_id
172 FROM dependencies
173 WHERE blocked_task_id = ?
174 "#,
175 )
176 .bind(task_id)
177 .fetch_all(pool)
178 .await?;
179
180 Ok(blocking_ids)
181}
182
183pub async fn get_blocked_tasks(pool: &SqlitePool, task_id: i64) -> Result<Vec<i64>> {
194 let blocked_ids = sqlx::query_scalar::<_, i64>(
195 r#"
196 SELECT blocked_task_id
197 FROM dependencies
198 WHERE blocking_task_id = ?
199 "#,
200 )
201 .bind(task_id)
202 .fetch_all(pool)
203 .await?;
204
205 Ok(blocked_ids)
206}
207
208pub async fn get_incomplete_blocking_tasks(
222 pool: &SqlitePool,
223 task_id: i64,
224) -> Result<Option<Vec<i64>>> {
225 let incomplete_blocking: Vec<i64> = sqlx::query_scalar::<_, i64>(
226 r#"
227 SELECT d.blocking_task_id
228 FROM dependencies d
229 JOIN tasks t ON t.id = d.blocking_task_id
230 WHERE d.blocked_task_id = ?
231 AND t.status IN ('todo', 'doing')
232 "#,
233 )
234 .bind(task_id)
235 .fetch_all(pool)
236 .await?;
237
238 if incomplete_blocking.is_empty() {
239 Ok(None)
240 } else {
241 Ok(Some(incomplete_blocking))
242 }
243}
244
245#[cfg(test)]
246mod tests {
247 use super::*;
248 use crate::db::{create_pool, run_migrations};
249 use tempfile::TempDir;
250
251 async fn setup_test_db() -> (TempDir, SqlitePool) {
252 let temp_dir = TempDir::new().unwrap();
253 let db_path = temp_dir.path().join("test.db");
254 let pool = create_pool(&db_path).await.unwrap();
255 run_migrations(&pool).await.unwrap();
256 (temp_dir, pool)
257 }
258
259 async fn create_test_task(pool: &SqlitePool, name: &str) -> i64 {
260 sqlx::query("INSERT INTO tasks (name, status) VALUES (?, 'todo')")
261 .bind(name)
262 .execute(pool)
263 .await
264 .unwrap()
265 .last_insert_rowid()
266 }
267
268 #[tokio::test]
269 async fn test_check_circular_dependency_self() {
270 let (_temp, pool) = setup_test_db().await;
271 let task_a = create_test_task(&pool, "Task A").await;
272
273 let is_circular = check_circular_dependency(&pool, task_a, task_a)
275 .await
276 .unwrap();
277 assert!(is_circular);
278 }
279
280 #[tokio::test]
281 async fn test_check_circular_dependency_direct_cycle() {
282 let (_temp, pool) = setup_test_db().await;
283 let task_a = create_test_task(&pool, "Task A").await;
284 let task_b = create_test_task(&pool, "Task B").await;
285
286 sqlx::query("INSERT INTO dependencies (blocking_task_id, blocked_task_id) VALUES (?, ?)")
288 .bind(task_b)
289 .bind(task_a)
290 .execute(&pool)
291 .await
292 .unwrap();
293
294 let is_circular = check_circular_dependency(&pool, task_a, task_b)
297 .await
298 .unwrap();
299 assert!(is_circular);
300 }
301
302 #[tokio::test]
303 async fn test_check_circular_dependency_transitive_cycle() {
304 let (_temp, pool) = setup_test_db().await;
305 let task_a = create_test_task(&pool, "Task A").await;
306 let task_b = create_test_task(&pool, "Task B").await;
307 let task_c = create_test_task(&pool, "Task C").await;
308
309 sqlx::query("INSERT INTO dependencies (blocking_task_id, blocked_task_id) VALUES (?, ?)")
311 .bind(task_b)
312 .bind(task_a)
313 .execute(&pool)
314 .await
315 .unwrap();
316
317 sqlx::query("INSERT INTO dependencies (blocking_task_id, blocked_task_id) VALUES (?, ?)")
318 .bind(task_c)
319 .bind(task_b)
320 .execute(&pool)
321 .await
322 .unwrap();
323
324 let is_circular = check_circular_dependency(&pool, task_a, task_c)
326 .await
327 .unwrap();
328 assert!(is_circular);
329 }
330
331 #[tokio::test]
332 async fn test_check_circular_dependency_no_cycle() {
333 let (_temp, pool) = setup_test_db().await;
334 let task_a = create_test_task(&pool, "Task A").await;
335 let task_b = create_test_task(&pool, "Task B").await;
336 let task_c = create_test_task(&pool, "Task C").await;
337
338 sqlx::query("INSERT INTO dependencies (blocking_task_id, blocked_task_id) VALUES (?, ?)")
340 .bind(task_b)
341 .bind(task_a)
342 .execute(&pool)
343 .await
344 .unwrap();
345
346 let is_circular = check_circular_dependency(&pool, task_a, task_c)
348 .await
349 .unwrap();
350 assert!(!is_circular);
351 }
352
353 #[tokio::test]
354 async fn test_check_circular_dependency_deep_chain() {
355 let (_temp, pool) = setup_test_db().await;
356 let task_a = create_test_task(&pool, "Task A").await;
357 let task_b = create_test_task(&pool, "Task B").await;
358 let task_c = create_test_task(&pool, "Task C").await;
359 let task_d = create_test_task(&pool, "Task D").await;
360 let task_e = create_test_task(&pool, "Task E").await;
361
362 sqlx::query("INSERT INTO dependencies (blocking_task_id, blocked_task_id) VALUES (?, ?)")
364 .bind(task_b)
365 .bind(task_a)
366 .execute(&pool)
367 .await
368 .unwrap();
369
370 sqlx::query("INSERT INTO dependencies (blocking_task_id, blocked_task_id) VALUES (?, ?)")
371 .bind(task_c)
372 .bind(task_b)
373 .execute(&pool)
374 .await
375 .unwrap();
376
377 sqlx::query("INSERT INTO dependencies (blocking_task_id, blocked_task_id) VALUES (?, ?)")
378 .bind(task_d)
379 .bind(task_c)
380 .execute(&pool)
381 .await
382 .unwrap();
383
384 sqlx::query("INSERT INTO dependencies (blocking_task_id, blocked_task_id) VALUES (?, ?)")
385 .bind(task_e)
386 .bind(task_d)
387 .execute(&pool)
388 .await
389 .unwrap();
390
391 let is_circular = check_circular_dependency(&pool, task_a, task_e)
393 .await
394 .unwrap();
395 assert!(is_circular);
396 }
397
398 #[tokio::test]
399 async fn test_add_dependency_success() {
400 let (_temp, pool) = setup_test_db().await;
401 let task_a = create_test_task(&pool, "Task A").await;
402 let task_b = create_test_task(&pool, "Task B").await;
403
404 let dep = add_dependency(&pool, task_b, task_a).await.unwrap();
405
406 assert_eq!(dep.blocking_task_id, task_b);
407 assert_eq!(dep.blocked_task_id, task_a);
408 }
409
410 #[tokio::test]
411 async fn test_add_dependency_circular_error() {
412 let (_temp, pool) = setup_test_db().await;
413 let task_a = create_test_task(&pool, "Task A").await;
414 let task_b = create_test_task(&pool, "Task B").await;
415
416 add_dependency(&pool, task_b, task_a).await.unwrap();
418
419 let result = add_dependency(&pool, task_a, task_b).await;
421 assert!(matches!(
422 result,
423 Err(IntentError::CircularDependency { .. })
424 ));
425 }
426
427 #[tokio::test]
428 async fn test_add_dependency_task_not_found() {
429 let (_temp, pool) = setup_test_db().await;
430 let task_a = create_test_task(&pool, "Task A").await;
431
432 let result = add_dependency(&pool, 9999, task_a).await;
434 assert!(matches!(result, Err(IntentError::TaskNotFound(9999))));
435 }
436
437 #[tokio::test]
438 async fn test_get_blocking_tasks() {
439 let (_temp, pool) = setup_test_db().await;
440 let task_a = create_test_task(&pool, "Task A").await;
441 let task_b = create_test_task(&pool, "Task B").await;
442 let task_c = create_test_task(&pool, "Task C").await;
443
444 add_dependency(&pool, task_b, task_a).await.unwrap();
446 add_dependency(&pool, task_c, task_a).await.unwrap();
447
448 let blocking = get_blocking_tasks(&pool, task_a).await.unwrap();
449 assert_eq!(blocking.len(), 2);
450 assert!(blocking.contains(&task_b));
451 assert!(blocking.contains(&task_c));
452 }
453
454 #[tokio::test]
455 async fn test_get_blocked_tasks() {
456 let (_temp, pool) = setup_test_db().await;
457 let task_a = create_test_task(&pool, "Task A").await;
458 let task_b = create_test_task(&pool, "Task B").await;
459 let task_c = create_test_task(&pool, "Task C").await;
460
461 add_dependency(&pool, task_a, task_b).await.unwrap();
463 add_dependency(&pool, task_a, task_c).await.unwrap();
464
465 let blocked = get_blocked_tasks(&pool, task_a).await.unwrap();
466 assert_eq!(blocked.len(), 2);
467 assert!(blocked.contains(&task_b));
468 assert!(blocked.contains(&task_c));
469 }
470
471 #[tokio::test]
472 async fn test_get_incomplete_blocking_tasks_blocked() {
473 let (_temp, pool) = setup_test_db().await;
474 let task_a = create_test_task(&pool, "Task A").await;
475 let task_b = create_test_task(&pool, "Task B").await;
476
477 add_dependency(&pool, task_b, task_a).await.unwrap();
479
480 let incomplete = get_incomplete_blocking_tasks(&pool, task_a).await.unwrap();
481 assert!(incomplete.is_some());
482 assert_eq!(incomplete.unwrap(), vec![task_b]);
483 }
484
485 #[tokio::test]
486 async fn test_get_incomplete_blocking_tasks_not_blocked() {
487 let (_temp, pool) = setup_test_db().await;
488 let task_a = create_test_task(&pool, "Task A").await;
489 let task_b = create_test_task(&pool, "Task B").await;
490
491 add_dependency(&pool, task_b, task_a).await.unwrap();
493 sqlx::query("UPDATE tasks SET status = 'done' WHERE id = ?")
494 .bind(task_b)
495 .execute(&pool)
496 .await
497 .unwrap();
498
499 let incomplete = get_incomplete_blocking_tasks(&pool, task_a).await.unwrap();
500 assert!(incomplete.is_none());
501 }
502}