multitype 0.21.1

Arithmetic type traits.
Documentation
// Copyright 2025-2026 Gabriel Bjørnager Jensen.
//
// SPDX: MIT OR Apache-2.0

//! The [`CarryingMulAdd`] trait.

mod test;

use multitype::Integral;

/// Implementation detail for `carrying_mul_add`.
pub(crate) trait CarryingMulAdd<Other = Self>: Integral {
	/// Performs carrying multiply-add.
	fn carrying_mul_add(
		self,
		mul:   Other,
		add:   Other,
		carry: Other,
	) -> (Self::Unsigned, Self);
}

/// Implements [`CarryingMulAdd`] for the given
/// types.
macro_rules! impl_carrying_mul_add {
	{ $($Ty:ty { Wide = $Wide:ty$(,)? }),*$(,)? } => {$(
		impl ::multitype::CarryingMulAdd for $Ty {
			#[inline(always)]
			#[track_caller]
			fn carrying_mul_add(
				self,
				mul:   Self,
				add:   Self,
				carry: Self,
			) -> (Self::Unsigned, Self) {
				let wide = self as $Wide * mul as $Wide + add as $Wide + carry as $Wide;

				let low  = wide as Self::Unsigned;
				let high = (wide >> Self::BITS) as Self;

				(low, high)
			}
		}
	)*};
}

impl_carrying_mul_add! {
	u8  { Wide = u16 },
	u16 { Wide = u32 },
	u32 { Wide = u64 },
	u64 { Wide = u128 },

	i8  { Wide = i16 },
	i16 { Wide = i32 },
	i32 { Wide = i64 },
	i64 { Wide = i128 },
}

#[cfg(target_pointer_width = "16")]
impl_carrying_mul_add! {
	usize { Wide = u32 },

	isize { Wide = i32 },
}

#[cfg(target_pointer_width = "32")]
impl_carrying_mul_add! {
	usize { Wide = u64 },

	isize { Wide = i64 },
}

#[cfg(target_pointer_width = "64")]
impl_carrying_mul_add! {
	usize { Wide = u128 },

	isize { Wide = i128 },
}

impl CarryingMulAdd for u128 {
	#[inline]
	fn carrying_mul_add(
			self,
			mul:   Self,
			add:   Self,
			carry: Self,
	) -> (Self::Unsigned, Self) {
		#[inline]
		#[must_use]
		const fn u128_to_parts(value: u128) -> (u128, u128) {
			let low  = value & u64::MAX as u128;
			let high = value >> u64::BITS;

			(low, high)
		}

		// Unoptimised Karatsuba:
		//
		// Any integer `i` of arbitrary width can be ex-
		// pressed as the sum of the least-significant half
		// `i0` and the most-significant half `i1` multi-
		// plied by the base `b` raised to the integer's
		// half-width `n / 2` (wherein `n` is the full-
		// width):
		//
		// ```
		// i = i0 + i1 * b ^ (n / 2)
		// ```
		//
		// Assuming `b = 2` and `n = 128` yields:
		//
		// ```
		// i = i0 + i1 * 2 ^ (128 / 2)
		//   = i0 + i1 * 2 ^ 64
		//   = i0 + (i1 << 64)
		// ```
		//
		// The product of two `128`-bit integers `a` and
		// `b` can thus be deconstructed:
		//
		// ```
		// a * b = (a0 + (a1 << 64)) * (b0 + (b1 << 64))
		//       = (a0 * b0)
		//         + (a0 * (b1 << 64))
		//         + ((a1 << 64) * b0)
		//         + ((a1 << 64) * (b1 << 64))
		//       = a0 * b0
		//         + (a0 * b1) << 64
		//         + (a1 * b0) << 64
		//         + (a1 * b1) << 128
		// ```
		//
		// Karatsuba has an optimised (i.e. it uses one
		// less multiplication) generalisation of this al-
		// gorithm.
		//
		// ...
		//
		// In our case, we add two `128`-bit integers into
		// a final `256`-bit buffer of two `128`-bit
		// halves.

		let lhs = u128_to_parts(self);
		let mul = u128_to_parts(mul);

		// Multiply caller multiplicand.
		//
		// Note that none of these terms can overflow (as
		// they are limited to `64` logical bits eachs).

		let tmp0 = lhs.0.wrapping_mul(mul.0);
		let tmp1 = lhs.0.wrapping_mul(mul.1);
		let tmp2 = lhs.1.wrapping_mul(mul.0);
		let tmp3 = lhs.1.wrapping_mul(mul.1);

		// Add the terms that fully or partially overlap
		// with the least-significant half, masking away
		// out-of-bounds bits.
		//
		// Remember to save carries for the most-significant.

		let carry0;
		let carry1;
		let mut low = tmp0;
		(low, carry0) = low.overflowing_add(tmp1 << (Self::BITS / 2));
		(low, carry1) = low.overflowing_add(tmp2 << (Self::BITS / 2));

		// Add caller addend.
		//
		// Save carry.

		let carry2;
		(low, carry2) = low.overflowing_add(add);

		// Add caller carry.
		//
		// Also save the carry from this operations.

		let carry3;
		(low, carry3) = low.overflowing_add(carry);

		// Add the terms that fully or partially overlap
		// with the most-significant half, masking away
		// out-of-bounds bits. Also add saved carries.
		//
		// Overflow should be impossible from now on.

		let mut high = tmp3;
		high = high.wrapping_add(tmp1 >> (Self::BITS / 2));
		high = high.wrapping_add(tmp2 >> (Self::BITS / 2));
		high = high.wrapping_add(carry0.into());
		high = high.wrapping_add(carry1.into());
		high = high.wrapping_add(carry2.into());
		high = high.wrapping_add(carry3.into());

		(low, high)
	}
}

impl CarryingMulAdd for i128 {
	#[inline]
	fn carrying_mul_add(
			self,
			mul:   Self,
			add:   Self,
			carry: Self,
	) -> (Self::Unsigned, Self) {
		// NOTE: Two's complement arithmetic is identical
		// to unsigned arithmetic, and since there is no
		// overflow to handle, we might just as well cast.

		let (low, high) = CarryingMulAdd::carrying_mul_add(
			self as u128,
			mul as u128,
			carry as u128,
			add as u128,
		);

		let high = high as Self;

		(low, high)
	}
}