use proc_macro2::TokenStream;
use quote::quote;
use crate::codegen::CodeGen;
impl CodeGen {
pub(crate) fn dedup_nonrecursive(&mut self) -> TokenStream {
if self.config.is_datalog_batch() {
quote! { .consolidate() }
} else {
threshold_nonzero()
}
}
pub(crate) fn dedup_recursive(&mut self) -> TokenStream {
if self.config.is_datalog_batch() {
self.features.mark_threshold_total();
quote! {
.threshold_semigroup(move |_, _, old| old.is_none().then_some(SEMIRING_ONE))
}
} else {
threshold_nonzero()
}
}
pub(crate) fn dedup_antijoin(&mut self) -> TokenStream {
if self.config.is_datalog_batch() {
quote! {}
} else {
threshold_nonzero()
}
}
}
fn threshold_nonzero() -> TokenStream {
quote! { .threshold(|_, w| if *w > 0 { SEMIRING_ONE } else { 0 }) }
}
#[cfg(test)]
mod tests {
use super::*;
use crate::common::{Config, ExecutionMode};
use crate::parser::Program;
fn codegen_with_mode(mode: ExecutionMode) -> CodeGen {
let config = Config {
mode,
..Config::default()
};
CodeGen::new(config, Program::default())
}
#[test]
fn datalog_batch_emits_expected_variants_and_marks_threshold_total() {
let mut cg = codegen_with_mode(ExecutionMode::DatalogBatch);
let non_rec = cg.dedup_nonrecursive().to_string();
assert!(
non_rec.contains("consolidate"),
"batch non-recursive must emit consolidate(), got: {non_rec}"
);
assert!(
!cg.features().threshold_total(),
"threshold_total must start unset"
);
let rec = cg.dedup_recursive().to_string();
assert!(
rec.contains("threshold_semigroup"),
"batch recursive must emit threshold_semigroup(...), got: {rec}"
);
assert!(
cg.features().threshold_total(),
"dedup_recursive under batch must mark threshold_total"
);
let anti = cg.dedup_antijoin().to_string();
assert!(
anti.trim().is_empty(),
"batch antijoin dedup is a no-op, got: `{anti}`"
);
}
#[test]
fn datalog_inc_emits_threshold_uniformly() {
let mut cg = codegen_with_mode(ExecutionMode::DatalogInc);
for (name, tokens) in [
("dedup_nonrecursive", cg.dedup_nonrecursive().to_string()),
("dedup_recursive", cg.dedup_recursive().to_string()),
("dedup_antijoin", cg.dedup_antijoin().to_string()),
] {
assert!(
tokens.contains("threshold")
&& !tokens.contains("threshold_semigroup")
&& !tokens.contains("consolidate"),
"{name} under incremental mode must emit plain threshold(...), got: {tokens}"
);
}
}
}