next-plaid 1.2.0

CPU-based PLAID implementation for multi-vector search using ndarray
Documentation
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
//! K-means clustering integration using fastkmeans-rs.
//!
//! This module provides functions for computing centroids using the
//! fastkmeans-rs library, which is used during index creation.
//!
//! The implementation follows fast-plaid's approach for automatic K calculation
//! and document sampling.

use ndarray::{Array2, ArrayView2, Axis};
use rand::seq::SliceRandom;
use rand::SeedableRng;
use rand_chacha::ChaCha8Rng;

use crate::error::{Error, Result};
use crate::maxsim;

pub use fastkmeans_rs::{kmeans_double_chunked, FastKMeans, KMeansConfig, KMeansError};

#[cfg(feature = "cuda")]
pub use fastkmeans_rs::FastKMeansCuda;

#[cfg(feature = "metal_gpu")]
pub use fastkmeans_rs::FastKMeansMetal;

/// Configuration for the compute_kmeans function.
#[derive(Debug, Clone)]
pub struct ComputeKmeansConfig {
    /// Number of K-means iterations (default: 4)
    pub kmeans_niters: usize,
    /// Maximum number of points to support per centroid (default: 256)
    pub max_points_per_centroid: usize,
    /// Random seed for reproducibility (default: 42)
    pub seed: u64,
    /// Number of samples to use for K-means training.
    /// If None, uses heuristic: min(1 + 16 * sqrt(120 * num_documents), num_documents)
    pub n_samples_kmeans: Option<usize>,
    /// If provided, explicitly sets the number of centroids (K).
    /// If None, K is calculated using heuristic based on dataset size.
    pub num_partitions: Option<usize>,
    /// Force CPU execution even when CUDA feature is enabled.
    /// Useful for small batches where GPU initialization overhead exceeds benefits.
    pub force_cpu: bool,
}

impl Default for ComputeKmeansConfig {
    fn default() -> Self {
        Self {
            kmeans_niters: 4,
            max_points_per_centroid: 256,
            seed: 42,
            n_samples_kmeans: None,
            num_partitions: None,
            force_cpu: false,
        }
    }
}

/// Default configuration for centroid computation.
/// These defaults match fast-plaid's behavior.
pub fn default_config(num_centroids: usize) -> KMeansConfig {
    KMeansConfig {
        k: num_centroids,
        max_iters: 4,
        tol: 1e-8,
        seed: 42,
        max_points_per_centroid: Some(256),
        chunk_size_data: 51_200,
        chunk_size_centroids: 10_240,
        verbose: false,
    }
}

/// Compute centroids from a set of embeddings (CPU implementation).
/// Uses kmeans_double_chunked directly to avoid FastKMeans::train() which
/// tries CUDA when the cuda feature is enabled.
fn compute_centroids_cpu(
    embeddings: &ArrayView2<f32>,
    config: KMeansConfig,
) -> Result<Array2<f32>> {
    let result = kmeans_double_chunked(embeddings, &config)
        .map_err(|e| Error::IndexCreation(format!("K-means training failed: {}", e)))?;
    Ok(result.centroids)
}

/// Compute centroids from a set of embeddings.
///
/// # Arguments
///
/// * `embeddings` - The embeddings to cluster, shape `[N, dim]`
/// * `num_centroids` - Number of centroids to compute
/// * `config` - Optional custom k-means configuration
/// * `force_cpu` - Force CPU execution even when a GPU backend is available
///
/// # Returns
///
/// The centroids array of shape `[num_centroids, dim]`
#[cfg(not(any(feature = "cuda", feature = "metal_gpu")))]
pub fn compute_centroids(
    embeddings: &ArrayView2<f32>,
    num_centroids: usize,
    config: Option<KMeansConfig>,
    _force_cpu: bool,
) -> Result<Array2<f32>> {
    let config = config.unwrap_or_else(|| default_config(num_centroids));
    compute_centroids_cpu(embeddings, config)
}

