1use std::fmt;
4use std::rc::Rc;
5
6use super::node::Node;
7use super::node_base::NodeBase;
8use super::constant::ConstantScalar;
9
10pub enum VariableKind {
12 VarContinuous,
13 VarInteger,
14}
15
16pub struct VariableScalar {
18 name: String,
19 kind: VariableKind,
20}
21
22impl VariableScalar {
23
24 pub fn is_continuous(&self) -> bool {
26 match self.kind {
27 VariableKind::VarContinuous => true,
28 _ => false,
29 }
30 }
31
32 pub fn is_integer(&self) -> bool {
34 match self.kind {
35 VariableKind::VarInteger => true,
36 _ => false,
37 }
38 }
39
40 pub fn name(&self) -> &str { self.name.as_ref() }
42
43 pub fn new(name: &str, kind: VariableKind) -> Node {
45 Node::VariableScalar(Rc::new(
46 Self {
47 name: name.to_string(),
48 kind: kind,
49 }
50 ))
51 }
52
53 pub fn new_continuous(name: &str) -> Node {
55 VariableScalar::new(name, VariableKind::VarContinuous)
56 }
57
58 pub fn new_integer(name: &str) -> Node {
60 VariableScalar::new(name, VariableKind::VarInteger)
61 }
62}
63
64impl NodeBase for VariableScalar {
65
66 fn partial(&self, arg: &Node) -> Node {
67 match arg {
68 Node::VariableScalar(x) => {
69 if self as *const VariableScalar == x.as_ref() {
70 ConstantScalar::new(1.)
71 }
72 else {
73 ConstantScalar::new(0.)
74 }
75 }
76 _ => ConstantScalar::new(0.)
77 }
78 }
79}
80
81impl<'a> fmt::Display for VariableScalar {
82 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
83 write!(f, "{}", self.name)
84 }
85}
86
87#[cfg(test)]
88mod tests {
89
90 use crate::model::node::Node;
91 use crate::model::node_base::NodeBase;
92 use crate::model::node_std::NodeStd;
93 use crate::model::node_diff::NodeDiff;
94 use crate::model::variable::VariableScalar;
95
96 #[test]
97 fn var_construction() {
98
99 let x = VariableScalar::new_continuous("x");
100 assert_eq!(x.name(), "x");
101 match x {
102 Node::VariableScalar(xx) => {
103 assert!(xx.is_continuous());
104 assert!(!xx.is_integer());
105 },
106 _ => panic!("construction failed"),
107 }
108
109 let y = VariableScalar::new_integer("y");
110 assert_eq!(y.name(), "y");
111 match y {
112 Node::VariableScalar(yy) => {
113 assert!(yy.is_integer());
114 assert!(!yy.is_continuous());
115 },
116 _ => panic!("construction failed"),
117 }
118 }
119
120 #[test]
121 fn var_partial() {
122
123 let x = VariableScalar::new_continuous("x");
124 let y = VariableScalar::new_continuous("y");
125
126 let z1 = x.partial(&x);
127 assert!(z1.is_constant_with_value(1.));
128
129 let z2 = x.partial(&y);
130 assert!(z2.is_constant_with_value(0.));
131 }
132
133 #[test]
134 fn var_derivative() {
135
136 let x = VariableScalar::new_continuous("x");
137 let y = VariableScalar::new_continuous("y");
138
139 let z1 = x.derivative(&y);
140 assert!(z1.is_constant_with_value(0.));
141
142 let z2 = x.derivative(&x);
143 assert!(z2.is_constant_with_value(1.));
144 }
145
146 #[test]
147 fn var_std_properties() {
148
149 let x = VariableScalar::new_integer("x");
150 let p = x.std_properties();
151 assert!(p.affine);
152 assert_eq!(p.b, 0.);
153 assert_eq!(p.a.len(), 1);
154 assert_eq!(*p.a.get(&x).unwrap(), 1.);
155 }
156}