Skip to main content

molrs_core/block/
mod.rs

1//! Block: dict-like keyed arrays with consistent axis-0 length and heterogeneous types.
2//!
3//! A Block stores heterogeneous arrays (float, int, bool) keyed by strings,
4//! enforcing that all stored arrays share the same axis-0 length (nrows).
5//!
6//! # Examples
7//!
8//! ```
9//! use molrs_core::block::Block;
10//! use molrs_core::types::{F, I};
11//! use ndarray::{Array1, ArrayD};
12//!
13//! let mut block = Block::new();
14//!
15//! // Insert different types - generic dispatch handles the conversion
16//! let pos = Array1::from_vec(vec![1.0 as F, 2.0 as F, 3.0 as F]).into_dyn();
17//! let ids = Array1::from_vec(vec![10 as I, 20 as I, 30 as I]).into_dyn();
18//!
19//! block.insert("pos", pos).unwrap();
20//! block.insert("id", ids).unwrap();
21//!
22//! // Type-safe retrieval
23//! let pos_ref = block.get_float("pos").unwrap();
24//! let ids_ref = block.get_int("id").unwrap();
25//!
26//! assert_eq!(block.nrows(), Some(3));
27//! assert_eq!(block.len(), 2);
28//! ```
29
30mod column;
31mod dtype;
32mod error;
33
34pub mod access;
35pub mod block_view;
36pub mod column_view;
37
38pub use access::{BlockAccess, ColumnAccess};
39pub use block_view::BlockView;
40pub use column::{Column, ColumnHolder};
41pub use column_view::ColumnView;
42pub use dtype::{BlockDtype, DType};
43pub use error::BlockError;
44
45use ndarray::ArrayD;
46use std::collections::HashMap;
47use std::ops::{Index, IndexMut};
48
49/// A dictionary from string keys to ndarray arrays with a consistent axis-0 length.
50///
51/// This Block supports heterogeneous column types (float, int, bool).
52///
53/// `shape` is optional structural metadata that lets a Block declare itself
54/// as N-dimensional (e.g. a 3D volumetric grid). Columns themselves are
55/// stored row-major, so a grid block with `shape = [Nx, Ny, Nz]` carries
56/// columns of axis-0 length `Nx * Ny * Nz` — `shape` only tells consumers
57/// how to unflatten that index. When `shape` is `None`, the block is a
58/// plain row table and `block.shape()` reports `vec![nrows]`.
59#[derive(Default, Clone)]
60pub struct Block {
61    map: HashMap<String, Column>,
62    nrows: Option<usize>,
63    shape: Option<Vec<usize>>,
64}
65
66impl std::fmt::Debug for Block {
67    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
68        let mut map = f.debug_map();
69        for (k, v) in &self.map {
70            map.entry(k, &format!("{}(shape={:?})", v.dtype(), v.shape()));
71        }
72        map.finish()
73    }
74}
75
76impl Block {
77    /// Creates an empty Block.
78    pub fn new() -> Self {
79        Self {
80            map: HashMap::new(),
81            nrows: None,
82            shape: None,
83        }
84    }
85
86    /// Creates an empty Block with the specified capacity.
87    pub fn with_capacity(cap: usize) -> Self {
88        Self {
89            map: HashMap::with_capacity(cap),
90            nrows: None,
91            shape: None,
92        }
93    }
94
95    /// Number of keys (columns).
96    #[inline]
97    pub fn len(&self) -> usize {
98        self.map.len()
99    }
100
101    /// Returns true if there are no arrays in the block.
102    #[inline]
103    pub fn is_empty(&self) -> bool {
104        self.map.is_empty()
105    }
106
107    /// Returns the common axis-0 length of all arrays, or `None` if empty.
108    #[inline]
109    pub fn nrows(&self) -> Option<usize> {
110        self.nrows
111    }
112
113    /// Returns the structural shape of the block.
114    ///
115    /// - For plain row tables (atoms, bonds): `vec![nrows]` — a single
116    ///   axis whose length is the row count.
117    /// - For N-D blocks (volumetric grids): the explicitly-set shape,
118    ///   e.g. `vec![Nx, Ny, Nz]`.
119    /// - For empty blocks: `vec![]`.
120    ///
121    /// The product of the returned shape always equals `nrows.unwrap_or(0)`.
122    /// This API is uniform across atoms / bonds / grid blocks — the
123    /// difference is the rank of the returned vector, not whether the
124    /// accessor exists.
125    pub fn shape(&self) -> Vec<usize> {
126        match (&self.shape, self.nrows) {
127            (Some(s), _) => s.clone(),
128            (None, Some(n)) => vec![n],
129            (None, None) => Vec::new(),
130        }
131    }
132
133    /// Declare this block as N-dimensional with the given `shape`.
134    ///
135    /// `shape` must have at least one axis and `shape.iter().product()`
136    /// must equal the block's current `nrows` (when the block has columns).
137    /// This does **not** change column storage — columns remain row-major
138    /// 1D buffers of length `product(shape)`. `shape` is structural
139    /// metadata used by consumers (e.g. the volumetric renderer) to
140    /// unflatten the row index back into N-D coordinates.
141    ///
142    /// Passing an empty slice clears the shape, reverting the block to
143    /// plain-row-table semantics.
144    pub fn set_shape(&mut self, shape: &[usize]) -> Result<(), BlockError> {
145        if shape.is_empty() {
146            self.shape = None;
147            return Ok(());
148        }
149        let prod: usize = shape.iter().product();
150        if let Some(nrows) = self.nrows {
151            if prod != nrows {
152                return Err(BlockError::validation(format!(
153                    "shape product {} does not match block nrows {}",
154                    prod, nrows
155                )));
156            }
157        } else {
158            // Block is empty — adopt nrows = product(shape) so subsequent
159            // inserts validate against the flattened length.
160            self.nrows = Some(prod);
161        }
162        self.shape = Some(shape.to_vec());
163        Ok(())
164    }
165
166    /// Returns true if the Block contains the specified key.
167    #[inline]
168    pub fn contains_key(&self, key: &str) -> bool {
169        self.map.contains_key(key)
170    }
171
172    /// Inserts an array under `key`, enforcing consistent axis-0 length.
173    ///
174    /// This method uses generic dispatch via the `BlockDtype` trait to accept
175    /// any supported type (float, int, bool) without requiring users to
176    /// manually construct Column enums.
177    ///
178    /// # Errors
179    ///
180    /// - Returns `BlockError::RankZero` if the array has rank 0
181    /// - Returns `BlockError::RaggedAxis0` if the array's axis-0 length doesn't
182    ///   match the Block's existing `nrows`
183    ///
184    /// # Examples
185    ///
186    /// ```
187    /// use molrs_core::block::Block;
188    /// use molrs_core::types::{F, I};
189    /// use ndarray::Array1;
190    ///
191    /// let mut block = Block::new();
192    ///
193    /// // Insert float array
194    /// let arr_float = Array1::from_vec(vec![1.0 as F, 2.0 as F]).into_dyn();
195    /// block.insert("x", arr_float).unwrap();
196    ///
197    /// // Insert int array with same nrows
198    /// let arr_int = Array1::from_vec(vec![10 as I, 20 as I]).into_dyn();
199    /// block.insert("id", arr_int).unwrap();
200    ///
201    /// // This would error - different nrows
202    /// let arr_bad = Array1::from_vec(vec![1.0 as F, 2.0 as F, 3.0 as F]).into_dyn();
203    /// assert!(block.insert("bad", arr_bad).is_err());
204    /// ```
205    pub fn insert<T: BlockDtype>(
206        &mut self,
207        key: impl Into<String>,
208        arr: ArrayD<T>,
209    ) -> Result<(), BlockError> {
210        let key = key.into();
211        let shape = arr.shape();
212
213        // Check rank >= 1
214        if shape.is_empty() {
215            return Err(BlockError::RankZero { key });
216        }
217
218        let len0 = shape[0];
219
220        // Check axis-0 consistency
221        match self.nrows {
222            None => {
223                // First insertion defines nrows
224                self.nrows = Some(len0);
225            }
226            Some(expected) => {
227                if len0 != expected {
228                    return Err(BlockError::RaggedAxis0 {
229                        key,
230                        expected,
231                        got: len0,
232                    });
233                }
234            }
235        }
236
237        let col = T::into_column(arr);
238        self.map.insert(key, col);
239        Ok(())
240    }
241
242    /// Insert a pre-built [`Column`] under `key`, validating axis-0 length.
243    ///
244    /// This is the zero-copy insert path: the caller owns a [`Column`]
245    /// (which internally holds an `Arc<ArrayD<T>>`) and hands it over
246    /// without unwrapping the Arc. Useful when moving a column between
247    /// blocks or re-inserting a clone.
248    pub fn insert_column(&mut self, key: impl Into<String>, col: Column) -> Result<(), BlockError> {
249        let key = key.into();
250        let shape = col.shape();
251
252        if shape.is_empty() {
253            return Err(BlockError::RankZero { key });
254        }
255
256        let len0 = shape[0];
257
258        match self.nrows {
259            None => {
260                self.nrows = Some(len0);
261            }
262            Some(expected) => {
263                if len0 != expected {
264                    return Err(BlockError::RaggedAxis0 {
265                        key,
266                        expected,
267                        got: len0,
268                    });
269                }
270            }
271        }
272
273        self.map.insert(key, col);
274        Ok(())
275    }
276
277    /// Gets an immutable reference to the column for `key` if present.
278    ///
279    /// For type-safe access, prefer using `get_float()`, `get_int()`, etc.
280    #[inline]
281    pub fn get(&self, key: &str) -> Option<&Column> {
282        self.map.get(key)
283    }
284
285    /// Gets a mutable reference to the column for `key` if present.
286    ///
287    /// For type-safe access, prefer using `get_float_mut()`, `get_int_mut()`, etc.
288    ///
289    /// # Warning
290    ///
291    /// Mutating the column's shape through this reference is allowed but NOT
292    /// revalidated. It's the caller's responsibility to maintain axis-0 consistency.
293    #[inline]
294    pub fn get_mut(&mut self, key: &str) -> Option<&mut Column> {
295        self.map.get_mut(key)
296    }
297
298    // Type-specific getters for the compile-time float scalar.
299
300    /// Gets an immutable reference to a float array for `key` if present and of correct type.
301    pub fn get_float(&self, key: &str) -> Option<&ArrayD<crate::types::F>> {
302        self.get(key).and_then(|c| c.as_float())
303    }
304
305    /// Gets a mutable reference to a float array for `key` if present and of correct type.
306    pub fn get_float_mut(&mut self, key: &str) -> Option<&mut ArrayD<crate::types::F>> {
307        self.get_mut(key).and_then(|c| c.as_float_mut())
308    }
309
310    // Type-specific getters for the compile-time signed integer scalar.
311
312    /// Gets an immutable reference to an int array for `key` if present and of correct type.
313    pub fn get_int(&self, key: &str) -> Option<&ArrayD<crate::types::I>> {
314        self.get(key).and_then(|c| c.as_int())
315    }
316
317    /// Gets a mutable reference to an int array for `key` if present and of correct type.
318    pub fn get_int_mut(&mut self, key: &str) -> Option<&mut ArrayD<crate::types::I>> {
319        self.get_mut(key).and_then(|c| c.as_int_mut())
320    }
321
322    // Type-specific getters for bool
323
324    /// Gets an immutable reference to a bool array for `key` if present and of correct type.
325    pub fn get_bool(&self, key: &str) -> Option<&ArrayD<bool>> {
326        self.get(key).and_then(|c| c.as_bool())
327    }
328
329    /// Gets a mutable reference to a bool array for `key` if present and of correct type.
330    pub fn get_bool_mut(&mut self, key: &str) -> Option<&mut ArrayD<bool>> {
331        self.get_mut(key).and_then(|c| c.as_bool_mut())
332    }
333
334    // Type-specific getters for the compile-time unsigned integer scalar.
335
336    /// Gets an immutable reference to a uint array for `key` if present and of correct type.
337    pub fn get_uint(&self, key: &str) -> Option<&ArrayD<crate::types::U>> {
338        self.get(key).and_then(|c| c.as_uint())
339    }
340
341    /// Gets a mutable reference to a uint array for `key` if present and of correct type.
342    pub fn get_uint_mut(&mut self, key: &str) -> Option<&mut ArrayD<crate::types::U>> {
343        self.get_mut(key).and_then(|c| c.as_uint_mut())
344    }
345
346    // Type-specific getters for u8
347
348    /// Gets an immutable reference to a u8 array for `key` if present and of correct type.
349    pub fn get_u8(&self, key: &str) -> Option<&ArrayD<u8>> {
350        self.get(key).and_then(|c| c.as_u8())
351    }
352
353    /// Gets a mutable reference to a u8 array for `key` if present and of correct type.
354    pub fn get_u8_mut(&mut self, key: &str) -> Option<&mut ArrayD<u8>> {
355        self.get_mut(key).and_then(|c| c.as_u8_mut())
356    }
357
358    // Type-specific getters for String
359
360    /// Gets an immutable reference to a String array for `key` if present and of correct type.
361    pub fn get_string(&self, key: &str) -> Option<&ArrayD<String>> {
362        self.get(key).and_then(|c| c.as_string())
363    }
364
365    /// Gets a mutable reference to a String array for `key` if present and of correct type.
366    pub fn get_string_mut(&mut self, key: &str) -> Option<&mut ArrayD<String>> {
367        self.get_mut(key).and_then(|c| c.as_string_mut())
368    }
369
370    /// Removes and returns the column for `key`, if present.
371    ///
372    /// If the Block becomes empty after removal, resets `nrows` and
373    /// `shape` to `None`.
374    pub fn remove(&mut self, key: &str) -> Option<Column> {
375        let out = self.map.remove(key);
376        if self.map.is_empty() {
377            self.nrows = None;
378            self.shape = None;
379        }
380        out
381    }
382
383    /// Renames a column from `old_key` to `new_key`.
384    ///
385    /// Returns `true` if the column was successfully renamed, `false` if `old_key` doesn't exist
386    /// or `new_key` already exists.
387    ///
388    /// # Examples
389    ///
390    /// ```
391    /// use molrs_core::block::Block;
392    /// use molrs_core::types::F;
393    /// use ndarray::Array1;
394    ///
395    /// let mut block = Block::new();
396    /// block.insert("x", Array1::from_vec(vec![1.0 as F]).into_dyn()).unwrap();
397    ///
398    /// assert!(block.rename_column("x", "position_x"));
399    /// assert!(!block.contains_key("x"));
400    /// assert!(block.contains_key("position_x"));
401    /// ```
402    pub fn rename_column(&mut self, old_key: &str, new_key: &str) -> bool {
403        // Check if old_key exists and new_key doesn't exist
404        if !self.map.contains_key(old_key) || self.map.contains_key(new_key) {
405            return false;
406        }
407
408        // Remove the old key and re-insert with new key
409        if let Some(column) = self.map.remove(old_key) {
410            self.map.insert(new_key.to_string(), column);
411            true
412        } else {
413            false
414        }
415    }
416
417    /// Clears the Block, removing all keys and resetting `nrows` / `shape`.
418    pub fn clear(&mut self) {
419        self.map.clear();
420        self.nrows = None;
421        self.shape = None;
422    }
423
424    /// Returns an iterator over (&str, &Column).
425    pub fn iter(&self) -> impl Iterator<Item = (&str, &Column)> {
426        self.map.iter().map(|(k, v)| (k.as_str(), v))
427    }
428
429    /// Returns an iterator over keys.
430    pub fn keys(&self) -> impl Iterator<Item = &str> {
431        self.map.keys().map(|k| k.as_str())
432    }
433
434    /// Returns an iterator over column references.
435    pub fn values(&self) -> impl Iterator<Item = &Column> {
436        self.map.values()
437    }
438
439    /// Returns the data type of the column with the given key, if it exists.
440    pub fn dtype(&self, key: &str) -> Option<DType> {
441        self.get(key).map(|c| c.dtype())
442    }
443
444    /// Resize all columns along axis 0 to `new_nrows`.
445    ///
446    /// - **Shrink** (`new_nrows` < current): slices each column to keep the first `new_nrows` rows.
447    /// - **Grow** (`new_nrows` > current): extends each column with default values
448    ///   (0.0 for Float, 0 for Int/UInt/U8, false for Bool, empty string for String).
449    /// - **Same size**: no-op, returns `Ok(())`.
450    /// - **Empty block** (no columns): sets `nrows` without touching columns.
451    ///
452    /// Multi-dimensional columns (e.g. Nx3 positions) are resized only along
453    /// axis 0; trailing dimensions are preserved.
454    ///
455    /// # Arguments
456    /// * `new_nrows` - The desired number of rows after resize.
457    ///
458    /// # Returns
459    /// * `Ok(())` on success.
460    ///
461    /// # Examples
462    ///
463    /// ```
464    /// use molrs_core::block::Block;
465    /// use molrs_core::types::F;
466    /// use ndarray::Array1;
467    ///
468    /// let mut block = Block::new();
469    /// block.insert("x", Array1::from_vec(vec![1.0 as F, 2.0 as F]).into_dyn()).unwrap();
470    ///
471    /// block.resize(4).unwrap();
472    /// assert_eq!(block.nrows(), Some(4));
473    /// let x = block.get_float("x").unwrap();
474    /// assert_eq!(x.as_slice_memory_order().unwrap(), &[1.0, 2.0, 0.0, 0.0]);
475    /// ```
476    pub fn resize(&mut self, new_nrows: usize) -> Result<(), crate::error::MolRsError> {
477        if self.is_empty() {
478            self.nrows = Some(new_nrows);
479            return Ok(());
480        }
481
482        let current = self.nrows.unwrap_or(0);
483        if new_nrows == current {
484            return Ok(());
485        }
486
487        for col in self.map.values_mut() {
488            col.resize(new_nrows);
489        }
490        self.nrows = Some(new_nrows);
491        // N-D shape becomes meaningless once axis-0 row count is changed
492        // by a 1D resize. Callers that want to preserve a grid shape must
493        // re-declare it via `set_shape` after resizing.
494        self.shape = None;
495        Ok(())
496    }
497
498    /// Merge another block into this one by concatenating columns along axis-0.
499    ///
500    /// Both blocks must have the same set of column keys and matching dtypes.
501    /// The resulting block will have nrows = self.nrows + other.nrows.
502    ///
503    /// # Arguments
504    /// * `other` - The block to merge into this one
505    ///
506    /// # Returns
507    /// * `Ok(())` if merge succeeds
508    /// * `Err(BlockError)` if blocks have incompatible columns
509    ///
510    /// # Examples
511    ///
512    /// ```
513    /// use molrs_core::block::Block;
514    /// use molrs_core::types::F;
515    /// use ndarray::Array1;
516    ///
517    /// let mut block1 = Block::new();
518    /// block1.insert("x", Array1::from_vec(vec![1.0 as F, 2.0 as F]).into_dyn()).unwrap();
519    ///
520    /// let mut block2 = Block::new();
521    /// block2.insert("x", Array1::from_vec(vec![3.0 as F, 4.0 as F]).into_dyn()).unwrap();
522    ///
523    /// block1.merge(&block2).unwrap();
524    /// assert_eq!(block1.nrows(), Some(4));
525    /// ```
526    pub fn merge(&mut self, other: &Block) -> Result<(), BlockError> {
527        use ndarray::Axis;
528        use ndarray::concatenate;
529
530        // If other is empty, nothing to do
531        if other.is_empty() {
532            return Ok(());
533        }
534
535        // If self is empty, clone other
536        if self.is_empty() {
537            self.map = other.map.clone();
538            self.nrows = other.nrows;
539            self.shape = other.shape.clone();
540            return Ok(());
541        }
542
543        // Check that both blocks have the same keys
544        let self_keys: std::collections::HashSet<_> = self.keys().collect();
545        let other_keys: std::collections::HashSet<_> = other.keys().collect();
546
547        if self_keys != other_keys {
548            return Err(BlockError::validation(format!(
549                "Cannot merge blocks with different keys. Self has {:?}, other has {:?}",
550                self_keys, other_keys
551            )));
552        }
553
554        // Merge each column
555        let mut new_map = HashMap::new();
556        for key in self.keys() {
557            let self_col = &self.map[key];
558            let other_col = &other.map[key];
559
560            // Check dtype compatibility
561            if self_col.dtype() != other_col.dtype() {
562                return Err(BlockError::validation(format!(
563                    "Column '{}' has incompatible dtypes: {:?} vs {:?}",
564                    key,
565                    self_col.dtype(),
566                    other_col.dtype()
567                )));
568            }
569
570            // Concatenate based on dtype
571            let merged_col = match (self_col, other_col) {
572                (Column::Float(a), Column::Float(b)) => {
573                    let merged = concatenate(Axis(0), &[a.view(), b.view()]).map_err(|e| {
574                        BlockError::validation(format!(
575                            "Failed to concatenate float column '{}': {}",
576                            key, e
577                        ))
578                    })?;
579                    Column::from_float(merged)
580                }
581                (Column::Int(a), Column::Int(b)) => {
582                    let merged = concatenate(Axis(0), &[a.view(), b.view()]).map_err(|e| {
583                        BlockError::validation(format!(
584                            "Failed to concatenate int column '{}': {}",
585                            key, e
586                        ))
587                    })?;
588                    Column::from_int(merged)
589                }
590                (Column::UInt(a), Column::UInt(b)) => {
591                    let merged = concatenate(Axis(0), &[a.view(), b.view()]).map_err(|e| {
592                        BlockError::validation(format!(
593                            "Failed to concatenate uint column '{}': {}",
594                            key, e
595                        ))
596                    })?;
597                    Column::from_uint(merged)
598                }
599                (Column::U8(a), Column::U8(b)) => {
600                    let merged = concatenate(Axis(0), &[a.view(), b.view()]).map_err(|e| {
601                        BlockError::validation(format!(
602                            "Failed to concatenate u8 column '{}': {}",
603                            key, e
604                        ))
605                    })?;
606                    Column::from_u8(merged)
607                }
608                (Column::Bool(a), Column::Bool(b)) => {
609                    let merged = concatenate(Axis(0), &[a.view(), b.view()]).map_err(|e| {
610                        BlockError::validation(format!(
611                            "Failed to concatenate bool column '{}': {}",
612                            key, e
613                        ))
614                    })?;
615                    Column::from_bool(merged)
616                }
617                (Column::String(a), Column::String(b)) => {
618                    let merged = concatenate(Axis(0), &[a.view(), b.view()]).map_err(|e| {
619                        BlockError::validation(format!(
620                            "Failed to concatenate string column '{}': {}",
621                            key, e
622                        ))
623                    })?;
624                    Column::from_string(merged)
625                }
626                _ => unreachable!("dtype mismatch already checked"),
627            };
628
629            new_map.insert(key.to_string(), merged_col);
630        }
631
632        // Update nrows. As with `resize`, an explicit N-D shape becomes
633        // meaningless once axis-0 grows; the merged block falls back to a
634        // plain row table unless the caller re-declares a shape.
635        let new_nrows = self.nrows.unwrap() + other.nrows.unwrap();
636        self.map = new_map;
637        self.nrows = Some(new_nrows);
638        self.shape = None;
639
640        Ok(())
641    }
642}
643
644// Index trait for convenient access: block["key"]
645impl Index<&str> for Block {
646    type Output = Column;
647
648    fn index(&self, key: &str) -> &Self::Output {
649        self.get(key)
650            .unwrap_or_else(|| panic!("key '{}' not found in Block", key))
651    }
652}
653
654impl IndexMut<&str> for Block {
655    fn index_mut(&mut self, key: &str) -> &mut Self::Output {
656        self.get_mut(key)
657            .unwrap_or_else(|| panic!("key '{}' not found in Block", key))
658    }
659}
660
661#[cfg(test)]
662mod tests {
663    use super::*;
664    use crate::types::{F, I};
665    use ndarray::Array1;
666
667    #[test]
668    fn test_insert_mixed_dtypes() {
669        let mut block = Block::new();
670
671        let arr_float = Array1::from_vec(vec![1.0 as F, 2.0 as F, 3.0 as F]).into_dyn();
672        let arr_float_2 = Array1::from_vec(vec![4.0 as F, 5.0 as F, 6.0 as F]).into_dyn();
673        let arr_i64 = Array1::from_vec(vec![10 as I, 20, 30]).into_dyn();
674        let arr_bool = Array1::from_vec(vec![true, false, true]).into_dyn();
675
676        assert!(block.insert("x", arr_float).is_ok());
677        assert!(block.insert("y", arr_float_2).is_ok());
678        assert!(block.insert("id", arr_i64).is_ok());
679        assert!(block.insert("mask", arr_bool).is_ok());
680
681        assert_eq!(block.len(), 4);
682        assert_eq!(block.nrows(), Some(3));
683    }
684
685    #[test]
686    fn test_axis0_mismatch_error() {
687        let mut block = Block::new();
688
689        let arr1 = Array1::from_vec(vec![1.0 as F, 2.0 as F, 3.0 as F]).into_dyn();
690        let arr2 = Array1::from_vec(vec![1.0 as F, 2.0 as F]).into_dyn();
691
692        block.insert("x", arr1).unwrap();
693        let result = block.insert("y", arr2);
694
695        assert!(result.is_err());
696        match result {
697            Err(BlockError::RaggedAxis0 { expected, got, .. }) => {
698                assert_eq!(expected, 3);
699                assert_eq!(got, 2);
700            }
701            _ => panic!("Expected RaggedAxis0 error"),
702        }
703    }
704
705    #[test]
706    fn test_typed_getters() {
707        let mut block = Block::new();
708
709        let arr_float = Array1::from_vec(vec![1.0 as F, 2.0 as F]).into_dyn();
710        let arr_i64 = Array1::from_vec(vec![10 as I, 20]).into_dyn();
711
712        block.insert("x", arr_float).unwrap();
713        block.insert("id", arr_i64).unwrap();
714
715        // Correct type access
716        assert!(block.get_float("x").is_some());
717        assert!(block.get_int("id").is_some());
718
719        // Wrong type access returns None
720        assert!(block.get_int("x").is_none());
721        assert!(block.get_float("id").is_none());
722
723        // Mutable access
724        if let Some(x_mut) = block.get_float_mut("x") {
725            x_mut[[0]] = 99.0;
726        }
727        assert_eq!(block.get_float("x").unwrap()[[0]], 99.0);
728    }
729
730    #[test]
731    fn test_index_access() {
732        let mut block = Block::new();
733
734        let arr = Array1::from_vec(vec![1.0 as F, 2.0 as F]).into_dyn();
735        block.insert("x", arr).unwrap();
736
737        // Immutable index
738        let col = &block["x"];
739        assert_eq!(col.dtype(), DType::Float);
740
741        // Mutable index
742        let col_mut = &mut block["x"];
743        if let Some(arr_mut) = col_mut.as_float_mut() {
744            arr_mut[[0]] = 42.0;
745        }
746        assert_eq!(block.get_float("x").unwrap()[[0]], 42.0);
747    }
748
749    #[test]
750    #[should_panic(expected = "key 'missing' not found")]
751    fn test_index_panic_on_missing_key() {
752        let block = Block::new();
753        let _ = &block["missing"];
754    }
755
756    #[test]
757    fn test_remove_resets_nrows() {
758        let mut block = Block::new();
759
760        let arr = Array1::from_vec(vec![1.0 as F, 2.0 as F]).into_dyn();
761        block.insert("x", arr).unwrap();
762
763        assert_eq!(block.nrows(), Some(2));
764
765        block.remove("x");
766        assert_eq!(block.nrows(), None);
767        assert!(block.is_empty());
768    }
769
770    #[test]
771    fn test_iter_keys_values() {
772        let mut block = Block::new();
773
774        let arr1 = Array1::from_vec(vec![1.0 as F, 2.0 as F]).into_dyn();
775        let arr2 = Array1::from_vec(vec![10 as I, 20]).into_dyn();
776
777        block.insert("x", arr1).unwrap();
778        block.insert("id", arr2).unwrap();
779
780        let keys: Vec<&str> = block.keys().collect();
781        assert_eq!(keys.len(), 2);
782        assert!(keys.contains(&"x"));
783        assert!(keys.contains(&"id"));
784
785        let dtypes: Vec<DType> = block.values().map(|c| c.dtype()).collect();
786        assert!(dtypes.contains(&DType::Float));
787        assert!(dtypes.contains(&DType::Int));
788    }
789
790    #[test]
791    fn test_rank_zero_error() {
792        let mut block = Block::new();
793
794        // Create a rank-0 array (scalar)
795        let arr = ArrayD::<F>::zeros(vec![]);
796
797        let result = block.insert("scalar", arr);
798        assert!(result.is_err());
799        match result {
800            Err(BlockError::RankZero { key }) => {
801                assert_eq!(key, "scalar");
802            }
803            _ => panic!("Expected RankZero error"),
804        }
805    }
806
807    #[test]
808    fn test_dtype_query() {
809        let mut block = Block::new();
810
811        let arr_float = Array1::from_vec(vec![1.0 as F]).into_dyn();
812        let arr_i64 = Array1::from_vec(vec![10 as I]).into_dyn();
813
814        block.insert("x", arr_float).unwrap();
815        block.insert("id", arr_i64).unwrap();
816
817        assert_eq!(block.dtype("x"), Some(DType::Float));
818        assert_eq!(block.dtype("id"), Some(DType::Int));
819        assert_eq!(block.dtype("missing"), None);
820    }
821
822    #[test]
823    fn test_merge_basic() {
824        let mut block1 = Block::new();
825        let mut block2 = Block::new();
826
827        let arr1 = Array1::from_vec(vec![1.0 as F, 2.0 as F]).into_dyn();
828        let arr2 = Array1::from_vec(vec![3.0 as F, 4.0 as F]).into_dyn();
829
830        block1.insert("x", arr1).unwrap();
831        block2.insert("x", arr2).unwrap();
832
833        block1.merge(&block2).unwrap();
834
835        assert_eq!(block1.nrows(), Some(4));
836        let x = block1.get_float("x").unwrap();
837        assert_eq!(x.as_slice_memory_order().unwrap(), &[1.0, 2.0, 3.0, 4.0]);
838    }
839
840    #[test]
841    fn test_merge_empty_blocks() {
842        let mut block1 = Block::new();
843        let mut block2 = Block::new();
844
845        // Merge empty into empty
846        block1.merge(&block2).unwrap();
847        assert_eq!(block1.nrows(), None);
848
849        // Merge non-empty into empty
850        let arr = Array1::from_vec(vec![1.0 as F, 2.0 as F]).into_dyn();
851        block2.insert("x", arr).unwrap();
852        block1.merge(&block2).unwrap();
853        assert_eq!(block1.nrows(), Some(2));
854
855        // Merge empty into non-empty
856        let block3 = Block::new();
857        block1.merge(&block3).unwrap();
858        assert_eq!(block1.nrows(), Some(2));
859    }
860
861    #[test]
862    fn test_merge_incompatible_keys() {
863        let mut block1 = Block::new();
864        let mut block2 = Block::new();
865
866        let arr1 = Array1::from_vec(vec![1.0 as F, 2.0 as F]).into_dyn();
867        let arr2 = Array1::from_vec(vec![3.0 as F, 4.0 as F]).into_dyn();
868
869        block1.insert("x", arr1).unwrap();
870        block2.insert("y", arr2).unwrap();
871
872        let result = block1.merge(&block2);
873        assert!(result.is_err());
874    }
875
876    #[test]
877    fn test_merge_incompatible_dtypes() {
878        let mut block1 = Block::new();
879        let mut block2 = Block::new();
880
881        let arr1 = Array1::from_vec(vec![1.0 as F, 2.0 as F]).into_dyn();
882        let arr2 = Array1::from_vec(vec![3 as I, 4]).into_dyn();
883
884        block1.insert("x", arr1).unwrap();
885        block2.insert("x", arr2).unwrap();
886
887        let result = block1.merge(&block2);
888        assert!(result.is_err());
889    }
890
891    #[test]
892    fn test_merge_multiple_columns() {
893        let mut block1 = Block::new();
894        let mut block2 = Block::new();
895
896        block1
897            .insert("x", Array1::from_vec(vec![1.0 as F, 2.0 as F]).into_dyn())
898            .unwrap();
899        block1
900            .insert("id", Array1::from_vec(vec![10 as I, 20]).into_dyn())
901            .unwrap();
902
903        block2
904            .insert("x", Array1::from_vec(vec![3.0 as F, 4.0 as F]).into_dyn())
905            .unwrap();
906        block2
907            .insert("id", Array1::from_vec(vec![30 as I, 40]).into_dyn())
908            .unwrap();
909
910        block1.merge(&block2).unwrap();
911
912        assert_eq!(block1.nrows(), Some(4));
913        let x = block1.get_float("x").unwrap();
914        assert_eq!(x.as_slice_memory_order().unwrap(), &[1.0, 2.0, 3.0, 4.0]);
915        let id = block1.get_int("id").unwrap();
916        assert_eq!(id.as_slice_memory_order().unwrap(), &[10, 20, 30, 40]);
917    }
918
919    #[test]
920    fn test_rename_column() {
921        let mut block = Block::new();
922        block
923            .insert("x", Array1::from_vec(vec![1.0 as F, 2.0 as F]).into_dyn())
924            .unwrap();
925        block
926            .insert("y", Array1::from_vec(vec![3.0 as F, 4.0 as F]).into_dyn())
927            .unwrap();
928
929        // Successful rename
930        assert!(block.rename_column("x", "position_x"));
931        assert!(!block.contains_key("x"));
932        assert!(block.contains_key("position_x"));
933        assert_eq!(
934            block
935                .get_float("position_x")
936                .unwrap()
937                .as_slice_memory_order()
938                .unwrap(),
939            &[1.0, 2.0]
940        );
941
942        // Try to rename non-existent column
943        assert!(!block.rename_column("nonexistent", "new_name"));
944
945        // Try to rename to existing column name
946        assert!(!block.rename_column("position_x", "y"));
947    }
948
949    #[test]
950    fn test_resize_shrink() {
951        let mut block = Block::new();
952        block
953            .insert(
954                "x",
955                Array1::from_vec(vec![1.0 as F, 2.0, 3.0, 4.0]).into_dyn(),
956            )
957            .unwrap();
958        block
959            .insert("id", Array1::from_vec(vec![10 as I, 20, 30, 40]).into_dyn())
960            .unwrap();
961
962        block.resize(2).unwrap();
963
964        assert_eq!(block.nrows(), Some(2));
965        let x = block.get_float("x").unwrap();
966        assert_eq!(x.as_slice_memory_order().unwrap(), &[1.0, 2.0]);
967        let id = block.get_int("id").unwrap();
968        assert_eq!(id.as_slice_memory_order().unwrap(), &[10, 20]);
969    }
970
971    #[test]
972    fn test_resize_grow() {
973        let mut block = Block::new();
974        block
975            .insert("x", Array1::from_vec(vec![1.0 as F, 2.0]).into_dyn())
976            .unwrap();
977        block
978            .insert("id", Array1::from_vec(vec![10 as I, 20]).into_dyn())
979            .unwrap();
980
981        block.resize(4).unwrap();
982
983        assert_eq!(block.nrows(), Some(4));
984        let x = block.get_float("x").unwrap();
985        // Original data preserved, new rows are 0.0
986        assert_eq!(x.as_slice_memory_order().unwrap(), &[1.0, 2.0, 0.0, 0.0]);
987        let id = block.get_int("id").unwrap();
988        // Original data preserved, new rows are 0
989        assert_eq!(id.as_slice_memory_order().unwrap(), &[10, 20, 0, 0]);
990    }
991
992    #[test]
993    fn test_resize_same() {
994        let mut block = Block::new();
995        block
996            .insert("x", Array1::from_vec(vec![1.0 as F, 2.0, 3.0]).into_dyn())
997            .unwrap();
998
999        block.resize(3).unwrap();
1000
1001        assert_eq!(block.nrows(), Some(3));
1002        let x = block.get_float("x").unwrap();
1003        assert_eq!(x.as_slice_memory_order().unwrap(), &[1.0, 2.0, 3.0]);
1004    }
1005
1006    #[test]
1007    fn test_resize_empty() {
1008        let mut block = Block::new();
1009
1010        block.resize(5).unwrap();
1011        assert_eq!(block.nrows(), Some(5));
1012        assert!(block.is_empty());
1013    }
1014
1015    #[test]
1016    fn test_resize_multidim() {
1017        use ndarray::Array2;
1018
1019        let mut block = Block::new();
1020        // 4x3 position array
1021        let pos = Array2::from_shape_vec(
1022            (4, 3),
1023            vec![
1024                1.0 as F, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
1025            ],
1026        )
1027        .unwrap()
1028        .into_dyn();
1029        block.insert("pos", pos).unwrap();
1030
1031        // Shrink 4x3 -> 2x3
1032        block.resize(2).unwrap();
1033        assert_eq!(block.nrows(), Some(2));
1034        let pos = block.get_float("pos").unwrap();
1035        assert_eq!(pos.shape(), &[2, 3]);
1036        assert_eq!(
1037            pos.as_slice_memory_order().unwrap(),
1038            &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
1039        );
1040
1041        // Grow 2x3 -> 5x3
1042        block.resize(5).unwrap();
1043        assert_eq!(block.nrows(), Some(5));
1044        let pos = block.get_float("pos").unwrap();
1045        assert_eq!(pos.shape(), &[5, 3]);
1046        // Original data followed by zeros
1047        assert_eq!(
1048            pos.as_slice_memory_order().unwrap(),
1049            &[
1050                1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0
1051            ]
1052        );
1053    }
1054
1055    #[test]
1056    fn test_resize_mixed_dtypes() {
1057        let mut block = Block::new();
1058        block
1059            .insert("x", Array1::from_vec(vec![1.0 as F, 2.0, 3.0]).into_dyn())
1060            .unwrap();
1061        block
1062            .insert("id", Array1::from_vec(vec![10 as I, 20, 30]).into_dyn())
1063            .unwrap();
1064        block
1065            .insert("mask", Array1::from_vec(vec![true, false, true]).into_dyn())
1066            .unwrap();
1067        block
1068            .insert(
1069                "name",
1070                Array1::from_vec(vec!["a".to_string(), "b".to_string(), "c".to_string()])
1071                    .into_dyn(),
1072            )
1073            .unwrap();
1074
1075        // Grow from 3 to 5
1076        block.resize(5).unwrap();
1077        assert_eq!(block.nrows(), Some(5));
1078
1079        let x = block.get_float("x").unwrap();
1080        assert_eq!(
1081            x.as_slice_memory_order().unwrap(),
1082            &[1.0, 2.0, 3.0, 0.0, 0.0]
1083        );
1084        let id = block.get_int("id").unwrap();
1085        assert_eq!(id.as_slice_memory_order().unwrap(), &[10, 20, 30, 0, 0]);
1086        let mask = block.get_bool("mask").unwrap();
1087        assert_eq!(
1088            mask.as_slice_memory_order().unwrap(),
1089            &[true, false, true, false, false]
1090        );
1091        let name = block.get_string("name").unwrap();
1092        assert_eq!(name[[0]], "a");
1093        assert_eq!(name[[1]], "b");
1094        assert_eq!(name[[2]], "c");
1095        assert_eq!(name[[3]], "");
1096        assert_eq!(name[[4]], "");
1097
1098        // Shrink from 5 to 2
1099        block.resize(2).unwrap();
1100        assert_eq!(block.nrows(), Some(2));
1101
1102        let x = block.get_float("x").unwrap();
1103        assert_eq!(x.as_slice_memory_order().unwrap(), &[1.0, 2.0]);
1104        let id = block.get_int("id").unwrap();
1105        assert_eq!(id.as_slice_memory_order().unwrap(), &[10, 20]);
1106        let mask = block.get_bool("mask").unwrap();
1107        assert_eq!(mask.as_slice_memory_order().unwrap(), &[true, false]);
1108        let name = block.get_string("name").unwrap();
1109        assert_eq!(name[[0]], "a");
1110        assert_eq!(name[[1]], "b");
1111    }
1112}