use pounce_common::types::Number;
#[derive(Debug, Clone, PartialEq)]
pub enum FbbtOp {
Const(Number),
Var(usize),
Add(usize, usize),
Sub(usize, usize),
Mul(usize, usize),
Div(usize, usize),
PowInt(usize, u32),
Neg(usize),
Sqrt(usize),
Exp(usize),
Ln(usize),
Abs(usize),
Sin(usize),
Cos(usize),
Opaque,
}
impl FbbtOp {
pub fn operand_indices(&self) -> ArrayVec2 {
match *self {
FbbtOp::Const(_) | FbbtOp::Var(_) | FbbtOp::Opaque => ArrayVec2::new(),
FbbtOp::Neg(a)
| FbbtOp::Sqrt(a)
| FbbtOp::Exp(a)
| FbbtOp::Ln(a)
| FbbtOp::Abs(a)
| FbbtOp::Sin(a)
| FbbtOp::Cos(a)
| FbbtOp::PowInt(a, _) => ArrayVec2::one(a),
FbbtOp::Add(a, b) | FbbtOp::Sub(a, b) | FbbtOp::Mul(a, b) | FbbtOp::Div(a, b) => {
ArrayVec2::two(a, b)
}
}
}
}
#[derive(Debug, Clone, Copy)]
pub struct ArrayVec2 {
data: [usize; 2],
len: u8,
}
impl ArrayVec2 {
pub fn new() -> Self {
Self {
data: [0, 0],
len: 0,
}
}
pub fn one(a: usize) -> Self {
Self {
data: [a, 0],
len: 1,
}
}
pub fn two(a: usize, b: usize) -> Self {
Self {
data: [a, b],
len: 2,
}
}
pub fn as_slice(&self) -> &[usize] {
&self.data[..self.len as usize]
}
pub fn len(&self) -> usize {
self.len as usize
}
pub fn is_empty(&self) -> bool {
self.len == 0
}
}
impl Default for ArrayVec2 {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Default)]
pub struct FbbtTape {
pub ops: Vec<FbbtOp>,
}
impl FbbtTape {
pub fn new() -> Self {
Self { ops: Vec::new() }
}
pub fn is_empty(&self) -> bool {
self.ops.is_empty()
}
pub fn len(&self) -> usize {
self.ops.len()
}
pub fn first_invalid_slot(&self) -> Option<usize> {
for (i, op) in self.ops.iter().enumerate() {
for &operand in op.operand_indices().as_slice() {
if operand >= i {
return Some(i);
}
}
}
None
}
}
pub trait ExpressionProvider {
fn constraint_expression(&self, _i: usize) -> Option<FbbtTape> {
None
}
fn objective_expression(&self) -> Option<FbbtTape> {
None
}
}
#[cfg(test)]
mod tests {
use super::*;
fn const_tape(c: Number) -> FbbtTape {
FbbtTape {
ops: vec![FbbtOp::Const(c)],
}
}
#[test]
fn operand_indices_match_op_arity() {
assert!(FbbtOp::Const(1.0).operand_indices().is_empty());
assert!(FbbtOp::Var(0).operand_indices().is_empty());
assert!(FbbtOp::Opaque.operand_indices().is_empty());
assert_eq!(FbbtOp::Neg(3).operand_indices().as_slice(), &[3]);
assert_eq!(FbbtOp::PowInt(2, 4).operand_indices().as_slice(), &[2]);
assert_eq!(FbbtOp::Add(1, 2).operand_indices().as_slice(), &[1, 2]);
}
#[test]
fn validate_well_formed_tape() {
let tape = FbbtTape {
ops: vec![
FbbtOp::Var(0),
FbbtOp::Var(1),
FbbtOp::Add(0, 1),
FbbtOp::Const(2.0),
FbbtOp::Mul(2, 3),
],
};
assert_eq!(tape.first_invalid_slot(), None);
}
#[test]
fn validate_catches_forward_reference() {
let tape = FbbtTape {
ops: vec![FbbtOp::Neg(1), FbbtOp::Const(0.0)],
};
assert_eq!(tape.first_invalid_slot(), Some(0));
}
#[test]
fn validate_catches_self_reference() {
let tape = FbbtTape {
ops: vec![FbbtOp::Neg(0)],
};
assert_eq!(tape.first_invalid_slot(), Some(0));
}
#[test]
fn default_trait_returns_none() {
struct NoExpr;
impl ExpressionProvider for NoExpr {}
let p = NoExpr;
assert!(p.constraint_expression(0).is_none());
assert!(p.objective_expression().is_none());
}
#[test]
fn custom_provider_returns_tape() {
struct Always(Number);
impl ExpressionProvider for Always {
fn constraint_expression(&self, _i: usize) -> Option<FbbtTape> {
Some(const_tape(self.0))
}
}
let p = Always(3.5);
let t = p.constraint_expression(7).unwrap();
assert_eq!(t.ops, vec![FbbtOp::Const(3.5)]);
}
}