intent_engine/
dependencies.rs

1use crate::db::models::Dependency;
2use crate::error::{IntentError, Result};
3use sqlx::SqlitePool;
4
5/// Check if adding a dependency would create a circular dependency.
6///
7/// This function implements a depth-first search using SQLite's recursive CTE
8/// to detect cycles in the dependency graph.
9///
10/// # Algorithm
11///
12/// To check if we can add "blocked_task depends on blocking_task":
13/// 1. Start from blocking_task (the new prerequisite)
14/// 2. Traverse what blocking_task depends on (its blocking tasks)
15/// 3. If we ever reach blocked_task, adding this dependency would create a cycle
16///
17/// # Example
18///
19/// Existing: A depends on B (stored as: blocking=B, blocked=A)
20/// Trying to add: B depends on A (would be: blocking=A, blocked=B)
21///
22/// Check: Does A depend on B?
23/// - Start from A (new blocking task)
24/// - Find what A depends on: B
25/// - We reached B (new blocked task) → Cycle detected!
26///
27/// # Performance
28///
29/// - Time complexity: O(V + E) where V is tasks and E is dependencies
30/// - Expected: <10ms for graphs with 10,000 tasks
31/// - Depth limit: 100 levels to prevent infinite loops
32///
33/// # Arguments
34///
35/// * `pool` - Database connection pool
36/// * `blocking_task_id` - ID of the task that must be completed first
37/// * `blocked_task_id` - ID of the task that depends on the blocking task
38///
39/// # Returns
40///
41/// - `Ok(true)` if adding this dependency would create a cycle
42/// - `Ok(false)` if the dependency is safe to add
43/// - `Err` if database query fails
44pub async fn check_circular_dependency(
45    pool: &SqlitePool,
46    blocking_task_id: i64,
47    blocked_task_id: i64,
48) -> Result<bool> {
49    // Self-dependency is always circular (but should be prevented by DB constraint)
50    if blocking_task_id == blocked_task_id {
51        return Ok(true);
52    }
53
54    // Check if blocking_task already (transitively) depends on blocked_task
55    // If yes, adding "blocked depends on blocking" would create a cycle
56    let has_cycle: bool = sqlx::query_scalar(
57        r#"
58        WITH RECURSIVE dep_chain(task_id, depth) AS (
59            -- Start from the NEW blocking task
60            SELECT ? as task_id, 0 as depth
61
62            UNION ALL
63
64            -- Follow what each task depends on (its blocking tasks)
65            SELECT d.blocking_task_id, dc.depth + 1
66            FROM dependencies d
67            JOIN dep_chain dc ON d.blocked_task_id = dc.task_id
68            WHERE dc.depth < 100
69        )
70        SELECT COUNT(*) > 0
71        FROM dep_chain
72        WHERE task_id = ?
73        "#,
74    )
75    .bind(blocking_task_id)
76    .bind(blocked_task_id)
77    .fetch_one(pool)
78    .await?;
79
80    Ok(has_cycle)
81}
82
83/// Add a dependency between two tasks after checking for circular dependencies.
84///
85/// This is the safe way to add dependencies. It will:
86/// 1. Verify both tasks exist
87/// 2. Check for circular dependencies
88/// 3. Add the dependency if safe
89///
90/// # Arguments
91///
92/// * `pool` - Database connection pool
93/// * `blocking_task_id` - ID of the task that must be completed first
94/// * `blocked_task_id` - ID of the task that depends on the blocking task
95///
96/// # Returns
97///
98/// - `Ok(Dependency)` if the dependency was added successfully
99/// - `Err(IntentError::CircularDependency)` if adding would create a cycle
100/// - `Err(IntentError::TaskNotFound)` if either task doesn't exist
101pub async fn add_dependency(
102    pool: &SqlitePool,
103    blocking_task_id: i64,
104    blocked_task_id: i64,
105) -> Result<Dependency> {
106    // Verify both tasks exist
107    let blocking_exists: bool = sqlx::query_scalar("SELECT COUNT(*) > 0 FROM tasks WHERE id = ?")
108        .bind(blocking_task_id)
109        .fetch_one(pool)
110        .await?;
111
112    if !blocking_exists {
113        return Err(IntentError::TaskNotFound(blocking_task_id));
114    }
115
116    let blocked_exists: bool = sqlx::query_scalar("SELECT COUNT(*) > 0 FROM tasks WHERE id = ?")
117        .bind(blocked_task_id)
118        .fetch_one(pool)
119        .await?;
120
121    if !blocked_exists {
122        return Err(IntentError::TaskNotFound(blocked_task_id));
123    }
124
125    // Check for circular dependency
126    if check_circular_dependency(pool, blocking_task_id, blocked_task_id).await? {
127        return Err(IntentError::CircularDependency {
128            blocking_task_id,
129            blocked_task_id,
130        });
131    }
132
133    // Add the dependency
134    let result = sqlx::query(
135        r#"
136        INSERT INTO dependencies (blocking_task_id, blocked_task_id)
137        VALUES (?, ?)
138        "#,
139    )
140    .bind(blocking_task_id)
141    .bind(blocked_task_id)
142    .execute(pool)
143    .await?;
144
145    let dependency_id = result.last_insert_rowid();
146
147    // Fetch the created dependency
148    let dependency = sqlx::query_as::<_, Dependency>(
149        "SELECT id, blocking_task_id, blocked_task_id, created_at FROM dependencies WHERE id = ?",
150    )
151    .bind(dependency_id)
152    .fetch_one(pool)
153    .await?;
154
155    Ok(dependency)
156}
157
158/// Get all tasks that block a given task (dependencies)
159///
160/// # Arguments
161///
162/// * `pool` - Database connection pool
163/// * `task_id` - ID of the task to check
164///
165/// # Returns
166///
167/// Vector of task IDs that must be completed before the given task can start
168pub async fn get_blocking_tasks(pool: &SqlitePool, task_id: i64) -> Result<Vec<i64>> {
169    let blocking_ids = sqlx::query_scalar::<_, i64>(
170        r#"
171        SELECT blocking_task_id
172        FROM dependencies
173        WHERE blocked_task_id = ?
174        "#,
175    )
176    .bind(task_id)
177    .fetch_all(pool)
178    .await?;
179
180    Ok(blocking_ids)
181}
182
183/// Get all tasks that are blocked by a given task
184///
185/// # Arguments
186///
187/// * `pool` - Database connection pool
188/// * `task_id` - ID of the task to check
189///
190/// # Returns
191///
192/// Vector of task IDs that depend on the given task
193pub async fn get_blocked_tasks(pool: &SqlitePool, task_id: i64) -> Result<Vec<i64>> {
194    let blocked_ids = sqlx::query_scalar::<_, i64>(
195        r#"
196        SELECT blocked_task_id
197        FROM dependencies
198        WHERE blocking_task_id = ?
199        "#,
200    )
201    .bind(task_id)
202    .fetch_all(pool)
203    .await?;
204
205    Ok(blocked_ids)
206}
207
208/// Check if a task is blocked by any incomplete tasks
209///
210/// A task is blocked if any of its blocking tasks are not in 'done' status.
211///
212/// # Arguments
213///
214/// * `pool` - Database connection pool
215/// * `task_id` - ID of the task to check
216///
217/// # Returns
218///
219/// - `Ok(Some(Vec<i64>))` with IDs of incomplete blocking tasks if blocked
220/// - `Ok(None)` if task is not blocked and can be started
221pub async fn get_incomplete_blocking_tasks(
222    pool: &SqlitePool,
223    task_id: i64,
224) -> Result<Option<Vec<i64>>> {
225    let incomplete_blocking: Vec<i64> = sqlx::query_scalar::<_, i64>(
226        r#"
227        SELECT d.blocking_task_id
228        FROM dependencies d
229        JOIN tasks t ON t.id = d.blocking_task_id
230        WHERE d.blocked_task_id = ?
231          AND t.status IN ('todo', 'doing')
232        "#,
233    )
234    .bind(task_id)
235    .fetch_all(pool)
236    .await?;
237
238    if incomplete_blocking.is_empty() {
239        Ok(None)
240    } else {
241        Ok(Some(incomplete_blocking))
242    }
243}
244
245#[cfg(test)]
246mod tests {
247    use super::*;
248    use crate::db::{create_pool, run_migrations};
249    use tempfile::TempDir;
250
251    async fn setup_test_db() -> (TempDir, SqlitePool) {
252        let temp_dir = TempDir::new().unwrap();
253        let db_path = temp_dir.path().join("test.db");
254        let pool = create_pool(&db_path).await.unwrap();
255        run_migrations(&pool).await.unwrap();
256        (temp_dir, pool)
257    }
258
259    async fn create_test_task(pool: &SqlitePool, name: &str) -> i64 {
260        sqlx::query("INSERT INTO tasks (name, status) VALUES (?, 'todo')")
261            .bind(name)
262            .execute(pool)
263            .await
264            .unwrap()
265            .last_insert_rowid()
266    }
267
268    #[tokio::test]
269    async fn test_check_circular_dependency_self() {
270        let (_temp, pool) = setup_test_db().await;
271        let task_a = create_test_task(&pool, "Task A").await;
272
273        // Self-dependency should be detected as circular
274        let is_circular = check_circular_dependency(&pool, task_a, task_a)
275            .await
276            .unwrap();
277        assert!(is_circular);
278    }
279
280    #[tokio::test]
281    async fn test_check_circular_dependency_direct_cycle() {
282        let (_temp, pool) = setup_test_db().await;
283        let task_a = create_test_task(&pool, "Task A").await;
284        let task_b = create_test_task(&pool, "Task B").await;
285
286        // Add dependency: A depends on B (B → A)
287        sqlx::query("INSERT INTO dependencies (blocking_task_id, blocked_task_id) VALUES (?, ?)")
288            .bind(task_b)
289            .bind(task_a)
290            .execute(&pool)
291            .await
292            .unwrap();
293
294        // Try to add reverse dependency: B depends on A (A → B)
295        // This would create a cycle: A → B → A
296        let is_circular = check_circular_dependency(&pool, task_a, task_b)
297            .await
298            .unwrap();
299        assert!(is_circular);
300    }
301
302    #[tokio::test]
303    async fn test_check_circular_dependency_transitive_cycle() {
304        let (_temp, pool) = setup_test_db().await;
305        let task_a = create_test_task(&pool, "Task A").await;
306        let task_b = create_test_task(&pool, "Task B").await;
307        let task_c = create_test_task(&pool, "Task C").await;
308
309        // Create chain: A → B → C
310        sqlx::query("INSERT INTO dependencies (blocking_task_id, blocked_task_id) VALUES (?, ?)")
311            .bind(task_b)
312            .bind(task_a)
313            .execute(&pool)
314            .await
315            .unwrap();
316
317        sqlx::query("INSERT INTO dependencies (blocking_task_id, blocked_task_id) VALUES (?, ?)")
318            .bind(task_c)
319            .bind(task_b)
320            .execute(&pool)
321            .await
322            .unwrap();
323
324        // Try to add C → A (would create cycle: A → B → C → A)
325        let is_circular = check_circular_dependency(&pool, task_a, task_c)
326            .await
327            .unwrap();
328        assert!(is_circular);
329    }
330
331    #[tokio::test]
332    async fn test_check_circular_dependency_no_cycle() {
333        let (_temp, pool) = setup_test_db().await;
334        let task_a = create_test_task(&pool, "Task A").await;
335        let task_b = create_test_task(&pool, "Task B").await;
336        let task_c = create_test_task(&pool, "Task C").await;
337
338        // Create chain: A → B
339        sqlx::query("INSERT INTO dependencies (blocking_task_id, blocked_task_id) VALUES (?, ?)")
340            .bind(task_b)
341            .bind(task_a)
342            .execute(&pool)
343            .await
344            .unwrap();
345
346        // Try to add C → A (no cycle, C is independent)
347        let is_circular = check_circular_dependency(&pool, task_a, task_c)
348            .await
349            .unwrap();
350        assert!(!is_circular);
351    }
352
353    #[tokio::test]
354    async fn test_check_circular_dependency_deep_chain() {
355        let (_temp, pool) = setup_test_db().await;
356        let task_a = create_test_task(&pool, "Task A").await;
357        let task_b = create_test_task(&pool, "Task B").await;
358        let task_c = create_test_task(&pool, "Task C").await;
359        let task_d = create_test_task(&pool, "Task D").await;
360        let task_e = create_test_task(&pool, "Task E").await;
361
362        // Create chain: A → B → C → D → E
363        sqlx::query("INSERT INTO dependencies (blocking_task_id, blocked_task_id) VALUES (?, ?)")
364            .bind(task_b)
365            .bind(task_a)
366            .execute(&pool)
367            .await
368            .unwrap();
369
370        sqlx::query("INSERT INTO dependencies (blocking_task_id, blocked_task_id) VALUES (?, ?)")
371            .bind(task_c)
372            .bind(task_b)
373            .execute(&pool)
374            .await
375            .unwrap();
376
377        sqlx::query("INSERT INTO dependencies (blocking_task_id, blocked_task_id) VALUES (?, ?)")
378            .bind(task_d)
379            .bind(task_c)
380            .execute(&pool)
381            .await
382            .unwrap();
383
384        sqlx::query("INSERT INTO dependencies (blocking_task_id, blocked_task_id) VALUES (?, ?)")
385            .bind(task_e)
386            .bind(task_d)
387            .execute(&pool)
388            .await
389            .unwrap();
390
391        // Try to add E → A (would create long cycle)
392        let is_circular = check_circular_dependency(&pool, task_a, task_e)
393            .await
394            .unwrap();
395        assert!(is_circular);
396    }
397
398    #[tokio::test]
399    async fn test_add_dependency_success() {
400        let (_temp, pool) = setup_test_db().await;
401        let task_a = create_test_task(&pool, "Task A").await;
402        let task_b = create_test_task(&pool, "Task B").await;
403
404        let dep = add_dependency(&pool, task_b, task_a).await.unwrap();
405
406        assert_eq!(dep.blocking_task_id, task_b);
407        assert_eq!(dep.blocked_task_id, task_a);
408    }
409
410    #[tokio::test]
411    async fn test_add_dependency_circular_error() {
412        let (_temp, pool) = setup_test_db().await;
413        let task_a = create_test_task(&pool, "Task A").await;
414        let task_b = create_test_task(&pool, "Task B").await;
415
416        // Add A → B
417        add_dependency(&pool, task_b, task_a).await.unwrap();
418
419        // Try to add B → A (circular)
420        let result = add_dependency(&pool, task_a, task_b).await;
421        assert!(matches!(
422            result,
423            Err(IntentError::CircularDependency { .. })
424        ));
425    }
426
427    #[tokio::test]
428    async fn test_add_dependency_task_not_found() {
429        let (_temp, pool) = setup_test_db().await;
430        let task_a = create_test_task(&pool, "Task A").await;
431
432        // Try to add dependency with non-existent task
433        let result = add_dependency(&pool, 9999, task_a).await;
434        assert!(matches!(result, Err(IntentError::TaskNotFound(9999))));
435    }
436
437    #[tokio::test]
438    async fn test_get_blocking_tasks() {
439        let (_temp, pool) = setup_test_db().await;
440        let task_a = create_test_task(&pool, "Task A").await;
441        let task_b = create_test_task(&pool, "Task B").await;
442        let task_c = create_test_task(&pool, "Task C").await;
443
444        // A depends on B and C
445        add_dependency(&pool, task_b, task_a).await.unwrap();
446        add_dependency(&pool, task_c, task_a).await.unwrap();
447
448        let blocking = get_blocking_tasks(&pool, task_a).await.unwrap();
449        assert_eq!(blocking.len(), 2);
450        assert!(blocking.contains(&task_b));
451        assert!(blocking.contains(&task_c));
452    }
453
454    #[tokio::test]
455    async fn test_get_blocked_tasks() {
456        let (_temp, pool) = setup_test_db().await;
457        let task_a = create_test_task(&pool, "Task A").await;
458        let task_b = create_test_task(&pool, "Task B").await;
459        let task_c = create_test_task(&pool, "Task C").await;
460
461        // B and C depend on A
462        add_dependency(&pool, task_a, task_b).await.unwrap();
463        add_dependency(&pool, task_a, task_c).await.unwrap();
464
465        let blocked = get_blocked_tasks(&pool, task_a).await.unwrap();
466        assert_eq!(blocked.len(), 2);
467        assert!(blocked.contains(&task_b));
468        assert!(blocked.contains(&task_c));
469    }
470
471    #[tokio::test]
472    async fn test_get_incomplete_blocking_tasks_blocked() {
473        let (_temp, pool) = setup_test_db().await;
474        let task_a = create_test_task(&pool, "Task A").await;
475        let task_b = create_test_task(&pool, "Task B").await;
476
477        // A depends on B (B is todo)
478        add_dependency(&pool, task_b, task_a).await.unwrap();
479
480        let incomplete = get_incomplete_blocking_tasks(&pool, task_a).await.unwrap();
481        assert!(incomplete.is_some());
482        assert_eq!(incomplete.unwrap(), vec![task_b]);
483    }
484
485    #[tokio::test]
486    async fn test_get_incomplete_blocking_tasks_not_blocked() {
487        let (_temp, pool) = setup_test_db().await;
488        let task_a = create_test_task(&pool, "Task A").await;
489        let task_b = create_test_task(&pool, "Task B").await;
490
491        // A depends on B, but mark B as done
492        add_dependency(&pool, task_b, task_a).await.unwrap();
493        sqlx::query("UPDATE tasks SET status = 'done' WHERE id = ?")
494            .bind(task_b)
495            .execute(&pool)
496            .await
497            .unwrap();
498
499        let incomplete = get_incomplete_blocking_tasks(&pool, task_a).await.unwrap();
500        assert!(incomplete.is_none());
501    }
502}