use std::collections::{HashMap, HashSet};
use super::DependencyGraph;
pub trait Rules: Send + Sync {
fn successors(&self, node_type: &str) -> Vec<&str>;
fn roots(&self) -> Vec<&str>;
fn is_terminal(&self, node_type: &str) -> bool;
fn is_empty(&self) -> bool;
fn param_variants(&self, _node_type: &str) -> Option<(&str, &[String])> {
None
}
}
#[derive(Debug, Clone, Default)]
pub struct NodeRules {
successors: HashMap<String, HashSet<String>>,
roots: HashSet<String>,
terminals: HashSet<String>,
param_variants: HashMap<String, (String, Vec<String>)>,
edge_confidence: HashMap<(String, String), f64>,
}
impl NodeRules {
pub fn new() -> Self {
Self::default()
}
pub fn add_rule(mut self, from: &str, to: &str) -> Self {
self.successors
.entry(from.to_string())
.or_default()
.insert(to.to_string());
self
}
pub fn add_rules(mut self, from: &str, tos: &[&str]) -> Self {
let entry = self.successors.entry(from.to_string()).or_default();
for to in tos {
entry.insert(to.to_string());
}
self
}
pub fn add_root(mut self, node_type: &str) -> Self {
self.roots.insert(node_type.to_string());
self
}
pub fn add_roots(mut self, node_types: &[&str]) -> Self {
for node_type in node_types {
self.roots.insert(node_type.to_string());
}
self
}
pub fn add_terminal(mut self, node_type: &str) -> Self {
self.terminals.insert(node_type.to_string());
self
}
pub fn add_terminals(mut self, node_types: &[&str]) -> Self {
for node_type in node_types {
self.terminals.insert(node_type.to_string());
}
self
}
pub fn add_param_variants(mut self, node_type: &str, key: &str, values: &[&str]) -> Self {
self.param_variants.insert(
node_type.to_string(),
(
key.to_string(),
values.iter().map(|s| s.to_string()).collect(),
),
);
self
}
pub fn add_rule_with_confidence(mut self, from: &str, to: &str, confidence: f64) -> Self {
self.successors
.entry(from.to_string())
.or_default()
.insert(to.to_string());
self.edge_confidence.insert(
(from.to_string(), to.to_string()),
confidence.clamp(0.0, 1.0),
);
self
}
pub fn successors(&self, node_type: &str) -> Vec<&str> {
self.successors
.get(node_type)
.map(|set| set.iter().map(|s| s.as_str()).collect())
.unwrap_or_default()
}
pub fn roots(&self) -> Vec<&str> {
self.roots.iter().map(|s| s.as_str()).collect()
}
pub fn terminals(&self) -> Vec<&str> {
self.terminals.iter().map(|s| s.as_str()).collect()
}
pub fn can_transition(&self, from: &str, to: &str) -> bool {
self.successors
.get(from)
.map(|set| set.contains(to))
.unwrap_or(false)
}
pub fn has_node_type(&self, node_type: &str) -> bool {
self.successors.contains_key(node_type)
|| self.roots.contains(node_type)
|| self.terminals.contains(node_type)
}
pub fn is_terminal(&self, node_type: &str) -> bool {
self.terminals.contains(node_type)
}
pub fn is_root(&self, node_type: &str) -> bool {
self.roots.contains(node_type)
}
pub fn get_confidence(&self, from: &str, to: &str) -> Option<f64> {
self.edge_confidence
.get(&(from.to_string(), to.to_string()))
.copied()
}
pub fn confidence_map(&self) -> HashMap<String, f64> {
let mut result = HashMap::new();
for ((_, to), conf) in &self.edge_confidence {
let entry = result.entry(to.clone()).or_insert(0.0);
if *conf > *entry {
*entry = *conf;
}
}
result
}
pub fn is_empty(&self) -> bool {
self.successors.is_empty() && self.roots.is_empty()
}
#[cfg(test)]
pub fn for_testing() -> Self {
Self::new()
.add_roots(&["grep", "glob"])
.add_rules("grep", &["read", "summary"])
.add_rule("glob", "grep")
.add_terminals(&["read", "summary"])
}
}
impl Rules for NodeRules {
fn successors(&self, node_type: &str) -> Vec<&str> {
self.successors(node_type)
}
fn roots(&self) -> Vec<&str> {
self.roots()
}
fn is_terminal(&self, node_type: &str) -> bool {
self.is_terminal(node_type)
}
fn is_empty(&self) -> bool {
self.is_empty()
}
fn param_variants(&self, node_type: &str) -> Option<(&str, &[String])> {
self.param_variants
.get(node_type)
.map(|(key, values)| (key.as_str(), values.as_slice()))
}
}
impl From<DependencyGraph> for NodeRules {
fn from(graph: DependencyGraph) -> Self {
let mut rules = NodeRules::new();
for start in graph.start_actions() {
rules.roots.insert(start);
}
for terminal in graph.terminal_actions() {
rules.terminals.insert(terminal);
}
for edge in graph.edges() {
rules
.successors
.entry(edge.from.clone())
.or_default()
.insert(edge.to.clone());
rules
.edge_confidence
.insert((edge.from.clone(), edge.to.clone()), edge.confidence);
}
for (action, (key, values)) in graph.all_param_variants() {
rules
.param_variants
.insert(action.clone(), (key.clone(), values.clone()));
}
rules
}
}
impl From<&DependencyGraph> for NodeRules {
fn from(graph: &DependencyGraph) -> Self {
let mut rules = NodeRules::new();
for start in graph.start_actions() {
rules.roots.insert(start);
}
for terminal in graph.terminal_actions() {
rules.terminals.insert(terminal);
}
for edge in graph.edges() {
rules
.successors
.entry(edge.from.clone())
.or_default()
.insert(edge.to.clone());
rules
.edge_confidence
.insert((edge.from.clone(), edge.to.clone()), edge.confidence);
}
for (action, (key, values)) in graph.all_param_variants() {
rules
.param_variants
.insert(action.clone(), (key.clone(), values.clone()));
}
rules
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::exploration::DependencyGraphBuilder;
#[test]
fn test_node_rules_basic() {
let rules = NodeRules::new()
.add_roots(&["grep", "glob"])
.add_rules("grep", &["read", "summary", "grep"])
.add_rules("read", &["analyze", "extract"])
.add_rule("summary", "report")
.add_terminals(&["report", "extract"]);
let roots = rules.roots();
assert!(roots.contains(&"grep"));
assert!(roots.contains(&"glob"));
let grep_successors = rules.successors("grep");
assert_eq!(grep_successors.len(), 3);
assert!(grep_successors.contains(&"read"));
assert!(grep_successors.contains(&"summary"));
assert!(rules.can_transition("grep", "read"));
assert!(!rules.can_transition("grep", "report"));
assert!(rules.can_transition("summary", "report"));
assert!(rules.is_terminal("report"));
assert!(rules.is_terminal("extract"));
assert!(!rules.is_terminal("grep"));
}
#[test]
fn test_node_rules_empty() {
let rules = NodeRules::new();
assert!(rules.is_empty());
assert!(rules.successors("anything").is_empty());
assert!(rules.roots().is_empty());
}
#[test]
fn test_node_rules_has_node_type() {
let rules = NodeRules::new()
.add_root("start")
.add_rule("middle", "end")
.add_terminal("end");
assert!(rules.has_node_type("start"));
assert!(rules.has_node_type("middle"));
assert!(rules.has_node_type("end"));
assert!(!rules.has_node_type("unknown"));
}
#[test]
fn test_from_dependency_graph() {
let graph = DependencyGraphBuilder::new()
.task("Find auth function")
.available_actions(["Grep", "List", "Read"])
.edge("Grep", "Read", 0.95)
.edge("List", "Grep", 0.60)
.edge("List", "Read", 0.40)
.start_nodes(["Grep", "List"])
.terminal_node("Read")
.build();
let rules: NodeRules = graph.into();
assert!(rules.is_root("Grep"));
assert!(rules.is_root("List"));
assert!(!rules.is_root("Read"));
assert!(rules.is_terminal("Read"));
assert!(!rules.is_terminal("Grep"));
assert!(rules.can_transition("Grep", "Read"));
assert!(rules.can_transition("List", "Grep"));
assert!(rules.can_transition("List", "Read"));
assert!(!rules.can_transition("Read", "Grep")); }
#[test]
fn test_from_dependency_graph_ref() {
let graph = DependencyGraphBuilder::new()
.edge("A", "B", 0.9)
.start_node("A")
.terminal_node("B")
.build();
let rules: NodeRules = (&graph).into();
assert!(rules.is_root("A"));
assert!(rules.is_terminal("B"));
assert!(rules.can_transition("A", "B"));
}
}