use car_memgine::MemgineEngine;
use crate::types::*;
pub fn select_actions(
engine: &MemgineEngine,
problem_class: ProblemClass,
problem: &str,
) -> Vec<(ActionKind, ActionConfig)> {
let query = format!("reason {} {}", problem_class, problem);
let found = engine.find_skill("", "", &query, 20);
let mut candidates: Vec<(ActionKind, ActionConfig, f64)> = found
.iter()
.filter_map(|(meta, score)| {
if !meta.name.starts_with("reason:") {
return None;
}
let config: ActionConfig = serde_json::from_str(&meta.code).ok()?;
if !config.applicable_to.contains(&problem_class) {
return None;
}
Some((config.kind, config, *score))
})
.collect();
if candidates.is_empty() {
return fallback_plan(problem_class);
}
candidates.sort_by(|a, b| b.2.partial_cmp(&a.2).unwrap_or(std::cmp::Ordering::Equal));
let mut seen = std::collections::HashSet::new();
candidates.retain(|(kind, _, _)| seen.insert(*kind));
topological_sort(candidates)
}
fn topological_sort(
mut actions: Vec<(ActionKind, ActionConfig, f64)>,
) -> Vec<(ActionKind, ActionConfig)> {
let mut result: Vec<(ActionKind, ActionConfig)> = Vec::new();
let mut resolved: std::collections::HashSet<ActionKind> = std::collections::HashSet::new();
actions.sort_by_key(|(_, config, _)| config.priority);
let mut remaining = actions;
let max_iterations = remaining.len() + 1;
let mut iteration = 0;
while !remaining.is_empty() && iteration < max_iterations {
iteration += 1;
let mut next_remaining = Vec::new();
for (kind, config, score) in remaining {
let deps_met = config
.prerequisites
.iter()
.all(|dep| resolved.contains(dep));
if deps_met {
resolved.insert(kind);
result.push((kind, config));
} else {
next_remaining.push((kind, config, score));
}
}
remaining = next_remaining;
}
for (kind, config, _) in remaining {
result.push((kind, config));
}
result
}
fn fallback_plan(problem_class: ProblemClass) -> Vec<(ActionKind, ActionConfig)> {
use ProblemClass::*;
let actions = match problem_class {
BugFix => vec![
ActionKind::Locate,
ActionKind::Diagnose,
ActionKind::GenerateFix,
ActionKind::VerifyFix,
ActionKind::Explain,
],
Refactor => vec![
ActionKind::Locate,
ActionKind::Diagnose,
ActionKind::GenerateFix,
ActionKind::Explain,
],
Performance => vec![
ActionKind::Locate,
ActionKind::Diagnose,
ActionKind::GenerateFix,
ActionKind::Explain,
],
NewFeature | TestWriting => vec![
ActionKind::Locate,
ActionKind::GenerateFix,
ActionKind::Explain,
],
Architecture | Explanation => vec![ActionKind::Explain],
Unknown => vec![ActionKind::Diagnose, ActionKind::Explain],
};
actions
.into_iter()
.map(|kind| {
let config = ActionConfig {
kind,
applicable_to: vec![problem_class],
prerequisites: vec![],
prompt_template: String::new(), priority: 0,
};
(kind, config)
})
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn topological_sort_respects_deps() {
let actions = vec![
(
ActionKind::Explain,
ActionConfig {
kind: ActionKind::Explain,
applicable_to: vec![ProblemClass::BugFix],
prerequisites: vec![ActionKind::Diagnose],
prompt_template: String::new(),
priority: 50,
},
0.5,
),
(
ActionKind::Diagnose,
ActionConfig {
kind: ActionKind::Diagnose,
applicable_to: vec![ProblemClass::BugFix],
prerequisites: vec![ActionKind::Locate],
prompt_template: String::new(),
priority: 20,
},
0.8,
),
(
ActionKind::Locate,
ActionConfig {
kind: ActionKind::Locate,
applicable_to: vec![ProblemClass::BugFix],
prerequisites: vec![],
prompt_template: String::new(),
priority: 10,
},
0.7,
),
];
let sorted = topological_sort(actions);
assert_eq!(sorted[0].0, ActionKind::Locate);
assert_eq!(sorted[1].0, ActionKind::Diagnose);
assert_eq!(sorted[2].0, ActionKind::Explain);
}
#[test]
fn fallback_plan_for_bug() {
let plan = fallback_plan(ProblemClass::BugFix);
assert!(plan.len() >= 4); assert_eq!(plan[0].0, ActionKind::Locate);
}
#[test]
fn fallback_plan_for_explanation() {
let plan = fallback_plan(ProblemClass::Explanation);
assert_eq!(plan.len(), 1);
assert_eq!(plan[0].0, ActionKind::Explain);
}
}