Skip to main content

ati/core/
scope.rs

1//! Scope enforcement for ATI.
2//!
3//! Scopes are carried inside JWT claims as a space-delimited `scope` string.
4//! This module provides matching logic: exact matches, wildcard patterns
5//! (`tool:github:*`), and tool filtering.
6
7use crate::core::manifest::{Provider, Tool};
8use thiserror::Error;
9
10/// Check if a name matches a pattern with optional wildcard suffix.
11///
12/// Supports:
13/// - Exact match: `"foo"` matches `"foo"`
14/// - Wildcard suffix: `"foo*"` matches `"foobar"`
15/// - Global wildcard: `"*"` matches everything
16pub fn matches_wildcard(name: &str, pattern: &str) -> bool {
17    if pattern == "*" {
18        return true;
19    }
20    if pattern == name {
21        return true;
22    }
23    if let Some(prefix) = pattern.strip_suffix('*') {
24        if name.starts_with(prefix) {
25            return true;
26        }
27    }
28    false
29}
30
31/// Convert a canonical tool scope like `tool:github:search_repositories`
32/// into the legacy underscore alias `tool:github_search_repositories`.
33fn legacy_tool_scope_alias(tool_scope: &str) -> Option<String> {
34    let suffix = tool_scope.strip_prefix("tool:")?;
35    let colon_pos = suffix.find(':')?;
36    let mut alias = String::with_capacity(tool_scope.len());
37    alias.push_str("tool:");
38    alias.push_str(&suffix[..colon_pos]);
39    alias.push('_');
40    alias.push_str(&suffix[colon_pos + 1..]);
41    Some(alias)
42}
43
44#[derive(Error, Debug)]
45pub enum ScopeError {
46    #[error("Scopes have expired (expired at {0})")]
47    Expired(u64),
48    #[error("Access denied: '{0}' is not in your scopes")]
49    AccessDenied(String),
50}
51
52/// Scope configuration — constructed from JWT claims or programmatically.
53#[derive(Debug, Clone)]
54pub struct ScopeConfig {
55    /// Parsed scope strings (e.g. ["tool:web_search", "tool:github:*", "help"]).
56    pub scopes: Vec<String>,
57    /// Agent identity (from JWT `sub` claim).
58    pub sub: String,
59    /// Expiry timestamp (from JWT `exp` claim). 0 = no expiry.
60    pub expires_at: u64,
61    /// Per-tool rate limits parsed from JWT claims.
62    pub rate_config: Option<crate::core::rate::RateConfig>,
63}
64
65impl ScopeConfig {
66    /// Build a ScopeConfig from JWT TokenClaims.
67    pub fn from_jwt(claims: &crate::core::jwt::TokenClaims) -> Self {
68        let rate_config = claims.ati.as_ref().and_then(|ns| {
69            if ns.rate.is_empty() {
70                None
71            } else {
72                crate::core::rate::parse_rate_config(&ns.rate).ok()
73            }
74        });
75        ScopeConfig {
76            scopes: claims.scopes(),
77            sub: claims.sub.clone(),
78            expires_at: claims.exp,
79            rate_config,
80        }
81    }
82
83    /// Create an unrestricted scope config (for dev mode / no JWT set).
84    pub fn unrestricted() -> Self {
85        ScopeConfig {
86            scopes: vec!["*".to_string()],
87            sub: "dev".to_string(),
88            expires_at: 0,
89            rate_config: None,
90        }
91    }
92
93    /// Check if the scopes have expired.
94    pub fn is_expired(&self) -> bool {
95        if self.expires_at == 0 {
96            return false;
97        }
98        let now = std::time::SystemTime::now()
99            .duration_since(std::time::UNIX_EPOCH)
100            .unwrap_or_default()
101            .as_secs();
102        now > self.expires_at
103    }
104
105    /// Check if a specific tool scope is allowed.
106    ///
107    /// Supports:
108    /// - Exact match: `"tool:web_search"` matches `"tool:web_search"`
109    /// - Wildcard suffix: `"tool:github:*"` matches `"tool:github:search_repos"`
110    /// - Global wildcard: `"*"` matches everything
111    /// - Empty tool scope: always allowed (tool has no scope requirement)
112    /// - Legacy alias match for colon-namespaced tools, e.g.
113    ///   `tool:github_search_repositories` or `tool:github_*`
114    pub fn is_allowed(&self, tool_scope: &str) -> bool {
115        if self.is_expired() {
116            return false;
117        }
118        // Empty scope on tool means always allowed
119        if tool_scope.is_empty() {
120            return true;
121        }
122        let legacy_alias = legacy_tool_scope_alias(tool_scope);
123        // Check each scope pattern
124        for scope in &self.scopes {
125            if matches_wildcard(tool_scope, scope)
126                || legacy_alias
127                    .as_deref()
128                    .is_some_and(|alias| matches_wildcard(alias, scope))
129            {
130                return true;
131            }
132        }
133        false
134    }
135
136    /// Check access for a tool, returning an error if denied.
137    pub fn check_access(&self, tool_name: &str, tool_scope: &str) -> Result<(), ScopeError> {
138        if self.is_expired() {
139            return Err(ScopeError::Expired(self.expires_at));
140        }
141        if !self.is_allowed(tool_scope) {
142            return Err(ScopeError::AccessDenied(tool_name.to_string()));
143        }
144        Ok(())
145    }
146
147    /// Get time remaining until expiry, in seconds. Returns None if no expiry.
148    pub fn time_remaining(&self) -> Option<u64> {
149        if self.expires_at == 0 {
150            return None;
151        }
152        let now = std::time::SystemTime::now()
153            .duration_since(std::time::UNIX_EPOCH)
154            .unwrap_or_default()
155            .as_secs();
156        if now >= self.expires_at {
157            Some(0)
158        } else {
159            Some(self.expires_at - now)
160        }
161    }
162
163    /// Number of tool scopes (those starting with "tool:").
164    pub fn tool_scope_count(&self) -> usize {
165        self.scopes
166            .iter()
167            .filter(|s| s.starts_with("tool:"))
168            .count()
169    }
170
171    /// Number of skill scopes (those starting with "skill:").
172    pub fn skill_scope_count(&self) -> usize {
173        self.scopes
174            .iter()
175            .filter(|s| s.starts_with("skill:"))
176            .count()
177    }
178
179    /// Check if help is enabled.
180    pub fn help_enabled(&self) -> bool {
181        self.is_wildcard() || self.scopes.iter().any(|s| s == "help")
182    }
183
184    /// Check if this is an unrestricted (wildcard) scope.
185    pub fn is_wildcard(&self) -> bool {
186        self.scopes.iter().any(|s| s == "*")
187    }
188}
189
190/// Filter a list of tools to only those allowed by the scope config.
191pub fn filter_tools_by_scope<'a>(
192    tools: Vec<(&'a Provider, &'a Tool)>,
193    scopes: &ScopeConfig,
194) -> Vec<(&'a Provider, &'a Tool)> {
195    if scopes.is_wildcard() {
196        return tools;
197    }
198
199    tools
200        .into_iter()
201        .filter(|(_, tool)| match &tool.scope {
202            Some(scope) => scopes.is_allowed(scope),
203            None => true, // No scope required = always allowed
204        })
205        .collect()
206}
207
208#[cfg(test)]
209mod tests {
210    use super::*;
211
212    fn make_scopes(scopes: &[&str]) -> ScopeConfig {
213        ScopeConfig {
214            scopes: scopes.iter().map(|s| s.to_string()).collect(),
215            sub: "test-agent".into(),
216            expires_at: 0,
217            rate_config: None,
218        }
219    }
220
221    #[test]
222    fn test_exact_match() {
223        let config = make_scopes(&["tool:web_search", "tool:web_fetch"]);
224        assert!(config.is_allowed("tool:web_search"));
225        assert!(config.is_allowed("tool:web_fetch"));
226        assert!(!config.is_allowed("tool:patent_search"));
227    }
228
229    #[test]
230    fn test_wildcard_suffix() {
231        let config = make_scopes(&["tool:github:*"]);
232        assert!(config.is_allowed("tool:github:search_repos"));
233        assert!(config.is_allowed("tool:github:create_issue"));
234        assert!(!config.is_allowed("tool:linear:list_issues"));
235    }
236
237    #[test]
238    fn test_legacy_underscore_scope_matches_canonical_tool_scope() {
239        let config = make_scopes(&["tool:test_api_get_data"]);
240        assert!(config.is_allowed("tool:test_api:get_data"));
241    }
242
243    #[test]
244    fn test_legacy_underscore_wildcard_matches_canonical_tool_scope() {
245        let config = make_scopes(&["tool:github_*"]);
246        assert!(config.is_allowed("tool:github:search_repos"));
247        assert!(config.is_allowed("tool:github:create_issue"));
248        assert!(!config.is_allowed("tool:linear:list_issues"));
249    }
250
251    #[test]
252    fn test_global_wildcard() {
253        let config = make_scopes(&["*"]);
254        assert!(config.is_allowed("tool:anything"));
255        assert!(config.is_allowed("help"));
256        assert!(config.is_allowed("skill:whatever"));
257    }
258
259    #[test]
260    fn test_empty_scope_always_allowed() {
261        let config = make_scopes(&["tool:web_search"]);
262        assert!(config.is_allowed(""));
263    }
264
265    #[test]
266    fn test_expired_denies_all() {
267        let config = ScopeConfig {
268            scopes: vec!["tool:web_search".into()],
269            sub: "test".into(),
270            expires_at: 1,
271            rate_config: None,
272        };
273        assert!(config.is_expired());
274        assert!(!config.is_allowed("tool:web_search"));
275    }
276
277    #[test]
278    fn test_zero_expiry_means_no_expiry() {
279        let config = ScopeConfig {
280            scopes: vec!["tool:web_search".into()],
281            sub: "test".into(),
282            expires_at: 0,
283            rate_config: None,
284        };
285        assert!(!config.is_expired());
286        assert!(config.is_allowed("tool:web_search"));
287    }
288
289    #[test]
290    fn test_check_access_denied() {
291        let config = make_scopes(&["tool:web_search"]);
292        let result = config.check_access("patent_search", "tool:patent_search");
293        assert!(result.is_err());
294    }
295
296    #[test]
297    fn test_check_access_expired() {
298        let config = ScopeConfig {
299            scopes: vec!["tool:web_search".into()],
300            sub: "test".into(),
301            expires_at: 1,
302            rate_config: None,
303        };
304        let result = config.check_access("web_search", "tool:web_search");
305        assert!(result.is_err());
306    }
307
308    #[test]
309    fn test_help_enabled() {
310        assert!(make_scopes(&["tool:web_search", "help"]).help_enabled());
311        assert!(!make_scopes(&["tool:web_search"]).help_enabled());
312        assert!(make_scopes(&["*"]).help_enabled());
313    }
314
315    #[test]
316    fn test_scope_counts() {
317        let config = make_scopes(&["tool:a", "tool:b", "skill:c", "help"]);
318        assert_eq!(config.tool_scope_count(), 2);
319        assert_eq!(config.skill_scope_count(), 1);
320    }
321
322    #[test]
323    fn test_unrestricted() {
324        let config = ScopeConfig::unrestricted();
325        assert!(config.is_wildcard());
326        assert!(config.is_allowed("anything"));
327        assert!(config.help_enabled());
328    }
329
330    #[test]
331    fn test_mixed_patterns() {
332        let config = make_scopes(&["tool:web_search", "tool:github:*", "skill:research-*"]);
333        assert!(config.is_allowed("tool:web_search"));
334        assert!(config.is_allowed("tool:github:search_repos"));
335        assert!(config.is_allowed("skill:research-general"));
336        assert!(!config.is_allowed("tool:linear:list_issues"));
337        assert!(!config.is_allowed("skill:patent-analysis"));
338    }
339}