solvr 0.2.0

Advanced computing library for real-world problem solving - optimization, differential equations, interpolation, statistics, and more
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
//! Generic information theory implementations.
use crate::DType;

use crate::stats::helpers::extract_scalar;
use crate::stats::validate_stats_dtype;
use numr::error::{Error, Result};
use numr::ops::TensorOps;
use numr::runtime::{Runtime, RuntimeClient};
use numr::tensor::Tensor;

/// Generic implementation of Shannon entropy.
pub fn entropy_impl<R, C>(client: &C, pk: &Tensor<R>, base: Option<f64>) -> Result<Tensor<R>>
where
    R: Runtime<DType = DType>,
    C: TensorOps<R> + RuntimeClient<R>,
{
    validate_stats_dtype(pk.dtype())?;

    let pk_contig = pk.contiguous()?;
    let n = pk_contig.numel();
    if n == 0 {
        return Err(Error::InvalidArgument {
            arg: "pk",
            reason: "empty distribution".to_string(),
        });
    }

    // H = -Σ p * log(p), treating 0*log(0) = 0
    let epsilon = Tensor::<R>::full_scalar(pk_contig.shape(), pk.dtype(), 1e-300, client.device());
    let pk_safe = client.maximum(&pk_contig, &epsilon)?;
    let log_pk = client.log(&pk_safe)?;

    let p_log_p = client.mul(&pk_contig, &log_pk)?;

    let all_dims: Vec<usize> = (0..p_log_p.ndim()).collect();
    let sum = extract_scalar(&client.sum(&p_log_p, &all_dims, false)?)?;
    let mut h = -sum;

    if let Some(b) = base {
        h /= b.ln();
    }

    Ok(Tensor::<R>::full_scalar(
        &[],
        pk.dtype(),
        h,
        client.device(),
    ))
}

/// Generic implementation of differential entropy via k-NN spacing estimator.
///
/// For 1-D data, uses the sorted-spacing approach: for each point, the k-th
/// nearest neighbor distance is computed from the sorted array as
/// `max(x[i+k] - x[i], x[i] - x[i-k])`, avoiding O(n²) pairwise distances.
pub fn differential_entropy_impl<R, C>(client: &C, x: &Tensor<R>, k: usize) -> Result<Tensor<R>>
where
    R: Runtime<DType = DType>,
    C: TensorOps<R> + RuntimeClient<R>,
{
    validate_stats_dtype(x.dtype())?;

    let x_contig = x.contiguous()?;
    let n = x_contig.numel();

    if n < k + 1 {
        return Err(Error::InvalidArgument {
            arg: "x",
            reason: format!("need at least {} samples for k={}", k + 1, k),
        });
    }
    if k == 0 {
        return Err(Error::InvalidArgument {
            arg: "k",
            reason: "k must be at least 1".to_string(),
        });
    }

    let dtype = x.dtype();
    let device = client.device();

    // Sort data on device
    let sorted = client.sort(&x_contig, 0, false)?;

    // For 1-D sorted data, the k-th nearest neighbor of sorted[i] is:
    //   min distance from sorted[i-k..i-1] or sorted[i+1..i+k]
    // But for the KL estimator, we use: rho_i = distance to k-th NN
    // In sorted 1-D data, the k-th NN distance for sorted[i] is:
    //   Consider sorted[i-k] and sorted[i+k] (if they exist), pick the k-th closest
    //
    // Simplified approach: use spacing. For sorted data, compute
    //   forward spacing: sorted[i+k] - sorted[i] for i in [0, n-k)
    //   backward spacing: sorted[i] - sorted[i-k] for i in [k, n)
    // Then rho_i = min(forward[i], backward[i]) where applicable, but actually
    // for the Kozachenko-Leonenko estimator, we just need:
    //   rho_i = 2 * (distance to k-th NN)
    //
    // For simplicity and correctness, compute shifted differences on device:
    let head = sorted.narrow(0, 0, n - k)?; // sorted[0..n-k]
    let tail = sorted.narrow(0, k, n - k)?; // sorted[k..n]
    let spacings = client.sub(&tail, &head)?; // sorted[i+k] - sorted[i], length n-k

    // For the estimator we need log(2 * rho_i) for each point
    // Use the spacing as an approximation of 2*rho for interior points
    // (standard approach for 1-D KL estimator with sorted data)
    let epsilon = Tensor::<R>::full_scalar(spacings.shape(), dtype, 1e-300, device);
    let safe_spacings = client.maximum(&spacings, &epsilon)?;
    let log_spacings = client.log(&safe_spacings)?;

    let all_dims: Vec<usize> = (0..log_spacings.ndim()).collect();
    let log_sum = extract_scalar(&client.sum(&log_spacings, &all_dims, false)?)?;

    let n_eff = (n - k) as f64;
    let digamma_n = {
        use crate::stats::continuous::special::digamma;
        digamma(n as f64)
    };
    let digamma_k = {
        use crate::stats::continuous::special::digamma;
        digamma(k as f64)
    };

    let h = digamma_n - digamma_k + log_sum / n_eff;

    Ok(Tensor::<R>::full_scalar(&[], dtype, h, device))
}

