1use anchor_lang::prelude::{
34 error, error_code, ProgramError,
35};
36use std::result::Result as StdResult;
37
38#[error_code]
39pub enum ErrorCode {
41 #[msg("overflow")]
42 Overflow,
43 #[msg("underflow")]
44 Underflow,
45 #[msg("division by zero")]
46 DivisionByZero,
47}
48
49pub trait SafeMath {
51 type Output;
52
53 fn safe_add(&self, rhs: Self::Output) -> StdResult<Self::Output, ProgramError>;
54 fn safe_sub(&self, rhs: Self::Output) -> StdResult<Self::Output, ProgramError>;
55 fn safe_div(&self, rhs: Self::Output) -> StdResult<Self::Output, ProgramError>;
56 fn safe_mul(&self, rhs: Self::Output) -> StdResult<Self::Output, ProgramError>;
57 fn safe_pow(&self, exp: u32) -> StdResult<Self::Output, ProgramError>;
58}
59
60macro_rules! safe_math {
61 ($type: ident) => {
62 impl SafeMath for $type {
64 type Output = $type;
65
66 fn safe_add(&self, rhs: Self::Output) -> StdResult<Self::Output, ProgramError> {
67 match self.checked_add(rhs) {
68 Some(result) => Ok(result),
69 None => return Err(error!(ErrorCode::Overflow).into())
70 }
71 }
72
73 fn safe_sub(&self, rhs: Self::Output) -> StdResult<Self::Output, ProgramError> {
74 match self.checked_sub(rhs) {
75 Some(result) => Ok(result),
76 None => return Err(error!(ErrorCode::Underflow).into())
77 }
78 }
79
80 fn safe_mul(&self, rhs: Self::Output) -> StdResult<Self::Output, ProgramError> {
81 match self.checked_mul(rhs) {
82 Some(result) => Ok(result),
83 None => return Err(error!(ErrorCode::Underflow).into())
84 }
85 }
86
87 fn safe_div(&self, rhs: Self::Output) -> StdResult<Self::Output, ProgramError> {
88 match self.checked_div(rhs) {
89 Some(result) => Ok(result),
90 None => return Err(error!(ErrorCode::DivisionByZero).into())
91 }
92 }
93
94 fn safe_pow(&self, exp: u32) -> StdResult<Self::Output, ProgramError> {
95 match self.checked_pow(exp) {
96 Some(result) => Ok(result),
97 None => return Err(error!(ErrorCode::Overflow).into())
98 }
99 }
100 }
101 }
102}
103
104safe_math!(u128);
105safe_math!(u64);
106safe_math!(u32);
107safe_math!(u16);
108safe_math!(u8);