rustyml/math.rs
1use ahash::AHashMap;
2use ndarray::{Array1, ArrayBase, Data, Ix1};
3
4const EULER_GAMMA: f64 = 0.57721566490153286060651209008240243104215933593992;
5
6/// Calculates the total sum of squares (SST).
7///
8/// SST measures the total variability in the data as the sum of squared
9/// differences between each value and the mean of all values.
10///
11/// # Parameters
12///
13/// - `values` - Observed values stored in a 1D array
14///
15/// # Returns
16///
17/// - `f64` - Total sum of squares for the provided values
18///
19/// # Examples
20/// ```rust
21/// use rustyml::math::sum_of_square_total;
22/// use ndarray::array;
23///
24/// let values = array![1.0, 2.0, 3.0];
25/// let sst = sum_of_square_total(&values);
26/// // Mean is 2.0, so SST = (1-2)^2 + (2-2)^2 + (3-2)^2 = 1 + 0 + 1 = 2.0
27/// assert!((sst - 2.0).abs() < 1e-5);
28/// ```
29#[inline]
30pub fn sum_of_square_total<S>(values: &ArrayBase<S, Ix1>) -> f64
31where
32 S: Data<Elem = f64>,
33{
34 // Handle empty array case
35 if values.is_empty() {
36 return 0.0;
37 }
38 // Calculate the mean
39 let mean = values.mean().unwrap();
40 // Fully vectorized computation
41 values.mapv(|x| (x - mean).powi(2)).sum()
42}
43
44/// Calculates the sum of squared errors (SSE).
45///
46/// SSE measures the total squared difference between predicted values and actual labels.
47///
48/// # Parameters
49///
50/// - `predicted` - Predicted values vector
51/// - `actual` - Actual values vector
52///
53/// # Returns
54///
55/// - `f64` - Sum of squared errors computed as sum((predicted_i - actual_i)^2)
56///
57/// # Examples
58/// ```rust
59/// use rustyml::math::sum_of_squared_errors;
60/// use ndarray::array;
61///
62/// let predicted = array![2.0, 3.0];
63/// let actual = array![1.0, 3.0];
64/// let sse = sum_of_squared_errors(&predicted, &actual);
65/// // (2-1)^2 + (3-3)^2 = 1 + 0 = 1
66/// assert!((sse - 1.0).abs() < 1e-6);
67/// ```
68#[inline]
69pub fn sum_of_squared_errors<S1, S2>(
70 predicted: &ArrayBase<S1, Ix1>,
71 actual: &ArrayBase<S2, Ix1>,
72) -> f64
73where
74 S1: Data<Elem = f64>,
75 S2: Data<Elem = f64>,
76{
77 let sum: f64 = predicted
78 .iter()
79 .zip(actual.iter())
80 .map(|(p, a)| (p - a).powi(2))
81 .sum();
82
83 sum
84}
85
86/// Calculates the mean squared error (variance) of a set of values.
87///
88/// The variance is the average of the squared differences between each value
89/// and the mean of all values.
90///
91/// # Parameters
92///
93/// - `y` - Values for which to calculate the variance
94///
95/// # Returns
96///
97/// - `f64` - Variance of the input values (0.0 when the array is empty)
98///
99/// # Examples
100/// ```rust
101/// use ndarray::array;
102/// use rustyml::math::variance;
103///
104/// let values = array![1.0, 2.0, 3.0];
105/// let mse = variance(&values);
106/// // Mean is 2.0, so variance = ((1-2)^2 + (2-2)^2 + (3-2)^2) / 3 = (1 + 0 + 1) / 3 ~= 0.66667
107/// assert!((mse - 0.6666667).abs() < 1e-6);
108/// ```
109#[inline]
110pub fn variance<S>(y: &ArrayBase<S, Ix1>) -> f64
111where
112 S: Data<Elem = f64>,
113{
114 let n = y.len();
115
116 // Return 0.0 for empty arrays
117 if n == 0 {
118 return 0.0;
119 }
120
121 // Calculate mean using ndarray's mean method
122 // Handle potential NaN case when all values are NaN
123 let mean = match y.mean() {
124 Some(m) if m.is_finite() => m,
125 _ => return 0.0, // Return 0.0 if mean is NaN or can't be calculated
126 };
127
128 // Use fold for potentially better performance than map/sum
129 // This computes the sum of squared differences in one pass
130 let sum_squared_diff = y.fold(0.0, |acc, &val| {
131 if val.is_finite() {
132 let diff = val - mean;
133 acc + diff * diff
134 } else {
135 acc // Skip NaN or infinite values
136 }
137 });
138
139 // Compute variance (MSE)
140 sum_squared_diff / n as f64
141}
142
143/// Computes the logistic sigmoid for a scalar input.
144///
145/// The sigmoid maps any real number into the open interval (0, 1) with clipping
146/// for extreme values to preserve numerical stability.
147///
148/// # Parameters
149///
150/// - `z` - Input value to transform
151///
152/// # Returns
153///
154/// - `f64` - Sigmoid output in the range (0, 1)
155///
156/// # Examples
157/// ```rust
158/// use rustyml::math::sigmoid;
159///
160/// let value = sigmoid(0.0);
161/// // sigmoid(0) = 0.5
162/// assert!((value - 0.5).abs() < 1e-6);
163/// ```
164#[inline]
165pub fn sigmoid(z: f64) -> f64 {
166 // Use numerically stable computation for extreme values
167 const MAX_SIGMOID_INPUT: f64 = 500.0;
168 const MIN_SIGMOID_INPUT: f64 = -500.0;
169
170 if z > MAX_SIGMOID_INPUT {
171 // For very large positive values, sigmoid(z) approaches 1
172 return 1.0;
173 } else if z < MIN_SIGMOID_INPUT {
174 // For very large negative values, sigmoid(z) approaches 0
175 return 0.0;
176 }
177
178 // Standard computation for normal range
179 1.0 / (1.0 + (-z).exp())
180}
181
182/// Calculates the logistic regression loss (log loss).
183///
184/// This computes the average cross-entropy loss by applying the sigmoid
185/// to raw logits before evaluating the log-likelihood.
186///
187/// # Parameters
188///
189/// - `logits` - Raw model outputs (logits before sigmoid)
190/// - `actual_labels` - Binary labels (0 or 1)
191///
192/// # Returns
193///
194/// - `f64` - Average logistic regression loss
195///
196/// # Examples
197/// ```rust
198/// use rustyml::math::logistic_loss;
199/// use ndarray::array;
200///
201/// let logits = array![0.0, 2.0, -1.0];
202/// let actual_labels = array![0.0, 1.0, 0.0];
203/// let loss = logistic_loss(&logits, &actual_labels);
204/// // Expected average loss is approximately 0.37778
205/// assert!((loss - 0.37778).abs() < 1e-5);
206/// ```
207#[inline]
208pub fn logistic_loss<S1, S2>(logits: &ArrayBase<S1, Ix1>, actual_labels: &ArrayBase<S2, Ix1>) -> f64
209where
210 S1: Data<Elem = f64>,
211 S2: Data<Elem = f64>,
212{
213 // Using a vectorized approach to calculate log loss
214 let n = logits.len() as f64;
215
216 // Calculate total loss using zip to iterate through both arrays simultaneously
217 let total_loss = logits
218 .iter()
219 .zip(actual_labels.iter())
220 .map(|(&x, &y)| {
221 // Numerically stable way to calculate log loss:
222 // max(0, x) - x*y + log(1 + exp(-|x|))
223 x.max(0.0) - x * y + (1.0 + (-x.abs()).exp()).ln()
224 })
225 .sum::<f64>();
226
227 total_loss / n
228}
229
230/// Calculates the squared Euclidean distance between two vectors.
231///
232/// # Parameters
233///
234/// - `x1` - First vector
235/// - `x2` - Second vector
236///
237/// # Returns
238///
239/// - `f64` - Squared Euclidean distance between the two vectors
240///
241/// # Examples
242/// ```rust
243/// use ndarray::array;
244/// use rustyml::math::squared_euclidean_distance_row;
245///
246/// let v1 = array![1.0, 2.0, 3.0];
247/// let v2 = array![4.0, 5.0, 6.0];
248/// let dist = squared_euclidean_distance_row(&v1, &v2);
249/// // (4-1)^2 + (5-2)^2 + (6-3)^2 = 9 + 9 + 9 = 27
250/// assert!((dist - 27.0).abs() < 1e-10);
251/// ```
252#[inline]
253pub fn squared_euclidean_distance_row<S1, S2>(
254 x1: &ArrayBase<S1, Ix1>,
255 x2: &ArrayBase<S2, Ix1>,
256) -> f64
257where
258 S1: Data<Elem = f64>,
259 S2: Data<Elem = f64>,
260{
261 // Calculate the difference between the two vectors
262 let diff = x1 - x2;
263
264 // Calculate the sum of squares (fully vectorized)
265 diff.mapv(|x| x * x).sum()
266}
267
268/// Calculates the Manhattan (L1) distance between two vectors.
269///
270/// # Parameters
271///
272/// - `x1` - First vector
273/// - `x2` - Second vector
274///
275/// # Returns
276///
277/// - `f64` - Manhattan distance between the two vectors
278///
279/// # Examples
280/// ```rust
281/// use ndarray::array;
282/// use rustyml::math::manhattan_distance_row;
283///
284/// let v1 = array![1.0, 2.0];
285/// let v2 = array![4.0, 6.0];
286/// let distance = manhattan_distance_row(&v1, &v2);
287/// // |1-4| + |2-6| = 3 + 4 = 7
288/// assert!((distance - 7.0).abs() < 1e-6);
289/// ```
290#[inline]
291pub fn manhattan_distance_row<S1, S2>(x1: &ArrayBase<S1, Ix1>, x2: &ArrayBase<S2, Ix1>) -> f64
292where
293 S1: Data<Elem = f64>,
294 S2: Data<Elem = f64>,
295{
296 // Calculate the difference between the two vectors
297 let diff = x1 - x2;
298
299 // Calculate the sum of absolute differences (fully vectorized)
300 diff.mapv(|x| x.abs()).sum()
301}
302
303/// Calculates the Minkowski distance between two vectors.
304///
305/// Computes the p-norm of the difference between two 1D arrays.
306///
307/// # Parameters
308///
309/// - `x1` - First vector
310/// - `x2` - Second vector
311/// - `p` - Order of the norm (must be at least 1.0)
312///
313/// # Returns
314///
315/// - `f64` - Minkowski distance between the two vectors
316///
317/// # Examples
318/// ```rust
319/// use ndarray::array;
320/// use rustyml::math::minkowski_distance_row;
321///
322/// let v1 = array![1.0, 2.0];
323/// let v2 = array![4.0, 6.0];
324/// let distance = minkowski_distance_row(&v1, &v2, 3.0);
325/// // Expected distance is approximately 4.497
326/// assert!((distance - 4.497).abs() < 1e-3);
327/// ```
328#[inline]
329pub fn minkowski_distance_row<S1, S2>(
330 x1: &ArrayBase<S1, Ix1>,
331 x2: &ArrayBase<S2, Ix1>,
332 p: f64,
333) -> f64
334where
335 S1: Data<Elem = f64>,
336 S2: Data<Elem = f64>,
337{
338 // Calculate the difference between the two vectors
339 let diff = x1 - x2;
340
341 // Calculate the sum of absolute differences raised to power p,
342 // then take the p-th root of the sum
343 let sum: f64 = diff.mapv(|x| x.abs().powf(p)).sum();
344 sum.powf(1.0 / p)
345}
346
347/// Calculates the Gini impurity of a label set.
348///
349/// Gini impurity measures how frequently a randomly chosen element would be
350/// mislabeled if it were randomly labeled according to the distribution of labels.
351///
352/// # Parameters
353///
354/// - `y` - Class labels stored in a 1D array
355///
356/// # Returns
357///
358/// - `f64` - Gini impurity in the range \[0.0, 1.0\]
359///
360/// # Examples
361/// ```rust
362/// use ndarray::array;
363/// use rustyml::math::gini;
364///
365/// let labels = array![0.0, 0.0, 1.0, 1.0];
366/// let gini_val = gini(&labels);
367/// // For two classes with equal frequency, Gini = 1 - (0.5^2 + 0.5^2) = 0.5
368/// assert!((gini_val - 0.5).abs() < 1e-6);
369/// ```
370#[inline]
371pub fn gini<S>(y: &ArrayBase<S, Ix1>) -> f64
372where
373 S: Data<Elem = f64>,
374{
375 let total_samples = y.len() as f64;
376 if total_samples == 0.0 {
377 return 0.0;
378 }
379
380 // Pre-allocate capacity for the HashMap to avoid frequent reallocations
381 // A capacity of 10 is reasonable for most classification problems
382 let mut class_counts = AHashMap::with_capacity(10);
383
384 // Process all elements in the array with fold operation
385 y.fold((), |_, &value| {
386 // Handle NaN values - they should be treated as invalid input
387 if value.is_nan() {
388 return; // Skip NaN values
389 }
390
391 // Convert float to integer representation with 3 decimal places precision
392 let key = (value * 1000.0).round() as i64;
393 *class_counts.entry(key).or_insert(0) += 1;
394 });
395
396 // If all values were NaN, treat as empty dataset
397 if class_counts.is_empty() {
398 return 0.0;
399 }
400
401 // Calculate Gini impurity more efficiently
402 let mut sum_squared_proportions = 0.0;
403 for &count in class_counts.values() {
404 let p = count as f64 / total_samples;
405 sum_squared_proportions += p * p;
406 }
407
408 1.0 - sum_squared_proportions
409}
410
411/// Calculates the entropy of a label set.
412///
413/// Entropy quantifies the impurity or randomness in a dataset and is used
414/// by decision tree algorithms to evaluate split quality.
415///
416/// # Parameters
417///
418/// - `y` - Class labels stored in a 1D array
419///
420/// # Returns
421///
422/// - `f64` - Entropy value of the dataset (0.0 for homogeneous data)
423///
424/// # Examples
425/// ```rust
426/// use ndarray::array;
427/// use rustyml::math::entropy;
428///
429/// let labels = array![0.0, 1.0, 1.0, 0.0];
430/// let ent = entropy(&labels);
431/// // For two classes with equal frequency, entropy = 1.0
432/// assert!((ent - 1.0).abs() < 1e-6);
433/// ```
434#[inline]
435pub fn entropy<S>(y: &ArrayBase<S, Ix1>) -> f64
436where
437 S: Data<Elem = f64>,
438{
439 let total_samples = y.len() as f64;
440 if total_samples == 0.0 {
441 return 0.0;
442 }
443
444 // Pre-allocate capacity for the HashMap to avoid frequent reallocations
445 // A capacity of 10 is reasonable for most classification problems
446 let mut class_counts = AHashMap::with_capacity(10);
447
448 // Use fold operation instead of manual iteration for potential compiler optimizations
449 y.fold((), |_, &value| {
450 // Handle NaN values - they should be treated as invalid input
451 if value.is_nan() {
452 return; // Skip NaN values
453 }
454
455 // Convert float to integer representation with 3 decimal places precision
456 let key = (value * 1000.0).round() as i64;
457 *class_counts.entry(key).or_insert(0) += 1;
458 });
459
460 // If all values were NaN, treat as empty dataset
461 if class_counts.is_empty() {
462 return 0.0;
463 }
464
465 // Calculate entropy more efficiently with direct loop
466 let mut entropy = 0.0;
467 for &count in class_counts.values() {
468 let p = count as f64 / total_samples;
469 // Safeguard against log2(0), although this shouldn't happen in this context
470 if p > 0.0 {
471 entropy -= p * p.log2();
472 }
473 }
474
475 entropy
476}
477
478/// Calculates the information gain when splitting a dataset.
479///
480/// Information gain measures the reduction in entropy achieved by dividing a
481/// dataset into child nodes, guiding feature selection in decision trees.
482///
483/// # Parameters
484///
485/// - `y` - Class labels in the parent node
486/// - `left_y` - Class labels in the left child node
487/// - `right_y` - Class labels in the right child node
488///
489/// # Returns
490///
491/// - `f64` - Information gain for the proposed split
492///
493/// # Examples
494/// ```rust
495/// use ndarray::array;
496/// use rustyml::math::information_gain;
497///
498/// let parent = array![0.0, 0.0, 1.0, 1.0];
499/// let left = array![0.0, 0.0];
500/// let right = array![1.0, 1.0];
501/// let ig = information_gain(&parent, &left, &right);
502/// // Entropy(parent)=1.0, Entropy(left)=Entropy(right)=0, so IG = 1.0
503/// assert!((ig - 1.0).abs() < 1e-6);
504/// ```
505#[inline]
506pub fn information_gain<S1, S2, S3>(
507 y: &ArrayBase<S1, Ix1>,
508 left_y: &ArrayBase<S2, Ix1>,
509 right_y: &ArrayBase<S3, Ix1>,
510) -> f64
511where
512 S1: Data<Elem = f64>,
513 S2: Data<Elem = f64>,
514 S3: Data<Elem = f64>,
515{
516 // Calculate sample counts once
517 let n = y.len() as f64;
518
519 // Early return for edge cases
520 if n == 0.0 {
521 return 0.0;
522 }
523
524 let n_left = left_y.len() as f64;
525 let n_right = right_y.len() as f64;
526
527 // Check for invalid split ratios - if child counts don't match parent, return 0
528 if (n_left + n_right - n).abs() > 1e-10 {
529 return 0.0;
530 }
531
532 // Calculate entropy values
533 let e = entropy(y);
534
535 // If parent node is already pure, no information gain is possible
536 if e.abs() < f64::EPSILON {
537 return 0.0;
538 }
539
540 let e_left = entropy(left_y);
541 let e_right = entropy(right_y);
542
543 // Calculate the weighted average entropy of children
544 let weighted_child_entropy = (n_left / n) * e_left + (n_right / n) * e_right;
545
546 // Information gain = parent entropy - weighted sum of child entropies
547 e - weighted_child_entropy
548}
549
550/// Calculates the gain ratio for a dataset split.
551///
552/// Gain ratio normalizes information gain by the entropy of the split to reduce
553/// bias toward features with many distinct values.
554///
555/// # Parameters
556///
557/// - `y` - Class labels in the parent node
558/// - `left_y` - Class labels in the left child node
559/// - `right_y` - Class labels in the right child node
560///
561/// # Returns
562///
563/// - `f64` - Gain ratio value for the proposed split
564///
565/// # Examples
566/// ```rust
567/// use ndarray::array;
568/// use rustyml::math::gain_ratio;
569///
570/// let parent = array![0.0, 0.0, 1.0, 1.0];
571/// let left = array![0.0, 0.0];
572/// let right = array![1.0, 1.0];
573/// let gr = gain_ratio(&parent, &left, &right);
574/// // With equal splits, gain ratio should be 1.0
575/// assert!((gr - 1.0).abs() < 1e-6);
576/// ```
577#[inline]
578pub fn gain_ratio<S1, S2, S3>(
579 y: &ArrayBase<S1, Ix1>,
580 left_y: &ArrayBase<S2, Ix1>,
581 right_y: &ArrayBase<S3, Ix1>,
582) -> f64
583where
584 S1: Data<Elem = f64>,
585 S2: Data<Elem = f64>,
586 S3: Data<Elem = f64>,
587{
588 // Early return if parent is empty or either split is empty
589 if y.is_empty() || left_y.is_empty() || right_y.is_empty() {
590 return 0.0;
591 }
592
593 // Calculate sample counts once to avoid redundant computations
594 let n = y.len() as f64;
595 let n_left = left_y.len() as f64;
596 let n_right = right_y.len() as f64;
597
598 // Calculate information gain
599 let info_gain = information_gain(y, left_y, right_y);
600
601 // If information gain is negligible, return 0 to avoid unnecessary calculations
602 if info_gain < f64::EPSILON {
603 return 0.0;
604 }
605
606 // Calculate the proportions for split information
607 let p_left = n_left / n;
608 let p_right = n_right / n;
609
610 // Calculate split information, which measures the potential information of the split
611 // Handle edge cases where one of the proportions is 0
612 let mut split_info = 0.0;
613 if p_left > 0.0 {
614 split_info -= p_left * p_left.log2();
615 }
616 if p_right > 0.0 {
617 split_info -= p_right * p_right.log2();
618 }
619
620 // Avoid division by zero
621 if split_info < f64::EPSILON {
622 0.0
623 } else {
624 info_gain / split_info
625 }
626}
627
628/// Calculates the population standard deviation of a set of values.
629///
630/// # Parameters
631///
632/// - `values` - Values to measure dispersion
633///
634/// # Returns
635///
636/// - `f64` - Population standard deviation (0.0 when the array is empty)
637///
638/// # Examples
639/// ```rust
640/// use ndarray::array;
641/// use rustyml::math::standard_deviation;
642///
643/// let values = array![1.0, 2.0, 3.0];
644/// let std_dev = standard_deviation(&values);
645/// // Population standard deviation for [1,2,3] is approximately 0.8165
646/// assert!((std_dev - 0.8165).abs() < 1e-4);
647/// ```
648#[inline]
649pub fn standard_deviation<S>(values: &ArrayBase<S, Ix1>) -> f64
650where
651 S: Data<Elem = f64>,
652{
653 let n = values.len();
654
655 // Return 0.0 for empty arrays
656 if n == 0 {
657 return 0.0;
658 }
659
660 // Use built-in methods when available for better performance
661 // calculate variance and then take the square root
662
663 // First calculate the mean efficiently
664 let mean = values.mean().unwrap(); // Safe since we've validated input
665
666 // Calculate variance in one pass
667 let variance = values.fold(0.0, |acc, &x| {
668 let diff = x - mean;
669 acc + diff * diff
670 }) / n as f64;
671
672 // Take the square root for standard deviation
673 variance.sqrt()
674}
675
676/// Calculates the average path length adjustment factor for isolation trees.
677///
678/// This is the correction factor `c(n)` used in isolation forests to normalize
679/// path lengths based on the expected height of a binary search tree.
680///
681/// # Parameters
682///
683/// - `n` - Number of samples in the isolation tree node (must be greater than 0)
684///
685/// # Returns
686///
687/// - `f64` - Adjustment factor for path length normalization:
688/// - 0.0 for `n <= 1`
689/// - 1.0 for `n == 2`
690/// - Computed correction factor for larger `n`
691///
692/// # Examples
693/// ```rust
694/// use rustyml::math::average_path_length_factor;
695///
696/// let factor_small = average_path_length_factor(10);
697/// let factor_large = average_path_length_factor(1000);
698/// assert_eq!(average_path_length_factor(0), 0.0);
699/// assert_eq!(average_path_length_factor(1), 0.0);
700/// assert_eq!(average_path_length_factor(2), 1.0);
701/// assert!(factor_small > 0.0);
702/// assert!(factor_large > factor_small);
703/// ```
704#[inline]
705pub fn average_path_length_factor(n: usize) -> f64 {
706 if n <= 1 {
707 return 0.0;
708 }
709 if n == 2 {
710 return 1.0;
711 }
712
713 let h_n_minus_1 = if n > 50 {
714 ((n - 1) as f64).ln() + EULER_GAMMA
715 } else {
716 (1..n).map(|i| 1.0 / i as f64).sum::<f64>()
717 };
718
719 2.0 * h_n_minus_1 - 2.0 * (n - 1) as f64 / n as f64
720}
721
722/// Finds the sigma value that matches a target perplexity for distance-derived probabilities.
723///
724/// Uses binary search to tune the precision parameter so the resulting probability
725/// distribution has the desired perplexity.
726///
727/// # Parameters
728///
729/// - `distances` - Squared Euclidean distances from a point to all others
730/// - `target_perplexity` - Desired perplexity controlling neighborhood size
731///
732/// # Returns
733///
734/// - `(Array1<f64>, f64)` - Probability distribution and the sigma value that achieves the target perplexity
735///
736/// # Examples
737/// ```rust
738/// use ndarray::array;
739/// use rustyml::math::binary_search_sigma;
740///
741/// let distances = array![0.0, 1.0, 4.0, 9.0, 16.0];
742/// let target_perplexity = 2.0;
743/// let (probabilities, sigma) = binary_search_sigma(&distances, target_perplexity);
744/// // The function returns probabilities and sigma that achieve the target perplexity
745/// assert_eq!(probabilities.len(), 5);
746/// assert!(sigma > 0.0);
747/// ```
748pub fn binary_search_sigma<S>(
749 distances: &ArrayBase<S, Ix1>,
750 target_perplexity: f64,
751) -> (Array1<f64>, f64)
752where
753 S: Data<Elem = f64>,
754{
755 let tol = 1e-5;
756 let mut sigma_min: f64 = 1e-20;
757 let mut sigma_max: f64 = 1e20;
758 let mut sigma: f64 = 1.0;
759 let n = distances.len();
760 let mut p = Array1::<f64>::zeros(n);
761
762 for _ in 0..50 {
763 for (j, &d) in distances.iter().enumerate() {
764 p[j] = if d == 0.0 {
765 0.0
766 } else {
767 (-d / (2.0 * sigma * sigma)).exp()
768 };
769 }
770
771 let sum_p = p.sum();
772 let epsilon = 1e-12;
773
774 if sum_p < epsilon {
775 // If sum is too small, use uniform distribution
776 p.fill(1.0 / n as f64);
777 } else {
778 p.mapv_inplace(|v| v / sum_p);
779 }
780
781 let h: f64 = p
782 .iter()
783 .map(|&v| if v > 1e-10 { -v * v.ln() } else { 0.0 })
784 .sum();
785 let current_perplexity = h.exp();
786 let diff = current_perplexity - target_perplexity;
787 if diff.abs() < tol {
788 break;
789 }
790 if diff > 0.0 {
791 sigma_min = sigma;
792 if sigma_max.is_infinite() {
793 sigma *= 2.0;
794 } else {
795 sigma = (sigma + sigma_max) / 2.0;
796 }
797 } else {
798 sigma_max = sigma;
799 sigma = (sigma + sigma_min) / 2.0;
800 }
801 }
802 (p, sigma)
803}