Skip to main content

agent_tools_interface/core/
scope.rs

1//! Scope enforcement for ATI.
2//!
3//! Scopes are carried inside JWT claims as a space-delimited `scope` string.
4//! This module provides matching logic: exact matches, wildcard patterns
5//! (`tool:github__*`), and tool filtering.
6
7use crate::core::manifest::{Provider, Tool};
8use thiserror::Error;
9
10/// Check if a name matches a pattern with optional wildcard suffix.
11///
12/// Supports:
13/// - Exact match: `"foo"` matches `"foo"`
14/// - Wildcard suffix: `"foo*"` matches `"foobar"`
15/// - Global wildcard: `"*"` matches everything
16pub fn matches_wildcard(name: &str, pattern: &str) -> bool {
17    if pattern == "*" {
18        return true;
19    }
20    if pattern == name {
21        return true;
22    }
23    if let Some(prefix) = pattern.strip_suffix('*') {
24        if name.starts_with(prefix) {
25            return true;
26        }
27    }
28    false
29}
30
31#[derive(Error, Debug)]
32pub enum ScopeError {
33    #[error("Scopes have expired (expired at {0})")]
34    Expired(u64),
35    #[error("Access denied: '{0}' is not in your scopes")]
36    AccessDenied(String),
37}
38
39/// Scope configuration — constructed from JWT claims or programmatically.
40#[derive(Debug, Clone)]
41pub struct ScopeConfig {
42    /// Parsed scope strings (e.g. ["tool:web_search", "tool:github__*", "help"]).
43    pub scopes: Vec<String>,
44    /// Agent identity (from JWT `sub` claim).
45    pub sub: String,
46    /// Expiry timestamp (from JWT `exp` claim). 0 = no expiry.
47    pub expires_at: u64,
48    /// Per-tool rate limits parsed from JWT claims.
49    pub rate_config: Option<crate::core::rate::RateConfig>,
50}
51
52impl ScopeConfig {
53    /// Build a ScopeConfig from JWT TokenClaims.
54    pub fn from_jwt(claims: &crate::core::jwt::TokenClaims) -> Self {
55        let rate_config = claims.ati.as_ref().and_then(|ns| {
56            if ns.rate.is_empty() {
57                None
58            } else {
59                crate::core::rate::parse_rate_config(&ns.rate).ok()
60            }
61        });
62        ScopeConfig {
63            scopes: claims.scopes(),
64            sub: claims.sub.clone(),
65            expires_at: claims.exp,
66            rate_config,
67        }
68    }
69
70    /// Create an unrestricted scope config (for dev mode / no JWT set).
71    pub fn unrestricted() -> Self {
72        ScopeConfig {
73            scopes: vec!["*".to_string()],
74            sub: "dev".to_string(),
75            expires_at: 0,
76            rate_config: None,
77        }
78    }
79
80    /// Check if the scopes have expired.
81    pub fn is_expired(&self) -> bool {
82        if self.expires_at == 0 {
83            return false;
84        }
85        let now = std::time::SystemTime::now()
86            .duration_since(std::time::UNIX_EPOCH)
87            .unwrap_or_default()
88            .as_secs();
89        now > self.expires_at
90    }
91
92    /// Check if a specific tool scope is allowed.
93    ///
94    /// Supports:
95    /// - Exact match: `"tool:web_search"` matches `"tool:web_search"`
96    /// - Wildcard suffix: `"tool:github__*"` matches `"tool:github__search_repos"`
97    /// - Global wildcard: `"*"` matches everything
98    /// - Empty tool scope: always allowed (tool has no scope requirement)
99    pub fn is_allowed(&self, tool_scope: &str) -> bool {
100        if self.is_expired() {
101            return false;
102        }
103        // Empty scope on tool means always allowed
104        if tool_scope.is_empty() {
105            return true;
106        }
107        // Check each scope pattern
108        for scope in &self.scopes {
109            if matches_wildcard(tool_scope, scope) {
110                return true;
111            }
112        }
113        false
114    }
115
116    /// Check access for a tool, returning an error if denied.
117    pub fn check_access(&self, tool_name: &str, tool_scope: &str) -> Result<(), ScopeError> {
118        if self.is_expired() {
119            return Err(ScopeError::Expired(self.expires_at));
120        }
121        if !self.is_allowed(tool_scope) {
122            return Err(ScopeError::AccessDenied(tool_name.to_string()));
123        }
124        Ok(())
125    }
126
127    /// Get time remaining until expiry, in seconds. Returns None if no expiry.
128    pub fn time_remaining(&self) -> Option<u64> {
129        if self.expires_at == 0 {
130            return None;
131        }
132        let now = std::time::SystemTime::now()
133            .duration_since(std::time::UNIX_EPOCH)
134            .unwrap_or_default()
135            .as_secs();
136        if now >= self.expires_at {
137            Some(0)
138        } else {
139            Some(self.expires_at - now)
140        }
141    }
142
143    /// Number of tool scopes (those starting with "tool:").
144    pub fn tool_scope_count(&self) -> usize {
145        self.scopes
146            .iter()
147            .filter(|s| s.starts_with("tool:"))
148            .count()
149    }
150
151    /// Number of skill scopes (those starting with "skill:").
152    pub fn skill_scope_count(&self) -> usize {
153        self.scopes
154            .iter()
155            .filter(|s| s.starts_with("skill:"))
156            .count()
157    }
158
159    /// Check if help is enabled.
160    pub fn help_enabled(&self) -> bool {
161        self.is_wildcard() || self.scopes.iter().any(|s| s == "help")
162    }
163
164    /// Check if this is an unrestricted (wildcard) scope.
165    pub fn is_wildcard(&self) -> bool {
166        self.scopes.iter().any(|s| s == "*")
167    }
168}
169
170/// Filter a list of tools to only those allowed by the scope config.
171pub fn filter_tools_by_scope<'a>(
172    tools: Vec<(&'a Provider, &'a Tool)>,
173    scopes: &ScopeConfig,
174) -> Vec<(&'a Provider, &'a Tool)> {
175    if scopes.is_wildcard() {
176        return tools;
177    }
178
179    tools
180        .into_iter()
181        .filter(|(_, tool)| match &tool.scope {
182            Some(scope) => scopes.is_allowed(scope),
183            None => true, // No scope required = always allowed
184        })
185        .collect()
186}
187
188#[cfg(test)]
189mod tests {
190    use super::*;
191
192    fn make_scopes(scopes: &[&str]) -> ScopeConfig {
193        ScopeConfig {
194            scopes: scopes.iter().map(|s| s.to_string()).collect(),
195            sub: "test-agent".into(),
196            expires_at: 0,
197            rate_config: None,
198        }
199    }
200
201    #[test]
202    fn test_exact_match() {
203        let config = make_scopes(&["tool:web_search", "tool:web_fetch"]);
204        assert!(config.is_allowed("tool:web_search"));
205        assert!(config.is_allowed("tool:web_fetch"));
206        assert!(!config.is_allowed("tool:patent_search"));
207    }
208
209    #[test]
210    fn test_wildcard_suffix() {
211        let config = make_scopes(&["tool:github__*"]);
212        assert!(config.is_allowed("tool:github__search_repos"));
213        assert!(config.is_allowed("tool:github__create_issue"));
214        assert!(!config.is_allowed("tool:linear__list_issues"));
215    }
216
217    #[test]
218    fn test_global_wildcard() {
219        let config = make_scopes(&["*"]);
220        assert!(config.is_allowed("tool:anything"));
221        assert!(config.is_allowed("help"));
222        assert!(config.is_allowed("skill:whatever"));
223    }
224
225    #[test]
226    fn test_empty_scope_always_allowed() {
227        let config = make_scopes(&["tool:web_search"]);
228        assert!(config.is_allowed(""));
229    }
230
231    #[test]
232    fn test_expired_denies_all() {
233        let config = ScopeConfig {
234            scopes: vec!["tool:web_search".into()],
235            sub: "test".into(),
236            expires_at: 1,
237            rate_config: None,
238        };
239        assert!(config.is_expired());
240        assert!(!config.is_allowed("tool:web_search"));
241    }
242
243    #[test]
244    fn test_zero_expiry_means_no_expiry() {
245        let config = ScopeConfig {
246            scopes: vec!["tool:web_search".into()],
247            sub: "test".into(),
248            expires_at: 0,
249            rate_config: None,
250        };
251        assert!(!config.is_expired());
252        assert!(config.is_allowed("tool:web_search"));
253    }
254
255    #[test]
256    fn test_check_access_denied() {
257        let config = make_scopes(&["tool:web_search"]);
258        let result = config.check_access("patent_search", "tool:patent_search");
259        assert!(result.is_err());
260    }
261
262    #[test]
263    fn test_check_access_expired() {
264        let config = ScopeConfig {
265            scopes: vec!["tool:web_search".into()],
266            sub: "test".into(),
267            expires_at: 1,
268            rate_config: None,
269        };
270        let result = config.check_access("web_search", "tool:web_search");
271        assert!(result.is_err());
272    }
273
274    #[test]
275    fn test_help_enabled() {
276        assert!(make_scopes(&["tool:web_search", "help"]).help_enabled());
277        assert!(!make_scopes(&["tool:web_search"]).help_enabled());
278        assert!(make_scopes(&["*"]).help_enabled());
279    }
280
281    #[test]
282    fn test_scope_counts() {
283        let config = make_scopes(&["tool:a", "tool:b", "skill:c", "help"]);
284        assert_eq!(config.tool_scope_count(), 2);
285        assert_eq!(config.skill_scope_count(), 1);
286    }
287
288    #[test]
289    fn test_unrestricted() {
290        let config = ScopeConfig::unrestricted();
291        assert!(config.is_wildcard());
292        assert!(config.is_allowed("anything"));
293        assert!(config.help_enabled());
294    }
295
296    #[test]
297    fn test_mixed_patterns() {
298        let config = make_scopes(&["tool:web_search", "tool:github__*", "skill:research-*"]);
299        assert!(config.is_allowed("tool:web_search"));
300        assert!(config.is_allowed("tool:github__search_repos"));
301        assert!(config.is_allowed("skill:research-general"));
302        assert!(!config.is_allowed("tool:linear__list_issues"));
303        assert!(!config.is_allowed("skill:patent-analysis"));
304    }
305}