1use crate::db::models::Dependency;
2use crate::error::{IntentError, Result};
3use sqlx::SqlitePool;
4
5pub async fn check_circular_dependency(
45 pool: &SqlitePool,
46 blocking_task_id: i64,
47 blocked_task_id: i64,
48) -> Result<bool> {
49 if blocking_task_id == blocked_task_id {
51 return Ok(true);
52 }
53
54 let has_cycle: bool = sqlx::query_scalar(
57 r#"
58 WITH RECURSIVE dep_chain(task_id, depth) AS (
59 -- Start from the NEW blocking task
60 SELECT ? as task_id, 0 as depth
61
62 UNION ALL
63
64 -- Follow what each task depends on (its blocking tasks)
65 SELECT d.blocking_task_id, dc.depth + 1
66 FROM dependencies d
67 JOIN dep_chain dc ON d.blocked_task_id = dc.task_id
68 WHERE dc.depth < 100
69 )
70 SELECT COUNT(*) > 0
71 FROM dep_chain
72 WHERE task_id = ?
73 "#,
74 )
75 .bind(blocking_task_id)
76 .bind(blocked_task_id)
77 .fetch_one(pool)
78 .await?;
79
80 Ok(has_cycle)
81}
82
83pub async fn add_dependency(
102 pool: &SqlitePool,
103 blocking_task_id: i64,
104 blocked_task_id: i64,
105) -> Result<Dependency> {
106 let blocking_exists: bool = sqlx::query_scalar("SELECT COUNT(*) > 0 FROM tasks WHERE id = ?")
108 .bind(blocking_task_id)
109 .fetch_one(pool)
110 .await?;
111
112 if !blocking_exists {
113 return Err(IntentError::TaskNotFound(blocking_task_id));
114 }
115
116 let blocked_exists: bool = sqlx::query_scalar("SELECT COUNT(*) > 0 FROM tasks WHERE id = ?")
117 .bind(blocked_task_id)
118 .fetch_one(pool)
119 .await?;
120
121 if !blocked_exists {
122 return Err(IntentError::TaskNotFound(blocked_task_id));
123 }
124
125 if check_circular_dependency(pool, blocking_task_id, blocked_task_id).await? {
127 return Err(IntentError::CircularDependency {
128 blocking_task_id,
129 blocked_task_id,
130 });
131 }
132
133 let result = sqlx::query(
135 r#"
136 INSERT INTO dependencies (blocking_task_id, blocked_task_id)
137 VALUES (?, ?)
138 "#,
139 )
140 .bind(blocking_task_id)
141 .bind(blocked_task_id)
142 .execute(pool)
143 .await?;
144
145 let dependency_id = result.last_insert_rowid();
146
147 let dependency = sqlx::query_as::<_, Dependency>(
149 "SELECT id, blocking_task_id, blocked_task_id, created_at FROM dependencies WHERE id = ?",
150 )
151 .bind(dependency_id)
152 .fetch_one(pool)
153 .await?;
154
155 Ok(dependency)
156}
157
158pub async fn get_incomplete_blocking_tasks(
172 pool: &SqlitePool,
173 task_id: i64,
174) -> Result<Option<Vec<i64>>> {
175 let incomplete_blocking: Vec<i64> = sqlx::query_scalar::<_, i64>(
176 r#"
177 SELECT d.blocking_task_id
178 FROM dependencies d
179 JOIN tasks t ON t.id = d.blocking_task_id
180 WHERE d.blocked_task_id = ?
181 AND t.status IN ('todo', 'doing')
182 "#,
183 )
184 .bind(task_id)
185 .fetch_all(pool)
186 .await?;
187
188 if incomplete_blocking.is_empty() {
189 Ok(None)
190 } else {
191 Ok(Some(incomplete_blocking))
192 }
193}
194
195#[cfg(test)]
196mod tests {
197 use super::*;
198 use crate::db::{create_pool, run_migrations};
199 use tempfile::TempDir;
200
201 async fn setup_test_db() -> (TempDir, SqlitePool) {
202 let temp_dir = TempDir::new().unwrap();
203 let db_path = temp_dir.path().join("test.db");
204 let pool = create_pool(&db_path).await.unwrap();
205 run_migrations(&pool).await.unwrap();
206 (temp_dir, pool)
207 }
208
209 async fn create_test_task(pool: &SqlitePool, name: &str) -> i64 {
210 sqlx::query("INSERT INTO tasks (name, status) VALUES (?, 'todo')")
211 .bind(name)
212 .execute(pool)
213 .await
214 .unwrap()
215 .last_insert_rowid()
216 }
217
218 #[tokio::test]
219 async fn test_check_circular_dependency_self() {
220 let (_temp, pool) = setup_test_db().await;
221 let task_a = create_test_task(&pool, "Task A").await;
222
223 let is_circular = check_circular_dependency(&pool, task_a, task_a)
225 .await
226 .unwrap();
227 assert!(is_circular);
228 }
229
230 #[tokio::test]
231 async fn test_check_circular_dependency_direct_cycle() {
232 let (_temp, pool) = setup_test_db().await;
233 let task_a = create_test_task(&pool, "Task A").await;
234 let task_b = create_test_task(&pool, "Task B").await;
235
236 sqlx::query("INSERT INTO dependencies (blocking_task_id, blocked_task_id) VALUES (?, ?)")
238 .bind(task_b)
239 .bind(task_a)
240 .execute(&pool)
241 .await
242 .unwrap();
243
244 let is_circular = check_circular_dependency(&pool, task_a, task_b)
247 .await
248 .unwrap();
249 assert!(is_circular);
250 }
251
252 #[tokio::test]
253 async fn test_check_circular_dependency_transitive_cycle() {
254 let (_temp, pool) = setup_test_db().await;
255 let task_a = create_test_task(&pool, "Task A").await;
256 let task_b = create_test_task(&pool, "Task B").await;
257 let task_c = create_test_task(&pool, "Task C").await;
258
259 sqlx::query("INSERT INTO dependencies (blocking_task_id, blocked_task_id) VALUES (?, ?)")
261 .bind(task_b)
262 .bind(task_a)
263 .execute(&pool)
264 .await
265 .unwrap();
266
267 sqlx::query("INSERT INTO dependencies (blocking_task_id, blocked_task_id) VALUES (?, ?)")
268 .bind(task_c)
269 .bind(task_b)
270 .execute(&pool)
271 .await
272 .unwrap();
273
274 let is_circular = check_circular_dependency(&pool, task_a, task_c)
276 .await
277 .unwrap();
278 assert!(is_circular);
279 }
280
281 #[tokio::test]
282 async fn test_check_circular_dependency_no_cycle() {
283 let (_temp, pool) = setup_test_db().await;
284 let task_a = create_test_task(&pool, "Task A").await;
285 let task_b = create_test_task(&pool, "Task B").await;
286 let task_c = create_test_task(&pool, "Task C").await;
287
288 sqlx::query("INSERT INTO dependencies (blocking_task_id, blocked_task_id) VALUES (?, ?)")
290 .bind(task_b)
291 .bind(task_a)
292 .execute(&pool)
293 .await
294 .unwrap();
295
296 let is_circular = check_circular_dependency(&pool, task_a, task_c)
298 .await
299 .unwrap();
300 assert!(!is_circular);
301 }
302
303 #[tokio::test]
304 async fn test_check_circular_dependency_deep_chain() {
305 let (_temp, pool) = setup_test_db().await;
306 let task_a = create_test_task(&pool, "Task A").await;
307 let task_b = create_test_task(&pool, "Task B").await;
308 let task_c = create_test_task(&pool, "Task C").await;
309 let task_d = create_test_task(&pool, "Task D").await;
310 let task_e = create_test_task(&pool, "Task E").await;
311
312 sqlx::query("INSERT INTO dependencies (blocking_task_id, blocked_task_id) VALUES (?, ?)")
314 .bind(task_b)
315 .bind(task_a)
316 .execute(&pool)
317 .await
318 .unwrap();
319
320 sqlx::query("INSERT INTO dependencies (blocking_task_id, blocked_task_id) VALUES (?, ?)")
321 .bind(task_c)
322 .bind(task_b)
323 .execute(&pool)
324 .await
325 .unwrap();
326
327 sqlx::query("INSERT INTO dependencies (blocking_task_id, blocked_task_id) VALUES (?, ?)")
328 .bind(task_d)
329 .bind(task_c)
330 .execute(&pool)
331 .await
332 .unwrap();
333
334 sqlx::query("INSERT INTO dependencies (blocking_task_id, blocked_task_id) VALUES (?, ?)")
335 .bind(task_e)
336 .bind(task_d)
337 .execute(&pool)
338 .await
339 .unwrap();
340
341 let is_circular = check_circular_dependency(&pool, task_a, task_e)
343 .await
344 .unwrap();
345 assert!(is_circular);
346 }
347
348 #[tokio::test]
349 async fn test_add_dependency_success() {
350 let (_temp, pool) = setup_test_db().await;
351 let task_a = create_test_task(&pool, "Task A").await;
352 let task_b = create_test_task(&pool, "Task B").await;
353
354 let dep = add_dependency(&pool, task_b, task_a).await.unwrap();
355
356 assert_eq!(dep.blocking_task_id, task_b);
357 assert_eq!(dep.blocked_task_id, task_a);
358 }
359
360 #[tokio::test]
361 async fn test_add_dependency_circular_error() {
362 let (_temp, pool) = setup_test_db().await;
363 let task_a = create_test_task(&pool, "Task A").await;
364 let task_b = create_test_task(&pool, "Task B").await;
365
366 add_dependency(&pool, task_b, task_a).await.unwrap();
368
369 let result = add_dependency(&pool, task_a, task_b).await;
371 assert!(matches!(
372 result,
373 Err(IntentError::CircularDependency { .. })
374 ));
375 }
376
377 #[tokio::test]
378 async fn test_add_dependency_task_not_found() {
379 let (_temp, pool) = setup_test_db().await;
380 let task_a = create_test_task(&pool, "Task A").await;
381
382 let result = add_dependency(&pool, 9999, task_a).await;
384 assert!(matches!(result, Err(IntentError::TaskNotFound(9999))));
385 }
386
387 #[tokio::test]
388 async fn test_get_incomplete_blocking_tasks_blocked() {
389 let (_temp, pool) = setup_test_db().await;
390 let task_a = create_test_task(&pool, "Task A").await;
391 let task_b = create_test_task(&pool, "Task B").await;
392
393 add_dependency(&pool, task_b, task_a).await.unwrap();
395
396 let incomplete = get_incomplete_blocking_tasks(&pool, task_a).await.unwrap();
397 assert!(incomplete.is_some());
398 assert_eq!(incomplete.unwrap(), vec![task_b]);
399 }
400
401 #[tokio::test]
402 async fn test_get_incomplete_blocking_tasks_not_blocked() {
403 let (_temp, pool) = setup_test_db().await;
404 let task_a = create_test_task(&pool, "Task A").await;
405 let task_b = create_test_task(&pool, "Task B").await;
406
407 add_dependency(&pool, task_b, task_a).await.unwrap();
409 sqlx::query("UPDATE tasks SET status = 'done' WHERE id = ?")
410 .bind(task_b)
411 .execute(&pool)
412 .await
413 .unwrap();
414
415 let incomplete = get_incomplete_blocking_tasks(&pool, task_a).await.unwrap();
416 assert!(incomplete.is_none());
417 }
418}