#![doc = include_str!("../README.md")]
use rust_decimal::{Decimal, dec};
const VAT_RATE: Decimal = dec!(12.0);
pub struct Vat {
pub net: Decimal,
pub vat: Decimal,
pub gross: Decimal,
}
pub enum VatInput {
Net(Decimal),
Gross(Decimal),
Vat(Decimal),
}
#[derive(Debug, PartialEq)]
pub enum Error {
NegativeInput(String),
RateOutOfBound(String),
}
pub fn compute_vat(input: VatInput, rate: Option<Decimal>) -> Result<Vat, Error> {
let vat_rate = rate.unwrap_or(VAT_RATE) / dec!(100.0);
if vat_rate < dec!(0.0) {
return Err(Error::RateOutOfBound(
"VAT rate cannot be negative".to_string(),
));
}
if vat_rate > dec!(1.0) {
return Err(Error::RateOutOfBound(
"VAT rate cannot exceed 100%".to_string(),
));
}
match input {
VatInput::Net(net) => {
if net < dec!(0.0) {
return Err(Error::NegativeInput(
"Net amount cannot be negative".to_string(),
));
}
let vat = net * vat_rate;
Ok(Vat {
net,
vat,
gross: net + vat,
})
}
VatInput::Gross(gross) => {
if gross < dec!(0.0) {
return Err(Error::NegativeInput(
"Gross amount cannot be negative".to_string(),
));
}
let net = gross / (dec!(1.0) + vat_rate);
Ok(Vat {
net,
vat: gross - net,
gross,
})
}
VatInput::Vat(vat) => {
if vat < dec!(0.0) {
return Err(Error::NegativeInput(
"VAT amount cannot be negative".to_string(),
));
}
let net = vat / vat_rate;
Ok(Vat {
net,
vat,
gross: net + vat,
})
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use rust_decimal::{Decimal, dec};
fn round_to_2(value: Decimal) -> Decimal {
value.round_dp(2)
}
#[test]
fn test_compute_vat_net_default_rate() {
let result = compute_vat(VatInput::Net(dec!(1000.0)), None).unwrap();
assert_eq!(round_to_2(result.net), dec!(1000.0));
assert_eq!(round_to_2(result.vat), dec!(120.0)); assert_eq!(round_to_2(result.gross), dec!(1120.0)); }
#[test]
fn test_compute_vat_gross_default_rate() {
let result = compute_vat(VatInput::Gross(dec!(1120.0)), None).unwrap();
assert_eq!(round_to_2(result.net), dec!(1000.0)); assert_eq!(round_to_2(result.vat), dec!(120.0)); assert_eq!(round_to_2(result.gross), dec!(1120.0));
}
#[test]
fn test_compute_vat_vat_default_rate() {
let result = compute_vat(VatInput::Vat(dec!(120.0)), None).unwrap();
assert_eq!(round_to_2(result.net), dec!(1000.0)); assert_eq!(round_to_2(result.vat), dec!(120.0));
assert_eq!(round_to_2(result.gross), dec!(1120.0)); }
#[test]
fn test_compute_vat_net_custom_rate() {
let result = compute_vat(VatInput::Net(dec!(1000.0)), Some(dec!(10.0))).unwrap();
assert_eq!(round_to_2(result.net), dec!(1000.0));
assert_eq!(round_to_2(result.vat), dec!(100.0)); assert_eq!(round_to_2(result.gross), dec!(1100.0)); }
#[test]
fn test_compute_vat_gross_custom_rate() {
let result = compute_vat(VatInput::Gross(dec!(1100.0)), Some(dec!(10.0))).unwrap();
assert_eq!(round_to_2(result.net), dec!(1000.0)); assert_eq!(round_to_2(result.vat), dec!(100.0)); assert_eq!(round_to_2(result.gross), dec!(1100.0));
}
#[test]
fn test_compute_vat_vat_custom_rate() {
let result = compute_vat(VatInput::Vat(dec!(100.0)), Some(dec!(10.0))).unwrap();
assert_eq!(round_to_2(result.net), dec!(1000.0)); assert_eq!(round_to_2(result.vat), dec!(100.0));
assert_eq!(round_to_2(result.gross), dec!(1100.0)); }
#[test]
fn test_compute_vat_net_incorrect() {
let result = compute_vat(VatInput::Net(dec!(1000.0)), None).unwrap();
assert_ne!(round_to_2(result.vat), dec!(100.0)); assert_ne!(round_to_2(result.gross), dec!(1000.0)); }
#[test]
fn test_compute_vat_gross_incorrect() {
let result = compute_vat(VatInput::Gross(dec!(1120.0)), None).unwrap();
assert_ne!(round_to_2(result.net), dec!(1120.0)); assert_ne!(round_to_2(result.vat), dec!(100.0)); }
#[test]
fn test_compute_vat_vat_incorrect() {
let result = compute_vat(VatInput::Vat(dec!(120.0)), None).unwrap();
assert_ne!(round_to_2(result.net), dec!(120.0)); assert_ne!(round_to_2(result.gross), dec!(100.0)); }
#[test]
fn test_compute_vat_zero_net() {
let result = compute_vat(VatInput::Net(dec!(0.0)), None).unwrap();
assert_eq!(round_to_2(result.net), dec!(0.0));
assert_eq!(round_to_2(result.vat), dec!(0.0)); assert_eq!(round_to_2(result.gross), dec!(0.0)); }
#[test]
fn test_compute_vat_zero_gross() {
let result = compute_vat(VatInput::Gross(dec!(0.0)), None).unwrap();
assert_eq!(round_to_2(result.net), dec!(0.0)); assert_eq!(round_to_2(result.vat), dec!(0.0)); assert_eq!(round_to_2(result.gross), dec!(0.0));
}
#[test]
fn test_compute_vat_zero_vat() {
let result = compute_vat(VatInput::Vat(dec!(0.0)), None).unwrap();
assert_eq!(round_to_2(result.net), dec!(0.0)); assert_eq!(round_to_2(result.vat), dec!(0.0));
assert_eq!(round_to_2(result.gross), dec!(0.0)); }
#[test]
fn test_compute_vat_negative_net() {
let result = compute_vat(VatInput::Net(dec!(-1000.0)), None);
assert!(matches!(result, Err(Error::NegativeInput(_))));
}
#[test]
fn test_compute_vat_negative_gross() {
let result = compute_vat(VatInput::Gross(dec!(-1120.0)), None);
assert!(matches!(result, Err(Error::NegativeInput(_))));
}
#[test]
fn test_compute_vat_negative_vat() {
let result = compute_vat(VatInput::Vat(dec!(-120.0)), None);
assert!(matches!(result, Err(Error::NegativeInput(_))));
}
#[test]
fn test_compute_vat_negative_rate() {
let result = compute_vat(VatInput::Net(dec!(1000.0)), Some(dec!(-12.0)));
assert!(matches!(result, Err(Error::RateOutOfBound(_))));
}
#[test]
fn test_compute_vat_excessive_rate() {
let result = compute_vat(VatInput::Net(dec!(1000.0)), Some(dec!(150.0)));
assert!(matches!(result, Err(Error::RateOutOfBound(_))));
}
#[test]
fn test_compute_vat_zero_rate() {
let result = compute_vat(VatInput::Net(dec!(1000.0)), Some(dec!(0.0))).unwrap();
assert_eq!(round_to_2(result.net), dec!(1000.0));
assert_eq!(round_to_2(result.vat), dec!(0.0)); assert_eq!(round_to_2(result.gross), dec!(1000.0)); }
}