Skip to main content

forge_agent/
policy.rs

1//! Policy engine - Constraint validation system.
2//!
3//! This module implements policy validation for agent operations, ensuring
4//! that code changes comply with specified constraints.
5
6use crate::Result;
7use forge_core::Forge;
8use std::sync::Arc;
9
10/// Policy for constraint validation.
11///
12/// Policies define rules that must be satisfied before mutations are applied.
13#[derive(Clone, Debug)]
14pub enum Policy {
15    /// No unsafe code in public API
16    NoUnsafeInPublicAPI,
17
18    /// Preserve test coverage
19    PreserveTests,
20
21    /// Maximum cyclomatic complexity
22    MaxComplexity(usize),
23
24    /// Custom policy with validation function
25    Custom { name: String, description: String },
26}
27
28impl Policy {
29    /// Validates an edit operation against this policy.
30    pub async fn validate(&self, forge: &Forge, diff: &Diff) -> Result<PolicyReport> {
31        let mut violations = Vec::new();
32
33        match self {
34            Policy::NoUnsafeInPublicAPI => {
35                if let Some(v) = check_no_unsafe_in_public_api(diff).await? {
36                    violations.push(v);
37                }
38            }
39            Policy::PreserveTests => {
40                if let Some(v) = check_preserve_tests(forge, diff).await? {
41                    violations.push(v);
42                }
43            }
44            Policy::MaxComplexity(max) => {
45                if let Some(v) = check_max_complexity(forge, *max, diff).await? {
46                    violations.push(v);
47                }
48            }
49            Policy::Custom { name, .. } => {
50                // Custom policies are not yet implemented
51                // In production, this would use a DSL or plugin system
52                violations.push(PolicyViolation {
53                    policy: name.clone(),
54                    message: "Custom policy validation not yet implemented".to_string(),
55                    location: None,
56                });
57            }
58        }
59
60        Ok(PolicyReport {
61            policy: self.clone(),
62            violations: violations.clone(),
63            passed: violations.is_empty(),
64        })
65    }
66}
67
68/// Policy validator that can check multiple policies.
69#[derive(Clone)]
70pub struct PolicyValidator {
71    /// Forge SDK for graph queries
72    forge: Arc<Forge>,
73}
74
75impl PolicyValidator {
76    /// Creates a new policy validator.
77    pub fn new(forge: Forge) -> Self {
78        Self {
79            forge: Arc::new(forge),
80        }
81    }
82
83    /// Validates a diff against all policies.
84    pub async fn validate(&self, diff: &Diff, policies: &[Policy]) -> Result<PolicyReport> {
85        let mut all_violations = Vec::new();
86
87        for policy in policies {
88            let report = policy.validate(&self.forge, diff).await?;
89            all_violations.extend(report.violations);
90        }
91
92        Ok(PolicyReport {
93            policy: Policy::Custom {
94                name: "All".to_string(),
95                description: "Combined policy check".to_string(),
96            },
97            violations: all_violations.clone(),
98            passed: all_violations.is_empty(),
99        })
100    }
101
102    /// Validates a single policy.
103    pub async fn validate_single(&self, policy: &Policy, diff: &Diff) -> Result<PolicyReport> {
104        policy.validate(&self.forge, diff).await
105    }
106}
107
108/// Result of policy validation.
109#[derive(Clone, Debug)]
110pub struct PolicyReport {
111    /// The policy that was validated
112    pub policy: Policy,
113    /// Any violations found
114    pub violations: Vec<PolicyViolation>,
115    /// Whether validation passed
116    pub passed: bool,
117}
118
119/// A policy violation with location information.
120#[derive(Clone, Debug)]
121pub struct PolicyViolation {
122    /// Policy that was violated
123    pub policy: String,
124    /// Human-readable violation message
125    pub message: String,
126    /// Source location (if applicable)
127    pub location: Option<forge_core::types::Location>,
128}
129
130impl PolicyViolation {
131    /// Creates a new policy violation.
132    pub fn new(policy: impl Into<String>, message: impl Into<String>) -> Self {
133        Self {
134            policy: policy.into(),
135            message: message.into(),
136            location: None,
137        }
138    }
139
140    /// Creates a new policy violation with location.
141    pub fn with_location(
142        policy: impl Into<String>,
143        message: impl Into<String>,
144        location: forge_core::types::Location,
145    ) -> Self {
146        Self {
147            policy: policy.into(),
148            message: message.into(),
149            location: Some(location),
150        }
151    }
152}
153
154/// Policy composition: All policies must pass.
155#[derive(Clone, Debug)]
156pub struct AllPolicies {
157    /// The policies to validate
158    pub policies: Vec<Policy>,
159}
160
161impl AllPolicies {
162    /// Creates a new AllPolicies composition.
163    pub fn new(policies: Vec<Policy>) -> Self {
164        Self { policies }
165    }
166
167    /// Validates all policies.
168    pub async fn validate(&self, forge: &Forge, diff: &Diff) -> Result<PolicyReport> {
169        let mut all_violations = Vec::new();
170
171        for policy in &self.policies {
172            let report = policy.validate(forge, diff).await?;
173            all_violations.extend(report.violations);
174        }
175
176        Ok(PolicyReport {
177            policy: Policy::Custom {
178                name: "All".to_string(),
179                description: format!("All {} policies must pass", self.policies.len()),
180            },
181            violations: all_violations.clone(),
182            passed: all_violations.is_empty(),
183        })
184    }
185}
186
187/// Policy composition: At least one policy must pass.
188#[derive(Clone, Debug)]
189pub struct AnyPolicy {
190    /// The policies to validate
191    pub policies: Vec<Policy>,
192}
193
194impl AnyPolicy {
195    /// Creates a new AnyPolicy composition.
196    pub fn new(policies: Vec<Policy>) -> Self {
197        Self { policies }
198    }
199
200    /// Validates that at least one policy passes.
201    pub async fn validate(&self, forge: &Forge, diff: &Diff) -> Result<PolicyReport> {
202        let mut all_violations = Vec::new();
203        let mut any_passed = false;
204
205        for policy in &self.policies {
206            let report = policy.validate(forge, diff).await?;
207            if report.passed {
208                any_passed = true;
209            }
210            all_violations.extend(report.violations);
211        }
212
213        Ok(PolicyReport {
214            policy: Policy::Custom {
215                name: "Any".to_string(),
216                description: format!("At least one of {} policies must pass", self.policies.len()),
217            },
218            violations: if any_passed {
219                Vec::new()
220            } else {
221                all_violations.clone()
222            },
223            passed: any_passed,
224        })
225    }
226}
227
228/// A diff representing code changes.
229///
230/// This is a simplified representation - in production, this would be
231/// a proper AST diff or line-based diff.
232#[derive(Clone, Debug)]
233pub struct Diff {
234    /// File path
235    pub file_path: std::path::PathBuf,
236    /// Original content
237    pub original: String,
238    /// Modified content
239    pub modified: String,
240    /// Changed lines
241    pub changes: Vec<DiffChange>,
242}
243
244/// A single change in a diff.
245#[derive(Clone, Debug)]
246pub struct DiffChange {
247    /// Line number
248    pub line: usize,
249    /// Original line
250    pub original: String,
251    /// Modified line
252    pub modified: String,
253    /// Change type
254    pub kind: DiffChangeKind,
255}
256
257/// Type of diff change.
258#[derive(Clone, Debug, PartialEq, Eq)]
259pub enum DiffChangeKind {
260    /// Line was added
261    Added,
262    /// Line was removed
263    Removed,
264    /// Line was modified
265    Modified,
266}
267
268// Policy validation implementations
269
270/// Checks that no unsafe code appears in public API.
271async fn check_no_unsafe_in_public_api(diff: &Diff) -> Result<Option<PolicyViolation>> {
272    // Parse the modified content for unsafe blocks
273    let mut violations = Vec::new();
274
275    for (line_num, line) in diff.modified.lines().enumerate() {
276        let line_num = line_num + 1;
277        let trimmed = line.trim();
278
279        // Check for unsafe keyword
280        if trimmed.contains("unsafe") {
281            // Check if it's in a public context
282            let is_public_function = trimmed.starts_with("pub ")
283                && (trimmed.contains("fn ") || trimmed.contains("unsafe fn"));
284
285            let is_public_struct = trimmed.starts_with("pub ")
286                && (trimmed.contains("struct ") || trimmed.contains("enum "));
287
288            if is_public_function || is_public_struct {
289                violations.push(PolicyViolation::new(
290                    "NoUnsafeInPublicAPI",
291                    format!("Unsafe code in public API at line {}", line_num),
292                ));
293            }
294        }
295    }
296
297    Ok(if violations.is_empty() {
298        None
299    } else {
300        Some(PolicyViolation::new(
301            "NoUnsafeInPublicAPI",
302            format!(
303                "Found {} violations of unsafe in public API",
304                violations.len()
305            ),
306        ))
307    })
308}
309
310/// Checks that test coverage is preserved.
311async fn check_preserve_tests(_forge: &Forge, diff: &Diff) -> Result<Option<PolicyViolation>> {
312    // Count tests in original and modified
313    let original_tests = count_tests(&diff.original);
314    let modified_tests = count_tests(&diff.modified);
315
316    if modified_tests < original_tests {
317        Ok(Some(PolicyViolation::new(
318            "PreserveTests",
319            format!(
320                "Test count decreased from {} to {}",
321                original_tests, modified_tests
322            ),
323        )))
324    } else {
325        Ok(None)
326    }
327}
328
329/// Checks that cyclomatic complexity is within limit.
330async fn check_max_complexity(
331    _forge: &Forge,
332    max_complexity: usize,
333    diff: &Diff,
334) -> Result<Option<PolicyViolation>> {
335    // For each function in modified content, estimate complexity
336    // Find all functions and check their complexity
337    let violations: Vec<_> = diff
338        .modified
339        .lines()
340        .enumerate()
341        .filter(|(_, line)| line.trim().starts_with("pub fn ") || line.trim().starts_with("fn "))
342        .map(|(line_num, line)| {
343            // Get the rest of the line after "fn name("
344            let rest = if let Some(fn_pos) = line.find("fn ") {
345                &line[fn_pos + 3..]
346            } else {
347                line
348            };
349
350            // Count branching keywords in this function declaration line
351            let complexity = estimate_complexity_from_line(rest);
352
353            if complexity > max_complexity {
354                Some(PolicyViolation::new(
355                    "MaxComplexity",
356                    format!(
357                        "Function at line {} has complexity {}, exceeds max {}",
358                        line_num + 1,
359                        complexity,
360                        max_complexity
361                    ),
362                ))
363            } else {
364                None
365            }
366        })
367        .flatten()
368        .collect();
369
370    Ok(if violations.is_empty() {
371        None
372    } else {
373        Some(PolicyViolation::new(
374            "MaxComplexity",
375            format!(
376                "Found {} functions exceeding complexity limit",
377                violations.len()
378            ),
379        ))
380    })
381}
382
383/// Estimates complexity from a single line (for inline functions).
384fn estimate_complexity_from_line(line: &str) -> usize {
385    let mut complexity = 1; // Base complexity
386
387    // Count branching keywords
388    let if_count = line.matches("if ").count();
389    let while_count = line.matches("while ").count();
390    let for_count = line.matches("for ").count();
391    let match_count = line.matches("match ").count();
392    let and_count = line.matches("&&").count();
393    let or_count = line.matches("||").count();
394
395    complexity += if_count + while_count + for_count + match_count + and_count + or_count;
396    complexity
397}
398
399/// Counts test functions in content.
400fn count_tests(content: &str) -> usize {
401    content
402        .lines()
403        .filter(|line| {
404            let trimmed = line.trim();
405            trimmed.contains("#[test]") || trimmed.contains("#[tokio::test]")
406        })
407        .count()
408}
409
410#[cfg(test)]
411mod tests {
412    use super::*;
413    use std::path::PathBuf;
414    use tempfile::TempDir;
415    use forge_core::Forge;
416
417    #[tokio::test]
418    async fn test_policy_no_unsafe_in_public_api() {
419        let temp_dir = TempDir::new().unwrap();
420        let forge = Forge::open(temp_dir.path()).await.unwrap();
421
422        let diff = Diff {
423            file_path: PathBuf::from("test.rs"),
424            original: "fn safe() {}".to_string(),
425            modified: "pub unsafe fn dangerous() {}".to_string(),
426            changes: vec![],
427        };
428
429        let policy = Policy::NoUnsafeInPublicAPI;
430        let report = policy.validate(&forge, &diff).await.unwrap();
431
432        assert!(!report.passed);
433        assert_eq!(report.violations.len(), 1);
434    }
435
436    #[tokio::test]
437    async fn test_policy_preserve_tests() {
438        let temp_dir = TempDir::new().unwrap();
439        let forge = Forge::open(temp_dir.path()).await.unwrap();
440
441        let diff = Diff {
442            file_path: PathBuf::from("test.rs"),
443            original: "#[test]\nfn test_one() {}\n#[test]\nfn test_two() {}".to_string(),
444            modified: "#[test]\nfn test_one() {}".to_string(),
445            changes: vec![],
446        };
447
448        let policy = Policy::PreserveTests;
449        let report = policy.validate(&forge, &diff).await.unwrap();
450
451        assert!(!report.passed);
452        assert_eq!(report.violations.len(), 1);
453    }
454
455    #[tokio::test]
456    async fn test_policy_max_complexity() {
457        let temp_dir = TempDir::new().unwrap();
458        let forge = Forge::open(temp_dir.path()).await.unwrap();
459
460        let diff = Diff {
461            file_path: PathBuf::from("test.rs"),
462            original: "".to_string(),
463            modified: "pub fn complex() { if x { if y { if z {} } } }".to_string(),
464            changes: vec![],
465        };
466
467        let policy = Policy::MaxComplexity(3);
468        let report = policy.validate(&forge, &diff).await.unwrap();
469
470        assert!(!report.passed);
471    }
472
473    #[tokio::test]
474    async fn test_all_policies() {
475        let temp_dir = TempDir::new().unwrap();
476        let forge = Forge::open(temp_dir.path()).await.unwrap();
477
478        let diff = Diff {
479            file_path: PathBuf::from("test.rs"),
480            original: "".to_string(),
481            modified: "pub fn safe() {}".to_string(),
482            changes: vec![],
483        };
484
485        let policies = vec![Policy::NoUnsafeInPublicAPI, Policy::PreserveTests];
486
487        let all = AllPolicies::new(policies);
488        let report = all.validate(&forge, &diff).await.unwrap();
489
490        assert!(report.passed);
491    }
492
493    #[tokio::test]
494    async fn test_any_policy() {
495        let temp_dir = TempDir::new().unwrap();
496        let forge = Forge::open(temp_dir.path()).await.unwrap();
497
498        let diff = Diff {
499            file_path: PathBuf::from("test.rs"),
500            original: "".to_string(),
501            modified: "pub unsafe fn dangerous() {}".to_string(),
502            changes: vec![],
503        };
504
505        let policies = vec![
506            Policy::NoUnsafeInPublicAPI,
507            Policy::Custom {
508                name: "AlwaysPass".to_string(),
509                description: "Always passes".to_string(),
510            },
511        ];
512
513        let any = AnyPolicy::new(policies);
514        let report = any.validate(&forge, &diff).await.unwrap();
515
516        // Custom policy fails but Any still passes because first one also fails
517        // Actually with current implementation, Custom fails too
518        assert!(!report.passed);
519    }
520
521    #[tokio::test]
522    async fn test_count_tests() {
523        let content = r#"
524            #[test]
525            fn test_one() {}
526
527            #[test]
528            fn test_two() {}
529
530            #[tokio::test]
531            async fn test_three() {}
532        "#;
533
534        let count = count_tests(content);
535        assert_eq!(count, 3);
536    }
537}