use super::*;
pub(crate) fn nsw_insert_at(table: &mut Table, idx_pos: usize, new_row_idx: usize) {
let col_pos = table.indices[idx_pos].column_position;
let cell_dim: Option<usize> = match &table.rows[new_row_idx].values[col_pos] {
Value::Vector(v) => Some(v.len()),
Value::Sq8Vector(q) => Some(q.bytes.len()),
Value::HalfVector(h) => Some(h.dim()),
_ => None,
};
let Some(dim) = cell_dim else {
ensure_node_slot(table, idx_pos, new_row_idx, 0);
return;
};
if dim == 0 {
ensure_node_slot(table, idx_pos, new_row_idx, 0);
return;
}
let level = nsw_assign_level(new_row_idx);
ensure_node_slot(table, idx_pos, new_row_idx, level);
let (entry, entry_level, m) = match &table.indices[idx_pos].kind {
IndexKind::Nsw(g) => (g.entry, g.entry_level, g.m),
IndexKind::BTree(_)
| IndexKind::Brin { .. }
| IndexKind::Gin(_)
| IndexKind::GinTrgm(_)
| IndexKind::GinFulltext(_) => {
unreachable!("nsw_insert_at on a non-NSW index")
}
};
if entry.is_none() {
if let IndexKind::Nsw(g) = &mut table.indices[idx_pos].kind {
g.entry = Some(new_row_idx);
g.entry_level = level;
*g.levels
.get_mut(new_row_idx)
.expect("levels slot padded by ensure_node_slot") = level;
}
return;
}
if let IndexKind::Nsw(g) = &mut table.indices[idx_pos].kind {
*g.levels
.get_mut(new_row_idx)
.expect("levels slot padded by ensure_node_slot") = level;
}
let query = match &table.rows[new_row_idx].values[col_pos] {
Value::Vector(v) => v.clone(),
Value::Sq8Vector(q) => quantize::dequantize(q),
Value::HalfVector(h) => h.to_f32_vec(),
_ => return,
};
let mut current = entry.expect("entry was Some above");
let mut current_d = vec_l2_sq(table, col_pos, current, &query);
if entry_level > level {
for layer in (level + 1..=entry_level).rev() {
(current, current_d) =
greedy_layer_walk(table, idx_pos, layer, current, current_d, &query);
}
}
let top = level.min(entry_level);
let ef = (m * 2).max(8);
for layer in (0..=top).rev() {
let cap = if layer == 0 { m * 2 } else { m };
let mut candidates = layer_beam_search(
table,
idx_pos,
layer,
current,
current_d,
&query,
ef,
NswMetric::L2,
);
candidates.retain(|&(_, n)| n != new_row_idx);
if let Some(&(d, n)) = candidates.first() {
current = n;
current_d = d;
}
let peers = select_neighbours_heuristic(&candidates, cap, table, col_pos);
connect_at_layer(table, idx_pos, layer, new_row_idx, &peers);
}
if level > entry_level
&& let IndexKind::Nsw(g) = &mut table.indices[idx_pos].kind
{
g.entry = Some(new_row_idx);
g.entry_level = level;
}
}
fn ensure_node_slot(table: &mut Table, idx_pos: usize, new_row_idx: usize, level: u8) {
let IndexKind::Nsw(g) = &mut table.indices[idx_pos].kind else {
unreachable!("ensure_node_slot on a BTree index");
};
while g.layers.len() <= level as usize {
g.layers.push(PersistentVec::new());
}
while g.levels.len() <= new_row_idx {
g.levels.push_mut(0);
}
for layer_vec in &mut g.layers {
while layer_vec.len() <= new_row_idx {
layer_vec.push_mut(Vec::new());
}
}
}
fn greedy_layer_walk(
table: &Table,
idx_pos: usize,
layer: u8,
mut current: usize,
mut current_d: f32,
query: &[f32],
) -> (usize, f32) {
let g = match &table.indices[idx_pos].kind {
IndexKind::Nsw(g) => g,
IndexKind::BTree(_)
| IndexKind::Brin { .. }
| IndexKind::Gin(_)
| IndexKind::GinTrgm(_)
| IndexKind::GinFulltext(_) => {
return (current, current_d);
}
};
let col_pos = table.indices[idx_pos].column_position;
loop {
let neighbours: &[u32] = g
.layers
.get(layer as usize)
.and_then(|layer_v| layer_v.get(current))
.map_or(&[][..], Vec::as_slice);
let mut best = current;
let mut best_d = current_d;
for &n in neighbours {
let n = n as usize;
let d = vec_l2_sq(table, col_pos, n, query);
if d < best_d {
best = n;
best_d = d;
}
}
if best == current {
return (current, current_d);
}
current = best;
current_d = best_d;
}
}
#[allow(clippy::too_many_arguments)] fn layer_beam_search(
table: &Table,
idx_pos: usize,
layer: u8,
entry_node: usize,
entry_d: f32,
query: &[f32],
ef: usize,
metric: NswMetric,
) -> Vec<(f32, usize)> {
let g = match &table.indices[idx_pos].kind {
IndexKind::Nsw(g) => g,
IndexKind::BTree(_)
| IndexKind::Brin { .. }
| IndexKind::Gin(_)
| IndexKind::GinTrgm(_)
| IndexKind::GinFulltext(_) => return Vec::new(),
};
let col_pos = table.indices[idx_pos].column_position;
let d0 = if matches!(metric, NswMetric::L2) {
entry_d
} else {
cell_to_query_metric_distance(table, col_pos, entry_node, query, metric)
};
let row_count = table.rows.len();
let mut visited: Vec<bool> = alloc::vec![false; row_count];
if entry_node < row_count {
visited[entry_node] = true;
}
let mut candidates: alloc::collections::BinaryHeap<NodeClosest> =
alloc::collections::BinaryHeap::with_capacity(ef);
let mut results: alloc::collections::BinaryHeap<NodeFurthest> =
alloc::collections::BinaryHeap::with_capacity(ef);
candidates.push(NodeClosest {
dist: d0,
node: entry_node,
});
results.push(NodeFurthest {
dist: d0,
node: entry_node,
});
while let Some(cur) = candidates.pop() {
let worst = results.peek().map_or(f32::INFINITY, |c| c.dist);
if cur.dist > worst && results.len() >= ef {
break;
}
let neighbours: &[u32] = g
.layers
.get(layer as usize)
.and_then(|layer_v| layer_v.get(cur.node))
.map_or(&[][..], Vec::as_slice);
for &n in neighbours {
let n = n as usize;
if n >= row_count || visited[n] {
continue;
}
visited[n] = true;
let dn = cell_to_query_metric_distance(table, col_pos, n, query, metric);
if !dn.is_finite() {
continue;
}
let worst = results.peek().map_or(f32::INFINITY, |c| c.dist);
if results.len() < ef || dn < worst {
results.push(NodeFurthest { dist: dn, node: n });
if results.len() > ef {
results.pop();
}
candidates.push(NodeClosest { dist: dn, node: n });
}
}
}
let mut out: Vec<(f32, usize)> = results.into_iter().map(|c| (c.dist, c.node)).collect();
out.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(core::cmp::Ordering::Equal));
out
}
#[derive(Debug, Clone, Copy)]
struct NodeClosest {
dist: f32,
node: usize,
}
impl PartialEq for NodeClosest {
fn eq(&self, other: &Self) -> bool {
self.dist == other.dist && self.node == other.node
}
}
impl Eq for NodeClosest {}
impl PartialOrd for NodeClosest {
fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl Ord for NodeClosest {
fn cmp(&self, other: &Self) -> core::cmp::Ordering {
other
.dist
.partial_cmp(&self.dist)
.unwrap_or(core::cmp::Ordering::Equal)
}
}
#[derive(Debug, Clone, Copy)]
struct NodeFurthest {
dist: f32,
node: usize,
}
impl PartialEq for NodeFurthest {
fn eq(&self, other: &Self) -> bool {
self.dist == other.dist && self.node == other.node
}
}
impl Eq for NodeFurthest {}
impl PartialOrd for NodeFurthest {
fn partial_cmp(&self, other: &Self) -> Option<core::cmp::Ordering> {
Some(self.cmp(other))
}
}
impl Ord for NodeFurthest {
fn cmp(&self, other: &Self) -> core::cmp::Ordering {
self.dist
.partial_cmp(&other.dist)
.unwrap_or(core::cmp::Ordering::Equal)
}
}
fn select_neighbours_heuristic(
candidates: &[(f32, usize)],
m: usize,
table: &Table,
col_pos: usize,
) -> Vec<usize> {
let mut chosen: Vec<usize> = Vec::with_capacity(m);
for &(d_q, e) in candidates {
if chosen.len() >= m {
break;
}
if !matches!(
table.rows.get(e).and_then(|r| r.values.get(col_pos)),
Some(Value::Vector(_) | Value::Sq8Vector(_) | Value::HalfVector(_))
) {
continue;
}
let mut covered = false;
for &r in &chosen {
if cell_l2_sq(table, col_pos, e, r) < d_q {
covered = true;
break;
}
}
if !covered {
chosen.push(e);
}
}
chosen
}
fn connect_at_layer(
table: &mut Table,
idx_pos: usize,
layer: u8,
new_row_idx: usize,
peers: &[usize],
) {
let col_pos = table.indices[idx_pos].column_position;
let cap = match &table.indices[idx_pos].kind {
IndexKind::Nsw(g) => g.cap_for_layer(layer),
IndexKind::BTree(_)
| IndexKind::Brin { .. }
| IndexKind::Gin(_)
| IndexKind::GinTrgm(_)
| IndexKind::GinFulltext(_) => return,
};
let new_row_u32 = u32::try_from(new_row_idx).expect("row index fits in u32");
if let IndexKind::Nsw(g) = &mut table.indices[idx_pos].kind {
let layer_v = &mut g.layers[layer as usize];
if let Some(slot) = layer_v.get_mut(new_row_idx) {
*slot = peers
.iter()
.map(|&p| u32::try_from(p).expect("row index fits in u32"))
.collect();
}
}
for &peer in peers {
if !matches!(
&table.rows[peer].values[col_pos],
Value::Vector(_) | Value::Sq8Vector(_) | Value::HalfVector(_)
) {
continue;
}
if let IndexKind::Nsw(g) = &mut table.indices[idx_pos].kind {
let layer_v = &mut g.layers[layer as usize];
if let Some(slot) = layer_v.get_mut(peer)
&& !slot.contains(&new_row_u32)
{
slot.push(new_row_u32);
}
}
let needs_trim = match &table.indices[idx_pos].kind {
IndexKind::Nsw(g) => g.layers[layer as usize][peer].len() > cap,
IndexKind::BTree(_)
| IndexKind::Brin { .. }
| IndexKind::Gin(_)
| IndexKind::GinTrgm(_)
| IndexKind::GinFulltext(_) => false,
};
if needs_trim {
let current_peers: Vec<usize> = match &table.indices[idx_pos].kind {
IndexKind::Nsw(g) => g.layers[layer as usize][peer]
.iter()
.map(|&n| n as usize)
.collect(),
IndexKind::BTree(_)
| IndexKind::Brin { .. }
| IndexKind::Gin(_)
| IndexKind::GinTrgm(_)
| IndexKind::GinFulltext(_) => continue,
};
let mut tagged: Vec<(f32, usize)> = current_peers
.iter()
.map(|&p| (cell_l2_sq(table, col_pos, peer, p), p))
.collect();
tagged.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(core::cmp::Ordering::Equal));
let kept = select_neighbours_heuristic(&tagged, cap, table, col_pos);
if let IndexKind::Nsw(g) = &mut table.indices[idx_pos].kind
&& let Some(slot) = g.layers[layer as usize].get_mut(peer)
{
*slot = kept
.into_iter()
.map(|p| u32::try_from(p).expect("row index fits in u32"))
.collect();
}
}
}
}
fn vec_l2_sq(table: &Table, col_pos: usize, row: usize, query: &[f32]) -> f32 {
match table.rows.get(row).and_then(|r| r.values.get(col_pos)) {
Some(Value::Vector(v)) if v.len() == query.len() => l2_distance_sq(v, query),
Some(Value::Sq8Vector(q)) if q.bytes.len() == query.len() => {
quantize::sq8_l2_distance_sq_asymmetric(q, query)
}
Some(Value::HalfVector(h)) if h.dim() == query.len() => {
halfvec::half_l2_distance_sq_asymmetric(h, query)
}
_ => f32::INFINITY,
}
}
fn cell_l2_sq(table: &Table, col_pos: usize, row_a: usize, row_b: usize) -> f32 {
let Some(cell_a) = table.rows.get(row_a).and_then(|r| r.values.get(col_pos)) else {
return f32::INFINITY;
};
let Some(cell_b) = table.rows.get(row_b).and_then(|r| r.values.get(col_pos)) else {
return f32::INFINITY;
};
match (cell_a, cell_b) {
(Value::Vector(a), Value::Vector(b)) if a.len() == b.len() => l2_distance_sq(a, b),
(Value::Sq8Vector(a), Value::Sq8Vector(b)) if a.bytes.len() == b.bytes.len() => {
quantize::sq8_l2_distance_sq(a, b)
}
(Value::HalfVector(a), Value::HalfVector(b)) if a.dim() == b.dim() => {
halfvec::half_l2_distance_sq(a, b)
}
_ => f32::INFINITY,
}
}
fn cell_to_query_metric_distance(
table: &Table,
col_pos: usize,
row: usize,
query: &[f32],
metric: NswMetric,
) -> f32 {
match table.rows.get(row).and_then(|r| r.values.get(col_pos)) {
Some(Value::Vector(v)) if v.len() == query.len() => metric_distance(metric, v, query),
Some(Value::Sq8Vector(q)) if q.bytes.len() == query.len() => match metric {
NswMetric::L2 => quantize::sq8_l2_distance_sq_asymmetric(q, query),
NswMetric::InnerProduct => quantize::sq8_inner_product_asymmetric(q, query),
NswMetric::Cosine => quantize::sq8_cosine_distance_asymmetric(q, query),
},
Some(Value::HalfVector(h)) if h.dim() == query.len() => match metric {
NswMetric::L2 => halfvec::half_l2_distance_sq_asymmetric(h, query),
NswMetric::InnerProduct => halfvec::half_inner_product_asymmetric(h, query),
NswMetric::Cosine => halfvec::half_cosine_distance_asymmetric(h, query),
},
_ => f32::INFINITY,
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum NswMetric {
L2,
InnerProduct,
Cosine,
}
pub(crate) fn nsw_search(
table: &Table,
idx_pos: usize,
query: &[f32],
k: usize,
ef: usize,
metric: NswMetric,
) -> Vec<(f32, usize)> {
let (entry, entry_level) = match &table.indices[idx_pos].kind {
IndexKind::Nsw(g) => (g.entry, g.entry_level),
IndexKind::BTree(_)
| IndexKind::Brin { .. }
| IndexKind::Gin(_)
| IndexKind::GinTrgm(_)
| IndexKind::GinFulltext(_) => return Vec::new(),
};
let Some(entry) = entry else {
return Vec::new();
};
let col_pos = table.indices[idx_pos].column_position;
let sq8 = matches!(
table.schema.columns.get(col_pos).map(|c| c.ty),
Some(DataType::Vector {
encoding: VecEncoding::Sq8,
..
})
);
let ef = if sq8 {
ef.max(k).max(k * SQ8_RERANK_OVER_FETCH)
} else {
ef.max(k)
};
let entry_d = vec_l2_sq(table, col_pos, entry, query);
let mut current = entry;
let mut current_d = entry_d;
for layer in (1..=entry_level).rev() {
(current, current_d) = greedy_layer_walk(table, idx_pos, layer, current, current_d, query);
}
let mut results = layer_beam_search(table, idx_pos, 0, current, current_d, query, ef, metric);
if sq8 {
results = sq8_rerank(table, col_pos, &results, query, metric);
}
results.truncate(k);
results
}
fn sq8_rerank(
table: &Table,
col_pos: usize,
candidates: &[(f32, usize)],
query: &[f32],
metric: NswMetric,
) -> Vec<(f32, usize)> {
let mut out: Vec<(f32, usize)> = candidates
.iter()
.filter_map(|&(adc_d, row)| {
let cell = table.rows.get(row).and_then(|r| r.values.get(col_pos))?;
let Value::Sq8Vector(q) = cell else {
return Some((adc_d, row));
};
let deq = quantize::dequantize(q);
if deq.len() != query.len() {
return None;
}
Some((metric_distance(metric, &deq, query), row))
})
.collect();
out.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap_or(core::cmp::Ordering::Equal));
out
}
const SQ8_RERANK_OVER_FETCH: usize = 3;
fn metric_distance(metric: NswMetric, a: &[f32], b: &[f32]) -> f32 {
match metric {
NswMetric::L2 => l2_distance_sq(a, b),
NswMetric::InnerProduct => -inner_product_f32(a, b),
NswMetric::Cosine => {
let (dot, na, nb) = cosine_dot_norms_f32(a, b);
if na == 0.0 || nb == 0.0 {
return f32::INFINITY;
}
let denom = sqrt_newton_f32(na) * sqrt_newton_f32(nb);
1.0 - dot / denom
}
}
}
#[doc(hidden)]
#[inline]
pub fn inner_product_f32(a: &[f32], b: &[f32]) -> f32 {
#[cfg(target_arch = "aarch64")]
{
if a.len() == b.len() && a.len() >= 4 && a.len().is_multiple_of(4) {
return unsafe { inner_product_neon(a, b) };
}
}
inner_product_scalar(a, b)
}
pub(crate) fn inner_product_scalar(a: &[f32], b: &[f32]) -> f32 {
let mut dot: f32 = 0.0;
for (x, y) in a.iter().zip(b.iter()) {
dot += x * y;
}
dot
}
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
#[allow(clippy::many_single_char_names)] pub(crate) unsafe fn inner_product_neon(a: &[f32], b: &[f32]) -> f32 {
use core::arch::aarch64::{
float32x4_t, vaddq_f32, vaddvq_f32, vdupq_n_f32, vfmaq_f32, vld1q_f32,
};
unsafe {
let zero: float32x4_t = vdupq_n_f32(0.0);
let mut acc0 = zero;
let mut acc1 = zero;
let n = a.len();
let mut i = 0usize;
while i + 8 <= n {
let av0 = vld1q_f32(a.as_ptr().add(i));
let bv0 = vld1q_f32(b.as_ptr().add(i));
acc0 = vfmaq_f32(acc0, av0, bv0);
let av1 = vld1q_f32(a.as_ptr().add(i + 4));
let bv1 = vld1q_f32(b.as_ptr().add(i + 4));
acc1 = vfmaq_f32(acc1, av1, bv1);
i += 8;
}
while i + 4 <= n {
let av = vld1q_f32(a.as_ptr().add(i));
let bv = vld1q_f32(b.as_ptr().add(i));
acc0 = vfmaq_f32(acc0, av, bv);
i += 4;
}
vaddvq_f32(vaddq_f32(acc0, acc1))
}
}
#[doc(hidden)]
#[inline]
pub fn cosine_dot_norms_f32(a: &[f32], b: &[f32]) -> (f32, f32, f32) {
#[cfg(target_arch = "aarch64")]
{
if a.len() == b.len() && a.len() >= 4 && a.len().is_multiple_of(4) {
return unsafe { cosine_dot_norms_neon(a, b) };
}
}
cosine_dot_norms_scalar(a, b)
}
pub(crate) fn cosine_dot_norms_scalar(a: &[f32], b: &[f32]) -> (f32, f32, f32) {
let mut dot: f32 = 0.0;
let mut na: f32 = 0.0;
let mut nb: f32 = 0.0;
for (x, y) in a.iter().zip(b.iter()) {
dot += x * y;
na += x * x;
nb += y * y;
}
(dot, na, nb)
}
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
#[allow(clippy::many_single_char_names, clippy::similar_names)]
pub(crate) unsafe fn cosine_dot_norms_neon(a: &[f32], b: &[f32]) -> (f32, f32, f32) {
use core::arch::aarch64::{float32x4_t, vaddvq_f32, vdupq_n_f32, vfmaq_f32, vld1q_f32};
unsafe {
let zero: float32x4_t = vdupq_n_f32(0.0);
let mut acc_dot = zero;
let mut acc_na = zero;
let mut acc_nb = zero;
let n = a.len();
let mut i = 0usize;
while i + 4 <= n {
let av = vld1q_f32(a.as_ptr().add(i));
let bv = vld1q_f32(b.as_ptr().add(i));
acc_dot = vfmaq_f32(acc_dot, av, bv);
acc_na = vfmaq_f32(acc_na, av, av);
acc_nb = vfmaq_f32(acc_nb, bv, bv);
i += 4;
}
(vaddvq_f32(acc_dot), vaddvq_f32(acc_na), vaddvq_f32(acc_nb))
}
}
fn sqrt_newton_f32(x: f32) -> f32 {
if x <= 0.0 {
return 0.0;
}
let mut g = x;
for _ in 0..10 {
g = 0.5 * (g + x / g);
}
g
}
#[inline]
pub(crate) fn l2_distance_sq(a: &[f32], b: &[f32]) -> f32 {
#[cfg(target_arch = "aarch64")]
{
if a.len() == b.len() && a.len() >= 4 && a.len().is_multiple_of(4) {
return unsafe { l2_distance_sq_neon(a, b) };
}
}
l2_distance_sq_scalar(a, b)
}
pub(crate) fn l2_distance_sq_scalar(a: &[f32], b: &[f32]) -> f32 {
let mut sum: f32 = 0.0;
for (x, y) in a.iter().zip(b.iter()) {
let d = *x - *y;
sum += d * d;
}
sum
}
#[cfg(target_arch = "aarch64")]
#[target_feature(enable = "neon")]
#[allow(clippy::many_single_char_names)] pub(crate) unsafe fn l2_distance_sq_neon(a: &[f32], b: &[f32]) -> f32 {
use core::arch::aarch64::{
float32x4_t, vaddq_f32, vaddvq_f32, vdupq_n_f32, vfmaq_f32, vld1q_f32, vsubq_f32,
};
unsafe {
let zero: float32x4_t = vdupq_n_f32(0.0);
let mut acc0 = zero;
let mut acc1 = zero;
let n = a.len();
let mut i = 0usize;
while i + 8 <= n {
let d0 = vsubq_f32(vld1q_f32(a.as_ptr().add(i)), vld1q_f32(b.as_ptr().add(i)));
acc0 = vfmaq_f32(acc0, d0, d0);
let d1 = vsubq_f32(
vld1q_f32(a.as_ptr().add(i + 4)),
vld1q_f32(b.as_ptr().add(i + 4)),
);
acc1 = vfmaq_f32(acc1, d1, d1);
i += 8;
}
while i + 4 <= n {
let d = vsubq_f32(vld1q_f32(a.as_ptr().add(i)), vld1q_f32(b.as_ptr().add(i)));
acc0 = vfmaq_f32(acc0, d, d);
i += 4;
}
vaddvq_f32(vaddq_f32(acc0, acc1))
}
}
pub fn nsw_query(
table: &Table,
idx_name: &str,
query: &[f32],
k: usize,
metric: NswMetric,
) -> Vec<usize> {
let Some(idx_pos) = table.indices.iter().position(|i| i.name == idx_name) else {
return Vec::new();
};
let ef = (k * 2).max(NSW_DEFAULT_M);
let mut hits = nsw_search(table, idx_pos, query, k, ef, metric);
hits.truncate(k);
hits.into_iter().map(|(_, idx)| idx).collect()
}
pub fn nsw_index_on(table: &Table, column_position: usize) -> Option<&Index> {
table
.indices
.iter()
.find(|i| i.column_position == column_position && matches!(i.kind, IndexKind::Nsw(_)))
}