use rust_decimal::Decimal;
use std::collections::HashMap;
#[derive(Debug, Clone, Default)]
pub struct DisplayContext {
precisions: HashMap<String, u32>,
render_commas: bool,
fixed_precisions: HashMap<String, u32>,
}
impl DisplayContext {
#[must_use]
pub fn new() -> Self {
Self::default()
}
pub fn update(&mut self, number: Decimal, currency: &str) {
let precision = Self::decimal_precision(number);
let entry = self.precisions.entry(currency.to_string()).or_insert(0);
*entry = (*entry).max(precision);
}
pub fn update_from(&mut self, other: &Self) {
for (currency, precision) in &other.precisions {
let entry = self.precisions.entry(currency.clone()).or_insert(0);
*entry = (*entry).max(*precision);
}
}
pub const fn set_render_commas(&mut self, render_commas: bool) {
self.render_commas = render_commas;
}
#[must_use]
pub const fn render_commas(&self) -> bool {
self.render_commas
}
pub fn set_fixed_precision(&mut self, currency: &str, precision: u32) {
self.fixed_precisions
.insert(currency.to_string(), precision);
}
#[must_use]
pub fn get_precision(&self, currency: &str) -> Option<u32> {
if let Some(&precision) = self.fixed_precisions.get(currency) {
return Some(precision);
}
self.precisions.get(currency).copied()
}
#[must_use]
pub fn quantize(&self, number: Decimal, currency: &str) -> Decimal {
if let Some(dp) = self.get_precision(currency) {
number.round_dp(dp)
} else {
number
}
}
#[must_use]
pub fn format(&self, number: Decimal, currency: &str) -> String {
let precision = self.get_precision(currency);
if let Some(dp) = precision {
let rounded = number.round_dp(dp);
let formatted = format!("{rounded}");
let formatted = Self::ensure_decimal_places(&formatted, dp);
if self.render_commas {
Self::add_commas(&formatted)
} else {
formatted
}
} else {
let formatted = number.normalize().to_string();
if self.render_commas {
Self::add_commas(&formatted)
} else {
formatted
}
}
}
#[must_use]
pub fn format_amount(&self, number: Decimal, currency: &str) -> String {
format!("{} {}", self.format(number, currency), currency)
}
const fn decimal_precision(number: Decimal) -> u32 {
number.scale()
}
fn ensure_decimal_places(s: &str, dp: u32) -> String {
if dp == 0 {
return s.split('.').next().unwrap_or(s).to_string();
}
let dp = dp as usize;
if let Some(dot_pos) = s.find('.') {
let current_decimals = s.len() - dot_pos - 1;
if current_decimals >= dp {
s.to_string()
} else {
let zeros_needed = dp - current_decimals;
format!("{s}{}", "0".repeat(zeros_needed))
}
} else {
format!("{s}.{}", "0".repeat(dp))
}
}
fn add_commas(s: &str) -> String {
let (integer_part, decimal_part) = match s.find('.') {
Some(pos) => (&s[..pos], Some(&s[pos..])),
None => (s, None),
};
let (sign, digits) = if let Some(stripped) = integer_part.strip_prefix('-') {
("-", stripped)
} else {
("", integer_part)
};
let mut result = String::with_capacity(digits.len() + digits.len() / 3);
for (i, c) in digits.chars().rev().enumerate() {
if i > 0 && i % 3 == 0 {
result.push(',');
}
result.push(c);
}
let integer_with_commas: String = result.chars().rev().collect();
match decimal_part {
Some(dec) => format!("{sign}{integer_with_commas}{dec}"),
None => format!("{sign}{integer_with_commas}"),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use rust_decimal_macros::dec;
#[test]
fn test_update_and_get_precision() {
let mut ctx = DisplayContext::new();
ctx.update(dec!(100), "USD");
assert_eq!(ctx.get_precision("USD"), Some(0));
ctx.update(dec!(50.25), "USD");
assert_eq!(ctx.get_precision("USD"), Some(2));
ctx.update(dec!(1), "USD");
assert_eq!(ctx.get_precision("USD"), Some(2));
assert_eq!(ctx.get_precision("EUR"), None);
}
#[test]
fn test_format_with_precision() {
let mut ctx = DisplayContext::new();
ctx.update(dec!(100), "USD");
ctx.update(dec!(50.25), "USD");
assert_eq!(ctx.format(dec!(100), "USD"), "100.00");
assert_eq!(ctx.format(dec!(50.25), "USD"), "50.25");
assert_eq!(ctx.format(dec!(7.5), "USD"), "7.50");
}
#[test]
fn test_format_unknown_currency() {
let ctx = DisplayContext::new();
assert_eq!(ctx.format(dec!(100), "EUR"), "100");
assert_eq!(ctx.format(dec!(50.25), "EUR"), "50.25");
}
#[test]
fn test_fixed_precision_override() {
let mut ctx = DisplayContext::new();
ctx.update(dec!(100), "USD");
ctx.update(dec!(50.25), "USD");
assert_eq!(ctx.get_precision("USD"), Some(2));
ctx.set_fixed_precision("USD", 4);
assert_eq!(ctx.get_precision("USD"), Some(4));
assert_eq!(ctx.format(dec!(100), "USD"), "100.0000");
}
#[test]
fn test_render_commas() {
let mut ctx = DisplayContext::new();
ctx.set_render_commas(true);
ctx.update(dec!(1234567.89), "USD");
assert_eq!(ctx.format(dec!(1234567.89), "USD"), "1,234,567.89");
assert_eq!(ctx.format(dec!(1000), "USD"), "1,000.00");
}
#[test]
fn test_add_commas() {
assert_eq!(DisplayContext::add_commas("1234567"), "1,234,567");
assert_eq!(DisplayContext::add_commas("1234567.89"), "1,234,567.89");
assert_eq!(DisplayContext::add_commas("-1234567.89"), "-1,234,567.89");
assert_eq!(DisplayContext::add_commas("123"), "123");
assert_eq!(DisplayContext::add_commas("1"), "1");
}
#[test]
fn test_update_from() {
let mut ctx1 = DisplayContext::new();
ctx1.update(dec!(100), "USD");
let mut ctx2 = DisplayContext::new();
ctx2.update(dec!(50.25), "USD");
ctx2.update(dec!(1.5), "EUR");
ctx1.update_from(&ctx2);
assert_eq!(ctx1.get_precision("USD"), Some(2));
assert_eq!(ctx1.get_precision("EUR"), Some(1));
}
#[test]
fn test_format_amount() {
let mut ctx = DisplayContext::new();
ctx.update(dec!(50.25), "USD");
assert_eq!(ctx.format_amount(dec!(100), "USD"), "100.00 USD");
}
}