1use scirs2_core::ndarray_ext::{Array1, Array2};
5#[allow(unused_imports)]
6use scirs2_core::random::{Random, Rng};
7
8pub fn xavier_init<R>(
10 shape: (usize, usize),
11 fan_in: usize,
12 fan_out: usize,
13 rng: &mut Random<R>,
14) -> Array2<f64>
15where
16 R: scirs2_core::random::RngCore,
17{
18 let limit = (6.0 / (fan_in + fan_out) as f64).sqrt();
19 let scale = 2.0 * limit;
20 Array2::from_shape_fn(shape, |_| rng.random_f64() * scale - limit)
21}
22
23pub fn batch_xavier_init(
25 shapes: &[(usize, usize)],
26 fan_in: usize,
27 fan_out: usize,
28 rng: &mut Random,
29) -> Vec<Array2<f64>> {
30 let limit = (6.0 / (fan_in + fan_out) as f64).sqrt();
31 let scale = 2.0 * limit;
32
33 shapes
34 .iter()
35 .map(|&shape| Array2::from_shape_fn(shape, |_| rng.random_f64() * scale - limit))
36 .collect()
37}
38
39pub fn uniform_init(shape: (usize, usize), low: f64, high: f64, rng: &mut Random) -> Array2<f64> {
41 Array2::from_shape_fn(shape, |_| rng.random_f64() * (high - low) + low)
42}
43
44pub fn normal_init(shape: (usize, usize), mean: f64, std: f64, rng: &mut Random) -> Array2<f64> {
46 Array2::from_shape_fn(shape, |_| {
47 let u1 = rng.random_f64();
49 let u2 = rng.random_f64();
50 let z0 = (-2.0 * u1.ln()).sqrt() * (2.0 * std::f64::consts::PI * u2).cos();
51 mean + std * z0
52 })
53}
54
55pub fn normalize_embeddings(embeddings: &mut Array2<f64>) {
57 for mut row in embeddings.rows_mut() {
58 let norm = row.dot(&row).sqrt();
59 if norm > 1e-10 {
60 row /= norm;
61 }
62 }
63}
64
65pub fn normalize_vector(vector: &mut Array1<f64>) {
67 let norm = vector.dot(vector).sqrt();
68 if norm > 1e-10 {
69 *vector /= norm;
70 }
71}
72
73pub fn l2_distance(a: &Array1<f64>, b: &Array1<f64>) -> f64 {
75 scirs2_core::ndarray_ext::Zip::from(a)
77 .and(b)
78 .fold(0.0, |acc, &a_val, &b_val| {
79 let diff = a_val - b_val;
80 acc + diff * diff
81 })
82 .sqrt()
83}
84
85pub fn l1_distance(a: &Array1<f64>, b: &Array1<f64>) -> f64 {
87 scirs2_core::ndarray_ext::Zip::from(a)
88 .and(b)
89 .fold(0.0, |acc, &a_val, &b_val| acc + (a_val - b_val).abs())
90}
91
92pub fn cosine_similarity(a: &Array1<f64>, b: &Array1<f64>) -> f64 {
94 let (dot_product, norm_a_sq, norm_b_sq) = scirs2_core::ndarray_ext::Zip::from(a).and(b).fold(
95 (0.0, 0.0, 0.0),
96 |(dot, norm_a, norm_b), &a_val, &b_val| {
97 (
98 dot + a_val * b_val,
99 norm_a + a_val * a_val,
100 norm_b + b_val * b_val,
101 )
102 },
103 );
104
105 let norm_product = (norm_a_sq * norm_b_sq).sqrt();
106 if norm_product > 1e-10 {
107 dot_product / norm_product
108 } else {
109 0.0
110 }
111}
112
113pub fn batch_l2_distances(vectors_a: &[Array1<f64>], vectors_b: &[Array1<f64>]) -> Vec<f64> {
115 let mut distances = Vec::with_capacity(vectors_a.len() * vectors_b.len());
117
118 for a in vectors_a {
119 for b in vectors_b {
120 distances.push(l2_distance(a, b));
121 }
122 }
123
124 distances
125}
126
127pub fn pairwise_distances(vectors: &[Array1<f64>]) -> Array2<f64> {
129 let n = vectors.len();
130 let mut distances = Array2::zeros((n, n));
131
132 for i in 0..n {
133 for j in (i + 1)..n {
134 let dist = l2_distance(&vectors[i], &vectors[j]);
135 distances[[i, j]] = dist;
136 distances[[j, i]] = dist; }
138 }
139
140 distances
141}
142
143pub fn cosine_similarity_f32(a: &Array1<f32>, b: &Array1<f32>) -> f32 {
146 let (dot_product, norm_a_sq, norm_b_sq) = scirs2_core::ndarray_ext::Zip::from(a).and(b).fold(
147 (0.0_f32, 0.0_f32, 0.0_f32),
148 |(dot, norm_a, norm_b), &a_val, &b_val| {
149 (
150 dot + a_val * b_val,
151 norm_a + a_val * a_val,
152 norm_b + b_val * b_val,
153 )
154 },
155 );
156
157 let norm_product = (norm_a_sq * norm_b_sq).sqrt();
158 if norm_product > 1e-10 {
159 dot_product / norm_product
160 } else {
161 0.0
162 }
163}
164
165pub fn l2_distance_f32(a: &Array1<f32>, b: &Array1<f32>) -> f32 {
167 scirs2_core::ndarray_ext::Zip::from(a)
168 .and(b)
169 .fold(0.0_f32, |acc, &a_val, &b_val| {
170 let diff = a_val - b_val;
171 acc + diff * diff
172 })
173 .sqrt()
174}
175
176pub fn clamp_embeddings(embeddings: &mut Array2<f64>, max_norm: f64) {
178 for mut row in embeddings.rows_mut() {
179 let norm = row.dot(&row).sqrt();
180 if norm > max_norm {
181 row *= max_norm / norm;
182 }
183 }
184}
185
186pub fn gradient_update(
188 embeddings: &mut Array2<f64>,
189 gradients: &Array2<f64>,
190 learning_rate: f64,
191 l2_reg: f64,
192) {
193 scirs2_core::ndarray_ext::Zip::from(embeddings)
195 .and(gradients)
196 .for_each(|embed, &grad| {
197 *embed = *embed - learning_rate * (grad + l2_reg * *embed);
198 });
199}
200
201pub fn batch_gradient_update(
203 embeddings: &mut [Array2<f64>],
204 gradients: &[Array2<f64>],
205 learning_rate: f64,
206 l2_reg: f64,
207) {
208 for (embedding, gradient) in embeddings.iter_mut().zip(gradients.iter()) {
209 gradient_update(embedding, gradient, learning_rate, l2_reg);
210 }
211}
212
213pub fn gradient_update_single(
215 embedding: &mut Array1<f64>,
216 gradient: &Array1<f64>,
217 learning_rate: f64,
218 l2_reg: f64,
219) {
220 *embedding = embedding.clone() - learning_rate * (gradient + l2_reg * &*embedding);
221}
222
223pub fn sigmoid(x: f64) -> f64 {
225 1.0 / (1.0 + (-x).exp())
226}
227
228pub fn relu(x: f64) -> f64 {
230 x.max(0.0)
231}
232
233pub fn tanh(x: f64) -> f64 {
235 x.tanh()
236}
237
238pub fn margin_loss(positive_score: f64, negative_score: f64, margin: f64) -> f64 {
240 (margin + negative_score - positive_score).max(0.0)
241}
242
243pub fn logistic_loss(score: f64, label: f64) -> f64 {
245 (1.0 + (-label * score).exp()).ln()
246}
247
248pub fn shuffle_batch<T>(batch: &mut [T], rng: &mut Random) {
250 if batch.len() <= 1 {
252 return;
253 }
254
255 for i in (1..batch.len()).rev() {
256 let j = rng.random_range(0..i + 1);
257 if i != j {
258 batch.swap(i, j);
259 }
260 }
261}
262
263pub fn shuffle_multiple_batches<T: Clone>(batches: &mut [Vec<T>], rng: &mut Random) {
265 for batch in batches.iter_mut() {
266 shuffle_batch(batch, rng);
267 }
268}
269
270pub fn sample_without_replacement<T: Clone>(
272 data: &[T],
273 sample_size: usize,
274 rng: &mut Random,
275) -> Vec<T> {
276 if sample_size >= data.len() {
277 return data.to_vec();
278 }
279
280 let mut indices: Vec<usize> = (0..data.len()).collect();
281 shuffle_batch(&mut indices, rng);
282
283 indices[..sample_size]
284 .iter()
285 .map(|&i| data[i].clone())
286 .collect()
287}
288
289pub fn create_batches<T: Clone>(data: &[T], batch_size: usize) -> Vec<Vec<T>> {
291 let mut batches = Vec::with_capacity((data.len() + batch_size - 1) / batch_size);
292 for chunk in data.chunks(batch_size) {
293 batches.push(chunk.to_vec());
294 }
295 batches
296}
297
298pub fn create_batch_refs<T>(data: &[T], batch_size: usize) -> impl Iterator<Item = &[T]> {
300 data.chunks(batch_size)
301}
302
303pub fn ndarray_to_vector(array: &Array1<f64>) -> crate::Vector {
305 let mut values = Vec::with_capacity(array.len());
306 values.extend(array.iter().map(|&x| x as f32));
307 crate::Vector::new(values)
308}
309
310pub fn vector_to_ndarray(vector: &crate::Vector) -> Array1<f64> {
312 let mut values = Vec::with_capacity(vector.values.len());
313 values.extend(vector.values.iter().map(|&x| x as f64));
314 Array1::from_vec(values)
315}
316
317pub fn batch_ndarray_to_vectors(arrays: &[Array1<f64>]) -> Vec<crate::Vector> {
319 arrays.iter().map(ndarray_to_vector).collect()
320}
321
322pub enum LearningRateSchedule {
324 Constant(f64),
326 ExponentialDecay {
328 initial_lr: f64,
329 decay_rate: f64,
330 decay_steps: usize,
331 },
332 StepDecay {
334 initial_lr: f64,
335 step_size: usize,
336 factor: f64,
337 },
338 PolynomialDecay {
340 initial_lr: f64,
341 final_lr: f64,
342 decay_steps: usize,
343 power: f64,
344 },
345}
346
347impl LearningRateSchedule {
348 pub fn get_lr(&self, epoch: usize) -> f64 {
350 match self {
351 LearningRateSchedule::Constant(lr) => *lr,
352 LearningRateSchedule::ExponentialDecay {
353 initial_lr,
354 decay_rate,
355 decay_steps,
356 } => initial_lr * decay_rate.powf(epoch as f64 / *decay_steps as f64),
357 LearningRateSchedule::StepDecay {
358 initial_lr,
359 step_size,
360 factor,
361 } => initial_lr * factor.powf((epoch / step_size) as f64),
362 LearningRateSchedule::PolynomialDecay {
363 initial_lr,
364 final_lr,
365 decay_steps,
366 power,
367 } => {
368 if epoch >= *decay_steps {
369 *final_lr
370 } else {
371 let decay_factor = (1.0 - epoch as f64 / *decay_steps as f64).powf(*power);
372 final_lr + (initial_lr - final_lr) * decay_factor
373 }
374 }
375 }
376 }
377}
378
379pub struct EarlyStopping {
381 patience: usize,
382 min_delta: f64,
383 best_loss: f64,
384 wait_count: usize,
385 stopped: bool,
386}
387
388impl EarlyStopping {
389 pub fn new(patience: usize, min_delta: f64) -> Self {
391 Self {
392 patience,
393 min_delta,
394 best_loss: f64::INFINITY,
395 wait_count: 0,
396 stopped: false,
397 }
398 }
399
400 pub fn update(&mut self, current_loss: f64) -> bool {
402 if current_loss < self.best_loss - self.min_delta {
403 self.best_loss = current_loss;
404 self.wait_count = 0;
405 } else {
406 self.wait_count += 1;
407 if self.wait_count > self.patience {
408 self.stopped = true;
409 }
410 }
411
412 self.stopped
413 }
414
415 pub fn should_stop(&self) -> bool {
417 self.stopped
418 }
419
420 pub fn best_loss(&self) -> f64 {
422 self.best_loss
423 }
424}
425
426#[cfg(test)]
427mod tests {
428 use super::*;
429 use scirs2_core::ndarray_ext::Array1;
430
431 #[test]
432 fn test_distance_functions() {
433 let a = Array1::from_vec(vec![1.0, 2.0, 3.0]);
434 let b = Array1::from_vec(vec![4.0, 5.0, 6.0]);
435
436 let l2_dist = l2_distance(&a, &b);
437 assert!((l2_dist - 5.196152422706632).abs() < 1e-10);
438
439 let l1_dist = l1_distance(&a, &b);
440 assert!((l1_dist - 9.0).abs() < 1e-10);
441
442 let cos_sim = cosine_similarity(&a, &b);
443 assert!(cos_sim > 0.0 && cos_sim < 1.0);
444 }
445
446 #[test]
447 fn test_normalization() {
448 let mut vec = Array1::from_vec(vec![3.0, 4.0]);
449 normalize_vector(&mut vec);
450 let norm = vec.dot(&vec).sqrt();
451 assert!((norm - 1.0).abs() < 1e-10);
452 }
453
454 #[test]
455 fn test_learning_rate_schedule() {
456 let schedule = LearningRateSchedule::ExponentialDecay {
457 initial_lr: 0.1,
458 decay_rate: 0.9,
459 decay_steps: 10,
460 };
461
462 let lr0 = schedule.get_lr(0);
463 let lr10 = schedule.get_lr(10);
464 let lr20 = schedule.get_lr(20);
465
466 assert!((lr0 - 0.1).abs() < 1e-10);
467 assert!(lr10 < lr0);
468 assert!(lr20 < lr10);
469 }
470
471 #[test]
472 fn test_early_stopping() {
473 let mut early_stop = EarlyStopping::new(3, 0.01);
474
475 assert!(!early_stop.update(1.0));
476 assert!(!early_stop.update(0.5));
477 assert!(!early_stop.update(0.51));
478 assert!(!early_stop.update(0.52));
479 assert!(!early_stop.update(0.53));
480 assert!(early_stop.update(0.54)); }
482}