brainwires_mdap/decomposition/
mod.rs1pub mod recursive;
10
11pub use recursive::{BinaryRecursiveDecomposer, SimpleRecursiveDecomposer};
13
14use super::error::{DecompositionError, MdapResult};
15use super::microagent::Subtask;
16
17#[derive(Clone, Debug)]
19pub struct DecomposeContext {
20 pub working_directory: String,
22 pub available_tools: Vec<String>,
24 pub max_depth: u32,
26 pub current_depth: u32,
28 pub additional_context: Option<String>,
30}
31
32impl Default for DecomposeContext {
33 fn default() -> Self {
34 Self {
35 working_directory: ".".to_string(),
36 available_tools: Vec::new(),
37 max_depth: 10,
38 current_depth: 0,
39 additional_context: None,
40 }
41 }
42}
43
44impl DecomposeContext {
45 pub fn new(working_directory: impl Into<String>) -> Self {
47 Self {
48 working_directory: working_directory.into(),
49 ..Default::default()
50 }
51 }
52
53 pub fn with_tools(mut self, tools: Vec<String>) -> Self {
55 self.available_tools = tools;
56 self
57 }
58
59 pub fn with_max_depth(mut self, depth: u32) -> Self {
61 self.max_depth = depth;
62 self
63 }
64
65 pub fn with_context(mut self, context: impl Into<String>) -> Self {
67 self.additional_context = Some(context.into());
68 self
69 }
70
71 pub fn child(&self) -> Self {
73 Self {
74 working_directory: self.working_directory.clone(),
75 available_tools: self.available_tools.clone(),
76 max_depth: self.max_depth,
77 current_depth: self.current_depth + 1,
78 additional_context: self.additional_context.clone(),
79 }
80 }
81
82 pub fn at_max_depth(&self) -> bool {
84 self.current_depth >= self.max_depth
85 }
86}
87
88#[derive(Clone, Debug)]
90pub struct DecompositionResult {
91 pub subtasks: Vec<Subtask>,
93 pub composition_function: CompositionFunction,
95 pub is_minimal: bool,
97 pub total_complexity: f32,
99}
100
101impl DecompositionResult {
102 pub fn atomic(subtask: Subtask) -> Self {
104 let complexity = subtask.complexity_estimate;
105 Self {
106 subtasks: vec![subtask],
107 composition_function: CompositionFunction::Identity,
108 is_minimal: true,
109 total_complexity: complexity,
110 }
111 }
112
113 pub fn composite(subtasks: Vec<Subtask>, composition: CompositionFunction) -> Self {
115 let total_complexity: f32 = subtasks.iter().map(|s| s.complexity_estimate).sum();
116 Self {
117 subtasks,
118 composition_function: composition,
119 is_minimal: false,
120 total_complexity,
121 }
122 }
123}
124
125#[derive(Clone, Debug)]
127pub enum CompositionFunction {
128 Identity,
130 Concatenate,
132 Sequence,
134 ObjectMerge,
136 LastOnly,
138 Custom(String),
140 Reduce {
142 operation: String,
144 },
145}
146
147impl CompositionFunction {
148 pub fn description(&self) -> String {
150 match self {
151 CompositionFunction::Identity => "identity (single result)".to_string(),
152 CompositionFunction::Concatenate => "concatenate all results".to_string(),
153 CompositionFunction::Sequence => "merge as sequence".to_string(),
154 CompositionFunction::ObjectMerge => "merge into object".to_string(),
155 CompositionFunction::LastOnly => "take last result".to_string(),
156 CompositionFunction::Custom(desc) => format!("custom: {}", desc),
157 CompositionFunction::Reduce { operation } => format!("reduce with {}", operation),
158 }
159 }
160}
161
162#[derive(Clone, Debug)]
164pub enum DecompositionStrategy {
165 BinaryRecursive {
167 max_depth: u32,
169 },
170 Simple {
172 max_depth: u32,
174 },
175 Sequential,
177 CodeOperations,
179 AIDriven {
181 discriminator_k: u32,
183 },
184 None,
186}
187
188impl Default for DecompositionStrategy {
189 fn default() -> Self {
190 DecompositionStrategy::BinaryRecursive { max_depth: 10 }
191 }
192}
193
194#[async_trait::async_trait]
196pub trait TaskDecomposer: Send + Sync {
197 async fn decompose(
199 &self,
200 task: &str,
201 context: &DecomposeContext,
202 ) -> MdapResult<DecompositionResult>;
203
204 fn is_minimal(&self, task: &str) -> bool;
206
207 fn strategy(&self) -> DecompositionStrategy;
209}
210
211pub struct SequentialDecomposer {
213 max_steps: u32,
214}
215
216impl SequentialDecomposer {
217 pub fn new(max_steps: u32) -> Self {
219 Self { max_steps }
220 }
221}
222
223impl Default for SequentialDecomposer {
224 fn default() -> Self {
225 Self::new(20)
226 }
227}
228
229#[async_trait::async_trait]
230impl TaskDecomposer for SequentialDecomposer {
231 async fn decompose(
232 &self,
233 task: &str,
234 context: &DecomposeContext,
235 ) -> MdapResult<DecompositionResult> {
236 let lines: Vec<&str> = task.lines().collect();
238 let mut subtasks = Vec::new();
239
240 for (i, line) in lines.iter().enumerate() {
241 let trimmed = line.trim();
242 if trimmed.is_empty() {
243 continue;
244 }
245
246 let is_numbered = trimmed
248 .chars()
249 .next()
250 .map(|c| c.is_ascii_digit())
251 .unwrap_or(false);
252
253 if is_numbered || subtasks.is_empty() {
254 let subtask = Subtask::new(
255 format!("step_{}", i + 1),
256 trimmed.to_string(),
257 serde_json::json!({
258 "step": i + 1,
259 "context": context.additional_context
260 }),
261 )
262 .with_complexity(1.0 / lines.len() as f32);
263
264 subtasks.push(subtask);
265 }
266
267 if subtasks.len() >= self.max_steps as usize {
268 break;
269 }
270 }
271
272 if subtasks.is_empty() {
273 let subtask = Subtask::atomic(task);
275 return Ok(DecompositionResult::atomic(subtask));
276 }
277
278 for i in 1..subtasks.len() {
280 let prev_id = subtasks[i - 1].id.clone();
281 subtasks[i].depends_on.push(prev_id);
282 }
283
284 Ok(DecompositionResult::composite(
285 subtasks,
286 CompositionFunction::Sequence,
287 ))
288 }
289
290 fn is_minimal(&self, task: &str) -> bool {
291 !task.contains('\n') && task.len() < 200
293 }
294
295 fn strategy(&self) -> DecompositionStrategy {
296 DecompositionStrategy::Sequential
297 }
298}
299
300pub struct AtomicDecomposer;
302
303#[async_trait::async_trait]
304impl TaskDecomposer for AtomicDecomposer {
305 async fn decompose(
306 &self,
307 task: &str,
308 _context: &DecomposeContext,
309 ) -> MdapResult<DecompositionResult> {
310 Ok(DecompositionResult::atomic(Subtask::atomic(task)))
311 }
312
313 fn is_minimal(&self, _task: &str) -> bool {
314 true
315 }
316
317 fn strategy(&self) -> DecompositionStrategy {
318 DecompositionStrategy::None
319 }
320}
321
322pub fn validate_decomposition(result: &DecompositionResult) -> MdapResult<()> {
324 if result.subtasks.is_empty() {
325 return Err(DecompositionError::EmptyResult(
326 "Decomposition produced no subtasks".to_string(),
327 )
328 .into());
329 }
330
331 let mut visited = std::collections::HashSet::new();
333 for subtask in &result.subtasks {
334 visited.insert(subtask.id.clone());
335 }
336
337 for subtask in &result.subtasks {
338 for dep in &subtask.depends_on {
339 if !visited.contains(dep) {
340 return Err(DecompositionError::InvalidDependency {
341 subtask: subtask.id.clone(),
342 dependency: dep.clone(),
343 }
344 .into());
345 }
346 }
347 }
348
349 Ok(())
350}
351
352pub fn topological_sort(subtasks: &[Subtask]) -> MdapResult<Vec<Subtask>> {
354 use std::collections::{HashMap, VecDeque};
355
356 let mut in_degree: HashMap<String, usize> = HashMap::new();
357 let mut graph: HashMap<String, Vec<String>> = HashMap::new();
358
359 for subtask in subtasks {
361 in_degree.insert(subtask.id.clone(), subtask.depends_on.len());
362 graph.insert(subtask.id.clone(), Vec::new());
363 }
364
365 for subtask in subtasks {
367 for dep in &subtask.depends_on {
368 if let Some(dependents) = graph.get_mut(dep) {
369 dependents.push(subtask.id.clone());
370 }
371 }
372 }
373
374 let mut queue: VecDeque<String> = in_degree
376 .iter()
377 .filter(|(_, deg)| **deg == 0)
378 .map(|(id, _)| id.clone())
379 .collect();
380
381 let mut result = Vec::new();
382 let subtask_map: HashMap<_, _> = subtasks.iter().map(|s| (s.id.clone(), s.clone())).collect();
383
384 while let Some(id) = queue.pop_front() {
385 if let Some(subtask) = subtask_map.get(&id) {
386 result.push(subtask.clone());
387 }
388
389 if let Some(dependents) = graph.get(&id) {
390 for dependent in dependents {
391 if let Some(deg) = in_degree.get_mut(dependent) {
392 *deg -= 1;
393 if *deg == 0 {
394 queue.push_back(dependent.clone());
395 }
396 }
397 }
398 }
399 }
400
401 if result.len() != subtasks.len() {
402 return Err(DecompositionError::CircularDependency(
403 "Circular dependency detected in subtasks".to_string(),
404 )
405 .into());
406 }
407
408 Ok(result)
409}
410
411#[cfg(test)]
412mod tests {
413 use super::*;
414
415 #[test]
416 fn test_decompose_context() {
417 let ctx = DecomposeContext::new("/home/user/project")
418 .with_tools(vec!["read".to_string(), "write".to_string()])
419 .with_max_depth(5);
420
421 assert_eq!(ctx.working_directory, "/home/user/project");
422 assert_eq!(ctx.available_tools.len(), 2);
423 assert_eq!(ctx.max_depth, 5);
424 }
425
426 #[test]
427 fn test_context_child() {
428 let parent = DecomposeContext::new("/home").with_max_depth(5);
429 let child = parent.child();
430
431 assert_eq!(child.current_depth, 1);
432 assert_eq!(child.max_depth, 5);
433 }
434
435 #[test]
436 fn test_decomposition_result_atomic() {
437 let subtask = Subtask::atomic("Test");
438 let result = DecompositionResult::atomic(subtask);
439
440 assert!(result.is_minimal);
441 assert_eq!(result.subtasks.len(), 1);
442 }
443
444 #[test]
445 fn test_topological_sort_simple() {
446 let subtasks = vec![
447 Subtask::new("a", "Task A", serde_json::Value::Null),
448 Subtask::new("b", "Task B", serde_json::Value::Null).depends_on(vec!["a".to_string()]),
449 Subtask::new("c", "Task C", serde_json::Value::Null).depends_on(vec!["b".to_string()]),
450 ];
451
452 let sorted = topological_sort(&subtasks).unwrap();
453 assert_eq!(sorted[0].id, "a");
454 assert_eq!(sorted[1].id, "b");
455 assert_eq!(sorted[2].id, "c");
456 }
457
458 #[test]
459 fn test_topological_sort_parallel() {
460 let subtasks = vec![
461 Subtask::new("a", "Task A", serde_json::Value::Null),
462 Subtask::new("b", "Task B", serde_json::Value::Null),
463 Subtask::new("c", "Task C", serde_json::Value::Null)
464 .depends_on(vec!["a".to_string(), "b".to_string()]),
465 ];
466
467 let sorted = topological_sort(&subtasks).unwrap();
468 let c_pos = sorted.iter().position(|s| s.id == "c").unwrap();
470 let a_pos = sorted.iter().position(|s| s.id == "a").unwrap();
471 let b_pos = sorted.iter().position(|s| s.id == "b").unwrap();
472 assert!(a_pos < c_pos);
473 assert!(b_pos < c_pos);
474 }
475
476 #[test]
477 fn test_topological_sort_circular() {
478 let subtasks = vec![
479 Subtask::new("a", "Task A", serde_json::Value::Null).depends_on(vec!["c".to_string()]),
480 Subtask::new("b", "Task B", serde_json::Value::Null).depends_on(vec!["a".to_string()]),
481 Subtask::new("c", "Task C", serde_json::Value::Null).depends_on(vec!["b".to_string()]),
482 ];
483
484 let result = topological_sort(&subtasks);
485 assert!(result.is_err());
486 }
487
488 #[tokio::test]
489 async fn test_atomic_decomposer() {
490 let decomposer = AtomicDecomposer;
491 let result = decomposer
492 .decompose("Test task", &DecomposeContext::default())
493 .await
494 .unwrap();
495
496 assert!(result.is_minimal);
497 assert_eq!(result.subtasks.len(), 1);
498 }
499
500 #[tokio::test]
501 async fn test_sequential_decomposer() {
502 let decomposer = SequentialDecomposer::new(10);
503 let task = "1. First step\n2. Second step\n3. Third step";
504 let result = decomposer
505 .decompose(task, &DecomposeContext::default())
506 .await
507 .unwrap();
508
509 assert_eq!(result.subtasks.len(), 3);
510 assert!(!result.is_minimal);
511 }
512
513 #[test]
514 fn test_validate_decomposition_valid() {
515 let result = DecompositionResult::composite(
516 vec![
517 Subtask::new("a", "Task A", serde_json::Value::Null),
518 Subtask::new("b", "Task B", serde_json::Value::Null)
519 .depends_on(vec!["a".to_string()]),
520 ],
521 CompositionFunction::Sequence,
522 );
523
524 assert!(validate_decomposition(&result).is_ok());
525 }
526
527 #[test]
528 fn test_validate_decomposition_invalid_dep() {
529 let result = DecompositionResult::composite(
530 vec![
531 Subtask::new("a", "Task A", serde_json::Value::Null)
532 .depends_on(vec!["nonexistent".to_string()]),
533 ],
534 CompositionFunction::Sequence,
535 );
536
537 assert!(validate_decomposition(&result).is_err());
538 }
539}