1use crate::Result;
7use forge_core::Forge;
8use std::sync::Arc;
9
10#[derive(Clone, Debug)]
14pub enum Policy {
15 NoUnsafeInPublicAPI,
17
18 PreserveTests,
20
21 MaxComplexity(usize),
23
24 Custom { name: String, description: String },
26}
27
28impl Policy {
29 pub async fn validate(&self, forge: &Forge, diff: &Diff) -> Result<PolicyReport> {
31 let mut violations = Vec::new();
32
33 match self {
34 Policy::NoUnsafeInPublicAPI => {
35 if let Some(v) = check_no_unsafe_in_public_api(diff).await? {
36 violations.push(v);
37 }
38 }
39 Policy::PreserveTests => {
40 if let Some(v) = check_preserve_tests(forge, diff).await? {
41 violations.push(v);
42 }
43 }
44 Policy::MaxComplexity(max) => {
45 if let Some(v) = check_max_complexity(forge, *max, diff).await? {
46 violations.push(v);
47 }
48 }
49 Policy::Custom { name, .. } => {
50 violations.push(PolicyViolation {
53 policy: name.clone(),
54 message: "Custom policy validation not yet implemented".to_string(),
55 location: None,
56 });
57 }
58 }
59
60 Ok(PolicyReport {
61 policy: self.clone(),
62 violations: violations.clone(),
63 passed: violations.is_empty(),
64 })
65 }
66}
67
68#[derive(Clone)]
70pub struct PolicyValidator {
71 forge: Arc<Forge>,
73}
74
75impl PolicyValidator {
76 pub fn new(forge: Forge) -> Self {
78 Self {
79 forge: Arc::new(forge),
80 }
81 }
82
83 pub async fn validate(&self, diff: &Diff, policies: &[Policy]) -> Result<PolicyReport> {
85 let mut all_violations = Vec::new();
86
87 for policy in policies {
88 let report = policy.validate(&self.forge, diff).await?;
89 all_violations.extend(report.violations);
90 }
91
92 Ok(PolicyReport {
93 policy: Policy::Custom {
94 name: "All".to_string(),
95 description: "Combined policy check".to_string(),
96 },
97 violations: all_violations.clone(),
98 passed: all_violations.is_empty(),
99 })
100 }
101
102 pub async fn validate_single(&self, policy: &Policy, diff: &Diff) -> Result<PolicyReport> {
104 policy.validate(&self.forge, diff).await
105 }
106}
107
108#[derive(Clone, Debug)]
110pub struct PolicyReport {
111 pub policy: Policy,
113 pub violations: Vec<PolicyViolation>,
115 pub passed: bool,
117}
118
119#[derive(Clone, Debug)]
121pub struct PolicyViolation {
122 pub policy: String,
124 pub message: String,
126 pub location: Option<forge_core::types::Location>,
128}
129
130impl PolicyViolation {
131 pub fn new(policy: impl Into<String>, message: impl Into<String>) -> Self {
133 Self {
134 policy: policy.into(),
135 message: message.into(),
136 location: None,
137 }
138 }
139
140 pub fn with_location(
142 policy: impl Into<String>,
143 message: impl Into<String>,
144 location: forge_core::types::Location,
145 ) -> Self {
146 Self {
147 policy: policy.into(),
148 message: message.into(),
149 location: Some(location),
150 }
151 }
152}
153
154#[derive(Clone, Debug)]
156pub struct AllPolicies {
157 pub policies: Vec<Policy>,
159}
160
161impl AllPolicies {
162 pub fn new(policies: Vec<Policy>) -> Self {
164 Self { policies }
165 }
166
167 pub async fn validate(&self, forge: &Forge, diff: &Diff) -> Result<PolicyReport> {
169 let mut all_violations = Vec::new();
170
171 for policy in &self.policies {
172 let report = policy.validate(forge, diff).await?;
173 all_violations.extend(report.violations);
174 }
175
176 Ok(PolicyReport {
177 policy: Policy::Custom {
178 name: "All".to_string(),
179 description: format!("All {} policies must pass", self.policies.len()),
180 },
181 violations: all_violations.clone(),
182 passed: all_violations.is_empty(),
183 })
184 }
185}
186
187#[derive(Clone, Debug)]
189pub struct AnyPolicy {
190 pub policies: Vec<Policy>,
192}
193
194impl AnyPolicy {
195 pub fn new(policies: Vec<Policy>) -> Self {
197 Self { policies }
198 }
199
200 pub async fn validate(&self, forge: &Forge, diff: &Diff) -> Result<PolicyReport> {
202 let mut all_violations = Vec::new();
203 let mut any_passed = false;
204
205 for policy in &self.policies {
206 let report = policy.validate(forge, diff).await?;
207 if report.passed {
208 any_passed = true;
209 }
210 all_violations.extend(report.violations);
211 }
212
213 Ok(PolicyReport {
214 policy: Policy::Custom {
215 name: "Any".to_string(),
216 description: format!("At least one of {} policies must pass", self.policies.len()),
217 },
218 violations: if any_passed {
219 Vec::new()
220 } else {
221 all_violations.clone()
222 },
223 passed: any_passed,
224 })
225 }
226}
227
228#[derive(Clone, Debug)]
233pub struct Diff {
234 pub file_path: std::path::PathBuf,
236 pub original: String,
238 pub modified: String,
240 pub changes: Vec<DiffChange>,
242}
243
244#[derive(Clone, Debug)]
246pub struct DiffChange {
247 pub line: usize,
249 pub original: String,
251 pub modified: String,
253 pub kind: DiffChangeKind,
255}
256
257#[derive(Clone, Debug, PartialEq, Eq)]
259pub enum DiffChangeKind {
260 Added,
262 Removed,
264 Modified,
266}
267
268async fn check_no_unsafe_in_public_api(diff: &Diff) -> Result<Option<PolicyViolation>> {
272 let mut violations = Vec::new();
274
275 for (line_num, line) in diff.modified.lines().enumerate() {
276 let line_num = line_num + 1;
277 let trimmed = line.trim();
278
279 if trimmed.contains("unsafe") {
281 let is_public_function = trimmed.starts_with("pub ")
283 && (trimmed.contains("fn ") || trimmed.contains("unsafe fn"));
284
285 let is_public_struct = trimmed.starts_with("pub ")
286 && (trimmed.contains("struct ") || trimmed.contains("enum "));
287
288 if is_public_function || is_public_struct {
289 violations.push(PolicyViolation::new(
290 "NoUnsafeInPublicAPI",
291 format!("Unsafe code in public API at line {}", line_num),
292 ));
293 }
294 }
295 }
296
297 Ok(if violations.is_empty() {
298 None
299 } else {
300 Some(PolicyViolation::new(
301 "NoUnsafeInPublicAPI",
302 format!(
303 "Found {} violations of unsafe in public API",
304 violations.len()
305 ),
306 ))
307 })
308}
309
310async fn check_preserve_tests(_forge: &Forge, diff: &Diff) -> Result<Option<PolicyViolation>> {
312 let original_tests = count_tests(&diff.original);
314 let modified_tests = count_tests(&diff.modified);
315
316 if modified_tests < original_tests {
317 Ok(Some(PolicyViolation::new(
318 "PreserveTests",
319 format!(
320 "Test count decreased from {} to {}",
321 original_tests, modified_tests
322 ),
323 )))
324 } else {
325 Ok(None)
326 }
327}
328
329async fn check_max_complexity(
331 _forge: &Forge,
332 max_complexity: usize,
333 diff: &Diff,
334) -> Result<Option<PolicyViolation>> {
335 let violations: Vec<_> = diff
338 .modified
339 .lines()
340 .enumerate()
341 .filter(|(_, line)| line.trim().starts_with("pub fn ") || line.trim().starts_with("fn "))
342 .map(|(line_num, line)| {
343 let rest = if let Some(fn_pos) = line.find("fn ") {
345 &line[fn_pos + 3..]
346 } else {
347 line
348 };
349
350 let complexity = estimate_complexity_from_line(rest);
352
353 if complexity > max_complexity {
354 Some(PolicyViolation::new(
355 "MaxComplexity",
356 format!(
357 "Function at line {} has complexity {}, exceeds max {}",
358 line_num + 1,
359 complexity,
360 max_complexity
361 ),
362 ))
363 } else {
364 None
365 }
366 })
367 .flatten()
368 .collect();
369
370 Ok(if violations.is_empty() {
371 None
372 } else {
373 Some(PolicyViolation::new(
374 "MaxComplexity",
375 format!(
376 "Found {} functions exceeding complexity limit",
377 violations.len()
378 ),
379 ))
380 })
381}
382
383fn estimate_complexity_from_line(line: &str) -> usize {
385 let mut complexity = 1; let if_count = line.matches("if ").count();
389 let while_count = line.matches("while ").count();
390 let for_count = line.matches("for ").count();
391 let match_count = line.matches("match ").count();
392 let and_count = line.matches("&&").count();
393 let or_count = line.matches("||").count();
394
395 complexity += if_count + while_count + for_count + match_count + and_count + or_count;
396 complexity
397}
398
399fn count_tests(content: &str) -> usize {
401 content
402 .lines()
403 .filter(|line| {
404 let trimmed = line.trim();
405 trimmed.contains("#[test]") || trimmed.contains("#[tokio::test]")
406 })
407 .count()
408}
409
410#[cfg(test)]
411mod tests {
412 use super::*;
413 use std::path::PathBuf;
414 use tempfile::TempDir;
415 use forge_core::Forge;
416
417 #[tokio::test]
418 async fn test_policy_no_unsafe_in_public_api() {
419 let temp_dir = TempDir::new().unwrap();
420 let forge = Forge::open(temp_dir.path()).await.unwrap();
421
422 let diff = Diff {
423 file_path: PathBuf::from("test.rs"),
424 original: "fn safe() {}".to_string(),
425 modified: "pub unsafe fn dangerous() {}".to_string(),
426 changes: vec![],
427 };
428
429 let policy = Policy::NoUnsafeInPublicAPI;
430 let report = policy.validate(&forge, &diff).await.unwrap();
431
432 assert!(!report.passed);
433 assert_eq!(report.violations.len(), 1);
434 }
435
436 #[tokio::test]
437 async fn test_policy_preserve_tests() {
438 let temp_dir = TempDir::new().unwrap();
439 let forge = Forge::open(temp_dir.path()).await.unwrap();
440
441 let diff = Diff {
442 file_path: PathBuf::from("test.rs"),
443 original: "#[test]\nfn test_one() {}\n#[test]\nfn test_two() {}".to_string(),
444 modified: "#[test]\nfn test_one() {}".to_string(),
445 changes: vec![],
446 };
447
448 let policy = Policy::PreserveTests;
449 let report = policy.validate(&forge, &diff).await.unwrap();
450
451 assert!(!report.passed);
452 assert_eq!(report.violations.len(), 1);
453 }
454
455 #[tokio::test]
456 async fn test_policy_max_complexity() {
457 let temp_dir = TempDir::new().unwrap();
458 let forge = Forge::open(temp_dir.path()).await.unwrap();
459
460 let diff = Diff {
461 file_path: PathBuf::from("test.rs"),
462 original: "".to_string(),
463 modified: "pub fn complex() { if x { if y { if z {} } } }".to_string(),
464 changes: vec![],
465 };
466
467 let policy = Policy::MaxComplexity(3);
468 let report = policy.validate(&forge, &diff).await.unwrap();
469
470 assert!(!report.passed);
471 }
472
473 #[tokio::test]
474 async fn test_all_policies() {
475 let temp_dir = TempDir::new().unwrap();
476 let forge = Forge::open(temp_dir.path()).await.unwrap();
477
478 let diff = Diff {
479 file_path: PathBuf::from("test.rs"),
480 original: "".to_string(),
481 modified: "pub fn safe() {}".to_string(),
482 changes: vec![],
483 };
484
485 let policies = vec![Policy::NoUnsafeInPublicAPI, Policy::PreserveTests];
486
487 let all = AllPolicies::new(policies);
488 let report = all.validate(&forge, &diff).await.unwrap();
489
490 assert!(report.passed);
491 }
492
493 #[tokio::test]
494 async fn test_any_policy() {
495 let temp_dir = TempDir::new().unwrap();
496 let forge = Forge::open(temp_dir.path()).await.unwrap();
497
498 let diff = Diff {
499 file_path: PathBuf::from("test.rs"),
500 original: "".to_string(),
501 modified: "pub unsafe fn dangerous() {}".to_string(),
502 changes: vec![],
503 };
504
505 let policies = vec![
506 Policy::NoUnsafeInPublicAPI,
507 Policy::Custom {
508 name: "AlwaysPass".to_string(),
509 description: "Always passes".to_string(),
510 },
511 ];
512
513 let any = AnyPolicy::new(policies);
514 let report = any.validate(&forge, &diff).await.unwrap();
515
516 assert!(!report.passed);
519 }
520
521 #[tokio::test]
522 async fn test_count_tests() {
523 let content = r#"
524 #[test]
525 fn test_one() {}
526
527 #[test]
528 fn test_two() {}
529
530 #[tokio::test]
531 async fn test_three() {}
532 "#;
533
534 let count = count_tests(content);
535 assert_eq!(count, 3);
536 }
537}