/// Compute centroids from a set of embeddings using CUDA (or CPU if force_cpu is true or CUDA fails).
#[cfg(feature = "cuda")]
pub fn compute_centroids(
    embeddings: &ArrayView2<f32>,
    num_centroids: usize,
    config: Option<KMeansConfig>,
    force_cpu: bool,
) -> Result<Array2<f32>> {
    let config = config.unwrap_or_else(|| default_config(num_centroids));

    // Skip CUDA if force_cpu is set or CUDA has been determined to be broken
    if force_cpu || crate::cuda::is_cuda_broken() {
        return compute_centroids_cpu(embeddings, config.clone());
    }

    // Try CUDA first, catching panics from invalid/stub CUDA libraries
    let cuda_result = crate::cuda::catch_cuda_panic(std::panic::AssertUnwindSafe(|| {
        match FastKMeansCuda::with_config(config.clone()) {
            Ok(mut kmeans) => match kmeans.train(embeddings) {
                Ok(()) => kmeans
                    .centroids()
                    .map(|c| c.to_owned())
                    .ok_or_else(|| "CUDA K-means did not produce centroids".to_string()),
                Err(e) => Err(format!("CUDA K-means training failed: {}", e)),
            },
            Err(e) => Err(format!("CUDA K-means init failed: {}", e)),
        }
    }));

    match cuda_result {
        Ok(Ok(centroids)) => Ok(centroids),
        Ok(Err(e)) => {
            crate::cuda::mark_cuda_broken();
            eprintln!(
                "[next-plaid] CUDA K-means error: {}. Falling back to CPU.",
                e
            );
            compute_centroids_cpu(embeddings, config)
        }
        Err(_) => {
            crate::cuda::mark_cuda_broken();
            eprintln!(
                "[next-plaid] CUDA library found but missing required symbols (stub or incompatible driver). \
                 K-means will use CPU instead."
            );
            compute_centroids_cpu(embeddings, config)
        }
    }
}

/// Compute centroids from a set of embeddings using Metal GPU (or CPU if force_cpu is true).
#[cfg(all(feature = "metal_gpu", not(feature = "cuda")))]
pub fn compute_centroids(
    embeddings: &ArrayView2<f32>,
    num_centroids: usize,
    config: Option<KMeansConfig>,
    force_cpu: bool,
) -> Result<Array2<f32>> {
    let config = config.unwrap_or_else(|| default_config(num_centroids));

    if force_cpu {
        return compute_centroids_cpu(embeddings, config);
    }

    let mut kmeans = FastKMeansMetal::with_config(config)
        .map_err(|e| Error::IndexCreation(format!("Metal K-means initialization failed: {}", e)))?;

    kmeans
        .train(embeddings)
        .map_err(|e| Error::IndexCreation(format!("Metal K-means training failed: {}", e)))?;

    kmeans
        .centroids()
        .ok_or_else(|| Error::IndexCreation("Metal K-means did not produce centroids".into()))
        .map(|c| c.to_owned())
}

/// Compute centroids from document embeddings.
///
/// This function flattens the document embeddings before clustering,
/// as k-means operates on individual token embeddings.
///
/// # Arguments
///
/// * `documents` - List of document embeddings, each of shape `[num_tokens, dim]`
/// * `num_centroids` - Number of centroids to compute
/// * `config` - Optional custom k-means configuration
/// * `force_cpu` - Force CPU execution even when CUDA is available
///
/// # Returns
///
/// The centroids array of shape `[num_centroids, dim]`
pub fn compute_centroids_from_documents(
    documents: &[Array2<f32>],
    num_centroids: usize,
    config: Option<KMeansConfig>,
    force_cpu: bool,
) -> Result<Array2<f32>> {
    if documents.is_empty() {
        return Err(Error::IndexCreation("No documents provided".into()));
    }

    let dim = documents[0].ncols();
    let total_tokens: usize = documents.iter().map(|d| d.nrows()).sum();

    // Flatten all documents into a single array
    let mut flat = Array2::<f32>::zeros((total_tokens, dim));
    let mut offset = 0;

    for doc in documents {
        let n = doc.nrows();
        flat.slice_mut(ndarray::s![offset..offset + n, ..])
            .assign(doc);
        offset += n;
    }

    compute_centroids(&flat.view(), num_centroids, config, force_cpu)
}

