use anyhow::{bail, Result};
use tensorlogic_ir::{EinsumGraph, TLExpr};
use crate::compile::compile_expr;
use crate::context::{CompileState, CompilerContext};
pub(crate) fn compile_all_different(
variables: &[String],
ctx: &mut CompilerContext,
graph: &mut EinsumGraph,
) -> Result<CompileState> {
if variables.is_empty() {
bail!("AllDifferent constraint requires at least one variable");
}
if variables.len() == 1 {
let tensor_name = "const_1.0";
let tensor_idx = graph.add_tensor(tensor_name);
return Ok(CompileState {
tensor_idx,
axes: String::new(),
});
}
let mut constraints = Vec::new();
for i in 0..variables.len() {
for j in (i + 1)..variables.len() {
let var_i = &variables[i];
let var_j = &variables[j];
let expr_i = TLExpr::pred(var_i, vec![]);
let expr_j = TLExpr::pred(var_j, vec![]);
let inequality = TLExpr::negate(TLExpr::Eq(Box::new(expr_i), Box::new(expr_j)));
constraints.push(inequality);
}
}
let result_expr = constraints.into_iter().reduce(TLExpr::and).unwrap();
compile_expr(&result_expr, ctx, graph)
}
pub(crate) fn compile_global_cardinality(
variables: &[String],
values: &[TLExpr],
min_occurrences: &[usize],
max_occurrences: &[usize],
ctx: &mut CompilerContext,
graph: &mut EinsumGraph,
) -> Result<CompileState> {
if variables.is_empty() {
bail!("GlobalCardinality constraint requires at least one variable");
}
if values.len() != min_occurrences.len() || values.len() != max_occurrences.len() {
bail!(
"GlobalCardinality: values, min_occurrences, and max_occurrences must have same length"
);
}
for (i, (min, max)) in min_occurrences
.iter()
.zip(max_occurrences.iter())
.enumerate()
{
if min > max {
bail!(
"GlobalCardinality: min_occurrences[{}] ({}) > max_occurrences[{}] ({})",
i,
min,
i,
max
);
}
}
let mut value_constraints = Vec::new();
for (idx, value_expr) in values.iter().enumerate() {
let min = min_occurrences[idx];
let max = max_occurrences[idx];
let mut occurrence_indicators = Vec::new();
for var_name in variables {
let var_expr = TLExpr::pred(var_name, vec![]);
let equals = TLExpr::Eq(Box::new(var_expr), Box::new(value_expr.clone()));
occurrence_indicators.push(equals);
}
let count_expr = occurrence_indicators
.into_iter()
.reduce(|acc, indicator| TLExpr::Add(Box::new(acc), Box::new(indicator)))
.unwrap();
let min_constraint = if min > 0 {
Some(TLExpr::Gte(
Box::new(count_expr.clone()),
Box::new(TLExpr::Constant(min as f64)),
))
} else {
None };
let max_constraint = if max < variables.len() {
Some(TLExpr::Lte(
Box::new(count_expr),
Box::new(TLExpr::Constant(max as f64)),
))
} else {
None };
match (min_constraint, max_constraint) {
(Some(min_c), Some(max_c)) => {
value_constraints.push(TLExpr::and(min_c, max_c));
}
(Some(c), None) | (None, Some(c)) => {
value_constraints.push(c);
}
(None, None) => {
}
}
}
if value_constraints.is_empty() {
let tensor_name = "const_1.0";
let tensor_idx = graph.add_tensor(tensor_name);
return Ok(CompileState {
tensor_idx,
axes: String::new(),
});
}
let combined_constraint = value_constraints.into_iter().reduce(TLExpr::and).unwrap();
compile_expr(&combined_constraint, ctx, graph)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_all_different_single_variable() {
let mut ctx = CompilerContext::new();
let mut graph = EinsumGraph::new();
let variables = vec!["x".to_string()];
let result = compile_all_different(&variables, &mut ctx, &mut graph).unwrap();
assert!(result.axes.is_empty()); }
#[test]
fn test_all_different_two_variables() {
let mut ctx = CompilerContext::new();
let mut graph = EinsumGraph::new();
let variables = vec!["x".to_string(), "y".to_string()];
let _result = compile_all_different(&variables, &mut ctx, &mut graph).unwrap();
assert!(!graph.tensors.is_empty());
}
#[test]
fn test_all_different_three_variables() {
let mut ctx = CompilerContext::new();
let mut graph = EinsumGraph::new();
let variables = vec!["x".to_string(), "y".to_string(), "z".to_string()];
let _result = compile_all_different(&variables, &mut ctx, &mut graph).unwrap();
assert!(!graph.tensors.is_empty());
}
#[test]
fn test_all_different_empty_fails() {
let mut ctx = CompilerContext::new();
let mut graph = EinsumGraph::new();
let variables: Vec<String> = vec![];
let result = compile_all_different(&variables, &mut ctx, &mut graph);
assert!(result.is_err());
}
#[test]
fn test_global_cardinality_simple() {
let mut ctx = CompilerContext::new();
let mut graph = EinsumGraph::new();
let variables = vec!["x".to_string(), "y".to_string(), "z".to_string()];
let values = vec![TLExpr::Constant(1.0), TLExpr::Constant(2.0)];
let min_occurrences = vec![1, 1]; let max_occurrences = vec![2, 2];
let _result = compile_global_cardinality(
&variables,
&values,
&min_occurrences,
&max_occurrences,
&mut ctx,
&mut graph,
)
.unwrap();
assert!(!graph.tensors.is_empty());
}
#[test]
fn test_global_cardinality_no_constraints() {
let mut ctx = CompilerContext::new();
let mut graph = EinsumGraph::new();
let variables = vec!["x".to_string(), "y".to_string()];
let values = vec![TLExpr::Constant(1.0)];
let min_occurrences = vec![0];
let max_occurrences = vec![2];
let result = compile_global_cardinality(
&variables,
&values,
&min_occurrences,
&max_occurrences,
&mut ctx,
&mut graph,
)
.unwrap();
assert!(result.axes.is_empty());
}
#[test]
fn test_global_cardinality_invalid_bounds() {
let mut ctx = CompilerContext::new();
let mut graph = EinsumGraph::new();
let variables = vec!["x".to_string()];
let values = vec![TLExpr::Constant(1.0)];
let min_occurrences = vec![2];
let max_occurrences = vec![1];
let result = compile_global_cardinality(
&variables,
&values,
&min_occurrences,
&max_occurrences,
&mut ctx,
&mut graph,
);
assert!(result.is_err());
}
#[test]
fn test_global_cardinality_mismatched_lengths() {
let mut ctx = CompilerContext::new();
let mut graph = EinsumGraph::new();
let variables = vec!["x".to_string()];
let values = vec![TLExpr::Constant(1.0), TLExpr::Constant(2.0)];
let min_occurrences = vec![1]; let max_occurrences = vec![1, 1];
let result = compile_global_cardinality(
&variables,
&values,
&min_occurrences,
&max_occurrences,
&mut ctx,
&mut graph,
);
assert!(result.is_err());
}
#[test]
fn test_global_cardinality_empty_variables_fails() {
let mut ctx = CompilerContext::new();
let mut graph = EinsumGraph::new();
let variables: Vec<String> = vec![];
let values = vec![TLExpr::Constant(1.0)];
let min_occurrences = vec![0];
let max_occurrences = vec![1];
let result = compile_global_cardinality(
&variables,
&values,
&min_occurrences,
&max_occurrences,
&mut ctx,
&mut graph,
);
assert!(result.is_err());
}
}