Skip to main content

rustyml/
math.rs

1use ahash::AHashMap;
2use ndarray::{Array1, ArrayBase, Data, Ix1};
3
4const EULER_GAMMA: f64 = 0.57721566490153286060651209008240243104215933593992;
5
6/// Calculates the total sum of squares (SST).
7///
8/// SST measures the total variability in the data as the sum of squared
9/// differences between each value and the mean of all values.
10///
11/// # Parameters
12///
13/// - `values` - Observed values stored in a 1D array
14///
15/// # Returns
16///
17/// - `f64` - Total sum of squares for the provided values
18///
19/// # Examples
20/// ```rust
21/// use rustyml::math::sum_of_square_total;
22/// use ndarray::array;
23///
24/// let values = array![1.0, 2.0, 3.0];
25/// let sst = sum_of_square_total(&values);
26/// // Mean is 2.0, so SST = (1-2)^2 + (2-2)^2 + (3-2)^2 = 1 + 0 + 1 = 2.0
27/// assert!((sst - 2.0).abs() < 1e-5);
28/// ```
29#[inline]
30pub fn sum_of_square_total<S>(values: &ArrayBase<S, Ix1>) -> f64
31where
32    S: Data<Elem = f64>,
33{
34    // Handle empty array case
35    if values.is_empty() {
36        return 0.0;
37    }
38    // Calculate the mean
39    let mean = values.mean().unwrap();
40    // Fully vectorized computation
41    values.mapv(|x| (x - mean).powi(2)).sum()
42}
43
44/// Calculates the sum of squared errors (SSE).
45///
46/// SSE measures the total squared difference between predicted values and actual labels.
47///
48/// # Parameters
49///
50/// - `predicted` - Predicted values vector
51/// - `actual` - Actual values vector
52///
53/// # Returns
54///
55/// - `f64` - Sum of squared errors computed as sum((predicted_i - actual_i)^2)
56///
57/// # Examples
58/// ```rust
59/// use rustyml::math::sum_of_squared_errors;
60/// use ndarray::array;
61///
62/// let predicted = array![2.0, 3.0];
63/// let actual = array![1.0, 3.0];
64/// let sse = sum_of_squared_errors(&predicted, &actual);
65/// // (2-1)^2 + (3-3)^2 = 1 + 0 = 1
66/// assert!((sse - 1.0).abs() < 1e-6);
67/// ```
68#[inline]
69pub fn sum_of_squared_errors<S1, S2>(
70    predicted: &ArrayBase<S1, Ix1>,
71    actual: &ArrayBase<S2, Ix1>,
72) -> f64
73where
74    S1: Data<Elem = f64>,
75    S2: Data<Elem = f64>,
76{
77    let sum: f64 = predicted
78        .iter()
79        .zip(actual.iter())
80        .map(|(p, a)| (p - a).powi(2))
81        .sum();
82
83    sum
84}
85
86/// Calculates the mean squared error (variance) of a set of values.
87///
88/// The variance is the average of the squared differences between each value
89/// and the mean of all values.
90///
91/// # Parameters
92///
93/// - `y` - Values for which to calculate the variance
94///
95/// # Returns
96///
97/// - `f64` - Variance of the input values (0.0 when the array is empty)
98///
99/// # Examples
100/// ```rust
101/// use ndarray::array;
102/// use rustyml::math::variance;
103///
104/// let values = array![1.0, 2.0, 3.0];
105/// let mse = variance(&values);
106/// // Mean is 2.0, so variance = ((1-2)^2 + (2-2)^2 + (3-2)^2) / 3 = (1 + 0 + 1) / 3 ~= 0.66667
107/// assert!((mse - 0.6666667).abs() < 1e-6);
108/// ```
109#[inline]
110pub fn variance<S>(y: &ArrayBase<S, Ix1>) -> f64
111where
112    S: Data<Elem = f64>,
113{
114    let n = y.len();
115
116    // Return 0.0 for empty arrays
117    if n == 0 {
118        return 0.0;
119    }
120
121    // Calculate mean using ndarray's mean method
122    // Handle potential NaN case when all values are NaN
123    let mean = match y.mean() {
124        Some(m) if m.is_finite() => m,
125        _ => return 0.0, // Return 0.0 if mean is NaN or can't be calculated
126    };
127
128    // Use fold for potentially better performance than map/sum
129    // This computes the sum of squared differences in one pass
130    let sum_squared_diff = y.fold(0.0, |acc, &val| {
131        if val.is_finite() {
132            let diff = val - mean;
133            acc + diff * diff
134        } else {
135            acc // Skip NaN or infinite values
136        }
137    });
138
139    // Compute variance (MSE)
140    sum_squared_diff / n as f64
141}
142
143/// Computes the logistic sigmoid for a scalar input.
144///
145/// The sigmoid maps any real number into the open interval (0, 1) with clipping
146/// for extreme values to preserve numerical stability.
147///
148/// # Parameters
149///
150/// - `z` - Input value to transform
151///
152/// # Returns
153///
154/// - `f64` - Sigmoid output in the range (0, 1)
155///
156/// # Examples
157/// ```rust
158/// use rustyml::math::sigmoid;
159///
160/// let value = sigmoid(0.0);
161/// // sigmoid(0) = 0.5
162/// assert!((value - 0.5).abs() < 1e-6);
163/// ```
164#[inline]
165pub fn sigmoid(z: f64) -> f64 {
166    // Use numerically stable computation for extreme values
167    const MAX_SIGMOID_INPUT: f64 = 500.0;
168    const MIN_SIGMOID_INPUT: f64 = -500.0;
169
170    if z > MAX_SIGMOID_INPUT {
171        // For very large positive values, sigmoid(z) approaches 1
172        return 1.0;
173    } else if z < MIN_SIGMOID_INPUT {
174        // For very large negative values, sigmoid(z) approaches 0
175        return 0.0;
176    }
177
178    // Standard computation for normal range
179    1.0 / (1.0 + (-z).exp())
180}
181
182/// Calculates the logistic regression loss (log loss).
183///
184/// This computes the average cross-entropy loss by applying the sigmoid
185/// to raw logits before evaluating the log-likelihood.
186///
187/// # Parameters
188///
189/// - `logits` - Raw model outputs (logits before sigmoid)
190/// - `actual_labels` - Binary labels (0 or 1)
191///
192/// # Returns
193///
194/// - `f64` - Average logistic regression loss
195///
196/// # Examples
197/// ```rust
198/// use rustyml::math::logistic_loss;
199/// use ndarray::array;
200///
201/// let logits = array![0.0, 2.0, -1.0];
202/// let actual_labels = array![0.0, 1.0, 0.0];
203/// let loss = logistic_loss(&logits, &actual_labels);
204/// // Expected average loss is approximately 0.37778
205/// assert!((loss - 0.37778).abs() < 1e-5);
206/// ```
207#[inline]
208pub fn logistic_loss<S1, S2>(logits: &ArrayBase<S1, Ix1>, actual_labels: &ArrayBase<S2, Ix1>) -> f64
209where
210    S1: Data<Elem = f64>,
211    S2: Data<Elem = f64>,
212{
213    // Using a vectorized approach to calculate log loss
214    let n = logits.len() as f64;
215
216    // Calculate total loss using zip to iterate through both arrays simultaneously
217    let total_loss = logits
218        .iter()
219        .zip(actual_labels.iter())
220        .map(|(&x, &y)| {
221            // Numerically stable way to calculate log loss:
222            // max(0, x) - x*y + log(1 + exp(-|x|))
223            x.max(0.0) - x * y + (1.0 + (-x.abs()).exp()).ln()
224        })
225        .sum::<f64>();
226
227    total_loss / n
228}
229
230/// Calculates the squared Euclidean distance between two vectors.
231///
232/// # Parameters
233///
234/// - `x1` - First vector
235/// - `x2` - Second vector
236///
237/// # Returns
238///
239/// - `f64` - Squared Euclidean distance between the two vectors
240///
241/// # Examples
242/// ```rust
243/// use ndarray::array;
244/// use rustyml::math::squared_euclidean_distance_row;
245///
246/// let v1 = array![1.0, 2.0, 3.0];
247/// let v2 = array![4.0, 5.0, 6.0];
248/// let dist = squared_euclidean_distance_row(&v1, &v2);
249/// // (4-1)^2 + (5-2)^2 + (6-3)^2 = 9 + 9 + 9 = 27
250/// assert!((dist - 27.0).abs() < 1e-10);
251/// ```
252#[inline]
253pub fn squared_euclidean_distance_row<S1, S2>(
254    x1: &ArrayBase<S1, Ix1>,
255    x2: &ArrayBase<S2, Ix1>,
256) -> f64
257where
258    S1: Data<Elem = f64>,
259    S2: Data<Elem = f64>,
260{
261    // Calculate the difference between the two vectors
262    let diff = x1 - x2;
263
264    // Calculate the sum of squares (fully vectorized)
265    diff.mapv(|x| x * x).sum()
266}
267
268/// Calculates the Manhattan (L1) distance between two vectors.
269///
270/// # Parameters
271///
272/// - `x1` - First vector
273/// - `x2` - Second vector
274///
275/// # Returns
276///
277/// - `f64` - Manhattan distance between the two vectors
278///
279/// # Examples
280/// ```rust
281/// use ndarray::array;
282/// use rustyml::math::manhattan_distance_row;
283///
284/// let v1 = array![1.0, 2.0];
285/// let v2 = array![4.0, 6.0];
286/// let distance = manhattan_distance_row(&v1, &v2);
287/// // |1-4| + |2-6| = 3 + 4 = 7
288/// assert!((distance - 7.0).abs() < 1e-6);
289/// ```
290#[inline]
291pub fn manhattan_distance_row<S1, S2>(x1: &ArrayBase<S1, Ix1>, x2: &ArrayBase<S2, Ix1>) -> f64
292where
293    S1: Data<Elem = f64>,
294    S2: Data<Elem = f64>,
295{
296    // Calculate the difference between the two vectors
297    let diff = x1 - x2;
298
299    // Calculate the sum of absolute differences (fully vectorized)
300    diff.mapv(|x| x.abs()).sum()
301}
302
303/// Calculates the Minkowski distance between two vectors.
304///
305/// Computes the p-norm of the difference between two 1D arrays.
306///
307/// # Parameters
308///
309/// - `x1` - First vector
310/// - `x2` - Second vector
311/// - `p` - Order of the norm (must be at least 1.0)
312///
313/// # Returns
314///
315/// - `f64` - Minkowski distance between the two vectors
316///
317/// # Examples
318/// ```rust
319/// use ndarray::array;
320/// use rustyml::math::minkowski_distance_row;
321///
322/// let v1 = array![1.0, 2.0];
323/// let v2 = array![4.0, 6.0];
324/// let distance = minkowski_distance_row(&v1, &v2, 3.0);
325/// // Expected distance is approximately 4.497
326/// assert!((distance - 4.497).abs() < 1e-3);
327/// ```
328#[inline]
329pub fn minkowski_distance_row<S1, S2>(
330    x1: &ArrayBase<S1, Ix1>,
331    x2: &ArrayBase<S2, Ix1>,
332    p: f64,
333) -> f64
334where
335    S1: Data<Elem = f64>,
336    S2: Data<Elem = f64>,
337{
338    // Calculate the difference between the two vectors
339    let diff = x1 - x2;
340
341    // Calculate the sum of absolute differences raised to power p,
342    // then take the p-th root of the sum
343    let sum: f64 = diff.mapv(|x| x.abs().powf(p)).sum();
344    sum.powf(1.0 / p)
345}
346
347/// Calculates the Gini impurity of a label set.
348///
349/// Gini impurity measures how frequently a randomly chosen element would be
350/// mislabeled if it were randomly labeled according to the distribution of labels.
351///
352/// # Parameters
353///
354/// - `y` - Class labels stored in a 1D array
355///
356/// # Returns
357///
358/// - `f64` - Gini impurity in the range \[0.0, 1.0\]
359///
360/// # Examples
361/// ```rust
362/// use ndarray::array;
363/// use rustyml::math::gini;
364///
365/// let labels = array![0.0, 0.0, 1.0, 1.0];
366/// let gini_val = gini(&labels);
367/// // For two classes with equal frequency, Gini = 1 - (0.5^2 + 0.5^2) = 0.5
368/// assert!((gini_val - 0.5).abs() < 1e-6);
369/// ```
370#[inline]
371pub fn gini<S>(y: &ArrayBase<S, Ix1>) -> f64
372where
373    S: Data<Elem = f64>,
374{
375    let total_samples = y.len() as f64;
376    if total_samples == 0.0 {
377        return 0.0;
378    }
379
380    // Pre-allocate capacity for the HashMap to avoid frequent reallocations
381    // A capacity of 10 is reasonable for most classification problems
382    let mut class_counts = AHashMap::with_capacity(10);
383
384    // Process all elements in the array with fold operation
385    y.fold((), |_, &value| {
386        // Handle NaN values - they should be treated as invalid input
387        if value.is_nan() {
388            return; // Skip NaN values
389        }
390
391        // Convert float to integer representation with 3 decimal places precision
392        let key = (value * 1000.0).round() as i64;
393        *class_counts.entry(key).or_insert(0) += 1;
394    });
395
396    // If all values were NaN, treat as empty dataset
397    if class_counts.is_empty() {
398        return 0.0;
399    }
400
401    // Calculate Gini impurity more efficiently
402    let mut sum_squared_proportions = 0.0;
403    for &count in class_counts.values() {
404        let p = count as f64 / total_samples;
405        sum_squared_proportions += p * p;
406    }
407
408    1.0 - sum_squared_proportions
409}
410
411/// Calculates the entropy of a label set.
412///
413/// Entropy quantifies the impurity or randomness in a dataset and is used
414/// by decision tree algorithms to evaluate split quality.
415///
416/// # Parameters
417///
418/// - `y` - Class labels stored in a 1D array
419///
420/// # Returns
421///
422/// - `f64` - Entropy value of the dataset (0.0 for homogeneous data)
423///
424/// # Examples
425/// ```rust
426/// use ndarray::array;
427/// use rustyml::math::entropy;
428///
429/// let labels = array![0.0, 1.0, 1.0, 0.0];
430/// let ent = entropy(&labels);
431/// // For two classes with equal frequency, entropy = 1.0
432/// assert!((ent - 1.0).abs() < 1e-6);
433/// ```
434#[inline]
435pub fn entropy<S>(y: &ArrayBase<S, Ix1>) -> f64
436where
437    S: Data<Elem = f64>,
438{
439    let total_samples = y.len() as f64;
440    if total_samples == 0.0 {
441        return 0.0;
442    }
443
444    // Pre-allocate capacity for the HashMap to avoid frequent reallocations
445    // A capacity of 10 is reasonable for most classification problems
446    let mut class_counts = AHashMap::with_capacity(10);
447
448    // Use fold operation instead of manual iteration for potential compiler optimizations
449    y.fold((), |_, &value| {
450        // Handle NaN values - they should be treated as invalid input
451        if value.is_nan() {
452            return; // Skip NaN values
453        }
454
455        // Convert float to integer representation with 3 decimal places precision
456        let key = (value * 1000.0).round() as i64;
457        *class_counts.entry(key).or_insert(0) += 1;
458    });
459
460    // If all values were NaN, treat as empty dataset
461    if class_counts.is_empty() {
462        return 0.0;
463    }
464
465    // Calculate entropy more efficiently with direct loop
466    let mut entropy = 0.0;
467    for &count in class_counts.values() {
468        let p = count as f64 / total_samples;
469        // Safeguard against log2(0), although this shouldn't happen in this context
470        if p > 0.0 {
471            entropy -= p * p.log2();
472        }
473    }
474
475    entropy
476}
477
478/// Calculates the information gain when splitting a dataset.
479///
480/// Information gain measures the reduction in entropy achieved by dividing a
481/// dataset into child nodes, guiding feature selection in decision trees.
482///
483/// # Parameters
484///
485/// - `y` - Class labels in the parent node
486/// - `left_y` - Class labels in the left child node
487/// - `right_y` - Class labels in the right child node
488///
489/// # Returns
490///
491/// - `f64` - Information gain for the proposed split
492///
493/// # Examples
494/// ```rust
495/// use ndarray::array;
496/// use rustyml::math::information_gain;
497///
498/// let parent = array![0.0, 0.0, 1.0, 1.0];
499/// let left = array![0.0, 0.0];
500/// let right = array![1.0, 1.0];
501/// let ig = information_gain(&parent, &left, &right);
502/// // Entropy(parent)=1.0, Entropy(left)=Entropy(right)=0, so IG = 1.0
503/// assert!((ig - 1.0).abs() < 1e-6);
504/// ```
505#[inline]
506pub fn information_gain<S1, S2, S3>(
507    y: &ArrayBase<S1, Ix1>,
508    left_y: &ArrayBase<S2, Ix1>,
509    right_y: &ArrayBase<S3, Ix1>,
510) -> f64
511where
512    S1: Data<Elem = f64>,
513    S2: Data<Elem = f64>,
514    S3: Data<Elem = f64>,
515{
516    // Calculate sample counts once
517    let n = y.len() as f64;
518
519    // Early return for edge cases
520    if n == 0.0 {
521        return 0.0;
522    }
523
524    let n_left = left_y.len() as f64;
525    let n_right = right_y.len() as f64;
526
527    // Check for invalid split ratios - if child counts don't match parent, return 0
528    if (n_left + n_right - n).abs() > 1e-10 {
529        return 0.0;
530    }
531
532    // Calculate entropy values
533    let e = entropy(y);
534
535    // If parent node is already pure, no information gain is possible
536    if e.abs() < f64::EPSILON {
537        return 0.0;
538    }
539
540    let e_left = entropy(left_y);
541    let e_right = entropy(right_y);
542
543    // Calculate the weighted average entropy of children
544    let weighted_child_entropy = (n_left / n) * e_left + (n_right / n) * e_right;
545
546    // Information gain = parent entropy - weighted sum of child entropies
547    e - weighted_child_entropy
548}
549
550/// Calculates the gain ratio for a dataset split.
551///
552/// Gain ratio normalizes information gain by the entropy of the split to reduce
553/// bias toward features with many distinct values.
554///
555/// # Parameters
556///
557/// - `y` - Class labels in the parent node
558/// - `left_y` - Class labels in the left child node
559/// - `right_y` - Class labels in the right child node
560///
561/// # Returns
562///
563/// - `f64` - Gain ratio value for the proposed split
564///
565/// # Examples
566/// ```rust
567/// use ndarray::array;
568/// use rustyml::math::gain_ratio;
569///
570/// let parent = array![0.0, 0.0, 1.0, 1.0];
571/// let left = array![0.0, 0.0];
572/// let right = array![1.0, 1.0];
573/// let gr = gain_ratio(&parent, &left, &right);
574/// // With equal splits, gain ratio should be 1.0
575/// assert!((gr - 1.0).abs() < 1e-6);
576/// ```
577#[inline]
578pub fn gain_ratio<S1, S2, S3>(
579    y: &ArrayBase<S1, Ix1>,
580    left_y: &ArrayBase<S2, Ix1>,
581    right_y: &ArrayBase<S3, Ix1>,
582) -> f64
583where
584    S1: Data<Elem = f64>,
585    S2: Data<Elem = f64>,
586    S3: Data<Elem = f64>,
587{
588    // Early return if parent is empty or either split is empty
589    if y.is_empty() || left_y.is_empty() || right_y.is_empty() {
590        return 0.0;
591    }
592
593    // Calculate sample counts once to avoid redundant computations
594    let n = y.len() as f64;
595    let n_left = left_y.len() as f64;
596    let n_right = right_y.len() as f64;
597
598    // Calculate information gain
599    let info_gain = information_gain(y, left_y, right_y);
600
601    // If information gain is negligible, return 0 to avoid unnecessary calculations
602    if info_gain < f64::EPSILON {
603        return 0.0;
604    }
605
606    // Calculate the proportions for split information
607    let p_left = n_left / n;
608    let p_right = n_right / n;
609
610    // Calculate split information, which measures the potential information of the split
611    // Handle edge cases where one of the proportions is 0
612    let mut split_info = 0.0;
613    if p_left > 0.0 {
614        split_info -= p_left * p_left.log2();
615    }
616    if p_right > 0.0 {
617        split_info -= p_right * p_right.log2();
618    }
619
620    // Avoid division by zero
621    if split_info < f64::EPSILON {
622        0.0
623    } else {
624        info_gain / split_info
625    }
626}
627
628/// Calculates the population standard deviation of a set of values.
629///
630/// # Parameters
631///
632/// - `values` - Values to measure dispersion
633///
634/// # Returns
635///
636/// - `f64` - Population standard deviation (0.0 when the array is empty)
637///
638/// # Examples
639/// ```rust
640/// use ndarray::array;
641/// use rustyml::math::standard_deviation;
642///
643/// let values = array![1.0, 2.0, 3.0];
644/// let std_dev = standard_deviation(&values);
645/// // Population standard deviation for [1,2,3] is approximately 0.8165
646/// assert!((std_dev - 0.8165).abs() < 1e-4);
647/// ```
648#[inline]
649pub fn standard_deviation<S>(values: &ArrayBase<S, Ix1>) -> f64
650where
651    S: Data<Elem = f64>,
652{
653    let n = values.len();
654
655    // Return 0.0 for empty arrays
656    if n == 0 {
657        return 0.0;
658    }
659
660    // Use built-in methods when available for better performance
661    // calculate variance and then take the square root
662
663    // First calculate the mean efficiently
664    let mean = values.mean().unwrap(); // Safe since we've validated input
665
666    // Calculate variance in one pass
667    let variance = values.fold(0.0, |acc, &x| {
668        let diff = x - mean;
669        acc + diff * diff
670    }) / n as f64;
671
672    // Take the square root for standard deviation
673    variance.sqrt()
674}
675
676/// Calculates the average path length adjustment factor for isolation trees.
677///
678/// This is the correction factor `c(n)` used in isolation forests to normalize
679/// path lengths based on the expected height of a binary search tree.
680///
681/// # Parameters
682///
683/// - `n` - Number of samples in the isolation tree node (must be greater than 0)
684///
685/// # Returns
686///
687/// - `f64` - Adjustment factor for path length normalization:
688///   - 0.0 for `n <= 1`
689///   - 1.0 for `n == 2`
690///   - Computed correction factor for larger `n`
691///
692/// # Examples
693/// ```rust
694/// use rustyml::math::average_path_length_factor;
695///
696/// let factor_small = average_path_length_factor(10);
697/// let factor_large = average_path_length_factor(1000);
698/// assert_eq!(average_path_length_factor(0), 0.0);
699/// assert_eq!(average_path_length_factor(1), 0.0);
700/// assert_eq!(average_path_length_factor(2), 1.0);
701/// assert!(factor_small > 0.0);
702/// assert!(factor_large > factor_small);
703/// ```
704#[inline]
705pub fn average_path_length_factor(n: usize) -> f64 {
706    if n <= 1 {
707        return 0.0;
708    }
709    if n == 2 {
710        return 1.0;
711    }
712
713    let h_n_minus_1 = if n > 50 {
714        ((n - 1) as f64).ln() + EULER_GAMMA
715    } else {
716        (1..n).map(|i| 1.0 / i as f64).sum::<f64>()
717    };
718
719    2.0 * h_n_minus_1 - 2.0 * (n - 1) as f64 / n as f64
720}
721
722/// Finds the sigma value that matches a target perplexity for distance-derived probabilities.
723///
724/// Uses binary search to tune the precision parameter so the resulting probability
725/// distribution has the desired perplexity.
726///
727/// # Parameters
728///
729/// - `distances` - Squared Euclidean distances from a point to all others
730/// - `target_perplexity` - Desired perplexity controlling neighborhood size
731///
732/// # Returns
733///
734/// - `(Array1<f64>, f64)` - Probability distribution and the sigma value that achieves the target perplexity
735///
736/// # Examples
737/// ```rust
738/// use ndarray::array;
739/// use rustyml::math::binary_search_sigma;
740///
741/// let distances = array![0.0, 1.0, 4.0, 9.0, 16.0];
742/// let target_perplexity = 2.0;
743/// let (probabilities, sigma) = binary_search_sigma(&distances, target_perplexity);
744/// // The function returns probabilities and sigma that achieve the target perplexity
745/// assert_eq!(probabilities.len(), 5);
746/// assert!(sigma > 0.0);
747/// ```
748pub fn binary_search_sigma<S>(
749    distances: &ArrayBase<S, Ix1>,
750    target_perplexity: f64,
751) -> (Array1<f64>, f64)
752where
753    S: Data<Elem = f64>,
754{
755    let tol = 1e-5;
756    let mut sigma_min: f64 = 1e-20;
757    let mut sigma_max: f64 = 1e20;
758    let mut sigma: f64 = 1.0;
759    let n = distances.len();
760    let mut p = Array1::<f64>::zeros(n);
761
762    for _ in 0..50 {
763        for (j, &d) in distances.iter().enumerate() {
764            p[j] = if d == 0.0 {
765                0.0
766            } else {
767                (-d / (2.0 * sigma * sigma)).exp()
768            };
769        }
770
771        let sum_p = p.sum();
772        let epsilon = 1e-12;
773
774        if sum_p < epsilon {
775            // If sum is too small, use uniform distribution
776            p.fill(1.0 / n as f64);
777        } else {
778            p.mapv_inplace(|v| v / sum_p);
779        }
780
781        let h: f64 = p
782            .iter()
783            .map(|&v| if v > 1e-10 { -v * v.ln() } else { 0.0 })
784            .sum();
785        let current_perplexity = h.exp();
786        let diff = current_perplexity - target_perplexity;
787        if diff.abs() < tol {
788            break;
789        }
790        if diff > 0.0 {
791            sigma_min = sigma;
792            if sigma_max.is_infinite() {
793                sigma *= 2.0;
794            } else {
795                sigma = (sigma + sigma_max) / 2.0;
796            }
797        } else {
798            sigma_max = sigma;
799            sigma = (sigma + sigma_min) / 2.0;
800        }
801    }
802    (p, sigma)
803}