cairo-native 0.9.0-rc.3

A compiler to convert Cairo's IR Sierra code to MLIR and execute it.
//! Wide square operation.
//!
//! This module provides the [`WideSquare`] trait which enables squaring operations
//! that return a result type with double the bit width of the input type.
//! This is particularly useful when you need to square a number without
//! worrying about overflow, as the result type can hold the full range of possible values.
//!
//! # Examples
//!
//! ```
//! use core::num::traits::WideSquare;
//!
//! // Squaring a `u8` value to get a `u16` result
//! let a: u8 = 200;
//! let result: u16 = a.wide_square();
//! assert!(result == 40000);
//!
//! // Squaring a `u128` value to get a `u256` result
//! let x: u128 = 0xffffffffffffffffffffffffffffffff; // max u128
//! let wide_result: u256 = x.wide_square(); // No overflow occurs
//! assert!(wide_result == 0xfffffffffffffffffffffffffffffffe00000000000000000000000000000001);
//! ```
//!
//! # Available Implementations
//!
//! The trait is implemented for the following type pairs:
//! - `i8` → `i16`
//! - `i16` → `i32`
//! - `i32` → `i64`
//! - `i64` → `i128`
//! - `u8` → `u16`
//! - `u16` → `u32`
//! - `u32` → `u64`
//! - `u64` → `u128`
//! - `u128` → `u256`
//! - `u256` → `u512`

use crate::num::traits::WideMul;

/// A trait for a type that can be squared to produce a wider type.
///
/// This trait enables squaring operations where the result type has double
/// the bit width of the input type, preventing overflow in cases where the
/// result would exceed the input type's maximum value.
///
/// # Examples
///
/// ```
/// use core::num::traits::WideSquare;
///
/// let a: u8 = 16;
/// let result: u16 = a.wide_square();
/// assert!(result == 256);
/// ```
pub trait WideSquare<T> {
    /// The type of the result of the square.
    type Target;
    /// Calculates the square, producing a wider type.
    fn wide_square(self: T) -> Self::Target;
}

mod wide_mul_based {
    pub impl TWideSquare<T, impl TWideMul: super::WideMul<T, T>, +Copy<T>> of super::WideSquare<T> {
        type Target = TWideMul::Target;
        fn wide_square(self: T) -> Self::Target {
            TWideMul::wide_mul(self, self)
        }
    }
}

impl WideSquareI8 = wide_mul_based::TWideSquare<i8>;
impl WideSquareI16 = wide_mul_based::TWideSquare<i16>;
impl WideSquareI32 = wide_mul_based::TWideSquare<i32>;
impl WideSquareI64 = wide_mul_based::TWideSquare<i64>;
impl WideSquareU8 = wide_mul_based::TWideSquare<u8>;
impl WideSquareU16 = wide_mul_based::TWideSquare<u16>;
impl WideSquareU32 = wide_mul_based::TWideSquare<u32>;
impl WideSquareU64 = wide_mul_based::TWideSquare<u64>;
impl WideSquareU128 = wide_mul_based::TWideSquare<u128>;
impl WideSquareU256 of WideSquare<u256> {
    type Target = crate::integer::u512;
    fn wide_square(self: u256) -> Self::Target {
        inner::u256_wide_square(self)
    }
}

#[feature("bounded-int-utils")]
mod inner {
    use crate::integer::{u128_add_with_bounded_int_carry, u512};
    use crate::internal::bounded_int;
    use crate::num::traits::{WideMul, WideSquare, WrappingAdd};

    pub fn u256_wide_square(value: u256) -> u512 {
        let u256 { high: limb1, low: limb0 } = value.low.wide_square();
        let u256 { high: limb2, low: limb1_part } = value.low.wide_mul(value.high);
        let (limb1, limb1_overflow0) = u128_add_with_bounded_int_carry(limb1, limb1_part);
        let (limb1, limb1_overflow1) = u128_add_with_bounded_int_carry(limb1, limb1_part);
        let (limb2, limb2_overflow0) = u128_add_with_bounded_int_carry(limb2, limb2);
        let u256 { high: limb3, low: limb2_part } = value.high.wide_square();
        let (limb2, limb2_overflow1) = u128_add_with_bounded_int_carry(limb2, limb2_part);
        // Packing together the overflow bits, making a cheaper addition into limb2.
        let limb1_overflow = bounded_int::add(limb1_overflow0, limb1_overflow1);
        let (limb2, limb2_overflow2) = u128_add_with_bounded_int_carry(
            limb2, bounded_int::upcast(limb1_overflow),
        );
        // Packing together the overflow bits, making a cheaper addition into limb3.
        let limb2_overflow = bounded_int::add(limb2_overflow0, limb2_overflow1);
        let limb2_overflow = bounded_int::add(limb2_overflow, limb2_overflow2);
        // No overflow since no limb4.
        let limb3 = limb3.wrapping_add(bounded_int::upcast(limb2_overflow));
        u512 { limb0, limb1, limb2, limb3 }
    }
}