/// Generic implementation of KL divergence.
pub fn kl_divergence_impl<R, C>(
    client: &C,
    pk: &Tensor<R>,
    qk: &Tensor<R>,
    base: Option<f64>,
) -> Result<Tensor<R>>
where
    R: Runtime<DType = DType>,
    C: TensorOps<R> + RuntimeClient<R>,
{
    validate_stats_dtype(pk.dtype())?;
    validate_stats_dtype(qk.dtype())?;

    if pk.numel() != qk.numel() {
        return Err(Error::InvalidArgument {
            arg: "pk/qk",
            reason: "distributions must have equal length".to_string(),
        });
    }

    let pk_contig = pk.contiguous()?;
    let qk_contig = qk.contiguous()?;

    // D_KL = Σ p * log(p/q) = Σ p * (log(p) - log(q))
    let epsilon = Tensor::<R>::full_scalar(pk_contig.shape(), pk.dtype(), 1e-300, client.device());
    let pk_safe = client.maximum(&pk_contig, &epsilon)?;
    let qk_safe = client.maximum(&qk_contig, &epsilon)?;

    let log_pk = client.log(&pk_safe)?;
    let log_qk = client.log(&qk_safe)?;
    let log_ratio = client.sub(&log_pk, &log_qk)?;

    let terms = client.mul(&pk_contig, &log_ratio)?;

    let all_dims: Vec<usize> = (0..terms.ndim()).collect();
    let sum = extract_scalar(&client.sum(&terms, &all_dims, false)?)?;

    let mut kl = sum;
    if let Some(b) = base {
        kl /= b.ln();
    }

    Ok(Tensor::<R>::full_scalar(
        &[],
        pk.dtype(),
        kl,
        client.device(),
    ))
}

