1use crate::cbor_utils::{cbor_map, map_insert};
2use crate::generators::Generator;
3use crate::protocol::{Channel, Connection, SERVER_CRASHED_MESSAGE};
4use crate::runner::Verbosity;
5use ciborium::Value;
6use std::cell::RefCell;
7use std::rc::Rc;
8use std::sync::{Arc, LazyLock};
9
10use crate::generators::value;
11
12#[diagnostic::on_unimplemented(
19 message = "The first parameter in a #[composite] generator must have type TestCase.",
22 label = "This type does not match `TestCase`."
23)]
24pub trait __IsTestCase {}
25impl __IsTestCase for TestCase {}
26pub fn __assert_is_test_case<T: __IsTestCase>() {}
27
28#[derive(Debug)]
30pub struct StopTestError;
31impl std::fmt::Display for StopTestError {
32 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
33 write!(f, "Server ran out of data (StopTest)")
34 }
35}
36impl std::error::Error for StopTestError {}
37
38static PROTOCOL_DEBUG: LazyLock<bool> = LazyLock::new(|| {
39 matches!(
40 std::env::var("HEGEL_PROTOCOL_DEBUG")
41 .unwrap_or_default()
42 .to_lowercase()
43 .as_str(),
44 "1" | "true"
45 )
46});
47
48pub(crate) const ASSUME_FAIL_STRING: &str = "__HEGEL_ASSUME_FAIL";
49
50pub(crate) const STOP_TEST_STRING: &str = "__HEGEL_STOP_TEST";
54
55pub(crate) struct TestCaseGlobalData {
56 #[allow(dead_code)]
57 connection: Arc<Connection>,
58 channel: Channel,
59 verbosity: Verbosity,
60 is_last_run: bool,
61 test_aborted: bool,
62}
63
64#[derive(Clone)]
65pub(crate) struct TestCaseLocalData {
66 span_depth: usize,
67 draw_count: usize,
68 indent: usize,
69 on_draw: Rc<dyn Fn(&str)>,
70}
71
72pub struct TestCase {
90 global: Rc<RefCell<TestCaseGlobalData>>,
91 local: RefCell<TestCaseLocalData>,
92}
93
94impl Clone for TestCase {
95 fn clone(&self) -> Self {
96 TestCase {
97 global: self.global.clone(),
98 local: RefCell::new(self.local.borrow().clone()),
99 }
100 }
101}
102
103impl std::fmt::Debug for TestCase {
104 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
105 f.debug_struct("TestCase").finish_non_exhaustive()
106 }
107}
108
109impl TestCase {
110 pub(crate) fn new(
111 connection: Arc<Connection>,
112 channel: Channel,
113 verbosity: Verbosity,
114 is_last_run: bool,
115 ) -> Self {
116 let on_draw: Rc<dyn Fn(&str)> = if is_last_run {
117 Rc::new(|msg| eprintln!("{}", msg))
118 } else {
119 Rc::new(|_| {})
120 };
121 TestCase {
122 global: Rc::new(RefCell::new(TestCaseGlobalData {
123 connection,
124 channel,
125 verbosity,
126 is_last_run,
127 test_aborted: false,
128 })),
129 local: RefCell::new(TestCaseLocalData {
130 span_depth: 0,
131 draw_count: 0,
132 indent: 0,
133 on_draw,
134 }),
135 }
136 }
137
138 pub fn draw<T: std::fmt::Debug>(&self, generator: impl Generator<T>) -> T {
152 let value = generator.do_draw(self);
153 if self.local.borrow().span_depth == 0 {
154 self.record_draw(&value);
155 }
156 value
157 }
158
159 pub fn draw_silent<T>(&self, generator: impl Generator<T>) -> T {
164 generator.do_draw(self)
165 }
166
167 pub fn assume(&self, condition: bool) {
181 if !condition {
182 panic!("{}", ASSUME_FAIL_STRING);
183 }
184 }
185
186 pub fn note(&self, message: &str) {
202 if self.global.borrow().is_last_run {
203 let indent = self.local.borrow().indent;
204 eprintln!("{:indent$}{}", "", message, indent = indent);
205 }
206 }
207
208 pub(crate) fn child(&self, extra_indent: usize) -> Self {
209 let local = self.local.borrow();
210 TestCase {
211 global: self.global.clone(),
212 local: RefCell::new(TestCaseLocalData {
213 span_depth: 0,
214 draw_count: 0,
215 indent: local.indent + extra_indent,
216 on_draw: local.on_draw.clone(),
217 }),
218 }
219 }
220
221 fn record_draw<T: std::fmt::Debug>(&self, value: &T) {
222 let mut local = self.local.borrow_mut();
223 local.draw_count += 1;
224 let count = local.draw_count;
225 let indent = local.indent;
226 (local.on_draw)(&format!(
227 "{:indent$}Draw {}: {:?}",
228 "",
229 count,
230 value,
231 indent = indent
232 ));
233 }
234
235 #[doc(hidden)]
236 pub fn start_span(&self, label: u64) {
237 self.local.borrow_mut().span_depth += 1;
238 if let Err(StopTestError) = self.send_request("start_span", &cbor_map! {"label" => label}) {
239 let mut local = self.local.borrow_mut();
240 assert!(local.span_depth > 0);
241 local.span_depth -= 1;
242 drop(local);
243 panic!("{}", STOP_TEST_STRING);
244 }
245 }
246
247 #[doc(hidden)]
248 pub fn stop_span(&self, discard: bool) {
249 {
250 let mut local = self.local.borrow_mut();
251 assert!(local.span_depth > 0);
252 local.span_depth -= 1;
253 }
254 let _ = self.send_request("stop_span", &cbor_map! {"discard" => discard});
255 }
256
257 pub(crate) fn send_request(
259 &self,
260 command: &str,
261 payload: &Value,
262 ) -> Result<Value, StopTestError> {
263 let mut global = self.global.borrow_mut();
264
265 if global.test_aborted {
270 return Err(StopTestError);
271 }
272 let debug = *PROTOCOL_DEBUG || global.verbosity == Verbosity::Debug;
273
274 let mut entries = vec![(
275 Value::Text("command".to_string()),
276 Value::Text(command.to_string()),
277 )];
278
279 if let Value::Map(map) = payload {
280 for (k, v) in map {
281 entries.push((k.clone(), v.clone()));
282 }
283 }
284
285 let request = Value::Map(entries);
286
287 if debug {
288 eprintln!("REQUEST: {:?}", request);
289 }
290
291 let result = global.channel.request_cbor(&request);
292 drop(global);
293
294 match result {
295 Ok(response) => {
296 if debug {
297 eprintln!("RESPONSE: {:?}", response);
298 }
299 Ok(response)
300 }
301 Err(e) => {
302 let error_msg = e.to_string();
303 if error_msg.contains("overflow")
304 || error_msg.contains("StopTest")
305 || error_msg.contains("channel is closed")
306 {
307 if debug {
308 eprintln!("RESPONSE: StopTest/overflow");
309 }
310 let mut global = self.global.borrow_mut();
311 global.channel.mark_closed();
312 global.test_aborted = true;
313 drop(global);
314 Err(StopTestError)
315 } else if error_msg.contains("FlakyStrategyDefinition")
316 || error_msg.contains("FlakyReplay")
317 {
318 let mut global = self.global.borrow_mut();
321 global.channel.mark_closed();
322 global.test_aborted = true;
323 drop(global);
324 Err(StopTestError)
325 } else if self.global.borrow().connection.server_has_exited() {
326 panic!("{}", SERVER_CRASHED_MESSAGE);
327 } else {
328 panic!("Failed to communicate with Hegel: {}", e);
329 }
330 }
331 }
332 }
333
334 pub(crate) fn test_aborted(&self) -> bool {
337 self.global.borrow().test_aborted
338 }
339
340 pub(crate) fn send_mark_complete(&self, mark_complete: &Value) {
341 let mut global = self.global.borrow_mut();
342 let _ = global.channel.request_cbor(mark_complete);
343 let _ = global.channel.close();
344 }
345}
346
347#[doc(hidden)]
349pub fn generate_raw(tc: &TestCase, schema: &Value) -> Value {
350 match tc.send_request("generate", &cbor_map! {"schema" => schema.clone()}) {
351 Ok(v) => v,
352 Err(StopTestError) => {
353 panic!("{}", STOP_TEST_STRING);
354 }
355 }
356}
357
358#[doc(hidden)]
359pub fn generate_from_schema<T: serde::de::DeserializeOwned>(tc: &TestCase, schema: &Value) -> T {
360 deserialize_value(generate_raw(tc, schema))
361}
362
363pub fn deserialize_value<T: serde::de::DeserializeOwned>(raw: Value) -> T {
368 let hv = value::HegelValue::from(raw.clone());
369 value::from_hegel_value(hv).unwrap_or_else(|e| {
370 panic!("Failed to deserialize value: {}\nValue: {:?}", e, raw);
371 })
372}
373
374pub struct Collection<'a> {
379 tc: &'a TestCase,
380 base_name: String,
381 min_size: usize,
382 max_size: Option<usize>,
383 server_name: Option<String>,
384 finished: bool,
385}
386
387impl<'a> Collection<'a> {
388 pub fn new(tc: &'a TestCase, name: &str, min_size: usize, max_size: Option<usize>) -> Self {
390 Collection {
391 tc,
392 base_name: name.to_string(),
393 min_size,
394 max_size,
395 server_name: None,
396 finished: false,
397 }
398 }
399
400 fn ensure_initialized(&mut self) -> &str {
401 if self.server_name.is_none() {
402 let mut payload = cbor_map! {
403 "name" => self.base_name.as_str(),
404 "min_size" => self.min_size as u64
405 };
406 if let Some(max) = self.max_size {
407 map_insert(&mut payload, "max_size", max as u64);
408 }
409 let response = match self.tc.send_request("new_collection", &payload) {
410 Ok(v) => v,
411 Err(StopTestError) => {
412 panic!("{}", STOP_TEST_STRING);
413 }
414 };
415 let name = match response {
416 Value::Text(s) => s,
417 _ => panic!(
418 "Expected text response from new_collection, got {:?}",
419 response
420 ),
421 };
422 self.server_name = Some(name);
423 }
424 self.server_name.as_ref().unwrap()
425 }
426
427 pub fn more(&mut self) -> bool {
429 if self.finished {
430 return false;
431 }
432 let server_name = self.ensure_initialized().to_string();
433 let response = match self.tc.send_request(
434 "collection_more",
435 &cbor_map! { "collection" => server_name.as_str() },
436 ) {
437 Ok(v) => v,
438 Err(StopTestError) => {
439 self.finished = true;
440 panic!("{}", STOP_TEST_STRING);
441 }
442 };
443 let result = match response {
444 Value::Bool(b) => b,
445 _ => panic!("Expected bool from collection_more, got {:?}", response),
446 };
447 if !result {
448 self.finished = true;
449 }
450 result
451 }
452
453 pub fn reject(&mut self, why: Option<&str>) {
455 if self.finished {
456 return;
457 }
458 let server_name = self.ensure_initialized().to_string();
459 let mut payload = cbor_map! {
460 "collection" => server_name.as_str()
461 };
462 if let Some(reason) = why {
463 map_insert(&mut payload, "why", reason.to_string());
464 }
465 let _ = self.tc.send_request("collection_reject", &payload);
466 }
467}
468
469#[doc(hidden)]
470pub mod labels {
471 pub const LIST: u64 = 1;
472 pub const LIST_ELEMENT: u64 = 2;
473 pub const SET: u64 = 3;
474 pub const SET_ELEMENT: u64 = 4;
475 pub const MAP: u64 = 5;
476 pub const MAP_ENTRY: u64 = 6;
477 pub const TUPLE: u64 = 7;
478 pub const ONE_OF: u64 = 8;
479 pub const OPTIONAL: u64 = 9;
480 pub const FIXED_DICT: u64 = 10;
481 pub const FLAT_MAP: u64 = 11;
482 pub const FILTER: u64 = 12;
483 pub const MAPPED: u64 = 13;
484 pub const SAMPLED_FROM: u64 = 14;
485 pub const ENUM_VARIANT: u64 = 15;
486}