1use yscv_tensor::Tensor;
7
8use crate::ModelError;
9
10struct Rng(u64);
12
13impl Rng {
14 fn new(seed: u64) -> Self {
15 Self(if seed == 0 { 0xDEAD_BEEF } else { seed })
16 }
17
18 fn next_u64(&mut self) -> u64 {
19 let mut x = self.0;
20 x ^= x << 13;
21 x ^= x >> 7;
22 x ^= x << 17;
23 self.0 = x;
24 x
25 }
26
27 fn uniform(&mut self) -> f32 {
29 (self.next_u64() >> 40) as f32 / (1u64 << 24) as f32
30 }
31
32 fn uniform_range(&mut self, lo: f32, hi: f32) -> f32 {
34 lo + (hi - lo) * self.uniform()
35 }
36
37 fn normal(&mut self) -> f32 {
39 let u1 = self.uniform().max(1e-10);
40 let u2 = self.uniform();
41 (-2.0 * u1.ln()).sqrt() * (2.0 * std::f32::consts::PI * u2).cos()
42 }
43}
44
45pub fn kaiming_uniform(shape: Vec<usize>, fan_in: usize, seed: u64) -> Result<Tensor, ModelError> {
49 let bound = (6.0 / fan_in as f32).sqrt();
50 let n: usize = shape.iter().product();
51 let mut rng = Rng::new(seed);
52 let data: Vec<f32> = (0..n).map(|_| rng.uniform_range(-bound, bound)).collect();
53 Ok(Tensor::from_vec(shape, data)?)
54}
55
56pub fn kaiming_normal(shape: Vec<usize>, fan_in: usize, seed: u64) -> Result<Tensor, ModelError> {
60 let std = (2.0 / fan_in as f32).sqrt();
61 let n: usize = shape.iter().product();
62 let mut rng = Rng::new(seed);
63 let data: Vec<f32> = (0..n).map(|_| rng.normal() * std).collect();
64 Ok(Tensor::from_vec(shape, data)?)
65}
66
67pub fn xavier_uniform(
71 shape: Vec<usize>,
72 fan_in: usize,
73 fan_out: usize,
74 seed: u64,
75) -> Result<Tensor, ModelError> {
76 let bound = (6.0 / (fan_in + fan_out) as f32).sqrt();
77 let n: usize = shape.iter().product();
78 let mut rng = Rng::new(seed);
79 let data: Vec<f32> = (0..n).map(|_| rng.uniform_range(-bound, bound)).collect();
80 Ok(Tensor::from_vec(shape, data)?)
81}
82
83pub fn xavier_normal(
87 shape: Vec<usize>,
88 fan_in: usize,
89 fan_out: usize,
90 seed: u64,
91) -> Result<Tensor, ModelError> {
92 let std = (2.0 / (fan_in + fan_out) as f32).sqrt();
93 let n: usize = shape.iter().product();
94 let mut rng = Rng::new(seed);
95 let data: Vec<f32> = (0..n).map(|_| rng.normal() * std).collect();
96 Ok(Tensor::from_vec(shape, data)?)
97}
98
99pub fn orthogonal(rows: usize, cols: usize, seed: u64) -> Result<Tensor, ModelError> {
103 let n = rows.max(cols);
104 let mut rng = Rng::new(seed);
105
106 let mut mat: Vec<Vec<f32>> = (0..n)
108 .map(|_| (0..n).map(|_| rng.normal()).collect())
109 .collect();
110
111 for i in 0..n {
113 let norm: f32 = (0..n).map(|r| mat[r][i] * mat[r][i]).sum::<f32>().sqrt();
115 if norm > 1e-10 {
116 for r in 0..n {
117 mat[r][i] /= norm;
118 }
119 }
120 for j in (i + 1)..n {
122 let dot: f32 = (0..n).map(|r| mat[r][i] * mat[r][j]).sum();
123 for r in 0..n {
124 mat[r][j] -= dot * mat[r][i];
125 }
126 }
127 }
128
129 let mut data = Vec::with_capacity(rows * cols);
131 for r in 0..rows {
132 for c in 0..cols {
133 data.push(mat[r][c]);
134 }
135 }
136 Ok(Tensor::from_vec(vec![rows, cols], data)?)
137}
138
139pub fn constant(shape: Vec<usize>, value: f32) -> Result<Tensor, ModelError> {
141 Ok(Tensor::filled(shape, value)?)
142}
143
144#[cfg(test)]
145mod tests {
146 use super::*;
147
148 #[test]
149 fn kaiming_uniform_shape_and_bounds() {
150 let t = kaiming_uniform(vec![64, 32], 32, 42).unwrap();
151 assert_eq!(t.shape(), &[64, 32]);
152 let bound = (6.0f32 / 32.0).sqrt();
153 for &v in t.data() {
154 assert!(v >= -bound && v <= bound, "value {v} out of bounds");
155 }
156 }
157
158 #[test]
159 fn kaiming_normal_shape() {
160 let t = kaiming_normal(vec![128, 64], 64, 42).unwrap();
161 assert_eq!(t.shape(), &[128, 64]);
162 let mean: f32 = t.data().iter().sum::<f32>() / t.data().len() as f32;
164 assert!(mean.abs() < 0.1, "mean {mean} too far from 0");
165 }
166
167 #[test]
168 fn xavier_uniform_shape_and_bounds() {
169 let t = xavier_uniform(vec![100, 50], 100, 50, 42).unwrap();
170 assert_eq!(t.shape(), &[100, 50]);
171 let bound = (6.0f32 / 150.0).sqrt();
172 for &v in t.data() {
173 assert!(v >= -bound && v <= bound);
174 }
175 }
176
177 #[test]
178 fn xavier_normal_shape() {
179 let t = xavier_normal(vec![100, 50], 100, 50, 42).unwrap();
180 assert_eq!(t.shape(), &[100, 50]);
181 let mean: f32 = t.data().iter().sum::<f32>() / t.data().len() as f32;
182 assert!(mean.abs() < 0.1);
183 }
184
185 #[test]
186 fn orthogonal_produces_orthonormal_columns() {
187 let t = orthogonal(4, 4, 42).unwrap();
188 let d = t.data();
189 let norm: f32 = (0..4).map(|r| d[r * 4] * d[r * 4]).sum::<f32>().sqrt();
191 assert!((norm - 1.0).abs() < 1e-4, "col 0 norm = {norm}");
192 let dot: f32 = (0..4).map(|r| d[r * 4] * d[r * 4 + 1]).sum();
194 assert!(dot.abs() < 1e-4, "dot = {dot}");
195 }
196
197 #[test]
198 fn constant_fills_with_value() {
199 let t = constant(vec![2, 3], 7.0).unwrap();
200 assert_eq!(t.data(), &[7.0; 6]);
201 }
202}