Skip to main content

hegel/
test_case.rs

1use crate::cbor_utils::{cbor_map, map_insert};
2use crate::generators::Generator;
3use crate::protocol::{Channel, Connection, SERVER_CRASHED_MESSAGE};
4use crate::runner::Verbosity;
5use ciborium::Value;
6use std::cell::RefCell;
7use std::rc::Rc;
8use std::sync::{Arc, LazyLock};
9
10use crate::generators::value;
11
12// We use the __IsTestCase trait internally to provide nice error messages for misuses of #[composite].
13// It should not be used by users.
14//
15// The idea is #[composite] calls __assert_is_test_case(<first param>), which errors with our on_unimplemented
16// message iff the first param does not have type TestCase.
17
18#[diagnostic::on_unimplemented(
19    // NOTE: worth checking if edits to this message should also be applied to the similar-but-different
20    // error message in #[composite] in hegel-macros.
21    message = "The first parameter in a #[composite] generator must have type TestCase.",
22    label = "This type does not match `TestCase`."
23)]
24pub trait __IsTestCase {}
25impl __IsTestCase for TestCase {}
26pub fn __assert_is_test_case<T: __IsTestCase>() {}
27
28/// Error indicating the server ran out of data for this test case.
29#[derive(Debug)]
30pub struct StopTestError;
31impl std::fmt::Display for StopTestError {
32    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
33        write!(f, "Server ran out of data (StopTest)")
34    }
35}
36impl std::error::Error for StopTestError {}
37
38static PROTOCOL_DEBUG: LazyLock<bool> = LazyLock::new(|| {
39    matches!(
40        std::env::var("HEGEL_PROTOCOL_DEBUG")
41            .unwrap_or_default()
42            .to_lowercase()
43            .as_str(),
44        "1" | "true"
45    )
46});
47
48pub(crate) const ASSUME_FAIL_STRING: &str = "__HEGEL_ASSUME_FAIL";
49
50/// The sentinel string used to identify overflow/StopTest panics.
51/// Distinct from ASSUME_FAIL_STRING so callers can tell user-initiated
52/// assumption failures apart from server-initiated data exhaustion.
53pub(crate) const STOP_TEST_STRING: &str = "__HEGEL_STOP_TEST";
54
55pub(crate) struct TestCaseGlobalData {
56    #[allow(dead_code)]
57    connection: Arc<Connection>,
58    channel: Channel,
59    verbosity: Verbosity,
60    is_last_run: bool,
61    test_aborted: bool,
62}
63
64#[derive(Clone)]
65pub(crate) struct TestCaseLocalData {
66    span_depth: usize,
67    draw_count: usize,
68    indent: usize,
69    on_draw: Rc<dyn Fn(&str)>,
70}
71
72/// A handle to the current test case.
73///
74/// This is passed to `#[hegel::test]` functions and provides methods
75/// for drawing values, making assumptions, and recording notes.
76///
77/// # Example
78///
79/// ```no_run
80/// use hegel::generators as gs;
81///
82/// #[hegel::test]
83/// fn my_test(tc: hegel::TestCase) {
84///     let x: i32 = tc.draw(gs::integers());
85///     tc.assume(x > 0);
86///     tc.note(&format!("x = {}", x));
87/// }
88/// ```
89pub struct TestCase {
90    global: Rc<RefCell<TestCaseGlobalData>>,
91    local: RefCell<TestCaseLocalData>,
92}
93
94impl Clone for TestCase {
95    fn clone(&self) -> Self {
96        TestCase {
97            global: self.global.clone(),
98            local: RefCell::new(self.local.borrow().clone()),
99        }
100    }
101}
102
103impl std::fmt::Debug for TestCase {
104    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
105        f.debug_struct("TestCase").finish_non_exhaustive()
106    }
107}
108
109impl TestCase {
110    pub(crate) fn new(
111        connection: Arc<Connection>,
112        channel: Channel,
113        verbosity: Verbosity,
114        is_last_run: bool,
115    ) -> Self {
116        let on_draw: Rc<dyn Fn(&str)> = if is_last_run {
117            Rc::new(|msg| eprintln!("{}", msg))
118        } else {
119            Rc::new(|_| {})
120        };
121        TestCase {
122            global: Rc::new(RefCell::new(TestCaseGlobalData {
123                connection,
124                channel,
125                verbosity,
126                is_last_run,
127                test_aborted: false,
128            })),
129            local: RefCell::new(TestCaseLocalData {
130                span_depth: 0,
131                draw_count: 0,
132                indent: 0,
133                on_draw,
134            }),
135        }
136    }
137
138    /// Draw a value from a generator.
139    ///
140    /// # Example
141    ///
142    /// ```no_run
143    /// use hegel::generators as gs;
144    ///
145    /// #[hegel::test]
146    /// fn my_test(tc: hegel::TestCase) {
147    ///     let x: i32 = tc.draw(gs::integers());
148    ///     let s: String = tc.draw(gs::text());
149    /// }
150    /// ```
151    pub fn draw<T: std::fmt::Debug>(&self, generator: impl Generator<T>) -> T {
152        let value = generator.do_draw(self);
153        if self.local.borrow().span_depth == 0 {
154            self.record_draw(&value);
155        }
156        value
157    }
158
159    /// Draw a value from a generator without recording it in the output.
160    ///
161    /// Unlike [`draw`](Self::draw), this does not require `T: Debug` and
162    /// will not print the value in the failing-test summary.
163    pub fn draw_silent<T>(&self, generator: impl Generator<T>) -> T {
164        generator.do_draw(self)
165    }
166
167    /// Assume a condition is true. If false, reject the current test input.
168    ///
169    /// # Example
170    ///
171    /// ```no_run
172    /// use hegel::generators as gs;
173    ///
174    /// #[hegel::test]
175    /// fn my_test(tc: hegel::TestCase) {
176    ///     let age: u32 = tc.draw(gs::integers());
177    ///     tc.assume(age >= 18);
178    /// }
179    /// ```
180    pub fn assume(&self, condition: bool) {
181        if !condition {
182            panic!("{}", ASSUME_FAIL_STRING);
183        }
184    }
185
186    /// Note a message which will be displayed with the reported failing test case.
187    ///
188    /// Only prints during the final replay of a failing test case.
189    ///
190    /// # Example
191    ///
192    /// ```no_run
193    /// use hegel::generators as gs;
194    ///
195    /// #[hegel::test]
196    /// fn my_test(tc: hegel::TestCase) {
197    ///     let x: i32 = tc.draw(gs::integers());
198    ///     tc.note(&format!("Generated x = {}", x));
199    /// }
200    /// ```
201    pub fn note(&self, message: &str) {
202        if self.global.borrow().is_last_run {
203            let indent = self.local.borrow().indent;
204            eprintln!("{:indent$}{}", "", message, indent = indent);
205        }
206    }
207
208    pub(crate) fn child(&self, extra_indent: usize) -> Self {
209        let local = self.local.borrow();
210        TestCase {
211            global: self.global.clone(),
212            local: RefCell::new(TestCaseLocalData {
213                span_depth: 0,
214                draw_count: 0,
215                indent: local.indent + extra_indent,
216                on_draw: local.on_draw.clone(),
217            }),
218        }
219    }
220
221    fn record_draw<T: std::fmt::Debug>(&self, value: &T) {
222        let mut local = self.local.borrow_mut();
223        local.draw_count += 1;
224        let count = local.draw_count;
225        let indent = local.indent;
226        (local.on_draw)(&format!(
227            "{:indent$}Draw {}: {:?}",
228            "",
229            count,
230            value,
231            indent = indent
232        ));
233    }
234
235    #[doc(hidden)]
236    pub fn start_span(&self, label: u64) {
237        self.local.borrow_mut().span_depth += 1;
238        if let Err(StopTestError) = self.send_request("start_span", &cbor_map! {"label" => label}) {
239            let mut local = self.local.borrow_mut();
240            assert!(local.span_depth > 0);
241            local.span_depth -= 1;
242            drop(local);
243            panic!("{}", STOP_TEST_STRING);
244        }
245    }
246
247    #[doc(hidden)]
248    pub fn stop_span(&self, discard: bool) {
249        {
250            let mut local = self.local.borrow_mut();
251            assert!(local.span_depth > 0);
252            local.span_depth -= 1;
253        }
254        let _ = self.send_request("stop_span", &cbor_map! {"discard" => discard});
255    }
256
257    /// Returns Err(StopTestError) if the server sends an overflow error.
258    pub(crate) fn send_request(
259        &self,
260        command: &str,
261        payload: &Value,
262    ) -> Result<Value, StopTestError> {
263        let mut global = self.global.borrow_mut();
264
265        // If a previous request already triggered overflow/StopTest, the server
266        // has closed this channel. Don't send another request—it would block.
267        // (The channel-level closed check is also enforced, but this gives a
268        // clean StopTestError instead of an io::Error.)
269        if global.test_aborted {
270            return Err(StopTestError);
271        }
272        let debug = *PROTOCOL_DEBUG || global.verbosity == Verbosity::Debug;
273
274        let mut entries = vec![(
275            Value::Text("command".to_string()),
276            Value::Text(command.to_string()),
277        )];
278
279        if let Value::Map(map) = payload {
280            for (k, v) in map {
281                entries.push((k.clone(), v.clone()));
282            }
283        }
284
285        let request = Value::Map(entries);
286
287        if debug {
288            eprintln!("REQUEST: {:?}", request);
289        }
290
291        let result = global.channel.request_cbor(&request);
292        drop(global);
293
294        match result {
295            Ok(response) => {
296                if debug {
297                    eprintln!("RESPONSE: {:?}", response);
298                }
299                Ok(response)
300            }
301            Err(e) => {
302                let error_msg = e.to_string();
303                if error_msg.contains("overflow")
304                    || error_msg.contains("StopTest")
305                    || error_msg.contains("channel is closed")
306                {
307                    if debug {
308                        eprintln!("RESPONSE: StopTest/overflow");
309                    }
310                    let mut global = self.global.borrow_mut();
311                    global.channel.mark_closed();
312                    global.test_aborted = true;
313                    drop(global);
314                    Err(StopTestError)
315                } else if error_msg.contains("FlakyStrategyDefinition")
316                    || error_msg.contains("FlakyReplay")
317                {
318                    // Abort the test case; the server will report the flaky
319                    // error in the test_done results, which runner.rs handles.
320                    let mut global = self.global.borrow_mut();
321                    global.channel.mark_closed();
322                    global.test_aborted = true;
323                    drop(global);
324                    Err(StopTestError)
325                } else if self.global.borrow().connection.server_has_exited() {
326                    panic!("{}", SERVER_CRASHED_MESSAGE);
327                } else {
328                    panic!("Failed to communicate with Hegel: {}", e);
329                }
330            }
331        }
332    }
333
334    // --- Methods for runner access ---
335
336    pub(crate) fn test_aborted(&self) -> bool {
337        self.global.borrow().test_aborted
338    }
339
340    pub(crate) fn send_mark_complete(&self, mark_complete: &Value) {
341        let mut global = self.global.borrow_mut();
342        let _ = global.channel.request_cbor(mark_complete);
343        let _ = global.channel.close();
344    }
345}
346
347/// Send a schema to the server and return the raw CBOR response.
348#[doc(hidden)]
349pub fn generate_raw(tc: &TestCase, schema: &Value) -> Value {
350    match tc.send_request("generate", &cbor_map! {"schema" => schema.clone()}) {
351        Ok(v) => v,
352        Err(StopTestError) => {
353            panic!("{}", STOP_TEST_STRING);
354        }
355    }
356}
357
358#[doc(hidden)]
359pub fn generate_from_schema<T: serde::de::DeserializeOwned>(tc: &TestCase, schema: &Value) -> T {
360    deserialize_value(generate_raw(tc, schema))
361}
362
363/// Deserialize a raw CBOR value into a Rust type.
364///
365/// This is a public helper for use by derived generators (proc macros)
366/// that need to deserialize individual field values from CBOR.
367pub fn deserialize_value<T: serde::de::DeserializeOwned>(raw: Value) -> T {
368    let hv = value::HegelValue::from(raw.clone());
369    value::from_hegel_value(hv).unwrap_or_else(|e| {
370        panic!("Failed to deserialize value: {}\nValue: {:?}", e, raw);
371    })
372}
373
374/// Uses the hegel server to determine collection sizing.
375///
376/// The server-side `many` object is created lazily on the first call to
377/// [`more()`](Collection::more).
378pub struct Collection<'a> {
379    tc: &'a TestCase,
380    base_name: String,
381    min_size: usize,
382    max_size: Option<usize>,
383    server_name: Option<String>,
384    finished: bool,
385}
386
387impl<'a> Collection<'a> {
388    /// Create a new server-managed collection.
389    pub fn new(tc: &'a TestCase, name: &str, min_size: usize, max_size: Option<usize>) -> Self {
390        Collection {
391            tc,
392            base_name: name.to_string(),
393            min_size,
394            max_size,
395            server_name: None,
396            finished: false,
397        }
398    }
399
400    fn ensure_initialized(&mut self) -> &str {
401        if self.server_name.is_none() {
402            let mut payload = cbor_map! {
403                "name" => self.base_name.as_str(),
404                "min_size" => self.min_size as u64
405            };
406            if let Some(max) = self.max_size {
407                map_insert(&mut payload, "max_size", max as u64);
408            }
409            let response = match self.tc.send_request("new_collection", &payload) {
410                Ok(v) => v,
411                Err(StopTestError) => {
412                    panic!("{}", STOP_TEST_STRING);
413                }
414            };
415            let name = match response {
416                Value::Text(s) => s,
417                _ => panic!(
418                    "Expected text response from new_collection, got {:?}",
419                    response
420                ),
421            };
422            self.server_name = Some(name);
423        }
424        self.server_name.as_ref().unwrap()
425    }
426
427    /// Ask the server whether to produce another element.
428    pub fn more(&mut self) -> bool {
429        if self.finished {
430            return false;
431        }
432        let server_name = self.ensure_initialized().to_string();
433        let response = match self.tc.send_request(
434            "collection_more",
435            &cbor_map! { "collection" => server_name.as_str() },
436        ) {
437            Ok(v) => v,
438            Err(StopTestError) => {
439                self.finished = true;
440                panic!("{}", STOP_TEST_STRING);
441            }
442        };
443        let result = match response {
444            Value::Bool(b) => b,
445            _ => panic!("Expected bool from collection_more, got {:?}", response),
446        };
447        if !result {
448            self.finished = true;
449        }
450        result
451    }
452
453    /// Reject the last element (don't count it towards the size budget).
454    pub fn reject(&mut self, why: Option<&str>) {
455        if self.finished {
456            return;
457        }
458        let server_name = self.ensure_initialized().to_string();
459        let mut payload = cbor_map! {
460            "collection" => server_name.as_str()
461        };
462        if let Some(reason) = why {
463            map_insert(&mut payload, "why", reason.to_string());
464        }
465        let _ = self.tc.send_request("collection_reject", &payload);
466    }
467}
468
469#[doc(hidden)]
470pub mod labels {
471    pub const LIST: u64 = 1;
472    pub const LIST_ELEMENT: u64 = 2;
473    pub const SET: u64 = 3;
474    pub const SET_ELEMENT: u64 = 4;
475    pub const MAP: u64 = 5;
476    pub const MAP_ENTRY: u64 = 6;
477    pub const TUPLE: u64 = 7;
478    pub const ONE_OF: u64 = 8;
479    pub const OPTIONAL: u64 = 9;
480    pub const FIXED_DICT: u64 = 10;
481    pub const FLAT_MAP: u64 = 11;
482    pub const FILTER: u64 = 12;
483    pub const MAPPED: u64 = 13;
484    pub const SAMPLED_FROM: u64 = 14;
485    pub const ENUM_VARIANT: u64 = 15;
486}