Skip to main content

molrs/block/
mod.rs

1//! Block: dict-like keyed arrays with consistent axis-0 length and heterogeneous types.
2//!
3//! A Block stores heterogeneous arrays (float, int, bool) keyed by strings,
4//! enforcing that all stored arrays share the same axis-0 length (nrows).
5//!
6//! # Examples
7//!
8//! ```
9//! use molrs::block::Block;
10//! use molrs::types::{F, I};
11//! use ndarray::{Array1, ArrayD};
12//!
13//! let mut block = Block::new();
14//!
15//! // Insert different types - generic dispatch handles the conversion
16//! let pos = Array1::from_vec(vec![1.0 as F, 2.0 as F, 3.0 as F]).into_dyn();
17//! let ids = Array1::from_vec(vec![10 as I, 20 as I, 30 as I]).into_dyn();
18//!
19//! block.insert("pos", pos).unwrap();
20//! block.insert("id", ids).unwrap();
21//!
22//! // Type-safe retrieval
23//! let pos_ref = block.get_float("pos").unwrap();
24//! let ids_ref = block.get_int("id").unwrap();
25//!
26//! assert_eq!(block.nrows(), Some(3));
27//! assert_eq!(block.len(), 2);
28//! ```
29
30mod column;
31mod dtype;
32mod error;
33
34pub use column::Column;
35pub use dtype::{BlockDtype, DType};
36pub use error::BlockError;
37
38use ndarray::ArrayD;
39use std::collections::HashMap;
40use std::ops::{Index, IndexMut};
41
42/// A dictionary from string keys to ndarray arrays with a consistent axis-0 length.
43///
44/// This Block supports heterogeneous column types (float, int, bool).
45#[derive(Default, Clone)]
46pub struct Block {
47    map: HashMap<String, Column>,
48    nrows: Option<usize>,
49}
50
51impl std::fmt::Debug for Block {
52    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
53        let mut map = f.debug_map();
54        for (k, v) in &self.map {
55            map.entry(k, &format!("{}(shape={:?})", v.dtype(), v.shape()));
56        }
57        map.finish()
58    }
59}
60
61impl Block {
62    /// Creates an empty Block.
63    pub fn new() -> Self {
64        Self {
65            map: HashMap::new(),
66            nrows: None,
67        }
68    }
69
70    /// Creates an empty Block with the specified capacity.
71    pub fn with_capacity(cap: usize) -> Self {
72        Self {
73            map: HashMap::with_capacity(cap),
74            nrows: None,
75        }
76    }
77
78    /// Number of keys (columns).
79    #[inline]
80    pub fn len(&self) -> usize {
81        self.map.len()
82    }
83
84    /// Returns true if there are no arrays in the block.
85    #[inline]
86    pub fn is_empty(&self) -> bool {
87        self.map.is_empty()
88    }
89
90    /// Returns the common axis-0 length of all arrays, or `None` if empty.
91    #[inline]
92    pub fn nrows(&self) -> Option<usize> {
93        self.nrows
94    }
95
96    /// Returns true if the Block contains the specified key.
97    #[inline]
98    pub fn contains_key(&self, key: &str) -> bool {
99        self.map.contains_key(key)
100    }
101
102    /// Inserts an array under `key`, enforcing consistent axis-0 length.
103    ///
104    /// This method uses generic dispatch via the `BlockDtype` trait to accept
105    /// any supported type (float, int, bool) without requiring users to
106    /// manually construct Column enums.
107    ///
108    /// # Errors
109    ///
110    /// - Returns `BlockError::RankZero` if the array has rank 0
111    /// - Returns `BlockError::RaggedAxis0` if the array's axis-0 length doesn't
112    ///   match the Block's existing `nrows`
113    ///
114    /// # Examples
115    ///
116    /// ```
117    /// use molrs::block::Block;
118    /// use molrs::types::{F, I};
119    /// use ndarray::Array1;
120    ///
121    /// let mut block = Block::new();
122    ///
123    /// // Insert float array
124    /// let arr_float = Array1::from_vec(vec![1.0 as F, 2.0 as F]).into_dyn();
125    /// block.insert("x", arr_float).unwrap();
126    ///
127    /// // Insert int array with same nrows
128    /// let arr_int = Array1::from_vec(vec![10 as I, 20 as I]).into_dyn();
129    /// block.insert("id", arr_int).unwrap();
130    ///
131    /// // This would error - different nrows
132    /// let arr_bad = Array1::from_vec(vec![1.0 as F, 2.0 as F, 3.0 as F]).into_dyn();
133    /// assert!(block.insert("bad", arr_bad).is_err());
134    /// ```
135    pub fn insert<T: BlockDtype>(
136        &mut self,
137        key: impl Into<String>,
138        arr: ArrayD<T>,
139    ) -> Result<(), BlockError> {
140        let key = key.into();
141        let shape = arr.shape();
142
143        // Check rank >= 1
144        if shape.is_empty() {
145            return Err(BlockError::RankZero { key });
146        }
147
148        let len0 = shape[0];
149
150        // Check axis-0 consistency
151        match self.nrows {
152            None => {
153                // First insertion defines nrows
154                self.nrows = Some(len0);
155            }
156            Some(expected) => {
157                if len0 != expected {
158                    return Err(BlockError::RaggedAxis0 {
159                        key,
160                        expected,
161                        got: len0,
162                    });
163                }
164            }
165        }
166
167        let col = T::into_column(arr);
168        self.map.insert(key, col);
169        Ok(())
170    }
171
172    /// Gets an immutable reference to the column for `key` if present.
173    ///
174    /// For type-safe access, prefer using `get_float()`, `get_int()`, etc.
175    #[inline]
176    pub fn get(&self, key: &str) -> Option<&Column> {
177        self.map.get(key)
178    }
179
180    /// Gets a mutable reference to the column for `key` if present.
181    ///
182    /// For type-safe access, prefer using `get_float_mut()`, `get_int_mut()`, etc.
183    ///
184    /// # Warning
185    ///
186    /// Mutating the column's shape through this reference is allowed but NOT
187    /// revalidated. It's the caller's responsibility to maintain axis-0 consistency.
188    #[inline]
189    pub fn get_mut(&mut self, key: &str) -> Option<&mut Column> {
190        self.map.get_mut(key)
191    }
192
193    // Type-specific getters for the compile-time float scalar.
194
195    /// Gets an immutable reference to a float array for `key` if present and of correct type.
196    pub fn get_float(&self, key: &str) -> Option<&ArrayD<crate::types::F>> {
197        self.get(key).and_then(|c| c.as_float())
198    }
199
200    /// Gets a mutable reference to a float array for `key` if present and of correct type.
201    pub fn get_float_mut(&mut self, key: &str) -> Option<&mut ArrayD<crate::types::F>> {
202        self.get_mut(key).and_then(|c| c.as_float_mut())
203    }
204
205    // Type-specific getters for the compile-time signed integer scalar.
206
207    /// Gets an immutable reference to an int array for `key` if present and of correct type.
208    pub fn get_int(&self, key: &str) -> Option<&ArrayD<crate::types::I>> {
209        self.get(key).and_then(|c| c.as_int())
210    }
211
212    /// Gets a mutable reference to an int array for `key` if present and of correct type.
213    pub fn get_int_mut(&mut self, key: &str) -> Option<&mut ArrayD<crate::types::I>> {
214        self.get_mut(key).and_then(|c| c.as_int_mut())
215    }
216
217    // Type-specific getters for bool
218
219    /// Gets an immutable reference to a bool array for `key` if present and of correct type.
220    pub fn get_bool(&self, key: &str) -> Option<&ArrayD<bool>> {
221        self.get(key).and_then(|c| c.as_bool())
222    }
223
224    /// Gets a mutable reference to a bool array for `key` if present and of correct type.
225    pub fn get_bool_mut(&mut self, key: &str) -> Option<&mut ArrayD<bool>> {
226        self.get_mut(key).and_then(|c| c.as_bool_mut())
227    }
228
229    // Type-specific getters for the compile-time unsigned integer scalar.
230
231    /// Gets an immutable reference to a uint array for `key` if present and of correct type.
232    pub fn get_uint(&self, key: &str) -> Option<&ArrayD<crate::types::U>> {
233        self.get(key).and_then(|c| c.as_uint())
234    }
235
236    /// Gets a mutable reference to a uint array for `key` if present and of correct type.
237    pub fn get_uint_mut(&mut self, key: &str) -> Option<&mut ArrayD<crate::types::U>> {
238        self.get_mut(key).and_then(|c| c.as_uint_mut())
239    }
240
241    // Type-specific getters for u8
242
243    /// Gets an immutable reference to a u8 array for `key` if present and of correct type.
244    pub fn get_u8(&self, key: &str) -> Option<&ArrayD<u8>> {
245        self.get(key).and_then(|c| c.as_u8())
246    }
247
248    /// Gets a mutable reference to a u8 array for `key` if present and of correct type.
249    pub fn get_u8_mut(&mut self, key: &str) -> Option<&mut ArrayD<u8>> {
250        self.get_mut(key).and_then(|c| c.as_u8_mut())
251    }
252
253    // Type-specific getters for String
254
255    /// Gets an immutable reference to a String array for `key` if present and of correct type.
256    pub fn get_string(&self, key: &str) -> Option<&ArrayD<String>> {
257        self.get(key).and_then(|c| c.as_string())
258    }
259
260    /// Gets a mutable reference to a String array for `key` if present and of correct type.
261    pub fn get_string_mut(&mut self, key: &str) -> Option<&mut ArrayD<String>> {
262        self.get_mut(key).and_then(|c| c.as_string_mut())
263    }
264
265    /// Removes and returns the column for `key`, if present.
266    ///
267    /// If the Block becomes empty after removal, resets `nrows` to `None`.
268    pub fn remove(&mut self, key: &str) -> Option<Column> {
269        let out = self.map.remove(key);
270        if self.map.is_empty() {
271            self.nrows = None;
272        }
273        out
274    }
275
276    /// Renames a column from `old_key` to `new_key`.
277    ///
278    /// Returns `true` if the column was successfully renamed, `false` if `old_key` doesn't exist
279    /// or `new_key` already exists.
280    ///
281    /// # Examples
282    ///
283    /// ```
284    /// use molrs::block::Block;
285    /// use molrs::types::F;
286    /// use ndarray::Array1;
287    ///
288    /// let mut block = Block::new();
289    /// block.insert("x", Array1::from_vec(vec![1.0 as F]).into_dyn()).unwrap();
290    ///
291    /// assert!(block.rename_column("x", "position_x"));
292    /// assert!(!block.contains_key("x"));
293    /// assert!(block.contains_key("position_x"));
294    /// ```
295    pub fn rename_column(&mut self, old_key: &str, new_key: &str) -> bool {
296        // Check if old_key exists and new_key doesn't exist
297        if !self.map.contains_key(old_key) || self.map.contains_key(new_key) {
298            return false;
299        }
300
301        // Remove the old key and re-insert with new key
302        if let Some(column) = self.map.remove(old_key) {
303            self.map.insert(new_key.to_string(), column);
304            true
305        } else {
306            false
307        }
308    }
309
310    /// Clears the Block, removing all keys and resetting `nrows`.
311    pub fn clear(&mut self) {
312        self.map.clear();
313        self.nrows = None;
314    }
315
316    /// Returns an iterator over (&str, &Column).
317    pub fn iter(&self) -> impl Iterator<Item = (&str, &Column)> {
318        self.map.iter().map(|(k, v)| (k.as_str(), v))
319    }
320
321    /// Returns an iterator over keys.
322    pub fn keys(&self) -> impl Iterator<Item = &str> {
323        self.map.keys().map(|k| k.as_str())
324    }
325
326    /// Returns an iterator over column references.
327    pub fn values(&self) -> impl Iterator<Item = &Column> {
328        self.map.values()
329    }
330
331    /// Returns the data type of the column with the given key, if it exists.
332    pub fn dtype(&self, key: &str) -> Option<DType> {
333        self.get(key).map(|c| c.dtype())
334    }
335
336    /// Resize all columns along axis 0 to `new_nrows`.
337    ///
338    /// - **Shrink** (`new_nrows` < current): slices each column to keep the first `new_nrows` rows.
339    /// - **Grow** (`new_nrows` > current): extends each column with default values
340    ///   (0.0 for Float, 0 for Int/UInt/U8, false for Bool, empty string for String).
341    /// - **Same size**: no-op, returns `Ok(())`.
342    /// - **Empty block** (no columns): sets `nrows` without touching columns.
343    ///
344    /// Multi-dimensional columns (e.g. Nx3 positions) are resized only along
345    /// axis 0; trailing dimensions are preserved.
346    ///
347    /// # Arguments
348    /// * `new_nrows` - The desired number of rows after resize.
349    ///
350    /// # Returns
351    /// * `Ok(())` on success.
352    ///
353    /// # Examples
354    ///
355    /// ```
356    /// use molrs::block::Block;
357    /// use molrs::types::F;
358    /// use ndarray::Array1;
359    ///
360    /// let mut block = Block::new();
361    /// block.insert("x", Array1::from_vec(vec![1.0 as F, 2.0 as F]).into_dyn()).unwrap();
362    ///
363    /// block.resize(4).unwrap();
364    /// assert_eq!(block.nrows(), Some(4));
365    /// let x = block.get_float("x").unwrap();
366    /// assert_eq!(x.as_slice_memory_order().unwrap(), &[1.0, 2.0, 0.0, 0.0]);
367    /// ```
368    pub fn resize(&mut self, new_nrows: usize) -> Result<(), crate::error::MolRsError> {
369        if self.is_empty() {
370            self.nrows = Some(new_nrows);
371            return Ok(());
372        }
373
374        let current = self.nrows.unwrap_or(0);
375        if new_nrows == current {
376            return Ok(());
377        }
378
379        for col in self.map.values_mut() {
380            col.resize(new_nrows);
381        }
382        self.nrows = Some(new_nrows);
383        Ok(())
384    }
385
386    /// Merge another block into this one by concatenating columns along axis-0.
387    ///
388    /// Both blocks must have the same set of column keys and matching dtypes.
389    /// The resulting block will have nrows = self.nrows + other.nrows.
390    ///
391    /// # Arguments
392    /// * `other` - The block to merge into this one
393    ///
394    /// # Returns
395    /// * `Ok(())` if merge succeeds
396    /// * `Err(BlockError)` if blocks have incompatible columns
397    ///
398    /// # Examples
399    ///
400    /// ```
401    /// use molrs::block::Block;
402    /// use molrs::types::F;
403    /// use ndarray::Array1;
404    ///
405    /// let mut block1 = Block::new();
406    /// block1.insert("x", Array1::from_vec(vec![1.0 as F, 2.0 as F]).into_dyn()).unwrap();
407    ///
408    /// let mut block2 = Block::new();
409    /// block2.insert("x", Array1::from_vec(vec![3.0 as F, 4.0 as F]).into_dyn()).unwrap();
410    ///
411    /// block1.merge(&block2).unwrap();
412    /// assert_eq!(block1.nrows(), Some(4));
413    /// ```
414    pub fn merge(&mut self, other: &Block) -> Result<(), BlockError> {
415        use ndarray::Axis;
416        use ndarray::concatenate;
417
418        // If other is empty, nothing to do
419        if other.is_empty() {
420            return Ok(());
421        }
422
423        // If self is empty, clone other
424        if self.is_empty() {
425            self.map = other.map.clone();
426            self.nrows = other.nrows;
427            return Ok(());
428        }
429
430        // Check that both blocks have the same keys
431        let self_keys: std::collections::HashSet<_> = self.keys().collect();
432        let other_keys: std::collections::HashSet<_> = other.keys().collect();
433
434        if self_keys != other_keys {
435            return Err(BlockError::validation(format!(
436                "Cannot merge blocks with different keys. Self has {:?}, other has {:?}",
437                self_keys, other_keys
438            )));
439        }
440
441        // Merge each column
442        let mut new_map = HashMap::new();
443        for key in self.keys() {
444            let self_col = &self.map[key];
445            let other_col = &other.map[key];
446
447            // Check dtype compatibility
448            if self_col.dtype() != other_col.dtype() {
449                return Err(BlockError::validation(format!(
450                    "Column '{}' has incompatible dtypes: {:?} vs {:?}",
451                    key,
452                    self_col.dtype(),
453                    other_col.dtype()
454                )));
455            }
456
457            // Concatenate based on dtype
458            let merged_col = match (self_col, other_col) {
459                (Column::Float(a), Column::Float(b)) => {
460                    let merged = concatenate(Axis(0), &[a.view(), b.view()]).map_err(|e| {
461                        BlockError::validation(format!(
462                            "Failed to concatenate float column '{}': {}",
463                            key, e
464                        ))
465                    })?;
466                    Column::Float(merged)
467                }
468                (Column::Int(a), Column::Int(b)) => {
469                    let merged = concatenate(Axis(0), &[a.view(), b.view()]).map_err(|e| {
470                        BlockError::validation(format!(
471                            "Failed to concatenate int column '{}': {}",
472                            key, e
473                        ))
474                    })?;
475                    Column::Int(merged)
476                }
477                (Column::UInt(a), Column::UInt(b)) => {
478                    let merged = concatenate(Axis(0), &[a.view(), b.view()]).map_err(|e| {
479                        BlockError::validation(format!(
480                            "Failed to concatenate uint column '{}': {}",
481                            key, e
482                        ))
483                    })?;
484                    Column::UInt(merged)
485                }
486                (Column::U8(a), Column::U8(b)) => {
487                    let merged = concatenate(Axis(0), &[a.view(), b.view()]).map_err(|e| {
488                        BlockError::validation(format!(
489                            "Failed to concatenate u8 column '{}': {}",
490                            key, e
491                        ))
492                    })?;
493                    Column::U8(merged)
494                }
495                (Column::Bool(a), Column::Bool(b)) => {
496                    let merged = concatenate(Axis(0), &[a.view(), b.view()]).map_err(|e| {
497                        BlockError::validation(format!(
498                            "Failed to concatenate bool column '{}': {}",
499                            key, e
500                        ))
501                    })?;
502                    Column::Bool(merged)
503                }
504                (Column::String(a), Column::String(b)) => {
505                    let merged = concatenate(Axis(0), &[a.view(), b.view()]).map_err(|e| {
506                        BlockError::validation(format!(
507                            "Failed to concatenate string column '{}': {}",
508                            key, e
509                        ))
510                    })?;
511                    Column::String(merged)
512                }
513                _ => unreachable!("dtype mismatch already checked"),
514            };
515
516            new_map.insert(key.to_string(), merged_col);
517        }
518
519        // Update nrows
520        let new_nrows = self.nrows.unwrap() + other.nrows.unwrap();
521        self.map = new_map;
522        self.nrows = Some(new_nrows);
523
524        Ok(())
525    }
526}
527
528// Index trait for convenient access: block["key"]
529impl Index<&str> for Block {
530    type Output = Column;
531
532    fn index(&self, key: &str) -> &Self::Output {
533        self.get(key)
534            .unwrap_or_else(|| panic!("key '{}' not found in Block", key))
535    }
536}
537
538impl IndexMut<&str> for Block {
539    fn index_mut(&mut self, key: &str) -> &mut Self::Output {
540        self.get_mut(key)
541            .unwrap_or_else(|| panic!("key '{}' not found in Block", key))
542    }
543}
544
545#[cfg(test)]
546mod tests {
547    use super::*;
548    use crate::types::{F, I};
549    use ndarray::Array1;
550
551    #[test]
552    fn test_insert_mixed_dtypes() {
553        let mut block = Block::new();
554
555        let arr_float = Array1::from_vec(vec![1.0 as F, 2.0 as F, 3.0 as F]).into_dyn();
556        let arr_float_2 = Array1::from_vec(vec![4.0 as F, 5.0 as F, 6.0 as F]).into_dyn();
557        let arr_i64 = Array1::from_vec(vec![10 as I, 20, 30]).into_dyn();
558        let arr_bool = Array1::from_vec(vec![true, false, true]).into_dyn();
559
560        assert!(block.insert("x", arr_float).is_ok());
561        assert!(block.insert("y", arr_float_2).is_ok());
562        assert!(block.insert("id", arr_i64).is_ok());
563        assert!(block.insert("mask", arr_bool).is_ok());
564
565        assert_eq!(block.len(), 4);
566        assert_eq!(block.nrows(), Some(3));
567    }
568
569    #[test]
570    fn test_axis0_mismatch_error() {
571        let mut block = Block::new();
572
573        let arr1 = Array1::from_vec(vec![1.0 as F, 2.0 as F, 3.0 as F]).into_dyn();
574        let arr2 = Array1::from_vec(vec![1.0 as F, 2.0 as F]).into_dyn();
575
576        block.insert("x", arr1).unwrap();
577        let result = block.insert("y", arr2);
578
579        assert!(result.is_err());
580        match result {
581            Err(BlockError::RaggedAxis0 { expected, got, .. }) => {
582                assert_eq!(expected, 3);
583                assert_eq!(got, 2);
584            }
585            _ => panic!("Expected RaggedAxis0 error"),
586        }
587    }
588
589    #[test]
590    fn test_typed_getters() {
591        let mut block = Block::new();
592
593        let arr_float = Array1::from_vec(vec![1.0 as F, 2.0 as F]).into_dyn();
594        let arr_i64 = Array1::from_vec(vec![10 as I, 20]).into_dyn();
595
596        block.insert("x", arr_float).unwrap();
597        block.insert("id", arr_i64).unwrap();
598
599        // Correct type access
600        assert!(block.get_float("x").is_some());
601        assert!(block.get_int("id").is_some());
602
603        // Wrong type access returns None
604        assert!(block.get_int("x").is_none());
605        assert!(block.get_float("id").is_none());
606
607        // Mutable access
608        if let Some(x_mut) = block.get_float_mut("x") {
609            x_mut[[0]] = 99.0;
610        }
611        assert_eq!(block.get_float("x").unwrap()[[0]], 99.0);
612    }
613
614    #[test]
615    fn test_index_access() {
616        let mut block = Block::new();
617
618        let arr = Array1::from_vec(vec![1.0 as F, 2.0 as F]).into_dyn();
619        block.insert("x", arr).unwrap();
620
621        // Immutable index
622        let col = &block["x"];
623        assert_eq!(col.dtype(), DType::Float);
624
625        // Mutable index
626        let col_mut = &mut block["x"];
627        if let Some(arr_mut) = col_mut.as_float_mut() {
628            arr_mut[[0]] = 42.0;
629        }
630        assert_eq!(block.get_float("x").unwrap()[[0]], 42.0);
631    }
632
633    #[test]
634    #[should_panic(expected = "key 'missing' not found")]
635    fn test_index_panic_on_missing_key() {
636        let block = Block::new();
637        let _ = &block["missing"];
638    }
639
640    #[test]
641    fn test_remove_resets_nrows() {
642        let mut block = Block::new();
643
644        let arr = Array1::from_vec(vec![1.0 as F, 2.0 as F]).into_dyn();
645        block.insert("x", arr).unwrap();
646
647        assert_eq!(block.nrows(), Some(2));
648
649        block.remove("x");
650        assert_eq!(block.nrows(), None);
651        assert!(block.is_empty());
652    }
653
654    #[test]
655    fn test_iter_keys_values() {
656        let mut block = Block::new();
657
658        let arr1 = Array1::from_vec(vec![1.0 as F, 2.0 as F]).into_dyn();
659        let arr2 = Array1::from_vec(vec![10 as I, 20]).into_dyn();
660
661        block.insert("x", arr1).unwrap();
662        block.insert("id", arr2).unwrap();
663
664        let keys: Vec<&str> = block.keys().collect();
665        assert_eq!(keys.len(), 2);
666        assert!(keys.contains(&"x"));
667        assert!(keys.contains(&"id"));
668
669        let dtypes: Vec<DType> = block.values().map(|c| c.dtype()).collect();
670        assert!(dtypes.contains(&DType::Float));
671        assert!(dtypes.contains(&DType::Int));
672    }
673
674    #[test]
675    fn test_rank_zero_error() {
676        let mut block = Block::new();
677
678        // Create a rank-0 array (scalar)
679        let arr = ArrayD::<F>::zeros(vec![]);
680
681        let result = block.insert("scalar", arr);
682        assert!(result.is_err());
683        match result {
684            Err(BlockError::RankZero { key }) => {
685                assert_eq!(key, "scalar");
686            }
687            _ => panic!("Expected RankZero error"),
688        }
689    }
690
691    #[test]
692    fn test_dtype_query() {
693        let mut block = Block::new();
694
695        let arr_float = Array1::from_vec(vec![1.0 as F]).into_dyn();
696        let arr_i64 = Array1::from_vec(vec![10 as I]).into_dyn();
697
698        block.insert("x", arr_float).unwrap();
699        block.insert("id", arr_i64).unwrap();
700
701        assert_eq!(block.dtype("x"), Some(DType::Float));
702        assert_eq!(block.dtype("id"), Some(DType::Int));
703        assert_eq!(block.dtype("missing"), None);
704    }
705
706    #[test]
707    fn test_merge_basic() {
708        let mut block1 = Block::new();
709        let mut block2 = Block::new();
710
711        let arr1 = Array1::from_vec(vec![1.0 as F, 2.0 as F]).into_dyn();
712        let arr2 = Array1::from_vec(vec![3.0 as F, 4.0 as F]).into_dyn();
713
714        block1.insert("x", arr1).unwrap();
715        block2.insert("x", arr2).unwrap();
716
717        block1.merge(&block2).unwrap();
718
719        assert_eq!(block1.nrows(), Some(4));
720        let x = block1.get_float("x").unwrap();
721        assert_eq!(x.as_slice_memory_order().unwrap(), &[1.0, 2.0, 3.0, 4.0]);
722    }
723
724    #[test]
725    fn test_merge_empty_blocks() {
726        let mut block1 = Block::new();
727        let mut block2 = Block::new();
728
729        // Merge empty into empty
730        block1.merge(&block2).unwrap();
731        assert_eq!(block1.nrows(), None);
732
733        // Merge non-empty into empty
734        let arr = Array1::from_vec(vec![1.0 as F, 2.0 as F]).into_dyn();
735        block2.insert("x", arr).unwrap();
736        block1.merge(&block2).unwrap();
737        assert_eq!(block1.nrows(), Some(2));
738
739        // Merge empty into non-empty
740        let block3 = Block::new();
741        block1.merge(&block3).unwrap();
742        assert_eq!(block1.nrows(), Some(2));
743    }
744
745    #[test]
746    fn test_merge_incompatible_keys() {
747        let mut block1 = Block::new();
748        let mut block2 = Block::new();
749
750        let arr1 = Array1::from_vec(vec![1.0 as F, 2.0 as F]).into_dyn();
751        let arr2 = Array1::from_vec(vec![3.0 as F, 4.0 as F]).into_dyn();
752
753        block1.insert("x", arr1).unwrap();
754        block2.insert("y", arr2).unwrap();
755
756        let result = block1.merge(&block2);
757        assert!(result.is_err());
758    }
759
760    #[test]
761    fn test_merge_incompatible_dtypes() {
762        let mut block1 = Block::new();
763        let mut block2 = Block::new();
764
765        let arr1 = Array1::from_vec(vec![1.0 as F, 2.0 as F]).into_dyn();
766        let arr2 = Array1::from_vec(vec![3 as I, 4]).into_dyn();
767
768        block1.insert("x", arr1).unwrap();
769        block2.insert("x", arr2).unwrap();
770
771        let result = block1.merge(&block2);
772        assert!(result.is_err());
773    }
774
775    #[test]
776    fn test_merge_multiple_columns() {
777        let mut block1 = Block::new();
778        let mut block2 = Block::new();
779
780        block1
781            .insert("x", Array1::from_vec(vec![1.0 as F, 2.0 as F]).into_dyn())
782            .unwrap();
783        block1
784            .insert("id", Array1::from_vec(vec![10 as I, 20]).into_dyn())
785            .unwrap();
786
787        block2
788            .insert("x", Array1::from_vec(vec![3.0 as F, 4.0 as F]).into_dyn())
789            .unwrap();
790        block2
791            .insert("id", Array1::from_vec(vec![30 as I, 40]).into_dyn())
792            .unwrap();
793
794        block1.merge(&block2).unwrap();
795
796        assert_eq!(block1.nrows(), Some(4));
797        let x = block1.get_float("x").unwrap();
798        assert_eq!(x.as_slice_memory_order().unwrap(), &[1.0, 2.0, 3.0, 4.0]);
799        let id = block1.get_int("id").unwrap();
800        assert_eq!(id.as_slice_memory_order().unwrap(), &[10, 20, 30, 40]);
801    }
802
803    #[test]
804    fn test_rename_column() {
805        let mut block = Block::new();
806        block
807            .insert("x", Array1::from_vec(vec![1.0 as F, 2.0 as F]).into_dyn())
808            .unwrap();
809        block
810            .insert("y", Array1::from_vec(vec![3.0 as F, 4.0 as F]).into_dyn())
811            .unwrap();
812
813        // Successful rename
814        assert!(block.rename_column("x", "position_x"));
815        assert!(!block.contains_key("x"));
816        assert!(block.contains_key("position_x"));
817        assert_eq!(
818            block
819                .get_float("position_x")
820                .unwrap()
821                .as_slice_memory_order()
822                .unwrap(),
823            &[1.0, 2.0]
824        );
825
826        // Try to rename non-existent column
827        assert!(!block.rename_column("nonexistent", "new_name"));
828
829        // Try to rename to existing column name
830        assert!(!block.rename_column("position_x", "y"));
831    }
832
833    #[test]
834    fn test_resize_shrink() {
835        let mut block = Block::new();
836        block
837            .insert(
838                "x",
839                Array1::from_vec(vec![1.0 as F, 2.0, 3.0, 4.0]).into_dyn(),
840            )
841            .unwrap();
842        block
843            .insert("id", Array1::from_vec(vec![10 as I, 20, 30, 40]).into_dyn())
844            .unwrap();
845
846        block.resize(2).unwrap();
847
848        assert_eq!(block.nrows(), Some(2));
849        let x = block.get_float("x").unwrap();
850        assert_eq!(x.as_slice_memory_order().unwrap(), &[1.0, 2.0]);
851        let id = block.get_int("id").unwrap();
852        assert_eq!(id.as_slice_memory_order().unwrap(), &[10, 20]);
853    }
854
855    #[test]
856    fn test_resize_grow() {
857        let mut block = Block::new();
858        block
859            .insert("x", Array1::from_vec(vec![1.0 as F, 2.0]).into_dyn())
860            .unwrap();
861        block
862            .insert("id", Array1::from_vec(vec![10 as I, 20]).into_dyn())
863            .unwrap();
864
865        block.resize(4).unwrap();
866
867        assert_eq!(block.nrows(), Some(4));
868        let x = block.get_float("x").unwrap();
869        // Original data preserved, new rows are 0.0
870        assert_eq!(x.as_slice_memory_order().unwrap(), &[1.0, 2.0, 0.0, 0.0]);
871        let id = block.get_int("id").unwrap();
872        // Original data preserved, new rows are 0
873        assert_eq!(id.as_slice_memory_order().unwrap(), &[10, 20, 0, 0]);
874    }
875
876    #[test]
877    fn test_resize_same() {
878        let mut block = Block::new();
879        block
880            .insert("x", Array1::from_vec(vec![1.0 as F, 2.0, 3.0]).into_dyn())
881            .unwrap();
882
883        block.resize(3).unwrap();
884
885        assert_eq!(block.nrows(), Some(3));
886        let x = block.get_float("x").unwrap();
887        assert_eq!(x.as_slice_memory_order().unwrap(), &[1.0, 2.0, 3.0]);
888    }
889
890    #[test]
891    fn test_resize_empty() {
892        let mut block = Block::new();
893
894        block.resize(5).unwrap();
895        assert_eq!(block.nrows(), Some(5));
896        assert!(block.is_empty());
897    }
898
899    #[test]
900    fn test_resize_multidim() {
901        use ndarray::Array2;
902
903        let mut block = Block::new();
904        // 4x3 position array
905        let pos = Array2::from_shape_vec(
906            (4, 3),
907            vec![
908                1.0 as F, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0, 11.0, 12.0,
909            ],
910        )
911        .unwrap()
912        .into_dyn();
913        block.insert("pos", pos).unwrap();
914
915        // Shrink 4x3 -> 2x3
916        block.resize(2).unwrap();
917        assert_eq!(block.nrows(), Some(2));
918        let pos = block.get_float("pos").unwrap();
919        assert_eq!(pos.shape(), &[2, 3]);
920        assert_eq!(
921            pos.as_slice_memory_order().unwrap(),
922            &[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]
923        );
924
925        // Grow 2x3 -> 5x3
926        block.resize(5).unwrap();
927        assert_eq!(block.nrows(), Some(5));
928        let pos = block.get_float("pos").unwrap();
929        assert_eq!(pos.shape(), &[5, 3]);
930        // Original data followed by zeros
931        assert_eq!(
932            pos.as_slice_memory_order().unwrap(),
933            &[
934                1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0
935            ]
936        );
937    }
938
939    #[test]
940    fn test_resize_mixed_dtypes() {
941        let mut block = Block::new();
942        block
943            .insert("x", Array1::from_vec(vec![1.0 as F, 2.0, 3.0]).into_dyn())
944            .unwrap();
945        block
946            .insert("id", Array1::from_vec(vec![10 as I, 20, 30]).into_dyn())
947            .unwrap();
948        block
949            .insert("mask", Array1::from_vec(vec![true, false, true]).into_dyn())
950            .unwrap();
951        block
952            .insert(
953                "name",
954                Array1::from_vec(vec!["a".to_string(), "b".to_string(), "c".to_string()])
955                    .into_dyn(),
956            )
957            .unwrap();
958
959        // Grow from 3 to 5
960        block.resize(5).unwrap();
961        assert_eq!(block.nrows(), Some(5));
962
963        let x = block.get_float("x").unwrap();
964        assert_eq!(
965            x.as_slice_memory_order().unwrap(),
966            &[1.0, 2.0, 3.0, 0.0, 0.0]
967        );
968        let id = block.get_int("id").unwrap();
969        assert_eq!(id.as_slice_memory_order().unwrap(), &[10, 20, 30, 0, 0]);
970        let mask = block.get_bool("mask").unwrap();
971        assert_eq!(
972            mask.as_slice_memory_order().unwrap(),
973            &[true, false, true, false, false]
974        );
975        let name = block.get_string("name").unwrap();
976        assert_eq!(name[[0]], "a");
977        assert_eq!(name[[1]], "b");
978        assert_eq!(name[[2]], "c");
979        assert_eq!(name[[3]], "");
980        assert_eq!(name[[4]], "");
981
982        // Shrink from 5 to 2
983        block.resize(2).unwrap();
984        assert_eq!(block.nrows(), Some(2));
985
986        let x = block.get_float("x").unwrap();
987        assert_eq!(x.as_slice_memory_order().unwrap(), &[1.0, 2.0]);
988        let id = block.get_int("id").unwrap();
989        assert_eq!(id.as_slice_memory_order().unwrap(), &[10, 20]);
990        let mask = block.get_bool("mask").unwrap();
991        assert_eq!(mask.as_slice_memory_order().unwrap(), &[true, false]);
992        let name = block.get_string("name").unwrap();
993        assert_eq!(name[[0]], "a");
994        assert_eq!(name[[1]], "b");
995    }
996}