cargo_difftests/
test_rerunner_core.rs

1/*
2 *        Copyright (c) 2023-2024 Dinu Blanovschi
3 *
4 *    Licensed under the Apache License, Version 2.0 (the "License");
5 *    you may not use this file except in compliance with the License.
6 *    You may obtain a copy of the License at
7 *
8 *        https://www.apache.org/licenses/LICENSE-2.0
9 *
10 *    Unless required by applicable law or agreed to in writing, software
11 *    distributed under the License is distributed on an "AS IS" BASIS,
12 *    WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 *    See the License for the specific language governing permissions and
14 *    limitations under the License.
15 */
16
17use std::marker::PhantomData;
18
19use cargo_difftests_core::CoreTestDesc;
20
21use crate::{difftest::TestInfo, AnalyzeAllSingleTest, DifftestsResult};
22
23#[derive(serde::Serialize, serde::Deserialize)]
24pub enum State {
25    None,
26    Running {
27        current_test_count: usize,
28        total_test_count: usize,
29    },
30    Done,
31    Error,
32}
33
34pub struct TestRunnerInvocationTestCounts<'invocation> {
35    state: State,
36    _pd: PhantomData<&'invocation ()>,
37}
38
39impl<'invocation> Drop for TestRunnerInvocationTestCounts<'invocation> {
40    fn drop(&mut self) {
41        self.test_count_done().unwrap();
42    }
43}
44
45pub struct TestRunnerInvocationTestCountsTestGuard<'invocation, 'counts> {
46    counts: &'counts mut TestRunnerInvocationTestCounts<'invocation>,
47    test_name: String,
48}
49
50impl<'invocation, 'counts> TestRunnerInvocationTestCountsTestGuard<'invocation, 'counts> {
51    pub fn test_successful(self) -> DifftestsResult<()> {
52        self.counts.inc()?;
53        println!("cargo-difftests-test-successful::{}", self.test_name);
54        Ok(())
55    }
56
57    pub fn test_failed(self) -> DifftestsResult<()> {
58        self.counts.fail_if_running()?;
59        println!("cargo-difftests-test-failed::{}", self.test_name);
60        Ok(())
61    }
62}
63
64impl<'invocation> TestRunnerInvocationTestCounts<'invocation> {
65    pub fn initialize_test_counts(&mut self, total_tests_to_run: usize) -> DifftestsResult<()> {
66        match self.state {
67            State::None => {
68                self.state = State::Running {
69                    current_test_count: 0,
70                    total_test_count: total_tests_to_run,
71                };
72
73                self.write_test_counts()?;
74
75                Ok(())
76            }
77            _ => panic!("test counts already initialized"),
78        }
79    }
80
81    pub fn start_test<'counts>(
82        &'counts mut self,
83        test_name: String,
84    ) -> DifftestsResult<TestRunnerInvocationTestCountsTestGuard<'invocation, 'counts>> {
85        match self.state {
86            State::Running { .. } => {}
87            _ => panic!("test counts not initialized"),
88        }
89
90        println!("cargo-difftests-start-test::{}", test_name);
91
92        Ok(TestRunnerInvocationTestCountsTestGuard {
93            counts: self,
94            test_name,
95        })
96    }
97
98    pub fn inc(&mut self) -> DifftestsResult<()> {
99        match &mut self.state {
100            State::None => {
101                panic!("test counts not initialized");
102            }
103            State::Running {
104                current_test_count,
105                total_test_count,
106            } => {
107                *current_test_count += 1;
108                assert!(*current_test_count <= *total_test_count);
109            }
110            State::Done | State::Error => {
111                panic!("test counts already done");
112            }
113        }
114
115        self.write_test_counts()?;
116
117        Ok(())
118    }
119
120    pub fn test_count_done(&mut self) -> DifftestsResult {
121        match self.state {
122            State::Done => {}
123            State::Running { .. } => {
124                self.state = State::Done;
125                self.write_test_counts()?;
126            }
127            _ => panic!("test counts not initialized"),
128        }
129
130        Ok(())
131    }
132
133    pub fn fail_if_running(&mut self) -> DifftestsResult {
134        match self.state {
135            State::Running { .. } => {
136                self.state = State::Error;
137                self.write_test_counts()?;
138            }
139            _ => {}
140        }
141
142        Ok(())
143    }
144
145    fn write_test_counts(&self) -> DifftestsResult {
146        println!(
147            "cargo-difftests-test-counts::{}",
148            serde_json::to_string(&self.state)?
149        );
150        Ok(())
151    }
152}
153
154#[derive(serde::Serialize, serde::Deserialize)]
155pub struct TestRerunnerInvocation {
156    tests: Vec<TestInfo>,
157}
158
159impl TestRerunnerInvocation {
160    pub fn create_invocation_from<'a>(
161        iter: impl IntoIterator<Item = &'a AnalyzeAllSingleTest>,
162    ) -> DifftestsResult<Self> {
163        let mut tests = vec![];
164
165        for g in iter {
166            if let Some(difftest) = &g.difftest {
167                tests.push(difftest.test_info()?);
168            } else {
169                // Most likely came from an index.
170                tests.push(g.test_info.clone());
171            }
172        }
173
174        Ok(Self { tests })
175    }
176
177    pub fn is_empty(&self) -> bool {
178        self.tests.is_empty()
179    }
180
181    pub fn tests(&self) -> &[TestInfo] {
182        &self.tests
183    }
184
185    pub fn test_counts(&self) -> TestRunnerInvocationTestCounts {
186        TestRunnerInvocationTestCounts {
187            state: State::None,
188            _pd: PhantomData,
189        }
190    }
191}
192
193pub const CARGO_DIFFTESTS_VER_NAME: &str = "CARGO_DIFFTESTS_VER";
194
195pub fn read_invocation_from_command_line() -> DifftestsResult<TestRerunnerInvocation> {
196    let v = std::env::var(CARGO_DIFFTESTS_VER_NAME).map_err(|e| {
197        std::io::Error::new(
198            std::io::ErrorKind::InvalidInput,
199            format!("missing env var: {}", e),
200        )
201    })?;
202
203    if v != env!("CARGO_PKG_VERSION") {
204        return Err(std::io::Error::new(
205            std::io::ErrorKind::InvalidInput,
206            format!(
207                "version mismatch: expected {} (our version), got {} (cargo-difftests version)",
208                env!("CARGO_PKG_VERSION"),
209                v
210            ),
211        )
212        .into());
213    }
214
215    let mut args = std::env::args().skip(1);
216
217    let f = args.next().ok_or_else(|| {
218        std::io::Error::new(std::io::ErrorKind::InvalidInput, "missing invocation file")
219    })?;
220
221    let invocation_str = std::fs::read_to_string(f)?;
222    let invocation = serde_json::from_str(&invocation_str)?;
223
224    Ok(invocation)
225}
226
227#[macro_export]
228macro_rules! cargo_difftests_test_rerunner {
229    ($impl_fn:path) => {
230        fn main() -> $crate::DifftestsResult<impl std::process::Termination> {
231            let invocation = $crate::test_rerunner_core::read_invocation_from_command_line()?;
232
233            let result = $impl_fn(invocation);
234
235            Ok(result)
236        }
237    };
238}