use super::linear_model::LinearModel;
use super::simd_search;
use anyhow::Result;
use std::fmt;
const MAX_DENSITY: f64 = 0.95;
#[allow(dead_code)]
const MIN_DENSITY: f64 = 0.3;
#[derive(Debug, Clone)]
pub struct GappedNode {
keys: Vec<Option<i64>>,
values: Vec<Option<Vec<u8>>>,
model: LinearModel,
expansion_factor: f64,
num_keys: usize,
max_error_bound: usize,
}
impl GappedNode {
#[must_use]
pub fn new(expected_keys: usize, expansion_factor: f64) -> Self {
let capacity = ((expected_keys as f64 * (1.0 + expansion_factor)).ceil() as usize).max(4);
Self {
keys: vec![None; capacity],
values: vec![None; capacity],
model: LinearModel::new(),
expansion_factor,
num_keys: 0,
max_error_bound: capacity / 4, }
}
pub fn insert(&mut self, key: i64, value: Vec<u8>) -> Result<bool> {
if self.density() >= MAX_DENSITY {
return Ok(false); }
let pos = self.find_insert_position(key)?;
if pos < self.keys.len() && self.keys[pos].is_none() {
self.keys[pos] = Some(key);
self.values[pos] = Some(value);
self.num_keys += 1;
} else if pos < self.keys.len() {
self.shift_and_insert(pos, key, value);
} else {
return Err(anyhow::anyhow!("Insert position out of bounds"));
}
Ok(true)
}
pub fn insert_batch(&mut self, entries: &[(i64, Vec<u8>)]) -> Result<bool> {
if entries.is_empty() {
return Ok(true);
}
let density_after = (self.num_keys + entries.len()) as f64 / self.keys.len() as f64;
if density_after >= MAX_DENSITY {
return Ok(false); }
let mut sorted_entries: Vec<(i64, Vec<u8>)> = entries.to_vec();
sorted_entries.sort_unstable_by_key(|(k, _)| *k);
for (key, value) in sorted_entries {
let pos = self.find_insert_position(key)?;
if pos < self.keys.len() && self.keys[pos].is_none() {
self.keys[pos] = Some(key);
self.values[pos] = Some(value);
self.num_keys += 1;
} else if pos < self.keys.len() {
self.shift_and_insert(pos, key, value);
}
}
Ok(true)
}
pub fn get(&self, key: i64) -> Result<Option<Vec<u8>>> {
if self.num_keys == 0 {
return Ok(None);
}
let predicted_pos = self.model.predict(key).min(self.keys.len() - 1);
if let Some(actual_pos) = self.exponential_search(key, predicted_pos) {
Ok(self.values[actual_pos].clone())
} else {
Ok(None)
}
}
fn find_insert_position(&self, key: i64) -> Result<usize> {
if self.num_keys == 0 {
return Ok(0); }
let predicted_pos = self.model.predict(key).min(self.keys.len() - 1);
let mut radius = 1;
let max_radius = self.max_error_bound.max(16);
loop {
let start = predicted_pos.saturating_sub(radius);
let end = (predicted_pos + radius).min(self.keys.len());
let start_key = self.get_key_at(start);
let end_key = self.get_key_at(end.saturating_sub(1));
let can_bound = match (start_key, end_key) {
(Some(sk), Some(ek)) => sk <= key && key <= ek,
_ => false, };
if can_bound {
return Ok(self.binary_search_gap(start, end, key));
}
if radius >= max_radius {
return Ok(self.binary_search_gap(0, self.keys.len(), key));
}
radius *= 2;
if radius > self.keys.len() {
return Ok(self.binary_search_gap(0, self.keys.len(), key));
}
}
}
fn exponential_search(&self, key: i64, predicted_pos: usize) -> Option<usize> {
let mut radius = 1;
let max_radius = self.max_error_bound.max(16);
loop {
let start = predicted_pos.saturating_sub(radius);
let end = (predicted_pos + radius).min(self.keys.len());
let start_key = self.get_key_at(start);
let end_key = self.get_key_at(end.saturating_sub(1));
let can_bound = match (start_key, end_key) {
(Some(sk), Some(ek)) => sk <= key && key <= ek,
_ => false, };
if can_bound {
return self.binary_search_exact(start, end, key);
}
if radius >= max_radius {
return self.binary_search_exact(0, self.keys.len(), key);
}
radius *= 2;
if start == 0 && end == self.keys.len() {
return self.binary_search_exact(0, self.keys.len(), key);
}
if radius > self.keys.len() {
return self.binary_search_exact(0, self.keys.len(), key);
}
}
}
fn binary_search_exact(&self, start: usize, end: usize, key: i64) -> Option<usize> {
simd_search::simd_search_i64(&self.keys[start..end], key).map(|pos| start + pos)
}
fn binary_search_gap(&self, start: usize, end: usize, key: i64) -> usize {
for i in start..end {
if let Some(k) = self.keys[i] {
if k >= key {
return i;
}
} else {
if i == 0 || self.keys[i - 1].is_none_or(|k| k < key) {
return i;
}
}
}
end.saturating_sub(1).min(self.keys.len().saturating_sub(1))
}
fn find_nearest_gap(&self, pos: usize) -> usize {
for radius in 0..self.keys.len() {
if pos + radius < self.keys.len() && self.keys[pos + radius].is_none() {
return pos + radius;
}
if radius <= pos && self.keys[pos - radius].is_none() {
return pos - radius;
}
}
pos
}
#[allow(dead_code)]
fn find_any_gap(&self) -> usize {
self.keys
.iter()
.position(std::option::Option::is_none)
.unwrap_or(self.keys.len() - 1)
}
fn get_key_at(&self, pos: usize) -> Option<i64> {
if pos >= self.keys.len() {
None
} else {
self.keys[pos]
}
}
fn shift_and_insert(&mut self, pos: usize, key: i64, value: Vec<u8>) {
let gap_pos = self.find_nearest_gap(pos);
if gap_pos > pos {
for i in (pos + 1..=gap_pos).rev() {
self.keys[i] = self.keys[i - 1];
self.values[i] = self.values[i - 1].clone();
}
} else {
for i in gap_pos..pos {
self.keys[i] = self.keys[i + 1];
self.values[i] = self.values[i + 1].clone();
}
}
self.keys[pos] = Some(key);
self.values[pos] = Some(value);
self.num_keys += 1;
}
#[must_use]
pub fn needs_retrain(&self) -> bool {
if self.num_keys < 10 {
return false; }
let mut data: Vec<(i64, usize)> = self
.keys
.iter()
.enumerate()
.filter_map(|(pos, key)| key.map(|k| (k, pos)))
.collect();
if data.is_empty() {
return false;
}
data.sort_unstable_by_key(|(k, _)| *k);
let current_error = self.model.max_error(&data);
let error_threshold = (self.keys.len() as f64 * 0.2) as usize;
current_error > error_threshold.max(50) }
pub fn retrain(&mut self) -> Result<()> {
let mut data: Vec<(i64, usize)> = self
.keys
.iter()
.enumerate()
.filter_map(|(pos, key)| key.map(|k| (k, pos)))
.collect();
if !data.is_empty() {
data.sort_unstable_by_key(|(k, _)| *k);
self.model.train(&data);
self.max_error_bound = self.model.max_error(&data).max(4);
}
Ok(())
}
#[must_use]
pub fn keys_only(&self) -> Vec<(i64, usize)> {
self.keys
.iter()
.enumerate()
.filter_map(|(pos, key_opt)| key_opt.map(|key| (key, pos)))
.collect()
}
#[must_use]
pub fn lower_bound_position(&self, search_key: i64) -> Option<(i64, usize)> {
if self.num_keys == 0 {
return None;
}
let predicted_pos = self.model.predict(search_key).min(self.keys.len() - 1);
let mut radius = 1;
let max_radius = self.max_error_bound.max(16);
loop {
let start = predicted_pos.saturating_sub(radius);
let end = (predicted_pos + radius).min(self.keys.len());
let mut candidate: Option<(i64, usize)> = None;
for pos in start..end {
if let Some(key) = self.keys[pos] {
if key >= search_key {
if candidate.as_ref().is_none_or(|c| key < c.0) {
candidate = Some((key, pos));
}
if key == search_key
|| (pos > 0 && self.keys[pos - 1].is_some_and(|k| k < search_key))
{
break;
}
}
}
}
if let Some(result) = candidate {
if start == 0 {
return Some(result);
}
}
if radius >= max_radius || (start == 0 && end == self.keys.len()) {
if let Some(result) = candidate {
return Some(result);
}
return self
.keys
.iter()
.enumerate()
.filter_map(|(pos, key_opt)| key_opt.map(|k| (k, pos)))
.find(|(k, _)| *k >= search_key);
}
radius *= 2;
}
}
#[must_use]
pub fn density(&self) -> f64 {
self.num_keys as f64 / self.keys.len() as f64
}
#[must_use]
pub const fn num_keys(&self) -> usize {
self.num_keys
}
#[must_use]
pub const fn capacity(&self) -> usize {
self.keys.len()
}
#[must_use]
pub fn should_split(&self) -> bool {
self.density() >= MAX_DENSITY
}
pub fn split(&mut self) -> Result<(i64, Self)> {
if !self.should_split() {
return Err(anyhow::anyhow!(
"Node doesn't need splitting (density < MAX_DENSITY)"
));
}
let mut pairs = self.pairs();
if pairs.is_empty() {
return Err(anyhow::anyhow!("Cannot split empty node"));
}
let split_idx = pairs.len() / 2;
let split_key = pairs[split_idx].0;
let left_size = split_idx;
let right_size = pairs.len() - split_idx;
let expansion = self.expansion_factor.max(1.0);
let mut left = Self::new(left_size, expansion);
let mut right = Self::new(right_size, expansion);
for (key, value) in pairs.drain(..split_idx) {
left.insert(key, value)?;
}
for (key, value) in pairs {
right.insert(key, value)?;
}
*self = left;
Ok((split_key, right))
}
#[must_use]
pub fn into_pairs(self) -> Vec<(i64, Vec<u8>)> {
let mut pairs: Vec<(i64, Vec<u8>)> = self
.keys
.into_iter()
.zip(self.values)
.filter_map(|(k, v)| {
if let (Some(key), Some(value)) = (k, v) {
Some((key, value))
} else {
None
}
})
.collect();
pairs.sort_unstable_by_key(|(k, _)| *k);
pairs
}
#[must_use]
pub fn pairs(&self) -> Vec<(i64, Vec<u8>)> {
let mut pairs: Vec<(i64, Vec<u8>)> = self
.keys
.iter()
.zip(self.values.iter())
.filter_map(|(k, v)| match (k, v) {
(Some(key), Some(value)) => Some((*key, value.clone())),
_ => None,
})
.collect();
pairs.sort_unstable_by_key(|(k, _)| *k);
pairs
}
}
impl fmt::Display for GappedNode {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(
f,
"GappedNode(keys={}/{}, density={:.1}%, model={})",
self.num_keys,
self.capacity(),
self.density() * 100.0,
self.model
)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_new_node() {
let node = GappedNode::new(100, 1.0);
assert_eq!(node.capacity(), 200); assert_eq!(node.num_keys(), 0);
assert_eq!(node.density(), 0.0);
}
#[test]
fn test_insert_sequential() {
let mut node = GappedNode::new(10, 1.0);
for i in 0..10 {
assert!(node.insert(i * 10, vec![i as u8]).unwrap());
}
assert_eq!(node.num_keys(), 10);
assert!(node.density() < 0.6); }
#[test]
fn test_insert_out_of_order() {
let mut node = GappedNode::new(10, 1.0);
node.insert(50, vec![5]).unwrap();
node.insert(10, vec![1]).unwrap();
node.insert(30, vec![3]).unwrap();
node.insert(20, vec![2]).unwrap();
node.insert(40, vec![4]).unwrap();
assert_eq!(node.num_keys(), 5);
node.retrain().unwrap();
assert_eq!(node.get(10).unwrap(), Some(vec![1]));
assert_eq!(node.get(30).unwrap(), Some(vec![3]));
assert_eq!(node.get(50).unwrap(), Some(vec![5]));
assert_eq!(node.get(99).unwrap(), None);
}
#[test]
fn test_get_nonexistent() {
let mut node = GappedNode::new(10, 1.0);
node.insert(10, vec![1]).unwrap();
node.insert(20, vec![2]).unwrap();
node.retrain().unwrap();
assert_eq!(node.get(5).unwrap(), None);
assert_eq!(node.get(15).unwrap(), None);
assert_eq!(node.get(25).unwrap(), None);
}
#[test]
fn test_density_threshold() {
let mut node = GappedNode::new(10, 0.0);
for i in 0..10 {
assert!(node.insert(i, vec![i as u8]).unwrap());
}
let result = node.insert(99, vec![99]).unwrap();
assert!(!result); }
#[test]
fn test_retrain_improves_accuracy() {
let mut node = GappedNode::new(100, 1.0);
for i in 0..50 {
node.insert(i * 2, vec![i as u8]).unwrap();
}
node.retrain().unwrap();
for i in 0..50 {
let key = i * 2;
assert_eq!(node.get(key).unwrap(), Some(vec![i as u8]));
}
}
#[test]
fn test_into_pairs() {
let mut node = GappedNode::new(10, 1.0);
node.insert(30, vec![3]).unwrap();
node.insert(10, vec![1]).unwrap();
node.insert(20, vec![2]).unwrap();
let pairs = node.into_pairs();
assert_eq!(pairs.len(), 3);
assert_eq!(pairs[0], (10, vec![1]));
assert_eq!(pairs[1], (20, vec![2]));
assert_eq!(pairs[2], (30, vec![3]));
}
#[test]
fn test_large_scale() {
let mut node = GappedNode::new(1000, 1.0);
for i in 0..1000 {
let key = i * 7 % 10000; node.insert(key, vec![(i % 256) as u8]).unwrap();
}
assert_eq!(node.num_keys(), 1000);
assert!(node.density() < 0.6);
node.retrain().unwrap();
for i in (0..1000).step_by(10) {
let key = i * 7 % 10000;
assert!(
node.get(key).unwrap().is_some(),
"Failed to find key={}",
key
);
}
}
#[test]
fn test_duplicate_inserts() {
let mut node = GappedNode::new(10, 1.0);
node.insert(10, vec![1]).unwrap();
node.insert(10, vec![2]).unwrap();
assert_eq!(node.num_keys(), 2);
}
#[test]
fn test_node_split() {
let mut node = GappedNode::new(10, 0.0);
for i in 0..10 {
node.insert(i * 10, vec![i as u8]).unwrap();
}
assert!(node.should_split());
assert_eq!(node.num_keys(), 10);
let (split_key, right) = node.split().unwrap();
assert_eq!(split_key, 50);
assert_eq!(node.num_keys(), 5);
assert!(node.get(0).unwrap().is_some());
assert!(node.get(10).unwrap().is_some());
assert!(node.get(20).unwrap().is_some());
assert!(node.get(30).unwrap().is_some());
assert!(node.get(40).unwrap().is_some());
assert!(node.get(50).unwrap().is_none());
assert_eq!(right.num_keys(), 5);
assert!(right.get(50).unwrap().is_some());
assert!(right.get(60).unwrap().is_some());
assert!(right.get(70).unwrap().is_some());
assert!(right.get(80).unwrap().is_some());
assert!(right.get(90).unwrap().is_some());
assert!(right.get(40).unwrap().is_none());
assert!(node.density() < 0.6);
assert!(right.density() < 0.6);
}
#[test]
fn test_expansion_factors() {
let node_high = GappedNode::new(100, 2.0); assert_eq!(node_high.capacity(), 300);
let node_low = GappedNode::new(100, 0.5); assert_eq!(node_low.capacity(), 150);
}
}