#![allow(clippy::use_self)] use serde::{Deserialize, Serialize};
use smallvec::SmallVec;
use tch::{Device, Kind, Tensor};
use thiserror::Error;
#[derive(Debug, Copy, Clone, PartialEq, Serialize, Deserialize)]
pub enum Initializer {
Zeros,
Constant(f64),
Uniform(VarianceScale),
Normal(VarianceScale),
Orthogonal,
}
impl Default for Initializer {
fn default() -> Self {
Self::Uniform(VarianceScale::FanAvg)
}
}
#[derive(Debug, Copy, Clone, PartialEq, Serialize, Deserialize)]
pub enum VarianceScale {
Constant(f64),
FanIn,
FanOut,
FanAvg,
}
impl Default for VarianceScale {
fn default() -> Self {
Self::FanIn
}
}
impl VarianceScale {
fn variance(self, shape: &[usize], fan_in: Option<usize>, fan_out: Option<usize>) -> f64 {
let (fan_in_calc, fan_out_calc) = calculate_fan_in_and_fan_out(shape);
let fan_in = fan_in.unwrap_or(fan_in_calc);
let fan_out = fan_out.unwrap_or(fan_out_calc);
match self {
Self::Constant(v) => v,
Self::FanIn => (fan_in as f64).recip(),
Self::FanOut => (fan_out as f64).recip(),
Self::FanAvg => 2.0 / (fan_in as f64 + fan_out as f64),
}
}
}
fn calculate_fan_in_and_fan_out(shape: &[usize]) -> (usize, usize) {
let num_input_fmaps = shape.get(1).copied().unwrap_or(1);
let num_output_fmaps = shape.get(0).copied().unwrap_or(1);
let receptive_field_size: usize = if shape.len() >= 2 {
shape[2..].iter().product()
} else {
1
};
let fan_in = num_input_fmaps * receptive_field_size;
let fan_out = num_output_fmaps * receptive_field_size;
(fan_in, fan_out)
}
impl Initializer {
#[must_use]
#[inline]
pub const fn tensor<'a>(&'a self, shape: &'a [usize]) -> TensorBuilder<'a> {
TensorBuilder::new(self, shape)
}
}
#[derive(Debug, Copy, Clone, PartialEq)]
pub struct TensorBuilder<'a> {
initializer: &'a Initializer,
shape: &'a [usize],
gain: f64,
fan_in: Option<usize>,
fan_out: Option<usize>,
requires_grad: bool,
kind: Kind,
device: Device,
}
impl<'a> TensorBuilder<'a> {
#[must_use]
#[inline]
pub const fn new(initializer: &'a Initializer, shape: &'a [usize]) -> Self {
Self {
initializer,
shape,
gain: 1.0,
fan_in: None,
fan_out: None,
requires_grad: true,
kind: Kind::Float,
device: Device::Cpu,
}
}
pub fn build(&self) -> Tensor {
let options = (self.kind, self.device);
let shape_i64: SmallVec<[i64; 8]> =
self.shape.iter().map(|&d| d.try_into().unwrap()).collect();
let tensor = match &self.initializer {
Initializer::Zeros => Tensor::zeros(&shape_i64, options),
Initializer::Constant(v) => Tensor::full(&shape_i64, *v, options),
Initializer::Uniform(scaling) => {
let lim = self.gain
* (3.0 * scaling.variance(self.shape, self.fan_in, self.fan_out)).sqrt();
Tensor::empty(&shape_i64, options).uniform_(-lim, lim)
}
Initializer::Normal(scaling) => {
let mean = 0.0;
let stddev = self.gain
* scaling
.variance(self.shape, self.fan_in, self.fan_out)
.sqrt();
Tensor::empty(&shape_i64, options).normal_(mean, stddev)
}
Initializer::Orthogonal => init_orthogonal(&shape_i64, self.gain, options),
};
tensor.set_requires_grad(self.requires_grad)
}
#[must_use]
#[inline]
pub const fn gain(mut self, gain: f64) -> Self {
self.gain = gain;
self
}
#[must_use]
#[inline]
pub const fn fan_in(mut self, fan_in: usize) -> Self {
self.fan_in = Some(fan_in);
self
}
#[must_use]
#[inline]
pub const fn fan_out(mut self, fan_out: usize) -> Self {
self.fan_out = Some(fan_out);
self
}
#[must_use]
#[inline]
pub const fn requires_grad(mut self, requires_grad: bool) -> Self {
self.requires_grad = requires_grad;
self
}
#[inline]
pub const fn kind(mut self, kind: Kind) -> Result<Self, InitializeTensorError> {
use Kind::*;
match kind {
Half | Float | Double | ComplexHalf | ComplexFloat | ComplexDouble | BFloat16 => {}
_ => return Err(InitializeTensorError::InvalidKind(kind)),
}
self.kind = kind;
Ok(self)
}
#[must_use]
#[inline]
pub const fn device(mut self, device: Device) -> Self {
self.device = device;
self
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Error)]
pub enum InitializeTensorError {
#[error("unsupported kind {0:?}; expected a float or complex type")]
InvalidKind(Kind),
}
fn init_orthogonal(shape: &[i64], gain: f64, options: (Kind, Device)) -> Tensor {
assert!(
shape.len() >= 2,
"tensor for orthogonal init must be at least 2D",
);
let _no_grad = tch::no_grad_guard();
let num_rows = shape[0];
let num_cols: i64 = shape[1..].iter().product();
let mut flattened = Tensor::empty(&[num_rows, num_cols], options).normal_(0.0, 1.0);
if num_rows < num_cols {
let _ = flattened.t_();
}
let (mut q, r) = Tensor::linalg_qr(&flattened, "reduced");
let d = r.diag(0);
let ph = d.sign();
q *= ph;
if num_rows < num_cols {
let _ = q.t_();
}
#[allow(clippy::float_cmp)] if gain != 1.0 {
q *= gain;
}
q = q.reshape(shape);
let mut out = Tensor::empty(shape, options);
out.copy_(&q);
out
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn zeros() {
let a = Initializer::Zeros.tensor(&[5]).build();
assert_eq!(a, Tensor::zeros(&[5], (Kind::Float, Device::Cpu)));
}
#[test]
fn constant() {
let a = Initializer::Constant(2.0).tensor(&[5]).build();
assert_eq!(a, Tensor::full(&[5], 2.0, (Kind::Float, Device::Cpu)));
}
#[test]
fn orthogonal_is_orthogonal() {
let n = 5;
let a = Initializer::Orthogonal.tensor(&[n, n]).build();
assert!(a.matmul(&a.tr()).allclose(
&Tensor::eye(n as i64, (Kind::Float, Device::Cpu)),
1e-4,
1e-4,
false
));
}
#[test]
fn shape() {
let a = Initializer::default().tensor(&[2, 3]).build();
assert_eq!(a.size(), [2, 3]);
}
#[test]
fn gain() {
let a = Initializer::Uniform(VarianceScale::Constant(1.0 / 3.0))
.tensor(&[100])
.gain(0.1)
.build();
let max = f32::from(a.max());
assert!(max <= 0.1, "{max:?}");
assert!(max >= 0.075, "{max:?}");
}
#[test]
fn fan_in_default() {
let a = Initializer::Uniform(VarianceScale::FanIn)
.tensor(&[1, 100])
.build();
let max = f32::from(a.max());
assert!(max <= 0.174, "{max:?}");
assert!(max >= 0.173 * 0.75, "{max:?}");
}
#[test]
fn fan_in() {
let a = Initializer::Uniform(VarianceScale::FanIn)
.tensor(&[1, 100])
.fan_in(1)
.build();
let max = f32::from(a.max());
assert!(max <= 1.74, "{max:?}");
assert!(max >= 1.73 * 0.75, "{max:?}");
}
#[test]
fn fan_out_default() {
let a = Initializer::Uniform(VarianceScale::FanOut)
.tensor(&[100, 1])
.build();
let max = f32::from(a.max());
assert!(max <= 0.174, "{max:?}");
assert!(max >= 0.173 * 0.75, "{max:?}");
}
#[test]
fn fan_out() {
let a = Initializer::Uniform(VarianceScale::FanOut)
.tensor(&[100, 1])
.fan_out(1)
.build();
let max = f32::from(a.max());
assert!(max <= 1.74, "{max:?}");
assert!(max >= 1.73 * 0.75, "{max:?}");
}
#[test]
fn requires_grad_default() {
let a = Initializer::default().tensor(&[2]).build();
assert!(a.requires_grad());
}
#[test]
fn requires_grad_true() {
let a = Initializer::default()
.tensor(&[2])
.requires_grad(true)
.build();
assert!(a.requires_grad());
}
#[test]
fn requires_grad_false() {
let a = Initializer::default()
.tensor(&[2])
.requires_grad(false)
.build();
assert!(!a.requires_grad());
}
#[test]
fn kind_default_float() {
let a = Initializer::default().tensor(&[2]).build();
assert_eq!(a.kind(), Kind::Float);
}
#[test]
fn kind_double() {
let a = Initializer::default()
.tensor(&[2])
.kind(Kind::Double)
.unwrap()
.build();
assert_eq!(a.kind(), Kind::Double);
}
#[test]
fn kind_complex() {
let a = Initializer::default()
.tensor(&[2])
.kind(Kind::ComplexFloat)
.unwrap()
.build();
assert_eq!(a.kind(), Kind::ComplexFloat);
}
#[test]
fn kind_int_error() {
assert!(Initializer::default().tensor(&[2]).kind(Kind::Int).is_err());
}
#[test]
fn device_default_cpu() {
let a = Initializer::default().tensor(&[2]).build();
assert_eq!(a.device(), Device::Cpu);
}
#[test]
fn device_cuda_if_available() {
let device = Device::cuda_if_available();
let a = Initializer::default().tensor(&[2]).device(device).build();
assert_eq!(a.device(), device);
}
}