Skip to main content

oxirs_embed/
projection_layer.rs

1//! Linear embedding projection layer (dimensionality reduction or expansion).
2//!
3//! `ProjectionLayer` wraps a weight matrix and bias vector together with an
4//! optional activation function.  It can reduce high-dimensional embeddings
5//! to a smaller space (e.g. 768 → 128) or expand them to a larger space.
6
7/// Initialisation methods for the projection weights.
8#[derive(Debug, Clone, PartialEq)]
9pub enum InitMethod {
10    /// Initialise all weights and biases to zero.
11    Zeros,
12    /// Initialise as identity-like (min-dim diagonal is 1, rest 0).
13    Identity,
14    /// Pseudo-random initialisation using the given seed (Xavier-style scaling).
15    Random(u64),
16}
17
18/// Activation function applied element-wise to the projected output.
19#[derive(Debug, Clone, PartialEq)]
20pub enum ActivationFn {
21    /// Rectified linear unit: max(0, x).
22    ReLU,
23    /// Hyperbolic tangent.
24    Tanh,
25    /// Logistic sigmoid: 1 / (1 + exp(−x)).
26    Sigmoid,
27    /// No activation (linear / identity).
28    None,
29}
30
31/// The weight matrix and bias for a linear projection.
32#[derive(Debug, Clone, PartialEq)]
33pub struct ProjectionMatrix {
34    /// Number of input dimensions.
35    pub input_dim: usize,
36    /// Number of output dimensions.
37    pub output_dim: usize,
38    /// Row-major weight matrix: `weights[out][in]`.
39    pub weights: Vec<Vec<f64>>,
40    /// Bias vector of length `output_dim`.
41    pub bias: Vec<f64>,
42}
43
44impl ProjectionMatrix {
45    fn new_zeros(input_dim: usize, output_dim: usize) -> Self {
46        Self {
47            input_dim,
48            output_dim,
49            weights: vec![vec![0.0; input_dim]; output_dim],
50            bias: vec![0.0; output_dim],
51        }
52    }
53
54    fn new_identity(input_dim: usize, output_dim: usize) -> Self {
55        let mut weights = vec![vec![0.0; input_dim]; output_dim];
56        let min_dim = input_dim.min(output_dim);
57        for (i, row) in weights.iter_mut().enumerate().take(min_dim) {
58            row[i] = 1.0;
59        }
60        Self {
61            input_dim,
62            output_dim,
63            weights,
64            bias: vec![0.0; output_dim],
65        }
66    }
67
68    /// Simple LCG-based pseudo-random weight initialisation (Xavier uniform).
69    fn new_random(input_dim: usize, output_dim: usize, seed: u64) -> Self {
70        // Xavier uniform bound: sqrt(6 / (fan_in + fan_out))
71        let limit = (6.0_f64 / (input_dim + output_dim) as f64).sqrt();
72        let mut state = seed.wrapping_add(1);
73        let mut weights = vec![vec![0.0; input_dim]; output_dim];
74        for row in weights.iter_mut() {
75            for w in row.iter_mut() {
76                // LCG step
77                state = state
78                    .wrapping_mul(6_364_136_223_846_793_005)
79                    .wrapping_add(1_442_695_040_888_963_407);
80                // Map to [0, 1)
81                let u = (state >> 11) as f64 / (1u64 << 53) as f64;
82                // Map to [-limit, limit]
83                *w = (u * 2.0 - 1.0) * limit;
84            }
85        }
86        Self {
87            input_dim,
88            output_dim,
89            weights,
90            bias: vec![0.0; output_dim],
91        }
92    }
93}
94
95// ──────────────────────────────────────────────────────────────────────────────
96// apply_activation
97// ──────────────────────────────────────────────────────────────────────────────
98
99/// Apply `act` element-wise to `val`.
100fn apply_activation(val: f64, act: &ActivationFn) -> f64 {
101    match act {
102        ActivationFn::ReLU => val.max(0.0),
103        ActivationFn::Tanh => val.tanh(),
104        ActivationFn::Sigmoid => 1.0 / (1.0 + (-val).exp()),
105        ActivationFn::None => val,
106    }
107}
108
109// ──────────────────────────────────────────────────────────────────────────────
110// ProjectionLayer
111// ──────────────────────────────────────────────────────────────────────────────
112
113/// A linear projection layer: `output = activation(W·input + b)`.
114#[derive(Debug, Clone)]
115pub struct ProjectionLayer {
116    matrix: ProjectionMatrix,
117    activation: Option<ActivationFn>,
118}
119
120impl ProjectionLayer {
121    /// Create a new projection layer with the given initialisation method.
122    pub fn new(input_dim: usize, output_dim: usize, init: InitMethod) -> Self {
123        let matrix = match init {
124            InitMethod::Zeros => ProjectionMatrix::new_zeros(input_dim, output_dim),
125            InitMethod::Identity => ProjectionMatrix::new_identity(input_dim, output_dim),
126            InitMethod::Random(seed) => ProjectionMatrix::new_random(input_dim, output_dim, seed),
127        };
128        Self {
129            matrix,
130            activation: None,
131        }
132    }
133
134    /// Attach an activation function (builder pattern).
135    pub fn with_activation(mut self, activation: ActivationFn) -> Self {
136        self.activation = Some(activation);
137        self
138    }
139
140    /// Project a single input vector.
141    ///
142    /// Returns an error string if `input.len() != self.input_dim()`.
143    pub fn project(&self, input: &[f64]) -> Vec<f64> {
144        debug_assert_eq!(
145            input.len(),
146            self.matrix.input_dim,
147            "input dimension mismatch"
148        );
149        let mut output = Vec::with_capacity(self.matrix.output_dim);
150        for (i, row) in self.matrix.weights.iter().enumerate() {
151            let mut sum = self.matrix.bias[i];
152            for (w, x) in row.iter().zip(input.iter()) {
153                sum += w * x;
154            }
155            let activated = if let Some(act) = &self.activation {
156                apply_activation(sum, act)
157            } else {
158                sum
159            };
160            output.push(activated);
161        }
162        output
163    }
164
165    /// Project a batch of input vectors.
166    pub fn project_batch(&self, inputs: &[Vec<f64>]) -> Vec<Vec<f64>> {
167        inputs.iter().map(|inp| self.project(inp)).collect()
168    }
169
170    /// Return the input dimensionality.
171    pub fn input_dim(&self) -> usize {
172        self.matrix.input_dim
173    }
174
175    /// Return the output dimensionality.
176    pub fn output_dim(&self) -> usize {
177        self.matrix.output_dim
178    }
179
180    /// Replace the weight matrix.
181    ///
182    /// Returns `Err` if dimensions do not match `(output_dim × input_dim)`.
183    pub fn set_weights(&mut self, weights: Vec<Vec<f64>>) -> Result<(), String> {
184        if weights.len() != self.matrix.output_dim {
185            return Err(format!(
186                "expected {} output rows, got {}",
187                self.matrix.output_dim,
188                weights.len()
189            ));
190        }
191        for (i, row) in weights.iter().enumerate() {
192            if row.len() != self.matrix.input_dim {
193                return Err(format!(
194                    "row {} has {} columns, expected {}",
195                    i,
196                    row.len(),
197                    self.matrix.input_dim
198                ));
199            }
200        }
201        self.matrix.weights = weights;
202        Ok(())
203    }
204
205    /// Replace the bias vector.
206    ///
207    /// Returns `Err` if `bias.len() != output_dim`.
208    pub fn set_bias(&mut self, bias: Vec<f64>) -> Result<(), String> {
209        if bias.len() != self.matrix.output_dim {
210            return Err(format!(
211                "expected bias length {}, got {}",
212                self.matrix.output_dim,
213                bias.len()
214            ));
215        }
216        self.matrix.bias = bias;
217        Ok(())
218    }
219
220    /// Return the total number of learnable parameters (weights + biases).
221    pub fn parameter_count(&self) -> usize {
222        self.matrix.input_dim * self.matrix.output_dim + self.matrix.output_dim
223    }
224}
225
226// ──────────────────────────────────────────────────────────────────────────────
227// Tests
228// ──────────────────────────────────────────────────────────────────────────────
229
230#[cfg(test)]
231mod tests {
232    use super::*;
233
234    // ── apply_activation ──────────────────────────────────────────────────────
235
236    #[test]
237    fn test_activation_relu_positive() {
238        assert_eq!(apply_activation(2.5, &ActivationFn::ReLU), 2.5);
239    }
240
241    #[test]
242    fn test_activation_relu_negative() {
243        assert_eq!(apply_activation(-3.0, &ActivationFn::ReLU), 0.0);
244    }
245
246    #[test]
247    fn test_activation_relu_zero() {
248        assert_eq!(apply_activation(0.0, &ActivationFn::ReLU), 0.0);
249    }
250
251    #[test]
252    fn test_activation_tanh() {
253        let v = apply_activation(0.0, &ActivationFn::Tanh);
254        assert!((v - 0.0).abs() < 1e-10);
255        let v2 = apply_activation(1.0, &ActivationFn::Tanh);
256        assert!((v2 - 1.0_f64.tanh()).abs() < 1e-10);
257    }
258
259    #[test]
260    fn test_activation_sigmoid_at_zero() {
261        let v = apply_activation(0.0, &ActivationFn::Sigmoid);
262        assert!((v - 0.5).abs() < 1e-10);
263    }
264
265    #[test]
266    fn test_activation_sigmoid_large_positive() {
267        let v = apply_activation(100.0, &ActivationFn::Sigmoid);
268        assert!((v - 1.0).abs() < 1e-6);
269    }
270
271    #[test]
272    fn test_activation_sigmoid_large_negative() {
273        let v = apply_activation(-100.0, &ActivationFn::Sigmoid);
274        assert!(v < 1e-6);
275    }
276
277    #[test]
278    #[allow(clippy::approx_constant)]
279    fn test_activation_none_is_identity() {
280        assert_eq!(apply_activation(3.14, &ActivationFn::None), 3.14);
281        assert_eq!(apply_activation(-7.0, &ActivationFn::None), -7.0);
282    }
283
284    // ── ProjectionLayer construction ──────────────────────────────────────────
285
286    #[test]
287    fn test_new_zeros() {
288        let layer = ProjectionLayer::new(4, 2, InitMethod::Zeros);
289        assert_eq!(layer.input_dim(), 4);
290        assert_eq!(layer.output_dim(), 2);
291        // All outputs should be 0
292        let out = layer.project(&[1.0, 2.0, 3.0, 4.0]);
293        assert_eq!(out, vec![0.0, 0.0]);
294    }
295
296    #[test]
297    fn test_new_identity_square() {
298        let layer = ProjectionLayer::new(3, 3, InitMethod::Identity);
299        let input = vec![1.0, 2.0, 3.0];
300        let out = layer.project(&input);
301        assert_eq!(out, vec![1.0, 2.0, 3.0]);
302    }
303
304    #[test]
305    fn test_new_identity_reduce_dim() {
306        let layer = ProjectionLayer::new(4, 2, InitMethod::Identity);
307        let input = vec![5.0, 7.0, 9.0, 11.0];
308        let out = layer.project(&input);
309        // Only the first 2 inputs are copied (diagonal)
310        assert!((out[0] - 5.0).abs() < 1e-10);
311        assert!((out[1] - 7.0).abs() < 1e-10);
312    }
313
314    #[test]
315    fn test_new_random_produces_output() {
316        let layer = ProjectionLayer::new(8, 4, InitMethod::Random(42));
317        let input = vec![1.0; 8];
318        let out = layer.project(&input);
319        assert_eq!(out.len(), 4);
320    }
321
322    #[test]
323    fn test_new_random_different_seeds_differ() {
324        let l1 = ProjectionLayer::new(4, 2, InitMethod::Random(1));
325        let l2 = ProjectionLayer::new(4, 2, InitMethod::Random(2));
326        let input = vec![1.0, 1.0, 1.0, 1.0];
327        let o1 = l1.project(&input);
328        let o2 = l2.project(&input);
329        assert_ne!(o1, o2);
330    }
331
332    #[test]
333    fn test_new_random_same_seed_same_output() {
334        let l1 = ProjectionLayer::new(4, 2, InitMethod::Random(99));
335        let l2 = ProjectionLayer::new(4, 2, InitMethod::Random(99));
336        let input = vec![1.0, 0.5, -0.5, -1.0];
337        assert_eq!(l1.project(&input), l2.project(&input));
338    }
339
340    // ── parameter_count ───────────────────────────────────────────────────────
341
342    #[test]
343    fn test_parameter_count() {
344        let layer = ProjectionLayer::new(10, 5, InitMethod::Zeros);
345        // 10 * 5 weights + 5 biases = 55
346        assert_eq!(layer.parameter_count(), 55);
347    }
348
349    #[test]
350    fn test_parameter_count_large() {
351        let layer = ProjectionLayer::new(768, 128, InitMethod::Zeros);
352        assert_eq!(layer.parameter_count(), 768 * 128 + 128);
353    }
354
355    // ── set_weights ───────────────────────────────────────────────────────────
356
357    #[test]
358    fn test_set_weights_valid() {
359        let mut layer = ProjectionLayer::new(3, 2, InitMethod::Zeros);
360        let weights = vec![vec![1.0, 2.0, 3.0], vec![4.0, 5.0, 6.0]];
361        assert!(layer.set_weights(weights).is_ok());
362        let out = layer.project(&[1.0, 1.0, 1.0]);
363        assert!((out[0] - 6.0).abs() < 1e-10); // 1+2+3
364        assert!((out[1] - 15.0).abs() < 1e-10); // 4+5+6
365    }
366
367    #[test]
368    fn test_set_weights_wrong_row_count() {
369        let mut layer = ProjectionLayer::new(3, 2, InitMethod::Zeros);
370        let err = layer.set_weights(vec![vec![1.0, 2.0, 3.0]]);
371        assert!(err.is_err());
372    }
373
374    #[test]
375    fn test_set_weights_wrong_col_count() {
376        let mut layer = ProjectionLayer::new(3, 2, InitMethod::Zeros);
377        let err = layer.set_weights(vec![vec![1.0, 2.0], vec![3.0, 4.0]]);
378        assert!(err.is_err());
379    }
380
381    // ── set_bias ──────────────────────────────────────────────────────────────
382
383    #[test]
384    fn test_set_bias_valid() {
385        let mut layer = ProjectionLayer::new(2, 2, InitMethod::Identity);
386        assert!(layer.set_bias(vec![10.0, 20.0]).is_ok());
387        let out = layer.project(&[1.0, 2.0]);
388        assert!((out[0] - 11.0).abs() < 1e-10);
389        assert!((out[1] - 22.0).abs() < 1e-10);
390    }
391
392    #[test]
393    fn test_set_bias_wrong_length() {
394        let mut layer = ProjectionLayer::new(2, 2, InitMethod::Zeros);
395        let err = layer.set_bias(vec![1.0, 2.0, 3.0]);
396        assert!(err.is_err());
397    }
398
399    // ── with_activation ───────────────────────────────────────────────────────
400
401    #[test]
402    fn test_relu_activation_clips_negative() {
403        let mut layer =
404            ProjectionLayer::new(2, 2, InitMethod::Identity).with_activation(ActivationFn::ReLU);
405        assert!(layer.set_bias(vec![-5.0, -5.0]).is_ok());
406        let out = layer.project(&[1.0, 1.0]);
407        // 1 - 5 = -4 → relu → 0
408        assert_eq!(out, vec![0.0, 0.0]);
409    }
410
411    #[test]
412    fn test_tanh_activation_bounds() {
413        let layer =
414            ProjectionLayer::new(1, 1, InitMethod::Identity).with_activation(ActivationFn::Tanh);
415        let out = layer.project(&[100.0]);
416        // tanh of large positive → ~1.0
417        assert!((out[0] - 1.0).abs() < 1e-6);
418    }
419
420    #[test]
421    fn test_sigmoid_activation_bounds() {
422        let layer =
423            ProjectionLayer::new(1, 1, InitMethod::Identity).with_activation(ActivationFn::Sigmoid);
424        let out0 = layer.project(&[0.0]);
425        assert!((out0[0] - 0.5).abs() < 1e-10);
426    }
427
428    #[test]
429    fn test_no_activation_is_linear() {
430        let layer =
431            ProjectionLayer::new(1, 1, InitMethod::Identity).with_activation(ActivationFn::None);
432        let out = layer.project(&[42.0]);
433        assert!((out[0] - 42.0).abs() < 1e-10);
434    }
435
436    // ── project_batch ─────────────────────────────────────────────────────────
437
438    #[test]
439    fn test_project_batch_empty() {
440        let layer = ProjectionLayer::new(3, 2, InitMethod::Zeros);
441        let result = layer.project_batch(&[]);
442        assert!(result.is_empty());
443    }
444
445    #[test]
446    fn test_project_batch_multiple() {
447        let layer = ProjectionLayer::new(2, 2, InitMethod::Identity);
448        let inputs = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
449        let results = layer.project_batch(&inputs);
450        assert_eq!(results.len(), 2);
451        assert_eq!(results[0], vec![1.0, 2.0]);
452        assert_eq!(results[1], vec![3.0, 4.0]);
453    }
454
455    #[test]
456    fn test_project_batch_consistency() {
457        let layer = ProjectionLayer::new(4, 2, InitMethod::Random(7));
458        let input = vec![1.0, 0.5, -0.5, -1.0];
459        let single = layer.project(&input);
460        let batch = layer.project_batch(&[input]);
461        assert_eq!(batch[0], single);
462    }
463
464    // ── dimensionality ────────────────────────────────────────────────────────
465
466    #[test]
467    fn test_reduce_dim_768_to_128() {
468        let layer = ProjectionLayer::new(768, 128, InitMethod::Zeros);
469        assert_eq!(layer.input_dim(), 768);
470        assert_eq!(layer.output_dim(), 128);
471        let input = vec![0.0; 768];
472        let out = layer.project(&input);
473        assert_eq!(out.len(), 128);
474    }
475
476    #[test]
477    fn test_expand_dim() {
478        let layer = ProjectionLayer::new(32, 256, InitMethod::Identity);
479        assert_eq!(layer.output_dim(), 256);
480        let input = vec![1.0; 32];
481        let out = layer.project(&input);
482        assert_eq!(out.len(), 256);
483    }
484
485    // ── edge cases ────────────────────────────────────────────────────────────
486
487    #[test]
488    fn test_single_dim_projection() {
489        let mut layer = ProjectionLayer::new(1, 1, InitMethod::Zeros);
490        assert!(layer.set_weights(vec![vec![3.0]]).is_ok());
491        assert!(layer.set_bias(vec![1.0]).is_ok());
492        let out = layer.project(&[2.0]);
493        assert!((out[0] - 7.0).abs() < 1e-10); // 3*2 + 1
494    }
495
496    #[test]
497    fn test_zero_input_with_bias() {
498        let mut layer = ProjectionLayer::new(3, 2, InitMethod::Zeros);
499        assert!(layer.set_bias(vec![1.0, 2.0]).is_ok());
500        let out = layer.project(&[0.0, 0.0, 0.0]);
501        assert_eq!(out, vec![1.0, 2.0]);
502    }
503
504    #[test]
505    fn test_init_method_equality() {
506        assert_eq!(InitMethod::Zeros, InitMethod::Zeros);
507        assert_eq!(InitMethod::Random(42), InitMethod::Random(42));
508        assert_ne!(InitMethod::Random(1), InitMethod::Random(2));
509    }
510
511    #[test]
512    fn test_activation_fn_equality() {
513        assert_eq!(ActivationFn::ReLU, ActivationFn::ReLU);
514        assert_ne!(ActivationFn::ReLU, ActivationFn::Tanh);
515    }
516}