use snafu::ResultExt;
use svod_dtype::{DType, DeviceSpec, ScalarDType};
use svod_ir::{ConstValue, UOp, shape::Shape, shape::to_vec_usize};
use crate::{Error, Result, Tensor, UOpSnafu};
use super::state;
impl Tensor {
pub fn rand(shape: &[usize]) -> Result<Tensor> {
Self::rand_with(shape, DType::Float32, DeviceSpec::Cpu)
}
pub fn rand_with(shape: &[usize], dtype: DType, device: DeviceSpec) -> Result<Tensor> {
let scalar = dtype.scalar().ok_or_else(|| Error::SymbolicShapeUnsupported {
operation: format!("Tensor::rand: non-scalar dtype {dtype:?}"),
})?;
if !scalar.is_float() {
return Err(Error::SymbolicShapeUnsupported {
operation: format!(
"Tensor::rand: float dtype required, got {scalar:?}; use Tensor::randint for integers"
),
});
}
let numel: usize = shape.iter().product();
if numel == 0 {
return Tensor::zeros(shape, dtype);
}
let num_words = (numel * scalar.bytes()).div_ceil(4) as u64;
let (seed, counter_val) = state::next_counter(&device, num_words);
let bits = random_bits(&seed, counter_val, num_words as usize)?;
bits_to_rand(&bits, shape, dtype)
}
}
fn random_bits(seed: &Tensor, counter_val: u64, num: usize) -> Result<Tensor> {
let u32_dt = DType::Scalar(ScalarDType::UInt32);
let c_low = Tensor::full(&[1], (counter_val & 0xFFFF_FFFF) as u32, u32_dt.clone())?;
let c_high = Tensor::full(&[1], (counter_val >> 32) as u32, u32_dt.clone())?;
let new_key = threefry_random_bits(seed, &c_low, &c_high)?;
let half = num.div_ceil(2);
let counts0 = Tensor::arange(0, Some(half as i64), None)?.cast(u32_dt.clone())?;
let half_t = Tensor::full(&[half], half as u32, u32_dt)?;
let counts1 = counts0.try_add(&half_t)?;
let bits_full = threefry_random_bits(&new_key, &counts0, &counts1)?;
bits_full.try_shrink([(0usize, num)])
}
pub(crate) fn threefry_random_bits(key: &Tensor, counts0: &Tensor, counts1: &Tensor) -> Result<Tensor> {
let u32_dt = DType::Scalar(ScalarDType::UInt32);
let u64_dt = DType::Scalar(ScalarDType::UInt64);
let counts_shape: Shape = counts0.shape()?;
let shift_32 = Tensor::full(&to_vec_usize(&counts_shape).context(UOpSnafu)?, 32u32, u64_dt.clone())?;
let c0_u64 = counts0.cast(u64_dt.clone())?;
let c1_u64 = counts1.cast(u64_dt.clone())?;
let c1_shifted = c1_u64.try_shl(&shift_32)?;
let x = c1_shifted.try_bitor(&c0_u64)?;
let k_shift_32 = Tensor::full(&[1], 32u32, u64_dt.clone())?;
let k0 = key.try_shrink([(0usize, 1usize)])?.cast(u64_dt.clone())?;
let k1 = key.try_shrink([(1usize, 2usize)])?.cast(u64_dt.clone())?;
let key_packed = k1.try_shl(&k_shift_32)?.try_bitor(&k0)?;
let key_broadcast = key_packed.broadcast_to(&counts_shape)?;
let result_uop = UOp::threefry(x.uop().clone(), key_broadcast.uop().clone()).context(UOpSnafu)?;
let result = Tensor::from_lazy(result_uop);
let mask_u64 = Tensor::full(&to_vec_usize(&counts_shape).context(UOpSnafu)?, 0xFFFF_FFFFu64, u64_dt)?;
let lo_u64 = result.try_bitand(&mask_u64)?;
let lo = lo_u64.cast(u32_dt.clone())?;
let hi_u64 = result.try_shr(&shift_32)?.try_bitand(&mask_u64)?;
let hi = hi_u64.cast(u32_dt)?;
Tensor::cat(&[&lo, &hi], 0)
}
fn bits_to_rand(bits: &Tensor, shape: &[usize], dtype: DType) -> Result<Tensor> {
let scalar = dtype.scalar().expect("scalar dtype validated by rand_with");
let (_, nmant) = scalar.finfo();
let uint_dt = DType::Scalar(scalar.float_to_uint());
let total_bits = (scalar.bytes() * 8) as u32;
let shift = total_bits - nmant;
let uint_bits = bits.bitcast(uint_dt.clone())?;
let bits_shape_concrete = to_vec_usize(&uint_bits.shape()?).context(UOpSnafu)?;
let shift_t = Tensor::full(&bits_shape_concrete, ConstValue::UInt(shift as u64), uint_dt.clone())?;
let shifted = uint_bits.try_shr(&shift_t)?;
let one_bits = ConstValue::UInt(one_bits_for(scalar));
let one_bits_t = Tensor::full(&bits_shape_concrete, one_bits, uint_dt)?;
let or_ed = shifted.try_bitor(&one_bits_t)?;
let in_one_two = or_ed.bitcast(dtype.clone())?;
let one_f = Tensor::full(&bits_shape_concrete, ConstValue::Float(1.0), dtype)?;
let in_unit = in_one_two.try_sub(&one_f)?;
let numel: usize = shape.iter().product();
let trimmed = in_unit.try_shrink([(0usize, numel)])?;
let isize_shape: Vec<isize> = shape.iter().map(|&d| d as isize).collect();
trimmed.try_reshape(&isize_shape)
}
fn one_bits_for(s: ScalarDType) -> u64 {
match s {
ScalarDType::Float16 => 0x3C00,
ScalarDType::BFloat16 => 0x3F80,
ScalarDType::Float32 => 0x3F80_0000,
ScalarDType::Float64 => 0x3FF0_0000_0000_0000,
_ => panic!("one_bits_for: non-float dtype {s:?}"),
}
}