molar_python/
lib.rs

1use std::path::PathBuf;
2
3use anyhow::{anyhow, bail};
4use molar::prelude::*;
5use numpy::{
6    nalgebra::{self, Const, Dyn, VectorView},
7    PyArrayLike1, PyArrayMethods, PyReadonlyArray2, PyUntypedArrayMethods, ToPyArray,
8};
9use pyo3::{IntoPyObjectExt, prelude::*, types::PyTuple};
10
11mod utils;
12use triomphe::Arc;
13use utils::*;
14
15mod atom;
16use atom::Atom;
17
18mod particle;
19use particle::Particle;
20
21mod periodic_box;
22use periodic_box::PeriodicBox;
23
24mod membrane;
25use membrane::*;
26
27//-------------------------------------------
28
29#[pyclass(unsendable)]
30struct Topology(triomphe::Arc<molar::core::Topology>);
31
32#[pyclass(unsendable)]
33struct State(triomphe::Arc<molar::core::State>);
34
35#[pymethods]
36impl State {
37    fn __len__(&self) -> usize {
38        self.0.len()
39    }
40
41    #[getter]
42    fn get_time(&self) -> f32 {
43        self.0.get_time()
44    }
45
46    #[setter]
47    fn set_time(&self, t: f32) {
48        self.0.set_time(t);
49    }
50
51    #[getter]
52    fn get_box(&self) -> anyhow::Result<PeriodicBox> {
53        Ok(PeriodicBox(
54            self.0
55                .get_box()
56                .ok_or_else(|| anyhow!("No periodic box"))?
57                .clone(),
58        ))
59    }
60
61    #[setter]
62    fn set_box(&mut self, val: Bound<'_, PeriodicBox>) -> anyhow::Result<()> {
63        let b = self
64            .0
65            .get_box_mut()
66            .ok_or_else(|| anyhow!("No periodic box"))?;
67        *b = val.borrow().0.clone();
68        Ok(())
69    }
70}
71
72#[pyclass(unsendable)]
73struct FileHandler(
74    Option<molar::io::FileHandler>,
75    Option<molar::io::IoStateIterator>,
76);
77
78const ALREADY_TRANDFORMED: &str = "file handler is already transformed to state iterator";
79
80#[pymethods]
81impl FileHandler {
82    #[new]
83    fn new(fname: &str, mode: &str) -> anyhow::Result<Self> {
84        match mode {
85            "r" => Ok(FileHandler(
86                Some(molar::io::FileHandler::open(fname)?),
87                None,
88            )),
89            "w" => Ok(FileHandler(
90                Some(molar::io::FileHandler::create(fname)?),
91                None,
92            )),
93            _ => Err(anyhow!("Wrong file open mode")),
94        }
95    }
96
97    fn read(&mut self) -> anyhow::Result<(Topology, State)> {
98        let h = self
99            .0
100            .as_mut()
101            .ok_or_else(|| anyhow!(ALREADY_TRANDFORMED))?;
102        let (top, st) = h.read()?;
103        Ok((Topology(top.into()), State(st.into())))
104    }
105
106    fn read_topology(&mut self) -> anyhow::Result<Topology> {
107        let h = self
108            .0
109            .as_mut()
110            .ok_or_else(|| anyhow!(ALREADY_TRANDFORMED))?;
111        let top = h.read_topology()?;
112        Ok(Topology(top.into()))
113    }
114
115    fn read_state(&mut self) -> anyhow::Result<State> {
116        let h = self
117            .0
118            .as_mut()
119            .ok_or_else(|| anyhow!(ALREADY_TRANDFORMED))?;
120        if let Some(st) = h.read_state()? {
121            Ok(State(st.into()))
122        } else {
123            Err(anyhow!("can't read state"))
124        }
125    }
126
127    fn write(&mut self, data: Bound<'_, PyAny>) -> anyhow::Result<()> {
128        let h = self
129            .0
130            .as_mut()
131            .ok_or_else(|| anyhow!(ALREADY_TRANDFORMED))?;
132        if let Ok(s) = data.extract::<PyRef<'_, System>>() {
133            h.write(&s.0)?;
134        } else if let Ok(s) = data.extract::<PyRef<'_, Sel>>() {
135            h.write(&s.0)?;
136        } else if let Ok(s) = data.cast::<PyTuple>() {
137            if s.len() != 2 {
138                return Err(anyhow!("Tuple must have two elements"));
139            }
140            let top = s
141                .iter()
142                .next()
143                .unwrap()
144                .extract::<PyRefMut<'_, Topology>>().unwrap();
145            let st = s.iter().next().unwrap().extract::<PyRefMut<'_, State>>().unwrap();
146            h.write(&(Arc::clone(&top.0), Arc::clone(&st.0)))?;
147        } else {
148            return Err(anyhow!(
149                "Invalid data type {} when writing to file",
150                data.get_type()
151            ));
152        }
153        Ok(())
154    }
155
156    fn write_topology(&mut self, data: Bound<'_, PyAny>) -> anyhow::Result<()> {
157        let h = self
158            .0
159            .as_mut()
160            .ok_or_else(|| anyhow!(ALREADY_TRANDFORMED))?;
161        if let Ok(s) = data.extract::<PyRef<'_, System>>() {
162            h.write_topology(&s.0)?;
163        } else if let Ok(s) = data.extract::<PyRef<'_, Sel>>() {
164            h.write_topology(&s.0)?;
165        } else if let Ok(s) = data.extract::<PyRefMut<'_, Topology>>() {
166            h.write_topology(&s.0)?;
167        } else {
168            return Err(anyhow!(
169                "Invalid data type {} when writing to file",
170                data.get_type()
171            ));
172        }
173        Ok(())
174    }
175
176    fn write_state(&mut self, data: Bound<'_, PyAny>) -> anyhow::Result<()> {
177        let h = self
178            .0
179            .as_mut()
180            .ok_or_else(|| anyhow!(ALREADY_TRANDFORMED))?;
181        if let Ok(s) = data.extract::<PyRef<'_, System>>() {
182            h.write_state(&s.0)?;
183        } else if let Ok(s) = data.extract::<PyRef<'_, Sel>>() {
184            h.write_state(&s.0)?;
185        } else if let Ok(s) = data.extract::<PyRefMut<'_, State>>() {
186            h.write_state(&s.0)?;
187        } else {
188            return Err(anyhow!(
189                "Invalid data type {} when writing to file",
190                data.get_type()
191            ));
192        }
193        Ok(())
194    }
195
196    fn __iter__(mut slf: PyRefMut<'_, Self>) -> PyRefMut<'_, Self> {
197        if slf.1.is_none() {
198            let h = slf.0.take().unwrap();
199            slf.1 = Some(h.into_iter());
200        }
201        slf
202    }
203
204    fn __next__(mut slf: PyRefMut<'_, Self>) -> Option<Py<PyAny>> {
205        let st = slf.1.as_mut().unwrap().next().map(|st| State(st.into()));
206        if st.is_some() {
207            Python::attach(|py| Some(st.unwrap().into_py_any(py)))
208                .unwrap()
209                .ok()
210        } else {
211            None
212        }
213    }
214
215    fn skip_to_frame(&mut self, fr: usize) -> PyResult<()> {
216        let h = self
217            .0
218            .as_mut()
219            .ok_or_else(|| anyhow!(ALREADY_TRANDFORMED))?;
220        h.skip_to_frame(fr).map_err(|e| anyhow!(e))?;
221        Ok(())
222    }
223
224    fn skip_to_time(&mut self, t: f32) -> anyhow::Result<()> {
225        let h = self
226            .0
227            .as_mut()
228            .ok_or_else(|| anyhow!(ALREADY_TRANDFORMED))?;
229        h.skip_to_time(t)?;
230        Ok(())
231    }
232
233    fn tell_first(&self) -> anyhow::Result<(usize, f32)> {
234        let h = self
235            .0
236            .as_ref()
237            .ok_or_else(|| anyhow!(ALREADY_TRANDFORMED))?;
238        Ok(h.tell_first()?)
239    }
240
241    fn tell_current(&self) -> anyhow::Result<(usize, f32)> {
242        let h = self
243            .0
244            .as_ref()
245            .ok_or_else(|| anyhow!(ALREADY_TRANDFORMED))?;
246        Ok(h.tell_current()?)
247    }
248
249    fn tell_last(&self) -> anyhow::Result<(usize, f32)> {
250        let h = self
251            .0
252            .as_ref()
253            .ok_or_else(|| anyhow!(ALREADY_TRANDFORMED))?;
254        Ok(h.tell_last()?)
255    }
256
257    #[getter]
258    fn stats(&self) -> anyhow::Result<FileStats> {
259        let h = self
260            .0
261            .as_ref()
262            .ok_or_else(|| anyhow!(ALREADY_TRANDFORMED))?;
263        Ok(FileStats(h.stats.clone()))
264    }
265
266    #[getter]
267    fn file_name(&self) -> anyhow::Result<PathBuf> {
268        let h = self
269            .0
270            .as_ref()
271            .ok_or_else(|| anyhow!(ALREADY_TRANDFORMED))?;
272        Ok(h.file_path.clone())
273    }
274}
275
276#[pyclass]
277struct FileStats(molar::io::FileStats);
278
279#[pymethods]
280impl FileStats {
281    #[getter]
282    fn elapsed_time(&self) -> std::time::Duration {
283        self.0.elapsed_time
284    }
285
286    #[getter]
287    fn frames_processed(&self) -> usize {
288        self.0.frames_processed
289    }
290
291    #[getter]
292    fn cur_t(&self) -> f32 {
293        self.0.cur_t
294    }
295
296    fn __repr__(&self) -> String {
297        format!("{}", self.0)
298    }
299
300    fn __str__(&self) -> String {
301        format!("{}", self.0)
302    }
303}
304
305#[pyclass(unsendable, sequence)]
306struct System(molar::core::System);
307
308#[pymethods]
309impl System {
310    #[new]
311    #[pyo3(signature = (*py_args))]
312    fn new<'py>(py_args: &Bound<'py, PyTuple>) -> PyResult<Self> {
313        if py_args.len() == 1 {
314            // From file
315            Ok(System(
316                molar::core::System::from_file(&py_args.get_item(0)?.extract::<String>()?)
317                    .map_err(|e| anyhow!(e))?,
318            ))
319        } else if py_args.len() == 2 {
320            let top = py_args
321                .get_item(0)?
322                .cast::<Topology>()?
323                .try_borrow_mut()?;
324            let st = py_args.get_item(1)?.cast::<State>()?.try_borrow_mut()?;
325            Ok(System(
326                molar::core::System::new(Arc::clone(&top.0), Arc::clone(&st.0))
327                    .map_err(|e| anyhow!(e))?,
328            ))
329        } else {
330            // Empty builder
331            Ok(System(molar::core::System::new_empty()))
332        }
333    }
334
335    fn __len__(&self) -> usize {
336        self.0.len()
337    }
338
339    fn select_all(&mut self) -> anyhow::Result<Sel> {
340        Ok(Sel::new_owned(self.0.select_all()?))
341    }
342
343    fn select(&mut self, sel_str: &str) -> anyhow::Result<Sel> {
344        Ok(Sel::new_owned(self.0.select(sel_str)?))
345    }
346
347    #[pyo3(signature = (arg=None))]
348    fn __call__(&self, arg: Option<&Bound<'_, PyAny>>) -> anyhow::Result<Sel> {
349        if let Some(arg) = arg {
350            if let Ok(val) = arg.extract::<String>() {
351                if val.is_empty() {
352                    Ok(Sel::new_owned(self.0.select_all()?))
353                } else {
354                    Ok(Sel::new_owned(self.0.select(val)?))
355                }
356            } else if let Ok(val) = arg.extract::<(usize, usize)>() {
357                Ok(Sel::new_owned(self.0.select(val.0..val.1)?))
358            } else if let Ok(val) = arg.extract::<Vec<usize>>() {
359                Ok(Sel::new_owned(self.0.select(val)?))
360            } else {
361                Err(anyhow!(
362                    "Invalid argument type {} when creating selection",
363                    arg.get_type()
364                )
365                .into())
366            }
367        } else {
368            Ok(Sel::new_owned(self.0.select_all()?))
369        }
370    }
371
372    fn set_state(&mut self, st: &State) -> anyhow::Result<State> {
373        let old_state = self.0.set_state(Arc::clone(&st.0))?;
374        Ok(State(old_state))
375    }
376
377    fn set_topology(&mut self, top: &Topology) -> anyhow::Result<Topology> {
378        let old_top = self.0.set_topology(Arc::clone(&top.0))?;
379        Ok(Topology(old_top))
380    }
381
382    
383    fn get_box(&self) -> anyhow::Result<PeriodicBox> {
384        Ok(self.0.require_box().cloned().map(|b| PeriodicBox(b))?)
385    }
386
387    // fn get_topology(&self) -> Topology {
388    //     Topology(self.0.get_topology())
389    // }
390
391    fn save(&self, fname: &str) -> anyhow::Result<()> {
392        Ok(self.0.save(fname)?)
393    }
394
395    fn remove(&self, arg: &Bound<'_, PyAny>) -> anyhow::Result<()> {
396        // In the future other types can be used as well
397        if let Ok(sel) = arg.cast::<Sel>() {
398            Ok(self.0.remove(&sel.borrow().0)?)
399        } else if let Ok(sel_str) = arg.extract::<String>() {
400            let sel = self.0.select(sel_str)?;
401            Ok(self.0.remove(&sel)?)
402        } else if let Ok(list) = arg.extract::<Vec<usize>>() {
403            Ok(self.0.remove(&list)?)
404        } else {
405            unreachable!()
406        }
407    }
408
409    fn append(&self, arg: &Bound<'_, PyAny>) -> anyhow::Result<()> {
410        // In the future other types can be used as well
411        if let Ok(sel) = arg.cast::<Sel>() {
412            self.0.append(&sel.borrow().0);
413        } else if let Ok(sel) = arg.cast::<System>() {
414            self.0.append(&sel.borrow().0);
415        } else {
416            anyhow::bail!("Unsupported type to append a Source")
417        }
418        Ok(())
419    }
420
421    #[getter]
422    fn get_time(&self) -> f32 {
423        self.0.get_time()
424    }
425
426    // #[setter]
427    // fn set_time(&self, t: f32) {
428    //     self.0.set_time(t);
429    // }
430
431    fn set_box_from(&self, sys: &System) {
432        self.0.set_box_from(&sys.0);
433    }
434}
435
436//====================================
437
438#[pyclass(sequence, unsendable)]
439struct Sel(molar::core::Sel);
440
441impl Sel {
442    fn new_owned(sel: molar::core::Sel) -> Self {
443        Self(sel)
444    }
445
446    fn new_ref(sel: &molar::core::Sel) -> Self {
447        Self(sel.new_view())
448    }
449}
450
451#[pymethods]
452impl Sel {
453    fn __len__(&self) -> usize {
454        self.0.len()
455    }
456
457    fn __call__(&self, arg: &Bound<'_, PyAny>) -> PyResult<Sel> {
458        if let Ok(val) = arg.extract::<String>() {
459            Ok(Sel::new_owned(self.0.select(val).map_err(|e| anyhow!(e))?))
460        } else if let Ok(val) = arg.extract::<(usize, usize)>() {
461            Ok(Sel::new_owned(
462                self.0.select(val.0..=val.1).map_err(|e| anyhow!(e))?,
463            ))
464        } else if let Ok(val) = arg.extract::<Vec<usize>>() {
465            Ok(Sel::new_owned(self.0.select(val).map_err(|e| anyhow!(e))?))
466        } else {
467            Err(anyhow!(
468                "Invalid argument type {} when creating selection",
469                arg.get_type()
470            )
471            .into())
472        }
473    }
474
475    // Indexing
476    fn __getitem__(slf: Bound<Self>, i: isize) -> PyResult<Py<PyAny>> {
477        let s = slf.borrow();
478        let ind = if i < 0 {
479            if i.abs() > s.__len__() as isize {
480                return Err(anyhow!(
481                    "Negative index {i} is out of bounds {}:-1",
482                    -(s.__len__() as isize)
483                )
484                .into());
485            }
486            s.__len__() - i.unsigned_abs()
487        } else if i >= s.__len__() as isize {
488            return Err(anyhow!("Index {} is out of bounds 0:{}", i, s.__len__()).into());
489        } else {
490            i as usize
491        };
492
493        // Call Rust function
494        let p = s.0.get_particle_mut(ind).unwrap();
495        Ok(Particle {
496            atom: unsafe { &mut *(p.atom as *mut molar::core::Atom) },
497            pos: map_pyarray_to_pos(slf.py(), p.pos, &slf),
498            id: p.id,
499        }
500        .into_py_any(slf.py())?)
501    }
502
503    // Iteration protocol
504    fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, ParticleIterator> {
505        Bound::new(
506            slf.py(),
507            ParticleIterator {
508                sel: slf.into(),
509                cur: 0,
510            },
511        )
512        .unwrap()
513        .borrow()
514    }
515
516    fn get_index<'py>(&self, py: Python<'py>) -> Bound<'py, numpy::PyArray1<usize>> {
517        numpy::PyArray1::from_iter(py, self.0.iter_index())
518    }
519
520    fn get_coord<'py>(&self, py: Python<'py>) -> Bound<'py, numpy::PyArray2<f32>> {
521        // We allocate an uninitialized PyArray manually and fill it with data.
522        // By doing this we save on unnecessary initiallization and extra allocation
523        unsafe {
524            let arr = numpy::PyArray2::<f32>::new(py, [3, self.0.len()], true);
525            let arr_ptr = arr.data();
526            for i in 0..self.0.len() {
527                let pos_ptr = self.0.get_pos_unchecked(i).coords.as_ptr();
528                // This is faster than copying by element with uget_raw()
529                std::ptr::copy_nonoverlapping(pos_ptr, arr_ptr.offset(i as isize * 3), 3);
530            }
531            arr
532        }
533    }
534
535    fn set_coord(&mut self, arr: PyReadonlyArray2<f32>) -> PyResult<()> {
536        // Check if the shape is correct
537        if arr.shape() != [3, self.__len__()] {
538            return Err(anyhow!(
539                "Array shape must be [3, {}], not {:?}",
540                self.__len__(),
541                arr.shape()
542            ))?;
543        }
544        let ptr = arr.data();
545
546        unsafe {
547            for i in 0..self.__len__() {
548                let pos_ptr = self.0.get_pos_mut_unchecked(i).coords.as_mut_ptr();
549                std::ptr::copy_nonoverlapping(ptr.offset(i as isize * 3), pos_ptr, 3);
550            }
551        }
552
553        Ok(())
554    }
555
556    fn set_state(&mut self, st: &State) -> anyhow::Result<State> {
557        let old_state = self.0.set_state(Arc::clone(&st.0))?;
558        Ok(State(old_state))
559    }
560
561    fn set_state_from(&mut self, arg: &Bound<'_, PyAny>) -> anyhow::Result<State> {
562        if let Ok(val) = arg.cast::<System>() {
563            Ok(State(self.0.set_state_from(&val.borrow().0)?))
564        } else if let Ok(val) = arg.cast::<Sel>() {
565            Ok(State(self.0.set_state_from(&val.borrow().0)?))
566        } else {
567            Err(anyhow!(
568                "Invalid argument type {} in set_state_from()",
569                arg.get_type()
570            )
571            .into())
572        }
573    }
574
575    fn set_topology(&mut self, top: &Topology) -> anyhow::Result<Topology> {
576        let old_top = self.0.set_topology(Arc::clone(&top.0))?;
577        Ok(Topology(old_top))
578    }
579
580    pub fn set_same_chain(&self, val: char) {
581        self.0.set_same_chain(val)
582    }
583
584    pub fn set_same_resname(&mut self, val: &str) {
585        self.0.set_same_resname(val)
586    }
587
588    pub fn set_same_resid(&mut self, val: i32) {
589        self.0.set_same_resid(val)
590    }
591
592    pub fn set_same_name(&mut self, val: &str) {
593        self.0.set_same_name(val)
594    }
595
596    pub fn set_same_mass(&mut self, val: f32) {
597        self.0.set_same_mass(val)
598    }
599
600    pub fn set_same_bfactor(&mut self, val: f32) {
601        self.0.set_same_bfactor(val)
602    }
603
604    #[getter]
605    fn get_time(&self) -> f32 {
606        self.0.get_time()
607    }
608
609    #[pyo3(signature = (dims=[false,false,false]))]
610    fn com<'py>(
611        &self,
612        py: Python<'py>,
613        dims: [bool; 3],
614    ) -> PyResult<Bound<'py, numpy::PyArray1<f32>>> {
615        let pbc_dims = PbcDims::new(dims[0], dims[1], dims[2]);
616        Ok(clone_vec_to_pyarray1(
617            &self
618                .0
619                .center_of_mass_pbc_dims(pbc_dims)
620                .map_err(|e| anyhow!(e))?
621                .coords,
622            py,
623        ))
624    }
625
626    #[pyo3(signature = (dims=[false,false,false]))]
627    fn cog<'py>(
628        &self,
629        py: Python<'py>,
630        dims: [bool; 3],
631    ) -> PyResult<Bound<'py, numpy::PyArray1<f32>>> {
632        let pbc_dims = PbcDims::new(dims[0], dims[1], dims[2]);
633        Ok(clone_vec_to_pyarray1(
634            &self
635                .0
636                .center_of_geometry_pbc_dims(pbc_dims)
637                .map_err(|e| anyhow!(e))?
638                .coords,
639            py,
640        ))
641    }
642
643    fn principal_transform(&self) -> anyhow::Result<IsometryTransform> {
644        let tr = self.0.principal_transform()?;
645        Ok(IsometryTransform(tr))
646    }
647
648    fn principal_transform_pbc(&self) -> anyhow::Result<IsometryTransform> {
649        let tr = self.0.principal_transform_pbc()?;
650        Ok(IsometryTransform(tr))
651    }
652
653    fn apply_transform(&self, tr: &IsometryTransform) {
654        self.0.apply_transform(&tr.0);
655    }
656
657    fn gyration(&self) -> anyhow::Result<f32> {
658        Ok(self.0.gyration()?)
659    }
660
661    fn gyration_pbc(&self) -> anyhow::Result<f32> {
662        Ok(self.0.gyration_pbc()?)
663    }
664
665    fn inertia<'py>(
666        &self,
667        py: Python<'py>,
668    ) -> anyhow::Result<(
669        Bound<'py, numpy::PyArray1<f32>>,
670        Bound<'py, numpy::PyArray2<f32>>,
671    )> {
672        let (moments, axes) = self.0.inertia()?;
673        let mom = clone_vec_to_pyarray1(&moments, py);
674        let ax = axes.to_pyarray(py);
675        Ok((mom, ax))
676    }
677
678    fn inertia_pbc<'py>(
679        &self,
680        py: Python<'py>,
681    ) -> anyhow::Result<(
682        Bound<'py, numpy::PyArray1<f32>>,
683        Bound<'py, numpy::PyArray2<f32>>,
684    )> {
685        let (moments, axes) = self.0.inertia_pbc()?;
686        let mom = clone_vec_to_pyarray1(&moments, py);
687        let ax = axes.to_pyarray(py);
688        Ok((mom, ax))
689    }
690
691    fn save(&self, fname: &str) -> anyhow::Result<()> {
692        Ok(self.0.save(fname)?)
693    }
694
695    fn translate<'py>(&self, arg: PyArrayLike1<'py, f32>) -> anyhow::Result<()> {
696        let vec: VectorView<f32, Const<3>, Dyn> = arg
697            .try_as_matrix()
698            .ok_or_else(|| anyhow!("conversion to Vector3 has failed"))?;
699        self.0.translate(&vec);
700        Ok(())
701    }
702
703    fn split_resindex(&self) -> Vec<Sel> {
704        self.0.split_resindex_iter().map(|s| Sel(s)).collect()
705    }
706
707    fn split_chain(&self) -> Vec<Sel> {
708        self.0
709            .split_iter(|p| Some(p.atom.chain))
710            .map(|s| Sel(s))
711            .collect()
712    }
713
714    fn split_molecule(&self) -> Vec<Sel> {
715        self.0
716            .split_mol_iter()
717            .map(|sel| Sel::new_owned(sel))
718            .collect()
719    }
720
721    fn to_gromacs_ndx(&self, name: &str) -> String {
722        self.0.as_gromacs_ndx_str(name)
723    }
724
725    /// operator |
726    fn __or__(&self, rhs: &Sel) -> Sel {
727        Sel::new_owned(&self.0 | &rhs.0)
728    }
729
730    /// operator &
731    fn __and__(&self, rhs: &Sel) -> Sel {
732        Sel::new_owned(&self.0 & &rhs.0)
733    }
734
735    /// -= (remove other from self)
736    fn __sub__(&self, rhs: &Sel) -> Sel {
737        Sel::new_owned(&self.0 - &rhs.0)
738    }
739
740    /// ~ operator
741    fn __invert__(&self) -> Sel {
742        Sel::new_owned(!&self.0)
743    }
744
745    fn sasa(&self) -> SasaResults {
746        SasaResults(self.0.sasa())
747    }
748
749    fn min_max<'py>(&self, py: Python<'py>) -> (Bound<'py, numpy::PyArray1<f32>>,Bound<'py, numpy::PyArray1<f32>>) {
750        let (min,max) = self.0.min_max();
751        (clone_vec_to_pyarray1(&min.coords, py), clone_vec_to_pyarray1(&max.coords, py))
752    }
753
754    fn unwrap_connectivity(&self, cutoff: f32) -> anyhow::Result<Vec<Sel>> {
755        let mut res = vec![];
756        for s in self.0.unwrap_connectivity(cutoff)? {
757            res.push(Sel(s));
758        }
759        Ok(res)
760    }
761
762    fn split_connectivity(&self, cutoff: f32) -> anyhow::Result<Vec<Sel>> {
763        let mut res = vec![];
764        for s in self.0.split_connectivity(cutoff)? {
765            res.push(Sel(s));
766        }
767        Ok(res)
768    }
769
770    #[pyo3(signature = (cutoff,dims=[true,true,true]))]
771    fn unwrap_connectivity_dim(&self, cutoff: f32, dims: [bool; 3]) -> anyhow::Result<Vec<Sel>> {
772        let mut res = vec![];
773        let pbc_dims = PbcDims::new(dims[0], dims[1], dims[2]);
774        for s in self.0.unwrap_connectivity_dim(cutoff,pbc_dims)? {
775            res.push(Sel(s));
776        }
777        Ok(res)
778    }
779
780    #[pyo3(signature = (dims=[true,true,true]))]
781    fn unwrap_simple_dim(&self, dims: [bool; 3]) -> anyhow::Result<()> {
782        let pbc_dims = PbcDims::new(dims[0], dims[1], dims[2]);
783        Ok(self.0.unwrap_simple_dim(pbc_dims)?)
784    }
785
786    fn unwrap_simple(&self) -> anyhow::Result<()> {
787        Ok(self.0.unwrap_simple()?)
788    }
789}
790
791#[pyclass(unsendable)]
792struct SasaResults(molar::core::SasaResults);
793
794#[pymethods]
795impl SasaResults {
796    #[getter]
797    fn areas(&self) -> &[f32] {
798        self.0.areas()
799    }
800
801    #[getter]
802    fn volumes(&self) -> &[f32] {
803        self.0.volumes()
804    }
805
806    #[getter]
807    fn total_area(&self) -> f32 {
808        self.0.total_area()
809    }
810
811    #[getter]
812    fn total_volume(&self) -> f32 {
813        self.0.total_volume()
814    }
815}
816
817#[pyclass]
818struct IsometryTransform(nalgebra::IsometryMatrix3<f32>);
819
820// Free functions
821
822#[pyfunction(name = "fit_transform")]
823fn fit_transform_py(sel1: &Sel, sel2: &Sel) -> anyhow::Result<IsometryTransform> {
824    let tr = molar::prelude::fit_transform(&sel1.0, &sel2.0)?;
825    Ok(IsometryTransform(tr))
826}
827
828#[pyfunction(name = "fit_transform_matching")]
829fn fit_transform_matching_py(sel1: &Sel, sel2: &Sel) -> anyhow::Result<IsometryTransform> {
830    let tr = molar::core::fit_transform_matching(&sel1.0, &sel2.0)?;
831    Ok(IsometryTransform(tr))
832}
833
834#[pyfunction]
835fn rmsd(sel1: &Sel, sel2: &Sel) -> anyhow::Result<f32> {
836    Ok(molar::core::Sel::rmsd(&sel1.0, &sel2.0)?)
837}
838
839#[pyfunction(name = "rmsd_mw")]
840fn rmsd_mw_py(sel1: &Sel, sel2: &Sel) -> anyhow::Result<f32> {
841    Ok(molar::core::rmsd_mw(&sel1.0, &sel2.0)?)
842}
843
844#[pyclass]
845struct ParticleIterator {
846    sel: Py<Sel>,
847    cur: isize,
848}
849
850#[pymethods]
851impl ParticleIterator {
852    fn __iter__(slf: PyRef<'_, Self>) -> PyRef<'_, Self> {
853        slf
854    }
855
856    fn __next__(mut slf: PyRefMut<'_, Self>) -> Option<Py<PyAny>> {
857        let ret = Python::attach(|py| {
858            let s = slf.sel.bind(py);
859            Sel::__getitem__(s.clone(), slf.cur)
860        })
861        .ok();
862        slf.cur += 1;
863        ret
864    }
865}
866
867#[pyfunction]
868#[pyo3(signature = (cutoff,data1,data2=None,dims=[false,false,false]))]
869fn distance_search<'py>(
870    py: Python<'py>,
871    cutoff: &Bound<'py, PyAny>,
872    data1: &Bound<'py, Sel>,
873    data2: Option<&Bound<'py, Sel>>,
874    dims: [bool; 3],
875) -> anyhow::Result<Bound<'py, PyAny>> {
876    let mut res: Vec<(usize, usize, f32)>;
877    let pbc_dims = PbcDims::new(dims[0], dims[1], dims[2]);
878    let sel1 = data1.borrow();
879
880    if let Ok(d) = cutoff.extract::<f32>() {
881        // Distance cutoff
882        if let Some(d2) = data2 {
883            let sel2 = d2.borrow();
884            if pbc_dims.any() {
885                res = molar::core::distance_search_double_pbc(
886                    d,
887                    sel1.0.iter_pos(),
888                    sel2.0.iter_pos(),
889                    sel1.0.iter_index(),
890                    sel2.0.iter_index(),
891                    sel1.0.get_box().ok_or_else(|| anyhow!("no periodic box"))?,
892                    pbc_dims,
893                );
894            } else {
895                res = molar::core::distance_search_double(
896                    d,
897                    sel1.0.iter_pos(),
898                    sel2.0.iter_pos(),
899                    sel1.0.iter_index(),
900                    sel2.0.iter_index(),
901                );
902            }
903        } else {
904            if pbc_dims.any() {
905                res = molar::core::distance_search_single_pbc(
906                    d,
907                    sel1.0.iter_pos(),
908                    sel1.0.iter_index(),
909                    sel1.0.get_box().ok_or_else(|| anyhow!("no periodic box"))?,
910                    pbc_dims,
911                );
912            } else {
913                res =
914                    molar::core::distance_search_single(d, sel1.0.iter_pos(), sel1.0.iter_index());
915            }
916        }
917    } else if let Ok(s) = cutoff.extract::<String>() {
918        if s != "vdw" {
919            bail!("Unknown cutoff type {s}");
920        }
921
922        // VdW cutof
923        let vdw1: Vec<f32> = sel1.0.iter_atoms().map(|a| a.vdw()).collect();
924
925        if sel1.0.len() != vdw1.len() {
926            bail!("Size mismatch 1: {} {}", sel1.0.len(), vdw1.len());
927        }
928
929        if let Some(d2) = data2 {
930            let sel2 = d2.borrow();
931            let vdw2: Vec<f32> = sel2.0.iter_atoms().map(|a| a.vdw()).collect();
932
933            if sel2.0.len() != vdw2.len() {
934                bail!("Size mismatch 2: {} {}", sel2.0.len(), vdw2.len());
935            }
936
937            if pbc_dims.any() {
938                res = molar::core::distance_search_double_vdw(
939                    sel1.0.iter_pos(),
940                    sel2.0.iter_pos(),
941                    &vdw1,
942                    &vdw2,
943                );
944            } else {
945                res = molar::core::distance_search_double_vdw_pbc(
946                    sel1.0.iter_pos(),
947                    sel2.0.iter_pos(),
948                    &vdw1,
949                    &vdw2,
950                    sel1.0.get_box().ok_or_else(|| anyhow!("no periodic box"))?,
951                    pbc_dims,
952                );
953            }
954
955            // Convert local indices to global
956            unsafe {
957                for el in &mut res {
958                    el.0 = sel1.0.get_index_unchecked(el.0);
959                    el.1 = sel2.0.get_index_unchecked(el.1);
960                }
961            }
962        } else {
963            bail!("VdW distance search is not yet supported for single selection");
964        }
965    } else {
966        unreachable!()
967    };
968
969    // Subdivide the result into two arrays
970    unsafe {
971        // Pairs array
972        let pairs_arr = numpy::PyArray2::<usize>::new(py, [res.len(), 2], true);
973        for i in 0..res.len() {
974            pairs_arr.uget_raw([i, 0]).write(res[i].0);
975            pairs_arr.uget_raw([i, 1]).write(res[i].1);
976        }
977
978        // Distances array
979        let dist_arr = numpy::PyArray1::<f32>::new(py, [res.len()], true);
980        for i in 0..res.len() {
981            dist_arr.uget_raw(i).write(res[i].2);
982        }
983
984        Ok((pairs_arr, dist_arr).into_bound_py_any(py)?)
985    }
986}
987
988#[pyclass]
989struct NdxFile(molar::core::NdxFile);
990
991#[pymethods]
992impl NdxFile {
993    #[new]
994    fn new(fname: &str) -> anyhow::Result<Self> {
995        Ok(NdxFile(molar::core::NdxFile::new(fname)?))
996    }
997
998    fn get_group_as_sel(&self, gr_name: &str, src: &System) -> anyhow::Result<Sel> {
999        Ok(Sel::new_owned(self.0.get_group_as_sel(gr_name, &src.0)?))
1000    }
1001}
1002
1003//====================================
1004#[pyfunction]
1005fn greeting() {
1006    molar::greeting("molar_python");
1007}
1008
1009/// A Python module implemented in Rust.
1010#[pymodule(name = "molar")]
1011//#[pymodule]
1012fn molar_python(m: &Bound<'_, PyModule>) -> PyResult<()> {
1013    pyo3_log::init();
1014    m.add_class::<Atom>()?;
1015    m.add_class::<Particle>()?;
1016    m.add_class::<Topology>()?;
1017    m.add_class::<State>()?;
1018    m.add_class::<PeriodicBox>()?;
1019    m.add_class::<FileHandler>()?;
1020    m.add_class::<System>()?;
1021    m.add_class::<Sel>()?;
1022    m.add_class::<SasaResults>()?;
1023    m.add_class::<NdxFile>()?;
1024    m.add_class::<Histogram1D>()?;
1025    m.add_function(wrap_pyfunction!(greeting, m)?)?;
1026    m.add_function(wrap_pyfunction!(fit_transform_py, m)?)?;
1027    m.add_function(wrap_pyfunction!(fit_transform_matching_py, m)?)?;
1028    m.add_function(wrap_pyfunction!(rmsd, m)?)?;
1029    m.add_function(wrap_pyfunction!(rmsd_mw_py, m)?)?;
1030    m.add_function(wrap_pyfunction!(distance_search, m)?)?;
1031    m.add_class::<LipidMolecule>()?;
1032    m.add_class::<Membrane>()?;
1033    Ok(())
1034}