1use super::region::Region;
9use crate::math;
10use crate::types::{F, F3, F3View, F3x3, FNx3, FNx3View, Pbc3};
11use ndarray::{Array1, Array2, ArrayView1, array};
12
13#[derive(Debug, Clone, PartialEq)]
15pub enum BoxKind {
16 Ortho { len: F3, inv_len: F3 },
18 Triclinic,
20}
21
22#[derive(Debug, Clone)]
24pub struct SimBox {
25 h: F3x3,
27 inv: F3x3,
29 origin: F3,
31 pbc: Pbc3,
33 kind: BoxKind,
35}
36
37#[derive(Debug)]
39pub enum BoxError {
40 SingularCell,
42 InvalidMatrixShape { rows: usize, cols: usize },
44 InvalidVectorLength { len: usize },
46 NonContiguous(&'static str),
48}
49
50impl SimBox {
51 pub fn new(h: F3x3, origin: F3, pbc: Pbc3) -> Result<Self, BoxError> {
53 if let Some(inv) = math::inv3(&h) {
54 let kind = detect_box_kind(&h);
55 Ok(Self {
56 h,
57 inv,
58 origin,
59 pbc,
60 kind,
61 })
62 } else {
63 Err(BoxError::SingularCell)
64 }
65 }
66
67 pub fn try_new(h: F3x3, origin: F3, pbc: Pbc3) -> Result<Self, BoxError> {
68 Self::new(h, origin, pbc)
69 }
70
71 pub fn cube(a: F, origin: F3, pbc: Pbc3) -> Result<Self, BoxError> {
73 if a <= 0.0 {
74 return Err(BoxError::InvalidVectorLength { len: 0 });
75 }
76 let h = array![[a, 0.0, 0.0], [0.0, a, 0.0], [0.0, 0.0, a]];
77 Self::new(h, origin, pbc)
78 }
79
80 pub fn ortho(lengths: F3, origin: F3, pbc: Pbc3) -> Result<Self, BoxError> {
82 if lengths.len() != 3 {
83 return Err(BoxError::InvalidVectorLength { len: lengths.len() });
84 }
85 if lengths.iter().any(|v| *v <= 0.0) {
86 return Err(BoxError::InvalidVectorLength { len: 0 });
87 }
88 let h = array![
89 [lengths[0], 0.0, 0.0],
90 [0.0, lengths[1], 0.0],
91 [0.0, 0.0, lengths[2]],
92 ];
93 Self::new(h, origin, pbc)
94 }
95
96 pub fn free(points: FNx3View<'_>, padding: F) -> Result<Self, BoxError> {
110 assert!(padding > 0.0, "padding must be positive");
111 let n = points.nrows();
112 if n == 0 {
113 return Self::cube(padding, array![0.0 as F, 0.0, 0.0], [false, false, false]);
115 }
116 let mut min = array![points[[0, 0]], points[[0, 1]], points[[0, 2]]];
117 let mut max = min.clone();
118 for i in 1..n {
119 for d in 0..3 {
120 if points[[i, d]] < min[d] {
121 min[d] = points[[i, d]];
122 }
123 if points[[i, d]] > max[d] {
124 max[d] = points[[i, d]];
125 }
126 }
127 }
128 let origin = array![min[0] - padding, min[1] - padding, min[2] - padding,];
129 let lengths = array![
130 (max[0] - min[0] + 2.0 * padding).max(padding),
131 (max[1] - min[1] + 2.0 * padding).max(padding),
132 (max[2] - min[2] + 2.0 * padding).max(padding),
133 ];
134 Self::ortho(lengths, origin, [false, false, false])
135 }
136
137 pub fn h_view(&self) -> FNx3View<'_> {
139 self.h.view()
140 }
141
142 pub fn inv_view(&self) -> FNx3View<'_> {
144 self.inv.view()
145 }
146
147 pub fn origin_view(&self) -> F3View<'_> {
149 self.origin.view()
150 }
151
152 pub fn pbc_view(&self) -> ArrayView1<'_, bool> {
154 ArrayView1::from_shape(3, &self.pbc).expect("pbc_view shape")
155 }
156
157 pub fn pbc(&self) -> Pbc3 {
159 self.pbc
160 }
161
162 pub fn volume(&self) -> F {
164 math::det3(&self.h).abs()
165 }
166
167 pub fn tilts(&self) -> F3 {
169 array![self.h[[0, 1]], self.h[[0, 2]], self.h[[1, 2]]]
170 }
171
172 pub fn lengths(&self) -> F3 {
174 let a = self.lattice(0);
175 let b = self.lattice(1);
176 let c = self.lattice(2);
177 array![math::norm3(&a), math::norm3(&b), math::norm3(&c)]
178 }
179
180 pub fn nearest_plane_distance(&self) -> F3 {
183 let v = self.volume();
184 let a1 = self.lattice(0);
185 let a2 = self.lattice(1);
186 let a3 = self.lattice(2);
187
188 let c23 = math::cross3(&a2, &a3);
189 let c31 = math::cross3(&a3, &a1);
190 let c12 = math::cross3(&a1, &a2);
191
192 array![
193 v / math::norm3(&c23),
194 v / math::norm3(&c31),
195 v / math::norm3(&c12)
196 ]
197 }
198
199 pub fn kind(&self) -> &BoxKind {
200 &self.kind
201 }
202
203 pub fn lattice(&self, index: usize) -> F3 {
205 assert!(index < 3, "lattice index must be 0..2");
206 self.h.column(index).to_owned()
207 }
208
209 pub fn make_fractional(&self, r: F3View<'_>) -> F3 {
211 let dr = &r - &self.origin.view();
212 let mut frac = self.inv.dot(&dr);
213 for f in frac.iter_mut() {
214 *f -= f.floor();
215 }
216 frac
217 }
218
219 #[inline(always)]
221 pub fn make_fractional_fast(&self, r: F3View<'_>) -> F3 {
222 match &self.kind {
223 BoxKind::Ortho { inv_len, .. } => {
224 let mut frac = array![
225 (r[0] - self.origin[0]) * inv_len[0],
226 (r[1] - self.origin[1]) * inv_len[1],
227 (r[2] - self.origin[2]) * inv_len[2],
228 ];
229 for f in frac.iter_mut() {
230 *f -= f.floor();
231 }
232 frac
233 }
234 BoxKind::Triclinic => self.make_fractional(r),
235 }
236 }
237
238 #[inline(always)]
244 pub fn make_fractional_fast_arr(&self, r: F3View<'_>) -> [F; 3] {
245 match &self.kind {
246 BoxKind::Ortho { inv_len, .. } => {
247 let fx = (r[0] - self.origin[0]) * inv_len[0];
248 let fy = (r[1] - self.origin[1]) * inv_len[1];
249 let fz = (r[2] - self.origin[2]) * inv_len[2];
250 [fx - fx.floor(), fy - fy.floor(), fz - fz.floor()]
251 }
252 BoxKind::Triclinic => {
253 let f = self.make_fractional(r);
254 [f[0], f[1], f[2]]
255 }
256 }
257 }
258
259 pub fn make_cartesian(&self, frac: F3View<'_>) -> F3 {
261 &self.origin + &self.h.dot(&frac)
262 }
263
264 #[inline]
266 pub fn shortest_vector(&self, r1: F3View<'_>, r2: F3View<'_>) -> F3 {
267 let dr = &r2 - &r1;
268 let mut dr_frac = self.inv.dot(&dr);
269 for d in 0..3 {
270 if self.pbc[d] {
271 dr_frac[d] -= dr_frac[d].round();
272 }
273 }
274 self.h.dot(&dr_frac)
275 }
276
277 #[inline(always)]
279 pub fn shortest_vector_fast(&self, a: F3View<'_>, b: F3View<'_>) -> F3 {
280 match &self.kind {
281 BoxKind::Ortho { len, inv_len } => {
282 let mut dr = array![b[0] - a[0], b[1] - a[1], b[2] - a[2]];
283 for d in 0..3 {
284 if self.pbc[d] {
285 dr[d] -= (dr[d] * inv_len[d]).round() * len[d];
286 }
287 }
288 dr
289 }
290 BoxKind::Triclinic => self.shortest_vector(a, b),
291 }
292 }
293
294 #[inline(always)]
300 pub fn shortest_vector_fast_arr(&self, a: F3View<'_>, b: F3View<'_>) -> [F; 3] {
301 match &self.kind {
302 BoxKind::Ortho { len, inv_len } => {
303 let mut dr = [b[0] - a[0], b[1] - a[1], b[2] - a[2]];
304 if self.pbc[0] {
305 dr[0] -= (dr[0] * inv_len[0]).round() * len[0];
306 }
307 if self.pbc[1] {
308 dr[1] -= (dr[1] * inv_len[1]).round() * len[1];
309 }
310 if self.pbc[2] {
311 dr[2] -= (dr[2] * inv_len[2]).round() * len[2];
312 }
313 dr
314 }
315 BoxKind::Triclinic => {
316 let v = self.shortest_vector(a, b);
317 [v[0], v[1], v[2]]
318 }
319 }
320 }
321
322 #[inline(always)]
328 pub fn shortest_vector_raw(&self, a: [F; 3], b: [F; 3]) -> [F; 3] {
329 match &self.kind {
330 BoxKind::Ortho { len, inv_len } => {
331 let mut dr = [b[0] - a[0], b[1] - a[1], b[2] - a[2]];
332 if self.pbc[0] {
333 dr[0] -= (dr[0] * inv_len[0]).round() * len[0];
334 }
335 if self.pbc[1] {
336 dr[1] -= (dr[1] * inv_len[1]).round() * len[1];
337 }
338 if self.pbc[2] {
339 dr[2] -= (dr[2] * inv_len[2]).round() * len[2];
340 }
341 dr
342 }
343 BoxKind::Triclinic => {
344 let av = ndarray::ArrayView1::from(&a[..]);
345 let bv = ndarray::ArrayView1::from(&b[..]);
346 let v = self.shortest_vector(av, bv);
347 [v[0], v[1], v[2]]
348 }
349 }
350 }
351
352 #[inline]
354 pub fn calc_distance2(&self, a: F3View<'_>, b: F3View<'_>) -> F {
355 let dr = self.shortest_vector(a, b);
356 dr.dot(&dr)
357 }
358
359 pub fn to_frac(&self, xyz: FNx3View<'_>) -> FNx3 {
361 let n = xyz.nrows();
362 let mut result = FNx3::zeros((n, 3));
363 for i in 0..n {
364 let dr = &xyz.row(i) - &self.origin.view();
365 result.row_mut(i).assign(&self.inv.dot(&dr));
366 }
367 result
368 }
369
370 pub fn to_cart(&self, frac: FNx3View<'_>) -> FNx3 {
372 let n = frac.nrows();
373 let mut result = FNx3::zeros((n, 3));
374 for i in 0..n {
375 let cart = &self.origin + &self.h.dot(&frac.row(i));
376 result.row_mut(i).assign(&cart);
377 }
378 result
379 }
380
381 pub fn isin(&self, xyz: FNx3View<'_>) -> Array1<bool> {
383 let n = xyz.nrows();
384 let mut mask = Vec::with_capacity(n);
385 for i in 0..n {
386 let dr = &xyz.row(i) - &self.origin.view();
387 let frac = self.inv.dot(&dr);
388 let inside = (0..3).all(|d| frac[d] >= 0.0 && frac[d] < 1.0);
389 mask.push(inside);
390 }
391 Array1::from_vec(mask)
392 }
393
394 pub fn delta_out(
397 &self,
398 xyzu1: FNx3View<'_>,
399 xyzu2: FNx3View<'_>,
400 out: &mut FNx3,
401 minimum_image: bool,
402 ) {
403 assert_eq!(xyzu1.nrows(), xyzu2.nrows());
404 let n = xyzu1.nrows();
405 if minimum_image {
406 for i in 0..n {
407 let dr = self.shortest_vector(xyzu1.row(i), xyzu2.row(i));
408 out.row_mut(i).assign(&dr);
409 }
410 } else {
411 for i in 0..n {
412 let dr = &xyzu2.row(i) - &xyzu1.row(i);
413 out.row_mut(i).assign(&dr);
414 }
415 }
416 }
417
418 pub fn delta(&self, xyzu1: FNx3View<'_>, xyzu2: FNx3View<'_>, minimum_image: bool) -> FNx3 {
420 assert_eq!(xyzu1.nrows(), xyzu2.nrows());
421 let n = xyzu1.nrows();
422 let mut out = FNx3::zeros((n, 3));
423 self.delta_out(xyzu1, xyzu2, &mut out, minimum_image);
424 out
425 }
426
427 pub fn wrap(&self, xyz: FNx3View<'_>) -> FNx3 {
429 let mut frac = self.to_frac(xyz);
430 let n = frac.nrows();
431 for i in 0..n {
432 for d in 0..3 {
433 if self.pbc[d] {
434 frac[[i, d]] -= frac[[i, d]].floor();
435 }
436 }
437 }
438 self.to_cart(frac.view())
439 }
440
441 pub fn get_corners(&self) -> FNx3 {
442 let l = self.lengths();
443 let (ox, oy, oz) = (self.origin[0], self.origin[1], self.origin[2]);
444 let (lx, ly, lz) = (l[0], l[1], l[2]);
445 array![
446 [ox, oy, oz],
447 [ox + lx, oy, oz],
448 [ox + lx, oy + ly, oz],
449 [ox, oy + ly, oz],
450 [ox, oy, oz + lz],
451 [ox + lx, oy, oz + lz],
452 [ox + lx, oy + ly, oz + lz],
453 [ox, oy + ly, oz + lz],
454 ]
455 }
456}
457
458impl Region for SimBox {
459 fn bounds(&self) -> FNx3 {
460 let lengths = self.lengths();
461 let mut b = Array2::zeros((3, 2));
462 for d in 0..3 {
463 b[[d, 0]] = self.origin[d];
464 b[[d, 1]] = self.origin[d] + lengths[d];
465 }
466 b
467 }
468
469 fn contains(&self, points: &FNx3) -> Array1<bool> {
470 self.isin(points.view())
471 }
472
473 fn contains_point(&self, point: &[F; 3]) -> bool {
474 let r = ArrayView1::from_shape(3, point).expect("contains_point shape");
475 let dr = &r - &self.origin.view();
476 let frac = self.inv.dot(&dr);
477 (0..3).all(|d| frac[d] >= 0.0 && frac[d] < 1.0)
478 }
479}
480
481fn detect_box_kind(h: &F3x3) -> BoxKind {
482 let eps: F = 1e-12;
483 let is_ortho = h[[0, 1]].abs() < eps
484 && h[[0, 2]].abs() < eps
485 && h[[1, 0]].abs() < eps
486 && h[[1, 2]].abs() < eps
487 && h[[2, 0]].abs() < eps
488 && h[[2, 1]].abs() < eps;
489 if is_ortho {
490 let len = array![h[[0, 0]], h[[1, 1]], h[[2, 2]]];
491 let inv_len = array![1.0 / len[0], 1.0 / len[1], 1.0 / len[2]];
492 BoxKind::Ortho { len, inv_len }
493 } else {
494 BoxKind::Triclinic
495 }
496}
497
498#[cfg(test)]
499mod tests {
500 use super::*;
501
502 fn assert_close(a: F, b: F) {
503 assert!((a - b).abs() < 1e-6 as F, "{} != {}", a, b);
504 }
505
506 #[test]
507 fn roundtrip_frac_cart() {
508 let bx = SimBox::ortho(
509 array![2.0, 3.0, 4.0],
510 array![0.5, -1.0, 2.0],
511 [true, true, true],
512 )
513 .expect("invalid box lengths");
514 let pts = array![[0.5, -1.0, 2.0], [2.5, 2.0, 6.0]];
515 let frac = bx.to_frac(pts.view());
516 let cart = bx.to_cart(frac.view());
517 assert!((&pts - &cart).iter().all(|v| v.abs() < 1e-5));
518 }
519
520 #[test]
521 fn wrap_into_cell() {
522 let bx = SimBox::cube(2.0, array![0.0, 0.0, 0.0], [true, true, true])
523 .expect("invalid box length");
524 let pts = array![[2.1, -0.1, 3.9], [-1.9, 4.2, 0.0]];
525 let wrapped = bx.wrap(pts.view());
526 let frac = bx.to_frac(wrapped.view());
527 for i in 0..wrapped.nrows() {
528 let fx = frac[[i, 0]];
529 let fy = frac[[i, 1]];
530 let fz = frac[[i, 2]];
531 assert!((0.0..1.0).contains(&fx));
532 assert!((0.0..1.0).contains(&fy));
533 assert!((0.0..1.0).contains(&fz));
534 }
535 }
536
537 #[test]
538 fn calc_distance_matches_components() {
539 let bx = SimBox::cube(3.0, array![0.0, 0.0, 0.0], [true, true, true])
540 .expect("invalid box length");
541 let a = array![0.1, 0.2, 0.3];
542 let b = array![2.9, 0.2, 0.3];
543 let d2 = bx.calc_distance2(a.view(), b.view());
544 let dr = bx.shortest_vector(a.view(), b.view());
545 let expected = dr.dot(&dr);
546 assert!((d2 - expected).abs() < 1e-6);
547 }
548
549 #[test]
550 fn test_lengths_ortho() {
551 let bx = SimBox::ortho(
552 array![2.0, 4.0, 5.0],
553 array![0.0, 0.0, 0.0],
554 [true, true, true],
555 )
556 .expect("invalid box lengths");
557 let lengths = bx.lengths();
558 assert_close(lengths[0], 2.0);
559 assert_close(lengths[1], 4.0);
560 assert_close(lengths[2], 5.0);
561 }
562
563 #[test]
564 fn test_tilts_values() {
565 let h = array![[2.0, 1.0, 2.0], [0.0, 4.0, 3.0], [0.0, 0.0, 5.0]];
566 let bx = SimBox::new(h, array![0.0, 0.0, 0.0], [true, true, true]).expect("invalid box");
567 let tilts = bx.tilts();
568 assert_close(tilts[0], 1.0);
569 assert_close(tilts[1], 2.0);
570 assert_close(tilts[2], 3.0);
571 }
572
573 #[test]
574 fn test_volume() {
575 let bx = SimBox::ortho(
576 array![2.0, 3.0, 4.0],
577 array![0.0, 0.0, 0.0],
578 [true, true, true],
579 )
580 .expect("invalid box lengths");
581 assert_close(bx.volume(), 24.0);
582 }
583
584 #[test]
585 fn test_wrap_single_and_multi() {
586 let bx = SimBox::cube(2.0, array![0.0, 0.0, 0.0], [true, true, true])
587 .expect("invalid box length");
588 let pts = array![[10.0, -5.0, -5.0], [0.0, 0.5, 0.0]];
589 let wrapped = bx.wrap(pts.view());
590 assert_close(wrapped[[0, 0]], 0.0);
591 assert_close(wrapped[[0, 1]], 1.0);
592 assert_close(wrapped[[0, 2]], 1.0);
593 assert_close(wrapped[[1, 0]], 0.0);
594 assert_close(wrapped[[1, 1]], 0.5);
595 assert_close(wrapped[[1, 2]], 0.0);
596 }
597
598 #[test]
599 fn test_fractional_and_cartesian() {
600 let bx = SimBox::cube(2.0, array![0.0, 0.0, 0.0], [true, true, true])
601 .expect("invalid box length");
602 let p = array![-1.0, -1.0, -1.0];
603 let frac = bx.make_fractional(p.view());
604 assert_close(frac[0], 0.5);
605 assert_close(frac[1], 0.5);
606 assert_close(frac[2], 0.5);
607 let cart = bx.make_cartesian(frac.view());
608 assert_close(cart[0], 1.0);
609 assert_close(cart[1], 1.0);
610 assert_close(cart[2], 1.0);
611 }
612
613 #[test]
614 fn test_to_frac_to_cart_roundtrip() {
615 let bx = SimBox::ortho(
616 array![2.0, 3.0, 4.0],
617 array![1.0, 2.0, 3.0],
618 [true, true, true],
619 )
620 .expect("invalid box lengths");
621 let pts = array![[1.0, 2.0, 3.0], [2.0, 3.0, 4.0]];
622 let frac = bx.to_frac(pts.view());
623 let cart = bx.to_cart(frac.view());
624 for i in 0..pts.nrows() {
625 for j in 0..3 {
626 assert_close(pts[[i, j]], cart[[i, j]]);
627 }
628 }
629 }
630
631 #[test]
632 fn test_shortest_vector_and_distance() {
633 let bx = SimBox::cube(2.0, array![0.0, 0.0, 0.0], [true, true, true])
634 .expect("invalid box length");
635 let a = array![0.1, 0.0, 0.0];
636 let b = array![1.9, 0.0, 0.0];
637 let dr = bx.shortest_vector(a.view(), b.view());
638 assert_close(dr[0], -0.2);
639 assert_close(dr[1], 0.0);
640 assert_close(dr[2], 0.0);
641 let d2 = bx.calc_distance2(a.view(), b.view());
642 assert_close(d2, 0.04);
643 }
644
645 #[test]
646 fn test_contains_point_non_pbc() {
647 let bx = SimBox::cube(2.0, array![0.0, 0.0, 0.0], [false, false, false])
648 .expect("invalid box length");
649 assert!(bx.contains_point(&[0.5, 0.5, 0.5]));
650 assert!(!bx.contains_point(&[-0.1, 0.5, 0.5]));
651 assert!(!bx.contains_point(&[2.1, 0.5, 0.5]));
652 }
653
654 #[test]
655 fn test_contains_mask() {
656 let bx = SimBox::cube(2.0, array![0.0, 0.0, 0.0], [true, true, true])
657 .expect("invalid box length");
658 let pts = array![[0.1, 0.1, 0.1], [2.1, 0.0, 0.0], [-0.1, 0.0, 0.0]];
659 let mask = bx.contains(&pts);
660 assert!(mask[0]);
661 assert!(!mask[1]);
662 assert!(!mask[2]);
663 }
664
665 #[test]
666 fn test_simbox_free_basic() {
667 let pts = array![[1.0 as F, 2.0, 3.0], [4.0, 5.0, 6.0]];
668 let bx = SimBox::free(pts.view(), 1.0).unwrap();
669 assert_eq!(bx.pbc(), [false, false, false]);
670 let o = bx.origin_view();
672 assert!((o[0] - 0.0).abs() < 1e-5);
673 assert!((o[1] - 1.0).abs() < 1e-5);
674 assert!((o[2] - 2.0).abs() < 1e-5);
675 let l = bx.lengths();
677 assert!((l[0] - 5.0).abs() < 1e-5);
678 assert!((l[1] - 5.0).abs() < 1e-5);
679 assert!((l[2] - 5.0).abs() < 1e-5);
680 }
681
682 #[test]
683 fn test_simbox_free_single_point() {
684 let pts = array![[1.0 as F, 2.0, 3.0]];
685 let bx = SimBox::free(pts.view(), 2.0).unwrap();
686 assert_eq!(bx.pbc(), [false, false, false]);
687 let l = bx.lengths();
689 assert!(l[0] >= 2.0);
690 assert!(l[1] >= 2.0);
691 assert!(l[2] >= 2.0);
692 }
693
694 #[test]
695 fn test_simbox_free_empty() {
696 use ndarray::Array2;
697 let pts = Array2::<F>::zeros((0, 3));
698 let bx = SimBox::free(pts.view(), 1.0).unwrap();
699 assert_eq!(bx.pbc(), [false, false, false]);
700 }
701
702 #[test]
703 fn test_simbox_pbc_accessor() {
704 let bx = SimBox::cube(1.0, array![0.0 as F, 0.0, 0.0], [true, false, true]).unwrap();
705 assert_eq!(bx.pbc(), [true, false, true]);
706 }
707}