use super::*;
use crate::{random::unique_integers, VariableID, VariableIDPair};
use anyhow::{bail, Result};
use proptest::prelude::*;
use serde::ser::SerializeTuple;
use std::{collections::HashSet, fmt::Debug, hash::Hash};
pub type Linear = PolynomialBase<LinearMonomial>;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, getset::CopyGetters)]
pub struct LinearParameters {
#[getset(get_copy = "pub")]
num_terms: usize,
#[getset(get_copy = "pub")]
max_id: VariableID,
}
impl LinearParameters {
pub fn new(num_terms: usize, max_id: VariableID) -> Result<Self> {
if num_terms > Into::<u64>::into(max_id) as usize + 2 {
bail!("num_terms{num_terms} cannot be greater than max_id({max_id}) + 2");
}
Ok(Self { num_terms, max_id })
}
pub fn full(max_id: VariableID) -> Self {
Self {
num_terms: Into::<u64>::into(max_id) as usize + 2,
max_id,
}
}
pub fn is_full(&self) -> bool {
Into::<u64>::into(self.max_id) as usize + 2 == self.num_terms
}
pub fn is_empty(&self) -> bool {
self.num_terms == 0
}
}
impl Arbitrary for LinearParameters {
type Parameters = ();
type Strategy = BoxedStrategy<Self>;
fn arbitrary_with(_: Self::Parameters) -> Self::Strategy {
(0..10_usize)
.prop_flat_map(|num_terms| {
let minimal_possible_max_id = if num_terms < 2 {
0
} else {
num_terms as u64 - 2
};
(minimal_possible_max_id..=10).prop_map(move |max_id| {
LinearParameters::new(num_terms, max_id.into()).unwrap()
})
})
.boxed()
}
}
impl Default for LinearParameters {
fn default() -> Self {
Self {
num_terms: 3,
max_id: 10.into(),
}
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
pub enum LinearMonomial {
Variable(VariableID),
#[default]
Constant,
}
impl LinearMonomial {
pub fn iter(&self) -> Box<dyn Iterator<Item = VariableID>> {
match self {
LinearMonomial::Variable(id) => Box::new(std::iter::once(*id)),
LinearMonomial::Constant => Box::new(std::iter::empty()),
}
}
}
impl From<VariableID> for LinearMonomial {
fn from(value: VariableID) -> Self {
LinearMonomial::Variable(value)
}
}
impl From<u64> for LinearMonomial {
fn from(value: u64) -> Self {
LinearMonomial::Variable(VariableID::from(value))
}
}
impl std::ops::Neg for LinearMonomial {
type Output = Linear;
fn neg(self) -> Self::Output {
Linear::single_term(self, crate::coeff!(-1.0))
}
}
impl serde::Serialize for LinearMonomial {
fn serialize<S>(&self, serializer: S) -> std::result::Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
match self {
LinearMonomial::Variable(id) => {
let mut tuple = serializer.serialize_tuple(1)?;
tuple.serialize_element(&id.into_inner())?;
tuple.end()
}
LinearMonomial::Constant => {
let tuple = serializer.serialize_tuple(0)?;
tuple.end()
}
}
}
}
impl<'de> serde::Deserialize<'de> for LinearMonomial {
fn deserialize<D>(deserializer: D) -> std::result::Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
struct LinearMonomialVisitor;
impl<'de> serde::de::Visitor<'de> for LinearMonomialVisitor {
type Value = LinearMonomial;
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
formatter.write_str("a variable ID (u64) or an array of 0 or 1 variable IDs")
}
fn visit_u64<E>(self, value: u64) -> std::result::Result<Self::Value, E>
where
E: serde::de::Error,
{
Ok(LinearMonomial::Variable(value.into()))
}
fn visit_seq<A>(self, mut seq: A) -> std::result::Result<Self::Value, A::Error>
where
A: serde::de::SeqAccess<'de>,
{
let first = seq.next_element::<u64>()?;
let second = seq.next_element::<u64>()?;
match (first, second) {
(Some(id), None) => Ok(LinearMonomial::Variable(id.into())),
(None, None) => Ok(LinearMonomial::Constant),
_ => Err(serde::de::Error::custom("expected array of length 0 or 1")),
}
}
}
deserializer.deserialize_any(LinearMonomialVisitor)
}
}
impl Monomial for LinearMonomial {
type Parameters = LinearParameters;
fn degree(&self) -> Degree {
match self {
LinearMonomial::Variable(_) => 1.into(),
LinearMonomial::Constant => 0.into(),
}
}
fn max_degree() -> Degree {
1.into()
}
fn as_linear(&self) -> Option<VariableID> {
match self {
LinearMonomial::Variable(id) => Some(*id),
LinearMonomial::Constant => None,
}
}
fn as_quadratic(&self) -> Option<VariableIDPair> {
None
}
fn reduce_binary_power(&mut self, _: &VariableIDSet) -> bool {
false
}
fn ids(&self) -> Box<dyn Iterator<Item = VariableID>> {
match self {
LinearMonomial::Variable(id) => Box::new(std::iter::once(*id)),
LinearMonomial::Constant => Box::new(std::iter::empty()),
}
}
fn from_ids(mut ids: impl Iterator<Item = VariableID>) -> Option<Self> {
match (ids.next(), ids.next()) {
(Some(id), None) => Some(LinearMonomial::Variable(id)),
(None, None) => Some(LinearMonomial::Constant),
_ => None,
}
}
fn partial_evaluate(self, state: &State) -> (Self, f64) {
if let LinearMonomial::Variable(id) = self {
if let Some(value) = state.entries.get(&id.into_inner()) {
return (Self::default(), *value);
}
}
(self, 1.0)
}
fn arbitrary_uniques(p: LinearParameters) -> BoxedStrategy<FnvHashSet<Self>> {
if p.is_empty() {
return Just(HashSet::default()).boxed();
}
let max_id = p.max_id.into();
if p.is_full() {
return Just(
(0..=max_id)
.map(|id| LinearMonomial::Variable(id.into()))
.chain(std::iter::once(LinearMonomial::Constant))
.collect(),
)
.boxed();
}
bool::arbitrary()
.prop_flat_map(move |use_constant| {
if use_constant {
unique_integers(0, max_id, p.num_terms - 1)
.prop_map(|ids| {
ids.into_iter()
.map(|id| LinearMonomial::Variable(id.into()))
.chain(std::iter::once(LinearMonomial::Constant))
.collect()
})
.boxed()
} else {
unique_integers(0, max_id, p.num_terms)
.prop_map(|ids| {
ids.into_iter()
.map(|id| LinearMonomial::Variable(id.into()))
.collect()
})
.boxed()
}
})
.boxed()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_unacceptable_parameter() {
assert!(LinearParameters::new(5, 2.into()).is_err());
assert!(LinearParameters::new(5, 3.into()).is_ok());
}
proptest! {
#[test]
fn test_linear(
(p, monomials) in LinearParameters::arbitrary()
.prop_flat_map(|p| {
LinearMonomial::arbitrary_uniques(p)
.prop_map(move |monomials| (p, monomials))
}),
) {
prop_assert_eq!(monomials.len(), p.num_terms);
for monomial in monomials {
match monomial {
LinearMonomial::Variable(id) => {
prop_assert!(id <= p.max_id);
}
LinearMonomial::Constant => {}
}
}
}
}
#[test]
fn test_linear_monomial_serde() {
let var = LinearMonomial::Variable(42.into());
let json = serde_json::to_string(&var).unwrap();
assert_eq!(json, "[42]");
let deserialized: LinearMonomial = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized, var);
let deserialized: LinearMonomial = serde_json::from_str("42").unwrap();
assert_eq!(deserialized, LinearMonomial::Variable(42.into()));
let constant = LinearMonomial::Constant;
let json = serde_json::to_string(&constant).unwrap();
assert_eq!(json, "[]");
let deserialized: LinearMonomial = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized, constant);
let original = LinearMonomial::Variable(123.into());
let serialized = serde_json::to_string(&original).unwrap();
let deserialized: LinearMonomial = serde_json::from_str(&serialized).unwrap();
assert_eq!(deserialized, original);
let original = LinearMonomial::Constant;
let serialized = serde_json::to_string(&original).unwrap();
let deserialized: LinearMonomial = serde_json::from_str(&serialized).unwrap();
assert_eq!(deserialized, original);
}
#[test]
fn test_linear_monomial_deserialize_invalid() {
let result: Result<LinearMonomial, _> = serde_json::from_str("[1, 2]");
assert!(result.is_err());
let result: Result<LinearMonomial, _> = serde_json::from_str("[1, 2, 3]");
assert!(result.is_err());
}
}