use crate::cbor_utils::{cbor_map, map_insert};
use crate::generators::Generator;
use crate::protocol::{Connection, SERVER_CRASHED_MESSAGE, Stream};
use crate::runner::Verbosity;
use ciborium::Value;
use std::cell::RefCell;
use std::collections::{HashMap, HashSet};
use std::rc::Rc;
use std::sync::{Arc, LazyLock};
use crate::generators::value;
#[diagnostic::on_unimplemented(
// NOTE: worth checking if edits to this message should also be applied to the similar-but-different
// error message in #[composite] in hegel-macros.
message = "The first parameter in a #[composite] generator must have type TestCase.",
label = "This type does not match `TestCase`."
)]
pub trait __IsTestCase {}
impl __IsTestCase for TestCase {}
pub fn __assert_is_test_case<T: __IsTestCase>() {}
#[derive(Debug)]
pub struct StopTestError;
impl std::fmt::Display for StopTestError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "Server ran out of data (StopTest)")
}
}
impl std::error::Error for StopTestError {}
static PROTOCOL_DEBUG: LazyLock<bool> = LazyLock::new(|| {
matches!(
std::env::var("HEGEL_PROTOCOL_DEBUG")
.unwrap_or_default()
.to_lowercase()
.as_str(),
"1" | "true"
)
});
pub(crate) const ASSUME_FAIL_STRING: &str = "__HEGEL_ASSUME_FAIL";
pub(crate) const STOP_TEST_STRING: &str = "__HEGEL_STOP_TEST";
pub(crate) struct TestCaseGlobalData {
#[allow(dead_code)]
connection: Arc<Connection>,
stream: Stream,
verbosity: Verbosity,
is_last_run: bool,
test_aborted: bool,
named_draw_counts: HashMap<String, usize>,
named_draw_repeatable: HashMap<String, bool>,
allocated_display_names: HashSet<String>,
}
#[derive(Clone)]
pub(crate) struct TestCaseLocalData {
span_depth: usize,
indent: usize,
on_draw: Rc<dyn Fn(&str)>,
}
pub struct TestCase {
global: Rc<RefCell<TestCaseGlobalData>>,
local: RefCell<TestCaseLocalData>,
}
impl Clone for TestCase {
fn clone(&self) -> Self {
TestCase {
global: self.global.clone(),
local: RefCell::new(self.local.borrow().clone()),
}
}
}
impl std::fmt::Debug for TestCase {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("TestCase").finish_non_exhaustive()
}
}
impl TestCase {
pub(crate) fn new(
connection: Arc<Connection>,
stream: Stream,
verbosity: Verbosity,
is_last_run: bool,
) -> Self {
let on_draw: Rc<dyn Fn(&str)> = if is_last_run {
Rc::new(|msg| eprintln!("{}", msg))
} else {
Rc::new(|_| {})
};
TestCase {
global: Rc::new(RefCell::new(TestCaseGlobalData {
connection,
stream,
verbosity,
is_last_run,
test_aborted: false,
named_draw_counts: HashMap::new(),
named_draw_repeatable: HashMap::new(),
allocated_display_names: HashSet::new(),
})),
local: RefCell::new(TestCaseLocalData {
span_depth: 0,
indent: 0,
on_draw,
}),
}
}
pub fn draw<T: std::fmt::Debug>(&self, generator: impl Generator<T>) -> T {
self.__draw_named(generator, "draw", true)
}
pub fn __draw_named<T: std::fmt::Debug>(
&self,
generator: impl Generator<T>,
name: &str,
repeatable: bool,
) -> T {
let value = generator.do_draw(self);
if self.local.borrow().span_depth == 0 {
self.record_named_draw(&value, name, repeatable);
}
value
}
pub fn draw_silent<T>(&self, generator: impl Generator<T>) -> T {
generator.do_draw(self)
}
pub fn assume(&self, condition: bool) {
if !condition {
panic!("{}", ASSUME_FAIL_STRING);
}
}
pub fn note(&self, message: &str) {
if self.global.borrow().is_last_run {
let indent = self.local.borrow().indent; eprintln!("{:indent$}{}", "", message, indent = indent); }
}
pub(crate) fn child(&self, extra_indent: usize) -> Self {
let local = self.local.borrow();
TestCase {
global: self.global.clone(),
local: RefCell::new(TestCaseLocalData {
span_depth: 0,
indent: local.indent + extra_indent,
on_draw: local.on_draw.clone(),
}),
}
}
fn record_named_draw<T: std::fmt::Debug>(&self, value: &T, name: &str, repeatable: bool) {
let mut global = self.global.borrow_mut();
match global.named_draw_repeatable.get(name) {
Some(&prev) if prev != repeatable => {
panic!(
"__draw_named: name {:?} used with inconsistent repeatable flag (was {}, now {}). \
If you have not called __draw_named deliberately yourself, this is likely a bug in \
hegel. Please file a bug report at https://github.com/hegeldev/hegel-rust/issues",
name, prev, repeatable
);
}
_ => {
global
.named_draw_repeatable
.insert(name.to_string(), repeatable);
}
}
let count = global
.named_draw_counts
.entry(name.to_string())
.or_insert(0);
*count += 1;
let current_count = *count;
if !repeatable && current_count > 1 {
panic!(
"__draw_named: name {:?} used more than once but repeatable is false. \
This is almost certainly a bug in hegel - please report it at https://github.com/hegeldev/hegel-rust/issues",
name
);
}
let display_name = if repeatable {
let mut candidate = current_count;
loop {
let name = format!("{}_{}", name, candidate);
if global.allocated_display_names.insert(name.clone()) {
break name;
}
candidate += 1;
}
} else {
let name = name.to_string();
global.allocated_display_names.insert(name.clone());
name
};
drop(global);
let local = self.local.borrow();
let indent = local.indent;
(local.on_draw)(&format!(
"{:indent$}let {} = {:?};",
"",
display_name,
value,
indent = indent
));
}
#[doc(hidden)]
pub fn start_span(&self, label: u64) {
self.local.borrow_mut().span_depth += 1;
if let Err(StopTestError) = self.send_request("start_span", &cbor_map! {"label" => label}) {
let mut local = self.local.borrow_mut();
assert!(local.span_depth > 0);
local.span_depth -= 1;
drop(local);
panic!("{}", STOP_TEST_STRING);
}
}
#[doc(hidden)]
pub fn stop_span(&self, discard: bool) {
{
let mut local = self.local.borrow_mut();
assert!(local.span_depth > 0);
local.span_depth -= 1;
}
let _ = self.send_request("stop_span", &cbor_map! {"discard" => discard});
}
pub(crate) fn send_request(
&self,
command: &str,
payload: &Value,
) -> Result<Value, StopTestError> {
let mut global = self.global.borrow_mut();
if global.test_aborted {
return Err(StopTestError); }
let debug = *PROTOCOL_DEBUG || global.verbosity == Verbosity::Debug;
let mut entries = vec![(
Value::Text("command".to_string()),
Value::Text(command.to_string()),
)];
if let Value::Map(map) = payload {
for (k, v) in map {
entries.push((k.clone(), v.clone()));
}
}
let request = Value::Map(entries);
if debug {
eprintln!("REQUEST: {:?}", request); }
let result = global.stream.request_cbor(&request);
drop(global);
match result {
Ok(response) => {
if debug {
eprintln!("RESPONSE: {:?}", response); }
Ok(response)
}
Err(e) => {
let error_msg = e.to_string();
if error_msg.contains("overflow")
|| error_msg.contains("StopTest")
|| error_msg.contains("stream is closed")
{
if debug {
eprintln!("RESPONSE: StopTest/overflow"); }
let mut global = self.global.borrow_mut();
global.stream.mark_closed();
global.test_aborted = true;
drop(global);
Err(StopTestError)
} else if error_msg.contains("FlakyStrategyDefinition")
|| error_msg.contains("FlakyReplay")
{
let mut global = self.global.borrow_mut();
global.stream.mark_closed();
global.test_aborted = true;
drop(global);
Err(StopTestError)
} else if self.global.borrow().connection.server_has_exited() {
panic!("{}", SERVER_CRASHED_MESSAGE);
} else {
panic!("Failed to communicate with Hegel: {}", e); }
}
}
}
pub(crate) fn test_aborted(&self) -> bool {
self.global.borrow().test_aborted
}
pub(crate) fn send_mark_complete(&self, mark_complete: &Value) {
let mut global = self.global.borrow_mut();
let _ = global.stream.request_cbor(mark_complete);
let _ = global.stream.close();
}
}
#[doc(hidden)]
pub fn generate_raw(tc: &TestCase, schema: &Value) -> Value {
match tc.send_request("generate", &cbor_map! {"schema" => schema.clone()}) {
Ok(v) => v,
Err(StopTestError) => {
panic!("{}", STOP_TEST_STRING);
}
}
}
#[doc(hidden)]
pub fn generate_from_schema<T: serde::de::DeserializeOwned>(tc: &TestCase, schema: &Value) -> T {
deserialize_value(generate_raw(tc, schema))
}
pub fn deserialize_value<T: serde::de::DeserializeOwned>(raw: Value) -> T {
let hv = value::HegelValue::from(raw.clone());
value::from_hegel_value(hv).unwrap_or_else(|e| {
panic!("Failed to deserialize value: {}\nValue: {:?}", e, raw); })
}
pub struct Collection<'a> {
tc: &'a TestCase,
min_size: usize,
max_size: Option<usize>,
collection_id: Option<i64>,
finished: bool,
}
impl<'a> Collection<'a> {
pub fn new(tc: &'a TestCase, min_size: usize, max_size: Option<usize>) -> Self {
Collection {
tc,
min_size,
max_size,
collection_id: None,
finished: false,
}
}
fn ensure_initialized(&mut self) -> i64 {
if self.collection_id.is_none() {
let mut payload = cbor_map! {
"min_size" => self.min_size as u64
};
if let Some(max) = self.max_size {
map_insert(&mut payload, "max_size", max as u64); }
let response = match self.tc.send_request("new_collection", &payload) {
Ok(v) => v,
Err(StopTestError) => {
panic!("{}", STOP_TEST_STRING); }
};
let id = match response {
Value::Integer(i) => {
let n: i128 = i.into();
n as i64
}
_ => panic!(
"Expected integer response from new_collection, got {:?}",
response
),
};
self.collection_id = Some(id);
}
self.collection_id.unwrap()
}
pub fn more(&mut self) -> bool {
if self.finished {
return false; }
let collection_id = self.ensure_initialized();
let response = match self.tc.send_request(
"collection_more",
&cbor_map! { "collection_id" => collection_id },
) {
Ok(v) => v,
Err(StopTestError) => {
self.finished = true;
panic!("{}", STOP_TEST_STRING);
}
};
let result = match response {
Value::Bool(b) => b,
_ => panic!("Expected bool from collection_more, got {:?}", response), };
if !result {
self.finished = true;
}
result
}
pub fn reject(&mut self, why: Option<&str>) {
if self.finished {
return;
}
let collection_id = self.ensure_initialized();
let mut payload = cbor_map! {
"collection_id" => collection_id
};
if let Some(reason) = why {
map_insert(&mut payload, "why", reason.to_string());
}
let _ = self.tc.send_request("collection_reject", &payload); }
}
#[doc(hidden)]
pub mod labels {
pub const LIST: u64 = 1;
pub const LIST_ELEMENT: u64 = 2;
pub const SET: u64 = 3;
pub const SET_ELEMENT: u64 = 4;
pub const MAP: u64 = 5;
pub const MAP_ENTRY: u64 = 6;
pub const TUPLE: u64 = 7;
pub const ONE_OF: u64 = 8;
pub const OPTIONAL: u64 = 9;
pub const FIXED_DICT: u64 = 10;
pub const FLAT_MAP: u64 = 11;
pub const FILTER: u64 = 12;
pub const MAPPED: u64 = 13;
pub const SAMPLED_FROM: u64 = 14;
pub const ENUM_VARIANT: u64 = 15;
}