use std::collections::HashMap;
use std::io::{Read, Write};
use std::sync::Arc;
use crate::error::{LaurusError, Result};
use crate::lexical::core::field::FieldValue;
use crate::storage::Storage;
const DOC_VALUES_EXTENSION: &str = ".dv";
#[derive(Debug, Clone)]
pub struct FieldDocValues {
pub field_name: String,
values: ahash::AHashMap<u64, FieldValue>,
}
impl FieldDocValues {
pub fn new(field_name: String) -> Self {
FieldDocValues {
field_name,
values: ahash::AHashMap::new(),
}
}
pub fn set(&mut self, doc_id: u64, value: FieldValue) {
self.values.insert(doc_id, value);
}
pub fn get(&self, doc_id: u64) -> Option<&FieldValue> {
self.values.get(&doc_id)
}
pub fn len(&self) -> usize {
self.values.len()
}
pub fn is_empty(&self) -> bool {
self.values.is_empty()
}
}
pub struct DocValuesWriter {
storage: Arc<dyn Storage>,
segment_name: String,
fields: HashMap<String, FieldDocValues>,
}
impl DocValuesWriter {
pub fn new(storage: Arc<dyn Storage>, segment_name: String) -> Self {
DocValuesWriter {
storage,
segment_name,
fields: HashMap::new(),
}
}
pub fn add_value(&mut self, doc_id: u64, field_name: &str, value: FieldValue) {
self.fields
.entry(field_name.to_string())
.or_insert_with(|| FieldDocValues::new(field_name.to_string()))
.set(doc_id, value);
}
pub fn write(&self) -> Result<()> {
let dv_filename = format!("{}{}", self.segment_name, DOC_VALUES_EXTENSION);
let mut output = self.storage.create_output(&dv_filename)?;
output.write_all(b"DVFF")?; output.write_all(&[1u8, 0u8])?;
let num_fields = self.fields.len() as u32;
output.write_all(&num_fields.to_le_bytes())?;
for (field_name, field_dv) in &self.fields {
let name_bytes = field_name.as_bytes();
output.write_all(&(name_bytes.len() as u32).to_le_bytes())?;
output.write_all(name_bytes)?;
let num_values = field_dv.values.len() as u64;
output.write_all(&num_values.to_le_bytes())?;
let values_vec: Vec<(u64, FieldValue)> = field_dv
.values
.iter()
.map(|(k, v)| (*k, v.clone()))
.collect();
let serialized = rkyv::to_bytes::<rkyv::rancor::Error>(&values_vec)
.map_err(|e| LaurusError::Index(format!("Failed to serialize DocValues: {}", e)))?;
output.write_all(&(serialized.len() as u64).to_le_bytes())?;
output.write_all(&serialized)?;
}
output.flush()?;
Ok(())
}
}
#[derive(Debug)]
pub struct DocValuesReader {
fields: HashMap<String, FieldDocValues>,
}
impl DocValuesReader {
pub fn load(storage: Arc<dyn Storage>, segment_name: &str) -> Result<Self> {
let dv_filename = format!("{}{}", segment_name, DOC_VALUES_EXTENSION);
let mut input = match storage.open_input(&dv_filename) {
Ok(input) => input,
Err(_) => {
return Ok(DocValuesReader {
fields: HashMap::new(),
});
}
};
let mut magic = [0u8; 4];
input.read_exact(&mut magic)?;
if &magic != b"DVFF" {
return Err(LaurusError::Index(
"Invalid DocValues file format".to_string(),
));
}
let mut version = [0u8; 2];
input.read_exact(&mut version)?;
if version[0] != 1 {
return Err(LaurusError::Index(format!(
"Unsupported DocValues version: {}.{}",
version[0], version[1]
)));
}
let mut num_fields_bytes = [0u8; 4];
input.read_exact(&mut num_fields_bytes)?;
let num_fields = u32::from_le_bytes(num_fields_bytes);
let mut fields = HashMap::new();
for _ in 0..num_fields {
let mut name_len_bytes = [0u8; 4];
input.read_exact(&mut name_len_bytes)?;
let name_len = u32::from_le_bytes(name_len_bytes) as usize;
let mut name_bytes = vec![0u8; name_len];
input.read_exact(&mut name_bytes)?;
let field_name = String::from_utf8(name_bytes)
.map_err(|e| LaurusError::Index(format!("Invalid field name: {}", e)))?;
let mut num_values_bytes = [0u8; 8];
input.read_exact(&mut num_values_bytes)?;
let _num_values = u64::from_le_bytes(num_values_bytes);
let mut data_len_bytes = [0u8; 8];
input.read_exact(&mut data_len_bytes)?;
let data_len = u64::from_le_bytes(data_len_bytes) as usize;
let mut data = vec![0u8; data_len];
input.read_exact(&mut data)?;
let values_vec: Vec<(u64, FieldValue)> =
rkyv::from_bytes::<Vec<(u64, FieldValue)>, rkyv::rancor::Error>(&data).map_err(
|e| LaurusError::Index(format!("Failed to deserialize DocValues: {}", e)),
)?;
let values = values_vec.into_iter().collect();
fields.insert(field_name.clone(), FieldDocValues { field_name, values });
}
Ok(DocValuesReader { fields })
}
pub fn get_field(&self, field_name: &str) -> Option<&FieldDocValues> {
self.fields.get(field_name)
}
pub fn get_value(&self, field_name: &str, doc_id: u64) -> Option<&FieldValue> {
self.fields.get(field_name).and_then(|dv| dv.get(doc_id))
}
pub fn has_field(&self, field_name: &str) -> bool {
self.fields.contains_key(field_name)
}
pub fn field_names(&self) -> Vec<String> {
self.fields.keys().cloned().collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::storage::memory::MemoryStorage;
use crate::storage::memory::MemoryStorageConfig;
#[test]
fn test_field_doc_values() {
let mut dv = FieldDocValues::new("test_field".to_string());
dv.set(0, crate::data::DataValue::Int64(100));
dv.set(1, crate::data::DataValue::Text("hello".to_string()));
dv.set(5, crate::data::DataValue::Float64(3.15));
assert_eq!(dv.get(0), Some(&crate::data::DataValue::Int64(100)));
assert_eq!(
dv.get(1),
Some(&crate::data::DataValue::Text("hello".to_string()))
);
assert_eq!(dv.get(2), None);
assert_eq!(dv.get(5), Some(&crate::data::DataValue::Float64(3.15)));
}
#[test]
fn test_doc_values_write_read() {
let storage = Arc::new(MemoryStorage::new(MemoryStorageConfig::default()));
let segment_name = "segment_0".to_string();
{
let mut writer = DocValuesWriter::new(storage.clone(), segment_name.clone());
writer.add_value(0, "year", crate::data::DataValue::Int64(2023));
writer.add_value(1, "year", crate::data::DataValue::Int64(2024));
writer.add_value(0, "rating", crate::data::DataValue::Float64(4.5));
writer.add_value(1, "rating", crate::data::DataValue::Float64(5.0));
writer.write().unwrap();
}
{
let reader = DocValuesReader::load(storage.clone(), &segment_name).unwrap();
assert!(reader.has_field("year"));
assert!(reader.has_field("rating"));
assert!(!reader.has_field("unknown"));
assert_eq!(
reader.get_value("year", 0),
Some(&crate::data::DataValue::Int64(2023))
);
assert_eq!(
reader.get_value("year", 1),
Some(&crate::data::DataValue::Int64(2024))
);
assert_eq!(
reader.get_value("rating", 0),
Some(&crate::data::DataValue::Float64(4.5))
);
assert_eq!(
reader.get_value("rating", 1),
Some(&crate::data::DataValue::Float64(5.0))
);
}
}
}