use std::sync::Arc;
use svod_dtype::DType;
use svod_ir::{ConstValue, Op, SInt, UOp};
use crate::error::{Result, VariableOutOfRangeSnafu};
use snafu::ensure;
#[derive(Clone)]
pub struct Variable {
uop: Arc<UOp>,
}
impl Variable {
pub fn new(name: &str, min_val: i64, max_val: i64) -> Self {
assert!(min_val <= max_val, "Variable '{name}': min_val ({min_val}) > max_val ({max_val})");
Self { uop: UOp::define_var(name.to_string(), min_val, max_val) }
}
pub fn bind(&self, val: i64) -> Result<BoundVariable> {
let (min, max) = self.bounds();
ensure!(val >= min && val <= max, VariableOutOfRangeSnafu { name: self.name().to_string(), val, min, max });
let val_uop = UOp::const_(DType::Index, ConstValue::Int(val));
let bind_uop = self.uop.bind(val_uop);
Ok(BoundVariable { var: self.clone(), value: val, uop: bind_uop })
}
pub fn name(&self) -> &str {
match self.uop.op() {
Op::DefineVar { name, .. } => name.as_str(),
_ => unreachable!("Variable always wraps DefineVar"),
}
}
pub fn bounds(&self) -> (i64, i64) {
match self.uop.op() {
Op::DefineVar { min_val, max_val, .. } => (*min_val, *max_val),
_ => unreachable!("Variable always wraps DefineVar"),
}
}
pub fn as_sint(&self) -> SInt {
SInt::Symbolic(self.uop.clone())
}
pub fn uop(&self) -> &Arc<UOp> {
&self.uop
}
}
impl std::fmt::Debug for Variable {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
let (min, max) = self.bounds();
write!(f, "Variable({:?}, {}, {})", self.name(), min, max)
}
}
#[derive(Clone)]
pub struct BoundVariable {
var: Variable,
value: i64,
uop: Arc<UOp>,
}
impl BoundVariable {
pub fn as_sint(&self) -> SInt {
SInt::Symbolic(self.uop.clone())
}
pub fn value(&self) -> i64 {
self.value
}
pub fn variable(&self) -> &Variable {
&self.var
}
pub fn unbind(self) -> (Variable, i64) {
(self.var, self.value)
}
pub fn as_var_val(&self) -> (&str, i64) {
(self.variable().name(), self.value())
}
pub fn uop(&self) -> &Arc<UOp> {
&self.uop
}
}
impl std::fmt::Debug for BoundVariable {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "BoundVariable({:?} = {})", self.var.name(), self.value)
}
}
impl From<BoundVariable> for SInt {
fn from(bv: BoundVariable) -> SInt {
bv.as_sint()
}
}
impl From<&BoundVariable> for SInt {
fn from(bv: &BoundVariable) -> SInt {
bv.as_sint()
}
}