/// Assign embeddings to their nearest centroids.
///
/// This uses direct distance computation rather than the k-means predict
/// method, as we may have pre-computed centroids.
///
/// # Arguments
///
/// * `embeddings` - The embeddings to assign, shape `[N, dim]`
/// * `centroids` - The centroids, shape `[K, dim]`
///
/// # Returns
///
/// Vector of centroid indices, one per embedding
pub fn assign_to_centroids(embeddings: &ArrayView2<f32>, centroids: &Array2<f32>) -> Vec<usize> {
    maxsim::assign_to_centroids(embeddings, &centroids.view())
}

/// Compute K-means centroids from document embeddings.
///
/// This function implements the same logic as fast-plaid's `compute_kmeans`:
/// 1. Samples documents using heuristic: `min(1 + 16 * sqrt(120 * num_documents), num_documents)`
/// 2. Concatenates all token embeddings from sampled documents
/// 3. Calculates K (num_partitions) using: `2^floor(log2(16 * sqrt(estimated_total_tokens)))`
/// 4. Runs k-means clustering
/// 5. Normalizes the resulting centroids
///
/// # Arguments
///
/// * `documents_embeddings` - List of document embeddings, each of shape `[num_tokens, dim]`
/// * `config` - Configuration for k-means computation
///
/// # Returns
///
/// Normalized centroids array of shape `[K, dim]`
pub fn compute_kmeans(
    documents_embeddings: &[Array2<f32>],
    config: &ComputeKmeansConfig,
) -> Result<Array2<f32>> {
    if documents_embeddings.is_empty() {
        return Err(Error::IndexCreation("No documents provided".into()));
    }

    let num_documents = documents_embeddings.len();
    let dim = documents_embeddings[0].ncols();

    // Calculate n_samples_kmeans using fast-plaid's heuristic
    let n_samples_kmeans = config.n_samples_kmeans.unwrap_or_else(|| {
        (1.0 + 16.0 * (120.0 * num_documents as f64).sqrt()).min(num_documents as f64) as usize
    });
    let n_samples_kmeans = n_samples_kmeans.min(num_documents);

    let mut rng = ChaCha8Rng::seed_from_u64(config.seed);
    let mut indices: Vec<usize> = (0..num_documents).collect();
    indices.shuffle(&mut rng);
    indices.truncate(n_samples_kmeans);
    let sampled_indices = indices;

    // Calculate total tokens in sampled documents
    let total_sample_tokens: usize = sampled_indices
        .iter()
        .map(|&i| documents_embeddings[i].nrows())
        .sum();

    // Concatenate all embeddings from sampled documents
    let mut samples_tensor = Array2::<f32>::zeros((total_sample_tokens, dim));
    let mut current_offset = 0;

    for &i in &sampled_indices {
        let tensor_slice = &documents_embeddings[i];
        let length = tensor_slice.nrows();
        samples_tensor
            .slice_mut(ndarray::s![current_offset..current_offset + length, ..])
            .assign(tensor_slice);
        current_offset += length;
    }

    // Calculate num_partitions using fast-plaid's heuristic if not provided
    let num_partitions = config.num_partitions.unwrap_or_else(|| {
        // Calculate based on density of sample relative to whole dataset
        let avg_tokens_per_doc = total_sample_tokens as f64 / n_samples_kmeans as f64;
        let estimated_total_tokens = avg_tokens_per_doc * num_documents as f64;
        2usize.pow((16.0 * estimated_total_tokens.sqrt()).log2().floor() as u32)
    });

    // The actual K that will be used
    let actual_k = num_partitions.min(total_sample_tokens);

    if actual_k == 0 {
        return Err(Error::IndexCreation("Cannot compute 0 centroids".into()));
    }

    // Build k-means config
    let kmeans_config = KMeansConfig {
        k: actual_k,
        max_iters: config.kmeans_niters,
        tol: 1e-8,
        seed: config.seed,
        max_points_per_centroid: Some(config.max_points_per_centroid),
        chunk_size_data: 51_200,
        chunk_size_centroids: 10_240,
        verbose: false,
    };

    // Run k-means (CPU implementation)
    #[cfg(not(any(feature = "cuda", feature = "metal_gpu")))]
    let centroids = {
        let mut kmeans = FastKMeans::with_config(kmeans_config);
        kmeans
            .train(&samples_tensor.view())
            .map_err(|e| Error::IndexCreation(format!("K-means training failed: {}", e)))?;

        kmeans
            .centroids()
            .ok_or_else(|| Error::IndexCreation("K-means did not produce centroids".into()))?
            .to_owned()
    };

    // Run k-means (CUDA with automatic CPU fallback, catching panics)
    #[cfg(feature = "cuda")]
    let centroids = if config.force_cpu || crate::cuda::is_cuda_broken() {
        // Use CPU if force_cpu is set or CUDA has been determined to be broken
        // Use kmeans_double_chunked directly to avoid FastKMeans::train() which
        // tries CUDA when the cuda feature is enabled
        let result = kmeans_double_chunked(&samples_tensor.view(), &kmeans_config)
            .map_err(|e| Error::IndexCreation(format!("K-means training failed: {}", e)))?;
        result.centroids
    } else {
        // Try CUDA, catching panics from invalid/stub CUDA libraries
        let samples_view = samples_tensor.view();
        let cuda_result = crate::cuda::catch_cuda_panic(std::panic::AssertUnwindSafe(|| {
            match FastKMeansCuda::with_config(kmeans_config.clone()) {
                Ok(mut kmeans) => match kmeans.train(&samples_view) {
                    Ok(()) => kmeans.centroids().map(|c| c.to_owned()),
                    Err(_) => None,
                },
                Err(_) => None,
            }
        }));

        match cuda_result {
            Ok(Some(c)) => c,
            Ok(None) | Err(_) => {
                // Mark CUDA as broken to prevent subsequent attempts
                crate::cuda::mark_cuda_broken();
                if cuda_result.is_err() {
                    eprintln!("[next-plaid] CUDA library found but missing required symbols (stub or incompatible driver). \
                               K-means will use CPU instead.");
                } else {
                    eprintln!(
                        "[next-plaid] CUDA K-means did not produce centroids. Falling back to CPU."
                    );
                }
                // Use kmeans_double_chunked directly to avoid FastKMeans::train() which
                // tries CUDA when the cuda feature is enabled
                let result = kmeans_double_chunked(&samples_tensor.view(), &kmeans_config)
                    .map_err(|e| Error::IndexCreation(format!("K-means training failed: {}", e)))?;
                result.centroids
            }
        }
    };

    // Run k-means (Metal GPU with CPU fallback when force_cpu is true)
    #[cfg(all(feature = "metal_gpu", not(feature = "cuda")))]
    let centroids = if config.force_cpu {
        let mut kmeans = FastKMeans::with_config(kmeans_config);
        kmeans
            .train(&samples_tensor.view())
            .map_err(|e| Error::IndexCreation(format!("K-means training failed: {}", e)))?;

        kmeans
            .centroids()
            .ok_or_else(|| Error::IndexCreation("K-means did not produce centroids".into()))?
            .to_owned()
    } else {
        let mut kmeans = FastKMeansMetal::with_config(kmeans_config).map_err(|e| {
            Error::IndexCreation(format!("Metal K-means initialization failed: {}", e))
        })?;
        kmeans
            .train(&samples_tensor.view())
            .map_err(|e| Error::IndexCreation(format!("Metal K-means training failed: {}", e)))?;

        kmeans
            .centroids()
            .ok_or_else(|| Error::IndexCreation("Metal K-means did not produce centroids".into()))?
            .to_owned()
    };

    // Normalize centroids (fast-plaid does F.normalize(centroids, dim=-1))
    let mut normalized = centroids.clone();
    for mut row in normalized.axis_iter_mut(Axis(0)) {
        let norm = row.dot(&row).sqrt().max(1e-12);
        row /= norm;
    }

    Ok(normalized)
}

