use std::collections::{HashMap, HashSet};
use synaptic_core::ToolDefinition;
#[derive(Debug, Clone, Default)]
pub struct FilterContext {
pub turn_count: usize,
pub last_tool: Option<String>,
pub metadata: HashMap<String, serde_json::Value>,
}
pub trait ToolFilter: Send + Sync {
fn filter(&self, tools: Vec<ToolDefinition>, context: &FilterContext) -> Vec<ToolDefinition>;
}
pub struct AllowListFilter {
allowed: HashSet<String>,
}
impl AllowListFilter {
pub fn new(allowed: impl IntoIterator<Item = impl Into<String>>) -> Self {
Self {
allowed: allowed.into_iter().map(|s| s.into()).collect(),
}
}
}
impl ToolFilter for AllowListFilter {
fn filter(&self, tools: Vec<ToolDefinition>, _context: &FilterContext) -> Vec<ToolDefinition> {
tools
.into_iter()
.filter(|t| self.allowed.contains(&t.name))
.collect()
}
}
pub struct DenyListFilter {
denied: HashSet<String>,
}
impl DenyListFilter {
pub fn new(denied: impl IntoIterator<Item = impl Into<String>>) -> Self {
Self {
denied: denied.into_iter().map(|s| s.into()).collect(),
}
}
}
impl ToolFilter for DenyListFilter {
fn filter(&self, tools: Vec<ToolDefinition>, _context: &FilterContext) -> Vec<ToolDefinition> {
tools
.into_iter()
.filter(|t| !self.denied.contains(&t.name))
.collect()
}
}
pub struct StateMachineFilter {
after_tool_rules: HashMap<String, HashSet<String>>,
turn_thresholds: Vec<TurnThreshold>,
}
#[derive(Debug, Clone)]
struct TurnThreshold {
min_turns: usize,
add_tools: HashSet<String>,
}
impl StateMachineFilter {
pub fn new() -> Self {
Self {
after_tool_rules: HashMap::new(),
turn_thresholds: Vec::new(),
}
}
pub fn after_tool(
mut self,
tool_name: impl Into<String>,
allowed_next: impl IntoIterator<Item = impl Into<String>>,
) -> Self {
self.after_tool_rules.insert(
tool_name.into(),
allowed_next.into_iter().map(|s| s.into()).collect(),
);
self
}
pub fn turn_threshold(
mut self,
min_turns: usize,
add_tools: impl IntoIterator<Item = impl Into<String>>,
) -> Self {
self.turn_thresholds.push(TurnThreshold {
min_turns,
add_tools: add_tools.into_iter().map(|s| s.into()).collect(),
});
self
}
}
impl Default for StateMachineFilter {
fn default() -> Self {
Self::new()
}
}
impl ToolFilter for StateMachineFilter {
fn filter(&self, tools: Vec<ToolDefinition>, context: &FilterContext) -> Vec<ToolDefinition> {
let mut result = tools;
if let Some(last) = &context.last_tool {
if let Some(allowed) = self.after_tool_rules.get(last) {
result.retain(|t| allowed.contains(&t.name));
}
}
let mut gated_tools: HashMap<&str, bool> = HashMap::new();
for threshold in &self.turn_thresholds {
let met = context.turn_count >= threshold.min_turns;
for tool_name in &threshold.add_tools {
let entry = gated_tools.entry(tool_name.as_str()).or_insert(false);
if met {
*entry = true;
}
}
}
if !gated_tools.is_empty() {
result.retain(|t| {
match gated_tools.get(t.name.as_str()) {
Some(&met) => met, None => true, }
});
}
result
}
}
pub struct CompositeFilter(pub Vec<Box<dyn ToolFilter>>);
impl CompositeFilter {
pub fn new(filters: Vec<Box<dyn ToolFilter>>) -> Self {
Self(filters)
}
}
impl ToolFilter for CompositeFilter {
fn filter(
&self,
mut tools: Vec<ToolDefinition>,
context: &FilterContext,
) -> Vec<ToolDefinition> {
for f in &self.0 {
tools = f.filter(tools, context);
}
tools
}
}