1use parking_lot::Mutex;
50use serde_json::json;
51use std::collections::{BTreeMap, HashMap};
52use std::fmt;
53use std::io;
54use std::sync::Arc;
55
56use crate::request::Method;
57
58#[derive(Debug, Clone)]
60pub struct CoverageConfig {
61 pub line_threshold: f64,
63 pub branch_threshold: f64,
65 pub endpoint_threshold: f64,
67 pub fail_on_threshold: bool,
69 pub output_formats: Vec<OutputFormat>,
71 pub output_dir: String,
73}
74
75impl Default for CoverageConfig {
76 fn default() -> Self {
77 Self {
78 line_threshold: 0.80,
79 branch_threshold: 0.70,
80 endpoint_threshold: 0.90,
81 fail_on_threshold: true,
82 output_formats: vec![OutputFormat::Json, OutputFormat::Html],
83 output_dir: "target/coverage".into(),
84 }
85 }
86}
87
88impl CoverageConfig {
89 #[must_use]
91 pub fn new() -> Self {
92 Self::default()
93 }
94
95 #[must_use]
97 pub fn line_threshold(mut self, threshold: f64) -> Self {
98 self.line_threshold = threshold.clamp(0.0, 1.0);
99 self
100 }
101
102 #[must_use]
104 pub fn branch_threshold(mut self, threshold: f64) -> Self {
105 self.branch_threshold = threshold.clamp(0.0, 1.0);
106 self
107 }
108
109 #[must_use]
111 pub fn endpoint_threshold(mut self, threshold: f64) -> Self {
112 self.endpoint_threshold = threshold.clamp(0.0, 1.0);
113 self
114 }
115
116 #[must_use]
118 pub fn no_fail(mut self) -> Self {
119 self.fail_on_threshold = false;
120 self
121 }
122
123 #[must_use]
125 pub fn output_formats(mut self, formats: Vec<OutputFormat>) -> Self {
126 self.output_formats = formats;
127 self
128 }
129
130 #[must_use]
132 pub fn output_dir(mut self, dir: impl Into<String>) -> Self {
133 self.output_dir = dir.into();
134 self
135 }
136}
137
138#[derive(Debug, Clone, Copy, PartialEq, Eq)]
140pub enum OutputFormat {
141 Json,
143 Html,
145 Badge,
147 Lcov,
149}
150
151#[derive(Debug, Clone)]
155pub struct CoverageTracker {
156 inner: Arc<Mutex<CoverageTrackerInner>>,
157}
158
159#[derive(Debug, Default)]
160struct CoverageTrackerInner {
161 registered_endpoints: Vec<(Method, String)>,
163 endpoint_hits: HashMap<(Method, String), EndpointHits>,
165 branches: HashMap<String, BranchHits>,
167}
168
169#[derive(Debug, Clone, Default)]
171pub struct EndpointHits {
172 pub total_calls: u64,
174 pub success_count: u64,
176 pub client_error_count: u64,
178 pub server_error_count: u64,
180 pub status_codes: HashMap<u16, u64>,
182}
183
184#[derive(Debug, Clone, Default)]
186pub struct BranchHits {
187 pub taken_count: u64,
189 pub not_taken_count: u64,
191}
192
193impl CoverageTracker {
194 #[must_use]
196 pub fn new() -> Self {
197 Self {
198 inner: Arc::new(Mutex::new(CoverageTrackerInner::default())),
199 }
200 }
201
202 pub fn register_endpoint(&self, method: Method, path: impl Into<String>) {
204 let mut inner = self.inner.lock();
205 inner.registered_endpoints.push((method, path.into()));
206 }
207
208 pub fn register_endpoints<'a>(&self, endpoints: impl IntoIterator<Item = (Method, &'a str)>) {
210 let mut inner = self.inner.lock();
211 for (method, path) in endpoints {
212 inner.registered_endpoints.push((method, path.to_string()));
213 }
214 }
215
216 pub fn record_hit(&self, method: Method, path: &str, status_code: u16) {
218 let mut inner = self.inner.lock();
219
220 let key = (method, path.to_string());
221 let hits = inner.endpoint_hits.entry(key).or_default();
222
223 hits.total_calls += 1;
224 *hits.status_codes.entry(status_code).or_insert(0) += 1;
225
226 match status_code {
227 200..=299 => hits.success_count += 1,
228 400..=499 => hits.client_error_count += 1,
229 500..=599 => hits.server_error_count += 1,
230 _ => {}
231 }
232 }
233
234 pub fn record_branch(&self, branch_id: impl Into<String>, taken: bool) {
236 let mut inner = self.inner.lock();
237
238 let branch = inner.branches.entry(branch_id.into()).or_default();
239 if taken {
240 branch.taken_count += 1;
241 } else {
242 branch.not_taken_count += 1;
243 }
244 }
245
246 #[must_use]
248 pub fn report(&self) -> CoverageReport {
249 let inner = self.inner.lock();
250
251 let mut endpoints = BTreeMap::new();
252 for (method, path) in &inner.registered_endpoints {
253 let key = (*method, path.clone());
254 let hits = inner.endpoint_hits.get(&key).cloned().unwrap_or_default();
255 endpoints.insert((method.as_str().to_string(), path.clone()), hits);
256 }
257
258 for ((method, path), hits) in &inner.endpoint_hits {
260 let key = (method.as_str().to_string(), path.clone());
261 endpoints.entry(key).or_insert_with(|| hits.clone());
262 }
263
264 let branches = inner.branches.clone();
265
266 CoverageReport {
267 endpoints,
268 branches,
269 }
270 }
271
272 pub fn reset(&self) {
274 let mut inner = self.inner.lock();
275 inner.endpoint_hits.clear();
276 inner.branches.clear();
277 }
278}
279
280impl Default for CoverageTracker {
281 fn default() -> Self {
282 Self::new()
283 }
284}
285
286#[derive(Debug, Clone)]
288pub struct CoverageReport {
289 pub endpoints: BTreeMap<(String, String), EndpointHits>,
291 pub branches: HashMap<String, BranchHits>,
293}
294
295impl CoverageReport {
296 #[must_use]
298 #[allow(clippy::cast_precision_loss)]
299 pub fn endpoint_coverage(&self) -> f64 {
300 if self.endpoints.is_empty() {
301 return 1.0;
302 }
303
304 let covered = self
305 .endpoints
306 .values()
307 .filter(|h| h.total_calls > 0)
308 .count();
309
310 covered as f64 / self.endpoints.len() as f64
311 }
312
313 #[must_use]
315 #[allow(clippy::cast_precision_loss)]
316 pub fn branch_coverage(&self) -> f64 {
317 if self.branches.is_empty() {
318 return 1.0;
319 }
320
321 let fully_covered = self
322 .branches
323 .values()
324 .filter(|b| b.taken_count > 0 && b.not_taken_count > 0)
325 .count();
326
327 fully_covered as f64 / self.branches.len() as f64
328 }
329
330 #[must_use]
332 pub fn untested_endpoints(&self) -> Vec<(&str, &str)> {
333 self.endpoints
334 .iter()
335 .filter(|(_, hits)| hits.total_calls == 0)
336 .map(|((method, path), _)| (method.as_str(), path.as_str()))
337 .collect()
338 }
339
340 #[must_use]
342 pub fn untested_error_paths(&self) -> Vec<(&str, &str)> {
343 self.endpoints
344 .iter()
345 .filter(|(_, hits)| {
346 hits.total_calls > 0 && hits.client_error_count == 0 && hits.server_error_count == 0
347 })
348 .map(|((method, path), _)| (method.as_str(), path.as_str()))
349 .collect()
350 }
351
352 pub fn assert_threshold(&self, threshold: f64) {
358 let coverage = self.endpoint_coverage();
359 if coverage < threshold {
360 let untested = self.untested_endpoints();
361 panic!(
362 "Endpoint coverage {:.1}% is below threshold {:.1}%.\n\
363 Untested endpoints ({}):\n{}",
364 coverage * 100.0,
365 threshold * 100.0,
366 untested.len(),
367 untested
368 .iter()
369 .map(|(m, p)| format!(" - {} {}", m, p))
370 .collect::<Vec<_>>()
371 .join("\n")
372 );
373 }
374 }
375
376 pub fn write_json(&self, path: &str) -> io::Result<()> {
382 let json = self.to_json();
383 std::fs::write(path, json)
384 }
385
386 #[must_use]
388 pub fn to_json(&self) -> String {
389 let tested_endpoints = self
390 .endpoints
391 .values()
392 .filter(|h| h.total_calls > 0)
393 .count();
394
395 let endpoints: Vec<_> = self
396 .endpoints
397 .iter()
398 .map(|((method, path), hits)| {
399 json!({
400 "method": method,
401 "path": path,
402 "calls": hits.total_calls,
403 "success": hits.success_count,
404 "client_errors": hits.client_error_count,
405 "server_errors": hits.server_error_count,
406 })
407 })
408 .collect();
409
410 let doc = json!({
411 "summary": {
412 "endpoint_coverage": self.endpoint_coverage(),
413 "branch_coverage": self.branch_coverage(),
414 "total_endpoints": self.endpoints.len(),
415 "tested_endpoints": tested_endpoints,
416 },
417 "endpoints": endpoints,
418 });
419
420 serde_json::to_string_pretty(&doc)
421 .expect("serializing coverage report to JSON should never fail")
422 }
423
424 pub fn write_html(&self, path: &str) -> io::Result<()> {
430 let html = self.to_html();
431 std::fs::write(path, html)
432 }
433
434 #[must_use]
436 #[allow(clippy::too_many_lines)]
437 pub fn to_html(&self) -> String {
438 let coverage_pct = self.endpoint_coverage() * 100.0;
439 let coverage_class = if coverage_pct >= 80.0 {
440 "good"
441 } else if coverage_pct >= 60.0 {
442 "warning"
443 } else {
444 "poor"
445 };
446
447 let mut html = format!(
448 r#"<!DOCTYPE html>
449<html lang="en">
450<head>
451 <meta charset="UTF-8">
452 <meta name="viewport" content="width=device-width, initial-scale=1.0">
453 <title>fastapi_rust Coverage Report</title>
454 <style>
455 body {{ font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Roboto, sans-serif; margin: 20px; background: #f5f5f5; }}
456 .container {{ max-width: 1200px; margin: 0 auto; background: white; padding: 20px; border-radius: 8px; box-shadow: 0 2px 4px rgba(0,0,0,0.1); }}
457 h1 {{ color: #333; border-bottom: 2px solid #007bff; padding-bottom: 10px; }}
458 .summary {{ display: flex; gap: 20px; margin-bottom: 30px; }}
459 .metric {{ flex: 1; padding: 20px; border-radius: 8px; text-align: center; }}
460 .metric.good {{ background: #d4edda; color: #155724; }}
461 .metric.warning {{ background: #fff3cd; color: #856404; }}
462 .metric.poor {{ background: #f8d7da; color: #721c24; }}
463 .metric h2 {{ margin: 0 0 10px 0; font-size: 2.5em; }}
464 .metric p {{ margin: 0; font-size: 0.9em; }}
465 table {{ width: 100%; border-collapse: collapse; }}
466 th, td {{ padding: 12px; text-align: left; border-bottom: 1px solid #ddd; }}
467 th {{ background: #f8f9fa; font-weight: 600; }}
468 tr:hover {{ background: #f5f5f5; }}
469 .method {{ font-family: monospace; padding: 2px 6px; border-radius: 4px; font-weight: 600; }}
470 .method.GET {{ background: #28a745; color: white; }}
471 .method.POST {{ background: #ffc107; color: black; }}
472 .method.PUT {{ background: #17a2b8; color: white; }}
473 .method.DELETE {{ background: #dc3545; color: white; }}
474 .method.PATCH {{ background: #6f42c1; color: white; }}
475 .untested {{ color: #dc3545; font-weight: 600; }}
476 .path {{ font-family: monospace; }}
477 .count {{ text-align: right; }}
478 </style>
479</head>
480<body>
481 <div class="container">
482 <h1>fastapi_rust Coverage Report</h1>
483
484 <div class="summary">
485 <div class="metric {coverage_class}">
486 <h2>{coverage_pct:.1}%</h2>
487 <p>Endpoint Coverage</p>
488 </div>
489 <div class="metric">
490 <h2>{}</h2>
491 <p>Total Endpoints</p>
492 </div>
493 <div class="metric">
494 <h2>{}</h2>
495 <p>Tested Endpoints</p>
496 </div>
497 </div>
498
499 <h2>Endpoint Details</h2>
500 <table>
501 <thead>
502 <tr>
503 <th>Method</th>
504 <th>Path</th>
505 <th class="count">Calls</th>
506 <th class="count">Success</th>
507 <th class="count">4xx</th>
508 <th class="count">5xx</th>
509 </tr>
510 </thead>
511 <tbody>
512"#,
513 self.endpoints.len(),
514 self.endpoints
515 .values()
516 .filter(|h| h.total_calls > 0)
517 .count()
518 );
519
520 for ((method, path), hits) in &self.endpoints {
521 let tested_class = if hits.total_calls == 0 {
522 " class=\"untested\""
523 } else {
524 ""
525 };
526 let method_escaped = escape_html(method);
527 let path_escaped = escape_html(path);
528 html.push_str(&format!(
529 r#" <tr{tested_class}>
530 <td><span class="method {method_escaped}">{method_escaped}</span></td>
531 <td class="path">{path_escaped}</td>
532 <td class="count">{}</td>
533 <td class="count">{}</td>
534 <td class="count">{}</td>
535 <td class="count">{}</td>
536 </tr>
537"#,
538 hits.total_calls,
539 hits.success_count,
540 hits.client_error_count,
541 hits.server_error_count
542 ));
543 }
544
545 html.push_str(
546 r" </tbody>
547 </table>
548 </div>
549</body>
550</html>",
551 );
552
553 html
554 }
555
556 #[must_use]
558 pub fn to_badge(&self) -> String {
559 let coverage_pct = self.endpoint_coverage() * 100.0;
560 let color = if coverage_pct >= 80.0 {
561 "4c1"
562 } else if coverage_pct >= 60.0 {
563 "dfb317"
564 } else {
565 "e05d44"
566 };
567
568 let mut svg = String::new();
570 svg.push_str(r#"<svg xmlns="http://www.w3.org/2000/svg" width="106" height="20">"#);
571 svg.push_str("\n <linearGradient id=\"b\" x2=\"0\" y2=\"100%\">");
572 svg.push_str("\n <stop offset=\"0\" stop-color=\"#bbb\" stop-opacity=\".1\"/>");
573 svg.push_str("\n <stop offset=\"1\" stop-opacity=\".1\"/>");
574 svg.push_str("\n </linearGradient>");
575 svg.push_str(
576 "\n <mask id=\"a\"><rect width=\"106\" height=\"20\" rx=\"3\" fill=\"#fff\"/></mask>",
577 );
578 svg.push_str("\n <g mask=\"url(#a)\">");
579 svg.push_str("\n <rect width=\"61\" height=\"20\" fill=\"#555\"/>");
580 svg.push_str(&format!(
581 "\n <rect x=\"61\" width=\"45\" height=\"20\" fill=\"#{color}\"/>"
582 ));
583 svg.push_str("\n <rect width=\"106\" height=\"20\" fill=\"url(#b)\"/>");
584 svg.push_str("\n </g>");
585 svg.push_str("\n <g fill=\"#fff\" text-anchor=\"middle\" font-family=\"DejaVu Sans,Verdana,Geneva,sans-serif\" font-size=\"11\">");
586 svg.push_str(
587 "\n <text x=\"31.5\" y=\"15\" fill=\"#010101\" fill-opacity=\".3\">coverage</text>",
588 );
589 svg.push_str("\n <text x=\"31.5\" y=\"14\" fill=\"#fff\">coverage</text>");
590 svg.push_str(&format!("\n <text x=\"82.5\" y=\"15\" fill=\"#010101\" fill-opacity=\".3\">{coverage_pct:.0}%</text>"));
591 svg.push_str(&format!(
592 "\n <text x=\"82.5\" y=\"14\" fill=\"#fff\">{coverage_pct:.0}%</text>"
593 ));
594 svg.push_str("\n </g>");
595 svg.push_str("\n</svg>");
596
597 svg
598 }
599
600 pub fn write_badge(&self, path: &str) -> io::Result<()> {
606 std::fs::write(path, self.to_badge())
607 }
608}
609
610impl fmt::Display for CoverageReport {
611 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
612 writeln!(f, "Coverage Report")?;
613 writeln!(f, "===============")?;
614 writeln!(f)?;
615 writeln!(
616 f,
617 "Endpoint Coverage: {:.1}%",
618 self.endpoint_coverage() * 100.0
619 )?;
620 writeln!(
621 f,
622 "Branch Coverage: {:.1}%",
623 self.branch_coverage() * 100.0
624 )?;
625 writeln!(f)?;
626
627 let untested = self.untested_endpoints();
628 if !untested.is_empty() {
629 writeln!(f, "Untested Endpoints ({}):", untested.len())?;
630 for (method, path) in untested {
631 writeln!(f, " - {} {}", method, path)?;
632 }
633 }
634
635 let untested_errors = self.untested_error_paths();
636 if !untested_errors.is_empty() {
637 writeln!(f)?;
638 writeln!(f, "Missing Error Path Tests ({}):", untested_errors.len())?;
639 for (method, path) in untested_errors {
640 writeln!(f, " - {} {}", method, path)?;
641 }
642 }
643
644 Ok(())
645 }
646}
647
648#[macro_export]
666macro_rules! record_branch {
667 ($tracker:expr, $branch_id:expr, $taken:expr) => {
668 $tracker.record_branch($branch_id, $taken)
669 };
670}
671
672fn escape_html(s: &str) -> String {
674 let mut out = String::with_capacity(s.len());
675 for c in s.chars() {
676 match c {
677 '&' => out.push_str("&"),
678 '<' => out.push_str("<"),
679 '>' => out.push_str(">"),
680 '"' => out.push_str("""),
681 '\'' => out.push_str("'"),
682 _ => out.push(c),
683 }
684 }
685 out
686}
687
688#[cfg(test)]
689mod tests {
690 use super::*;
691
692 #[test]
693 fn test_tracker_basic() {
694 let tracker = CoverageTracker::new();
695
696 tracker.register_endpoint(Method::Get, "/users");
698 tracker.register_endpoint(Method::Post, "/users");
699 tracker.register_endpoint(Method::Get, "/users/{id}");
700
701 tracker.record_hit(Method::Get, "/users", 200);
703 tracker.record_hit(Method::Get, "/users", 200);
704 tracker.record_hit(Method::Post, "/users", 201);
705 tracker.record_hit(Method::Post, "/users", 400); let report = tracker.report();
708
709 assert_eq!(report.endpoints.len(), 3);
711 assert!((report.endpoint_coverage() - 2.0 / 3.0).abs() < 0.001);
712
713 let untested = report.untested_endpoints();
715 assert_eq!(untested.len(), 1);
716 assert_eq!(untested[0], ("GET", "/users/{id}"));
717 }
718
719 #[test]
720 fn test_tracker_error_paths() {
721 let tracker = CoverageTracker::new();
722
723 tracker.register_endpoint(Method::Get, "/users");
724 tracker.register_endpoint(Method::Post, "/users");
725
726 tracker.record_hit(Method::Get, "/users", 200);
728
729 tracker.record_hit(Method::Post, "/users", 201);
731 tracker.record_hit(Method::Post, "/users", 400);
732
733 let report = tracker.report();
734 let untested_errors = report.untested_error_paths();
735
736 assert_eq!(untested_errors.len(), 1);
737 assert_eq!(untested_errors[0], ("GET", "/users"));
738 }
739
740 #[test]
741 fn test_branch_coverage() {
742 let tracker = CoverageTracker::new();
743
744 tracker.record_branch("auth", true);
746 tracker.record_branch("auth", false);
747
748 tracker.record_branch("admin", true);
750
751 let report = tracker.report();
752
753 assert_eq!(report.branches.len(), 2);
755 assert!((report.branch_coverage() - 0.5).abs() < 0.001);
756 }
757
758 #[test]
759 fn test_report_json() {
760 let tracker = CoverageTracker::new();
761 tracker.register_endpoint(Method::Get, "/test");
762 tracker.record_hit(Method::Get, "/test", 200);
763
764 let report = tracker.report();
765 let json = report.to_json();
766
767 assert!(json.contains("\"endpoint_coverage\""));
768 assert!(json.contains("\"/test\""));
769 }
770
771 #[test]
772 fn test_report_json_escapes_special_characters() {
773 let tracker = CoverageTracker::new();
774 let path = "/te\"st\\path";
775 tracker.register_endpoint(Method::Get, path);
776 tracker.record_hit(Method::Get, path, 200);
777
778 let report = tracker.report();
779 let json = report.to_json();
780 let parsed: serde_json::Value =
781 serde_json::from_str(&json).expect("generated JSON must be valid");
782
783 assert_eq!(parsed["endpoints"][0]["path"], path);
784 }
785
786 #[test]
787 fn test_report_html() {
788 let tracker = CoverageTracker::new();
789 tracker.register_endpoint(Method::Get, "/test");
790
791 let report = tracker.report();
792 let html = report.to_html();
793
794 assert!(html.contains("<!DOCTYPE html>"));
795 assert!(html.contains("Coverage Report"));
796 assert!(html.contains("/test"));
797 }
798
799 #[test]
800 fn test_report_badge() {
801 let tracker = CoverageTracker::new();
802 tracker.register_endpoint(Method::Get, "/test");
803 tracker.record_hit(Method::Get, "/test", 200);
804
805 let report = tracker.report();
806 let badge = report.to_badge();
807
808 assert!(badge.contains("<svg"));
809 assert!(badge.contains("coverage"));
810 assert!(badge.contains("100%"));
811 }
812
813 #[test]
814 fn test_config_builder() {
815 let config = CoverageConfig::new()
816 .line_threshold(0.90)
817 .branch_threshold(0.85)
818 .endpoint_threshold(0.95)
819 .no_fail()
820 .output_dir("custom/path");
821
822 assert!((config.line_threshold - 0.90).abs() < 0.001);
823 assert!((config.branch_threshold - 0.85).abs() < 0.001);
824 assert!((config.endpoint_threshold - 0.95).abs() < 0.001);
825 assert!(!config.fail_on_threshold);
826 assert_eq!(config.output_dir, "custom/path");
827 }
828
829 #[test]
830 fn test_threshold_clamp() {
831 let config = CoverageConfig::new()
832 .line_threshold(1.5) .branch_threshold(-0.5); assert!((config.line_threshold - 1.0).abs() < 0.001);
836 assert!((config.branch_threshold - 0.0).abs() < 0.001);
837 }
838
839 #[test]
840 #[should_panic(expected = "coverage")]
841 fn test_assert_threshold_panics() {
842 let tracker = CoverageTracker::new();
843 tracker.register_endpoint(Method::Get, "/a");
844 tracker.register_endpoint(Method::Get, "/b");
845 tracker.record_hit(Method::Get, "/a", 200);
847
848 let report = tracker.report();
849 report.assert_threshold(0.90); }
851
852 #[test]
853 #[allow(clippy::float_cmp)]
854 fn test_reset() {
855 let tracker = CoverageTracker::new();
856 tracker.register_endpoint(Method::Get, "/test");
857 tracker.record_hit(Method::Get, "/test", 200);
858
859 let report1 = tracker.report();
860 assert_eq!(report1.endpoint_coverage(), 1.0);
861
862 tracker.reset();
863
864 let report2 = tracker.report();
865 assert_eq!(report2.endpoints.len(), 1);
867 let hits = report2
868 .endpoints
869 .get(&("GET".to_string(), "/test".to_string()))
870 .unwrap();
871 assert_eq!(hits.total_calls, 0);
872 }
873}