genegraph_storage/traits/
backend.rs

1use arrow::array::{Float64Array, UInt32Array};
2use arrow::datatypes::{DataType, Field, Schema};
3use arrow_array::{Array as ArrowArray, FixedSizeListArray, RecordBatch};
4use log::{debug, info, trace};
5use smartcore::linalg::basic::arrays::Array;
6use smartcore::linalg::basic::matrix::DenseMatrix;
7use sprs::{CsMat, TriMat};
8use std::fs;
9use std::path::{Path, PathBuf};
10use std::sync::Arc;
11
12use crate::metadata::GeneMetadata;
13use crate::{StorageError, StorageResult};
14
15/// Async storage backend for Lance-based graph and embedding data.
16///
17/// This trait defines the minimal async API required to persist and reload
18/// all artifacts used by Javelin:
19///
20/// - Dense matrices (embeddings, eigenmaps, energy maps)
21/// - Sparse matrices in CSR form (e.g. Laplacians, adjacency)
22/// - Scalar vectors (eigenvalues, norms, generic f64 sequences)
23/// - Index-like vectors (usize mappings and cluster assignments)
24/// - Clustering metadata (centroid maps, subcentroids, lambdas)
25/// - Global metadata describing the dataset layout and dimensions
26///
27/// ## Initialization
28///
29/// Storage must be initialized before saving any data:
30///
31/// 1. Call `save_metadata()` once to write an initial `*_metadata.json`.
32/// 2. Subsequent `save_*` calls validate that metadata exists and is consistent.
33/// 3. `exists()` can be used to detect and reuse an existing initialized store.
34///
35/// Filenames are conventionally:
36///
37/// ```ignore
38/// <base dir>/<instance name or name id>_<key>.lance
39/// ```
40///
41/// ## Async usage
42///
43/// All I/O functions are async and intended to be called from a Tokio runtime.
44/// Implementations (e.g. `LanceStorage`) must not create their own runtimes or
45/// block on I/O internally.
46///
47/// ## High-level flow
48///
49/// - Dense data:
50///   - `save_dense("raw_input", &matrix, md_path)`
51///   - `load_dense("raw_input")`
52///
53/// - Sparse data:
54///   - `save_sparse("laplacian", &csr, md_path)`
55///   - `load_sparse("laplacian")`
56///
57/// - Scalars and indices:
58///   - `save_lambdas`, `load_lambdas`
59///   - `save_vector`, `load_vector`
60///   - `save_index`, `load_index`
61///   - `save_centroid_map`, `load_centroid_map`
62///   - `save_item_norms`, `load_item_norms`
63///   - `save_cluster_assignments`, `load_cluster_assignments`
64///
65/// - Clustering structure:
66///   - `save_subcentroids`, `load_subcentroids`
67///   - `save_subcentroid_lambdas`, `load_subcentroid_lambdas`
68///
69/// Implementations are free to choose the on-disk layout as long as they honor
70/// these logical keys and round-trip semantics.
71pub trait StorageBackend: Send + Sync {
72    /// Base directory of the instance
73    fn get_base(&self) -> String;
74    /// Name of the instance
75    fn get_name(&self) -> String;
76
77    ///
78    /// Returns `true` and the path to the metadata file if metadata file exists and is valid,
79    /// `false` otherwise.
80    /// This is used to avoid overwriting existing indexes.
81    fn exists(path: &str) -> (bool, Option<PathBuf>) {
82        let base_path = std::path::PathBuf::from(path);
83        if !base_path.exists() {
84            debug!("StorageBackend: path {:?} does not exist", base_path);
85            return (false, None);
86        }
87
88        // Check for any _metadata.json file in the directory
89        if let Ok(entries) = std::fs::read_dir(&base_path) {
90            for entry in entries.flatten() {
91                let path = entry.path();
92                if let Some(name) = path.file_name().and_then(|n| n.to_str())
93                    && name.ends_with("_metadata.json")
94                {
95                    debug!("StorageBackend::exists: found metadata file at {:?}", path);
96                    return (true, Some(path));
97                }
98            }
99        }
100        (false, None)
101    }
102
103    /// Returns the base directory path.
104    fn base_path(&self) -> PathBuf;
105    /// Returns the metadata path.
106    fn metadata_path(&self) -> PathBuf;
107    /// return the base path as file:// string
108    fn basepath_to_uri(&self) -> String;
109
110    /// Load initial data using columnar format from a file path.
111    /// Implementations may use this as a helper for async `load_dense`.
112    async fn load_dense_from_file(&self, path: &Path) -> StorageResult<DenseMatrix<f64>>;
113
114    /// Compute the full Lance/parquet file path for a logical filetype.
115    fn file_path(&self, key: &str) -> PathBuf;
116
117    /// Converts a full file path to a `file://` URI for Lance.
118    fn path_to_uri(path: &Path) -> String {
119        path.canonicalize()
120            .unwrap_or_else(|_| {
121                if path.is_absolute() {
122                    path.to_path_buf()
123                } else if path.is_relative() {
124                    std::env::current_dir()
125                        .unwrap_or_else(|_| PathBuf::from("/"))
126                        .join(path)
127                } else {
128                    PathBuf::from(env!("CARGO_MANIFEST_DIR")).join(path)
129                }
130            })
131            .to_string_lossy()
132            .to_string()
133    }
134
135    /// Validates that the storage directory is properly initialized with metadata.
136    ///
137    /// # Returns
138    ///
139    /// Returns `Ok(())` if metadata file exists, otherwise returns an error.
140    fn validate_initialized(&self, md_path: &Path) -> StorageResult<()> {
141        assert_eq!(self.metadata_path(), *md_path);
142        if !md_path.exists() {
143            return Err(StorageError::Invalid(format!(
144                "Storage not initialized: metadata file missing at {:?}. \
145                 Call save_metadata() or save_eigenmaps_all()/save_energymaps_all() first.",
146                md_path
147            )));
148        }
149        Ok(())
150    }
151
152    // =========
153    // ASYNC API
154    // =========
155
156    /// Converts a dense matrix to a RecordBatch in vector format (Lance-optimized).
157    /// Each row of the matrix becomes a single FixedSizeList entry.
158    ///
159    /// Arguments:
160    /// * matrix - Dense matrix to convert (N rows × F cols)
161    ///
162    /// Returns:
163    /// RecordBatch with schema: { vector: FixedSizeList<Float64>[F] }
164    fn to_dense_record_batch(
165        &self,
166        matrix: &DenseMatrix<f64>,
167    ) -> Result<RecordBatch, StorageError> {
168        let (rows, cols) = (matrix.shape().0, matrix.shape().1);
169
170        debug!(
171            "Converting dense matrix to RecordBatch (vector format): {}x{}",
172            rows, cols
173        );
174
175        if rows == 0 || cols == 0 {
176            return Err(StorageError::Invalid(
177                "Cannot convert empty matrix to RecordBatch".to_string(),
178            ));
179        }
180
181        // Flatten matrix row-by-row into a single Vec<f64>
182        let mut values: Vec<f64> = Vec::with_capacity(rows * cols);
183        for r in 0..rows {
184            for c in 0..cols {
185                values.push(*matrix.get((r, c)));
186            }
187        }
188
189        // Create FixedSizeList field: each entry is a vector of length cols
190        let value_field = Field::new("item", DataType::Float64, false);
191        let list_field = Field::new(
192            "vector",
193            DataType::FixedSizeList(Arc::new(value_field), cols as i32),
194            false,
195        );
196
197        let schema = Schema::new(vec![list_field]);
198
199        // Build the FixedSizeList array
200        let values_array = Float64Array::from(values);
201        let list_array = FixedSizeListArray::new(
202            Arc::new(Field::new("item", DataType::Float64, false)),
203            cols as i32,
204            Arc::new(values_array),
205            None, // No nulls
206        );
207
208        let batch = RecordBatch::try_new(Arc::new(schema), vec![Arc::new(list_array)])
209            .map_err(|e| StorageError::Lance(e.to_string()))?;
210
211        trace!(
212            "RecordBatch created with {} rows (vectors of length {})",
213            batch.num_rows(),
214            cols
215        );
216
217        Ok(batch)
218    }
219
220    /// Reconstructs a dense matrix from a RecordBatch in vector format.
221    ///
222    /// Arguments:
223    /// * batch - RecordBatch containing FixedSizeList<Float64> vectors
224    ///
225    /// Returns:
226    /// DenseMatrix in column-major format (smartcore convention)
227    #[allow(clippy::wrong_self_convention)]
228    fn from_dense_record_batch(
229        &self,
230        batch: &RecordBatch,
231    ) -> Result<DenseMatrix<f64>, StorageError> {
232        use std::mem;
233
234        debug!("Reconstructing dense matrix from RecordBatch (vector format)");
235        debug!("Batch has {} columns", batch.num_columns());
236
237        if batch.num_columns() != 1 {
238            return Err(StorageError::Invalid(format!(
239                "Expected Lance row-major format with 1 FixedSizeList<Float64> column, but found {} columns. \
240                  This parquet file appears to be in wide format (feature-per-column). \
241                  Convert it first using: \
242                  `python -c \"import pyarrow.parquet as pq; import pyarrow.compute as pc; \
243                  tbl = pq.read_table('input.parquet'); \
244                  import pyarrow as pa; \
245                  vectors = pa.array([row.as_py() for row in tbl.to_pylist()], type=pa.list_(pa.float64(), len(tbl.column_names))); \
246                  new_tbl = pa.table({{'vector': vectors}}); \
247                  pq.write_table(new_tbl, 'output.parquet')\"` \
248                  or use a Lance-native writer in your data pipeline.",
249                batch.num_columns()
250            )));
251        }
252
253        debug!("Extracting FixedSizeList column");
254        let column = batch.column(0);
255        let list_array = column
256            .as_any()
257            .downcast_ref::<FixedSizeListArray>()
258            .ok_or_else(|| {
259                StorageError::Invalid(format!(
260                    "Column 0 is not FixedSizeList (found type: {:?}). \
261                      Expected Lance row-major format with a single FixedSizeList<Float64> column.",
262                    column.data_type()
263                ))
264            })?;
265
266        let rows = list_array.len();
267        let cols = list_array.value_length() as usize;
268
269        debug!("Matrix dimensions: {}x{}", rows, cols);
270
271        // Guard against excessive allocations
272        let total = rows
273            .checked_mul(cols)
274            .ok_or_else(|| StorageError::Invalid("Matrix size overflow (rows*cols)".to_string()))?;
275        let bytes = total
276            .checked_mul(mem::size_of::<f64>())
277            .ok_or_else(|| StorageError::Invalid("Byte size overflow".to_string()))?;
278
279        const MAX_BYTES: usize = 4usize * 1024 * 1024 * 1024; // 4 GiB
280        if bytes > MAX_BYTES {
281            return Err(StorageError::Invalid(format!(
282                "Dense load would allocate {} bytes for {}x{} matrix; exceeds 4GiB cap. \
283                  Enable --reduce-dim or shard your input data.",
284                bytes, rows, cols
285            )));
286        }
287
288        // Extract Float64 values
289        let values_array = list_array
290            .values()
291            .as_any()
292            .downcast_ref::<Float64Array>()
293            .ok_or_else(|| {
294                StorageError::Invalid("FixedSizeList values are not Float64Array".to_string())
295            })?;
296
297        debug!("Converting row-major to column-major");
298        let mut data = vec![0.0f64; total];
299        for r in 0..rows {
300            for c in 0..cols {
301                let row_major_idx = r * cols + c;
302                let col_major_idx = c * rows + r;
303                data[col_major_idx] = values_array.value(row_major_idx);
304            }
305        }
306
307        debug!("Creating DenseMatrix");
308        DenseMatrix::new(rows, cols, data, true).map_err(|e| StorageError::Invalid(e.to_string()))
309    }
310
311    /// Converts a sparse CSR matrix to a RecordBatch in columnar format.
312    ///
313    /// Only non-zero entries are stored.
314    fn to_sparse_record_batch(&self, m: &CsMat<f64>) -> StorageResult<RecordBatch> {
315        debug!(
316            "Converting sparse matrix to RecordBatch: {} x {}, nnz={}",
317            m.rows(),
318            m.cols(),
319            m.nnz()
320        );
321
322        let mut row_idx = Vec::with_capacity(m.nnz());
323        let mut col_idx = Vec::with_capacity(m.nnz());
324        let mut vals = Vec::with_capacity(m.nnz());
325
326        for (v, (r, c)) in m.iter() {
327            row_idx.push(r as u32);
328            col_idx.push(c as u32);
329            vals.push(*v);
330        }
331
332        // Store actual dimensions in schema metadata
333        let mut schema_metadata = std::collections::HashMap::new();
334        schema_metadata.insert("rows".to_string(), m.rows().to_string());
335        schema_metadata.insert("cols".to_string(), m.cols().to_string());
336        schema_metadata.insert("nnz".to_string(), m.nnz().to_string());
337
338        let schema = Schema::new(vec![
339            Field::new("row", DataType::UInt32, false),
340            Field::new("col", DataType::UInt32, false),
341            Field::new("value", DataType::Float64, false),
342        ])
343        .with_metadata(schema_metadata);
344
345        let batch = RecordBatch::try_new(
346            Arc::new(schema),
347            vec![
348                Arc::new(UInt32Array::from(row_idx)) as _,
349                Arc::new(UInt32Array::from(col_idx)) as _,
350                Arc::new(Float64Array::from(vals)) as _,
351            ],
352        )
353        .map_err(|e| StorageError::Lance(e.to_string()))?;
354
355        trace!(
356            "Sparse RecordBatch created with {} entries",
357            batch.num_rows()
358        );
359        Ok(batch)
360    }
361
362    /// Reconstructs a sparse CSR matrix from a RecordBatch in columnar format.
363    ///
364    /// * `batch` - RecordBatch containing (`row`, `col`, `value`) triplets
365    /// * `expected_rows` / `expected_cols` - dimensions taken from metadata
366    #[allow(clippy::wrong_self_convention)]
367    fn from_sparse_record_batch(
368        &self,
369        batch: RecordBatch,
370        expected_rows: usize,
371        expected_cols: usize,
372    ) -> StorageResult<CsMat<f64>> {
373        use arrow::array::UInt32Array;
374
375        debug!("Reconstructing sparse matrix from RecordBatch");
376
377        let row_arr = batch
378            .column(0)
379            .as_any()
380            .downcast_ref::<UInt32Array>()
381            .ok_or_else(|| StorageError::Invalid("row column type mismatch".into()))?;
382        let col_arr = batch
383            .column(1)
384            .as_any()
385            .downcast_ref::<UInt32Array>()
386            .ok_or_else(|| StorageError::Invalid("col column type mismatch".into()))?;
387        let val_arr = batch
388            .column(2)
389            .as_any()
390            .downcast_ref::<Float64Array>()
391            .ok_or_else(|| StorageError::Invalid("value column type mismatch".into()))?;
392
393        let n = row_arr.len();
394        if n == 0 {
395            debug!(
396                "Empty RecordBatch, returning {}x{} sparse matrix",
397                expected_rows, expected_cols
398            );
399            return Ok(CsMat::zero((expected_rows, expected_cols)));
400        }
401
402        // Try to read dimensions from schema metadata (for validation)
403        let schema = batch.schema();
404        let schema_metadata = schema.metadata();
405        if let (Some(rows_str), Some(cols_str)) =
406            (schema_metadata.get("rows"), schema_metadata.get("cols"))
407        {
408            let schema_rows = rows_str.parse::<usize>().ok();
409            let schema_cols = cols_str.parse::<usize>().ok();
410            if schema_rows != Some(expected_rows) || schema_cols != Some(expected_cols) {
411                panic!(
412                    "Schema metadata dimensions ({:?}x{:?}) don't match storage metadata ({}x{})",
413                    schema_rows, schema_cols, expected_rows, expected_cols
414                );
415            } else {
416                debug!(
417                    "Schema metadata matches storage metadata: {}x{}",
418                    expected_rows, expected_cols
419                );
420            }
421        }
422
423        let rows = expected_rows;
424        let cols = expected_cols;
425        debug!(
426            "Reconstructing {}x{} sparse matrix from {} entries",
427            rows, cols, n
428        );
429
430        let mut trimat = TriMat::new((rows, cols));
431        for i in 0..n {
432            let r = row_arr.value(i) as usize;
433            let c = col_arr.value(i) as usize;
434            let v = val_arr.value(i);
435
436            if r >= rows || c >= cols {
437                return Err(StorageError::Invalid(format!(
438                    "Index out of bounds: ({}, {}) in {}x{} matrix",
439                    r, c, rows, cols
440                )));
441            }
442            trimat.add_triplet(r, c, v);
443        }
444
445        let result = trimat.to_csr();
446        if result.rows() != rows || result.cols() != cols {
447            return Err(StorageError::Invalid(format!(
448                "Dimension mismatch after reconstruction: expected {}x{}, got {}x{}",
449                rows,
450                cols,
451                result.rows(),
452                result.cols()
453            )));
454        }
455
456        Ok(result)
457    }
458
459    /// Saves a dense matrix. Requires metadata to exist.
460    async fn save_dense(
461        &self,
462        key: &str,
463        matrix: &DenseMatrix<f64>,
464        md_path: &Path,
465    ) -> StorageResult<()>;
466
467    /// Loads a dense matrix from storage.
468    async fn load_dense(&self, key: &str) -> StorageResult<DenseMatrix<f64>>;
469
470    /// Saves a sparse matrix. Requires metadata to exist.
471    async fn save_sparse(
472        &self,
473        key: &str,
474        matrix: &CsMat<f64>,
475        md_path: &Path,
476    ) -> StorageResult<()>;
477
478    /// Loads a sparse matrix from storage.
479    async fn load_sparse(&self, key: &str) -> StorageResult<CsMat<f64>>;
480
481    /// Saves lambda eigenvalues. Requires metadata to exist.
482    async fn save_lambdas(&self, lambdas: &[f64], md_path: &Path) -> StorageResult<()>;
483
484    /// Loads lambda eigenvalues from storage.
485    async fn load_lambdas(&self) -> StorageResult<Vec<f64>>;
486
487    /// Initializes storage by saving metadata. Must be called first.
488    async fn save_metadata(&self, metadata: &GeneMetadata) -> StorageResult<PathBuf> {
489        let path = self.metadata_path();
490        info!("Saving metadata to {:?}", path);
491        fs::create_dir_all(self.base_path()).map_err(|e| StorageError::Io(e.to_string()))?;
492        let s = serde_json::to_string_pretty(metadata).map_err(StorageError::Serde)?;
493        fs::write(&path, s).map_err(|e| StorageError::Io(e.to_string()))?;
494        info!("Metadata saved successfully");
495        Ok(path)
496    }
497
498    /// Loads metadata from storage.
499    async fn load_metadata(&self) -> StorageResult<GeneMetadata> {
500        let filename = self.metadata_path();
501        info!("Loading metadata from {:?}", filename);
502        let s = fs::read_to_string(filename).map_err(|e| StorageError::Io(e.to_string()))?;
503        let md: GeneMetadata = serde_json::from_str(&s).map_err(StorageError::Serde)?;
504        info!("Metadata loaded successfully");
505        Ok(md)
506    }
507
508    /// Save vectors that are not lambdas but indices.
509    #[allow(dead_code)]
510    async fn save_index(&self, key: &str, vector: &[usize], md_path: &Path) -> StorageResult<()>;
511
512    /// save a generic f64 sequence
513    async fn save_vector(&self, key: &str, vector: &[f64], md_path: &Path) -> StorageResult<()>;
514
515    /// Save centroid_map (vector of usize mapping items to centroids)
516    async fn save_centroid_map(&self, map: &[usize], md_path: &Path) -> StorageResult<()>;
517
518    /// Load centroid_map
519    async fn load_centroid_map(&self) -> StorageResult<Vec<usize>>;
520    /// Save subcentroid_lambdas (tau values for subcentroids)
521    async fn save_subcentroid_lambdas(&self, lambdas: &[f64], md_path: &Path) -> StorageResult<()>;
522    /// Load subcentroid_lambdas
523    async fn load_subcentroid_lambdas(&self) -> StorageResult<Vec<f64>>;
524    /// Save subcentroids (dense matrix)
525    async fn save_subcentroids(
526        &self,
527        subcentroids: &DenseMatrix<f64>,
528        md_path: &Path,
529    ) -> StorageResult<()>;
530    /// Load subcentroids
531    async fn load_subcentroids(&self) -> StorageResult<Vec<Vec<f64>>>;
532
533    /// Save item norms (precomputed L2 norms for fast distance computation)
534    async fn save_item_norms(&self, item_norms: &[f64], md_path: &Path) -> StorageResult<()>;
535
536    /// Load item norms
537    async fn load_item_norms(&self) -> StorageResult<Vec<f64>>;
538
539    /// Save cluster assignments (Vec<Option<usize>>)
540    async fn save_cluster_assignments(
541        &self,
542        assignments: &[Option<usize>],
543        md_path: &Path,
544    ) -> StorageResult<()>;
545
546    /// Load cluster assignments
547    async fn load_cluster_assignments(&self) -> StorageResult<Vec<Option<usize>>>;
548
549    /// Load index or generic usize vector from storage.
550    #[allow(dead_code)]
551    async fn load_index(&self, key: &str) -> StorageResult<Vec<usize>>;
552
553    async fn load_vector(&self, key: &str) -> StorageResult<Vec<f64>>;
554
555    async fn save_dense_to_file(data: &DenseMatrix<f64>, path: &Path) -> StorageResult<()>;
556}