oxirs_embed/models/
common.rs

1//! Common utilities and functions used across embedding models
2
3// Removed unused import
4use scirs2_core::ndarray_ext::{Array1, Array2};
5#[allow(unused_imports)]
6use scirs2_core::random::{Random, Rng};
7
8/// Initialize embeddings with Xavier/Glorot initialization (optimized)
9pub fn xavier_init<R>(
10    shape: (usize, usize),
11    fan_in: usize,
12    fan_out: usize,
13    rng: &mut Random<R>,
14) -> Array2<f64>
15where
16    R: scirs2_core::random::RngCore,
17{
18    let limit = (6.0 / (fan_in + fan_out) as f64).sqrt();
19    let scale = 2.0 * limit;
20    Array2::from_shape_fn(shape, |_| rng.random_f64() * scale - limit)
21}
22
23/// Batch Xavier initialization for multiple layers (memory efficient)
24pub fn batch_xavier_init(
25    shapes: &[(usize, usize)],
26    fan_in: usize,
27    fan_out: usize,
28    rng: &mut Random,
29) -> Vec<Array2<f64>> {
30    let limit = (6.0 / (fan_in + fan_out) as f64).sqrt();
31    let scale = 2.0 * limit;
32
33    shapes
34        .iter()
35        .map(|&shape| Array2::from_shape_fn(shape, |_| rng.random_f64() * scale - limit))
36        .collect()
37}
38
39/// Initialize embeddings with uniform distribution
40pub fn uniform_init(shape: (usize, usize), low: f64, high: f64, rng: &mut Random) -> Array2<f64> {
41    Array2::from_shape_fn(shape, |_| rng.random_f64() * (high - low) + low)
42}
43
44/// Initialize embeddings with normal distribution
45pub fn normal_init(shape: (usize, usize), mean: f64, std: f64, rng: &mut Random) -> Array2<f64> {
46    Array2::from_shape_fn(shape, |_| {
47        // Box-Muller transform for normal distribution
48        let u1 = rng.random_f64();
49        let u2 = rng.random_f64();
50        let z0 = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
51        mean + std * z0
52    })
53}
54
55/// Normalize embeddings to unit length (L2 normalization)
56pub fn normalize_embeddings(embeddings: &mut Array2<f64>) {
57    for mut row in embeddings.rows_mut() {
58        let norm = row.dot(&row).sqrt();
59        if norm > 1e-10 {
60            row /= norm;
61        }
62    }
63}
64
65/// Normalize a single embedding vector
66pub fn normalize_vector(vector: &mut Array1<f64>) {
67    let norm = vector.dot(vector).sqrt();
68    if norm > 1e-10 {
69        *vector /= norm;
70    }
71}
72
73/// Compute L2 distance between two vectors (optimized)
74pub fn l2_distance(a: &Array1<f64>, b: &Array1<f64>) -> f64 {
75    // Use scirs2-core's zip for better performance
76    scirs2_core::ndarray_ext::Zip::from(a)
77        .and(b)
78        .fold(0.0, |acc, &a_val, &b_val| {
79            let diff = a_val - b_val;
80            acc + diff * diff
81        })
82        .sqrt()
83}
84
85/// Compute L1 distance between two vectors (optimized)
86pub fn l1_distance(a: &Array1<f64>, b: &Array1<f64>) -> f64 {
87    scirs2_core::ndarray_ext::Zip::from(a)
88        .and(b)
89        .fold(0.0, |acc, &a_val, &b_val| acc + (a_val - b_val).abs())
90}
91
92/// Compute cosine similarity between two vectors (optimized)
93pub fn cosine_similarity(a: &Array1<f64>, b: &Array1<f64>) -> f64 {
94    let (dot_product, norm_a_sq, norm_b_sq) = scirs2_core::ndarray_ext::Zip::from(a).and(b).fold(
95        (0.0, 0.0, 0.0),
96        |(dot, norm_a, norm_b), &a_val, &b_val| {
97            (
98                dot + a_val * b_val,
99                norm_a + a_val * a_val,
100                norm_b + b_val * b_val,
101            )
102        },
103    );
104
105    let norm_product = (norm_a_sq * norm_b_sq).sqrt();
106    if norm_product > 1e-10 {
107        dot_product / norm_product
108    } else {
109        0.0
110    }
111}
112
113/// Batch distance computation for multiple vector pairs
114pub fn batch_l2_distances(vectors_a: &[Array1<f64>], vectors_b: &[Array1<f64>]) -> Vec<f64> {
115    // Compute all pairwise distances between vectors_a and vectors_b
116    let mut distances = Vec::with_capacity(vectors_a.len() * vectors_b.len());
117
118    for a in vectors_a {
119        for b in vectors_b {
120            distances.push(l2_distance(a, b));
121        }
122    }
123
124    distances
125}
126
127/// Efficient pairwise distance matrix computation
128pub fn pairwise_distances(vectors: &[Array1<f64>]) -> Array2<f64> {
129    let n = vectors.len();
130    let mut distances = Array2::zeros((n, n));
131
132    for i in 0..n {
133        for j in (i + 1)..n {
134            let dist = l2_distance(&vectors[i], &vectors[j]);
135            distances[[i, j]] = dist;
136            distances[[j, i]] = dist; // Matrix is symmetric
137        }
138    }
139
140    distances
141}
142
143// F32 versions for transformer training compatibility
144/// Compute cosine similarity between two f32 vectors (optimized)
145pub fn cosine_similarity_f32(a: &Array1<f32>, b: &Array1<f32>) -> f32 {
146    let (dot_product, norm_a_sq, norm_b_sq) = scirs2_core::ndarray_ext::Zip::from(a).and(b).fold(
147        (0.0_f32, 0.0_f32, 0.0_f32),
148        |(dot, norm_a, norm_b), &a_val, &b_val| {
149            (
150                dot + a_val * b_val,
151                norm_a + a_val * a_val,
152                norm_b + b_val * b_val,
153            )
154        },
155    );
156
157    let norm_product = (norm_a_sq * norm_b_sq).sqrt();
158    if norm_product > 1e-10 {
159        dot_product / norm_product
160    } else {
161        0.0
162    }
163}
164
165/// Compute L2 distance between two f32 vectors (optimized)
166pub fn l2_distance_f32(a: &Array1<f32>, b: &Array1<f32>) -> f32 {
167    scirs2_core::ndarray_ext::Zip::from(a)
168        .and(b)
169        .fold(0.0_f32, |acc, &a_val, &b_val| {
170            let diff = a_val - b_val;
171            acc + diff * diff
172        })
173        .sqrt()
174}
175
176/// Clamp embeddings to a maximum norm
177pub fn clamp_embeddings(embeddings: &mut Array2<f64>, max_norm: f64) {
178    for mut row in embeddings.rows_mut() {
179        let norm = row.dot(&row).sqrt();
180        if norm > max_norm {
181            row *= max_norm / norm;
182        }
183    }
184}
185
186/// Apply gradient descent update with L2 regularization (optimized)
187pub fn gradient_update(
188    embeddings: &mut Array2<f64>,
189    gradients: &Array2<f64>,
190    learning_rate: f64,
191    l2_reg: f64,
192) {
193    // Vectorized in-place update to avoid temporary allocations
194    scirs2_core::ndarray_ext::Zip::from(embeddings)
195        .and(gradients)
196        .for_each(|embed, &grad| {
197            *embed = *embed - learning_rate * (grad + l2_reg * *embed);
198        });
199}
200
201/// Batch gradient update for multiple embedding matrices
202pub fn batch_gradient_update(
203    embeddings: &mut [Array2<f64>],
204    gradients: &[Array2<f64>],
205    learning_rate: f64,
206    l2_reg: f64,
207) {
208    for (embedding, gradient) in embeddings.iter_mut().zip(gradients.iter()) {
209        gradient_update(embedding, gradient, learning_rate, l2_reg);
210    }
211}
212
213/// Apply gradient descent update for a single embedding
214pub fn gradient_update_single(
215    embedding: &mut Array1<f64>,
216    gradient: &Array1<f64>,
217    learning_rate: f64,
218    l2_reg: f64,
219) {
220    *embedding = embedding.clone() - learning_rate * (gradient + l2_reg * &*embedding);
221}
222
223/// Sigmoid activation function
224pub fn sigmoid(x: f64) -> f64 {
225    1.0 / (1.0 + (-x).exp())
226}
227
228/// ReLU activation function
229pub fn relu(x: f64) -> f64 {
230    x.max(0.0)
231}
232
233/// Tanh activation function
234pub fn tanh(x: f64) -> f64 {
235    x.tanh()
236}
237
238/// Compute margin-based ranking loss
239pub fn margin_loss(positive_score: f64, negative_score: f64, margin: f64) -> f64 {
240    (margin + negative_score - positive_score).max(0.0)
241}
242
243/// Compute logistic loss
244pub fn logistic_loss(score: f64, label: f64) -> f64 {
245    (1.0 + (-label * score).exp()).ln()
246}
247
248/// Batch shuffle utility (optimized for performance)
249pub fn shuffle_batch<T>(batch: &mut [T], rng: &mut Random) {
250    // Fisher-Yates shuffle with early termination for small batches
251    if batch.len() <= 1 {
252        return;
253    }
254
255    for i in (1..batch.len()).rev() {
256        let j = rng.random_range(0..i + 1);
257        if i != j {
258            batch.swap(i, j);
259        }
260    }
261}
262
263/// High-performance batch shuffling for multiple arrays
264pub fn shuffle_multiple_batches<T: Clone>(batches: &mut [Vec<T>], rng: &mut Random) {
265    for batch in batches.iter_mut() {
266        shuffle_batch(batch, rng);
267    }
268}
269
270/// Optimized random sampling without replacement
271pub fn sample_without_replacement<T: Clone>(
272    data: &[T],
273    sample_size: usize,
274    rng: &mut Random,
275) -> Vec<T> {
276    if sample_size >= data.len() {
277        return data.to_vec();
278    }
279
280    let mut indices: Vec<usize> = (0..data.len()).collect();
281    shuffle_batch(&mut indices, rng);
282
283    indices[..sample_size]
284        .iter()
285        .map(|&i| data[i].clone())
286        .collect()
287}
288
289/// Create batches from data (optimized to avoid unnecessary cloning)
290pub fn create_batches<T: Clone>(data: &[T], batch_size: usize) -> Vec<Vec<T>> {
291    let mut batches = Vec::with_capacity((data.len() + batch_size - 1) / batch_size);
292    for chunk in data.chunks(batch_size) {
293        batches.push(chunk.to_vec());
294    }
295    batches
296}
297
298/// Create batch references (zero-copy alternative)
299pub fn create_batch_refs<T>(data: &[T], batch_size: usize) -> impl Iterator<Item = &[T]> {
300    data.chunks(batch_size)
301}
302
303/// Convert ndarray to Vector (optimized with pre-allocation)
304pub fn ndarray_to_vector(array: &Array1<f64>) -> crate::Vector {
305    let mut values = Vec::with_capacity(array.len());
306    values.extend(array.iter().map(|&x| x as f32));
307    crate::Vector::new(values)
308}
309
310/// Convert Vector to ndarray (optimized with pre-allocation)
311pub fn vector_to_ndarray(vector: &crate::Vector) -> Array1<f64> {
312    let mut values = Vec::with_capacity(vector.values.len());
313    values.extend(vector.values.iter().map(|&x| x as f64));
314    Array1::from_vec(values)
315}
316
317/// Batch convert multiple ndarrays to vectors (SIMD-friendly)
318pub fn batch_ndarray_to_vectors(arrays: &[Array1<f64>]) -> Vec<crate::Vector> {
319    arrays.iter().map(ndarray_to_vector).collect()
320}
321
322/// Learning rate scheduling
323pub enum LearningRateSchedule {
324    /// Constant learning rate
325    Constant(f64),
326    /// Exponential decay: lr * decay_rate^(epoch / decay_steps)
327    ExponentialDecay {
328        initial_lr: f64,
329        decay_rate: f64,
330        decay_steps: usize,
331    },
332    /// Step decay: lr * factor^(epoch / step_size)
333    StepDecay {
334        initial_lr: f64,
335        step_size: usize,
336        factor: f64,
337    },
338    /// Polynomial decay
339    PolynomialDecay {
340        initial_lr: f64,
341        final_lr: f64,
342        decay_steps: usize,
343        power: f64,
344    },
345}
346
347impl LearningRateSchedule {
348    /// Get learning rate for a given epoch
349    pub fn get_lr(&self, epoch: usize) -> f64 {
350        match self {
351            LearningRateSchedule::Constant(lr) => *lr,
352            LearningRateSchedule::ExponentialDecay {
353                initial_lr,
354                decay_rate,
355                decay_steps,
356            } => initial_lr * decay_rate.powf(epoch as f64 / *decay_steps as f64),
357            LearningRateSchedule::StepDecay {
358                initial_lr,
359                step_size,
360                factor,
361            } => initial_lr * factor.powf((epoch / step_size) as f64),
362            LearningRateSchedule::PolynomialDecay {
363                initial_lr,
364                final_lr,
365                decay_steps,
366                power,
367            } => {
368                if epoch >= *decay_steps {
369                    *final_lr
370                } else {
371                    let decay_factor = (1.0 - epoch as f64 / *decay_steps as f64).powf(*power);
372                    final_lr + (initial_lr - final_lr) * decay_factor
373                }
374            }
375        }
376    }
377}
378
379/// Early stopping utility
380pub struct EarlyStopping {
381    patience: usize,
382    min_delta: f64,
383    best_loss: f64,
384    wait_count: usize,
385    stopped: bool,
386}
387
388impl EarlyStopping {
389    /// Create new early stopping monitor
390    pub fn new(patience: usize, min_delta: f64) -> Self {
391        Self {
392            patience,
393            min_delta,
394            best_loss: f64::INFINITY,
395            wait_count: 0,
396            stopped: false,
397        }
398    }
399
400    /// Update with current loss and check if should stop
401    pub fn update(&mut self, current_loss: f64) -> bool {
402        if current_loss < self.best_loss - self.min_delta {
403            self.best_loss = current_loss;
404            self.wait_count = 0;
405        } else {
406            self.wait_count += 1;
407            if self.wait_count > self.patience {
408                self.stopped = true;
409            }
410        }
411
412        self.stopped
413    }
414
415    /// Check if training should stop
416    pub fn should_stop(&self) -> bool {
417        self.stopped
418    }
419
420    /// Get best loss seen so far
421    pub fn best_loss(&self) -> f64 {
422        self.best_loss
423    }
424}
425
426#[cfg(test)]
427mod tests {
428    use super::*;
429    use scirs2_core::ndarray_ext::Array1;
430
431    #[test]
432    fn test_distance_functions() {
433        let a = Array1::from_vec(vec![1.0, 2.0, 3.0]);
434        let b = Array1::from_vec(vec![4.0, 5.0, 6.0]);
435
436        let l2_dist = l2_distance(&a, &b);
437        assert!((l2_dist - 5.196152422706632).abs() < 1e-10);
438
439        let l1_dist = l1_distance(&a, &b);
440        assert!((l1_dist - 9.0).abs() < 1e-10);
441
442        let cos_sim = cosine_similarity(&a, &b);
443        assert!(cos_sim > 0.0 && cos_sim < 1.0);
444    }
445
446    #[test]
447    fn test_normalization() {
448        let mut vec = Array1::from_vec(vec![3.0, 4.0]);
449        normalize_vector(&mut vec);
450        let norm = vec.dot(&vec).sqrt();
451        assert!((norm - 1.0).abs() < 1e-10);
452    }
453
454    #[test]
455    fn test_learning_rate_schedule() {
456        let schedule = LearningRateSchedule::ExponentialDecay {
457            initial_lr: 0.1,
458            decay_rate: 0.9,
459            decay_steps: 10,
460        };
461
462        let lr0 = schedule.get_lr(0);
463        let lr10 = schedule.get_lr(10);
464        let lr20 = schedule.get_lr(20);
465
466        assert!((lr0 - 0.1).abs() < 1e-10);
467        assert!(lr10 < lr0);
468        assert!(lr20 < lr10);
469    }
470
471    #[test]
472    fn test_early_stopping() {
473        let mut early_stop = EarlyStopping::new(3, 0.01);
474
475        assert!(!early_stop.update(1.0));
476        assert!(!early_stop.update(0.5));
477        assert!(!early_stop.update(0.51));
478        assert!(!early_stop.update(0.52));
479        assert!(!early_stop.update(0.53));
480        assert!(early_stop.update(0.54)); // Should stop now
481    }
482}