1use serde::{Deserialize, Serialize};
7use std::fmt;
8use std::ops::{Add, Mul, Neg};
9
10use crate::error::{Result, TernaryError};
11use crate::trit::Trit;
12
13pub const TRYTE3_MIN: i32 = -13;
15pub const TRYTE3_MAX: i32 = 13;
17
18#[derive(Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
53pub struct Tryte3(u8);
54
55impl Tryte3 {
56 pub fn from_value(value: i32) -> Result<Self> {
78 if !(TRYTE3_MIN..=TRYTE3_MAX).contains(&value) {
79 return Err(TernaryError::InvalidTryteValue(value));
80 }
81
82 let trits = Self::value_to_trits(value);
83 Ok(Self::from_trits(trits))
84 }
85
86 #[must_use]
102 pub fn from_trits(trits: [Trit; 3]) -> Self {
103 let encoded = Self::encode_trit(trits[0])
104 | (Self::encode_trit(trits[1]) << 2)
105 | (Self::encode_trit(trits[2]) << 4);
106 Self(encoded)
107 }
108
109 #[must_use]
111 pub fn value(self) -> i32 {
112 let trits = self.to_trits();
113 trits[0].value() as i32 + trits[1].value() as i32 * 3 + trits[2].value() as i32 * 9
114 }
115
116 #[must_use]
122 pub fn to_trits(self) -> [Trit; 3] {
123 [
124 Self::decode_trit(self.0 & 0b11),
125 Self::decode_trit((self.0 >> 2) & 0b11),
126 Self::decode_trit((self.0 >> 4) & 0b11),
127 ]
128 }
129
130 #[must_use]
140 pub fn get_trit(self, index: usize) -> Trit {
141 assert!(index < 3, "trit index out of bounds");
142 Self::decode_trit((self.0 >> (index * 2)) & 0b11)
143 }
144
145 #[must_use]
147 pub const fn zero() -> Self {
148 Self(0b01_01_01)
150 }
151
152 #[must_use]
154 pub fn is_zero(self) -> bool {
155 self.value() == 0
156 }
157
158 #[must_use]
160 pub const fn raw(self) -> u8 {
161 self.0
162 }
163
164 fn value_to_trits(mut value: i32) -> [Trit; 3] {
166 let mut trits = [Trit::Z; 3];
167
168 for trit in &mut trits {
169 if value == 0 {
170 *trit = Trit::Z;
171 continue;
172 }
173
174 let mut rem = value % 3;
175 value /= 3;
176
177 if rem == 2 {
179 rem = -1;
180 value += 1;
181 } else if rem == -2 {
182 rem = 1;
183 value -= 1;
184 }
185
186 *trit = match rem {
187 -1 => Trit::N,
188 0 => Trit::Z,
189 1 => Trit::P,
190 _ => unreachable!(),
191 };
192 }
193
194 trits
195 }
196
197 fn encode_trit(trit: Trit) -> u8 {
199 match trit {
200 Trit::N => 0,
201 Trit::Z => 1,
202 Trit::P => 2,
203 }
204 }
205
206 fn decode_trit(bits: u8) -> Trit {
208 match bits & 0b11 {
209 0 => Trit::N,
210 1 | 3 => Trit::Z, 2 => Trit::P,
212 _ => unreachable!(),
213 }
214 }
215}
216
217impl Default for Tryte3 {
218 fn default() -> Self {
219 Self::zero()
220 }
221}
222
223impl Neg for Tryte3 {
224 type Output = Self;
225
226 fn neg(self) -> Self::Output {
227 let trits = self.to_trits();
228 Self::from_trits([-trits[0], -trits[1], -trits[2]])
229 }
230}
231
232impl Add for Tryte3 {
233 type Output = (Self, Trit);
234
235 fn add(self, other: Self) -> Self::Output {
237 let a = self.to_trits();
238 let b = other.to_trits();
239 let mut result = [Trit::Z; 3];
240 let mut carry = Trit::Z;
241
242 for i in 0..3 {
243 let (sum1, carry1) = a[i].add_with_carry(b[i]);
245 let (sum2, carry2) = sum1.add_with_carry(carry);
246
247 result[i] = sum2;
248 let (carry_sum, _) = carry1.add_with_carry(carry2);
250 carry = carry_sum;
251 }
252
253 (Self::from_trits(result), carry)
254 }
255}
256
257impl Mul for Tryte3 {
258 type Output = (Self, Self);
259
260 fn mul(self, other: Self) -> Self::Output {
264 let product = self.value() * other.value();
265
266 let low_val = ((product % 27) + 27 + 13) % 27 - 13;
268 let high_val = (product - low_val) / 27;
269
270 (
271 Self::from_value(low_val).unwrap_or_else(|_| Self::zero()),
272 Self::from_value(high_val.clamp(TRYTE3_MIN, TRYTE3_MAX))
273 .unwrap_or_else(|_| Self::zero()),
274 )
275 }
276}
277
278impl fmt::Debug for Tryte3 {
279 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
280 let trits = self.to_trits();
281 write!(
282 f,
283 "Tryte3({}{}{} = {})",
284 trits[2],
285 trits[1],
286 trits[0],
287 self.value()
288 )
289 }
290}
291
292impl fmt::Display for Tryte3 {
293 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
294 write!(f, "{}", self.value())
295 }
296}
297
298impl TryFrom<i32> for Tryte3 {
299 type Error = TernaryError;
300
301 fn try_from(value: i32) -> Result<Self> {
302 Self::from_value(value)
303 }
304}
305
306impl From<Tryte3> for i32 {
307 fn from(tryte: Tryte3) -> Self {
308 tryte.value()
309 }
310}
311
312#[cfg(test)]
313mod tests {
314 use super::*;
315
316 #[test]
317 fn test_tryte_range() {
318 for v in TRYTE3_MIN..=TRYTE3_MAX {
320 let t = Tryte3::from_value(v).expect("valid value");
321 assert_eq!(t.value(), v);
322 }
323
324 assert!(Tryte3::from_value(TRYTE3_MIN - 1).is_err());
326 assert!(Tryte3::from_value(TRYTE3_MAX + 1).is_err());
327 }
328
329 #[test]
330 fn test_tryte_zero() {
331 let z = Tryte3::zero();
332 assert_eq!(z.value(), 0);
333 assert!(z.is_zero());
334 }
335
336 #[test]
337 fn test_tryte_trits_roundtrip() {
338 for v in TRYTE3_MIN..=TRYTE3_MAX {
339 let t = Tryte3::from_value(v).unwrap();
340 let trits = t.to_trits();
341 let reconstructed = Tryte3::from_trits(trits);
342 assert_eq!(reconstructed.value(), v);
343 }
344 }
345
346 #[test]
347 fn test_tryte_negation() {
348 for v in TRYTE3_MIN..=TRYTE3_MAX {
349 let t = Tryte3::from_value(v).unwrap();
350 let neg = -t;
351 assert_eq!(neg.value(), -v);
352 }
353 }
354
355 #[test]
356 fn test_tryte_addition() {
357 let a = Tryte3::from_value(5).unwrap();
359 let b = Tryte3::from_value(3).unwrap();
360 let (result, carry) = a + b;
361 assert_eq!(result.value() + carry.value() as i32 * 27, 8);
362
363 let a = Tryte3::from_value(13).unwrap();
365 let b = Tryte3::from_value(1).unwrap();
366 let (result, carry) = a + b;
367 let total = result.value() + carry.value() as i32 * 27;
369 assert_eq!(total, 14);
370 }
371
372 #[test]
373 fn test_tryte_multiplication() {
374 let a = Tryte3::from_value(3).unwrap();
376 let b = Tryte3::from_value(4).unwrap();
377 let (low, high) = a * b;
378 let total = low.value() + high.value() * 27;
379 assert_eq!(total, 12);
380
381 let a = Tryte3::from_value(10).unwrap();
383 let b = Tryte3::from_value(10).unwrap();
384 let (low, high) = a * b;
385 let total = low.value() + high.value() * 27;
386 assert_eq!(total, 100);
387 }
388
389 #[test]
390 fn test_tryte_get_trit() {
391 let t = Tryte3::from_trits([Trit::N, Trit::Z, Trit::P]);
392 assert_eq!(t.get_trit(0), Trit::N);
393 assert_eq!(t.get_trit(1), Trit::Z);
394 assert_eq!(t.get_trit(2), Trit::P);
395 }
396
397 #[test]
398 fn test_tryte_specific_values() {
399 let t = Tryte3::from_value(5).unwrap();
407 let trits = t.to_trits();
408 let reconstructed =
409 trits[0].value() as i32 + trits[1].value() as i32 * 3 + trits[2].value() as i32 * 9;
410 assert_eq!(reconstructed, 5);
411 }
412}