entrenar/generative/code_gan/
latent.rs1use rand::Rng;
4
5#[derive(Debug, Clone, PartialEq)]
7pub struct LatentCode {
8 pub vector: Vec<f32>,
10}
11
12impl LatentCode {
13 #[must_use]
15 pub fn new(vector: Vec<f32>) -> Self {
16 Self { vector }
17 }
18
19 pub fn sample<R: Rng>(rng: &mut R, dim: usize) -> Self {
21 contract_pre_sample!();
22 let vector: Vec<f32> = (0..dim)
23 .map(|_| {
24 let u1: f64 = rng.random::<f64>().max(1e-10);
26 let u2: f64 = rng.random::<f64>();
27 ((-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos()) as f32
28 })
29 .collect();
30 Self { vector }
31 }
32
33 #[must_use]
35 pub fn dim(&self) -> usize {
36 self.vector.len()
37 }
38
39 #[must_use]
41 pub fn lerp(&self, other: &Self, t: f32) -> Self {
42 assert_eq!(self.dim(), other.dim(), "Latent dimensions must match");
43 let vector =
44 self.vector.iter().zip(&other.vector).map(|(a, b)| a * (1.0 - t) + b * t).collect();
45 Self { vector }
46 }
47
48 #[must_use]
50 pub fn slerp(&self, other: &Self, t: f32) -> Self {
51 assert_eq!(self.dim(), other.dim(), "Latent dimensions must match");
52
53 let norm_self: f32 = self.vector.iter().map(|x| x * x).sum::<f32>().sqrt();
54 let norm_other: f32 = other.vector.iter().map(|x| x * x).sum::<f32>().sqrt();
55
56 if norm_self < 1e-10 || norm_other < 1e-10 {
58 return self.lerp(other, t);
59 }
60
61 let dot: f32 = self.vector.iter().zip(&other.vector).map(|(a, b)| a * b).sum();
62
63 let cos_omega = (dot / (norm_self * norm_other)).clamp(-1.0, 1.0);
64 let omega = cos_omega.acos();
65
66 if omega.abs() < 1e-6 {
67 return self.lerp(other, t);
68 }
69
70 let sin_omega = omega.sin();
71 let factor_self = ((1.0 - t) * omega).sin() / sin_omega;
72 let factor_other = (t * omega).sin() / sin_omega;
73
74 let vector = self
75 .vector
76 .iter()
77 .zip(&other.vector)
78 .map(|(a, b)| a * factor_self + b * factor_other)
79 .collect();
80
81 Self { vector }
82 }
83
84 #[must_use]
86 pub fn norm(&self) -> f32 {
87 self.vector.iter().map(|x| x * x).sum::<f32>().sqrt()
88 }
89
90 #[must_use]
92 pub fn normalize(&self) -> Self {
93 let n = self.norm();
94 if n < 1e-10 {
95 return self.clone();
96 }
97 let vector = self.vector.iter().map(|x| x / n).collect();
98 Self { vector }
99 }
100}
101
102#[cfg(test)]
103mod tests {
104 use super::*;
105 use proptest::prelude::*;
106
107 #[test]
108 fn test_latent_code_creation() {
109 let code = LatentCode::new(vec![1.0, 2.0, 3.0]);
110 assert_eq!(code.dim(), 3);
111 assert_eq!(code.vector, vec![1.0, 2.0, 3.0]);
112 }
113
114 #[test]
115 fn test_latent_code_sample() {
116 use rand::SeedableRng;
117 let mut rng = rand::rngs::StdRng::seed_from_u64(42);
118 let code = LatentCode::sample(&mut rng, 128);
119 assert_eq!(code.dim(), 128);
120 }
121
122 #[test]
123 fn test_latent_code_lerp() {
124 let z1 = LatentCode::new(vec![0.0, 0.0]);
125 let z2 = LatentCode::new(vec![1.0, 1.0]);
126
127 let mid = z1.lerp(&z2, 0.5);
128 assert!((mid.vector[0] - 0.5).abs() < 1e-6);
129 assert!((mid.vector[1] - 0.5).abs() < 1e-6);
130
131 let start = z1.lerp(&z2, 0.0);
132 assert!((start.vector[0] - 0.0).abs() < 1e-6);
133
134 let end = z1.lerp(&z2, 1.0);
135 assert!((end.vector[0] - 1.0).abs() < 1e-6);
136 }
137
138 #[test]
139 fn test_latent_code_slerp() {
140 let z1 = LatentCode::new(vec![1.0, 0.0]);
141 let z2 = LatentCode::new(vec![0.0, 1.0]);
142
143 let mid = z1.slerp(&z2, 0.5);
144 assert!((mid.vector[0] - mid.vector[1]).abs() < 0.1);
146 }
147
148 #[test]
149 fn test_latent_code_norm() {
150 let code = LatentCode::new(vec![3.0, 4.0]);
151 assert!((code.norm() - 5.0).abs() < 1e-6);
152 }
153
154 #[test]
155 fn test_latent_code_normalize() {
156 let code = LatentCode::new(vec![3.0, 4.0]);
157 let normalized = code.normalize();
158 assert!((normalized.norm() - 1.0).abs() < 1e-6);
159 }
160
161 #[test]
162 fn test_slerp_maintains_norm() {
163 let z1 = LatentCode::new(vec![1.0, 0.0, 0.0]).normalize();
164 let z2 = LatentCode::new(vec![0.0, 1.0, 0.0]).normalize();
165
166 for i in 0..=10 {
167 let t = i as f32 / 10.0;
168 let z = z1.slerp(&z2, t);
169 assert!((z.norm() - 1.0).abs() < 0.1);
171 }
172 }
173
174 proptest! {
175 #[test]
176 fn test_latent_lerp_bounds(t in 0.0f32..=1.0) {
177 let z1 = LatentCode::new(vec![0.0, 0.0, 0.0]);
178 let z2 = LatentCode::new(vec![1.0, 1.0, 1.0]);
179
180 let result = z1.lerp(&z2, t);
181
182 for v in &result.vector {
183 prop_assert!(*v >= 0.0 && *v <= 1.0);
184 }
185 }
186
187 #[test]
188 fn test_latent_norm_non_negative(values in prop::collection::vec(-10.0f32..10.0, 1..100)) {
189 let code = LatentCode::new(values);
190 prop_assert!(code.norm() >= 0.0);
191 }
192
193 #[test]
194 fn test_normalize_unit_length(values in prop::collection::vec(-10.0f32..10.0, 1..100)) {
195 let code = LatentCode::new(values);
196 if code.norm() > 1e-6 {
197 let normalized = code.normalize();
198 prop_assert!((normalized.norm() - 1.0).abs() < 1e-5);
199 }
200 }
201 }
202}