use std::sync::atomic::{AtomicBool, Ordering};
use std::sync::{Arc, Mutex};
use ciborium::Value;
use crate::backend::{DataSource, DataSourceError, TestCaseResult};
use crate::native::core::{ChoiceNode, ManyState, NativeTestCase, Span, StopTest};
use crate::native::schema;
pub struct NativeTestCaseInner {
pub ntc: NativeTestCase,
pub outcome: Option<TestCaseResult>,
}
pub type NativeTestCaseHandle = Arc<Mutex<NativeTestCaseInner>>;
pub struct NativeDataSource {
inner: NativeTestCaseHandle,
aborted: AtomicBool,
}
impl NativeDataSource {
pub fn new(ntc: NativeTestCase) -> (Self, NativeTestCaseHandle) {
let inner = Arc::new(Mutex::new(NativeTestCaseInner { ntc, outcome: None }));
let handle = Arc::clone(&inner);
(
NativeDataSource {
inner,
aborted: AtomicBool::new(false),
},
handle,
)
}
pub fn take_nodes(handle: &NativeTestCaseHandle) -> Vec<ChoiceNode> {
handle
.lock()
.unwrap_or_else(|e| e.into_inner())
.ntc
.nodes
.clone()
}
pub fn take_spans(handle: &NativeTestCaseHandle) -> Vec<Span> {
handle
.lock()
.unwrap_or_else(|e| e.into_inner())
.ntc
.spans
.clone()
.into_vec()
}
pub fn take_outcome(handle: &NativeTestCaseHandle) -> TestCaseResult {
handle
.lock()
.unwrap_or_else(|e| e.into_inner())
.outcome
.take()
.expect("mark_complete must be called for every test case")
}
#[cfg(test)]
pub(crate) fn test_aborted(&self) -> bool {
self.aborted.load(Ordering::Relaxed)
}
fn with_ntc<R>(
&self,
f: impl FnOnce(&mut NativeTestCase) -> Result<R, StopTest>,
) -> Result<R, DataSourceError> {
if self.aborted.load(Ordering::Relaxed) {
return Err(DataSourceError::StopTest);
}
let mut inner = self.inner.lock().unwrap_or_else(|e| e.into_inner());
f(&mut inner.ntc).map_err(|_stop| {
self.aborted.store(true, Ordering::Relaxed);
DataSourceError::StopTest
})
}
}
impl DataSource for NativeDataSource {
fn generate(&self, schema: &Value) -> Result<Value, DataSourceError> {
self.with_ntc(|ntc| schema::interpret_schema(ntc, schema))
}
fn start_span(&self, label: u64) -> Result<(), DataSourceError> {
self.with_ntc(|ntc| {
ntc.start_span(label);
Ok(())
})
}
fn stop_span(&self, discard: bool) -> Result<(), DataSourceError> {
self.with_ntc(|ntc| {
ntc.stop_span(discard);
Ok(())
})
}
fn new_collection(&self, min_size: u64, max_size: Option<u64>) -> Result<i64, DataSourceError> {
self.with_ntc(|ntc| {
let state = ManyState::new(min_size as usize, max_size.map(|n| n as usize));
Ok(ntc.new_collection(state))
})
}
fn collection_more(&self, collection_id: i64) -> Result<bool, DataSourceError> {
self.with_ntc(|ntc| {
let mut state = ntc
.collections
.remove(&collection_id)
.expect("collection_more: unknown collection_id");
let result = schema::many_more(ntc, &mut state);
ntc.collections.insert(collection_id, state);
result
})
}
fn collection_reject(
&self,
collection_id: i64,
_why: Option<&str>,
) -> Result<(), DataSourceError> {
self.with_ntc(|ntc| {
let mut state = ntc
.collections
.remove(&collection_id)
.expect("collection_reject: unknown collection_id");
let result = schema::many_reject(ntc, &mut state);
ntc.collections.insert(collection_id, state);
result
})
}
fn new_pool(&self) -> Result<i128, DataSourceError> {
self.with_ntc(|ntc| {
let pool_id = ntc.variable_pools.len() as i128;
ntc.variable_pools
.push(crate::native::core::NativeVariables::new());
Ok(pool_id)
})
}
fn pool_add(&self, pool_id: i128) -> Result<i128, DataSourceError> {
self.with_ntc(|ntc| Ok(ntc.variable_pools[pool_id as usize].next()))
}
fn pool_generate(&self, pool_id: i128, consume: bool) -> Result<i128, DataSourceError> {
self.with_ntc(|ntc| {
let pool_idx = pool_id as usize;
let active = ntc.variable_pools[pool_idx].active();
if active.is_empty() {
return Err(StopTest);
}
let n = active.len() as i128;
let k = ntc.draw_integer(0, n - 1)?;
let variable_id = active[(n - 1 - k) as usize];
if consume {
ntc.variable_pools[pool_idx].consume(variable_id);
}
Ok(variable_id)
})
}
fn target_observation(&self, _score: f64, _label: &str) {
todo!(
"tc.target() is not yet supported by the native backend; \
Phase::Target will land in a follow-up PR"
);
}
fn mark_complete(&self, result: &TestCaseResult) {
let mut inner = self.inner.lock().unwrap_or_else(|e| e.into_inner());
inner.outcome = Some(result.clone());
}
}
#[cfg(test)]
#[path = "../../tests/embedded/native/data_source_tests.rs"]
mod tests;