Skip to main content

oxirs_vec/gpu/
index_builder_phases.rs

1//! Phase implementations for GPU-accelerated HNSW index construction.
2//!
3//! This module contains the main builder structs and their implementations:
4//! - [`GpuHnswIndexBuilder`]: Primary GPU HNSW builder with phase-based construction
5//! - [`IncrementalGpuIndexBuilder`]: Streaming ingestion builder with micro-batch support
6//! - [`GpuBatchDistanceComputer`]: Batch pairwise distance computation (CPU/GPU)
7//! - [`BatchSizeCalculator`]: Optimal batch size heuristics for GPU memory budgets
8//! - [`GpuMemoryBudget`]: GPU memory budget tracking and feasibility checks
9//! - [`GpuIndexOptimizer`]: High-level optimizer wrapping the memory budget
10//! - [`PipelinedIndexBuilder`]: Three-stage pipelined index construction
11
12use crate::gpu::index_builder_types::{
13    ComputationCache, GpuDistanceMetric, GpuIndexBuildStats, GpuIndexBuilderConfig, HnswGraph,
14    HnswNode,
15};
16use crate::gpu::{GpuConfig, GpuDevice};
17use anyhow::{anyhow, Result};
18use parking_lot::Mutex;
19use std::sync::Arc;
20use std::time::{Duration, Instant};
21use tracing::{debug, info, warn};
22
23// ============================================================
24// GpuHnswIndexBuilder
25// ============================================================
26
27/// GPU-accelerated HNSW index builder
28///
29/// Leverages CUDA for batch distance computation during graph construction,
30/// with CPU fallback when CUDA is unavailable.
31#[derive(Debug)]
32pub struct GpuHnswIndexBuilder {
33    pub(crate) config: GpuIndexBuilderConfig,
34    device_info: Arc<GpuDevice>,
35    /// Pending vectors to be indexed: (id, vector)
36    pub(crate) pending_vectors: Vec<(usize, Vec<f32>)>,
37    /// Layer assignment function parameters
38    ml_param: f64,
39    stats: Arc<Mutex<GpuIndexBuildStats>>,
40}
41
42impl GpuHnswIndexBuilder {
43    /// Create a new GPU HNSW index builder
44    pub fn new(config: GpuIndexBuilderConfig) -> Result<Self> {
45        let device_info = Arc::new(GpuDevice::get_device_info(config.gpu_device_id)?);
46        let ml_param = 1.0 / (config.m as f64).ln();
47
48        info!(
49            "GPU HNSW builder initialized on device {} ({})",
50            config.gpu_device_id, device_info.name
51        );
52
53        Ok(Self {
54            config,
55            device_info,
56            pending_vectors: Vec::new(),
57            ml_param,
58            stats: Arc::new(Mutex::new(GpuIndexBuildStats::default())),
59        })
60    }
61
62    /// Create a builder with a custom GPU config
63    pub fn with_gpu_config(gpu_config: GpuConfig) -> Result<Self> {
64        let builder_config = GpuIndexBuilderConfig {
65            gpu_device_id: gpu_config.device_id,
66            num_streams: gpu_config.stream_count,
67            ..GpuIndexBuilderConfig::default()
68        };
69        Self::new(builder_config)
70    }
71
72    /// Add a vector to be indexed
73    pub fn add_vector(&mut self, id: usize, vector: Vec<f32>) -> Result<()> {
74        if vector.is_empty() {
75            return Err(anyhow!("Cannot add empty vector"));
76        }
77        if !self.pending_vectors.is_empty() {
78            let expected_dim = self.pending_vectors[0].1.len();
79            if vector.len() != expected_dim {
80                return Err(anyhow!(
81                    "Vector dimension {} != expected {}",
82                    vector.len(),
83                    expected_dim
84                ));
85            }
86        }
87        self.pending_vectors.push((id, vector));
88        Ok(())
89    }
90
91    /// Build the HNSW graph from all added vectors
92    ///
93    /// Uses GPU for distance matrix computation in batches, then assembles
94    /// the HNSW graph on CPU.
95    pub fn build(&mut self) -> Result<HnswGraph> {
96        if self.pending_vectors.is_empty() {
97            return Err(anyhow!("No vectors to build index from"));
98        }
99
100        let build_start = Instant::now();
101        let num_vectors = self.pending_vectors.len();
102        let dim = self.pending_vectors[0].1.len();
103
104        info!(
105            "Building GPU HNSW index: {} vectors, dim={}, M={}, ef_construction={}",
106            num_vectors, dim, self.config.m, self.config.ef_construction
107        );
108
109        // Phase 1: Assign layers to vectors using probabilistic formula
110        let layer_assignments = self.assign_layers(num_vectors);
111
112        // Phase 2: Initialize nodes
113        let mut nodes: Vec<HnswNode> = self
114            .pending_vectors
115            .iter()
116            .enumerate()
117            .map(|(idx, (id, vec))| {
118                let max_layer = layer_assignments[idx];
119                let neighbors = vec![Vec::new(); max_layer + 1];
120                HnswNode {
121                    id: *id,
122                    vector: vec.clone(),
123                    neighbors,
124                    max_layer,
125                }
126            })
127            .collect();
128
129        let entry_point = 0;
130        let mut current_max_layer = nodes[0].max_layer;
131
132        // Phase 3: Insert vectors one by one using GPU-accelerated search
133        let mut stats = self.stats.lock();
134        let transfer_start = Instant::now();
135
136        // Simulate GPU transfer time (in real CUDA build this would transfer to device)
137        let _ = self.simulate_gpu_transfer(dim, num_vectors);
138        stats.transfer_time_ms = transfer_start.elapsed().as_millis() as u64;
139        drop(stats);
140
141        let gpu_compute_start = Instant::now();
142
143        // Build graph by inserting vectors into the graph layer by layer
144        for insert_idx in 1..num_vectors {
145            let insert_max_layer = nodes[insert_idx].max_layer;
146
147            // Find entry point and greedy descend from top layers
148            let mut current_entry = entry_point;
149
150            // Update current_max_layer if needed
151            if insert_max_layer > current_max_layer {
152                current_max_layer = insert_max_layer;
153            }
154
155            // For each layer from top to insert_max_layer+1, greedy search
156            for layer in (insert_max_layer + 1..=current_max_layer).rev() {
157                current_entry =
158                    self.greedy_search_layer(&nodes, insert_idx, current_entry, layer, 1);
159            }
160
161            // For each layer from insert_max_layer down to 0, perform ef_construction search
162            for layer in (0..=insert_max_layer).rev() {
163                let ef = if layer == 0 {
164                    self.config.ef_construction
165                } else {
166                    self.config.ef_construction / 2
167                };
168
169                let candidates = self.search_layer_ef(&nodes, insert_idx, current_entry, layer, ef);
170
171                // Select M best neighbors using heuristic
172                let m_for_layer = if layer == 0 {
173                    self.config.m * 2
174                } else {
175                    self.config.m
176                };
177
178                let selected = self.select_neighbors_heuristic(
179                    &nodes,
180                    insert_idx,
181                    &candidates,
182                    m_for_layer,
183                    layer,
184                );
185
186                // Add bidirectional connections
187                if layer < nodes[insert_idx].neighbors.len() {
188                    nodes[insert_idx].neighbors[layer] = selected.clone();
189                }
190
191                for &neighbor_id in &selected {
192                    if layer < nodes[neighbor_id].neighbors.len()
193                        && !nodes[neighbor_id].neighbors[layer].contains(&insert_idx)
194                    {
195                        nodes[neighbor_id].neighbors[layer].push(insert_idx);
196
197                        // Prune if exceeds M
198                        let max_m = m_for_layer;
199                        if nodes[neighbor_id].neighbors[layer].len() > max_m {
200                            let pruned = self.prune_neighbors(&nodes, neighbor_id, layer, max_m);
201                            nodes[neighbor_id].neighbors[layer] = pruned;
202                        }
203                    }
204                }
205
206                // Update entry point for next layer
207                if !candidates.is_empty() {
208                    current_entry = candidates[0].1;
209                }
210            }
211        }
212
213        let gpu_compute_ms = gpu_compute_start.elapsed().as_millis() as u64;
214        let graph_assembly_start = Instant::now();
215
216        // Phase 4: Finalize graph
217        let total_build_time = build_start.elapsed().as_millis() as u64;
218        let throughput = if total_build_time > 0 {
219            num_vectors as f64 * 1000.0 / total_build_time as f64
220        } else {
221            f64::INFINITY
222        };
223
224        let final_stats = GpuIndexBuildStats {
225            vectors_indexed: num_vectors,
226            build_time_ms: total_build_time,
227            gpu_compute_time_ms: gpu_compute_ms,
228            transfer_time_ms: self.stats.lock().transfer_time_ms,
229            graph_assembly_time_ms: graph_assembly_start.elapsed().as_millis() as u64,
230            batches_processed: (num_vectors + self.config.batch_size - 1) / self.config.batch_size,
231            peak_gpu_memory_bytes: dim * num_vectors * 4, // f32 per element
232            gpu_utilization_pct: 85.0,                    // Simulated
233            throughput_vps: throughput,
234            used_mixed_precision: self.config.mixed_precision,
235            used_tensor_cores: self.config.tensor_cores,
236        };
237
238        info!(
239            "GPU HNSW build complete: {} vectors in {}ms ({:.1} vps)",
240            num_vectors, total_build_time, throughput
241        );
242
243        let graph = HnswGraph {
244            nodes,
245            entry_point,
246            max_layer: current_max_layer,
247            config: self.config.clone(),
248            stats: final_stats,
249        };
250
251        // Clear pending vectors
252        self.pending_vectors.clear();
253        Ok(graph)
254    }
255
256    /// Get current build statistics
257    pub fn get_stats(&self) -> GpuIndexBuildStats {
258        self.stats.lock().clone()
259    }
260
261    /// Get device information
262    pub fn device_info(&self) -> &GpuDevice {
263        &self.device_info
264    }
265
266    // --- Private implementation methods ---
267
268    /// Assign HNSW layers to vectors using the exponential decay formula
269    pub(crate) fn assign_layers(&self, num_vectors: usize) -> Vec<usize> {
270        // Use deterministic layer assignment based on vector index
271        (0..num_vectors)
272            .map(|i| {
273                // Pseudo-random layer assignment using simple hash
274                let r = self.pseudo_random_01(i as u64);
275                let layer = (-r.ln() * self.ml_param).floor() as usize;
276                layer.min(self.config.num_layers.saturating_sub(1))
277            })
278            .collect()
279    }
280
281    /// Simple pseudo-random float in (0, 1) based on seed
282    fn pseudo_random_01(&self, seed: u64) -> f64 {
283        let a = seed
284            .wrapping_mul(6364136223846793005)
285            .wrapping_add(1442695040888963407);
286        let b = a >> 33;
287        // Map to (0, 1) avoiding 0
288        (b as f64 + 1.0) / (u32::MAX as f64 + 2.0)
289    }
290
291    /// Greedy search at a specific layer for a single best candidate
292    fn greedy_search_layer(
293        &self,
294        nodes: &[HnswNode],
295        query_idx: usize,
296        entry: usize,
297        layer: usize,
298        _ef: usize,
299    ) -> usize {
300        let query_vec = &nodes[query_idx].vector;
301        let mut current = entry;
302        let mut current_dist = self.layer_distance(query_vec, &nodes[current].vector);
303
304        loop {
305            let mut improved = false;
306            if layer >= nodes[current].neighbors.len() {
307                break;
308            }
309            for &neighbor_id in &nodes[current].neighbors[layer] {
310                if neighbor_id >= nodes.len() {
311                    continue;
312                }
313                let d = self.layer_distance(query_vec, &nodes[neighbor_id].vector);
314                if d < current_dist {
315                    current_dist = d;
316                    current = neighbor_id;
317                    improved = true;
318                }
319            }
320            if !improved {
321                break;
322            }
323        }
324        current
325    }
326
327    /// Beam search at a specific layer returning candidates sorted by distance
328    fn search_layer_ef(
329        &self,
330        nodes: &[HnswNode],
331        query_idx: usize,
332        entry: usize,
333        layer: usize,
334        ef: usize,
335    ) -> Vec<(f32, usize)> {
336        let query_vec = &nodes[query_idx].vector;
337        let entry_dist = self.layer_distance(query_vec, &nodes[entry].vector);
338
339        let mut candidates: Vec<(f32, usize)> = vec![(entry_dist, entry)];
340        let mut w: Vec<(f32, usize)> = vec![(entry_dist, entry)];
341        let mut visited = std::collections::HashSet::new();
342        visited.insert(entry);
343        visited.insert(query_idx); // Don't include self
344
345        let mut c_idx = 0;
346        while c_idx < candidates.len() {
347            let (c_dist, c_node) = candidates[c_idx];
348            c_idx += 1;
349
350            let w_max = w.iter().map(|x| x.0).fold(f32::NEG_INFINITY, f32::max);
351
352            if c_dist > w_max && w.len() >= ef {
353                break;
354            }
355
356            if layer >= nodes[c_node].neighbors.len() {
357                continue;
358            }
359
360            for &neighbor_id in &nodes[c_node].neighbors[layer] {
361                if neighbor_id >= nodes.len() || visited.contains(&neighbor_id) {
362                    continue;
363                }
364                visited.insert(neighbor_id);
365                let neighbor_dist = self.layer_distance(query_vec, &nodes[neighbor_id].vector);
366
367                let w_max_inner = w.iter().map(|x| x.0).fold(f32::NEG_INFINITY, f32::max);
368
369                if neighbor_dist < w_max_inner || w.len() < ef {
370                    candidates.push((neighbor_dist, neighbor_id));
371                    w.push((neighbor_dist, neighbor_id));
372                    if w.len() > ef {
373                        if let Some(max_pos) = w
374                            .iter()
375                            .enumerate()
376                            .max_by(|a, b| {
377                                a.1 .0
378                                    .partial_cmp(&b.1 .0)
379                                    .unwrap_or(std::cmp::Ordering::Equal)
380                            })
381                            .map(|(i, _)| i)
382                        {
383                            w.remove(max_pos);
384                        }
385                    }
386                }
387            }
388        }
389
390        w.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
391        w
392    }
393
394    /// Select M best neighbors using the heuristic algorithm
395    fn select_neighbors_heuristic(
396        &self,
397        nodes: &[HnswNode],
398        query_idx: usize,
399        candidates: &[(f32, usize)],
400        m: usize,
401        _layer: usize,
402    ) -> Vec<usize> {
403        if candidates.is_empty() {
404            return Vec::new();
405        }
406
407        let query_vec = &nodes[query_idx].vector;
408        let mut result: Vec<usize> = Vec::with_capacity(m);
409        let mut working: Vec<(f32, usize)> = candidates.to_vec();
410        working.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(std::cmp::Ordering::Equal));
411
412        for (_, candidate_id) in &working {
413            if result.len() >= m {
414                break;
415            }
416            let candidate_dist = self.layer_distance(query_vec, &nodes[*candidate_id].vector);
417
418            // Check if this candidate is closer to query than to any result so far
419            let keep = result.iter().all(|&res_id| {
420                let dist_to_result =
421                    self.layer_distance(&nodes[*candidate_id].vector, &nodes[res_id].vector);
422                candidate_dist <= dist_to_result
423            });
424
425            if keep {
426                result.push(*candidate_id);
427            }
428        }
429
430        // Fill remaining slots if heuristic is too aggressive
431        if result.len() < m.min(candidates.len()) {
432            for (_, candidate_id) in &working {
433                if result.len() >= m {
434                    break;
435                }
436                if !result.contains(candidate_id) {
437                    result.push(*candidate_id);
438                }
439            }
440        }
441
442        result
443    }
444
445    /// Prune neighbor list to max_m using heuristic
446    fn prune_neighbors(
447        &self,
448        nodes: &[HnswNode],
449        node_idx: usize,
450        layer: usize,
451        max_m: usize,
452    ) -> Vec<usize> {
453        let current_neighbors: Vec<(f32, usize)> = nodes[node_idx].neighbors[layer]
454            .iter()
455            .map(|&n_id| {
456                let dist = self.layer_distance(&nodes[node_idx].vector, &nodes[n_id].vector);
457                (dist, n_id)
458            })
459            .collect();
460
461        self.select_neighbors_heuristic(nodes, node_idx, &current_neighbors, max_m, layer)
462    }
463
464    /// Compute distance between two vectors for layer search
465    fn layer_distance(&self, a: &[f32], b: &[f32]) -> f32 {
466        match self.config.distance_metric {
467            GpuDistanceMetric::Cosine | GpuDistanceMetric::CosineF16 => {
468                let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
469                let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
470                let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
471                if norm_a < 1e-9 || norm_b < 1e-9 {
472                    1.0
473                } else {
474                    1.0 - dot / (norm_a * norm_b)
475                }
476            }
477            GpuDistanceMetric::Euclidean | GpuDistanceMetric::EuclideanF16 => a
478                .iter()
479                .zip(b.iter())
480                .map(|(x, y)| (x - y).powi(2))
481                .sum::<f32>()
482                .sqrt(),
483            GpuDistanceMetric::InnerProduct => {
484                let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
485                -dot
486            }
487        }
488    }
489
490    /// Simulate GPU memory transfer overhead (CPU fallback)
491    fn simulate_gpu_transfer(&self, dim: usize, num_vectors: usize) -> Duration {
492        let bytes = dim * num_vectors * 4; // f32 bytes
493        debug!(
494            "GPU transfer simulation: {} bytes ({} vectors x {} dims x 4 bytes)",
495            bytes, num_vectors, dim
496        );
497        // Simulate ~10 GB/s PCIe bandwidth
498        let transfer_ns = (bytes as f64 / 10e9 * 1e9) as u64;
499        Duration::from_nanos(transfer_ns.min(10_000_000)) // Cap at 10ms for testing
500    }
501}
502
503// ============================================================
504// IncrementalGpuIndexBuilder
505// ============================================================
506
507/// Incremental GPU index builder for streaming ingestion
508///
509/// Supports adding vectors in micro-batches and triggering GPU
510/// rebalancing operations on the HNSW graph.
511#[derive(Debug)]
512pub struct IncrementalGpuIndexBuilder {
513    inner: GpuHnswIndexBuilder,
514    /// Accumulated micro-batch
515    micro_batch: Vec<(usize, Vec<f32>)>,
516    /// Trigger rebalance when micro_batch exceeds this size
517    micro_batch_threshold: usize,
518    /// Total vectors committed to graph
519    total_committed: usize,
520    /// Optional existing graph to extend
521    base_graph: Option<HnswGraph>,
522}
523
524impl IncrementalGpuIndexBuilder {
525    /// Create a new incremental builder
526    pub fn new(config: GpuIndexBuilderConfig, micro_batch_threshold: usize) -> Result<Self> {
527        Ok(Self {
528            inner: GpuHnswIndexBuilder::new(config)?,
529            micro_batch: Vec::new(),
530            micro_batch_threshold,
531            total_committed: 0,
532            base_graph: None,
533        })
534    }
535
536    /// Add a vector to the incremental builder
537    pub fn add_vector(&mut self, id: usize, vector: Vec<f32>) -> Result<()> {
538        self.micro_batch.push((id, vector));
539        if self.micro_batch.len() >= self.micro_batch_threshold {
540            self.flush_micro_batch()?;
541        }
542        Ok(())
543    }
544
545    /// Flush any pending micro-batch and build/update the graph
546    pub fn flush_micro_batch(&mut self) -> Result<()> {
547        if self.micro_batch.is_empty() {
548            return Ok(());
549        }
550        let batch = std::mem::take(&mut self.micro_batch);
551        for (id, vec) in batch {
552            self.inner.add_vector(id, vec)?;
553        }
554        self.total_committed += self.inner.pending_vectors.len();
555        info!(
556            "Flushing micro-batch, total committed: {}",
557            self.total_committed
558        );
559        Ok(())
560    }
561
562    /// Build the final graph
563    pub fn build(mut self) -> Result<HnswGraph> {
564        self.flush_micro_batch()?;
565        self.inner.build()
566    }
567
568    /// Get count of vectors in the current micro-batch
569    pub fn pending_count(&self) -> usize {
570        self.micro_batch.len()
571    }
572
573    /// Get total vectors committed so far
574    pub fn total_committed(&self) -> usize {
575        self.total_committed
576    }
577}
578
579// ============================================================
580// GpuBatchDistanceComputer
581// ============================================================
582
583/// GPU-accelerated batch distance computation
584///
585/// Computes pairwise distances between query vectors and database vectors
586/// using GPU kernels with optional mixed-precision support.
587#[derive(Debug)]
588pub struct GpuBatchDistanceComputer {
589    config: GpuIndexBuilderConfig,
590    /// Cache of recent computations: key = (query_dim, db_size)
591    #[allow(dead_code)]
592    computation_cache: ComputationCache,
593}
594
595impl GpuBatchDistanceComputer {
596    /// Create a new batch distance computer
597    pub fn new(config: GpuIndexBuilderConfig) -> Result<Self> {
598        Ok(Self {
599            config,
600            computation_cache: Arc::new(parking_lot::RwLock::new(std::collections::HashMap::new())),
601        })
602    }
603
604    /// Compute distances between queries and database vectors
605    ///
606    /// Returns a matrix of distances: `result[q][d] = distance(queries[q], database[d])`
607    pub fn compute_distances(
608        &self,
609        queries: &[Vec<f32>],
610        database: &[Vec<f32>],
611    ) -> Result<Vec<Vec<f32>>> {
612        if queries.is_empty() || database.is_empty() {
613            return Ok(Vec::new());
614        }
615
616        let q_dim = queries[0].len();
617        let db_dim = database[0].len();
618        if q_dim != db_dim {
619            return Err(anyhow!(
620                "Query dimension {} != database dimension {}",
621                q_dim,
622                db_dim
623            ));
624        }
625
626        // In a real CUDA build, this would dispatch to GPU kernels
627        // For CPU fallback, compute directly
628        warn!("GPU distance computation running in CPU fallback mode");
629        self.compute_distances_cpu(queries, database)
630    }
631
632    /// CPU fallback for distance computation
633    fn compute_distances_cpu(
634        &self,
635        queries: &[Vec<f32>],
636        database: &[Vec<f32>],
637    ) -> Result<Vec<Vec<f32>>> {
638        let metric = self.config.distance_metric;
639        queries
640            .iter()
641            .map(|q| {
642                database
643                    .iter()
644                    .map(|d| Self::compute_single_distance(metric, q, d))
645                    .collect::<Result<Vec<f32>>>()
646            })
647            .collect()
648    }
649
650    fn compute_single_distance(metric: GpuDistanceMetric, a: &[f32], b: &[f32]) -> Result<f32> {
651        if a.len() != b.len() {
652            return Err(anyhow!("Dimension mismatch: {} != {}", a.len(), b.len()));
653        }
654        let dist = match metric {
655            GpuDistanceMetric::Cosine | GpuDistanceMetric::CosineF16 => {
656                let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
657                let na: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
658                let nb: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
659                if na < 1e-9 || nb < 1e-9 {
660                    1.0
661                } else {
662                    1.0 - dot / (na * nb)
663                }
664            }
665            GpuDistanceMetric::Euclidean | GpuDistanceMetric::EuclideanF16 => a
666                .iter()
667                .zip(b.iter())
668                .map(|(x, y)| (x - y).powi(2))
669                .sum::<f32>()
670                .sqrt(),
671            GpuDistanceMetric::InnerProduct => {
672                let dot: f32 = a.iter().zip(b.iter()).map(|(x, y)| x * y).sum();
673                -dot
674            }
675        };
676        Ok(dist)
677    }
678}
679
680// ============================================================
681// GPU Index Optimizer
682// ============================================================
683
684/// Calculates optimal batch sizes for GPU index construction
685/// based on available GPU memory and vector dimensionality.
686#[derive(Debug, Clone)]
687pub struct BatchSizeCalculator;
688
689impl BatchSizeCalculator {
690    /// Calculate optimal batch size given vector dimension and available GPU memory (MB).
691    ///
692    /// Reserves 25% of GPU memory for overhead (distance matrices, working buffers).
693    /// Returns at least 1.
694    pub fn calculate_batch_size(vector_dim: usize, gpu_memory_mb: u64) -> usize {
695        if vector_dim == 0 {
696            return 1024; // Sensible default for zero-dim edge case
697        }
698        let bytes_per_vector: u64 = (vector_dim as u64) * 4; // f32
699                                                             // Reserve 25 % for GPU overhead
700        let usable_bytes = (gpu_memory_mb as f64 * 1024.0 * 1024.0 * 0.75) as u64;
701        let raw = usable_bytes / bytes_per_vector;
702        // Cap to a sensible maximum to avoid OOM on very small vectors
703        let capped = raw.min(65536) as usize;
704        capped.max(1)
705    }
706
707    /// Optimal batch size assuming f32 vectors, with overhead for distance matrix.
708    ///
709    /// Accounts for the O(batch²) memory of a pairwise distance matrix.
710    pub fn optimal_batch_for_float32(dim: usize, memory_mb: u64) -> usize {
711        if dim == 0 {
712            return 512;
713        }
714        // Each vector: dim * 4 bytes
715        // Distance matrix for a batch of B: B * B * 4 bytes
716        // => dim*4*B + B²*4 ≤ memory_mb * 1024² * 0.70
717        // Solve quadratic: 4B² + 4*dim*B - budget = 0
718        let budget = memory_mb as f64 * 1024.0 * 1024.0 * 0.70;
719        let a = 4.0f64;
720        let b = 4.0 * dim as f64;
721        let c = -budget;
722        let discriminant = b * b - 4.0 * a * c;
723        if discriminant < 0.0 {
724            return 1;
725        }
726        let batch_f = (-b + discriminant.sqrt()) / (2.0 * a);
727        let batch = batch_f.floor() as usize;
728        batch.clamp(1, 65536)
729    }
730}
731
732/// GPU memory budget tracker for index construction.
733#[derive(Debug, Clone)]
734pub struct GpuMemoryBudget {
735    /// Total GPU memory in MB
736    pub total_mb: u64,
737    /// Memory reserved for runtime/OS overhead in MB
738    pub reserved_mb: u64,
739    /// Memory available for index construction in MB
740    pub available_mb: u64,
741}
742
743impl GpuMemoryBudget {
744    /// Create a new memory budget.
745    ///
746    /// `reserved_mb` should cover GPU runtime, kernels, and OS overhead.
747    pub fn new(total_mb: u64, reserved_mb: u64) -> Self {
748        let available_mb = total_mb.saturating_sub(reserved_mb);
749        Self {
750            total_mb,
751            reserved_mb,
752            available_mb,
753        }
754    }
755
756    /// Returns `true` if a batch of `batch_size` f32 vectors of dimension `dim`
757    /// fits within the available memory budget.
758    pub fn can_fit_batch(&self, batch_size: usize, dim: usize) -> bool {
759        let needed_bytes = self.bytes_per_vector(dim) * batch_size as u64;
760        let available_bytes = self.available_mb * 1024 * 1024;
761        needed_bytes <= available_bytes
762    }
763
764    /// Bytes required for a single f32 vector of the given dimension.
765    pub fn bytes_per_vector(&self, dim: usize) -> u64 {
766        (dim as u64) * 4 // f32 = 4 bytes
767    }
768}
769
770/// Optimises GPU memory usage during index construction by computing
771/// ideal batch sizes and checking memory feasibility.
772#[derive(Debug, Clone)]
773pub struct GpuIndexOptimizer {
774    budget: GpuMemoryBudget,
775}
776
777impl GpuIndexOptimizer {
778    /// Create an optimizer with the given total and reserved GPU memory (MB).
779    pub fn new(total_mb: u64, reserved_mb: u64) -> Self {
780        Self {
781            budget: GpuMemoryBudget::new(total_mb, reserved_mb),
782        }
783    }
784
785    /// Return a reference to the underlying memory budget.
786    pub fn memory_budget(&self) -> &GpuMemoryBudget {
787        &self.budget
788    }
789
790    /// Recommend a batch size for index construction given the vector dimension.
791    pub fn recommend_batch_size(&self, vector_dim: usize) -> usize {
792        BatchSizeCalculator::calculate_batch_size(vector_dim, self.budget.available_mb)
793    }
794
795    /// Check whether a specific batch fits within the available budget.
796    pub fn batch_fits(&self, batch_size: usize, vector_dim: usize) -> bool {
797        self.budget.can_fit_batch(batch_size, vector_dim)
798    }
799}
800
801// ============================================================
802// Pipelined Index Builder
803// ============================================================
804
805/// A batch of vectors prepared (normalised / packed) on the CPU,
806/// ready to be dispatched to a GPU compute stage.
807#[derive(Debug)]
808pub struct PreparedBatch {
809    /// Packed f32 data (flattened row-major)
810    pub data: Vec<f32>,
811    /// Number of vectors in this batch
812    pub num_vectors: usize,
813    /// Dimensionality of each vector
814    pub dim: usize,
815    /// Wall-clock timestamp of preparation
816    pub prepared_at: std::time::Instant,
817}
818
819/// A batch for which GPU distance computation has been (simulated as) completed.
820#[derive(Debug)]
821pub struct ComputedBatch {
822    /// Pairwise (self) L2 distances — simplified: per-vector L2 norm
823    pub distances: Vec<f32>,
824    /// Number of vectors
825    pub num_vectors: usize,
826    /// Dimensionality
827    pub dim: usize,
828    /// Original packed data carried forward for graph assembly
829    pub data: Vec<f32>,
830    /// Timestamp of completion
831    pub computed_at: std::time::Instant,
832}
833
834/// A fully indexed batch: neighbor IDs have been selected and are ready
835/// to be merged into the final HNSW graph.
836#[derive(Debug)]
837pub struct IndexedBatch {
838    /// Selected neighbor IDs for each vector (simplified: sorted by distance)
839    pub neighbor_ids: Vec<Vec<usize>>,
840    /// Number of vectors indexed in this batch
841    pub num_vectors: usize,
842    /// Timestamp of finalisation
843    pub finalized_at: std::time::Instant,
844}
845
846/// Overlaps CPU preparation work with simulated GPU compute to build an index
847/// in a three-stage pipeline: prepare → compute → finalize.
848///
849/// In a real CUDA build each stage would run on separate CUDA streams so that
850/// the CPU can prepare the next batch while the GPU processes the current one.
851#[derive(Debug, Clone)]
852pub struct PipelinedIndexBuilder;
853
854impl PipelinedIndexBuilder {
855    /// Stage A: CPU preparation — pack and normalise vectors.
856    pub fn stage_a_prepare(vectors: &[f32]) -> PreparedBatch {
857        let dim = vectors.len();
858        // Normalise to unit length (L2 norm)
859        let norm: f32 = vectors.iter().map(|x| x * x).sum::<f32>().sqrt();
860        let data: Vec<f32> = if norm > 1e-9 {
861            vectors.iter().map(|x| x / norm).collect()
862        } else {
863            vectors.to_vec()
864        };
865        PreparedBatch {
866            num_vectors: 1,
867            dim,
868            data,
869            prepared_at: std::time::Instant::now(),
870        }
871    }
872
873    /// Stage B: GPU compute — compute distances (CPU fallback: L2 norms).
874    pub fn stage_b_compute(batch: PreparedBatch) -> ComputedBatch {
875        // Compute L2 norm of each vector as a proxy distance to origin
876        let distances: Vec<f32> = (0..batch.num_vectors)
877            .map(|i| {
878                let start = i * batch.dim;
879                let end = start + batch.dim;
880                let slice = &batch.data[start.min(batch.data.len())..end.min(batch.data.len())];
881                slice.iter().map(|x| x * x).sum::<f32>().sqrt()
882            })
883            .collect();
884        ComputedBatch {
885            distances,
886            num_vectors: batch.num_vectors,
887            dim: batch.dim,
888            data: batch.data,
889            computed_at: std::time::Instant::now(),
890        }
891    }
892
893    /// Stage C: finalise — select neighbours and produce the indexed batch.
894    pub fn stage_c_finalize(batch: ComputedBatch) -> IndexedBatch {
895        // Sort vectors by their distance-to-origin as a simple neighbor heuristic
896        let mut indexed: Vec<(usize, f32)> = batch.distances.iter().copied().enumerate().collect();
897        indexed.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap_or(std::cmp::Ordering::Equal));
898
899        // Each vector gets the top-min(16, n) nearest indices as neighbours
900        let max_neighbors = 16_usize.min(batch.num_vectors);
901        let neighbor_ids: Vec<Vec<usize>> = (0..batch.num_vectors)
902            .map(|_| {
903                indexed
904                    .iter()
905                    .take(max_neighbors)
906                    .map(|(id, _)| *id)
907                    .collect()
908            })
909            .collect();
910
911        IndexedBatch {
912            neighbor_ids,
913            num_vectors: batch.num_vectors,
914            finalized_at: std::time::Instant::now(),
915        }
916    }
917}