/// Returns the number of centroids (num_partitions) that would be computed
/// for the given documents using fast-plaid's heuristic.
///
/// This is useful for pre-computing the expected number of centroids.
pub fn estimate_num_partitions(documents_embeddings: &[Array2<f32>]) -> usize {
    if documents_embeddings.is_empty() {
        return 0;
    }

    let num_documents = documents_embeddings.len();

    // Calculate n_samples_kmeans
    let n_samples_kmeans =
        (1.0 + 16.0 * (120.0 * num_documents as f64).sqrt()).min(num_documents as f64) as usize;

    // Estimate total tokens
    let total_tokens: usize = documents_embeddings.iter().map(|d| d.nrows()).sum();
    let avg_tokens_per_doc = total_tokens as f64 / num_documents as f64;

    // Sample a subset to estimate
    let sampled_count = n_samples_kmeans.min(num_documents);
    let estimated_total_tokens = avg_tokens_per_doc * num_documents as f64;

    // Use the heuristic
    let k = 2usize.pow((16.0 * estimated_total_tokens.sqrt()).log2().floor() as u32);

    // Cap at total sample tokens
    let sample_tokens = (avg_tokens_per_doc * sampled_count as f64) as usize;
    k.min(sample_tokens)
}

#[cfg(test)]
mod tests {
    use super::*;
    use ndarray_rand::rand_distr::Uniform;
    use ndarray_rand::RandomExt;

