use arrow::array::{Float64Array, UInt32Array};
use arrow::datatypes::{DataType, Field, Schema};
use arrow_array::{Array as ArrowArray, FixedSizeListArray, RecordBatch};
use log::{debug, info, trace};
use smartcore::linalg::basic::arrays::Array;
use smartcore::linalg::basic::matrix::DenseMatrix;
use sprs::{CsMat, TriMat};
use std::fs;
use std::path::{Path, PathBuf};
use std::sync::Arc;
use crate::metadata::GeneMetadata;
use crate::{StorageError, StorageResult};
pub trait StorageBackend: Send + Sync {
fn get_base(&self) -> String;
fn get_name(&self) -> String;
fn exists(path: &str) -> (bool, Option<PathBuf>) {
let base_path = std::path::PathBuf::from(path);
if !base_path.exists() {
debug!("StorageBackend: path {:?} does not exist", base_path);
return (false, None);
}
if let Ok(entries) = std::fs::read_dir(&base_path) {
for entry in entries.flatten() {
let path = entry.path();
if let Some(name) = path.file_name().and_then(|n| n.to_str())
&& name.ends_with("_metadata.json")
{
debug!("StorageBackend::exists: found metadata file at {:?}", path);
return (true, Some(path));
}
}
}
(false, None)
}
fn base_path(&self) -> PathBuf;
fn metadata_path(&self) -> PathBuf;
fn basepath_to_uri(&self) -> String;
async fn load_dense_from_file(&self, path: &Path) -> StorageResult<DenseMatrix<f64>>;
fn file_path(&self, key: &str) -> PathBuf;
fn path_to_uri(path: &Path) -> String {
path.canonicalize()
.unwrap_or_else(|_| {
if path.is_absolute() {
path.to_path_buf()
} else if path.is_relative() {
std::env::current_dir()
.unwrap_or_else(|_| PathBuf::from("/"))
.join(path)
} else {
PathBuf::from(env!("CARGO_MANIFEST_DIR")).join(path)
}
})
.to_string_lossy()
.to_string()
}
fn validate_initialized(&self, md_path: &Path) -> StorageResult<()> {
assert_eq!(self.metadata_path(), *md_path);
if !md_path.exists() {
return Err(StorageError::Invalid(format!(
"Storage not initialized: metadata file missing at {:?}. \
Call save_metadata() or save_eigenmaps_all()/save_energymaps_all() first.",
md_path
)));
}
Ok(())
}
fn to_dense_record_batch(
&self,
matrix: &DenseMatrix<f64>,
) -> Result<RecordBatch, StorageError> {
let (rows, cols) = (matrix.shape().0, matrix.shape().1);
debug!(
"Converting dense matrix to RecordBatch (vector format): {}x{}",
rows, cols
);
if rows == 0 || cols == 0 {
return Err(StorageError::Invalid(
"Cannot convert empty matrix to RecordBatch".to_string(),
));
}
let mut values: Vec<f64> = Vec::with_capacity(rows * cols);
for r in 0..rows {
for c in 0..cols {
values.push(*matrix.get((r, c)));
}
}
let value_field = Field::new("item", DataType::Float64, false);
let list_field = Field::new(
"vector",
DataType::FixedSizeList(Arc::new(value_field), cols as i32),
false,
);
let schema = Schema::new(vec![list_field]);
let values_array = Float64Array::from(values);
let list_array = FixedSizeListArray::new(
Arc::new(Field::new("item", DataType::Float64, false)),
cols as i32,
Arc::new(values_array),
None, );
let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(list_array)])
.map_err(|e| StorageError::Lance(e.to_string()))?;
trace!(
"RecordBatch created with {} rows (vectors of length {})",
batch.num_rows(),
cols
);
Ok(batch)
}
#[allow(clippy::wrong_self_convention)]
fn from_dense_record_batch(
&self,
batch: &RecordBatch,
) -> Result<DenseMatrix<f64>, StorageError> {
use std::mem;
debug!("Reconstructing dense matrix from RecordBatch (vector format)");
debug!("Batch has {} columns", batch.num_columns());
if batch.num_columns() != 1 {
return Err(StorageError::Invalid(format!(
"Expected Lance row-major format with 1 FixedSizeList<Float64> column, but found {} columns. \
This parquet file appears to be in wide format (feature-per-column). \
Convert it first using: \
`python -c \"import pyarrow.parquet as pq; import pyarrow.compute as pc; \
tbl = pq.read_table('input.parquet'); \
import pyarrow as pa; \
vectors = pa.array([row.as_py() for row in tbl.to_pylist()], type=pa.list_(pa.float64(), len(tbl.column_names))); \
new_tbl = pa.table({{'vector': vectors}}); \
pq.write_table(new_tbl, 'output.parquet')\"` \
or use a Lance-native writer in your data pipeline.",
batch.num_columns()
)));
}
debug!("Extracting FixedSizeList column");
let column = batch.column(0);
let list_array = column
.as_any()
.downcast_ref::<FixedSizeListArray>()
.ok_or_else(|| {
StorageError::Invalid(format!(
"Column 0 is not FixedSizeList (found type: {:?}). \
Expected Lance row-major format with a single FixedSizeList<Float64> column.",
column.data_type()
))
})?;
let rows = list_array.len();
let cols = list_array.value_length() as usize;
debug!("Matrix dimensions: {}x{}", rows, cols);
let total = rows
.checked_mul(cols)
.ok_or_else(|| StorageError::Invalid("Matrix size overflow (rows*cols)".to_string()))?;
let bytes = total
.checked_mul(mem::size_of::<f64>())
.ok_or_else(|| StorageError::Invalid("Byte size overflow".to_string()))?;
const MAX_BYTES: usize = 4usize * 1024 * 1024 * 1024; if bytes > MAX_BYTES {
return Err(StorageError::Invalid(format!(
"Dense load would allocate {} bytes for {}x{} matrix; exceeds 4GiB cap. \
Enable --reduce-dim or shard your input data.",
bytes, rows, cols
)));
}
let values_array = list_array
.values()
.as_any()
.downcast_ref::<Float64Array>()
.ok_or_else(|| {
StorageError::Invalid("FixedSizeList values are not Float64Array".to_string())
})?;
debug!("Converting row-major to column-major");
let mut data = vec![0.0f64; total];
for r in 0..rows {
for c in 0..cols {
let row_major_idx = r * cols + c;
let col_major_idx = c * rows + r;
data[col_major_idx] = values_array.value(row_major_idx);
}
}
debug!("Creating DenseMatrix");
DenseMatrix::new(rows, cols, data, true).map_err(|e| StorageError::Invalid(e.to_string()))
}
fn to_sparse_record_batch(&self, m: &CsMat<f64>) -> StorageResult<RecordBatch> {
debug!(
"Converting sparse matrix to RecordBatch: {} x {}, nnz={}",
m.rows(),
m.cols(),
m.nnz()
);
let mut row_idx = Vec::with_capacity(m.nnz());
let mut col_idx = Vec::with_capacity(m.nnz());
let mut vals = Vec::with_capacity(m.nnz());
for (v, (r, c)) in m.iter() {
row_idx.push(r as u32);
col_idx.push(c as u32);
vals.push(*v);
}
let mut schema_metadata = std::collections::HashMap::new();
schema_metadata.insert("rows".to_string(), m.rows().to_string());
schema_metadata.insert("cols".to_string(), m.cols().to_string());
schema_metadata.insert("nnz".to_string(), m.nnz().to_string());
let schema = Schema::new(vec![
Field::new("row", DataType::UInt32, false),
Field::new("col", DataType::UInt32, false),
Field::new("value", DataType::Float64, false),
])
.with_metadata(schema_metadata);
let batch = RecordBatch::try_new(
Arc::new(schema),
vec![
Arc::new(UInt32Array::from(row_idx)) as _,
Arc::new(UInt32Array::from(col_idx)) as _,
Arc::new(Float64Array::from(vals)) as _,
],
)
.map_err(|e| StorageError::Lance(e.to_string()))?;
trace!(
"Sparse RecordBatch created with {} entries",
batch.num_rows()
);
Ok(batch)
}
#[allow(clippy::wrong_self_convention)]
fn from_sparse_record_batch(
&self,
batch: RecordBatch,
expected_rows: usize,
expected_cols: usize,
) -> StorageResult<CsMat<f64>> {
use arrow::array::UInt32Array;
debug!("Reconstructing sparse matrix from RecordBatch");
let row_arr = batch
.column(0)
.as_any()
.downcast_ref::<UInt32Array>()
.ok_or_else(|| StorageError::Invalid("row column type mismatch".into()))?;
let col_arr = batch
.column(1)
.as_any()
.downcast_ref::<UInt32Array>()
.ok_or_else(|| StorageError::Invalid("col column type mismatch".into()))?;
let val_arr = batch
.column(2)
.as_any()
.downcast_ref::<Float64Array>()
.ok_or_else(|| StorageError::Invalid("value column type mismatch".into()))?;
let n = row_arr.len();
if n == 0 {
debug!(
"Empty RecordBatch, returning {}x{} sparse matrix",
expected_rows, expected_cols
);
return Ok(CsMat::zero((expected_rows, expected_cols)));
}
let schema = batch.schema();
let schema_metadata = schema.metadata();
if let (Some(rows_str), Some(cols_str)) =
(schema_metadata.get("rows"), schema_metadata.get("cols"))
{
let schema_rows = rows_str.parse::<usize>().ok();
let schema_cols = cols_str.parse::<usize>().ok();
if schema_rows != Some(expected_rows) || schema_cols != Some(expected_cols) {
panic!(
"Schema metadata dimensions ({:?}x{:?}) don't match storage metadata ({}x{})",
schema_rows, schema_cols, expected_rows, expected_cols
);
} else {
debug!(
"Schema metadata matches storage metadata: {}x{}",
expected_rows, expected_cols
);
}
}
let rows = expected_rows;
let cols = expected_cols;
debug!(
"Reconstructing {}x{} sparse matrix from {} entries",
rows, cols, n
);
let mut trimat = TriMat::new((rows, cols));
for i in 0..n {
let r = row_arr.value(i) as usize;
let c = col_arr.value(i) as usize;
let v = val_arr.value(i);
if r >= rows || c >= cols {
return Err(StorageError::Invalid(format!(
"Index out of bounds: ({}, {}) in {}x{} matrix",
r, c, rows, cols
)));
}
trimat.add_triplet(r, c, v);
}
let result = trimat.to_csr();
if result.rows() != rows || result.cols() != cols {
return Err(StorageError::Invalid(format!(
"Dimension mismatch after reconstruction: expected {}x{}, got {}x{}",
rows,
cols,
result.rows(),
result.cols()
)));
}
Ok(result)
}
async fn save_dense(
&self,
key: &str,
matrix: &DenseMatrix<f64>,
md_path: &Path,
) -> StorageResult<()>;
async fn load_dense(&self, key: &str) -> StorageResult<DenseMatrix<f64>>;
async fn save_sparse(
&self,
key: &str,
matrix: &CsMat<f64>,
md_path: &Path,
) -> StorageResult<()>;
async fn load_sparse(&self, key: &str) -> StorageResult<CsMat<f64>>;
async fn save_lambdas(&self, lambdas: &[f64], md_path: &Path) -> StorageResult<()>;
async fn load_lambdas(&self) -> StorageResult<Vec<f64>>;
async fn save_metadata(&self, metadata: &GeneMetadata) -> StorageResult<PathBuf> {
let path = self.metadata_path();
info!("Saving metadata to {:?}", path);
fs::create_dir_all(self.base_path()).map_err(|e| StorageError::Io(e.to_string()))?;
let s = serde_json::to_string_pretty(metadata).map_err(StorageError::Serde)?;
fs::write(&path, s).map_err(|e| StorageError::Io(e.to_string()))?;
info!("Metadata saved successfully");
Ok(path)
}
async fn load_metadata(&self) -> StorageResult<GeneMetadata> {
let filename = self.metadata_path();
info!("Loading metadata from {:?}", filename);
let s = fs::read_to_string(filename).map_err(|e| StorageError::Io(e.to_string()))?;
let md: GeneMetadata = serde_json::from_str(&s).map_err(StorageError::Serde)?;
info!("Metadata loaded successfully");
Ok(md)
}
#[allow(dead_code)]
async fn save_index(&self, key: &str, vector: &[usize], md_path: &Path) -> StorageResult<()>;
async fn save_vector(&self, key: &str, vector: &[f64], md_path: &Path) -> StorageResult<()>;
async fn save_centroid_map(&self, map: &[usize], md_path: &Path) -> StorageResult<()>;
async fn load_centroid_map(&self) -> StorageResult<Vec<usize>>;
async fn save_subcentroid_lambdas(&self, lambdas: &[f64], md_path: &Path) -> StorageResult<()>;
async fn load_subcentroid_lambdas(&self) -> StorageResult<Vec<f64>>;
async fn save_subcentroids(
&self,
subcentroids: &DenseMatrix<f64>,
md_path: &Path,
) -> StorageResult<()>;
async fn load_subcentroids(&self) -> StorageResult<Vec<Vec<f64>>>;
async fn save_item_norms(&self, item_norms: &[f64], md_path: &Path) -> StorageResult<()>;
async fn load_item_norms(&self) -> StorageResult<Vec<f64>>;
async fn save_cluster_assignments(
&self,
assignments: &[Option<usize>],
md_path: &Path,
) -> StorageResult<()>;
async fn load_cluster_assignments(&self) -> StorageResult<Vec<Option<usize>>>;
#[allow(dead_code)]
async fn load_index(&self, key: &str) -> StorageResult<Vec<usize>>;
async fn load_vector(&self, key: &str) -> StorageResult<Vec<f64>>;
async fn save_dense_to_file(data: &DenseMatrix<f64>, path: &Path) -> StorageResult<()>;
}