1use crate::{AgentError, Result};
7use forge_core::Forge;
8use std::sync::Arc;
9
10#[derive(Clone, Debug)]
14pub enum Policy {
15 NoUnsafeInPublicAPI,
17
18 PreserveTests,
20
21 MaxComplexity(usize),
23
24 Custom { name: String, description: String },
26}
27
28impl Policy {
29 pub async fn validate(&self, forge: &Forge, diff: &Diff) -> Result<PolicyReport> {
31 let mut violations = Vec::new();
32
33 match self {
34 Policy::NoUnsafeInPublicAPI => {
35 if let Some(v) = check_no_unsafe_in_public_api(diff).await? {
36 violations.push(v);
37 }
38 }
39 Policy::PreserveTests => {
40 if let Some(v) = check_preserve_tests(forge, diff).await? {
41 violations.push(v);
42 }
43 }
44 Policy::MaxComplexity(max) => {
45 if let Some(v) = check_max_complexity(forge, *max, diff).await? {
46 violations.push(v);
47 }
48 }
49 Policy::Custom { name, .. } => {
50 violations.push(PolicyViolation {
53 policy: name.clone(),
54 message: "Custom policy validation not yet implemented".to_string(),
55 location: None,
56 });
57 }
58 }
59
60 Ok(PolicyReport {
61 policy: self.clone(),
62 violations: violations.clone(),
63 passed: violations.is_empty(),
64 })
65 }
66}
67
68#[derive(Clone)]
70pub struct PolicyValidator {
71 forge: Arc<Forge>,
73}
74
75impl PolicyValidator {
76 pub fn new(forge: Forge) -> Self {
78 Self {
79 forge: Arc::new(forge),
80 }
81 }
82
83 pub async fn validate(&self, diff: &Diff, policies: &[Policy]) -> Result<PolicyReport> {
85 let mut all_violations = Vec::new();
86
87 for policy in policies {
88 let report = policy.validate(&self.forge, diff).await?;
89 all_violations.extend(report.violations);
90 }
91
92 Ok(PolicyReport {
93 policy: Policy::Custom {
94 name: "All".to_string(),
95 description: "Combined policy check".to_string(),
96 },
97 violations: all_violations.clone(),
98 passed: all_violations.is_empty(),
99 })
100 }
101
102 pub async fn validate_single(&self, policy: &Policy, diff: &Diff) -> Result<PolicyReport> {
104 policy.validate(&self.forge, diff).await
105 }
106}
107
108#[derive(Clone, Debug)]
110pub struct PolicyReport {
111 pub policy: Policy,
113 pub violations: Vec<PolicyViolation>,
115 pub passed: bool,
117}
118
119#[derive(Clone, Debug)]
121pub struct PolicyViolation {
122 pub policy: String,
124 pub message: String,
126 pub location: Option<forge_core::types::Location>,
128}
129
130impl PolicyViolation {
131 pub fn new(policy: impl Into<String>, message: impl Into<String>) -> Self {
133 Self {
134 policy: policy.into(),
135 message: message.into(),
136 location: None,
137 }
138 }
139
140 pub fn with_location(
142 policy: impl Into<String>,
143 message: impl Into<String>,
144 location: forge_core::types::Location,
145 ) -> Self {
146 Self {
147 policy: policy.into(),
148 message: message.into(),
149 location: Some(location),
150 }
151 }
152}
153
154#[derive(Clone, Debug)]
156pub struct AllPolicies {
157 pub policies: Vec<Policy>,
159}
160
161impl AllPolicies {
162 pub fn new(policies: Vec<Policy>) -> Self {
164 Self { policies }
165 }
166
167 pub async fn validate(&self, forge: &Forge, diff: &Diff) -> Result<PolicyReport> {
169 let mut all_violations = Vec::new();
170
171 for policy in &self.policies {
172 let report = policy.validate(forge, diff).await?;
173 all_violations.extend(report.violations);
174 }
175
176 Ok(PolicyReport {
177 policy: Policy::Custom {
178 name: "All".to_string(),
179 description: format!("All {} policies must pass", self.policies.len()),
180 },
181 violations: all_violations.clone(),
182 passed: all_violations.is_empty(),
183 })
184 }
185}
186
187#[derive(Clone, Debug)]
189pub struct AnyPolicy {
190 pub policies: Vec<Policy>,
192}
193
194impl AnyPolicy {
195 pub fn new(policies: Vec<Policy>) -> Self {
197 Self { policies }
198 }
199
200 pub async fn validate(&self, forge: &Forge, diff: &Diff) -> Result<PolicyReport> {
202 let mut all_violations = Vec::new();
203 let mut any_passed = false;
204
205 for policy in &self.policies {
206 let report = policy.validate(forge, diff).await?;
207 if report.passed {
208 any_passed = true;
209 }
210 all_violations.extend(report.violations);
211 }
212
213 Ok(PolicyReport {
214 policy: Policy::Custom {
215 name: "Any".to_string(),
216 description: format!("At least one of {} policies must pass", self.policies.len()),
217 },
218 violations: if any_passed {
219 Vec::new()
220 } else {
221 all_violations.clone()
222 },
223 passed: any_passed,
224 })
225 }
226}
227
228#[derive(Clone, Debug)]
233pub struct Diff {
234 pub file_path: std::path::PathBuf,
236 pub original: String,
238 pub modified: String,
240 pub changes: Vec<DiffChange>,
242}
243
244#[derive(Clone, Debug)]
246pub struct DiffChange {
247 pub line: usize,
249 pub original: String,
251 pub modified: String,
253 pub kind: DiffChangeKind,
255}
256
257#[derive(Clone, Debug, PartialEq, Eq)]
259pub enum DiffChangeKind {
260 Added,
262 Removed,
264 Modified,
266}
267
268async fn check_no_unsafe_in_public_api(diff: &Diff) -> Result<Option<PolicyViolation>> {
272 let mut violations = Vec::new();
274
275 for (line_num, line) in diff.modified.lines().enumerate() {
276 let line_num = line_num + 1;
277 let trimmed = line.trim();
278
279 if trimmed.contains("unsafe") {
281 let is_public_function = trimmed.starts_with("pub ")
283 && (trimmed.contains("fn ") || trimmed.contains("unsafe fn"));
284
285 let is_public_struct = trimmed.starts_with("pub ")
286 && (trimmed.contains("struct ") || trimmed.contains("enum "));
287
288 if is_public_function || is_public_struct {
289 violations.push(PolicyViolation::new(
290 "NoUnsafeInPublicAPI",
291 format!("Unsafe code in public API at line {}", line_num),
292 ));
293 }
294 }
295 }
296
297 Ok(if violations.is_empty() {
298 None
299 } else {
300 Some(PolicyViolation::new(
301 "NoUnsafeInPublicAPI",
302 format!(
303 "Found {} violations of unsafe in public API",
304 violations.len()
305 ),
306 ))
307 })
308}
309
310async fn check_preserve_tests(_forge: &Forge, diff: &Diff) -> Result<Option<PolicyViolation>> {
312 let original_tests = count_tests(&diff.original);
314 let modified_tests = count_tests(&diff.modified);
315
316 if modified_tests < original_tests {
317 Ok(Some(PolicyViolation::new(
318 "PreserveTests",
319 format!(
320 "Test count decreased from {} to {}",
321 original_tests, modified_tests
322 ),
323 )))
324 } else {
325 Ok(None)
326 }
327}
328
329async fn check_max_complexity(
331 _forge: &Forge,
332 max_complexity: usize,
333 diff: &Diff,
334) -> Result<Option<PolicyViolation>> {
335 let violations: Vec<_> = diff
338 .modified
339 .lines()
340 .enumerate()
341 .filter(|(_, line)| line.trim().starts_with("pub fn ") || line.trim().starts_with("fn "))
342 .map(|(line_num, line)| {
343 let rest = if let Some(fn_pos) = line.find("fn ") {
345 &line[fn_pos + 3..]
346 } else {
347 line
348 };
349
350 let complexity = estimate_complexity_from_line(rest);
352
353 if complexity > max_complexity {
354 Some(PolicyViolation::new(
355 "MaxComplexity",
356 format!(
357 "Function at line {} has complexity {}, exceeds max {}",
358 line_num + 1,
359 complexity,
360 max_complexity
361 ),
362 ))
363 } else {
364 None
365 }
366 })
367 .flatten()
368 .collect();
369
370 Ok(if violations.is_empty() {
371 None
372 } else {
373 Some(PolicyViolation::new(
374 "MaxComplexity",
375 format!(
376 "Found {} functions exceeding complexity limit",
377 violations.len()
378 ),
379 ))
380 })
381}
382
383fn estimate_complexity_from_line(line: &str) -> usize {
385 let mut complexity = 1; let if_count = line.matches("if ").count();
389 let while_count = line.matches("while ").count();
390 let for_count = line.matches("for ").count();
391 let match_count = line.matches("match ").count();
392 let and_count = line.matches("&&").count();
393 let or_count = line.matches("||").count();
394
395 complexity += if_count + while_count + for_count + match_count + and_count + or_count;
396 complexity
397}
398
399fn count_tests(content: &str) -> usize {
401 content
402 .lines()
403 .filter(|line| {
404 let trimmed = line.trim();
405 trimmed.contains("#[test]") || trimmed.contains("#[tokio::test]")
406 })
407 .count()
408}
409
410fn estimate_complexity(lines: &[&str]) -> usize {
412 let mut complexity = 1;
414
415 for line in lines {
417 let trimmed = line.trim();
418 if trimmed.contains("if ")
421 || trimmed.contains("else if ")
422 || trimmed.contains("while ")
423 || trimmed.contains("for ")
424 || trimmed.contains("match ")
425 || trimmed.contains("&& ")
426 || trimmed.contains("|| ")
427 {
428 complexity += 1;
429 }
430 }
431
432 complexity
433}
434
435#[cfg(test)]
436mod tests {
437 use super::*;
438 use std::path::PathBuf;
439 use tempfile::TempDir;
440
441 #[tokio::test]
442 async fn test_policy_no_unsafe_in_public_api() {
443 let temp_dir = TempDir::new().unwrap();
444 let forge = Forge::open(temp_dir.path()).await.unwrap();
445
446 let diff = Diff {
447 file_path: PathBuf::from("test.rs"),
448 original: "fn safe() {}".to_string(),
449 modified: "pub unsafe fn dangerous() {}".to_string(),
450 changes: vec![],
451 };
452
453 let policy = Policy::NoUnsafeInPublicAPI;
454 let report = policy.validate(&forge, &diff).await.unwrap();
455
456 assert!(!report.passed);
457 assert_eq!(report.violations.len(), 1);
458 }
459
460 #[tokio::test]
461 async fn test_policy_preserve_tests() {
462 let temp_dir = TempDir::new().unwrap();
463 let forge = Forge::open(temp_dir.path()).await.unwrap();
464
465 let diff = Diff {
466 file_path: PathBuf::from("test.rs"),
467 original: "#[test]\nfn test_one() {}\n#[test]\nfn test_two() {}".to_string(),
468 modified: "#[test]\nfn test_one() {}".to_string(),
469 changes: vec![],
470 };
471
472 let policy = Policy::PreserveTests;
473 let report = policy.validate(&forge, &diff).await.unwrap();
474
475 assert!(!report.passed);
476 assert_eq!(report.violations.len(), 1);
477 }
478
479 #[tokio::test]
480 async fn test_policy_max_complexity() {
481 let temp_dir = TempDir::new().unwrap();
482 let forge = Forge::open(temp_dir.path()).await.unwrap();
483
484 let diff = Diff {
485 file_path: PathBuf::from("test.rs"),
486 original: "".to_string(),
487 modified: "pub fn complex() { if x { if y { if z {} } } }".to_string(),
488 changes: vec![],
489 };
490
491 let policy = Policy::MaxComplexity(3);
492 let report = policy.validate(&forge, &diff).await.unwrap();
493
494 assert!(!report.passed);
495 }
496
497 #[tokio::test]
498 async fn test_all_policies() {
499 let temp_dir = TempDir::new().unwrap();
500 let forge = Forge::open(temp_dir.path()).await.unwrap();
501
502 let diff = Diff {
503 file_path: PathBuf::from("test.rs"),
504 original: "".to_string(),
505 modified: "pub fn safe() {}".to_string(),
506 changes: vec![],
507 };
508
509 let policies = vec![Policy::NoUnsafeInPublicAPI, Policy::PreserveTests];
510
511 let all = AllPolicies::new(policies);
512 let report = all.validate(&forge, &diff).await.unwrap();
513
514 assert!(report.passed);
515 }
516
517 #[tokio::test]
518 async fn test_any_policy() {
519 let temp_dir = TempDir::new().unwrap();
520 let forge = Forge::open(temp_dir.path()).await.unwrap();
521
522 let diff = Diff {
523 file_path: PathBuf::from("test.rs"),
524 original: "".to_string(),
525 modified: "pub unsafe fn dangerous() {}".to_string(),
526 changes: vec![],
527 };
528
529 let policies = vec![
530 Policy::NoUnsafeInPublicAPI,
531 Policy::Custom {
532 name: "AlwaysPass".to_string(),
533 description: "Always passes".to_string(),
534 },
535 ];
536
537 let any = AnyPolicy::new(policies);
538 let report = any.validate(&forge, &diff).await.unwrap();
539
540 assert!(!report.passed);
543 }
544
545 #[tokio::test]
546 async fn test_count_tests() {
547 let content = r#"
548 #[test]
549 fn test_one() {}
550
551 #[test]
552 fn test_two() {}
553
554 #[tokio::test]
555 async fn test_three() {}
556 "#;
557
558 let count = count_tests(content);
559 assert_eq!(count, 3);
560 }
561
562 #[test]
563 fn test_estimate_complexity() {
564 let lines = vec!["fn simple() {", " let x = 1;", "}"];
565
566 let complexity = estimate_complexity(&lines);
567 assert_eq!(complexity, 1);
568 }
569
570 #[test]
571 fn test_estimate_complexity_with_branches() {
572 let lines = vec![
573 "fn complex() {",
574 " if x {",
575 " if y {",
576 " }",
577 " }",
578 "}",
579 ];
580
581 let complexity = estimate_complexity(&lines);
582 assert_eq!(complexity, 3); }
584}