Skip to main content

entrenar/generative/code_gan/
latent.rs

1//! Latent code representation for GAN latent space operations.
2
3use rand::Rng;
4
5/// Latent code representation (vector in latent space)
6#[derive(Debug, Clone, PartialEq)]
7pub struct LatentCode {
8    /// The latent vector
9    pub vector: Vec<f32>,
10}
11
12impl LatentCode {
13    /// Create a new latent code from a vector
14    #[must_use]
15    pub fn new(vector: Vec<f32>) -> Self {
16        Self { vector }
17    }
18
19    /// Sample from standard normal distribution using Box-Muller transform
20    pub fn sample<R: Rng>(rng: &mut R, dim: usize) -> Self {
21        contract_pre_sample!();
22        let vector: Vec<f32> = (0..dim)
23            .map(|_| {
24                // Box-Muller transform for standard normal
25                let u1: f64 = rng.random::<f64>().max(1e-10);
26                let u2: f64 = rng.random::<f64>();
27                ((-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos()) as f32
28            })
29            .collect();
30        Self { vector }
31    }
32
33    /// Dimension of the latent code
34    #[must_use]
35    pub fn dim(&self) -> usize {
36        self.vector.len()
37    }
38
39    /// Linear interpolation between two latent codes
40    #[must_use]
41    pub fn lerp(&self, other: &Self, t: f32) -> Self {
42        assert_eq!(self.dim(), other.dim(), "Latent dimensions must match");
43        let vector =
44            self.vector.iter().zip(&other.vector).map(|(a, b)| a * (1.0 - t) + b * t).collect();
45        Self { vector }
46    }
47
48    /// Spherical linear interpolation between two latent codes
49    #[must_use]
50    pub fn slerp(&self, other: &Self, t: f32) -> Self {
51        assert_eq!(self.dim(), other.dim(), "Latent dimensions must match");
52
53        let norm_self: f32 = self.vector.iter().map(|x| x * x).sum::<f32>().sqrt();
54        let norm_other: f32 = other.vector.iter().map(|x| x * x).sum::<f32>().sqrt();
55
56        // Fall back to lerp if either vector has near-zero norm
57        if norm_self < 1e-10 || norm_other < 1e-10 {
58            return self.lerp(other, t);
59        }
60
61        let dot: f32 = self.vector.iter().zip(&other.vector).map(|(a, b)| a * b).sum();
62
63        let cos_omega = (dot / (norm_self * norm_other)).clamp(-1.0, 1.0);
64        let omega = cos_omega.acos();
65
66        if omega.abs() < 1e-6 {
67            return self.lerp(other, t);
68        }
69
70        let sin_omega = omega.sin();
71        let factor_self = ((1.0 - t) * omega).sin() / sin_omega;
72        let factor_other = (t * omega).sin() / sin_omega;
73
74        let vector = self
75            .vector
76            .iter()
77            .zip(&other.vector)
78            .map(|(a, b)| a * factor_self + b * factor_other)
79            .collect();
80
81        Self { vector }
82    }
83
84    /// Compute L2 norm
85    #[must_use]
86    pub fn norm(&self) -> f32 {
87        self.vector.iter().map(|x| x * x).sum::<f32>().sqrt()
88    }
89
90    /// Normalize to unit length
91    #[must_use]
92    pub fn normalize(&self) -> Self {
93        let n = self.norm();
94        if n < 1e-10 {
95            return self.clone();
96        }
97        let vector = self.vector.iter().map(|x| x / n).collect();
98        Self { vector }
99    }
100}
101
102#[cfg(test)]
103mod tests {
104    use super::*;
105    use proptest::prelude::*;
106
107    #[test]
108    fn test_latent_code_creation() {
109        let code = LatentCode::new(vec![1.0, 2.0, 3.0]);
110        assert_eq!(code.dim(), 3);
111        assert_eq!(code.vector, vec![1.0, 2.0, 3.0]);
112    }
113
114    #[test]
115    fn test_latent_code_sample() {
116        use rand::SeedableRng;
117        let mut rng = rand::rngs::StdRng::seed_from_u64(42);
118        let code = LatentCode::sample(&mut rng, 128);
119        assert_eq!(code.dim(), 128);
120    }
121
122    #[test]
123    fn test_latent_code_lerp() {
124        let z1 = LatentCode::new(vec![0.0, 0.0]);
125        let z2 = LatentCode::new(vec![1.0, 1.0]);
126
127        let mid = z1.lerp(&z2, 0.5);
128        assert!((mid.vector[0] - 0.5).abs() < 1e-6);
129        assert!((mid.vector[1] - 0.5).abs() < 1e-6);
130
131        let start = z1.lerp(&z2, 0.0);
132        assert!((start.vector[0] - 0.0).abs() < 1e-6);
133
134        let end = z1.lerp(&z2, 1.0);
135        assert!((end.vector[0] - 1.0).abs() < 1e-6);
136    }
137
138    #[test]
139    fn test_latent_code_slerp() {
140        let z1 = LatentCode::new(vec![1.0, 0.0]);
141        let z2 = LatentCode::new(vec![0.0, 1.0]);
142
143        let mid = z1.slerp(&z2, 0.5);
144        // At midpoint, should have roughly equal components
145        assert!((mid.vector[0] - mid.vector[1]).abs() < 0.1);
146    }
147
148    #[test]
149    fn test_latent_code_norm() {
150        let code = LatentCode::new(vec![3.0, 4.0]);
151        assert!((code.norm() - 5.0).abs() < 1e-6);
152    }
153
154    #[test]
155    fn test_latent_code_normalize() {
156        let code = LatentCode::new(vec![3.0, 4.0]);
157        let normalized = code.normalize();
158        assert!((normalized.norm() - 1.0).abs() < 1e-6);
159    }
160
161    #[test]
162    fn test_slerp_maintains_norm() {
163        let z1 = LatentCode::new(vec![1.0, 0.0, 0.0]).normalize();
164        let z2 = LatentCode::new(vec![0.0, 1.0, 0.0]).normalize();
165
166        for i in 0..=10 {
167            let t = i as f32 / 10.0;
168            let z = z1.slerp(&z2, t);
169            // SLERP should maintain approximate unit norm
170            assert!((z.norm() - 1.0).abs() < 0.1);
171        }
172    }
173
174    proptest! {
175        #[test]
176        fn test_latent_lerp_bounds(t in 0.0f32..=1.0) {
177            let z1 = LatentCode::new(vec![0.0, 0.0, 0.0]);
178            let z2 = LatentCode::new(vec![1.0, 1.0, 1.0]);
179
180            let result = z1.lerp(&z2, t);
181
182            for v in &result.vector {
183                prop_assert!(*v >= 0.0 && *v <= 1.0);
184            }
185        }
186
187        #[test]
188        fn test_latent_norm_non_negative(values in prop::collection::vec(-10.0f32..10.0, 1..100)) {
189            let code = LatentCode::new(values);
190            prop_assert!(code.norm() >= 0.0);
191        }
192
193        #[test]
194        fn test_normalize_unit_length(values in prop::collection::vec(-10.0f32..10.0, 1..100)) {
195            let code = LatentCode::new(values);
196            if code.norm() > 1e-6 {
197                let normalized = code.normalize();
198                prop_assert!((normalized.norm() - 1.0).abs() < 1e-5);
199            }
200        }
201    }
202}