Skip to main content

oximo_expr/
handle.rs

1use std::cell::RefCell;
2
3use crate::arena::{ExprArena, ExprId, ExprNode, ParamId, VarId};
4
5/// Lightweight handle to a node in an [`ExprArena`].
6///
7/// Carries a borrow of the arena (wrapped in `RefCell` so operator overloads
8/// can push new nodes during arithmetic). `Expr` is `Copy`, so users freely
9/// reuse a variable handle in many constraints.
10#[derive(Copy, Clone)]
11pub struct Expr<'a> {
12    pub id: ExprId,
13    pub arena: &'a RefCell<ExprArena>,
14}
15
16impl std::fmt::Debug for Expr<'_> {
17    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
18        f.debug_struct("Expr").field("id", &self.id).finish()
19    }
20}
21
22impl<'a> Expr<'a> {
23    #[inline]
24    pub fn new(id: ExprId, arena: &'a RefCell<ExprArena>) -> Self {
25        Self { id, arena }
26    }
27
28    pub fn constant(arena: &'a RefCell<ExprArena>, v: f64) -> Self {
29        let id = arena.borrow_mut().constant(v);
30        Self::new(id, arena)
31    }
32
33    pub fn from_var(arena: &'a RefCell<ExprArena>, v: VarId) -> Self {
34        let id = arena.borrow_mut().var(v);
35        Self::new(id, arena)
36    }
37
38    /// If this handle is a bare variable, return its [`VarId`].
39    /// `None` for compound expressions (sums, products, constants, ...).
40    pub fn var_id(self) -> Option<VarId> {
41        match self.arena.borrow().get(self.id) {
42            ExprNode::Var(id) => Some(*id),
43            _ => None,
44        }
45    }
46
47    /// If this handle is a bare parameter, return its [`ParamId`].
48    /// `None` for compound expressions.
49    pub fn param_id(self) -> Option<ParamId> {
50        match self.arena.borrow().get(self.id) {
51            ExprNode::Param(id) => Some(*id),
52            _ => None,
53        }
54    }
55
56    /// Re-bind the parameter this handle references to `value`. Takes effect on
57    /// the next extraction/evaluation, which read the value straight from the
58    /// arena.
59    ///
60    /// # Panics
61    /// Panics if this handle is not a bare parameter (see [`Self::param_id`]).
62    pub fn set_param_value(self, value: f64) {
63        let id = self.param_id().expect("set_param_value expects a bare parameter handle");
64        self.arena.borrow_mut().set_param_value(id, value);
65    }
66
67    pub fn pow(self, exponent: Self) -> Self {
68        let id = self.arena.borrow_mut().push(ExprNode::Pow(self.id, exponent.id));
69        Self::new(id, self.arena)
70    }
71
72    pub fn powi(self, n: i32) -> Self {
73        let id = {
74            let mut a = self.arena.borrow_mut();
75            let exp_id = a.constant(f64::from(n));
76            a.push(ExprNode::Pow(self.id, exp_id))
77        };
78        Self::new(id, self.arena)
79    }
80
81    pub fn powf(self, n: f64) -> Self {
82        let id = {
83            let mut a = self.arena.borrow_mut();
84            let exp_id = a.constant(n);
85            a.push(ExprNode::Pow(self.id, exp_id))
86        };
87        Self::new(id, self.arena)
88    }
89
90    pub fn sin(self) -> Self {
91        let id = self.arena.borrow_mut().push(ExprNode::Sin(self.id));
92        Self::new(id, self.arena)
93    }
94
95    pub fn cos(self) -> Self {
96        let id = self.arena.borrow_mut().push(ExprNode::Cos(self.id));
97        Self::new(id, self.arena)
98    }
99
100    pub fn exp(self) -> Self {
101        let id = self.arena.borrow_mut().push(ExprNode::Exp(self.id));
102        Self::new(id, self.arena)
103    }
104
105    pub fn log(self) -> Self {
106        let id = self.arena.borrow_mut().push(ExprNode::Log(self.id));
107        Self::new(id, self.arena)
108    }
109
110    pub fn abs(self) -> Self {
111        let id = self.arena.borrow_mut().push(ExprNode::Abs(self.id));
112        Self::new(id, self.arena)
113    }
114}
115
116#[cfg(test)]
117mod tests {
118    use std::cell::RefCell;
119
120    use super::Expr;
121    use crate::arena::ExprArena;
122
123    #[test]
124    fn set_param_value_rebinds_through_handle() {
125        let arena = RefCell::new(ExprArena::new());
126        let pid = arena.borrow_mut().new_param(0.05);
127        let node = arena.borrow_mut().param(pid);
128        let p = Expr::new(node, &arena);
129
130        p.set_param_value(0.2);
131        assert!((arena.borrow().param_value(pid) - 0.2).abs() < f64::EPSILON);
132    }
133
134    #[test]
135    #[should_panic(expected = "bare parameter handle")]
136    fn set_param_value_panics_on_non_param() {
137        let arena = RefCell::new(ExprArena::new());
138        let c = Expr::constant(&arena, 1.0);
139        c.set_param_value(3.0);
140    }
141}