use async_trait::async_trait;
use std::collections::HashSet;
use crate::plugin::{Plugin, PluginCapabilities, ToolGate, ToolGateContext};
pub struct OpeningGate {
allowlist: HashSet<String>,
}
impl OpeningGate {
pub fn with_allowlist(allowlist: HashSet<String>) -> Self {
Self { allowlist }
}
pub fn new<I, S>(tools: I) -> Self
where
I: IntoIterator<Item = S>,
S: Into<String>,
{
Self {
allowlist: tools.into_iter().map(Into::into).collect(),
}
}
}
impl Plugin for OpeningGate {
fn name(&self) -> &'static str {
"opening_gate"
}
fn capabilities(&self) -> PluginCapabilities {
PluginCapabilities::tool_gate()
}
}
#[async_trait]
impl ToolGate for OpeningGate {
async fn next_turn_tool_allowlist(&self, ctx: ToolGateContext<'_>) -> Option<HashSet<String>> {
if ctx.iteration != 0 {
return None;
}
let mut allowed = HashSet::new();
for tool in ctx.available_tool_names {
if self.allowlist.contains(*tool) {
allowed.insert((*tool).to_string());
}
}
Some(allowed)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn ctx(iteration: usize, tools: &'static [&'static str]) -> ToolGateContext<'static> {
ToolGateContext {
iteration,
messages: &[],
conversation_id: None,
available_tool_names: tools,
}
}
#[tokio::test]
async fn gate_fires_on_iteration_zero_with_intersection() {
let gate = OpeningGate::new(["frame", "deliver", "ask"]);
let allowed = gate
.next_turn_tool_allowlist(ctx(
0,
&[
"frame",
"deliver",
"ask",
"load_skill",
"shell",
"file_write",
],
))
.await
.expect("opening turn should narrow");
assert!(allowed.contains("frame"));
assert!(allowed.contains("deliver"));
assert!(allowed.contains("ask"));
assert!(
!allowed.contains("load_skill"),
"opening turn must hide non-framing tools"
);
assert!(
!allowed.contains("shell"),
"opening turn must hide work tools so the model frames first"
);
assert!(
!allowed.contains("file_write"),
"opening turn must hide work tools"
);
}
#[tokio::test]
async fn gate_returns_none_after_first_iteration() {
let gate = OpeningGate::new(["frame", "shell"]);
let result = gate
.next_turn_tool_allowlist(ctx(1, &["frame", "shell"]))
.await;
assert!(
result.is_none(),
"iteration > 0 must NOT narrow — let the model use the full catalog"
);
}
#[tokio::test]
async fn allowlist_does_not_synthesize_tools_missing_from_registry() {
let gate = OpeningGate::new(["frame", "deliver", "ask"]);
let allowed = gate
.next_turn_tool_allowlist(ctx(0, &["ask", "deliver", "shell"]))
.await
.unwrap();
assert_eq!(allowed.len(), 2);
assert!(allowed.contains("ask"));
assert!(allowed.contains("deliver"));
assert!(!allowed.contains("frame"));
assert!(!allowed.contains("shell"));
}
#[tokio::test]
async fn with_allowlist_takes_an_explicit_set() {
let mut custom = HashSet::new();
custom.insert("frame".to_string());
let gate = OpeningGate::with_allowlist(custom);
let allowed = gate
.next_turn_tool_allowlist(ctx(0, &["frame", "deliver", "ask"]))
.await
.unwrap();
assert_eq!(allowed.len(), 1);
assert!(allowed.contains("frame"));
assert!(!allowed.contains("deliver"));
}
}