1use crate::db::models::Dependency;
2use crate::error::{IntentError, Result};
3use sqlx::SqlitePool;
4
5pub async fn check_circular_dependency(
45 pool: &SqlitePool,
46 blocking_task_id: i64,
47 blocked_task_id: i64,
48) -> Result<bool> {
49 if blocking_task_id == blocked_task_id {
51 return Ok(true);
52 }
53
54 let has_cycle: bool = sqlx::query_scalar(
57 r#"
58 WITH RECURSIVE dep_chain(task_id, depth) AS (
59 -- Start from the NEW blocking task
60 SELECT ? as task_id, 0 as depth
61
62 UNION ALL
63
64 -- Follow what each task depends on (its blocking tasks)
65 SELECT d.blocking_task_id, dc.depth + 1
66 FROM dependencies d
67 JOIN dep_chain dc ON d.blocked_task_id = dc.task_id
68 WHERE dc.depth < 100
69 )
70 SELECT COUNT(*) > 0
71 FROM dep_chain
72 WHERE task_id = ?
73 "#,
74 )
75 .bind(blocking_task_id)
76 .bind(blocked_task_id)
77 .fetch_one(pool)
78 .await?;
79
80 Ok(has_cycle)
81}
82
83pub async fn add_dependency(
102 pool: &SqlitePool,
103 blocking_task_id: i64,
104 blocked_task_id: i64,
105) -> Result<Dependency> {
106 let blocking_exists: bool =
108 sqlx::query_scalar::<_, bool>("SELECT COUNT(*) > 0 FROM tasks WHERE id = ?")
109 .bind(blocking_task_id)
110 .fetch_one(pool)
111 .await?;
112
113 if !blocking_exists {
114 return Err(IntentError::TaskNotFound(blocking_task_id));
115 }
116
117 let blocked_exists: bool =
118 sqlx::query_scalar::<_, bool>("SELECT COUNT(*) > 0 FROM tasks WHERE id = ?")
119 .bind(blocked_task_id)
120 .fetch_one(pool)
121 .await?;
122
123 if !blocked_exists {
124 return Err(IntentError::TaskNotFound(blocked_task_id));
125 }
126
127 if check_circular_dependency(pool, blocking_task_id, blocked_task_id).await? {
129 return Err(IntentError::CircularDependency {
130 blocking_task_id,
131 blocked_task_id,
132 });
133 }
134
135 let result = sqlx::query(
137 r#"
138 INSERT INTO dependencies (blocking_task_id, blocked_task_id)
139 VALUES (?, ?)
140 "#,
141 )
142 .bind(blocking_task_id)
143 .bind(blocked_task_id)
144 .execute(pool)
145 .await?;
146
147 let dependency_id = result.last_insert_rowid();
148
149 let dependency = sqlx::query_as::<_, Dependency>(
151 "SELECT id, blocking_task_id, blocked_task_id, created_at FROM dependencies WHERE id = ?",
152 )
153 .bind(dependency_id)
154 .fetch_one(pool)
155 .await?;
156
157 Ok(dependency)
158}
159
160pub async fn get_incomplete_blocking_tasks(
174 pool: &SqlitePool,
175 task_id: i64,
176) -> Result<Option<Vec<i64>>> {
177 let incomplete_blocking: Vec<i64> = sqlx::query_scalar::<_, i64>(
178 r#"
179 SELECT d.blocking_task_id
180 FROM dependencies d
181 JOIN tasks t ON t.id = d.blocking_task_id
182 WHERE d.blocked_task_id = ?
183 AND t.status IN ('todo', 'doing')
184 "#,
185 )
186 .bind(task_id)
187 .fetch_all(pool)
188 .await?;
189
190 if incomplete_blocking.is_empty() {
191 Ok(None)
192 } else {
193 Ok(Some(incomplete_blocking))
194 }
195}
196
197#[cfg(test)]
198mod tests {
199 use super::*;
200 use crate::db::{create_pool, run_migrations};
201 use tempfile::TempDir;
202
203 async fn setup_test_db() -> (TempDir, SqlitePool) {
204 let temp_dir = TempDir::new().unwrap();
205 let db_path = temp_dir.path().join("test.db");
206 let pool = create_pool(&db_path).await.unwrap();
207 run_migrations(&pool).await.unwrap();
208 (temp_dir, pool)
209 }
210
211 async fn create_test_task(pool: &SqlitePool, name: &str) -> i64 {
212 sqlx::query("INSERT INTO tasks (name, status) VALUES (?, 'todo')")
213 .bind(name)
214 .execute(pool)
215 .await
216 .unwrap()
217 .last_insert_rowid()
218 }
219
220 #[tokio::test]
221 async fn test_check_circular_dependency_self() {
222 let (_temp, pool) = setup_test_db().await;
223 let task_a = create_test_task(&pool, "Task A").await;
224
225 let is_circular = check_circular_dependency(&pool, task_a, task_a)
227 .await
228 .unwrap();
229 assert!(is_circular);
230 }
231
232 #[tokio::test]
233 async fn test_check_circular_dependency_direct_cycle() {
234 let (_temp, pool) = setup_test_db().await;
235 let task_a = create_test_task(&pool, "Task A").await;
236 let task_b = create_test_task(&pool, "Task B").await;
237
238 sqlx::query("INSERT INTO dependencies (blocking_task_id, blocked_task_id) VALUES (?, ?)")
240 .bind(task_b)
241 .bind(task_a)
242 .execute(&pool)
243 .await
244 .unwrap();
245
246 let is_circular = check_circular_dependency(&pool, task_a, task_b)
249 .await
250 .unwrap();
251 assert!(is_circular);
252 }
253
254 #[tokio::test]
255 async fn test_check_circular_dependency_transitive_cycle() {
256 let (_temp, pool) = setup_test_db().await;
257 let task_a = create_test_task(&pool, "Task A").await;
258 let task_b = create_test_task(&pool, "Task B").await;
259 let task_c = create_test_task(&pool, "Task C").await;
260
261 sqlx::query("INSERT INTO dependencies (blocking_task_id, blocked_task_id) VALUES (?, ?)")
263 .bind(task_b)
264 .bind(task_a)
265 .execute(&pool)
266 .await
267 .unwrap();
268
269 sqlx::query("INSERT INTO dependencies (blocking_task_id, blocked_task_id) VALUES (?, ?)")
270 .bind(task_c)
271 .bind(task_b)
272 .execute(&pool)
273 .await
274 .unwrap();
275
276 let is_circular = check_circular_dependency(&pool, task_a, task_c)
278 .await
279 .unwrap();
280 assert!(is_circular);
281 }
282
283 #[tokio::test]
284 async fn test_check_circular_dependency_no_cycle() {
285 let (_temp, pool) = setup_test_db().await;
286 let task_a = create_test_task(&pool, "Task A").await;
287 let task_b = create_test_task(&pool, "Task B").await;
288 let task_c = create_test_task(&pool, "Task C").await;
289
290 sqlx::query("INSERT INTO dependencies (blocking_task_id, blocked_task_id) VALUES (?, ?)")
292 .bind(task_b)
293 .bind(task_a)
294 .execute(&pool)
295 .await
296 .unwrap();
297
298 let is_circular = check_circular_dependency(&pool, task_a, task_c)
300 .await
301 .unwrap();
302 assert!(!is_circular);
303 }
304
305 #[tokio::test]
306 async fn test_check_circular_dependency_deep_chain() {
307 let (_temp, pool) = setup_test_db().await;
308 let task_a = create_test_task(&pool, "Task A").await;
309 let task_b = create_test_task(&pool, "Task B").await;
310 let task_c = create_test_task(&pool, "Task C").await;
311 let task_d = create_test_task(&pool, "Task D").await;
312 let task_e = create_test_task(&pool, "Task E").await;
313
314 sqlx::query("INSERT INTO dependencies (blocking_task_id, blocked_task_id) VALUES (?, ?)")
316 .bind(task_b)
317 .bind(task_a)
318 .execute(&pool)
319 .await
320 .unwrap();
321
322 sqlx::query("INSERT INTO dependencies (blocking_task_id, blocked_task_id) VALUES (?, ?)")
323 .bind(task_c)
324 .bind(task_b)
325 .execute(&pool)
326 .await
327 .unwrap();
328
329 sqlx::query("INSERT INTO dependencies (blocking_task_id, blocked_task_id) VALUES (?, ?)")
330 .bind(task_d)
331 .bind(task_c)
332 .execute(&pool)
333 .await
334 .unwrap();
335
336 sqlx::query("INSERT INTO dependencies (blocking_task_id, blocked_task_id) VALUES (?, ?)")
337 .bind(task_e)
338 .bind(task_d)
339 .execute(&pool)
340 .await
341 .unwrap();
342
343 let is_circular = check_circular_dependency(&pool, task_a, task_e)
345 .await
346 .unwrap();
347 assert!(is_circular);
348 }
349
350 #[tokio::test]
351 async fn test_add_dependency_success() {
352 let (_temp, pool) = setup_test_db().await;
353 let task_a = create_test_task(&pool, "Task A").await;
354 let task_b = create_test_task(&pool, "Task B").await;
355
356 let dep = add_dependency(&pool, task_b, task_a).await.unwrap();
357
358 assert_eq!(dep.blocking_task_id, task_b);
359 assert_eq!(dep.blocked_task_id, task_a);
360 }
361
362 #[tokio::test]
363 async fn test_add_dependency_circular_error() {
364 let (_temp, pool) = setup_test_db().await;
365 let task_a = create_test_task(&pool, "Task A").await;
366 let task_b = create_test_task(&pool, "Task B").await;
367
368 add_dependency(&pool, task_b, task_a).await.unwrap();
370
371 let result = add_dependency(&pool, task_a, task_b).await;
373 assert!(matches!(
374 result,
375 Err(IntentError::CircularDependency { .. })
376 ));
377 }
378
379 #[tokio::test]
380 async fn test_add_dependency_task_not_found() {
381 let (_temp, pool) = setup_test_db().await;
382 let task_a = create_test_task(&pool, "Task A").await;
383
384 let result = add_dependency(&pool, 9999, task_a).await;
386 assert!(matches!(result, Err(IntentError::TaskNotFound(9999))));
387 }
388
389 #[tokio::test]
390 async fn test_get_incomplete_blocking_tasks_blocked() {
391 let (_temp, pool) = setup_test_db().await;
392 let task_a = create_test_task(&pool, "Task A").await;
393 let task_b = create_test_task(&pool, "Task B").await;
394
395 add_dependency(&pool, task_b, task_a).await.unwrap();
397
398 let incomplete = get_incomplete_blocking_tasks(&pool, task_a).await.unwrap();
399 assert!(incomplete.is_some());
400 assert_eq!(incomplete.unwrap(), vec![task_b]);
401 }
402
403 #[tokio::test]
404 async fn test_get_incomplete_blocking_tasks_not_blocked() {
405 let (_temp, pool) = setup_test_db().await;
406 let task_a = create_test_task(&pool, "Task A").await;
407 let task_b = create_test_task(&pool, "Task B").await;
408
409 add_dependency(&pool, task_b, task_a).await.unwrap();
411 sqlx::query("UPDATE tasks SET status = 'done' WHERE id = ?")
412 .bind(task_b)
413 .execute(&pool)
414 .await
415 .unwrap();
416
417 let incomplete = get_incomplete_blocking_tasks(&pool, task_a).await.unwrap();
418 assert!(incomplete.is_none());
419 }
420}