use bytemuck::{Pod, Zeroable};
use num_traits::{Num, One, Zero};
use std::fmt;
use std::fmt::Display;
use std::ops::{
Add, AddAssign, BitAnd, BitAndAssign, BitOr, BitOrAssign, BitXor, BitXorAssign, Div, DivAssign,
Mul, MulAssign, Rem, RemAssign, Shl, ShlAssign, Shr, ShrAssign, Sub, SubAssign,
};
use std::{
ops::{Neg, Not},
str::FromStr,
};
#[cfg(feature = "pyo3")]
use pyo3::prelude::*;
#[derive(Debug, PartialEq, Eq)]
pub enum I48Error {
ParseError(std::num::ParseIntError),
OutOfRange,
}
impl std::fmt::Display for I48Error {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
I48Error::ParseError(e) => write!(f, "Parse error: {}", e),
I48Error::OutOfRange => write!(f, "Value out of range for i48"),
}
}
}
impl std::error::Error for I48Error {
fn source(&self) -> Option<&(dyn std::error::Error + 'static)> {
match self {
I48Error::ParseError(e) => Some(e),
I48Error::OutOfRange => None,
}
}
}
impl From<std::num::ParseIntError> for I48Error {
fn from(err: std::num::ParseIntError) -> Self {
I48Error::ParseError(err)
}
}
#[allow(non_camel_case_types)]
#[repr(C)]
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Default)]
#[cfg_attr(feature = "pyo3", pyclass)]
pub struct i48 {
pub data: [u8; 6],
}
impl i48 {
pub const fn to_i64(self) -> i64 {
let [a, b, c, d, e, f] = self.data;
let value = i64::from_le_bytes([a, b, c, d, e, f, 0, 0]);
if value & 0x800000000000 != 0 {
value | 0xFFFF000000000000u64 as i64
} else {
value
}
}
pub const fn from_i64(n: i64) -> Self {
let bytes = n.to_le_bytes();
Self {
data: [bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5]],
}
}
pub const fn from_ne_bytes(bytes: [u8; 6]) -> Self {
let data = if cfg!(target_endian = "big") {
[bytes[5], bytes[4], bytes[3], bytes[2], bytes[1], bytes[0]]
} else {
[bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5]]
};
Self { data }
}
pub const fn from_le_bytes(bytes: [u8; 6]) -> Self {
Self {
data: [bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5]],
}
}
pub const fn from_be_bytes(bytes: [u8; 6]) -> Self {
Self {
data: [bytes[5], bytes[4], bytes[3], bytes[2], bytes[1], bytes[0]],
}
}
pub fn checked_add(self, other: Self) -> Option<Self> {
self.to_i64()
.checked_add(other.to_i64())
.map(Self::from_i64)
}
pub fn checked_sub(self, other: Self) -> Option<Self> {
self.to_i64()
.checked_sub(other.to_i64())
.map(Self::from_i64)
}
pub fn checked_mul(self, other: Self) -> Option<Self> {
self.to_i64()
.checked_mul(other.to_i64())
.map(Self::from_i64)
}
pub fn checked_div(self, other: Self) -> Option<Self> {
self.to_i64()
.checked_div(other.to_i64())
.map(Self::from_i64)
}
}
impl From<i64> for i48 {
fn from(value: i64) -> Self {
i48::from_i64(value)
}
}
impl From<i48> for i64 {
fn from(value: i48) -> Self {
value.to_i64()
}
}
unsafe impl Zeroable for i48 {}
unsafe impl Pod for i48 {}
impl One for i48 {
fn one() -> Self {
i48::from_i64(1)
}
}
impl Zero for i48 {
fn zero() -> Self {
i48::from_i64(0)
}
fn is_zero(&self) -> bool {
i48::from_i64(0) == *self
}
}
impl Num for i48 {
type FromStrRadixErr = I48Error;
fn from_str_radix(str: &str, radix: u32) -> Result<Self, Self::FromStrRadixErr> {
let i64_result = i64::from_str_radix(str, radix).map_err(I48Error::ParseError)?;
if !(-140737488355328..=140737488355327).contains(&i64_result) {
Err(I48Error::OutOfRange)
} else {
Ok(i48::from_i64(i64_result))
}
}
}
#[cfg(feature = "pyo3")]
use numpy::Element;
#[cfg(feature = "pyo3")]
unsafe impl Element for i48 {
const IS_COPY: bool = true;
fn get_dtype_bound(py: Python<'_>) -> Bound<'_, numpy::PyArrayDescr> {
numpy::dtype_bound::<i48>(py)
}
}
impl Add for i48 {
type Output = Self;
fn add(self, other: Self) -> Self {
let result = self.to_i64().wrapping_add(other.to_i64());
Self::from_i64(result)
}
}
impl Sub for i48 {
type Output = Self;
fn sub(self, other: Self) -> Self {
let result = self.to_i64().wrapping_sub(other.to_i64());
Self::from_i64(result)
}
}
impl Mul for i48 {
type Output = Self;
fn mul(self, other: Self) -> Self {
let result = self.to_i64().wrapping_mul(other.to_i64());
Self::from_i64(result)
}
}
impl Div for i48 {
type Output = Self;
fn div(self, other: Self) -> Self {
let result = self.to_i64().wrapping_div(other.to_i64());
Self::from_i64(result)
}
}
impl Rem for i48 {
type Output = Self;
fn rem(self, other: Self) -> Self {
let result = self.to_i64().wrapping_rem(other.to_i64());
Self::from_i64(result)
}
}
impl Neg for i48 {
type Output = Self;
fn neg(self) -> Self {
let i64_result = self.to_i64().wrapping_neg();
i48::from_i64(i64_result)
}
}
impl Not for i48 {
type Output = Self;
fn not(self) -> Self {
let i64_result = !self.to_i64();
i48::from_i64(i64_result)
}
}
impl BitAnd for i48 {
type Output = Self;
fn bitand(self, rhs: Self) -> Self::Output {
let result = self.to_i64() & rhs.to_i64();
Self::from_i64(result)
}
}
impl BitOr for i48 {
type Output = Self;
fn bitor(self, rhs: Self) -> Self::Output {
let result = self.to_i64() | rhs.to_i64();
Self::from_i64(result)
}
}
impl BitXor for i48 {
type Output = Self;
fn bitxor(self, rhs: Self) -> Self::Output {
let result = self.to_i64() ^ rhs.to_i64();
Self::from_i64(result)
}
}
impl Shl<u32> for i48 {
type Output = Self;
fn shl(self, rhs: u32) -> Self::Output {
let result = (self.to_i64() << rhs) & 0x0000FFFFFFFFFFFF;
if result & 0x800000 != 0 {
Self::from_i64(result | 0xFFFF000000000000u64 as i64)
} else {
Self::from_i64(result)
}
}
}
impl Shr<u32> for i48 {
type Output = Self;
fn shr(self, rhs: u32) -> Self::Output {
let value = self.to_i64();
let result = if value < 0 {
((value >> rhs) | (-1 << (48 - rhs))) & 0x0000FFFFFFFFFFFF
} else {
(value >> rhs) & 0x0000FFFFFFFFFFFF
};
Self::from_i64(result)
}
}
impl Display for i48 {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.to_i64())
}
}
impl FromStr for i48 {
type Err = I48Error;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let i64_result = i64::from_str(s)?;
if !(-140737488355328..=140737488355327).contains(&i64_result) {
Err(I48Error::OutOfRange)
} else {
Ok(i48::from_i64(i64_result))
}
}
}
macro_rules! implement_ops_assign {
($($trait_path:path { $($function_name:ident),* }),*) => {
$(
impl $trait_path for i48 {
$(
fn $function_name(&mut self, other: Self){
let mut self_i64: i64 = self.to_i64();
let other_i64: i64 = other.to_i64();
self_i64.$function_name(other_i64);
}
)*
}
)*
};
}
macro_rules! implement_ops_assign_ref {
($($trait_path:path { $($function_name:ident),* }),*) => {
$(
impl $trait_path for &i48 {
$(
fn $function_name(&mut self, other: Self){
let mut self_i64: i64 = self.to_i64();
let other_i64: i64 = other.to_i64();
self_i64.$function_name(other_i64);
}
)*
}
)*
};
}
implement_ops_assign!(
AddAssign { add_assign },
SubAssign { sub_assign },
MulAssign { mul_assign },
DivAssign { div_assign },
RemAssign { rem_assign },
BitAndAssign { bitand_assign },
BitOrAssign { bitor_assign },
BitXorAssign { bitxor_assign },
ShlAssign { shl_assign },
ShrAssign { shr_assign }
);
implement_ops_assign_ref!(
AddAssign { add_assign },
SubAssign { sub_assign },
MulAssign { mul_assign },
DivAssign { div_assign },
RemAssign { rem_assign },
BitAndAssign { bitand_assign },
BitOrAssign { bitor_assign },
BitXorAssign { bitxor_assign },
ShlAssign { shl_assign },
ShrAssign { shr_assign }
);
#[cfg(test)]
mod i48_tests {
use super::*;
#[test]
fn test_arithmetic_operations() {
let a = i48::from_i64(100);
let b = i48::from_i64(50);
assert_eq!((a + b).to_i64(), 150);
assert_eq!((a - b).to_i64(), 50);
assert_eq!((a * b).to_i64(), 5000);
assert_eq!((a / b).to_i64(), 2);
assert_eq!((a % b).to_i64(), 0);
}
#[test]
fn test_bitwise_operations() {
let a = i48::from_i64(0b101010);
let b = i48::from_i64(0b110011);
assert_eq!((a & b).to_i64(), 0b100010);
assert_eq!((a | b).to_i64(), 0b111011);
assert_eq!((a ^ b).to_i64(), 0b011001);
assert_eq!((a << 2).to_i64(), 0b10101000);
assert_eq!((a >> 2).to_i64(), 0b1010);
}
#[test]
fn test_unary_operations() {
let a = i48::from_i64(100);
assert_eq!((-a).to_i64(), -100);
assert_eq!((!a).to_i64(), -101);
}
#[test]
fn test_from_i64() {
assert_eq!(i48::from_i64(0).to_i64(), 0);
assert_eq!(i48::from_i64(140737488355327).to_i64(), 140737488355327); assert_eq!(i48::from_i64(-140737488355328).to_i64(), -140737488355328); }
#[test]
fn test_from_bytes() {
assert_eq!(
i48::from_ne_bytes([0x01, 0x02, 0x03, 0x04, 0x05, 0x06]).to_i64(),
if cfg!(target_endian = "big") {
0x010203040506
} else {
0x060504030201
}
);
assert_eq!(
i48::from_le_bytes([0x01, 0x02, 0x03, 0x04, 0x05, 0x06]).to_i64(),
0x060504030201
);
assert_eq!(
i48::from_be_bytes([0x01, 0x02, 0x03, 0x04, 0x05, 0x06]).to_i64(),
0x010203040506
);
}
#[test]
fn test_to_i64() {
let a = i48::from_ne_bytes([0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0x7F]);
assert_eq!(a.to_i64(), 140737488355327);
let b = i48::from_ne_bytes([0x00, 0x00, 0x00, 0x00, 0x00, 0x80]);
assert_eq!(b.to_i64(), -140737488355328); }
#[test]
fn test_zero_and_one() {
assert_eq!(i48::zero().to_i64(), 0);
assert_eq!(i48::one().to_i64(), 1);
}
#[test]
fn test_from_str() {
assert_eq!(i48::from_str("100").unwrap().to_i64(), 100);
assert_eq!(i48::from_str("-100").unwrap().to_i64(), -100);
assert_eq!(
i48::from_str("140737488355327").unwrap().to_i64(),
140737488355327
); assert_eq!(
i48::from_str("-140737488355328").unwrap().to_i64(),
-140737488355328
); assert_eq!(
i48::from_str("140737488355328").unwrap_err(),
I48Error::OutOfRange
);
assert_eq!(
i48::from_str("-140737488355329").unwrap_err(),
I48Error::OutOfRange
);
}
#[test]
fn test_display() {
assert_eq!(format!("{}", i48::from_i64(100)), "100");
assert_eq!(format!("{}", i48::from_i64(-100)), "-100");
}
#[test]
fn test_wrapping_behavior() {
let max = i48::from_i64(140737488355327);
assert_eq!((max + i48::one()).to_i64(), -140737488355328);
let min = i48::from_i64(-140737488355328);
assert_eq!((min - i48::one()).to_i64(), 140737488355327);
}
#[test]
fn test_shift_operations() {
let a = i48::from_i64(0b1);
assert_eq!((a << 47).to_i64(), -140737488355328); assert_eq!((a << 48).to_i64(), 0);
let b = i48::from_i64(-1); assert_eq!((b >> 1).to_i64(), -1); assert_eq!((b >> 47).to_i64(), -1); assert_eq!((b >> 48).to_i64(), -1);
let c = i48::from_i64(0x7FFFFFFFFFFF); assert_eq!((c << 1).to_i64(), -2);
let d = i48::from_i64(-0x800000000000); assert_eq!((d >> 1).to_i64(), -70368744177664);
assert_eq!((c << 1).to_i64(), -2); assert_eq!((c << 2).to_i64(), -4); assert_eq!((c << 3).to_i64(), -8); }
}