    #[test]
    fn test_compute_centroids() {
        let data: Array2<f32> = Array2::random((500, 32), Uniform::new(-1.0f32, 1.0));
        let centroids = compute_centroids(&data.view(), 10, None, false).unwrap();

        assert_eq!(centroids.nrows(), 10);
        assert_eq!(centroids.ncols(), 32);
    }

    #[test]
    fn test_compute_centroids_from_documents() {
        let docs: Vec<Array2<f32>> = (0..10)
            .map(|_| Array2::random((50, 16), Uniform::new(-1.0f32, 1.0)))
            .collect();

        let centroids = compute_centroids_from_documents(&docs, 8, None, false).unwrap();

        assert_eq!(centroids.nrows(), 8);
        assert_eq!(centroids.ncols(), 16);
    }

    #[test]
    fn test_assign_to_centroids() {
        let data: Array2<f32> = Array2::random((100, 16), Uniform::new(-1.0f32, 1.0));
        let centroids = compute_centroids(&data.view(), 5, None, false).unwrap();

        let assignments = assign_to_centroids(&data.view(), &centroids);

        assert_eq!(assignments.len(), 100);
        for &label in &assignments {
            assert!(label < 5);
        }
    }

    #[test]
    fn test_compute_kmeans() {
        // Create synthetic documents
        let docs: Vec<Array2<f32>> = (0..100)
            .map(|_| Array2::random((50, 32), Uniform::new(-1.0f32, 1.0)))
            .collect();

        let config = ComputeKmeansConfig::default();
        let centroids = compute_kmeans(&docs, &config).unwrap();

        // Check that centroids are produced
        assert!(centroids.nrows() > 0);
        assert_eq!(centroids.ncols(), 32);

        // Check that centroids are normalized (unit vectors)
        for row in centroids.axis_iter(Axis(0)) {
            let norm = row.dot(&row).sqrt();
            assert!(
                (norm - 1.0).abs() < 1e-5,
                "Centroid not normalized: norm={}",
                norm
            );
        }
    }

    #[test]
    fn test_compute_kmeans_with_explicit_k() {
        let docs: Vec<Array2<f32>> = (0..50)
            .map(|_| Array2::random((30, 16), Uniform::new(-1.0f32, 1.0)))
            .collect();

        let config = ComputeKmeansConfig {
            num_partitions: Some(16),
            ..Default::default()
        };
        let centroids = compute_kmeans(&docs, &config).unwrap();

        assert_eq!(centroids.nrows(), 16);
        assert_eq!(centroids.ncols(), 16);
    }

    #[test]
    fn test_estimate_num_partitions() {
        // Test that the heuristic produces reasonable values
        let small_docs: Vec<Array2<f32>> = (0..10)
            .map(|_| Array2::random((20, 16), Uniform::new(-1.0f32, 1.0)))
            .collect();

        let k_small = estimate_num_partitions(&small_docs);
        assert!(k_small > 0);

        let large_docs: Vec<Array2<f32>> = (0..1000)
            .map(|_| Array2::random((50, 16), Uniform::new(-1.0f32, 1.0)))
            .collect();

        let k_large = estimate_num_partitions(&large_docs);
        assert!(
            k_large > k_small,
            "Larger dataset should have more partitions"
        );
    }
}