use nalgebra::{DMatrix, DVector};
use crate::context::error::OxiflowError;
#[non_exhaustive]
#[derive(Debug, Clone, PartialEq)]
#[cfg_attr(feature = "serde", derive(serde::Serialize, serde::Deserialize))]
pub enum ContextValue {
Scalar(f64),
Boolean(bool),
Vector(DVector<f64>),
Matrix(DMatrix<f64>),
ScalarField(DVector<f64>),
VectorField(DMatrix<f64>),
}
impl ContextValue {
pub fn variant_name(&self) -> &'static str {
match self {
Self::Scalar(_) => "Scalar",
Self::Boolean(_) => "Boolean",
Self::Vector(_) => "Vector",
Self::Matrix(_) => "Matrix",
Self::ScalarField(_) => "ScalarField",
Self::VectorField(_) => "VectorField",
#[allow(unreachable_patterns)]
_ => "Unknown",
}
}
pub fn as_scalar(&self) -> Result<f64, OxiflowError> {
match self {
Self::Scalar(v) => Ok(*v),
other => Err(OxiflowError::TypeMismatch {
expected: "Scalar",
actual: other.variant_name(),
}),
}
}
pub fn as_bool(&self) -> Result<bool, OxiflowError> {
match self {
Self::Boolean(v) => Ok(*v),
other => Err(OxiflowError::TypeMismatch {
expected: "Boolean",
actual: other.variant_name(),
}),
}
}
pub fn as_vector(&self) -> Result<&DVector<f64>, OxiflowError> {
match self {
Self::Vector(v) => Ok(v),
other => Err(OxiflowError::TypeMismatch {
expected: "Vector",
actual: other.variant_name(),
}),
}
}
pub fn as_matrix(&self) -> Result<&DMatrix<f64>, OxiflowError> {
match self {
Self::Matrix(m) => Ok(m),
other => Err(OxiflowError::TypeMismatch {
expected: "Matrix",
actual: other.variant_name(),
}),
}
}
pub fn as_scalar_field(&self) -> Result<&DVector<f64>, OxiflowError> {
match self {
Self::ScalarField(v) => Ok(v),
other => Err(OxiflowError::TypeMismatch {
expected: "ScalarField",
actual: other.variant_name(),
}),
}
}
pub fn as_vector_field(&self) -> Result<&DMatrix<f64>, OxiflowError> {
match self {
Self::VectorField(m) => Ok(m),
other => Err(OxiflowError::TypeMismatch {
expected: "VectorField",
actual: other.variant_name(),
}),
}
}
pub fn is_scalar(&self) -> bool {
matches!(self, Self::Scalar(_))
}
pub fn is_bool(&self) -> bool {
matches!(self, Self::Boolean(_))
}
pub fn is_vector(&self) -> bool {
matches!(self, Self::Vector(_))
}
pub fn is_matrix(&self) -> bool {
matches!(self, Self::Matrix(_))
}
pub fn is_scalar_field(&self) -> bool {
matches!(self, Self::ScalarField(_))
}
pub fn is_vector_field(&self) -> bool {
matches!(self, Self::VectorField(_))
}
}
#[cfg(test)]
mod tests {
use super::*;
use nalgebra::{DMatrix, DVector};
fn scalar() -> ContextValue {
ContextValue::Scalar(3.14)
}
fn boolean() -> ContextValue {
ContextValue::Boolean(true)
}
fn vector() -> ContextValue {
ContextValue::Vector(DVector::from_vec(vec![1.0, 2.0, 3.0]))
}
fn matrix() -> ContextValue {
ContextValue::Matrix(DMatrix::from_element(2, 2, 1.0))
}
fn scalar_field() -> ContextValue {
ContextValue::ScalarField(DVector::from_vec(vec![0.1, 0.2]))
}
fn vector_field() -> ContextValue {
ContextValue::VectorField(DMatrix::from_element(3, 2, 0.5))
}
fn all_variants() -> Vec<ContextValue> {
vec![
scalar(),
boolean(),
vector(),
matrix(),
scalar_field(),
vector_field(),
]
}
#[test]
fn variant_names_are_correct() {
assert_eq!(scalar().variant_name(), "Scalar");
assert_eq!(boolean().variant_name(), "Boolean");
assert_eq!(vector().variant_name(), "Vector");
assert_eq!(matrix().variant_name(), "Matrix");
assert_eq!(scalar_field().variant_name(), "ScalarField");
assert_eq!(vector_field().variant_name(), "VectorField");
}
#[test]
fn as_scalar_on_scalar_returns_value() {
assert_eq!(scalar().as_scalar().unwrap(), 3.14);
}
#[test]
fn as_scalar_on_wrong_variant_returns_type_mismatch() {
for v in [
boolean(),
vector(),
matrix(),
scalar_field(),
vector_field(),
] {
let err = v.as_scalar().unwrap_err();
assert!(
matches!(
err,
OxiflowError::TypeMismatch {
expected: "Scalar",
..
}
),
"expected TypeMismatch for {:?}",
v
);
}
}
#[test]
fn as_bool_on_boolean_returns_value() {
assert!(boolean().as_bool().unwrap());
assert!(!ContextValue::Boolean(false).as_bool().unwrap());
}
#[test]
fn as_bool_on_wrong_variant_returns_type_mismatch() {
for v in [scalar(), vector(), matrix(), scalar_field(), vector_field()] {
let err = v.as_bool().unwrap_err();
assert!(matches!(
err,
OxiflowError::TypeMismatch {
expected: "Boolean",
..
}
));
}
}
#[test]
fn as_vector_on_vector_returns_reference() {
let v = vector();
let inner = v.as_vector().unwrap();
assert_eq!(inner.len(), 3);
assert_eq!(inner[0], 1.0);
}
#[test]
fn as_vector_on_wrong_variant_returns_type_mismatch() {
for v in [
scalar(),
boolean(),
matrix(),
scalar_field(),
vector_field(),
] {
let err = v.as_vector().unwrap_err();
assert!(matches!(
err,
OxiflowError::TypeMismatch {
expected: "Vector",
..
}
));
}
}
#[test]
fn as_matrix_on_matrix_returns_reference() {
let m = matrix();
let inner = m.as_matrix().unwrap();
assert_eq!(inner.shape(), (2, 2));
}
#[test]
fn as_matrix_on_wrong_variant_returns_type_mismatch() {
for v in [
scalar(),
boolean(),
vector(),
scalar_field(),
vector_field(),
] {
let err = v.as_matrix().unwrap_err();
assert!(matches!(
err,
OxiflowError::TypeMismatch {
expected: "Matrix",
..
}
));
}
}
#[test]
fn as_scalar_field_on_scalar_field_returns_reference() {
let sf = scalar_field();
let inner = sf.as_scalar_field().unwrap();
assert_eq!(inner.len(), 2);
assert!((inner[0] - 0.1).abs() < 1e-12);
}
#[test]
fn as_scalar_field_on_wrong_variant_returns_type_mismatch() {
for v in [scalar(), boolean(), vector(), matrix(), vector_field()] {
let err = v.as_scalar_field().unwrap_err();
assert!(matches!(
err,
OxiflowError::TypeMismatch {
expected: "ScalarField",
..
}
));
}
}
#[test]
fn as_vector_field_on_vector_field_returns_reference() {
let vf = vector_field();
let inner = vf.as_vector_field().unwrap();
assert_eq!(inner.shape(), (3, 2));
}
#[test]
fn as_vector_field_on_wrong_variant_returns_type_mismatch() {
for v in [scalar(), boolean(), vector(), matrix(), scalar_field()] {
let err = v.as_vector_field().unwrap_err();
assert!(matches!(
err,
OxiflowError::TypeMismatch {
expected: "VectorField",
..
}
));
}
}
#[test]
fn is_predicates_return_true_for_correct_variant() {
assert!(scalar().is_scalar());
assert!(boolean().is_bool());
assert!(vector().is_vector());
assert!(matrix().is_matrix());
assert!(scalar_field().is_scalar_field());
assert!(vector_field().is_vector_field());
}
#[test]
fn is_scalar_returns_false_for_other_variants() {
for v in [
boolean(),
vector(),
matrix(),
scalar_field(),
vector_field(),
] {
assert!(!v.is_scalar(), "expected false for {:?}", v);
}
}
#[test]
fn is_scalar_field_returns_false_for_scalar() {
assert!(!scalar().is_scalar_field());
}
#[test]
fn clone_preserves_equality() {
for v in all_variants() {
assert_eq!(v.clone(), v);
}
}
#[test]
fn distinct_variants_are_not_equal() {
assert_ne!(scalar(), boolean());
assert_ne!(vector(), scalar_field()); assert_ne!(matrix(), vector_field());
}
#[test]
fn scalar_values_compared_by_content() {
assert_eq!(ContextValue::Scalar(1.0), ContextValue::Scalar(1.0));
assert_ne!(ContextValue::Scalar(1.0), ContextValue::Scalar(2.0));
}
#[test]
fn debug_is_non_empty_for_all_variants() {
for v in all_variants() {
assert!(!format!("{:?}", v).is_empty());
}
}
}