use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use async_trait::async_trait;
use log::{debug, info, warn};
use serde::{Deserialize, Serialize};
use terraphim_rolegraph::RoleGraph;
fn extract_paragraphs_from_automata(
_automata: &MockAutomata,
text: &str,
max_results: u32,
) -> Result<Vec<String>, String> {
let words: Vec<String> = text
.split_whitespace()
.take(max_results as usize)
.map(|s| s.to_string())
.collect();
Ok(words)
}
fn is_all_terms_connected_by_path(
_automata: &MockAutomata,
terms: &[&str],
) -> Result<bool, String> {
if terms.len() < 2 {
return Ok(true);
}
let first = terms[0].to_lowercase();
let second = terms[1].to_lowercase();
Ok(first.chars().any(|c| second.contains(c)))
}
use crate::{Automata, MockAutomata};
use crate::{Task, TaskComplexity, TaskDecompositionError, TaskDecompositionResult, TaskId};
#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
pub enum DecompositionStrategy {
KnowledgeGraphBased,
ComplexityBased,
RoleBased,
Hybrid,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DecompositionConfig {
pub max_depth: u32,
pub min_subtask_complexity: TaskComplexity,
pub max_subtasks_per_task: u32,
pub strategy: DecompositionStrategy,
pub similarity_threshold: f64,
pub preserve_dependencies: bool,
pub optimize_for_parallelism: bool,
}
impl Default for DecompositionConfig {
fn default() -> Self {
Self {
max_depth: 3,
min_subtask_complexity: TaskComplexity::Simple,
max_subtasks_per_task: 10,
strategy: DecompositionStrategy::Hybrid,
similarity_threshold: 0.7,
preserve_dependencies: true,
optimize_for_parallelism: true,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DecompositionResult {
pub original_task: TaskId,
pub subtasks: Vec<Task>,
pub dependencies: HashMap<TaskId, Vec<TaskId>>,
pub metadata: DecompositionMetadata,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DecompositionMetadata {
pub strategy_used: DecompositionStrategy,
pub depth: u32,
pub subtask_count: u32,
pub concepts_analyzed: Vec<String>,
pub roles_identified: Vec<String>,
pub confidence_score: f64,
pub parallelism_factor: f64,
}
#[async_trait]
pub trait TaskDecomposer: Send + Sync {
async fn decompose_task(
&self,
task: &Task,
config: &DecompositionConfig,
) -> TaskDecompositionResult<DecompositionResult>;
async fn analyze_complexity(&self, task: &Task) -> TaskDecompositionResult<TaskComplexity>;
async fn validate_decomposition(
&self,
result: &DecompositionResult,
) -> TaskDecompositionResult<bool>;
}
pub struct KnowledgeGraphTaskDecomposer {
automata: Arc<Automata>,
role_graph: Arc<RoleGraph>,
cache: Arc<tokio::sync::RwLock<HashMap<String, DecompositionResult>>>,
}
impl KnowledgeGraphTaskDecomposer {
pub fn new(automata: Arc<Automata>, role_graph: Arc<RoleGraph>) -> Self {
Self {
automata,
role_graph,
cache: Arc::new(tokio::sync::RwLock::new(HashMap::new())),
}
}
async fn extract_task_concepts(&self, task: &Task) -> TaskDecompositionResult<Vec<String>> {
let text = format!(
"{} {}",
task.description,
task.knowledge_context.keywords.join(" ")
);
match extract_paragraphs_from_automata(&self.automata, &text, 10) {
Ok(paragraphs) => {
let mut concepts: Vec<String> = paragraphs
.into_iter()
.flat_map(|p| {
p.split_whitespace()
.map(|s| s.to_lowercase())
.collect::<Vec<_>>()
})
.collect::<HashSet<_>>()
.into_iter()
.collect();
let mut related_concepts = HashSet::new();
for concept in &concepts {
if self.role_graph.is_all_terms_connected_by_path(concept) {
let node_ids = self.role_graph.find_matching_node_ids(concept);
for _node_id in node_ids.iter().take(3) {
let documents = self.role_graph.find_document_ids_for_term(concept);
for doc_id in documents.iter().take(2) {
if let Some(document) = self.role_graph.get_document(doc_id) {
for tag in &document.tags {
if !tag.is_empty() && tag.len() > 2 {
related_concepts.insert(tag.to_lowercase());
}
}
if !doc_id.is_empty() && doc_id.len() > 2 {
related_concepts.insert(doc_id.to_lowercase());
}
}
}
}
}
}
concepts.extend(related_concepts);
concepts.sort();
concepts.dedup();
debug!(
"Extracted {} concepts from task {}",
concepts.len(),
task.task_id
);
Ok(concepts)
}
Err(e) => {
warn!(
"Failed to extract concepts from task {}: {}",
task.task_id, e
);
Err(TaskDecompositionError::KnowledgeGraphError(format!(
"Concept extraction failed: {}",
e
)))
}
}
}
async fn analyze_connectivity(
&self,
concepts: &[String],
_threshold: f64,
) -> TaskDecompositionResult<Vec<Vec<String>>> {
let mut concept_groups = Vec::new();
let mut processed = HashSet::new();
for concept in concepts {
if processed.contains(concept) {
continue;
}
let mut group = vec![concept.clone()];
processed.insert(concept.clone());
for other_concept in concepts {
if processed.contains(other_concept) {
continue;
}
match is_all_terms_connected_by_path(&self.automata, &[concept, other_concept]) {
Ok(connected) => {
if connected {
group.push(other_concept.clone());
processed.insert(other_concept.clone());
}
}
Err(e) => {
debug!(
"Connectivity check failed for {} -> {}: {}",
concept, other_concept, e
);
}
}
}
if group.len() > 1 {
concept_groups.push(group);
}
}
debug!("Found {} concept groups", concept_groups.len());
Ok(concept_groups)
}
async fn generate_subtasks_from_concepts(
&self,
_original_task: &Task,
concept_groups: &[Vec<String>],
config: &DecompositionConfig,
) -> TaskDecompositionResult<Vec<Task>> {
let mut subtasks = Vec::new();
let base_priority = _original_task.priority;
for (i, group) in concept_groups.iter().enumerate() {
if subtasks.len() >= config.max_subtasks_per_task as usize {
break;
}
let subtask_id = format!("{}_{}", _original_task.task_id, i + 1);
let description = format!(
"Subtask of '{}' focusing on: {}",
_original_task.description,
group.join(", ")
);
let mut subtask = Task::new(
subtask_id,
description,
config.min_subtask_complexity.clone(),
base_priority,
);
subtask.knowledge_context.domains = _original_task.knowledge_context.domains.clone();
subtask.knowledge_context.concepts = group.clone();
subtask.knowledge_context.relationships =
_original_task.knowledge_context.relationships.clone();
subtask.knowledge_context.keywords = group.clone();
subtask.knowledge_context.input_types =
_original_task.knowledge_context.input_types.clone();
subtask.knowledge_context.output_types =
_original_task.knowledge_context.output_types.clone();
subtask.knowledge_context.similarity_thresholds = _original_task
.knowledge_context
.similarity_thresholds
.clone();
for constraint in &_original_task.constraints {
use crate::TaskConstraintType;
if matches!(
constraint.constraint_type,
TaskConstraintType::Quality | TaskConstraintType::Security
) {
subtask.add_constraint(constraint.clone())?;
}
}
subtask.parent_goal = _original_task.parent_goal.clone();
let effort_fraction = 1.0 / concept_groups.len() as f64;
subtask.estimated_effort = _original_task.estimated_effort.mul_f64(effort_fraction);
subtasks.push(subtask);
}
info!(
"Generated {} subtasks for task {}",
subtasks.len(),
_original_task.task_id
);
Ok(subtasks)
}
async fn generate_subtask_dependencies(
&self,
subtasks: &[Task],
_original_task: &Task,
config: &DecompositionConfig,
) -> TaskDecompositionResult<HashMap<TaskId, Vec<TaskId>>> {
let mut dependencies = HashMap::new();
if !config.preserve_dependencies {
return Ok(dependencies);
}
for (i, subtask) in subtasks.iter().enumerate() {
let mut deps = Vec::new();
for (j, other_subtask) in subtasks.iter().enumerate() {
if i == j {
continue;
}
let has_dependency = self
.check_concept_dependency(
&subtask.knowledge_context.concepts,
&other_subtask.knowledge_context.concepts,
)
.await?;
if has_dependency && j < i {
deps.push(other_subtask.task_id.clone());
}
}
if !deps.is_empty() {
dependencies.insert(subtask.task_id.clone(), deps);
}
}
debug!("Generated {} dependency relationships", dependencies.len());
Ok(dependencies)
}
async fn check_concept_dependency(
&self,
dependent_concepts: &[String],
prerequisite_concepts: &[String],
) -> TaskDecompositionResult<bool> {
for dep_concept in dependent_concepts {
for prereq_concept in prerequisite_concepts {
match is_all_terms_connected_by_path(&self.automata, &[prereq_concept, dep_concept])
{
Ok(connected) => {
if connected {
return Ok(true);
}
}
Err(_) => {
continue;
}
}
}
}
Ok(false)
}
fn calculate_confidence_score(
&self,
original_task: &Task,
subtasks: &[Task],
concept_groups: &[Vec<String>],
) -> f64 {
let mut score = 0.0;
let original_concepts: HashSet<String> = original_task
.knowledge_context
.concepts
.iter()
.cloned()
.collect();
let subtask_concepts: HashSet<String> = subtasks
.iter()
.flat_map(|t| t.knowledge_context.concepts.iter().cloned())
.collect();
let coverage = if original_concepts.is_empty() {
1.0
} else {
subtask_concepts.intersection(&original_concepts).count() as f64
/ original_concepts.len() as f64
};
score += coverage * 0.4;
let concept_distribution = concept_groups.iter().map(|g| g.len()).collect::<Vec<_>>();
let mean_size =
concept_distribution.iter().sum::<usize>() as f64 / concept_distribution.len() as f64;
let variance = concept_distribution
.iter()
.map(|&size| (size as f64 - mean_size).powi(2))
.sum::<f64>()
/ concept_distribution.len() as f64;
let balance_score = 1.0 / (1.0 + variance);
score += balance_score * 0.3;
let complexity_score = if original_task.complexity.requires_decomposition() {
if subtasks.len() > 1 { 1.0 } else { 0.5 }
} else if subtasks.len() <= 2 {
1.0
} else {
0.7
};
score += complexity_score * 0.3;
score.clamp(0.0, 1.0)
}
fn calculate_parallelism_factor(&self, dependencies: &HashMap<TaskId, Vec<TaskId>>) -> f64 {
if dependencies.is_empty() {
return 1.0; }
let total_tasks = dependencies.keys().len();
let independent_tasks = dependencies.values().filter(|deps| deps.is_empty()).count();
if total_tasks == 0 {
1.0
} else {
independent_tasks as f64 / total_tasks as f64
}
}
}
#[async_trait]
impl TaskDecomposer for KnowledgeGraphTaskDecomposer {
async fn decompose_task(
&self,
task: &Task,
config: &DecompositionConfig,
) -> TaskDecompositionResult<DecompositionResult> {
info!("Starting decomposition of task: {}", task.task_id);
let cache_key = format!("{}_{:?}", task.task_id, config.strategy);
{
let cache = self.cache.read().await;
if let Some(cached_result) = cache.get(&cache_key) {
debug!("Using cached decomposition for task {}", task.task_id);
return Ok(cached_result.clone());
}
}
let concepts = self.extract_task_concepts(task).await?;
if concepts.is_empty() {
return Err(TaskDecompositionError::DecompositionFailed(
task.task_id.clone(),
"No concepts could be extracted from task".to_string(),
));
}
let concept_groups = self
.analyze_connectivity(&concepts, config.similarity_threshold)
.await?;
if concept_groups.is_empty() || concept_groups.len() == 1 {
let result = DecompositionResult {
original_task: task.task_id.clone(),
subtasks: vec![task.clone()],
dependencies: HashMap::new(),
metadata: DecompositionMetadata {
strategy_used: config.strategy.clone(),
depth: 0,
subtask_count: 1,
concepts_analyzed: concepts,
roles_identified: Vec::new(),
confidence_score: 0.8,
parallelism_factor: 1.0,
},
};
return Ok(result);
}
let subtasks = self
.generate_subtasks_from_concepts(task, &concept_groups, config)
.await?;
let dependencies = self
.generate_subtask_dependencies(&subtasks, task, config)
.await?;
let confidence_score = self.calculate_confidence_score(task, &subtasks, &concept_groups);
let parallelism_factor = self.calculate_parallelism_factor(&dependencies);
let result = DecompositionResult {
original_task: task.task_id.clone(),
subtasks: subtasks.clone(),
dependencies,
metadata: DecompositionMetadata {
strategy_used: config.strategy.clone(),
depth: 1, subtask_count: subtasks.len() as u32,
concepts_analyzed: concepts,
roles_identified: Vec::new(), confidence_score,
parallelism_factor,
},
};
{
let mut cache = self.cache.write().await;
cache.insert(cache_key, result.clone());
}
info!(
"Completed decomposition of task {} into {} subtasks",
task.task_id,
result.subtasks.len()
);
Ok(result)
}
async fn analyze_complexity(&self, task: &Task) -> TaskDecompositionResult<TaskComplexity> {
let concepts = self.extract_task_concepts(task).await?;
let complexity = match concepts.len() {
0..=2 => TaskComplexity::Simple,
3..=5 => TaskComplexity::Moderate,
6..=10 => TaskComplexity::Complex,
_ => TaskComplexity::VeryComplex,
};
debug!(
"Analyzed complexity for task {}: {:?} (based on {} concepts)",
task.task_id,
complexity,
concepts.len()
);
Ok(complexity)
}
async fn validate_decomposition(
&self,
result: &DecompositionResult,
) -> TaskDecompositionResult<bool> {
if result.subtasks.is_empty() {
return Ok(false);
}
let mut visited = HashSet::new();
let mut rec_stack = HashSet::new();
for subtask in &result.subtasks {
if has_circular_dependency(
&subtask.task_id,
&result.dependencies,
&mut visited,
&mut rec_stack,
) {
return Ok(false);
}
}
if result.metadata.confidence_score < 0.5 {
return Ok(false);
}
Ok(true)
}
}
impl KnowledgeGraphTaskDecomposer {}
fn has_circular_dependency(
task_id: &str,
dependencies: &HashMap<TaskId, Vec<TaskId>>,
visited: &mut HashSet<String>,
rec_stack: &mut HashSet<String>,
) -> bool {
visited.insert(task_id.to_string());
rec_stack.insert(task_id.to_string());
if let Some(deps) = dependencies.get(task_id) {
for dep in deps {
if !visited.contains(dep) {
if has_circular_dependency(dep, dependencies, visited, rec_stack) {
return true;
}
} else if rec_stack.contains(dep) {
return true;
}
}
}
rec_stack.remove(task_id);
false
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use crate::decomposition::Automata;
use terraphim_rolegraph::RoleGraph;
fn create_test_automata() -> Arc<Automata> {
Arc::new(Automata::default())
}
async fn create_test_role_graph() -> Arc<RoleGraph> {
use terraphim_automata::{AutomataPath, load_thesaurus};
use terraphim_types::RoleName;
let role_name = RoleName::new("test_role");
let thesaurus = load_thesaurus(&AutomataPath::local_example())
.await
.unwrap();
let role_graph = RoleGraph::new(role_name, thesaurus).await.unwrap();
Arc::new(role_graph)
}
#[tokio::test]
async fn test_task_decomposer_creation() {
let automata = create_test_automata();
let role_graph = create_test_role_graph().await;
let decomposer = KnowledgeGraphTaskDecomposer::new(automata, role_graph);
assert!(decomposer.cache.read().await.is_empty());
}
#[tokio::test]
async fn test_simple_task_decomposition() {
let automata = create_test_automata();
let role_graph = create_test_role_graph().await;
let decomposer = KnowledgeGraphTaskDecomposer::new(automata, role_graph);
let task = Task::new(
"test_task".to_string(),
"Simple test task".to_string(),
TaskComplexity::Simple,
1,
);
let config = DecompositionConfig::default();
let result = decomposer.decompose_task(&task, &config).await;
assert!(result.is_ok());
let decomposition = result.unwrap();
assert_eq!(decomposition.original_task, "test_task");
assert!(!decomposition.subtasks.is_empty());
}
#[tokio::test]
async fn test_complexity_analysis() {
let automata = create_test_automata();
let role_graph = create_test_role_graph().await;
let decomposer = KnowledgeGraphTaskDecomposer::new(automata, role_graph);
let simple_task = Task::new(
"simple".to_string(),
"Simple task".to_string(),
TaskComplexity::Simple,
1,
);
let result = decomposer.analyze_complexity(&simple_task).await;
assert!(result.is_ok());
}
#[test]
fn test_decomposition_config_defaults() {
let config = DecompositionConfig::default();
assert_eq!(config.max_depth, 3);
assert_eq!(config.min_subtask_complexity, TaskComplexity::Simple);
assert_eq!(config.max_subtasks_per_task, 10);
assert_eq!(config.strategy, DecompositionStrategy::Hybrid);
assert_eq!(config.similarity_threshold, 0.7);
assert!(config.preserve_dependencies);
assert!(config.optimize_for_parallelism);
}
}