Skip to main content

clark_agent/plugins/
opening_gate.rs

1//! `OpeningGate` — narrow the very first LLM call's tool list to a
2//! caller-supplied subset.
3//!
4//! A common product pattern is "the model must frame / plan / clarify on
5//! its opening turn before any work tools are available." This gate
6//! enforces that on the wire: on iteration 0 it advertises only the tools
7//! whose names are in the configured allowlist (intersected with the tools
8//! the registry actually has); from iteration 1 on it imposes no
9//! narrowing. Pair it with a system-prompt sentence that explains the
10//! same contract in prose — the wire-level constraint is harder for a
11//! model to ignore than text instructions.
12//!
13//! The gate ships **no** default allowlist: it is vocabulary-free, so the
14//! core stays free of any product's tool names. Callers supply the
15//! opening subset via [`OpeningGate::with_allowlist`]. Composes with other
16//! [`ToolGate`] plugins through allowlist intersection.
17
18use async_trait::async_trait;
19use std::collections::HashSet;
20
21use crate::plugin::{Plugin, PluginCapabilities, ToolGate, ToolGateContext};
22
23/// Narrows iteration 0's tool advertisement to a caller-supplied subset.
24/// Composes via allowlist intersection with other `ToolGate` plugins.
25pub struct OpeningGate {
26    allowlist: HashSet<String>,
27}
28
29impl OpeningGate {
30    /// Construct a gate that narrows the opening turn to `allowlist`.
31    /// The names are product-specific and supplied by the caller — the
32    /// gate itself knows no tool vocabulary.
33    pub fn with_allowlist(allowlist: HashSet<String>) -> Self {
34        Self { allowlist }
35    }
36
37    /// Convenience constructor from an iterator of tool names.
38    pub fn new<I, S>(tools: I) -> Self
39    where
40        I: IntoIterator<Item = S>,
41        S: Into<String>,
42    {
43        Self {
44            allowlist: tools.into_iter().map(Into::into).collect(),
45        }
46    }
47}
48
49impl Plugin for OpeningGate {
50    fn name(&self) -> &'static str {
51        "opening_gate"
52    }
53    fn capabilities(&self) -> PluginCapabilities {
54        PluginCapabilities::tool_gate()
55    }
56}
57
58#[async_trait]
59impl ToolGate for OpeningGate {
60    async fn next_turn_tool_allowlist(&self, ctx: ToolGateContext<'_>) -> Option<HashSet<String>> {
61        if ctx.iteration != 0 {
62            return None;
63        }
64        // Intersect the configured allowlist with what the registry
65        // actually has. If a tool we want to advertise isn't in the
66        // registry (test fixture, surface that disabled it, etc.),
67        // don't synthesize the name on the wire.
68        let mut allowed = HashSet::new();
69        for tool in ctx.available_tool_names {
70            if self.allowlist.contains(*tool) {
71                allowed.insert((*tool).to_string());
72            }
73        }
74        Some(allowed)
75    }
76}
77
78#[cfg(test)]
79mod tests {
80    use super::*;
81
82    fn ctx(iteration: usize, tools: &'static [&'static str]) -> ToolGateContext<'static> {
83        ToolGateContext {
84            iteration,
85            messages: &[],
86            conversation_id: None,
87            available_tool_names: tools,
88        }
89    }
90
91    #[tokio::test]
92    async fn gate_fires_on_iteration_zero_with_intersection() {
93        // A product allowlist of framing tools; work tools must be hidden
94        // on the opening turn.
95        let gate = OpeningGate::new(["frame", "deliver", "ask"]);
96        let allowed = gate
97            .next_turn_tool_allowlist(ctx(
98                0,
99                &[
100                    "frame",
101                    "deliver",
102                    "ask",
103                    "load_skill",
104                    "shell",
105                    "file_write",
106                ],
107            ))
108            .await
109            .expect("opening turn should narrow");
110        assert!(allowed.contains("frame"));
111        assert!(allowed.contains("deliver"));
112        assert!(allowed.contains("ask"));
113        assert!(
114            !allowed.contains("load_skill"),
115            "opening turn must hide non-framing tools"
116        );
117        assert!(
118            !allowed.contains("shell"),
119            "opening turn must hide work tools so the model frames first"
120        );
121        assert!(
122            !allowed.contains("file_write"),
123            "opening turn must hide work tools"
124        );
125    }
126
127    #[tokio::test]
128    async fn gate_returns_none_after_first_iteration() {
129        let gate = OpeningGate::new(["frame", "shell"]);
130        let result = gate
131            .next_turn_tool_allowlist(ctx(1, &["frame", "shell"]))
132            .await;
133        assert!(
134            result.is_none(),
135            "iteration > 0 must NOT narrow — let the model use the full catalog"
136        );
137    }
138
139    #[tokio::test]
140    async fn allowlist_does_not_synthesize_tools_missing_from_registry() {
141        // The gate would allow `frame`/`deliver`, but only `ask`/`deliver`
142        // and `shell` are in the registry — don't fabricate names on the
143        // wire.
144        let gate = OpeningGate::new(["frame", "deliver", "ask"]);
145        let allowed = gate
146            .next_turn_tool_allowlist(ctx(0, &["ask", "deliver", "shell"]))
147            .await
148            .unwrap();
149        assert_eq!(allowed.len(), 2);
150        assert!(allowed.contains("ask"));
151        assert!(allowed.contains("deliver"));
152        assert!(!allowed.contains("frame"));
153        assert!(!allowed.contains("shell"));
154    }
155
156    #[tokio::test]
157    async fn with_allowlist_takes_an_explicit_set() {
158        let mut custom = HashSet::new();
159        custom.insert("frame".to_string());
160        let gate = OpeningGate::with_allowlist(custom);
161        let allowed = gate
162            .next_turn_tool_allowlist(ctx(0, &["frame", "deliver", "ask"]))
163            .await
164            .unwrap();
165        assert_eq!(allowed.len(), 1);
166        assert!(allowed.contains("frame"));
167        assert!(!allowed.contains("deliver"));
168    }
169}