use crate::algebra::mv::Mv;
use crate::algebra::ops;
use crate::algebra::signature::Signature;
use crate::governance::geoit::Geoit;
use crate::governance::governance::Governance;
use crate::governance::phase;
use crate::governance::predicate::Predicate;
use crate::governance::profile::GeneratorProfile;
use crate::governance::reading::ReadingRules;
use std::sync::Arc;
#[derive(Clone, Debug)]
pub struct TransformRule {
pub name: String,
pub input_classes: Vec<usize>,
pub output_class: usize,
pub operation: TransformOp,
pub reading_derivation: ReadingDerivation,
}
#[derive(Clone, Debug)]
pub enum TransformOp {
Sandwich,
SandwichNormalized,
Outer,
Geometric,
Regressive,
LeftContract,
GradeProject(u8),
Dual,
Reverse,
}
#[derive(Clone, Debug)]
pub enum ReadingDerivation {
PassThrough(usize),
Rederive,
}
#[derive(Clone, Debug)]
pub enum ProofTerm {
Checked { class_index: usize },
Derived {
rule_name: String,
input_proofs: Vec<Box<ProofTerm>>,
},
}
impl TransformOp {
pub fn apply(&self, inputs: &[&Mv], sig: &Signature) -> Mv {
match self {
TransformOp::Sandwich => ops::sandwich(inputs[0], inputs[1], sig),
TransformOp::SandwichNormalized => {
crate::algebra::inverse::sandwich_normalized(inputs[0], inputs[1], sig)
.unwrap_or_default()
}
TransformOp::Outer => ops::outer(inputs[0], inputs[1], sig),
TransformOp::Geometric => ops::geometric(inputs[0], inputs[1], sig),
TransformOp::Regressive => ops::regressive(inputs[0], inputs[1], sig),
TransformOp::LeftContract => ops::left_contract(inputs[0], inputs[1], sig),
TransformOp::GradeProject(k) => ops::grade_project(inputs[0], *k),
TransformOp::Dual => ops::dual(inputs[0], sig),
TransformOp::Reverse => ops::reverse(inputs[0]),
}
}
pub fn arity(&self) -> usize {
match self {
TransformOp::Sandwich
| TransformOp::SandwichNormalized
| TransformOp::Outer
| TransformOp::Geometric
| TransformOp::Regressive
| TransformOp::LeftContract => 2,
TransformOp::GradeProject(_) | TransformOp::Dual | TransformOp::Reverse => 1,
}
}
}
pub fn apply_transform(
rule: &TransformRule,
inputs: &[&Geoit],
gov: Arc<Governance>,
) -> Result<Geoit, TransformError> {
if inputs.len() != rule.input_classes.len() {
return Err(TransformError::ArityMismatch {
expected: rule.input_classes.len(),
got: inputs.len(),
});
}
for (i, (&expected_class, geoit)) in rule.input_classes.iter().zip(inputs.iter()).enumerate() {
if geoit.predicate().class_index != expected_class {
return Err(TransformError::ClassMismatch {
input_index: i,
expected: expected_class,
got: geoit.predicate().class_index,
});
}
if !geoit.is_satisfied() {
return Err(TransformError::InputNotSatisfied { input_index: i });
}
}
let input_mvs: Vec<&Mv> = inputs.iter().map(|g| g.mv()).collect();
let result_mv = rule.operation.apply(&input_mvs, &gov.sig);
let ph = phase::compute_phase(&result_mv, &gov.sig);
let profile = GeneratorProfile::compute(&result_mv, &gov.sig);
let readings = match &rule.reading_derivation {
ReadingDerivation::PassThrough(k) => inputs[*k].readings.clone(),
ReadingDerivation::Rederive => {
let class = &gov.geom_classes[rule.output_class];
let construction = gov
.constructions
.iter()
.find(|c| c.class_index == rule.output_class);
if let Some(constr) = construction {
ReadingRules::derive_from_probing(constr, &gov.sig, &gov.derived_gens)
.unwrap_or_else(|_| ReadingRules::derive_from_grade_mask(class, 0, &gov.sig))
} else {
ReadingRules::derive_from_grade_mask(class, 0, &gov.sig)
}
}
};
let pred = Predicate::new(
rule.output_class,
&[],
&[],
vec![true; gov.geom_classes[rule.output_class].equations.len()],
vec![true; gov.geom_classes[rule.output_class].inequalities.len()],
false,
);
Ok(Geoit {
mv: result_mv,
governance: gov,
predicate: pred,
phase: ph,
readings,
profile,
proof: ProofTerm::Derived {
rule_name: rule.name.clone(),
input_proofs: inputs.iter().map(|g| Box::new(g.proof.clone())).collect(),
},
})
}
#[derive(Clone, Debug)]
pub enum TransformError {
ArityMismatch {
expected: usize,
got: usize,
},
ClassMismatch {
input_index: usize,
expected: usize,
got: usize,
},
InputNotSatisfied {
input_index: usize,
},
}
impl std::fmt::Display for TransformError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
TransformError::ArityMismatch { expected, got } => write!(
f,
"transform arity mismatch: expected {} inputs, got {}",
expected, got
),
TransformError::ClassMismatch {
input_index,
expected,
got,
} => write!(
f,
"input {} class mismatch: expected class {}, got {}",
input_index, expected, got
),
TransformError::InputNotSatisfied { input_index } => {
write!(f, "input {} governance not satisfied", input_index)
}
}
}
}
impl std::error::Error for TransformError {}
#[cfg(test)]
mod tests {
use super::*;
use crate::scalar::{Rat, Scalar};
#[test]
fn transform_op_arity() {
assert_eq!(TransformOp::Sandwich.arity(), 2);
assert_eq!(TransformOp::Outer.arity(), 2);
assert_eq!(TransformOp::Reverse.arity(), 1);
assert_eq!(TransformOp::GradeProject(1).arity(), 1);
}
#[test]
fn transform_op_apply_geometric() {
let sig = Signature::new(0, 0, 3).unwrap();
let a = Mv::generator(0);
let b = Mv::generator(1);
let result = TransformOp::Geometric.apply(&[&a, &b], &sig);
assert_eq!(result.coefficient(0b11), Scalar::from(1i64));
}
#[test]
fn transform_op_apply_reverse() {
let sig = Signature::new(0, 0, 3).unwrap();
let a = Mv::from_rat_terms(&[(0b11, Rat::from(1))]); let result = TransformOp::Reverse.apply(&[&a], &sig);
assert_eq!(result.coefficient(0b11), Scalar::from(-1i64));
}
}