/// Generic implementation of mutual information via histogram binning.
///
/// Computes bin indices and joint/marginal histograms entirely on device
/// using tensor operations.
pub fn mutual_information_impl<R, C>(
    client: &C,
    x: &Tensor<R>,
    y: &Tensor<R>,
    bins: usize,
    base: Option<f64>,
) -> Result<Tensor<R>>
where
    R: Runtime<DType = DType>,
    C: TensorOps<R> + RuntimeClient<R>,
{
    validate_stats_dtype(x.dtype())?;
    validate_stats_dtype(y.dtype())?;

    let n = x.numel();
    if n != y.numel() {
        return Err(Error::InvalidArgument {
            arg: "x/y",
            reason: "must have equal length".to_string(),
        });
    }
    if n == 0 || bins == 0 {
        return Err(Error::InvalidArgument {
            arg: "bins",
            reason: "need positive bins and non-empty data".to_string(),
        });
    }

    let dtype = x.dtype();
    let device = client.device();
    let x_contig = x.contiguous()?;
    let y_contig = y.contiguous()?;

    // Compute min/max on device (single scalar transfers for range computation)
    let all_dims: Vec<usize> = (0..x_contig.ndim()).collect();
    let x_min = extract_scalar(&client.min(&x_contig, &all_dims, false)?)?;
    let x_max = extract_scalar(&client.max(&x_contig, &all_dims, false)?)?;
    let y_min = extract_scalar(&client.min(&y_contig, &all_dims, false)?)?;
    let y_max = extract_scalar(&client.max(&y_contig, &all_dims, false)?)?;

    let x_range = if (x_max - x_min).abs() < 1e-15 {
        1.0
    } else {
        x_max - x_min
    };
    let y_range = if (y_max - y_min).abs() < 1e-15 {
        1.0
    } else {
        y_max - y_min
    };

    // Compute bin indices on device: bin_i = clamp(round((x - min) / range * (bins-1)), 0, bins-1)
    let bins_f = (bins - 1) as f64;
    let x_min_t = Tensor::<R>::full_scalar(x_contig.shape(), dtype, x_min, device);
    let x_shifted = client.sub(&x_contig, &x_min_t)?;
    let x_scale_t = Tensor::<R>::full_scalar(x_contig.shape(), dtype, bins_f / x_range, device);
    let x_scaled = client.mul(&x_shifted, &x_scale_t)?;
    let x_rounded = client.round(&x_scaled)?;
    let x_bins = client.clamp(&x_rounded, 0.0, bins_f)?;

    let y_min_t = Tensor::<R>::full_scalar(y_contig.shape(), dtype, y_min, device);
    let y_shifted = client.sub(&y_contig, &y_min_t)?;
    let y_scale_t = Tensor::<R>::full_scalar(y_contig.shape(), dtype, bins_f / y_range, device);
    let y_scaled = client.mul(&y_shifted, &y_scale_t)?;
    let y_rounded = client.round(&y_scaled)?;
    let y_bins = client.clamp(&y_rounded, 0.0, bins_f)?;

    // Flatten to 1-D joint index: joint_idx = x_bin * bins + y_bin
    let bins_scale_t = Tensor::<R>::full_scalar(x_bins.shape(), dtype, bins as f64, device);
    let x_bins_scaled = client.mul(&x_bins, &bins_scale_t)?;
    let joint_idx_f = client.add(&x_bins_scaled, &y_bins)?;
    let joint_idx = client.cast(&joint_idx_f, numr::dtype::DType::I64)?;

    // Build joint histogram via scatter_reduce (sum ones at joint indices)
    let ones = Tensor::<R>::full_scalar(&[n], dtype, 1.0, device);
    let joint_zeros = Tensor::<R>::full_scalar(&[bins * bins], dtype, 0.0, device);
    let joint_hist = client.scatter_reduce(
        &joint_zeros,
        0,
        &joint_idx,
        &ones,
        numr::ops::ScatterReduceOp::Sum,
        true,
    )?;

    // Normalize to joint probability
    let n_t = Tensor::<R>::full_scalar(&[1], dtype, n as f64, device);
    let pxy = client.div(&joint_hist, &n_t)?; // [bins*bins]

    // Marginals: reshape to [bins, bins], sum along axes
    let pxy_2d = pxy.reshape(&[bins, bins])?;
    let px = client.sum(&pxy_2d, &[1], false)?; // [bins]
    let py = client.sum(&pxy_2d, &[0], false)?; // [bins]

    // I(X;Y) = Σ p(x,y) * log(p(x,y) / (p(x)*p(y)))
    // Compute outer product p(x) * p(y) → [bins, bins]
    let px_col = px.reshape(&[bins, 1])?;
    let py_row = py.reshape(&[1, bins])?;
    let px_py = client.mul(&px_col, &py_row)?; // [bins, bins]
    let px_py_flat = px_py.reshape(&[bins * bins])?;

    // Compute log(pxy / (px*py)) where both > 0
    let epsilon = Tensor::<R>::full_scalar(&[bins * bins], dtype, 1e-300, device);
    let pxy_safe = client.maximum(&pxy, &epsilon)?;
    let pxpy_safe = client.maximum(&px_py_flat, &epsilon)?;
    let ratio = client.div(&pxy_safe, &pxpy_safe)?;
    let log_ratio = client.log(&ratio)?;

    // p(x,y) * log(p(x,y) / (p(x)*p(y)))
    let terms = client.mul(&pxy, &log_ratio)?;

    let all_dims_joint: Vec<usize> = (0..terms.ndim()).collect();
    let mut mi = extract_scalar(&client.sum(&terms, &all_dims_joint, false)?)?;

    if let Some(b) = base {
        mi /= b.ln();
    }

    // MI can be slightly negative due to floating point; clamp to 0
    mi = mi.max(0.0);

    Ok(Tensor::<R>::full_scalar(&[], dtype, mi, device))
}

