use modkit_macros::domain_model;
#[allow(dead_code)]
pub const MAX_TOKENS: u64 = 10_000_000;
#[allow(dead_code)]
pub const MAX_MULT: u64 = 10_000_000_000;
#[allow(dead_code)]
pub const DIVISOR: u64 = 1_000_000;
#[domain_model]
#[allow(dead_code, clippy::enum_variant_names)]
#[derive(Debug, thiserror::Error)]
pub enum CreditOverflowError {
#[error("tokens {0} exceed MAX_TOKENS {MAX_TOKENS}")]
TokensOverflow(u64),
#[error("multiplier {0} exceeds MAX_MULT {MAX_MULT}")]
MultiplierOverflow(u64),
#[error("arithmetic overflow in checked_mul")]
ArithmeticOverflow,
}
#[allow(dead_code, clippy::integer_division)]
pub fn ceil_div_checked(a: u64, b: u64) -> Result<u64, CreditOverflowError> {
debug_assert!(b != 0, "ceil_div_checked: divisor must be non-zero");
if a == 0 || b == 0 {
return Ok(0);
}
a.checked_add(b - 1)
.map(|n| n / b)
.ok_or(CreditOverflowError::ArithmeticOverflow)
}
#[allow(dead_code)]
pub fn credits_micro_checked(
input_tokens: u64,
output_tokens: u64,
input_mult: u64,
output_mult: u64,
) -> Result<i64, CreditOverflowError> {
if input_tokens > MAX_TOKENS {
return Err(CreditOverflowError::TokensOverflow(input_tokens));
}
if output_tokens > MAX_TOKENS {
return Err(CreditOverflowError::TokensOverflow(output_tokens));
}
if input_mult > MAX_MULT {
return Err(CreditOverflowError::MultiplierOverflow(input_mult));
}
if output_mult > MAX_MULT {
return Err(CreditOverflowError::MultiplierOverflow(output_mult));
}
let input_product = input_tokens
.checked_mul(input_mult)
.ok_or(CreditOverflowError::ArithmeticOverflow)?;
let output_product = output_tokens
.checked_mul(output_mult)
.ok_or(CreditOverflowError::ArithmeticOverflow)?;
let input_credits = ceil_div_checked(input_product, DIVISOR)?;
let output_credits = ceil_div_checked(output_product, DIVISOR)?;
let total = input_credits
.checked_add(output_credits)
.ok_or(CreditOverflowError::ArithmeticOverflow)?;
#[allow(clippy::cast_possible_wrap)]
Ok(total as i64)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn normal_computation() {
let result = credits_micro_checked(1000, 500, 1_000_000, 3_000_000).unwrap();
assert_eq!(result, 2500);
}
#[test]
fn zero_tokens() {
assert_eq!(
credits_micro_checked(0, 0, 1_000_000, 3_000_000).unwrap(),
0
);
}
#[test]
fn rounding_each_component_ceil_div_independently() {
let result = credits_micro_checked(1, 1, 1, 1).unwrap();
assert_eq!(result, 2);
}
#[test]
fn overflow_tokens_exceeds_max() {
let result = credits_micro_checked(MAX_TOKENS + 1, 0, 1, 1);
assert!(matches!(
result,
Err(CreditOverflowError::TokensOverflow(_))
));
}
#[test]
fn overflow_mult_exceeds_max() {
let result = credits_micro_checked(1, 0, MAX_MULT + 1, 1);
assert!(matches!(
result,
Err(CreditOverflowError::MultiplierOverflow(_))
));
}
#[test]
fn max_bounds_no_overflow() {
let result = credits_micro_checked(MAX_TOKENS, MAX_TOKENS, MAX_MULT, MAX_MULT).unwrap();
assert_eq!(result, 200_000_000_000);
}
}