use super::*;
#[pyclass(module = "oasysdb.collection")]
#[derive(Debug, Serialize, Deserialize, Clone, Copy)]
pub struct Config {
#[pyo3(get, set)]
pub ef_construction: usize,
#[pyo3(get, set)]
pub ef_search: usize,
#[pyo3(get, set)]
pub ml: f32,
}
#[pymethods]
impl Config {
#[new]
pub fn new(ef_construction: usize, ef_search: usize, ml: f32) -> Self {
Self { ef_construction, ef_search, ml }
}
#[staticmethod]
fn create_default() -> Self {
Self::default()
}
fn __repr__(&self) -> String {
format!("{:?}", self)
}
}
impl Default for Config {
fn default() -> Self {
Self { ef_construction: 40, ef_search: 15, ml: 0.3 }
}
}
#[pyclass(module = "oasysdb.collection")]
#[derive(Debug, Serialize, Deserialize, Clone)]
pub struct Collection {
#[pyo3(get)]
pub config: Config,
data: HashMap<VectorID, Metadata>,
vectors: HashMap<VectorID, Vector>,
slots: Vec<VectorID>,
base_layer: Vec<BaseNode>,
upper_layers: Vec<Vec<UpperNode>>,
count: usize,
dimension: usize,
}
impl Index<&VectorID> for Collection {
type Output = Vector;
fn index(&self, index: &VectorID) -> &Self::Output {
&self.vectors[index]
}
}
#[pymethods]
impl Collection {
#[new]
pub fn new(config: &Config) -> Self {
Self {
config: *config,
count: 0,
dimension: 0,
data: HashMap::new(),
vectors: HashMap::new(),
slots: vec![],
base_layer: vec![],
upper_layers: vec![],
}
}
#[staticmethod]
fn from_records(
config: &Config,
records: Vec<Record>,
) -> Result<Self, Error> {
Self::build(config, &records)
}
pub fn insert(&mut self, record: &Record) -> Result<(), Error> {
if self.slots.len() == u32::MAX as usize {
return Err(Error::collection_limit());
}
if self.vectors.is_empty() && self.dimension == 0 {
self.dimension = record.vector.len();
} else if record.vector.len() != self.dimension {
let len = record.vector.len();
let err = Error::invalid_dimension(len, self.dimension);
return Err(err);
}
let id: VectorID = self.slots.len().into();
self.vectors.insert(id, record.vector.clone());
self.data.insert(id, record.data.clone());
self.slots.push(id);
self.count += 1;
self.insert_to_layers(&id);
Ok(())
}
pub fn delete(&mut self, id: &VectorID) -> Result<(), Error> {
if !self.contains(id) {
return Err(Error::record_not_found());
}
self.delete_from_layers(id);
self.vectors.remove(id);
self.data.remove(id);
self.slots[id.0 as usize] = INVALID;
self.count -= 1;
Ok(())
}
pub fn list(&self) -> Result<HashMap<VectorID, Record>, Error> {
if self.vectors.is_empty() {
return Ok(HashMap::new());
}
let mapper = |(id, vector): (&VectorID, &Vector)| {
let data = self.data[id].clone();
let record = Record::new(vector, &data);
(*id, record)
};
let records = self.vectors.par_iter().map(mapper).collect();
Ok(records)
}
pub fn get(&self, id: &VectorID) -> Result<Record, Error> {
if !self.contains(id) {
return Err(Error::record_not_found());
}
let vector = self.vectors[id].clone();
let data = self.data[id].clone();
Ok(Record::new(&vector, &data))
}
pub fn update(
&mut self,
id: &VectorID,
record: &Record,
) -> Result<(), Error> {
if !self.contains(id) {
return Err(Error::record_not_found());
}
self.validate_dimension(&record.vector)?;
self.delete_from_layers(id);
self.vectors.insert(*id, record.vector.clone());
self.data.insert(*id, record.data.clone());
self.insert_to_layers(id);
Ok(())
}
pub fn search(
&self,
vector: &Vector,
n: usize,
) -> Result<Vec<SearchResult>, Error> {
let mut search = Search::default();
if self.vectors.is_empty() {
return Ok(vec![]);
}
self.validate_dimension(vector)?;
let slots_iter = self.slots.as_slice().into_par_iter();
let vector_id = match slots_iter.find_first(|id| id.is_valid()) {
Some(id) => id,
None => return Err("Unable to initiate search.".into()),
};
search.visited.resize_capacity(self.vectors.len());
search.push(vector_id, vector, &self.vectors);
for layer in LayerID(self.upper_layers.len()).descend() {
search.ef = if layer.is_zero() { self.config.ef_search } else { 5 };
if layer.0 == 0 {
let layer = self.base_layer.as_slice();
search.search(layer, vector, &self.vectors, M * 2);
} else {
let layer = self.upper_layers[layer.0 - 1].as_slice();
search.search(layer, vector, &self.vectors, M);
}
if !layer.is_zero() {
search.cull();
}
}
let map_result = |candidate: Candidate| {
let id = candidate.vector_id.0;
let distance = candidate.distance.0;
let data = self.data[&candidate.vector_id].clone();
SearchResult { id, distance, data }
};
Ok(search.iter().map(map_result).take(n).collect())
}
pub fn true_search(
&self,
vector: &Vector,
n: usize,
) -> Result<Vec<SearchResult>, Error> {
let mut nearest = Vec::with_capacity(self.vectors.len());
self.validate_dimension(vector)?;
for (id, vec) in self.vectors.iter() {
let distance = vector.distance(vec);
let data = self.data[id].clone();
let res = SearchResult { id: id.0, distance, data };
nearest.push(res);
}
nearest.sort_by(|a, b| a.distance.partial_cmp(&b.distance).unwrap());
nearest.truncate(n);
Ok(nearest)
}
pub fn dimension(&self) -> usize {
self.dimension
}
pub fn set_dimension(&mut self, dimension: usize) -> Result<(), Error> {
if !self.vectors.is_empty() {
return Err("The collection must be empty.".into());
}
self.dimension = dimension;
Ok(())
}
pub fn len(&self) -> usize {
self.count
}
pub fn is_empty(&self) -> bool {
self.count == 0
}
pub fn contains(&self, id: &VectorID) -> bool {
self.vectors.contains_key(id)
}
fn __len__(&self) -> usize {
self.len()
}
}
impl Collection {
pub fn build(config: &Config, records: &[Record]) -> Result<Self, Error> {
if records.is_empty() {
return Ok(Self::new(config));
}
if records.len() >= u32::MAX as usize {
let message = format!(
"The collection record limit is {}. Given: {}",
u32::MAX,
records.len()
);
return Err(message.into());
}
let dimension = records[0].vector.len();
if records.par_iter().any(|i| i.vector.len() != dimension) {
let message = format!(
"The vector dimension is inconsistent. Expected: {}.",
dimension
);
return Err(message.into());
}
let mut len = records.len();
let mut layers = Vec::new();
loop {
let next = (len as f32 * config.ml) as usize;
if next < M {
break;
}
layers.push((len - next, len));
len = next;
}
layers.push((len, len));
layers.reverse();
let num_layers = layers.len();
let top_layer = LayerID(num_layers - 1);
let vectors = records
.par_iter()
.enumerate()
.map(|(i, item)| (i.into(), item.vector.clone()))
.collect::<HashMap<VectorID, Vector>>();
let mut ranges = Vec::with_capacity(top_layer.0);
for (i, (size, cumulative)) in layers.into_iter().enumerate() {
let start = cumulative - size;
let layer_id = LayerID(num_layers - i - 1);
let value = max(start, 1)..cumulative;
ranges.push((layer_id, value));
}
let search_pool = SearchPool::new(vectors.len());
let mut upper_layers = vec![vec![]; top_layer.0];
let base_layer = vectors
.par_iter()
.map(|_| RwLock::new(BaseNode::default()))
.collect::<Vec<_>>();
let state = IndexConstruction {
base_layer: &base_layer,
search_pool,
top_layer,
vectors: &vectors,
config,
};
for (layer, range) in ranges {
let end = range.end;
range.into_par_iter().for_each(|i: usize| {
state.insert(&i.into(), &layer, &upper_layers)
});
if !layer.is_zero() {
(&state.base_layer[..end])
.into_par_iter()
.map(|zero| UpperNode::from_zero(&zero.read()))
.collect_into_vec(&mut upper_layers[layer.0 - 1]);
}
}
let data = records
.iter()
.enumerate()
.map(|(i, item)| (i.into(), item.data.clone()))
.collect();
let base_iter = base_layer.into_par_iter();
let base_layer = base_iter.map(|node| node.into_inner()).collect();
let slots = (0..vectors.len()).map(|i| i.into()).collect();
Ok(Self {
data,
vectors,
base_layer,
upper_layers,
slots,
dimension,
config: *config,
count: records.len(),
})
}
fn validate_dimension(&self, vector: &Vector) -> Result<(), Error> {
let found = vector.len();
let expected = self.dimension;
if found != expected {
Err(Error::invalid_dimension(found, expected))
} else {
Ok(())
}
}
fn insert_to_layers(&mut self, id: &VectorID) {
self.base_layer.push(BaseNode::default());
let base_layer = self
.base_layer
.par_iter()
.map(|node| RwLock::new(*node))
.collect::<Vec<_>>();
let top_layer = match self.upper_layers.is_empty() {
true => LayerID(0),
false => LayerID(self.upper_layers.len()),
};
let state = IndexConstruction {
base_layer: base_layer.as_slice(),
search_pool: SearchPool::new(self.vectors.len()),
top_layer,
vectors: &self.vectors,
config: &self.config,
};
state.insert(id, &top_layer, &self.upper_layers);
let iter = state.base_layer.into_par_iter();
self.base_layer = iter.map(|node| *node.read()).collect();
}
fn delete_from_layers(&mut self, id: &VectorID) {
let base_node = &mut self.base_layer[id.0 as usize];
let index = base_node.par_iter().position_first(|x| *x == *id);
if let Some(index) = index {
base_node.set(index, &INVALID);
}
for layer in LayerID(self.upper_layers.len()).descend() {
let upper_layer = match layer.0 > 0 {
true => &mut self.upper_layers[layer.0 - 1],
false => break,
};
let node = &mut upper_layer[id.0 as usize];
let index = node.0.par_iter().position_first(|x| *x == *id);
if let Some(index) = index {
node.set(index, &INVALID);
}
}
}
}
#[pyclass(module = "oasysdb.collection")]
#[derive(Serialize, Deserialize, Clone, Debug)]
pub struct Record {
#[pyo3(get, set)]
pub vector: Vector,
#[pyo3(get)]
pub data: Metadata,
}
#[pymethods]
impl Record {
#[new]
fn py_new(vector: Vec<f32>, data: &PyAny) -> Self {
let vector = Vector::from(vector);
let data = Metadata::from(data);
Self::new(&vector, &data)
}
#[staticmethod]
pub fn random(dimension: usize) -> Self {
let vector = Vector::random(dimension);
let data = random::<usize>().into();
Self::new(&vector, &data)
}
#[staticmethod]
pub fn many_random(dimension: usize, len: usize) -> Vec<Self> {
(0..len).map(|_| Self::random(dimension)).collect()
}
fn __repr__(&self) -> String {
format!("{:?}", self)
}
}
impl Record {
pub fn new(vector: &Vector, data: &Metadata) -> Self {
Self { vector: vector.clone(), data: data.clone() }
}
}
#[pyclass(module = "oasysdb.collection")]
#[derive(Serialize, Deserialize, Debug)]
pub struct SearchResult {
#[pyo3(get)]
pub id: u32,
#[pyo3(get)]
pub distance: f32,
#[pyo3(get)]
pub data: Metadata,
}
#[pymethods]
impl SearchResult {
fn __repr__(&self) -> String {
format!("{:?}", self)
}
}