use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "type", rename_all = "snake_case")]
pub enum DependencyGraphType {
Static {
#[serde(default = "default_pattern")]
pattern: String,
},
Custom {
edges: Vec<DependencyEdgeConfig>,
start: Vec<String>,
terminal: Vec<String>,
#[serde(default)]
param_variants: Vec<ParamVariantConfig>,
},
Llm {},
}
impl Default for DependencyGraphType {
fn default() -> Self {
Self::Static {
pattern: default_pattern(),
}
}
}
fn default_pattern() -> String {
"code_search".to_string()
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DependencyEdgeConfig {
pub from: String,
pub to: String,
#[serde(default = "default_confidence")]
pub confidence: f64,
}
fn default_confidence() -> f64 {
0.9
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ParamVariantConfig {
pub action: String,
pub key: String,
pub values: Vec<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize, Default)]
pub struct DependencyGraphConfig {
#[serde(flatten)]
pub graph_type: DependencyGraphType,
}
impl DependencyGraphConfig {
pub fn to_core_graph(
&self,
available_actions: &[String],
) -> Option<swarm_engine_core::exploration::DependencyGraph> {
use swarm_engine_core::exploration::{
DependencyGraph, DependencyPlanner, StaticDependencyPlanner,
};
match &self.graph_type {
DependencyGraphType::Static { pattern } => {
let planner = match pattern.as_str() {
"code_search" => StaticDependencyPlanner::new().with_code_search_pattern(),
"file_exploration" => {
StaticDependencyPlanner::new().with_file_exploration_pattern()
}
_ => StaticDependencyPlanner::new().with_code_search_pattern(),
};
planner.plan("task", available_actions).ok()
}
DependencyGraphType::Custom {
edges,
start,
terminal,
param_variants,
} => {
let mut builder = DependencyGraph::builder()
.available_actions(available_actions.iter().cloned())
.start_nodes(start.iter().cloned())
.terminal_nodes(terminal.iter().cloned());
for edge in edges {
builder = builder.edge(&edge.from, &edge.to, edge.confidence);
}
for pv in param_variants {
builder =
builder.param_variants(&pv.action, &pv.key, pv.values.iter().cloned());
}
Some(builder.build())
}
DependencyGraphType::Llm {} => {
None
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_dependency_graph_type_default() {
let graph_type = DependencyGraphType::default();
match graph_type {
DependencyGraphType::Static { pattern } => {
assert_eq!(pattern, "code_search");
}
_ => panic!("Expected Static variant"),
}
}
#[test]
fn test_dependency_edge_config_default_confidence() {
let toml_str = r#"
from = "grep"
to = "read"
"#;
let edge: DependencyEdgeConfig = toml::from_str(toml_str).unwrap();
assert_eq!(edge.from, "grep");
assert_eq!(edge.to, "read");
assert!((edge.confidence - 0.9).abs() < 0.001);
}
#[test]
fn test_dependency_graph_config_static() {
let config = DependencyGraphConfig::default();
let actions = vec!["grep".to_string(), "read".to_string(), "done".to_string()];
let graph = config.to_core_graph(&actions);
assert!(graph.is_some());
}
}