/// Generic implementation of cross-entropy.
///
/// H(p, q) = -Σ p(x) log(q(x))
pub fn cross_entropy_impl<R, C>(
    client: &C,
    pk: &Tensor<R>,
    qk: &Tensor<R>,
    base: Option<f64>,
) -> Result<Tensor<R>>
where
    R: Runtime<DType = DType>,
    C: TensorOps<R> + RuntimeClient<R>,
{
    validate_stats_dtype(pk.dtype())?;
    validate_stats_dtype(qk.dtype())?;

    if pk.numel() != qk.numel() {
        return Err(Error::InvalidArgument {
            arg: "pk/qk",
            reason: "distributions must have equal length".to_string(),
        });
    }

    let pk_contig = pk.contiguous()?;
    let qk_contig = qk.contiguous()?;

    // H(p, q) = -Σ p * log(q)
    let epsilon = Tensor::<R>::full_scalar(qk_contig.shape(), qk.dtype(), 1e-300, client.device());
    let qk_safe = client.maximum(&qk_contig, &epsilon)?;
    let log_qk = client.log(&qk_safe)?;

    let terms = client.mul(&pk_contig, &log_qk)?;

    let all_dims: Vec<usize> = (0..terms.ndim()).collect();
    let sum = extract_scalar(&client.sum(&terms, &all_dims, false)?)?;
    let mut h = -sum;

    if let Some(b) = base {
        h /= b.ln();
    }

    Ok(Tensor::<R>::full_scalar(
        &[],
        pk.dtype(),
        h,
        client.device(),
    ))
}

/// Generic implementation of negative log-likelihood loss.
///
/// NLL = -mean(log_probs[i, targets[i]]) for i in 0..N
pub fn nll_loss_impl<R, C>(
    client: &C,
    log_probs: &Tensor<R>,
    targets: &Tensor<R>,
) -> Result<Tensor<R>>
where
    R: Runtime<DType = DType>,
    C: TensorOps<R> + RuntimeClient<R>,
{
    validate_stats_dtype(log_probs.dtype())?;

    let shape = log_probs.shape();
    if shape.len() != 2 {
        return Err(Error::InvalidArgument {
            arg: "log_probs",
            reason: format!("expected 2-D tensor [N, C], got {}-D", shape.len()),
        });
    }
    if targets.ndim() != 1 || targets.shape()[0] != shape[0] {
        return Err(Error::InvalidArgument {
            arg: "targets",
            reason: format!(
                "expected 1-D tensor of length {}, got shape {:?}",
                shape[0],
                targets.shape()
            ),
        });
    }

    let n = shape[0];
    let dtype = log_probs.dtype();
    let device = client.device();

    // Cast targets to I64 for gather
    let targets_i64 = client.cast(targets, numr::dtype::DType::I64)?;
    let targets_2d = targets_i64.reshape(&[n, 1])?;

    // Gather log_probs at target indices: gather(log_probs, dim=1, index=targets)
    let selected = client.gather(log_probs, 1, &targets_2d)?; // [N, 1]

    // Negate and mean
    let neg_selected = client.neg(&selected)?;
    let all_dims: Vec<usize> = (0..neg_selected.ndim()).collect();
    let sum = extract_scalar(&client.sum(&neg_selected, &all_dims, false)?)?;
    let loss = sum / n as f64;

    Ok(Tensor::<R>::full_scalar(&[], dtype, loss, device))
}