genegraph_storage/traits/
backend.rs1use arrow::array::{Float64Array, UInt32Array};
2use arrow::datatypes::{DataType, Field, Schema};
3use arrow_array::{Array as ArrowArray, FixedSizeListArray, RecordBatch};
4use log::{debug, info, trace};
5use smartcore::linalg::basic::arrays::Array;
6use smartcore::linalg::basic::matrix::DenseMatrix;
7use sprs::{CsMat, TriMat};
8use std::fs;
9use std::path::{Path, PathBuf};
10use std::sync::Arc;
11
12use crate::metadata::GeneMetadata;
13use crate::{StorageError, StorageResult};
14
15pub trait StorageBackend: Send + Sync {
72 fn get_base(&self) -> String;
74 fn get_name(&self) -> String;
76
77 fn exists(path: &str) -> (bool, Option<PathBuf>) {
82 let base_path = std::path::PathBuf::from(path);
83 if !base_path.exists() {
84 debug!("StorageBackend: path {:?} does not exist", base_path);
85 return (false, None);
86 }
87
88 if let Ok(entries) = std::fs::read_dir(&base_path) {
90 for entry in entries.flatten() {
91 let path = entry.path();
92 if let Some(name) = path.file_name().and_then(|n| n.to_str())
93 && name.ends_with("_metadata.json")
94 {
95 debug!("StorageBackend::exists: found metadata file at {:?}", path);
96 return (true, Some(path));
97 }
98 }
99 }
100 (false, None)
101 }
102
103 fn base_path(&self) -> PathBuf;
105 fn metadata_path(&self) -> PathBuf;
107 fn basepath_to_uri(&self) -> String;
109
110 async fn load_dense_from_file(&self, path: &Path) -> StorageResult<DenseMatrix<f64>>;
113
114 fn file_path(&self, key: &str) -> PathBuf;
116
117 fn path_to_uri(path: &Path) -> String {
119 path.canonicalize()
120 .unwrap_or_else(|_| {
121 if path.is_absolute() {
122 path.to_path_buf()
123 } else if path.is_relative() {
124 std::env::current_dir()
125 .unwrap_or_else(|_| PathBuf::from("/"))
126 .join(path)
127 } else {
128 PathBuf::from(env!("CARGO_MANIFEST_DIR")).join(path)
129 }
130 })
131 .to_string_lossy()
132 .to_string()
133 }
134
135 fn validate_initialized(&self, md_path: &Path) -> StorageResult<()> {
141 assert_eq!(self.metadata_path(), *md_path);
142 if !md_path.exists() {
143 return Err(StorageError::Invalid(format!(
144 "Storage not initialized: metadata file missing at {:?}. \
145 Call save_metadata() or save_eigenmaps_all()/save_energymaps_all() first.",
146 md_path
147 )));
148 }
149 Ok(())
150 }
151
152 fn to_dense_record_batch(
165 &self,
166 matrix: &DenseMatrix<f64>,
167 ) -> Result<RecordBatch, StorageError> {
168 let (rows, cols) = (matrix.shape().0, matrix.shape().1);
169
170 debug!(
171 "Converting dense matrix to RecordBatch (vector format): {}x{}",
172 rows, cols
173 );
174
175 if rows == 0 || cols == 0 {
176 return Err(StorageError::Invalid(
177 "Cannot convert empty matrix to RecordBatch".to_string(),
178 ));
179 }
180
181 let mut values: Vec<f64> = Vec::with_capacity(rows * cols);
183 for r in 0..rows {
184 for c in 0..cols {
185 values.push(*matrix.get((r, c)));
186 }
187 }
188
189 let value_field = Field::new("item", DataType::Float64, false);
191 let list_field = Field::new(
192 "vector",
193 DataType::FixedSizeList(Arc::new(value_field), cols as i32),
194 false,
195 );
196
197 let schema = Schema::new(vec![list_field]);
198
199 let values_array = Float64Array::from(values);
201 let list_array = FixedSizeListArray::new(
202 Arc::new(Field::new("item", DataType::Float64, false)),
203 cols as i32,
204 Arc::new(values_array),
205 None, );
207
208 let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(list_array)])
209 .map_err(|e| StorageError::Lance(e.to_string()))?;
210
211 trace!(
212 "RecordBatch created with {} rows (vectors of length {})",
213 batch.num_rows(),
214 cols
215 );
216
217 Ok(batch)
218 }
219
220 #[allow(clippy::wrong_self_convention)]
228 fn from_dense_record_batch(
229 &self,
230 batch: &RecordBatch,
231 ) -> Result<DenseMatrix<f64>, StorageError> {
232 use std::mem;
233
234 debug!("Reconstructing dense matrix from RecordBatch (vector format)");
235 debug!("Batch has {} columns", batch.num_columns());
236
237 if batch.num_columns() != 1 {
238 return Err(StorageError::Invalid(format!(
239 "Expected Lance row-major format with 1 FixedSizeList<Float64> column, but found {} columns. \
240 This parquet file appears to be in wide format (feature-per-column). \
241 Convert it first using: \
242 `python -c \"import pyarrow.parquet as pq; import pyarrow.compute as pc; \
243 tbl = pq.read_table('input.parquet'); \
244 import pyarrow as pa; \
245 vectors = pa.array([row.as_py() for row in tbl.to_pylist()], type=pa.list_(pa.float64(), len(tbl.column_names))); \
246 new_tbl = pa.table({{'vector': vectors}}); \
247 pq.write_table(new_tbl, 'output.parquet')\"` \
248 or use a Lance-native writer in your data pipeline.",
249 batch.num_columns()
250 )));
251 }
252
253 debug!("Extracting FixedSizeList column");
254 let column = batch.column(0);
255 let list_array = column
256 .as_any()
257 .downcast_ref::<FixedSizeListArray>()
258 .ok_or_else(|| {
259 StorageError::Invalid(format!(
260 "Column 0 is not FixedSizeList (found type: {:?}). \
261 Expected Lance row-major format with a single FixedSizeList<Float64> column.",
262 column.data_type()
263 ))
264 })?;
265
266 let rows = list_array.len();
267 let cols = list_array.value_length() as usize;
268
269 debug!("Matrix dimensions: {}x{}", rows, cols);
270
271 let total = rows
273 .checked_mul(cols)
274 .ok_or_else(|| StorageError::Invalid("Matrix size overflow (rows*cols)".to_string()))?;
275 let bytes = total
276 .checked_mul(mem::size_of::<f64>())
277 .ok_or_else(|| StorageError::Invalid("Byte size overflow".to_string()))?;
278
279 const MAX_BYTES: usize = 4usize * 1024 * 1024 * 1024; if bytes > MAX_BYTES {
281 return Err(StorageError::Invalid(format!(
282 "Dense load would allocate {} bytes for {}x{} matrix; exceeds 4GiB cap. \
283 Enable --reduce-dim or shard your input data.",
284 bytes, rows, cols
285 )));
286 }
287
288 let values_array = list_array
290 .values()
291 .as_any()
292 .downcast_ref::<Float64Array>()
293 .ok_or_else(|| {
294 StorageError::Invalid("FixedSizeList values are not Float64Array".to_string())
295 })?;
296
297 debug!("Converting row-major to column-major");
298 let mut data = vec![0.0f64; total];
299 for r in 0..rows {
300 for c in 0..cols {
301 let row_major_idx = r * cols + c;
302 let col_major_idx = c * rows + r;
303 data[col_major_idx] = values_array.value(row_major_idx);
304 }
305 }
306
307 debug!("Creating DenseMatrix");
308 DenseMatrix::new(rows, cols, data, true).map_err(|e| StorageError::Invalid(e.to_string()))
309 }
310
311 fn to_sparse_record_batch(&self, m: &CsMat<f64>) -> StorageResult<RecordBatch> {
315 debug!(
316 "Converting sparse matrix to RecordBatch: {} x {}, nnz={}",
317 m.rows(),
318 m.cols(),
319 m.nnz()
320 );
321
322 let mut row_idx = Vec::with_capacity(m.nnz());
323 let mut col_idx = Vec::with_capacity(m.nnz());
324 let mut vals = Vec::with_capacity(m.nnz());
325
326 for (v, (r, c)) in m.iter() {
327 row_idx.push(r as u32);
328 col_idx.push(c as u32);
329 vals.push(*v);
330 }
331
332 let mut schema_metadata = std::collections::HashMap::new();
334 schema_metadata.insert("rows".to_string(), m.rows().to_string());
335 schema_metadata.insert("cols".to_string(), m.cols().to_string());
336 schema_metadata.insert("nnz".to_string(), m.nnz().to_string());
337
338 let schema = Schema::new(vec![
339 Field::new("row", DataType::UInt32, false),
340 Field::new("col", DataType::UInt32, false),
341 Field::new("value", DataType::Float64, false),
342 ])
343 .with_metadata(schema_metadata);
344
345 let batch = RecordBatch::try_new(
346 Arc::new(schema),
347 vec![
348 Arc::new(UInt32Array::from(row_idx)) as _,
349 Arc::new(UInt32Array::from(col_idx)) as _,
350 Arc::new(Float64Array::from(vals)) as _,
351 ],
352 )
353 .map_err(|e| StorageError::Lance(e.to_string()))?;
354
355 trace!(
356 "Sparse RecordBatch created with {} entries",
357 batch.num_rows()
358 );
359 Ok(batch)
360 }
361
362 #[allow(clippy::wrong_self_convention)]
367 fn from_sparse_record_batch(
368 &self,
369 batch: RecordBatch,
370 expected_rows: usize,
371 expected_cols: usize,
372 ) -> StorageResult<CsMat<f64>> {
373 use arrow::array::UInt32Array;
374
375 debug!("Reconstructing sparse matrix from RecordBatch");
376
377 let row_arr = batch
378 .column(0)
379 .as_any()
380 .downcast_ref::<UInt32Array>()
381 .ok_or_else(|| StorageError::Invalid("row column type mismatch".into()))?;
382 let col_arr = batch
383 .column(1)
384 .as_any()
385 .downcast_ref::<UInt32Array>()
386 .ok_or_else(|| StorageError::Invalid("col column type mismatch".into()))?;
387 let val_arr = batch
388 .column(2)
389 .as_any()
390 .downcast_ref::<Float64Array>()
391 .ok_or_else(|| StorageError::Invalid("value column type mismatch".into()))?;
392
393 let n = row_arr.len();
394 if n == 0 {
395 debug!(
396 "Empty RecordBatch, returning {}x{} sparse matrix",
397 expected_rows, expected_cols
398 );
399 return Ok(CsMat::zero((expected_rows, expected_cols)));
400 }
401
402 let schema = batch.schema();
404 let schema_metadata = schema.metadata();
405 if let (Some(rows_str), Some(cols_str)) =
406 (schema_metadata.get("rows"), schema_metadata.get("cols"))
407 {
408 let schema_rows = rows_str.parse::<usize>().ok();
409 let schema_cols = cols_str.parse::<usize>().ok();
410 if schema_rows != Some(expected_rows) || schema_cols != Some(expected_cols) {
411 panic!(
412 "Schema metadata dimensions ({:?}x{:?}) don't match storage metadata ({}x{})",
413 schema_rows, schema_cols, expected_rows, expected_cols
414 );
415 } else {
416 debug!(
417 "Schema metadata matches storage metadata: {}x{}",
418 expected_rows, expected_cols
419 );
420 }
421 }
422
423 let rows = expected_rows;
424 let cols = expected_cols;
425 debug!(
426 "Reconstructing {}x{} sparse matrix from {} entries",
427 rows, cols, n
428 );
429
430 let mut trimat = TriMat::new((rows, cols));
431 for i in 0..n {
432 let r = row_arr.value(i) as usize;
433 let c = col_arr.value(i) as usize;
434 let v = val_arr.value(i);
435
436 if r >= rows || c >= cols {
437 return Err(StorageError::Invalid(format!(
438 "Index out of bounds: ({}, {}) in {}x{} matrix",
439 r, c, rows, cols
440 )));
441 }
442 trimat.add_triplet(r, c, v);
443 }
444
445 let result = trimat.to_csr();
446 if result.rows() != rows || result.cols() != cols {
447 return Err(StorageError::Invalid(format!(
448 "Dimension mismatch after reconstruction: expected {}x{}, got {}x{}",
449 rows,
450 cols,
451 result.rows(),
452 result.cols()
453 )));
454 }
455
456 Ok(result)
457 }
458
459 async fn save_dense(
461 &self,
462 key: &str,
463 matrix: &DenseMatrix<f64>,
464 md_path: &Path,
465 ) -> StorageResult<()>;
466
467 async fn load_dense(&self, key: &str) -> StorageResult<DenseMatrix<f64>>;
469
470 async fn save_sparse(
472 &self,
473 key: &str,
474 matrix: &CsMat<f64>,
475 md_path: &Path,
476 ) -> StorageResult<()>;
477
478 async fn load_sparse(&self, key: &str) -> StorageResult<CsMat<f64>>;
480
481 async fn save_lambdas(&self, lambdas: &[f64], md_path: &Path) -> StorageResult<()>;
483
484 async fn load_lambdas(&self) -> StorageResult<Vec<f64>>;
486
487 async fn save_metadata(&self, metadata: &GeneMetadata) -> StorageResult<PathBuf> {
489 let path = self.metadata_path();
490 info!("Saving metadata to {:?}", path);
491 fs::create_dir_all(self.base_path()).map_err(|e| StorageError::Io(e.to_string()))?;
492 let s = serde_json::to_string_pretty(metadata).map_err(StorageError::Serde)?;
493 fs::write(&path, s).map_err(|e| StorageError::Io(e.to_string()))?;
494 info!("Metadata saved successfully");
495 Ok(path)
496 }
497
498 async fn load_metadata(&self) -> StorageResult<GeneMetadata> {
500 let filename = self.metadata_path();
501 info!("Loading metadata from {:?}", filename);
502 let s = fs::read_to_string(filename).map_err(|e| StorageError::Io(e.to_string()))?;
503 let md: GeneMetadata = serde_json::from_str(&s).map_err(StorageError::Serde)?;
504 info!("Metadata loaded successfully");
505 Ok(md)
506 }
507
508 #[allow(dead_code)]
510 async fn save_index(&self, key: &str, vector: &[usize], md_path: &Path) -> StorageResult<()>;
511
512 async fn save_vector(&self, key: &str, vector: &[f64], md_path: &Path) -> StorageResult<()>;
514
515 async fn save_centroid_map(&self, map: &[usize], md_path: &Path) -> StorageResult<()>;
517
518 async fn load_centroid_map(&self) -> StorageResult<Vec<usize>>;
520 async fn save_subcentroid_lambdas(&self, lambdas: &[f64], md_path: &Path) -> StorageResult<()>;
522 async fn load_subcentroid_lambdas(&self) -> StorageResult<Vec<f64>>;
524 async fn save_subcentroids(
526 &self,
527 subcentroids: &DenseMatrix<f64>,
528 md_path: &Path,
529 ) -> StorageResult<()>;
530 async fn load_subcentroids(&self) -> StorageResult<Vec<Vec<f64>>>;
532
533 async fn save_item_norms(&self, item_norms: &[f64], md_path: &Path) -> StorageResult<()>;
535
536 async fn load_item_norms(&self) -> StorageResult<Vec<f64>>;
538
539 async fn save_cluster_assignments(
541 &self,
542 assignments: &[Option<usize>],
543 md_path: &Path,
544 ) -> StorageResult<()>;
545
546 async fn load_cluster_assignments(&self) -> StorageResult<Vec<Option<usize>>>;
548
549 #[allow(dead_code)]
551 async fn load_index(&self, key: &str) -> StorageResult<Vec<usize>>;
552
553 async fn load_vector(&self, key: &str) -> StorageResult<Vec<f64>>;
554
555 async fn save_dense_to_file(data: &DenseMatrix<f64>, path: &Path) -> StorageResult<()>;
556}