Skip to main content

brainwires_mdap/decomposition/
mod.rs

1//! Task Decomposition Module
2//!
3//! Implements task decomposition strategies for MDAP, breaking complex tasks
4//! into minimal subtasks that can be executed by microagents.
5//!
6//! The paper's approach is Maximal Agentic Decomposition (MAD), where
7//! each subtask should be as small as possible (m=1).
8
9pub mod recursive;
10
11// Re-export commonly used types from recursive
12pub use recursive::{BinaryRecursiveDecomposer, SimpleRecursiveDecomposer};
13
14use super::error::{DecompositionError, MdapResult};
15use super::microagent::Subtask;
16
17/// Context for task decomposition
18#[derive(Clone, Debug)]
19pub struct DecomposeContext {
20    /// Working directory for file operations
21    pub working_directory: String,
22    /// Available tools for the agent
23    pub available_tools: Vec<String>,
24    /// Maximum decomposition depth
25    pub max_depth: u32,
26    /// Current depth in recursive decomposition
27    pub current_depth: u32,
28    /// Additional context/constraints
29    pub additional_context: Option<String>,
30}
31
32impl Default for DecomposeContext {
33    fn default() -> Self {
34        Self {
35            working_directory: ".".to_string(),
36            available_tools: Vec::new(),
37            max_depth: 10,
38            current_depth: 0,
39            additional_context: None,
40        }
41    }
42}
43
44impl DecomposeContext {
45    /// Create a new context
46    pub fn new(working_directory: impl Into<String>) -> Self {
47        Self {
48            working_directory: working_directory.into(),
49            ..Default::default()
50        }
51    }
52
53    /// Add available tools
54    pub fn with_tools(mut self, tools: Vec<String>) -> Self {
55        self.available_tools = tools;
56        self
57    }
58
59    /// Set max depth
60    pub fn with_max_depth(mut self, depth: u32) -> Self {
61        self.max_depth = depth;
62        self
63    }
64
65    /// Add additional context
66    pub fn with_context(mut self, context: impl Into<String>) -> Self {
67        self.additional_context = Some(context.into());
68        self
69    }
70
71    /// Create a child context (increment depth)
72    pub fn child(&self) -> Self {
73        Self {
74            working_directory: self.working_directory.clone(),
75            available_tools: self.available_tools.clone(),
76            max_depth: self.max_depth,
77            current_depth: self.current_depth + 1,
78            additional_context: self.additional_context.clone(),
79        }
80    }
81
82    /// Check if we've exceeded max depth
83    pub fn at_max_depth(&self) -> bool {
84        self.current_depth >= self.max_depth
85    }
86}
87
88/// Result of task decomposition
89#[derive(Clone, Debug)]
90pub struct DecompositionResult {
91    /// The subtasks resulting from decomposition
92    pub subtasks: Vec<Subtask>,
93    /// How to combine results from subtasks
94    pub composition_function: CompositionFunction,
95    /// Whether the task is already minimal (cannot decompose further)
96    pub is_minimal: bool,
97    /// Estimated total complexity (sum of subtask complexities)
98    pub total_complexity: f32,
99}
100
101impl DecompositionResult {
102    /// Create a minimal (atomic) result
103    pub fn atomic(subtask: Subtask) -> Self {
104        let complexity = subtask.complexity_estimate;
105        Self {
106            subtasks: vec![subtask],
107            composition_function: CompositionFunction::Identity,
108            is_minimal: true,
109            total_complexity: complexity,
110        }
111    }
112
113    /// Create a result with multiple subtasks
114    pub fn composite(subtasks: Vec<Subtask>, composition: CompositionFunction) -> Self {
115        let total_complexity: f32 = subtasks.iter().map(|s| s.complexity_estimate).sum();
116        Self {
117            subtasks,
118            composition_function: composition,
119            is_minimal: false,
120            total_complexity,
121        }
122    }
123}
124
125/// How to combine results from subtasks
126#[derive(Clone, Debug)]
127pub enum CompositionFunction {
128    /// Single result, no composition needed
129    Identity,
130    /// Concatenate all results
131    Concatenate,
132    /// Merge as a sequence
133    Sequence,
134    /// Combine into an object with subtask IDs as keys
135    ObjectMerge,
136    /// Take the last result only
137    LastOnly,
138    /// Custom composition (described as a prompt)
139    Custom(String),
140    /// Reduce with an operation
141    Reduce {
142        /// The reduce operation to apply.
143        operation: String,
144    },
145}
146
147impl CompositionFunction {
148    /// Get a description of this composition function
149    pub fn description(&self) -> String {
150        match self {
151            CompositionFunction::Identity => "identity (single result)".to_string(),
152            CompositionFunction::Concatenate => "concatenate all results".to_string(),
153            CompositionFunction::Sequence => "merge as sequence".to_string(),
154            CompositionFunction::ObjectMerge => "merge into object".to_string(),
155            CompositionFunction::LastOnly => "take last result".to_string(),
156            CompositionFunction::Custom(desc) => format!("custom: {}", desc),
157            CompositionFunction::Reduce { operation } => format!("reduce with {}", operation),
158        }
159    }
160}
161
162/// Task decomposition strategy
163#[derive(Clone, Debug)]
164pub enum DecompositionStrategy {
165    /// Binary recursive decomposition (paper's approach for multiplication)
166    BinaryRecursive {
167        /// Maximum recursion depth.
168        max_depth: u32,
169    },
170    /// Simple text-based decomposition (for testing)
171    Simple {
172        /// Maximum recursion depth.
173        max_depth: u32,
174    },
175    /// Sequential step-by-step decomposition
176    Sequential,
177    /// Domain-specific decomposition for code operations
178    CodeOperations,
179    /// AI-driven decomposition with discriminator voting
180    AIDriven {
181        /// Number of discriminator votes (k).
182        discriminator_k: u32,
183    },
184    /// No decomposition (execute as single task)
185    None,
186}
187
188impl Default for DecompositionStrategy {
189    fn default() -> Self {
190        DecompositionStrategy::BinaryRecursive { max_depth: 10 }
191    }
192}
193
194/// Trait for task decomposers
195#[async_trait::async_trait]
196pub trait TaskDecomposer: Send + Sync {
197    /// Decompose a task into subtasks
198    async fn decompose(
199        &self,
200        task: &str,
201        context: &DecomposeContext,
202    ) -> MdapResult<DecompositionResult>;
203
204    /// Check if a task is already minimal (cannot decompose further)
205    fn is_minimal(&self, task: &str) -> bool;
206
207    /// Get the decomposition strategy
208    fn strategy(&self) -> DecompositionStrategy;
209}
210
211/// Simple sequential decomposer that breaks tasks into numbered steps
212pub struct SequentialDecomposer {
213    max_steps: u32,
214}
215
216impl SequentialDecomposer {
217    /// Create a new sequential decomposer with the given step limit.
218    pub fn new(max_steps: u32) -> Self {
219        Self { max_steps }
220    }
221}
222
223impl Default for SequentialDecomposer {
224    fn default() -> Self {
225        Self::new(20)
226    }
227}
228
229#[async_trait::async_trait]
230impl TaskDecomposer for SequentialDecomposer {
231    async fn decompose(
232        &self,
233        task: &str,
234        context: &DecomposeContext,
235    ) -> MdapResult<DecompositionResult> {
236        // Simple heuristic: if task has numbered steps, extract them
237        let lines: Vec<&str> = task.lines().collect();
238        let mut subtasks = Vec::new();
239
240        for (i, line) in lines.iter().enumerate() {
241            let trimmed = line.trim();
242            if trimmed.is_empty() {
243                continue;
244            }
245
246            // Check if line starts with a number
247            let is_numbered = trimmed
248                .chars()
249                .next()
250                .map(|c| c.is_ascii_digit())
251                .unwrap_or(false);
252
253            if is_numbered || subtasks.is_empty() {
254                let subtask = Subtask::new(
255                    format!("step_{}", i + 1),
256                    trimmed.to_string(),
257                    serde_json::json!({
258                        "step": i + 1,
259                        "context": context.additional_context
260                    }),
261                )
262                .with_complexity(1.0 / lines.len() as f32);
263
264                subtasks.push(subtask);
265            }
266
267            if subtasks.len() >= self.max_steps as usize {
268                break;
269            }
270        }
271
272        if subtasks.is_empty() {
273            // Treat as single task
274            let subtask = Subtask::atomic(task);
275            return Ok(DecompositionResult::atomic(subtask));
276        }
277
278        // Add dependencies (each step depends on previous)
279        for i in 1..subtasks.len() {
280            let prev_id = subtasks[i - 1].id.clone();
281            subtasks[i].depends_on.push(prev_id);
282        }
283
284        Ok(DecompositionResult::composite(
285            subtasks,
286            CompositionFunction::Sequence,
287        ))
288    }
289
290    fn is_minimal(&self, task: &str) -> bool {
291        // Consider minimal if single line and short
292        !task.contains('\n') && task.len() < 200
293    }
294
295    fn strategy(&self) -> DecompositionStrategy {
296        DecompositionStrategy::Sequential
297    }
298}
299
300/// No-op decomposer that treats everything as atomic
301pub struct AtomicDecomposer;
302
303#[async_trait::async_trait]
304impl TaskDecomposer for AtomicDecomposer {
305    async fn decompose(
306        &self,
307        task: &str,
308        _context: &DecomposeContext,
309    ) -> MdapResult<DecompositionResult> {
310        Ok(DecompositionResult::atomic(Subtask::atomic(task)))
311    }
312
313    fn is_minimal(&self, _task: &str) -> bool {
314        true
315    }
316
317    fn strategy(&self) -> DecompositionStrategy {
318        DecompositionStrategy::None
319    }
320}
321
322/// Validate decomposition result
323pub fn validate_decomposition(result: &DecompositionResult) -> MdapResult<()> {
324    if result.subtasks.is_empty() {
325        return Err(DecompositionError::EmptyResult(
326            "Decomposition produced no subtasks".to_string(),
327        )
328        .into());
329    }
330
331    // Check for circular dependencies
332    let mut visited = std::collections::HashSet::new();
333    for subtask in &result.subtasks {
334        visited.insert(subtask.id.clone());
335    }
336
337    for subtask in &result.subtasks {
338        for dep in &subtask.depends_on {
339            if !visited.contains(dep) {
340                return Err(DecompositionError::InvalidDependency {
341                    subtask: subtask.id.clone(),
342                    dependency: dep.clone(),
343                }
344                .into());
345            }
346        }
347    }
348
349    Ok(())
350}
351
352/// Order subtasks by dependencies (topological sort)
353pub fn topological_sort(subtasks: &[Subtask]) -> MdapResult<Vec<Subtask>> {
354    use std::collections::{HashMap, VecDeque};
355
356    let mut in_degree: HashMap<String, usize> = HashMap::new();
357    let mut graph: HashMap<String, Vec<String>> = HashMap::new();
358
359    // Initialize
360    for subtask in subtasks {
361        in_degree.insert(subtask.id.clone(), subtask.depends_on.len());
362        graph.insert(subtask.id.clone(), Vec::new());
363    }
364
365    // Build reverse graph (who depends on whom)
366    for subtask in subtasks {
367        for dep in &subtask.depends_on {
368            if let Some(dependents) = graph.get_mut(dep) {
369                dependents.push(subtask.id.clone());
370            }
371        }
372    }
373
374    // Find all subtasks with no dependencies
375    let mut queue: VecDeque<String> = in_degree
376        .iter()
377        .filter(|(_, deg)| **deg == 0)
378        .map(|(id, _)| id.clone())
379        .collect();
380
381    let mut result = Vec::new();
382    let subtask_map: HashMap<_, _> = subtasks.iter().map(|s| (s.id.clone(), s.clone())).collect();
383
384    while let Some(id) = queue.pop_front() {
385        if let Some(subtask) = subtask_map.get(&id) {
386            result.push(subtask.clone());
387        }
388
389        if let Some(dependents) = graph.get(&id) {
390            for dependent in dependents {
391                if let Some(deg) = in_degree.get_mut(dependent) {
392                    *deg -= 1;
393                    if *deg == 0 {
394                        queue.push_back(dependent.clone());
395                    }
396                }
397            }
398        }
399    }
400
401    if result.len() != subtasks.len() {
402        return Err(DecompositionError::CircularDependency(
403            "Circular dependency detected in subtasks".to_string(),
404        )
405        .into());
406    }
407
408    Ok(result)
409}
410
411#[cfg(test)]
412mod tests {
413    use super::*;
414
415    #[test]
416    fn test_decompose_context() {
417        let ctx = DecomposeContext::new("/home/user/project")
418            .with_tools(vec!["read".to_string(), "write".to_string()])
419            .with_max_depth(5);
420
421        assert_eq!(ctx.working_directory, "/home/user/project");
422        assert_eq!(ctx.available_tools.len(), 2);
423        assert_eq!(ctx.max_depth, 5);
424    }
425
426    #[test]
427    fn test_context_child() {
428        let parent = DecomposeContext::new("/home").with_max_depth(5);
429        let child = parent.child();
430
431        assert_eq!(child.current_depth, 1);
432        assert_eq!(child.max_depth, 5);
433    }
434
435    #[test]
436    fn test_decomposition_result_atomic() {
437        let subtask = Subtask::atomic("Test");
438        let result = DecompositionResult::atomic(subtask);
439
440        assert!(result.is_minimal);
441        assert_eq!(result.subtasks.len(), 1);
442    }
443
444    #[test]
445    fn test_topological_sort_simple() {
446        let subtasks = vec![
447            Subtask::new("a", "Task A", serde_json::Value::Null),
448            Subtask::new("b", "Task B", serde_json::Value::Null).depends_on(vec!["a".to_string()]),
449            Subtask::new("c", "Task C", serde_json::Value::Null).depends_on(vec!["b".to_string()]),
450        ];
451
452        let sorted = topological_sort(&subtasks).unwrap();
453        assert_eq!(sorted[0].id, "a");
454        assert_eq!(sorted[1].id, "b");
455        assert_eq!(sorted[2].id, "c");
456    }
457
458    #[test]
459    fn test_topological_sort_parallel() {
460        let subtasks = vec![
461            Subtask::new("a", "Task A", serde_json::Value::Null),
462            Subtask::new("b", "Task B", serde_json::Value::Null),
463            Subtask::new("c", "Task C", serde_json::Value::Null)
464                .depends_on(vec!["a".to_string(), "b".to_string()]),
465        ];
466
467        let sorted = topological_sort(&subtasks).unwrap();
468        // a and b should come before c
469        let c_pos = sorted.iter().position(|s| s.id == "c").unwrap();
470        let a_pos = sorted.iter().position(|s| s.id == "a").unwrap();
471        let b_pos = sorted.iter().position(|s| s.id == "b").unwrap();
472        assert!(a_pos < c_pos);
473        assert!(b_pos < c_pos);
474    }
475
476    #[test]
477    fn test_topological_sort_circular() {
478        let subtasks = vec![
479            Subtask::new("a", "Task A", serde_json::Value::Null).depends_on(vec!["c".to_string()]),
480            Subtask::new("b", "Task B", serde_json::Value::Null).depends_on(vec!["a".to_string()]),
481            Subtask::new("c", "Task C", serde_json::Value::Null).depends_on(vec!["b".to_string()]),
482        ];
483
484        let result = topological_sort(&subtasks);
485        assert!(result.is_err());
486    }
487
488    #[tokio::test]
489    async fn test_atomic_decomposer() {
490        let decomposer = AtomicDecomposer;
491        let result = decomposer
492            .decompose("Test task", &DecomposeContext::default())
493            .await
494            .unwrap();
495
496        assert!(result.is_minimal);
497        assert_eq!(result.subtasks.len(), 1);
498    }
499
500    #[tokio::test]
501    async fn test_sequential_decomposer() {
502        let decomposer = SequentialDecomposer::new(10);
503        let task = "1. First step\n2. Second step\n3. Third step";
504        let result = decomposer
505            .decompose(task, &DecomposeContext::default())
506            .await
507            .unwrap();
508
509        assert_eq!(result.subtasks.len(), 3);
510        assert!(!result.is_minimal);
511    }
512
513    #[test]
514    fn test_validate_decomposition_valid() {
515        let result = DecompositionResult::composite(
516            vec![
517                Subtask::new("a", "Task A", serde_json::Value::Null),
518                Subtask::new("b", "Task B", serde_json::Value::Null)
519                    .depends_on(vec!["a".to_string()]),
520            ],
521            CompositionFunction::Sequence,
522        );
523
524        assert!(validate_decomposition(&result).is_ok());
525    }
526
527    #[test]
528    fn test_validate_decomposition_invalid_dep() {
529        let result = DecompositionResult::composite(
530            vec![
531                Subtask::new("a", "Task A", serde_json::Value::Null)
532                    .depends_on(vec!["nonexistent".to_string()]),
533            ],
534            CompositionFunction::Sequence,
535        );
536
537        assert!(validate_decomposition(&result).is_err());
538    }
539}