rem_extract/extract/
extract_tests.rs

1use std::{
2    fs,
3    time::Instant,
4    time::Duration,
5    path::PathBuf,
6    process::Command,
7};
8use syn::{
9    parse_file,
10    File
11};
12use colored::*;
13use quote::ToTokens;
14use log::info;
15use regex::Regex;
16
17use crate::{
18    extract::extraction::extract_method,
19    extract::extraction::ExtractionInput,
20    error::ExtractionError,
21    test_details::TEST_FILES, // Import Test Files Information from test_details.rs
22};
23
24pub struct TestFile<'a> {
25    /// The name of the input file (without path)
26    /// The input is in ./testdata/input/<input_file>/src/main.rs
27    /// The correct output is in ./testdata/correct_output/<input_file>.rs
28    pub input_file: &'a str,
29    pub start_idx: u32,
30    pub end_idx: u32,
31}
32
33impl TestFile<'_> {
34    pub fn new(input_file: &str, start_idx: u32, end_idx: u32) -> TestFile<'_> {
35        TestFile {
36            input_file,
37            start_idx,
38            end_idx,
39        }
40    }
41}
42
43/// A TestInput needs a Path to write to, as part of the testing process is
44/// writing the output to a file. This is not needed for the actual extraction
45pub struct TestInput{
46    pub file_path: String,
47    pub output_path: String,
48    pub new_fn_name: String,
49    pub start_idx: u32,
50    pub end_idx: u32,
51}
52
53// Helper function to convert a TestFile into an ExtractionInput
54impl From<&TestFile<'_>> for TestInput {
55    fn from(test_file: &TestFile<'_>) -> TestInput {
56        // The file path is constructed as ./testdata/<input_file>/input/src/main.rs
57        let file_path: String = PathBuf::from("testdata")
58            .join(&test_file.input_file)
59            .join("input")
60            .join("src")
61            .join("main.rs")
62            .to_string_lossy()
63            .to_string();
64        let output_path: String = PathBuf::from("output")
65            .join(&test_file.input_file)
66            .with_extension("rs")
67            .to_string_lossy()
68            .to_string();
69
70        TestInput {
71            file_path,
72            output_path,
73            new_fn_name: "fun_name".to_string(),
74            start_idx: test_file.start_idx,
75            end_idx: test_file.end_idx,
76        }
77    }
78}
79
80/// Because the TestInput contains an output path, it also needs to be converted
81/// into an ExtractionInput for the actual extraction process
82impl From<&TestInput> for ExtractionInput {
83    fn from(test_input: &TestInput) -> ExtractionInput {
84        ExtractionInput {
85            file_path: test_input.file_path.clone(),
86            new_fn_name: test_input.new_fn_name.clone(),
87            start_idx: test_input.start_idx,
88            end_idx: test_input.end_idx,
89        }
90    }
91}
92
93/// Strips ANSI color codes from a string using a regex
94/// This is useful for comparing strings with ANSI color codes to strings without
95#[allow(dead_code)]
96fn strip_ansi_codes(s: &str) -> String {
97    let ansi_regex = Regex::new(r"\x1b\[([0-9]{1,2}(;[0-9]{0,2})*)m").unwrap();
98    ansi_regex.replace_all(s, "").to_string()
99}
100
101/// Parse and subsequently compare the ASTs of two files. One file is provided
102/// as a String, with the other being a reference to a file path (of the
103/// expected file)
104#[allow(dead_code)]
105fn parse_and_compare_ast(output_content: &String, expected_file_path: &str) -> Result<bool, ExtractionError> {
106    let expected_content: String = fs::read_to_string(expected_file_path)?;
107
108    let output_ast: File = parse_file(&output_content)?;
109    let expected_ast: File = parse_file(&expected_content)?;
110
111    // Convert both ASTs back into token stres for comparison
112    // FIXME this is sometimes buggy and is convinced that the two files are
113    // different when they are infact the same
114    let output_tokens: String = output_ast.into_token_stream().to_string();
115    let expected_tokens: String = expected_ast.into_token_stream().to_string();
116
117    Ok(output_tokens == expected_tokens)
118}
119
120// Helper function to show differences between two files
121#[allow(dead_code)]
122fn print_file_diff(expected_file_path: &str, output_file_path: &str) -> Result<(), std::io::Error> {
123    let expected_content = fs::read_to_string(expected_file_path)?;
124    let output_content = fs::read_to_string(output_file_path)?;
125
126    if expected_content != output_content {
127        println!("Differences found between expected and output:");
128        for diff in diff::lines(&expected_content, &output_content) {
129            match diff {
130                diff::Result::Left(l) => println!("{}", format!("- {}", l).red()), // Expected but not in output
131                diff::Result::Right(r) => println!("{}", format!("+ {}", r).green()), // In output but not in expected
132                diff::Result::Both(b, _) => println!("{}", format!("  {}", b)), // Same in both
133            }
134        }
135    } else {
136        println!("{}", "No differences found.".green());
137    }
138
139    Ok(())
140}
141
142/// Removes all files in a given directory
143#[allow(dead_code)]
144fn remove_all_files(dir: &PathBuf) -> () {
145    for entry in fs::read_dir(dir).unwrap() {
146        let entry = entry.unwrap();
147        let path = entry.path();
148        if path.is_file() {
149            fs::remove_file(path).unwrap();
150        }
151    }
152}
153
154#[allow(dead_code)]
155pub fn test() {
156    // Clear the output directory before running tests
157    let output_dir = PathBuf::from("./output");
158    remove_all_files(&output_dir);
159
160    // Measure total time at the start
161    let overall_start_time: Instant = Instant::now();
162
163    info!("Starting tests...");
164
165    // Initialize counters and time trackers
166    let mut total_tests: i32 = 0;
167    let mut passed_stage_1: i32 = 0;
168    let mut passed_tests: i32 = 0;
169    let mut failed_tests: i32 = 0;
170    let mut total_test_time: Duration = Duration::new(0, 0);
171    let mut min_test_time: Option<Duration> = None;
172    let mut max_test_time: Option<Duration> = None;
173
174    for (index, test_file) in TEST_FILES.iter().enumerate() {
175        let test_start_time: Instant = Instant::now();
176
177        total_tests += 1;
178
179        let input: TestInput = TestInput::from(test_file);
180        // The expected file is now in ./testdata/<testname>/output/src/main.rs
181        let expected_file_path: String = PathBuf::new()
182            .join("testdata")
183            .join(&test_file.input_file)
184            .join("output")
185            .join("main.rs")
186            .to_string_lossy()
187            .to_string();
188        let output_path: String = input.output_path.clone();
189
190        let extraction_input: ExtractionInput = ExtractionInput::from(&input);
191
192        // Call the extraction method and handle errors
193        let extraction_result: Result<(String, _), ExtractionError> = extract_method(extraction_input);
194
195        // Measure time taken for extraction
196        let test_elapsed_time: Duration = test_start_time.elapsed();
197        total_test_time += test_elapsed_time;
198
199        // Update min and max times
200        if let Some(min_time) = min_test_time {
201            if test_elapsed_time < min_time {
202                min_test_time = Some(test_elapsed_time);
203            }
204        } else {
205            min_test_time = Some(test_elapsed_time);
206        }
207
208        if let Some(max_time) = max_test_time {
209            if test_elapsed_time > max_time {
210                max_test_time = Some(test_elapsed_time);
211            }
212        } else {
213            max_test_time = Some(test_elapsed_time);
214        }
215
216        let test_elapsed_time_secs: f64 = test_elapsed_time.as_secs_f64();
217        let test_elapsed_time_str: String = if test_elapsed_time_secs < 1.0 {
218            format!("{:.2}ms", test_elapsed_time_secs * 1000.0)
219        } else {
220            format!("{:.2}s", test_elapsed_time_secs)
221        };
222
223        let test_name: &str = test_file.input_file.trim_end_matches(".rs");
224        let mut extraction_status: String = "FAILED".red().to_string();
225        let mut comparison_status: String = "N/A".to_string(); // Default to not applicable
226
227        if extraction_result.is_ok() {
228            // Unwrap the result to get the output code (as we know that it is
229            // successful)
230            // Also write the output code to the output file for later viewing
231            let (extraction_result, _)  = extraction_result.unwrap();
232            fs::write(&output_path, &extraction_result).unwrap();
233            extraction_status = "PASSED".green().to_string();
234            passed_stage_1 += 1;
235
236            // Compare the output file with the expected file's AST
237            match parse_and_compare_ast(&extraction_result, &expected_file_path) {
238                Ok(is_identical) => {
239                    if is_identical {
240                        comparison_status = "PASSED".green().to_string();
241                        passed_tests += 1;
242                    } else {
243                        comparison_status = "FAILED".red().to_string();
244                        failed_tests += 1;
245                    }
246                }
247                Err(e) => {
248                    comparison_status = format!("Error: {}", e).red().to_string();
249                    failed_tests += 1;
250                }
251            }
252        } else if let Err(e) = extraction_result {
253            extraction_status = format!("FAILED: {}", e).red().to_string();
254            failed_tests += 1;
255        }
256
257        println!("Test {} | {} | {}: {} in {}", index + 1, extraction_status, comparison_status, test_name, test_elapsed_time_str);
258        // Strip ANSI color codes before logging
259        let clean_extraction_status = strip_ansi_codes(&extraction_status);
260        let clean_comparison_status = strip_ansi_codes(&comparison_status);
261
262        info!("Test {} | {} | {}: {} in {}", index + 1, clean_extraction_status, clean_comparison_status, test_name, test_elapsed_time_str);
263
264    }
265
266    // Total elapsed time
267    let total_elapsed_time: Duration = overall_start_time.elapsed();
268    let total_elapsed_time_secs: f64 = total_elapsed_time.as_secs_f64();
269    let total_elapsed_time_str: String = if total_elapsed_time_secs < 1.0 {
270        format!("{:.2}ms", total_elapsed_time_secs * 1000.0)
271    } else {
272        format!("{:.2}s", total_elapsed_time_secs)
273    };
274
275    // Calculate average time per test
276    let average_time_per_test: f64 = if total_tests > 0 {
277        total_test_time.as_secs_f64() / total_tests as f64
278    } else {
279        0.0
280    };
281
282    let average_time_str: String = if average_time_per_test < 1.0 {
283        format!("{:.2}ms", average_time_per_test * 1000.0)
284    } else {
285        format!("{:.2}s", average_time_per_test)
286    };
287
288    // Print overall statistics
289    println!("------------------------------------------------------------------");
290    println!("Total tests run: {}", total_tests);
291    println!("Tests passed stage 1: {}", passed_stage_1);
292    println!("Tests passed: {}", passed_tests);
293    println!("Tests failed: {}", failed_tests);
294    println!("Total time: {}", total_elapsed_time_str);
295    println!("Average time per test: {}", average_time_str);
296
297    // Log overall statistics
298    info!("------------------------------------------------------------------");
299    info!("Total tests run: {}", total_tests);
300    info!("Tests passed stage 1: {}", passed_stage_1);
301    info!("Tests passed: {}", passed_tests);
302    info!("Tests failed: {}", failed_tests);
303    info!("Total time: {}", total_elapsed_time_str);
304    info!("Average time per test: {}", average_time_str);
305
306    if let Some(min_time) = min_test_time {
307        let min_time_secs: f64 = min_time.as_secs_f64();
308        let min_time_str: String = if min_time_secs < 1.0 {
309            format!("{:.2}ms", min_time_secs * 1000.0)
310        } else {
311            format!("{:.2}s", min_time_secs)
312        };
313        println!("Shortest test time: {}", min_time_str);
314        info!("Shortest test time: {}", min_time_str);
315    }
316
317    if let Some(max_time) = max_test_time {
318        let max_time_secs: f64 = max_time.as_secs_f64();
319        let max_time_str: String = if max_time_secs < 1.0 {
320            format!("{:.2}ms", max_time_secs * 1000.0)
321        } else {
322            format!("{:.2}s", max_time_secs)
323        };
324        println!("Longest test time: {}", max_time_str);
325        info!("Longest test time: {}", max_time_str);
326    }
327}
328
329#[allow(dead_code)]
330pub fn test_verbose() {
331    // Clear the output directory before running tests
332    let output_dir = PathBuf::from("./output");
333    remove_all_files(&output_dir);
334
335    // Measure total time at the start
336    let overall_start_time: Instant = Instant::now();
337
338    info!("Starting tests...");
339
340    // Initialize counters and time trackers
341    let mut total_tests: i32 = 0;
342    let mut passed_stage_1: i32 = 0;
343    let mut passed_tests: i32 = 0;
344    let mut failed_tests: i32 = 0;
345    let mut total_test_time: Duration = Duration::new(0, 0);
346    let mut min_test_time: Option<Duration> = None;
347    let mut max_test_time: Option<Duration> = None;
348
349    let allowed_tests: Vec<&'static str> = vec![
350        // "break_loop_nested",
351        // "comments_in_block_expr",
352    ];
353
354    for (index, test_file) in TEST_FILES.iter().enumerate() {
355
356        // Take a snapshot of all files in the environment
357        let snapshot = fs::read_dir("./").unwrap().map(|entry| entry.unwrap().path()).collect::<Vec<PathBuf>>();
358
359        let test_start_time: Instant = Instant::now();
360
361        total_tests += 1;
362
363        let input: TestInput = TestInput::from(test_file);
364        let output_path: String = input.output_path.clone();
365        let expected_file_path: String = PathBuf::new()
366            .join("testdata")
367            .join(&test_file.input_file)
368            .join("output")
369            .join("main.rs")
370            .to_string_lossy()
371            .to_string();
372
373        // Skip tests not in the allowed_tests list
374        // if the allowed_tests list is not empty
375        if !allowed_tests.is_empty() && !allowed_tests.contains(&test_file.input_file) {
376            continue;
377        }
378
379        let extraction_input: ExtractionInput = ExtractionInput::from(&input);
380
381        // Call the extraction method and handle errors
382        let extraction_result: Result<(String, _), ExtractionError> = extract_method(extraction_input);
383
384        // Measure time taken for extraction
385        let test_elapsed_time: Duration = test_start_time.elapsed();
386        total_test_time += test_elapsed_time;
387
388        // Update min and max times
389        if let Some(min_time) = min_test_time {
390            if test_elapsed_time < min_time {
391                min_test_time = Some(test_elapsed_time);
392            }
393        } else {
394            min_test_time = Some(test_elapsed_time);
395        }
396
397        if let Some(max_time) = max_test_time {
398            if test_elapsed_time > max_time {
399                max_test_time = Some(test_elapsed_time);
400            }
401        } else {
402            max_test_time = Some(test_elapsed_time);
403        }
404
405        let test_elapsed_time_secs: f64 = test_elapsed_time.as_secs_f64();
406        let test_elapsed_time_str: String = if test_elapsed_time_secs < 1.0 {
407            format!("{:.2}ms", test_elapsed_time_secs * 1000.0)
408        } else {
409            format!("{:.2}s", test_elapsed_time_secs)
410        };
411
412        let test_name: &str = test_file.input_file.trim_end_matches(".rs");
413        let mut extraction_status: String = "FAILED".red().to_string();
414        let mut comparison_status: String = "N/A".cyan().to_string(); // Default to not applicable
415        let mut compilation_status: String = "N/A".magenta().to_string(); // Default to not applicable
416
417        if extraction_result.is_ok() {
418            // Unwrap the result to get the output code (as we know that it is
419            // successful)
420            // Also write the output code to the output file for later viewing
421            let (extraction_result, _)  = extraction_result.unwrap();
422            fs::write(&output_path, &extraction_result).unwrap();
423            extraction_status = "PASSED".green().to_string();
424            passed_stage_1 += 1;
425
426            // Compare the output file with the expected file's AST
427            match parse_and_compare_ast(&extraction_result, &expected_file_path) {
428                Ok(is_identical) => {
429                    if is_identical {
430                        comparison_status = "PASSED".green().to_string();
431                        passed_tests += 1;
432                    } else {
433                        comparison_status = "FAILED".red().to_string();
434                        failed_tests += 1;
435                    }
436                }
437                Err(e) => {
438                    comparison_status = format!("Error: {}", e).red().to_string();
439                    failed_tests += 1;
440                }
441            }
442
443            // Complilation check using rustc
444            let compile_result = Command::new("rustc")
445                .arg(&output_path)
446                .output();
447
448            match compile_result {
449                Ok(output) => {
450                    if output.status.success() {
451                        compilation_status = "PASSED".green().to_string();
452                    } else {
453                        compilation_status = "FAILED".red().to_string();
454                        // failed_tests += 1;
455                    }
456                }
457                Err(e) => {
458                    compilation_status = format!("Error: {}", e).red().to_string();
459                    // failed_tests += 1;
460                }
461            }
462
463        } else if let Err(e) = extraction_result {
464            extraction_status = format!("FAILED: {}", e).red().to_string();
465            failed_tests += 1;
466        }
467
468        println!(
469            "Test {} | {} | {} | {}: {} in {}",
470            index + 1,
471            extraction_status,
472            comparison_status,
473            compilation_status,
474            test_name,
475            test_elapsed_time_str
476        );
477
478        let clean_extraction_status = strip_ansi_codes(&extraction_status);
479        let clean_comparison_status = strip_ansi_codes(&comparison_status);
480        let clean_compilation_status = strip_ansi_codes(&compilation_status);
481
482        info!(
483            "Test {} | {} | {} | {}: {} in {}",
484            index + 1,
485            clean_extraction_status,
486            clean_comparison_status,
487            clean_compilation_status,
488            test_name,
489            test_elapsed_time_str
490        );
491        // Print differences if the test failed
492        if clean_comparison_status == "FAILED" || clean_extraction_status == "FAILED" {
493            println!("==================================================================");
494            println!("Differences or compilation errors found for test '{}':", test_name);
495            print_file_diff(&expected_file_path, &output_path).unwrap();
496            println!("==================================================================");
497            println!("");
498        }
499
500        // Delete all files created by the test (i.e anything not in the
501        // snapshot)
502        let current_files = fs::read_dir("./").unwrap().map(|entry| entry.unwrap().path()).collect::<Vec<PathBuf>>();
503        for file in current_files {
504            if !snapshot.contains(&file) {
505                fs::remove_file(file).unwrap();
506            }
507        }
508    }
509
510    // Total elapsed time
511    let total_elapsed_time: Duration = overall_start_time.elapsed();
512    let total_elapsed_time_secs: f64 = total_elapsed_time.as_secs_f64();
513    let total_elapsed_time_str: String = if total_elapsed_time_secs < 1.0 {
514        format!("{:.2}ms", total_elapsed_time_secs * 1000.0)
515    } else {
516        format!("{:.2}s", total_elapsed_time_secs)
517    };
518
519    // Calculate average time per test
520    let average_time_per_test: f64 = if total_tests > 0 {
521        total_test_time.as_secs_f64() / total_tests as f64
522    } else {
523        0.0
524    };
525
526    let average_time_str: String = if average_time_per_test < 1.0 {
527        format!("{:.2}ms", average_time_per_test * 1000.0)
528    } else {
529        format!("{:.2}s", average_time_per_test)
530    };
531
532    // Print overall statistics
533    println!("------------------------------------------------------------------");
534    println!("Total tests run: {}", total_tests);
535    println!("Tests passed stage 1: {}", passed_stage_1);
536    println!("Tests passed: {}", passed_tests);
537    println!("Tests failed: {}", failed_tests);
538    println!("Total time: {}", total_elapsed_time_str);
539    println!("Average time per test: {}", average_time_str);
540
541    // Log overall statistics
542    info!("------------------------------------------------------------------");
543    info!("Total tests run: {}", total_tests);
544    info!("Tests passed stage 1: {}", passed_stage_1);
545    info!("Tests passed: {}", passed_tests);
546    info!("Tests failed: {}", failed_tests);
547    info!("Total time: {}", total_elapsed_time_str);
548    info!("Average time per test: {}", average_time_str);
549
550    if let Some(min_time) = min_test_time {
551        let min_time_secs: f64 = min_time.as_secs_f64();
552        let min_time_str: String = if min_time_secs < 1.0 {
553            format!("{:.2}ms", min_time_secs * 1000.0)
554        } else {
555            format!("{:.2}s", min_time_secs)
556        };
557        println!("Shortest test time: {}", min_time_str);
558        info!("Shortest test time: {}", min_time_str);
559    }
560
561    if let Some(max_time) = max_test_time {
562        let max_time_secs: f64 = max_time.as_secs_f64();
563        let max_time_str: String = if max_time_secs < 1.0 {
564            format!("{:.2}ms", max_time_secs * 1000.0)
565        } else {
566            format!("{:.2}s", max_time_secs)
567        };
568        println!("Longest test time: {}", max_time_str);
569        info!("Longest test time: {}", max_time_str);
570    }
571}
572
573#[allow(dead_code)]
574pub fn test_spammy() {
575    // Clear the output directory before running tests
576    let output_dir = PathBuf::from("./output");
577    remove_all_files(&output_dir);
578
579    // Measure total time at the start
580    let overall_start_time: Instant = Instant::now();
581
582    info!("Starting tests...");
583
584    // Initialize counters and time trackers
585    let mut total_tests: i32 = 0;
586    let mut passed_stage_1: i32 = 0;
587    let mut passed_tests: i32 = 0;
588    let mut failed_tests: i32 = 0;
589    let mut total_test_time: Duration = Duration::new(0, 0);
590    let mut min_test_time: Option<Duration> = None;
591    let mut max_test_time: Option<Duration> = None;
592
593    let allowed_tests: Vec<&'static str> = vec![
594        // "break_loop_nested",
595        // "comments_in_block_expr",
596    ];
597
598    for (index, test_file) in TEST_FILES.iter().enumerate() {
599
600        // Take a snapshot of all files in the environment
601        let snapshot = fs::read_dir("./").unwrap().map(|entry| entry.unwrap().path()).collect::<Vec<PathBuf>>();
602
603        let test_start_time: Instant = Instant::now();
604
605        total_tests += 1;
606
607        let input: TestInput = TestInput::from(test_file);
608        let output_path: String = input.output_path.clone();
609        let expected_file_path: String = PathBuf::new()
610            .join("testdata")
611            .join(&test_file.input_file)
612            .join("output")
613            .join("main.rs")
614            .to_string_lossy()
615            .to_string();
616
617        // Skip tests not in the allowed_tests list
618        // if the allowed_tests list is not empty
619        if !allowed_tests.is_empty() && !allowed_tests.contains(&test_file.input_file) {
620            continue;
621        }
622
623        let extraction_input: ExtractionInput = ExtractionInput::from(&input);
624
625        // Call the extraction method and handle errors
626        let extraction_result: Result<(String, _), ExtractionError> = extract_method(extraction_input);
627
628        // Measure time taken for extraction
629        let test_elapsed_time: Duration = test_start_time.elapsed();
630        total_test_time += test_elapsed_time;
631
632        // Update min and max times
633        if let Some(min_time) = min_test_time {
634            if test_elapsed_time < min_time {
635                min_test_time = Some(test_elapsed_time);
636            }
637        } else {
638            min_test_time = Some(test_elapsed_time);
639        }
640
641        if let Some(max_time) = max_test_time {
642            if test_elapsed_time > max_time {
643                max_test_time = Some(test_elapsed_time);
644            }
645        } else {
646            max_test_time = Some(test_elapsed_time);
647        }
648
649        let test_elapsed_time_secs: f64 = test_elapsed_time.as_secs_f64();
650        let test_elapsed_time_str: String = if test_elapsed_time_secs < 1.0 {
651            format!("{:.2}ms", test_elapsed_time_secs * 1000.0)
652        } else {
653            format!("{:.2}s", test_elapsed_time_secs)
654        };
655
656        let test_name: &str = test_file.input_file.trim_end_matches(".rs");
657        let mut extraction_status: String = "FAILED".red().to_string();
658        let mut comparison_status: String = "N/A".cyan().to_string(); // Default to not applicable
659        let mut compilation_status: String = "N/A".magenta().to_string(); // Default to not applicable
660
661        if extraction_result.is_ok() {
662            // Unwrap the result to get the output code (as we know that it is
663            // successful)
664            // Also write the output code to the output file for later viewing
665            let (extraction_result, _)  = extraction_result.unwrap();
666            fs::write(&output_path, &extraction_result).unwrap();
667            extraction_status = "PASSED".green().to_string();
668            passed_stage_1 += 1;
669
670            // Compare the output file with the expected file's AST
671            match parse_and_compare_ast(&extraction_result, &expected_file_path) {
672                Ok(is_identical) => {
673                    if is_identical {
674                        comparison_status = "PASSED".green().to_string();
675                        passed_tests += 1;
676                    } else {
677                        comparison_status = "FAILED".red().to_string();
678                        failed_tests += 1;
679                    }
680                }
681                Err(e) => {
682                    comparison_status = format!("Error: {}", e).red().to_string();
683                    failed_tests += 1;
684                }
685            }
686
687            // Complilation check using rustc
688            let compile_result = Command::new("rustc")
689                .arg(&output_path)
690                .output();
691
692            match compile_result {
693                Ok(output) => {
694                    if output.status.success() {
695                        compilation_status = "PASSED".green().to_string();
696                    } else {
697                        compilation_status = format!("FAILED: {}", String::from_utf8_lossy(&output.stderr)).red().to_string();
698                        // failed_tests += 1;
699                    }
700                }
701                Err(e) => {
702                    compilation_status = format!("Error: {}", e).red().to_string();
703                    // failed_tests += 1;
704                }
705            }
706
707        } else if let Err(e) = extraction_result {
708            extraction_status = format!("FAILED: {}", e).red().to_string();
709            failed_tests += 1;
710        }
711
712        println!(
713            "Test {} | {} | {} | {}: {} in {}",
714            index + 1,
715            extraction_status,
716            comparison_status,
717            compilation_status,
718            test_name,
719            test_elapsed_time_str
720        );
721
722        let clean_extraction_status = strip_ansi_codes(&extraction_status);
723        let clean_comparison_status = strip_ansi_codes(&comparison_status);
724        let clean_compilation_status = strip_ansi_codes(&compilation_status);
725
726        info!(
727            "Test {} | {} | {} | {}: {} in {}",
728            index + 1,
729            clean_extraction_status,
730            clean_comparison_status,
731            clean_compilation_status,
732            test_name,
733            test_elapsed_time_str
734        );
735        // Print differences if the test failed
736        if clean_comparison_status == "FAILED" || clean_extraction_status == "FAILED" {
737            println!("==================================================================");
738            println!("Differences or compilation errors found for test '{}':", test_name);
739            print_file_diff(&expected_file_path, &output_path).unwrap();
740            println!("==================================================================");
741            println!("");
742        }
743
744        // Delete all files created by the test (i.e anything not in the
745        // snapshot)
746        let current_files = fs::read_dir("./").unwrap().map(|entry| entry.unwrap().path()).collect::<Vec<PathBuf>>();
747        for file in current_files {
748            if !snapshot.contains(&file) {
749                fs::remove_file(file).unwrap();
750            }
751        }
752    }
753
754    // Total elapsed time
755    let total_elapsed_time: Duration = overall_start_time.elapsed();
756    let total_elapsed_time_secs: f64 = total_elapsed_time.as_secs_f64();
757    let total_elapsed_time_str: String = if total_elapsed_time_secs < 1.0 {
758        format!("{:.2}ms", total_elapsed_time_secs * 1000.0)
759    } else {
760        format!("{:.2}s", total_elapsed_time_secs)
761    };
762
763    // Calculate average time per test
764    let average_time_per_test: f64 = if total_tests > 0 {
765        total_test_time.as_secs_f64() / total_tests as f64
766    } else {
767        0.0
768    };
769
770    let average_time_str: String = if average_time_per_test < 1.0 {
771        format!("{:.2}ms", average_time_per_test * 1000.0)
772    } else {
773        format!("{:.2}s", average_time_per_test)
774    };
775
776    // Print overall statistics
777    println!("------------------------------------------------------------------");
778    println!("Total tests run: {}", total_tests);
779    println!("Tests passed stage 1: {}", passed_stage_1);
780    println!("Tests passed: {}", passed_tests);
781    println!("Tests failed: {}", failed_tests);
782    println!("Total time: {}", total_elapsed_time_str);
783    println!("Average time per test: {}", average_time_str);
784
785    // Log overall statistics
786    info!("------------------------------------------------------------------");
787    info!("Total tests run: {}", total_tests);
788    info!("Tests passed stage 1: {}", passed_stage_1);
789    info!("Tests passed: {}", passed_tests);
790    info!("Tests failed: {}", failed_tests);
791    info!("Total time: {}", total_elapsed_time_str);
792    info!("Average time per test: {}", average_time_str);
793
794    if let Some(min_time) = min_test_time {
795        let min_time_secs: f64 = min_time.as_secs_f64();
796        let min_time_str: String = if min_time_secs < 1.0 {
797            format!("{:.2}ms", min_time_secs * 1000.0)
798        } else {
799            format!("{:.2}s", min_time_secs)
800        };
801        println!("Shortest test time: {}", min_time_str);
802        info!("Shortest test time: {}", min_time_str);
803    }
804
805    if let Some(max_time) = max_test_time {
806        let max_time_secs: f64 = max_time.as_secs_f64();
807        let max_time_str: String = if max_time_secs < 1.0 {
808            format!("{:.2}ms", max_time_secs * 1000.0)
809        } else {
810            format!("{:.2}s", max_time_secs)
811        };
812        println!("Longest test time: {}", max_time_str);
813        info!("Longest test time: {}", max_time_str);
814    }
815}
816
817#[cfg(test)]
818mod tests {
819    use super::*;
820    use std::fs;
821    use tempfile::NamedTempFile;
822
823    #[test]
824    fn test_parse_and_compare_ast_identical() -> Result<(), ExtractionError> {
825        // Create temporary files for expected and output content
826        let expected_file: NamedTempFile = NamedTempFile::new()?;
827
828        // Write identical content to both files
829        let content: &str = "fn example() -> i32 { 42 }";
830        fs::write(expected_file.path(), content)?;
831
832        // Run the function
833        let result: bool = parse_and_compare_ast(
834            &content.to_string(),
835            expected_file.path().to_str().unwrap())?;
836
837        // Assert that the result is true
838        assert!(result, "The ASTs should be identical");
839
840        Ok(())
841    }
842
843    #[test]
844    fn test_parse_and_compare_ast_different() -> Result<(), ExtractionError> {
845        // Create temporary files for expected and output content
846        let expected_file: NamedTempFile = NamedTempFile::new()?;
847
848        // Write different content to the files
849        let output_content: &str = "fn example() -> i32 { 42 }";
850        let expected_content: &str = "fn example() -> i32 { 43 }";
851        fs::write(expected_file.path(), expected_content)?;
852
853        // Run the function
854        let result: bool = parse_and_compare_ast(
855            &output_content.to_string(),
856            expected_file.path().to_str().unwrap()
857        )?;
858
859        // Assert that the result is false
860        assert!(!result, "The ASTs should be different");
861
862        Ok(())
863    }
864
865    #[test]
866    fn test_parse_and_compare_ast_file_not_found() -> Result<(), ExtractionError> {
867        // Non-existent file paths
868        let non_existent_file: &str = "non_existent_file.rs";
869
870        // Run the function
871        let result: Result<bool, ExtractionError> = parse_and_compare_ast(
872            &"".to_string(),
873             non_existent_file
874        );
875
876        // Assert that the result is an error
877        assert!(result.is_err(), "The function should return an error for non-existent files");
878
879        Ok(())
880    }
881
882    #[test]
883    fn test_parse_and_compare_ast_empty_files() -> Result<(), ExtractionError> {
884        // Create temporary files for empty content
885        let expected_file: NamedTempFile = NamedTempFile::new()?;
886
887        // Write empty content to both files
888        fs::write(expected_file.path(), "")?;
889
890        // Run the function
891        let result: bool = parse_and_compare_ast(
892            &"".to_string(),
893            expected_file.path().to_str().unwrap(),
894        )?;
895
896        // Assert that the result is true
897        assert!(result, "The ASTs for empty files should be identical");
898
899        Ok(())
900    }
901
902    #[test]
903    fn test_parse_and_compare_ast_invalid_content() -> Result<(), ExtractionError> {
904        // Create temporary files with invalid content
905        let expected_file: NamedTempFile = NamedTempFile::new()?;
906
907        // Write invalid content to both files
908        let invalid_content: &str = "fn example { 42 "; // Missing closing brace
909        fs::write(expected_file.path(), invalid_content)?;
910
911        // Run the function
912        let result: Result<bool, ExtractionError> = parse_and_compare_ast(
913            &invalid_content.to_string(),
914            expected_file.path().to_str().unwrap()
915        );
916
917        // Assert that the result is an error
918        assert!(result.is_err(), "The function should return an error for invalid content");
919
920        Ok(())
921    }
922
923    #[test]
924    fn test_parse_and_compare_ast_different_formatting() -> Result<(), ExtractionError> {
925        // Create temporary files with the same logical content but different formatting
926        let expected_file: NamedTempFile = NamedTempFile::new()?;
927
928        // Write different formatting to the files
929        let output_content: &str = "fn example() -> i32 { 42 }";
930        let expected_content: &str = "fn example() -> i32 {\n    42\n}";
931        fs::write(expected_file.path(), expected_content)?;
932
933        // Run the function
934        let result: bool = parse_and_compare_ast(
935            &output_content.to_string(),
936            expected_file.path().to_str().unwrap(),
937        )?;
938
939        // Assert that the result is true (assuming the formatting does not affect the AST)
940        assert!(result, "The ASTs should be identical despite different formatting");
941
942        Ok(())
943    }
944}