Skip to main content

agent_core_runtime/permissions/
batch.rs

1//! Batch permission request handling.
2//!
3//! When multiple tools run in parallel, their permission requests are
4//! collected into a batch and presented to the user together, avoiding
5//! the deadlock issue with sequential permission prompts.
6
7use super::{Grant, GrantTarget, PermissionLevel, PermissionRequest};
8use serde::{Deserialize, Serialize};
9use std::collections::{HashMap, HashSet};
10use std::path::PathBuf;
11
12/// A batch of permission requests from parallel tool executions.
13///
14/// Batching permission requests allows the UI to present multiple
15/// requests together, letting the user make informed decisions about
16/// granting access to related resources.
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct BatchPermissionRequest {
19    /// Unique identifier for this batch.
20    pub batch_id: String,
21    /// The individual permission requests in this batch.
22    pub requests: Vec<PermissionRequest>,
23    /// Suggested grants that would cover all requests.
24    pub suggested_grants: Vec<Grant>,
25}
26
27impl BatchPermissionRequest {
28    /// Creates a new batch permission request.
29    pub fn new(batch_id: impl Into<String>, requests: Vec<PermissionRequest>) -> Self {
30        let batch_id = batch_id.into();
31        let suggested_grants = compute_suggested_grants(&requests);
32        Self {
33            batch_id,
34            requests,
35            suggested_grants,
36        }
37    }
38
39    /// Returns the number of requests in this batch.
40    pub fn len(&self) -> usize {
41        self.requests.len()
42    }
43
44    /// Returns true if the batch has no requests.
45    pub fn is_empty(&self) -> bool {
46        self.requests.is_empty()
47    }
48
49    /// Returns the unique request IDs in this batch.
50    pub fn request_ids(&self) -> Vec<&str> {
51        self.requests.iter().map(|r| r.id.as_str()).collect()
52    }
53}
54
55/// Response to a batch permission request.
56#[derive(Debug, Clone, Serialize, Deserialize)]
57pub struct BatchPermissionResponse {
58    /// The batch ID this response is for.
59    pub batch_id: String,
60    /// Grants that were approved by the user.
61    pub approved_grants: Vec<Grant>,
62    /// Request IDs that were explicitly denied.
63    pub denied_requests: HashSet<String>,
64    /// Request IDs that were auto-approved (already had permission).
65    pub auto_approved: HashSet<String>,
66}
67
68impl BatchPermissionResponse {
69    /// Creates a response where all requests were granted.
70    pub fn all_granted(batch_id: impl Into<String>, grants: Vec<Grant>) -> Self {
71        Self {
72            batch_id: batch_id.into(),
73            approved_grants: grants,
74            denied_requests: HashSet::new(),
75            auto_approved: HashSet::new(),
76        }
77    }
78
79    /// Creates a response where all requests were denied.
80    pub fn all_denied(batch_id: impl Into<String>, request_ids: impl IntoIterator<Item = String>) -> Self {
81        Self {
82            batch_id: batch_id.into(),
83            approved_grants: Vec::new(),
84            denied_requests: request_ids.into_iter().collect(),
85            auto_approved: HashSet::new(),
86        }
87    }
88
89    /// Creates a response with auto-approved requests.
90    pub fn with_auto_approved(
91        batch_id: impl Into<String>,
92        auto_approved: impl IntoIterator<Item = String>,
93    ) -> Self {
94        Self {
95            batch_id: batch_id.into(),
96            approved_grants: Vec::new(),
97            denied_requests: HashSet::new(),
98            auto_approved: auto_approved.into_iter().collect(),
99        }
100    }
101
102    /// Checks if a specific request was granted (either explicitly or auto-approved).
103    ///
104    /// # Conflict Resolution
105    /// If a request_id appears in both `auto_approved` and `denied_requests`,
106    /// this is treated as a malformed response. A warning is logged and the
107    /// request is denied (safe default).
108    pub fn is_granted(&self, request_id: &str, request: &PermissionRequest) -> bool {
109        // Validate: request cannot be both auto-approved and denied
110        let in_auto_approved = self.auto_approved.contains(request_id);
111        let in_denied = self.denied_requests.contains(request_id);
112
113        if in_auto_approved && in_denied {
114            tracing::warn!(
115                request_id,
116                batch_id = %self.batch_id,
117                "Request appears in both auto_approved and denied_requests, treating as denied"
118            );
119            return false;
120        }
121
122        // Check if auto-approved
123        if in_auto_approved {
124            return true;
125        }
126
127        // Check if denied
128        if in_denied {
129            return false;
130        }
131
132        // Check if any approved grant satisfies this request
133        self.approved_grants.iter().any(|grant| grant.satisfies(request))
134    }
135
136    /// Returns whether any requests were denied.
137    pub fn has_denials(&self) -> bool {
138        !self.denied_requests.is_empty()
139    }
140
141    /// Returns the number of approved grants.
142    pub fn approved_count(&self) -> usize {
143        self.approved_grants.len() + self.auto_approved.len()
144    }
145}
146
147/// Computes suggested grants that would cover all requests.
148///
149/// This function analyzes the requests and suggests a minimal set of
150/// grants that would satisfy all of them. It groups requests by common
151/// parent directories and suggests recursive grants where appropriate.
152pub fn compute_suggested_grants(requests: &[PermissionRequest]) -> Vec<Grant> {
153    if requests.is_empty() {
154        return Vec::new();
155    }
156
157    let mut grants = Vec::new();
158
159    // Group requests by target type
160    let mut path_requests: Vec<&PermissionRequest> = Vec::new();
161    let mut domain_requests: Vec<&PermissionRequest> = Vec::new();
162    let mut command_requests: Vec<&PermissionRequest> = Vec::new();
163
164    for req in requests {
165        match &req.target {
166            GrantTarget::Path { .. } => path_requests.push(req),
167            GrantTarget::Domain { .. } => domain_requests.push(req),
168            GrantTarget::Command { .. } => command_requests.push(req),
169        }
170    }
171
172    // Compute grants for each target type
173    grants.extend(compute_path_grants(&path_requests));
174    grants.extend(compute_domain_grants(&domain_requests));
175    grants.extend(compute_command_grants(&command_requests));
176
177    grants
178}
179
180/// Computes suggested grants for path-based requests.
181fn compute_path_grants(requests: &[&PermissionRequest]) -> Vec<Grant> {
182    if requests.is_empty() {
183        return Vec::new();
184    }
185
186    // Group paths by their parent directories and track max level needed
187    let mut dir_groups: HashMap<PathBuf, (PermissionLevel, Vec<PathBuf>)> = HashMap::new();
188
189    for req in requests {
190        if let GrantTarget::Path { path, .. } = &req.target {
191            let parent = path.parent().unwrap_or(path).to_path_buf();
192            let entry = dir_groups
193                .entry(parent)
194                .or_insert((PermissionLevel::None, Vec::new()));
195            entry.0 = std::cmp::max(entry.0, req.required_level);
196            entry.1.push(path.clone());
197        }
198    }
199
200    // Try to find common ancestors for related directories
201    let merged_groups = merge_related_directories(dir_groups);
202
203    // Create grants for each group
204    merged_groups
205        .into_iter()
206        .map(|(dir, (level, paths))| {
207            // If multiple paths share the same parent, make it recursive
208            let recursive = paths.len() > 1;
209            Grant::new(GrantTarget::path(dir, recursive), level)
210        })
211        .collect()
212}
213
214/// Merges directory groups that share a common ancestor.
215fn merge_related_directories(
216    groups: HashMap<PathBuf, (PermissionLevel, Vec<PathBuf>)>,
217) -> HashMap<PathBuf, (PermissionLevel, Vec<PathBuf>)> {
218    if groups.len() <= 1 {
219        return groups;
220    }
221
222    let mut result: HashMap<PathBuf, (PermissionLevel, Vec<PathBuf>)> = HashMap::new();
223
224    for (dir, (level, paths)) in groups {
225        // Check if this directory can be merged with an existing one
226        let mut merged = false;
227
228        for (existing_dir, (existing_level, existing_paths)) in result.iter_mut() {
229            // Check if one is ancestor of the other
230            if dir.starts_with(existing_dir) {
231                // Existing dir is ancestor - add paths and update level
232                existing_paths.extend(paths.clone());
233                *existing_level = std::cmp::max(*existing_level, level);
234                merged = true;
235                break;
236            } else if existing_dir.starts_with(&dir) {
237                // New dir is ancestor - this case needs special handling
238                // For simplicity, we'll just add as separate entry
239            } else {
240                // Check for common ancestor within reasonable depth
241                if let Some(common) = find_common_ancestor(&dir, existing_dir, 3) {
242                    // If close enough, could merge under common ancestor
243                    // For now, keep separate to avoid over-granting
244                    let _ = common;
245                }
246            }
247        }
248
249        if !merged {
250            result.insert(dir, (level, paths));
251        }
252    }
253
254    result
255}
256
257/// Finds the common ancestor of two paths, up to a maximum depth from either path.
258fn find_common_ancestor(path1: &PathBuf, path2: &PathBuf, max_depth: usize) -> Option<PathBuf> {
259    let ancestors1: Vec<_> = path1.ancestors().take(max_depth + 1).collect();
260    let ancestors2: Vec<_> = path2.ancestors().take(max_depth + 1).collect();
261
262    for a1 in &ancestors1 {
263        for a2 in &ancestors2 {
264            if a1 == a2 {
265                return Some(a1.to_path_buf());
266            }
267        }
268    }
269
270    None
271}
272
273/// Computes suggested grants for domain-based requests.
274fn compute_domain_grants(requests: &[&PermissionRequest]) -> Vec<Grant> {
275    if requests.is_empty() {
276        return Vec::new();
277    }
278
279    // Group by base domain and track max level
280    let mut domain_levels: HashMap<String, PermissionLevel> = HashMap::new();
281
282    for req in requests {
283        if let GrantTarget::Domain { pattern } = &req.target {
284            let base_domain = extract_base_domain(pattern);
285            let entry = domain_levels.entry(base_domain).or_insert(PermissionLevel::None);
286            *entry = std::cmp::max(*entry, req.required_level);
287        }
288    }
289
290    // Create grants - if multiple subdomains, suggest wildcard
291    domain_levels
292        .into_iter()
293        .map(|(domain, level)| {
294            // Could enhance to detect multiple subdomains and suggest *.domain
295            Grant::new(GrantTarget::domain(domain), level)
296        })
297        .collect()
298}
299
300/// Extracts the base domain from a domain pattern.
301fn extract_base_domain(pattern: &str) -> String {
302    // Remove wildcard prefix if present
303    pattern.strip_prefix("*.").unwrap_or(pattern).to_string()
304}
305
306/// Computes suggested grants for command-based requests.
307fn compute_command_grants(requests: &[&PermissionRequest]) -> Vec<Grant> {
308    if requests.is_empty() {
309        return Vec::new();
310    }
311
312    // Group by command prefix (first word) and track max level
313    let mut cmd_groups: HashMap<String, (PermissionLevel, Vec<String>)> = HashMap::new();
314
315    for req in requests {
316        if let GrantTarget::Command { pattern } = &req.target {
317            let prefix = extract_command_prefix(pattern);
318            let entry = cmd_groups
319                .entry(prefix)
320                .or_insert((PermissionLevel::None, Vec::new()));
321            entry.0 = std::cmp::max(entry.0, req.required_level);
322            entry.1.push(pattern.clone());
323        }
324    }
325
326    // Create grants - if multiple commands with same prefix, suggest wildcard
327    cmd_groups
328        .into_iter()
329        .map(|(prefix, (level, commands))| {
330            let pattern = if commands.len() > 1 {
331                format!("{} *", prefix)
332            } else {
333                commands.into_iter().next().unwrap_or(prefix)
334            };
335            Grant::new(GrantTarget::command(pattern), level)
336        })
337        .collect()
338}
339
340/// Extracts the command prefix (first word) from a command.
341fn extract_command_prefix(command: &str) -> String {
342    command
343        .split_whitespace()
344        .next()
345        .unwrap_or(command)
346        .to_string()
347}
348
349/// User actions for responding to a batch permission request.
350#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
351pub enum BatchAction {
352    /// Approve all requests using the suggested grants.
353    AllowAll,
354    /// Approve selected requests only.
355    AllowSelected,
356    /// Deny all requests.
357    DenyAll,
358}
359
360#[cfg(test)]
361mod tests {
362    use super::*;
363
364    #[test]
365    fn test_batch_request_creation() {
366        let requests = vec![
367            PermissionRequest::file_read("1", "/project/src/main.rs"),
368            PermissionRequest::file_read("2", "/project/src/lib.rs"),
369        ];
370
371        let batch = BatchPermissionRequest::new("batch-1", requests);
372
373        assert_eq!(batch.batch_id, "batch-1");
374        assert_eq!(batch.len(), 2);
375        assert!(!batch.suggested_grants.is_empty());
376    }
377
378    #[test]
379    fn test_batch_response_all_granted() {
380        let grants = vec![Grant::read_path("/project/src", true)];
381        let response = BatchPermissionResponse::all_granted("batch-1", grants);
382
383        let request = PermissionRequest::file_read("1", "/project/src/main.rs");
384        assert!(response.is_granted("1", &request));
385        assert!(!response.has_denials());
386    }
387
388    #[test]
389    fn test_batch_response_all_denied() {
390        let response = BatchPermissionResponse::all_denied("batch-1", vec!["1".to_string(), "2".to_string()]);
391
392        let request = PermissionRequest::file_read("1", "/project/src/main.rs");
393        assert!(!response.is_granted("1", &request));
394        assert!(response.has_denials());
395    }
396
397    #[test]
398    fn test_batch_response_auto_approved() {
399        let response = BatchPermissionResponse::with_auto_approved("batch-1", vec!["1".to_string()]);
400
401        let request = PermissionRequest::file_read("1", "/project/src/main.rs");
402        assert!(response.is_granted("1", &request));
403    }
404
405    #[test]
406    fn test_compute_suggested_grants_single_path() {
407        let requests = vec![PermissionRequest::file_read("1", "/project/src/main.rs")];
408
409        let grants = compute_suggested_grants(&requests);
410
411        assert_eq!(grants.len(), 1);
412        assert_eq!(grants[0].level, PermissionLevel::Read);
413    }
414
415    #[test]
416    fn test_compute_suggested_grants_multiple_same_dir() {
417        let requests = vec![
418            PermissionRequest::file_read("1", "/project/src/main.rs"),
419            PermissionRequest::file_read("2", "/project/src/lib.rs"),
420        ];
421
422        let grants = compute_suggested_grants(&requests);
423
424        // Should suggest a single grant for the parent directory
425        assert_eq!(grants.len(), 1);
426        if let GrantTarget::Path { path, recursive } = &grants[0].target {
427            assert_eq!(path.to_str().unwrap(), "/project/src");
428            assert!(recursive); // Multiple files means recursive
429        } else {
430            panic!("Expected path target");
431        }
432    }
433
434    #[test]
435    fn test_compute_suggested_grants_different_levels() {
436        let requests = vec![
437            PermissionRequest::file_read("1", "/project/src/main.rs"),
438            PermissionRequest::file_write("2", "/project/src/lib.rs"),
439        ];
440
441        let grants = compute_suggested_grants(&requests);
442
443        // Should use the highest level needed
444        assert_eq!(grants.len(), 1);
445        assert_eq!(grants[0].level, PermissionLevel::Write);
446    }
447
448    #[test]
449    fn test_compute_suggested_grants_mixed_targets() {
450        let requests = vec![
451            PermissionRequest::file_read("1", "/project/src/main.rs"),
452            PermissionRequest::command_execute("2", "git status"),
453        ];
454
455        let grants = compute_suggested_grants(&requests);
456
457        // Should have separate grants for path and command
458        assert_eq!(grants.len(), 2);
459    }
460
461    #[test]
462    fn test_compute_suggested_grants_commands() {
463        let requests = vec![
464            PermissionRequest::command_execute("1", "git status"),
465            PermissionRequest::command_execute("2", "git commit -m 'msg'"),
466        ];
467
468        let grants = compute_suggested_grants(&requests);
469
470        // Should suggest "git *" pattern
471        assert_eq!(grants.len(), 1);
472        if let GrantTarget::Command { pattern } = &grants[0].target {
473            assert!(pattern.contains("git"));
474        } else {
475            panic!("Expected command target");
476        }
477    }
478
479    #[test]
480    fn test_is_granted_conflict_resolution() {
481        // Create a malformed response where the same ID is in both sets
482        let response = BatchPermissionResponse {
483            batch_id: "batch-1".to_string(),
484            approved_grants: Vec::new(),
485            denied_requests: ["conflict-id".to_string()].into_iter().collect(),
486            auto_approved: ["conflict-id".to_string()].into_iter().collect(),
487        };
488
489        let request = PermissionRequest::file_read("conflict-id", "/project/src/main.rs");
490
491        // Should be denied (safe default) when in both sets
492        assert!(
493            !response.is_granted("conflict-id", &request),
494            "Conflicting request should be denied as safe default"
495        );
496    }
497}