1use nom::branch::alt;
6use nom::character::complete::alpha1;
7use nom::character::complete::multispace1;
8use nom::combinator::all_consuming;
9use nom::multi::many0;
10use nom::sequence::preceded;
11use nom::sequence::tuple;
12use std::collections::HashMap;
13use std::collections::HashSet;
14use std::ops::Add;
15use std::ops::Mul;
16
17use nom::bytes::complete::tag;
18use nom::character::complete::i32;
19use nom::number::complete::float;
20use nom::sequence::{separated_pair};
21use nom::IResult;
22
23#[derive(Debug, Clone, PartialEq)]
24pub struct Unit {
25 x: f32,
26 units: HashMap<String, i32>,
27}
28
29impl Add for Unit {
30 type Output = Result<Self, String>;
31
32 fn add(self, other: Self) -> Result<Self, String> {
33 if self.units == other.units {
34 Ok(Self {
35 x: self.x + other.x,
36 units: self.units,
37 })
38 } else {
39 Err("Units don't match".to_string())
40 }
41 }
42}
43
44impl Mul for Unit {
45 type Output = Self;
46
47 fn mul(self, rhs: Self) -> Self {
48 let new_x = self.x * rhs.x;
49 let new_units = combine(self.units, rhs.units);
50 Self {
51 x: new_x,
52 units: new_units,
53 }
54 }
55}
56
57fn combine(units1: HashMap<String, i32>, units2: HashMap<String, i32>) -> HashMap<String, i32> {
58 let keys1 = units1.keys().collect::<HashSet<&String>>();
59 let keys2 = units2.keys().collect::<HashSet<&String>>();
60 let all_keys = keys1.union(&keys2);
61 let mut result = HashMap::new();
62 for key in all_keys {
63 let new_value =
64 units1.get(&key.to_string()).unwrap_or(&0) + units2.get(&key.to_string()).unwrap_or(&0);
65 if new_value != 0 {
66 result.insert(key.to_string(), new_value);
67 }
68 }
69 result
70}
71
72fn parse_unit_and_exp(input: &str) -> IResult<&str, (&str, i32)> {
89 separated_pair(alpha1, tag("^"), i32)(input)
90}
91
92fn parse_unit_and_default_exp(input: &str) -> IResult<&str, (&str, i32)> {
93 let (remaining, unit_name) = alpha1(input)?;
94 Ok((remaining, (unit_name, 1)))
95}
96
97fn parse_unit_and_maybe_exp(input: &str) -> IResult<&str, (&str, i32)> {
98 alt((parse_unit_and_exp, parse_unit_and_default_exp))(input)
99}
100
101fn parse_full_expression(input: &str) -> IResult<&str, (f32, Vec<(&str, i32)>)> {
102 all_consuming(tuple((
103 float,
104 many0(preceded(multispace1, parse_unit_and_maybe_exp)),
105 )))(input)
106}
107
108fn build_unit(input: (f32, Vec<(&str, i32)>)) -> Unit {
109 Unit {
110 x: input.0,
111 units: input
112 .1
113 .into_iter()
114 .map(|(x, y)| (String::from(x), y))
115 .collect::<HashMap<_, _>>(),
116 }
117}
118
119pub fn u(input: &str) -> Result<Unit, Box<dyn std::error::Error + '_>> {
120 let (_, unpacked) = parse_full_expression(input)?;
121 Ok(build_unit(unpacked))
122}
123
124
125#[cfg(test)]
126mod tests {
127 use super::*;
128
129 #[test]
130 fn basic_add() {
131 assert_eq!(
132 Unit {
133 x: 1.,
134 units: HashMap::from([("m".to_string(), 1)])
135 } + Unit {
136 x: 2.,
137 units: HashMap::from([("m".to_string(), 1)])
138 },
139 Ok(Unit {
140 x: 3.,
141 units: HashMap::from([("m".to_string(), 1)])
142 })
143 );
144 }
145
146 #[test]
147 fn combine_works() {
148 assert_eq!(
149 combine(
150 HashMap::from([("m".to_string(), 2), ("s".to_string(), -2),]),
151 HashMap::from([("m".to_string(), 1), ("s".to_string(), 1),])
152 ),
153 HashMap::from([("m".to_string(), 3), ("s".to_string(), -1),])
154 )
155 }
156 #[test]
157 fn unit_and_exp_works() {
158 let (remaining, parsed) = parse_unit_and_exp("meters^2").unwrap();
159 assert_eq!(parsed, ("meters", 2));
160 assert_eq!(remaining, "");
161 }
162
163 #[test]
164 fn unit_and_exp1_works() {
165 let (remaining, parsed) = parse_unit_and_default_exp("meters").unwrap();
166 assert_eq!(parsed, ("meters", 1));
167 assert_eq!(remaining, "");
168 }
169
170 #[test]
171 fn unit_and_exp_maybe_works() {
172 let (remaining, parsed) = parse_unit_and_maybe_exp("meters").unwrap();
173 assert_eq!(parsed, ("meters", 1));
174 assert_eq!(remaining, "");
175 let (remaining, parsed) = parse_unit_and_maybe_exp("meters^-4").unwrap();
176 assert_eq!(parsed, ("meters", -4));
177 assert_eq!(remaining, "");
178 }
179
180 #[test]
181 fn parse_full_expression_works() {
182 let (remaining, parsed) = parse_full_expression("5 meters^2 seconds").unwrap();
183 assert_eq!(parsed, (5.0, vec![("meters", 2), ("seconds", 1)]));
184 assert_eq!(remaining, "");
185 }
186
187 #[test]
188 fn parse_and_build() {
189 let (remaining, parsed) = parse_full_expression("5 meters^2 seconds^-1").unwrap();
190 assert_eq!(remaining, "");
191 let u = build_unit(parsed);
192 assert_eq!(
193 u,
194 Unit {
195 x: 5.,
196 units: HashMap::from([("meters".to_string(), 2), ("seconds".to_string(), -1)])
197 }
198 );
199 }
200 #[test]
201 fn no_remaining() {
202 let result = parse_full_expression("5 meters^2 seconds^-1 ");
203 assert!(result.is_err());
204 }
205 #[test]
206 fn parse_and_convert() {
207 let u = u("5 meters^2 seconds^-1").unwrap();
208 assert_eq!(
209 u,
210 Unit {
211 x: 5.,
212 units: HashMap::from([("meters".to_string(), 2), ("seconds".to_string(), -1)])
213 }
214 );
215 }
216 #[test]
217 fn basic_mul() {
218 let u1 = u("5 meters^2 seconds^-1").unwrap();
219 let u2 = u("3 meters^-1 kg").unwrap();
220 let u3 = u("15 kg^1 meters seconds^-1").unwrap();
221 assert_eq!(u1 * u2, u3);
222 }
223 #[test]
224 fn cancel() {
225 let u1 = u("-5 meters^2 seconds^-1").unwrap();
226 let u2 = u("4.1 meters^-2").unwrap();
227 let u3 = u("-20.5 seconds^-1").unwrap();
228 assert_eq!(u1 * u2, u3);
229 }
230 #[test]
231 fn test_defaults() {
232 let u1 = u("5 m s kg").unwrap();
233 let u2 = u("2 m^-1").unwrap();
234 let u3 = u("7 s^-1").unwrap();
235 let u4 = u("3 kg^-1").unwrap();
236 let u5 = u("210").unwrap();
237 assert_eq!(u1 * u2 * u3 * u4, u5);
238 }
239}