use svod_dtype::DType;
use svod_ir::ConstValue;
use crate::{Error, Result, Tensor};
const TWO_PI: f64 = 2.0 * std::f64::consts::PI;
fn fan_in(shape: &[usize]) -> usize {
shape.iter().skip(1).copied().product::<usize>().max(1)
}
impl Tensor {
pub fn uniform(shape: &[usize], low: f64, high: f64) -> Result<Tensor> {
Self::uniform_with_dtype(shape, low, high, DType::Float32)
}
pub fn uniform_with_dtype(shape: &[usize], low: f64, high: f64, dtype: DType) -> Result<Tensor> {
if low >= high {
return Err(Error::ParamRange {
op: "Tensor::uniform",
param: "low/high",
value: format!("low={low}, high={high}"),
constraint: "low < high",
});
}
let u = Tensor::rand(shape)?;
let scale = u.broadcast_scalar(ConstValue::Float(high - low))?;
let scaled = u.try_mul(&scale)?.cast(dtype)?;
let offset = scaled.broadcast_scalar(ConstValue::Float(low))?;
scaled.try_add(&offset)
}
pub fn randn(shape: &[usize]) -> Result<Tensor> {
let mut combined_shape: Vec<usize> = Vec::with_capacity(shape.len() + 1);
combined_shape.push(2);
combined_shape.extend_from_slice(shape);
let src = Tensor::rand(&combined_shape)?;
let mut shrink_u1: Vec<Option<(isize, isize)>> = Vec::with_capacity(combined_shape.len());
shrink_u1.push(Some((0, 1)));
shrink_u1.extend(std::iter::repeat_n(None, shape.len()));
let mut shrink_u2: Vec<Option<(isize, isize)>> = Vec::with_capacity(combined_shape.len());
shrink_u2.push(Some((1, 2)));
shrink_u2.extend(std::iter::repeat_n(None, shape.len()));
let target_shape: Vec<isize> = shape.iter().map(|&d| d as isize).collect();
let u1 = src.try_shrink(shrink_u1)?.try_reshape(&target_shape)?;
let u2 = src.try_shrink(shrink_u2)?.try_reshape(&target_shape)?;
let two_pi = u1.broadcast_scalar(ConstValue::Float(TWO_PI))?;
let theta = u1.try_mul(&two_pi)?.cos()?;
let one = u2.broadcast_scalar(ConstValue::Float(1.0))?;
let neg_two = u2.broadcast_scalar(ConstValue::Float(-2.0))?;
let r = one.try_sub(&u2)?.try_log()?.try_mul(&neg_two)?.try_sqrt()?;
theta.try_mul(&r)
}
pub fn normal(shape: &[usize], mean: f64, std: f64) -> Result<Tensor> {
if std < 0.0 {
return Err(Error::ParamRange {
op: "Tensor::normal",
param: "std",
value: format!("{std}"),
constraint: ">= 0",
});
}
let z = Tensor::randn(shape)?;
let std_t = z.broadcast_scalar(ConstValue::Float(std))?;
let mean_t = z.broadcast_scalar(ConstValue::Float(mean))?;
z.try_mul(&std_t)?.try_add(&mean_t)
}
pub fn randint(shape: &[usize], low: i64, high: i64) -> Result<Tensor> {
if low >= high {
return Err(Error::ParamRange {
op: "Tensor::randint",
param: "low/high",
value: format!("low={low}, high={high}"),
constraint: "low < high",
});
}
let scaled = Tensor::rand(shape)?;
let range = scaled.broadcast_scalar(ConstValue::Float((high - low) as f64))?;
let truncated = scaled.try_mul(&range)?.cast(DType::Int32)?;
let offset = truncated.broadcast_scalar(ConstValue::Int(low))?;
truncated.try_add(&offset)
}
pub fn scaled_uniform(shape: &[usize]) -> Result<Tensor> {
let numel: usize = shape.iter().copied().product::<usize>().max(1);
let scale = (numel as f64).powf(-0.5);
let u = Tensor::uniform(shape, -1.0, 1.0)?;
let scale_t = u.broadcast_scalar(ConstValue::Float(scale))?;
u.try_mul(&scale_t)
}
pub fn glorot_uniform(shape: &[usize]) -> Result<Tensor> {
Self::glorot_uniform_with_dtype(shape, DType::Float32)
}
pub fn glorot_uniform_with_dtype(shape: &[usize], dtype: DType) -> Result<Tensor> {
if shape.is_empty() {
return Err(Error::ParamRange {
op: "Tensor::glorot_uniform",
param: "shape",
value: "[]".to_string(),
constraint: "at least 1D",
});
}
let fan_in_v = fan_in(shape);
let fan_out_v = shape[0];
let bound = (6.0 / (fan_out_v + fan_in_v) as f64).sqrt();
Self::uniform_with_dtype(shape, -bound, bound, dtype)
}
pub fn kaiming_uniform(shape: &[usize], a: f64) -> Result<Tensor> {
Self::kaiming_uniform_with_dtype(shape, a, DType::Float32)
}
pub fn kaiming_uniform_with_dtype(shape: &[usize], a: f64, dtype: DType) -> Result<Tensor> {
if shape.is_empty() {
return Err(Error::ParamRange {
op: "Tensor::kaiming_uniform",
param: "shape",
value: "[]".to_string(),
constraint: "at least 1D",
});
}
let bound = (6.0 / ((1.0 + a * a) * fan_in(shape) as f64)).sqrt();
Self::uniform_with_dtype(shape, -bound, bound, dtype)
}
pub fn kaiming_normal(shape: &[usize], a: f64) -> Result<Tensor> {
if shape.is_empty() {
return Err(Error::ParamRange {
op: "Tensor::kaiming_normal",
param: "shape",
value: "[]".to_string(),
constraint: "at least 1D",
});
}
let std = (2.0 / ((1.0 + a * a) * fan_in(shape) as f64)).sqrt();
let z = Tensor::randn(shape)?;
let std_t = z.broadcast_scalar(ConstValue::Float(std))?;
z.try_mul(&std_t)
}
}