Skip to main content

molrs_core/region/
simbox.rs

1//! Triclinic simulation box and periodic operations based on ndarray.
2//!
3//! Conventions (fractional/cartesian):
4//! - cart = origin + H * frac
5//! - frac = H^{-1} * (cart - origin)
6//! - Lattice vectors are the columns of H.
7
8use super::region::Region;
9use crate::math;
10use crate::types::{F, F3, F3View, F3x3, FNx3, FNx3View, Pbc3};
11use ndarray::{Array1, Array2, ArrayView1, array};
12
13/// Box geometry kind, detected once at construction.
14#[derive(Debug, Clone, PartialEq)]
15pub enum BoxKind {
16    /// Orthorhombic (diagonal H): lengths, inverse lengths cached.
17    Ortho { len: F3, inv_len: F3 },
18    /// General triclinic.
19    Triclinic,
20}
21
22/// Simulation box: triclinic cell with origin and per-axis PBC mask
23#[derive(Debug, Clone)]
24pub struct SimBox {
25    /// Triclinic cell matrix H (columns are lattice vectors)
26    h: F3x3,
27    /// Precomputed inverse of H
28    inv: F3x3,
29    /// Origin of the cell in Cartesian coordinates
30    origin: F3,
31    /// Per-axis periodic boundary condition flags (x, y, z)
32    pbc: Pbc3,
33    /// Cached geometry kind
34    kind: BoxKind,
35}
36
37/// Error type for simulation box construction.
38#[derive(Debug)]
39pub enum BoxError {
40    /// The cell matrix H is singular (determinant ≈ 0).
41    SingularCell,
42    /// The matrix does not have shape 3x3.
43    InvalidMatrixShape { rows: usize, cols: usize },
44    /// A vector does not have the expected length.
45    InvalidVectorLength { len: usize },
46    /// A required array is not contiguous in memory.
47    NonContiguous(&'static str),
48}
49
50impl SimBox {
51    /// Construct from triclinic cell matrix `H`, origin `O`, and per-axis PBC flags
52    pub fn new(h: F3x3, origin: F3, pbc: Pbc3) -> Result<Self, BoxError> {
53        if let Some(inv) = math::inv3(&h) {
54            let kind = detect_box_kind(&h);
55            Ok(Self {
56                h,
57                inv,
58                origin,
59                pbc,
60                kind,
61            })
62        } else {
63            Err(BoxError::SingularCell)
64        }
65    }
66
67    pub fn try_new(h: F3x3, origin: F3, pbc: Pbc3) -> Result<Self, BoxError> {
68        Self::new(h, origin, pbc)
69    }
70
71    /// Factory: cubic box with edge length `a` and origin `O`
72    pub fn cube(a: F, origin: F3, pbc: Pbc3) -> Result<Self, BoxError> {
73        if a <= 0.0 {
74            return Err(BoxError::InvalidVectorLength { len: 0 });
75        }
76        let h = array![[a, 0.0, 0.0], [0.0, a, 0.0], [0.0, 0.0, a]];
77        Self::new(h, origin, pbc)
78    }
79
80    /// Factory: ortho box with lengths (ax, ay, az) and origin `O`
81    pub fn ortho(lengths: F3, origin: F3, pbc: Pbc3) -> Result<Self, BoxError> {
82        if lengths.len() != 3 {
83            return Err(BoxError::InvalidVectorLength { len: lengths.len() });
84        }
85        if lengths.iter().any(|v| *v <= 0.0) {
86            return Err(BoxError::InvalidVectorLength { len: 0 });
87        }
88        let h = array![
89            [lengths[0], 0.0, 0.0],
90            [0.0, lengths[1], 0.0],
91            [0.0, 0.0, lengths[2]],
92        ];
93        Self::new(h, origin, pbc)
94    }
95
96    /// Create a non-periodic (free-boundary) box enclosing all points.
97    ///
98    /// Computes the axis-aligned bounding box of `points` and adds `padding`
99    /// on each side. The resulting box has `pbc = [false, false, false]`.
100    ///
101    /// `padding` should be >= the neighbor cutoff distance so that all
102    /// particles sit well inside the box for correct cell assignment.
103    ///
104    /// # Errors
105    /// Returns `BoxError` if padding is non-positive or the resulting box is degenerate.
106    ///
107    /// # Panics
108    /// Panics if `padding <= 0`.
109    pub fn free(points: FNx3View<'_>, padding: F) -> Result<Self, BoxError> {
110        assert!(padding > 0.0, "padding must be positive");
111        let n = points.nrows();
112        if n == 0 {
113            // Empty point set -- return a unit cube at origin
114            return Self::cube(padding, array![0.0 as F, 0.0, 0.0], [false, false, false]);
115        }
116        let mut min = array![points[[0, 0]], points[[0, 1]], points[[0, 2]]];
117        let mut max = min.clone();
118        for i in 1..n {
119            for d in 0..3 {
120                if points[[i, d]] < min[d] {
121                    min[d] = points[[i, d]];
122                }
123                if points[[i, d]] > max[d] {
124                    max[d] = points[[i, d]];
125                }
126            }
127        }
128        let origin = array![min[0] - padding, min[1] - padding, min[2] - padding,];
129        let lengths = array![
130            (max[0] - min[0] + 2.0 * padding).max(padding),
131            (max[1] - min[1] + 2.0 * padding).max(padding),
132            (max[2] - min[2] + 2.0 * padding).max(padding),
133        ];
134        Self::ortho(lengths, origin, [false, false, false])
135    }
136
137    /// View of the cell matrix
138    pub fn h_view(&self) -> FNx3View<'_> {
139        self.h.view()
140    }
141
142    /// View of the inverse cell matrix
143    pub fn inv_view(&self) -> FNx3View<'_> {
144        self.inv.view()
145    }
146
147    /// View of the origin
148    pub fn origin_view(&self) -> F3View<'_> {
149        self.origin.view()
150    }
151
152    /// View of the PBC flags
153    pub fn pbc_view(&self) -> ArrayView1<'_, bool> {
154        ArrayView1::from_shape(3, &self.pbc).expect("pbc_view shape")
155    }
156
157    /// Per-axis PBC flags
158    pub fn pbc(&self) -> Pbc3 {
159        self.pbc
160    }
161
162    /// Cell volume (|det(H)|)
163    pub fn volume(&self) -> F {
164        math::det3(&self.h).abs()
165    }
166
167    /// Off-diagonal tilts [xy, xz, yz] of the cell matrix
168    pub fn tilts(&self) -> F3 {
169        array![self.h[[0, 1]], self.h[[0, 2]], self.h[[1, 2]]]
170    }
171
172    /// Lattice vector lengths
173    pub fn lengths(&self) -> F3 {
174        let a = self.lattice(0);
175        let b = self.lattice(1);
176        let c = self.lattice(2);
177        array![math::norm3(&a), math::norm3(&b), math::norm3(&c)]
178    }
179
180    /// Nearest plane distance (half the box size along each axis)
181    /// For triclinic boxes, this is the perpendicular distance to each face
182    pub fn nearest_plane_distance(&self) -> F3 {
183        let v = self.volume();
184        let a1 = self.lattice(0);
185        let a2 = self.lattice(1);
186        let a3 = self.lattice(2);
187
188        let c23 = math::cross3(&a2, &a3);
189        let c31 = math::cross3(&a3, &a1);
190        let c12 = math::cross3(&a1, &a2);
191
192        array![
193            v / math::norm3(&c23),
194            v / math::norm3(&c31),
195            v / math::norm3(&c12)
196        ]
197    }
198
199    pub fn kind(&self) -> &BoxKind {
200        &self.kind
201    }
202
203    /// Lattice vector by index (0,1,2) — columns of H
204    pub fn lattice(&self, index: usize) -> F3 {
205        assert!(index < 3, "lattice index must be 0..2");
206        self.h.column(index).to_owned()
207    }
208
209    /// Convert Cartesian coordinates to fractional coordinates [0, 1)
210    pub fn make_fractional(&self, r: F3View<'_>) -> F3 {
211        let dr = &r - &self.origin.view();
212        let mut frac = self.inv.dot(&dr);
213        for f in frac.iter_mut() {
214            *f -= f.floor();
215        }
216        frac
217    }
218
219    /// Fractional coordinates with ortho fast-path
220    #[inline(always)]
221    pub fn make_fractional_fast(&self, r: F3View<'_>) -> F3 {
222        match &self.kind {
223            BoxKind::Ortho { inv_len, .. } => {
224                let mut frac = array![
225                    (r[0] - self.origin[0]) * inv_len[0],
226                    (r[1] - self.origin[1]) * inv_len[1],
227                    (r[2] - self.origin[2]) * inv_len[2],
228                ];
229                for f in frac.iter_mut() {
230                    *f -= f.floor();
231                }
232                frac
233            }
234            BoxKind::Triclinic => self.make_fractional(r),
235        }
236    }
237
238    /// Fractional coordinates returned as `[F; 3]` (zero-alloc hot path).
239    ///
240    /// Equivalent to [`make_fractional_fast`](Self::make_fractional_fast) but
241    /// avoids the `Array1<F>` heap allocation by returning a stack array.
242    /// Use in tight inner loops (neighbor-list cell assignment, etc.).
243    #[inline(always)]
244    pub fn make_fractional_fast_arr(&self, r: F3View<'_>) -> [F; 3] {
245        match &self.kind {
246            BoxKind::Ortho { inv_len, .. } => {
247                let fx = (r[0] - self.origin[0]) * inv_len[0];
248                let fy = (r[1] - self.origin[1]) * inv_len[1];
249                let fz = (r[2] - self.origin[2]) * inv_len[2];
250                [fx - fx.floor(), fy - fy.floor(), fz - fz.floor()]
251            }
252            BoxKind::Triclinic => {
253                let f = self.make_fractional(r);
254                [f[0], f[1], f[2]]
255            }
256        }
257    }
258
259    /// Convert fractional coordinates to Cartesian coordinates
260    pub fn make_cartesian(&self, frac: F3View<'_>) -> F3 {
261        &self.origin + &self.h.dot(&frac)
262    }
263
264    /// Minimum image displacement vector from r1 to r2 (r2 - r1)
265    #[inline]
266    pub fn shortest_vector(&self, r1: F3View<'_>, r2: F3View<'_>) -> F3 {
267        let dr = &r2 - &r1;
268        let mut dr_frac = self.inv.dot(&dr);
269        for d in 0..3 {
270            if self.pbc[d] {
271                dr_frac[d] -= dr_frac[d].round();
272            }
273        }
274        self.h.dot(&dr_frac)
275    }
276
277    /// Shortest vector with ortho fast-path
278    #[inline(always)]
279    pub fn shortest_vector_fast(&self, a: F3View<'_>, b: F3View<'_>) -> F3 {
280        match &self.kind {
281            BoxKind::Ortho { len, inv_len } => {
282                let mut dr = array![b[0] - a[0], b[1] - a[1], b[2] - a[2]];
283                for d in 0..3 {
284                    if self.pbc[d] {
285                        dr[d] -= (dr[d] * inv_len[d]).round() * len[d];
286                    }
287                }
288                dr
289            }
290            BoxKind::Triclinic => self.shortest_vector(a, b),
291        }
292    }
293
294    /// Shortest vector returned as `[F; 3]` (zero-alloc hot path).
295    ///
296    /// Equivalent to [`shortest_vector_fast`](Self::shortest_vector_fast) but
297    /// avoids the `Array1<F>` heap allocation by returning a stack array.
298    /// Use inside neighbor-list pair loops where called O(N·k) times per build.
299    #[inline(always)]
300    pub fn shortest_vector_fast_arr(&self, a: F3View<'_>, b: F3View<'_>) -> [F; 3] {
301        match &self.kind {
302            BoxKind::Ortho { len, inv_len } => {
303                let mut dr = [b[0] - a[0], b[1] - a[1], b[2] - a[2]];
304                if self.pbc[0] {
305                    dr[0] -= (dr[0] * inv_len[0]).round() * len[0];
306                }
307                if self.pbc[1] {
308                    dr[1] -= (dr[1] * inv_len[1]).round() * len[1];
309                }
310                if self.pbc[2] {
311                    dr[2] -= (dr[2] * inv_len[2]).round() * len[2];
312                }
313                dr
314            }
315            BoxKind::Triclinic => {
316                let v = self.shortest_vector(a, b);
317                [v[0], v[1], v[2]]
318            }
319        }
320    }
321
322    /// Shortest vector from raw `[F; 3]` inputs, returning `[F; 3]`.
323    ///
324    /// Both inputs and output are stack arrays — no `ArrayView` indexing.
325    /// Called from neighbor-list inner loops where positions are stored in a
326    /// flat `[F; 3]` slab instead of an `Array2<F>`.
327    #[inline(always)]
328    pub fn shortest_vector_raw(&self, a: [F; 3], b: [F; 3]) -> [F; 3] {
329        match &self.kind {
330            BoxKind::Ortho { len, inv_len } => {
331                let mut dr = [b[0] - a[0], b[1] - a[1], b[2] - a[2]];
332                if self.pbc[0] {
333                    dr[0] -= (dr[0] * inv_len[0]).round() * len[0];
334                }
335                if self.pbc[1] {
336                    dr[1] -= (dr[1] * inv_len[1]).round() * len[1];
337                }
338                if self.pbc[2] {
339                    dr[2] -= (dr[2] * inv_len[2]).round() * len[2];
340                }
341                dr
342            }
343            BoxKind::Triclinic => {
344                let av = ndarray::ArrayView1::from(&a[..]);
345                let bv = ndarray::ArrayView1::from(&b[..]);
346                let v = self.shortest_vector(av, bv);
347                [v[0], v[1], v[2]]
348            }
349        }
350    }
351
352    /// Calculate squared distance using MIC.
353    #[inline]
354    pub fn calc_distance2(&self, a: F3View<'_>, b: F3View<'_>) -> F {
355        let dr = self.shortest_vector(a, b);
356        dr.dot(&dr)
357    }
358
359    /// Convert Cartesian points to fractional coordinates (N×3)
360    pub fn to_frac(&self, xyz: FNx3View<'_>) -> FNx3 {
361        let n = xyz.nrows();
362        let mut result = FNx3::zeros((n, 3));
363        for i in 0..n {
364            let dr = &xyz.row(i) - &self.origin.view();
365            result.row_mut(i).assign(&self.inv.dot(&dr));
366        }
367        result
368    }
369
370    /// Convert fractional coordinates to Cartesian points (N×3)
371    pub fn to_cart(&self, frac: FNx3View<'_>) -> FNx3 {
372        let n = frac.nrows();
373        let mut result = FNx3::zeros((n, 3));
374        for i in 0..n {
375            let cart = &self.origin + &self.h.dot(&frac.row(i));
376            result.row_mut(i).assign(&cart);
377        }
378        result
379    }
380
381    /// Check if points lie within [0,1) in fractional space.
382    pub fn isin(&self, xyz: FNx3View<'_>) -> Array1<bool> {
383        let n = xyz.nrows();
384        let mut mask = Vec::with_capacity(n);
385        for i in 0..n {
386            let dr = &xyz.row(i) - &self.origin.view();
387            let frac = self.inv.dot(&dr);
388            let inside = (0..3).all(|d| frac[d] >= 0.0 && frac[d] < 1.0);
389            mask.push(inside);
390        }
391        Array1::from_vec(mask)
392    }
393
394    /// Batched displacement vectors row-wise (N×3).
395    /// Writes result into `out` to avoid allocation.
396    pub fn delta_out(
397        &self,
398        xyzu1: FNx3View<'_>,
399        xyzu2: FNx3View<'_>,
400        out: &mut FNx3,
401        minimum_image: bool,
402    ) {
403        assert_eq!(xyzu1.nrows(), xyzu2.nrows());
404        let n = xyzu1.nrows();
405        if minimum_image {
406            for i in 0..n {
407                let dr = self.shortest_vector(xyzu1.row(i), xyzu2.row(i));
408                out.row_mut(i).assign(&dr);
409            }
410        } else {
411            for i in 0..n {
412                let dr = &xyzu2.row(i) - &xyzu1.row(i);
413                out.row_mut(i).assign(&dr);
414            }
415        }
416    }
417
418    /// Batched displacement vectors row-wise (N×3)
419    pub fn delta(&self, xyzu1: FNx3View<'_>, xyzu2: FNx3View<'_>, minimum_image: bool) -> FNx3 {
420        assert_eq!(xyzu1.nrows(), xyzu2.nrows());
421        let n = xyzu1.nrows();
422        let mut out = FNx3::zeros((n, 3));
423        self.delta_out(xyzu1, xyzu2, &mut out, minimum_image);
424        out
425    }
426
427    /// Wrap Cartesian points into the unit cell according to PBC
428    pub fn wrap(&self, xyz: FNx3View<'_>) -> FNx3 {
429        let mut frac = self.to_frac(xyz);
430        let n = frac.nrows();
431        for i in 0..n {
432            for d in 0..3 {
433                if self.pbc[d] {
434                    frac[[i, d]] -= frac[[i, d]].floor();
435                }
436            }
437        }
438        self.to_cart(frac.view())
439    }
440
441    pub fn get_corners(&self) -> FNx3 {
442        let l = self.lengths();
443        let (ox, oy, oz) = (self.origin[0], self.origin[1], self.origin[2]);
444        let (lx, ly, lz) = (l[0], l[1], l[2]);
445        array![
446            [ox, oy, oz],
447            [ox + lx, oy, oz],
448            [ox + lx, oy + ly, oz],
449            [ox, oy + ly, oz],
450            [ox, oy, oz + lz],
451            [ox + lx, oy, oz + lz],
452            [ox + lx, oy + ly, oz + lz],
453            [ox, oy + ly, oz + lz],
454        ]
455    }
456}
457
458impl Region for SimBox {
459    fn bounds(&self) -> FNx3 {
460        let lengths = self.lengths();
461        let mut b = Array2::zeros((3, 2));
462        for d in 0..3 {
463            b[[d, 0]] = self.origin[d];
464            b[[d, 1]] = self.origin[d] + lengths[d];
465        }
466        b
467    }
468
469    fn contains(&self, points: &FNx3) -> Array1<bool> {
470        self.isin(points.view())
471    }
472
473    fn contains_point(&self, point: &[F; 3]) -> bool {
474        let r = ArrayView1::from_shape(3, point).expect("contains_point shape");
475        let dr = &r - &self.origin.view();
476        let frac = self.inv.dot(&dr);
477        (0..3).all(|d| frac[d] >= 0.0 && frac[d] < 1.0)
478    }
479}
480
481fn detect_box_kind(h: &F3x3) -> BoxKind {
482    let eps: F = 1e-12;
483    let is_ortho = h[[0, 1]].abs() < eps
484        && h[[0, 2]].abs() < eps
485        && h[[1, 0]].abs() < eps
486        && h[[1, 2]].abs() < eps
487        && h[[2, 0]].abs() < eps
488        && h[[2, 1]].abs() < eps;
489    if is_ortho {
490        let len = array![h[[0, 0]], h[[1, 1]], h[[2, 2]]];
491        let inv_len = array![1.0 / len[0], 1.0 / len[1], 1.0 / len[2]];
492        BoxKind::Ortho { len, inv_len }
493    } else {
494        BoxKind::Triclinic
495    }
496}
497
498#[cfg(test)]
499mod tests {
500    use super::*;
501
502    fn assert_close(a: F, b: F) {
503        assert!((a - b).abs() < 1e-6 as F, "{} != {}", a, b);
504    }
505
506    #[test]
507    fn roundtrip_frac_cart() {
508        let bx = SimBox::ortho(
509            array![2.0, 3.0, 4.0],
510            array![0.5, -1.0, 2.0],
511            [true, true, true],
512        )
513        .expect("invalid box lengths");
514        let pts = array![[0.5, -1.0, 2.0], [2.5, 2.0, 6.0]];
515        let frac = bx.to_frac(pts.view());
516        let cart = bx.to_cart(frac.view());
517        assert!((&pts - &cart).iter().all(|v| v.abs() < 1e-5));
518    }
519
520    #[test]
521    fn wrap_into_cell() {
522        let bx = SimBox::cube(2.0, array![0.0, 0.0, 0.0], [true, true, true])
523            .expect("invalid box length");
524        let pts = array![[2.1, -0.1, 3.9], [-1.9, 4.2, 0.0]];
525        let wrapped = bx.wrap(pts.view());
526        let frac = bx.to_frac(wrapped.view());
527        for i in 0..wrapped.nrows() {
528            let fx = frac[[i, 0]];
529            let fy = frac[[i, 1]];
530            let fz = frac[[i, 2]];
531            assert!((0.0..1.0).contains(&fx));
532            assert!((0.0..1.0).contains(&fy));
533            assert!((0.0..1.0).contains(&fz));
534        }
535    }
536
537    #[test]
538    fn calc_distance_matches_components() {
539        let bx = SimBox::cube(3.0, array![0.0, 0.0, 0.0], [true, true, true])
540            .expect("invalid box length");
541        let a = array![0.1, 0.2, 0.3];
542        let b = array![2.9, 0.2, 0.3];
543        let d2 = bx.calc_distance2(a.view(), b.view());
544        let dr = bx.shortest_vector(a.view(), b.view());
545        let expected = dr.dot(&dr);
546        assert!((d2 - expected).abs() < 1e-6);
547    }
548
549    #[test]
550    fn test_lengths_ortho() {
551        let bx = SimBox::ortho(
552            array![2.0, 4.0, 5.0],
553            array![0.0, 0.0, 0.0],
554            [true, true, true],
555        )
556        .expect("invalid box lengths");
557        let lengths = bx.lengths();
558        assert_close(lengths[0], 2.0);
559        assert_close(lengths[1], 4.0);
560        assert_close(lengths[2], 5.0);
561    }
562
563    #[test]
564    fn test_tilts_values() {
565        let h = array![[2.0, 1.0, 2.0], [0.0, 4.0, 3.0], [0.0, 0.0, 5.0]];
566        let bx = SimBox::new(h, array![0.0, 0.0, 0.0], [true, true, true]).expect("invalid box");
567        let tilts = bx.tilts();
568        assert_close(tilts[0], 1.0);
569        assert_close(tilts[1], 2.0);
570        assert_close(tilts[2], 3.0);
571    }
572
573    #[test]
574    fn test_volume() {
575        let bx = SimBox::ortho(
576            array![2.0, 3.0, 4.0],
577            array![0.0, 0.0, 0.0],
578            [true, true, true],
579        )
580        .expect("invalid box lengths");
581        assert_close(bx.volume(), 24.0);
582    }
583
584    #[test]
585    fn test_wrap_single_and_multi() {
586        let bx = SimBox::cube(2.0, array![0.0, 0.0, 0.0], [true, true, true])
587            .expect("invalid box length");
588        let pts = array![[10.0, -5.0, -5.0], [0.0, 0.5, 0.0]];
589        let wrapped = bx.wrap(pts.view());
590        assert_close(wrapped[[0, 0]], 0.0);
591        assert_close(wrapped[[0, 1]], 1.0);
592        assert_close(wrapped[[0, 2]], 1.0);
593        assert_close(wrapped[[1, 0]], 0.0);
594        assert_close(wrapped[[1, 1]], 0.5);
595        assert_close(wrapped[[1, 2]], 0.0);
596    }
597
598    #[test]
599    fn test_fractional_and_cartesian() {
600        let bx = SimBox::cube(2.0, array![0.0, 0.0, 0.0], [true, true, true])
601            .expect("invalid box length");
602        let p = array![-1.0, -1.0, -1.0];
603        let frac = bx.make_fractional(p.view());
604        assert_close(frac[0], 0.5);
605        assert_close(frac[1], 0.5);
606        assert_close(frac[2], 0.5);
607        let cart = bx.make_cartesian(frac.view());
608        assert_close(cart[0], 1.0);
609        assert_close(cart[1], 1.0);
610        assert_close(cart[2], 1.0);
611    }
612
613    #[test]
614    fn test_to_frac_to_cart_roundtrip() {
615        let bx = SimBox::ortho(
616            array![2.0, 3.0, 4.0],
617            array![1.0, 2.0, 3.0],
618            [true, true, true],
619        )
620        .expect("invalid box lengths");
621        let pts = array![[1.0, 2.0, 3.0], [2.0, 3.0, 4.0]];
622        let frac = bx.to_frac(pts.view());
623        let cart = bx.to_cart(frac.view());
624        for i in 0..pts.nrows() {
625            for j in 0..3 {
626                assert_close(pts[[i, j]], cart[[i, j]]);
627            }
628        }
629    }
630
631    #[test]
632    fn test_shortest_vector_and_distance() {
633        let bx = SimBox::cube(2.0, array![0.0, 0.0, 0.0], [true, true, true])
634            .expect("invalid box length");
635        let a = array![0.1, 0.0, 0.0];
636        let b = array![1.9, 0.0, 0.0];
637        let dr = bx.shortest_vector(a.view(), b.view());
638        assert_close(dr[0], -0.2);
639        assert_close(dr[1], 0.0);
640        assert_close(dr[2], 0.0);
641        let d2 = bx.calc_distance2(a.view(), b.view());
642        assert_close(d2, 0.04);
643    }
644
645    #[test]
646    fn test_contains_point_non_pbc() {
647        let bx = SimBox::cube(2.0, array![0.0, 0.0, 0.0], [false, false, false])
648            .expect("invalid box length");
649        assert!(bx.contains_point(&[0.5, 0.5, 0.5]));
650        assert!(!bx.contains_point(&[-0.1, 0.5, 0.5]));
651        assert!(!bx.contains_point(&[2.1, 0.5, 0.5]));
652    }
653
654    #[test]
655    fn test_contains_mask() {
656        let bx = SimBox::cube(2.0, array![0.0, 0.0, 0.0], [true, true, true])
657            .expect("invalid box length");
658        let pts = array![[0.1, 0.1, 0.1], [2.1, 0.0, 0.0], [-0.1, 0.0, 0.0]];
659        let mask = bx.contains(&pts);
660        assert!(mask[0]);
661        assert!(!mask[1]);
662        assert!(!mask[2]);
663    }
664
665    #[test]
666    fn test_simbox_free_basic() {
667        let pts = array![[1.0 as F, 2.0, 3.0], [4.0, 5.0, 6.0]];
668        let bx = SimBox::free(pts.view(), 1.0).unwrap();
669        assert_eq!(bx.pbc(), [false, false, false]);
670        // origin should be min - padding = [0.0, 1.0, 2.0]
671        let o = bx.origin_view();
672        assert!((o[0] - 0.0).abs() < 1e-5);
673        assert!((o[1] - 1.0).abs() < 1e-5);
674        assert!((o[2] - 2.0).abs() < 1e-5);
675        // lengths should be (max-min) + 2*padding = [5.0, 5.0, 5.0]
676        let l = bx.lengths();
677        assert!((l[0] - 5.0).abs() < 1e-5);
678        assert!((l[1] - 5.0).abs() < 1e-5);
679        assert!((l[2] - 5.0).abs() < 1e-5);
680    }
681
682    #[test]
683    fn test_simbox_free_single_point() {
684        let pts = array![[1.0 as F, 2.0, 3.0]];
685        let bx = SimBox::free(pts.view(), 2.0).unwrap();
686        assert_eq!(bx.pbc(), [false, false, false]);
687        // lengths = max(0 + 4, 2) = 4 on each axis
688        let l = bx.lengths();
689        assert!(l[0] >= 2.0);
690        assert!(l[1] >= 2.0);
691        assert!(l[2] >= 2.0);
692    }
693
694    #[test]
695    fn test_simbox_free_empty() {
696        use ndarray::Array2;
697        let pts = Array2::<F>::zeros((0, 3));
698        let bx = SimBox::free(pts.view(), 1.0).unwrap();
699        assert_eq!(bx.pbc(), [false, false, false]);
700    }
701
702    #[test]
703    fn test_simbox_pbc_accessor() {
704        let bx = SimBox::cube(1.0, array![0.0 as F, 0.0, 0.0], [true, false, true]).unwrap();
705        assert_eq!(bx.pbc(), [true, false, true]);
706    }
707}