moyopy/
base.rs

1use nalgebra::{OMatrix, RowVector3, Vector3};
2use pyo3::exceptions::PyValueError;
3use pyo3::prelude::*;
4use pyo3::types::PyType;
5use serde::de::{Deserialize, Deserializer};
6use serde::ser::{Serialize, Serializer};
7use serde_json;
8
9use moyo::base::{Cell, Lattice, MoyoError, Operations};
10
11// Unfortunately, "PyCell" is already reversed by pyo3...
12#[derive(Debug, Clone)]
13#[pyclass(name = "Cell")]
14#[pyo3(module = "moyopy")]
15pub struct PyStructure(Cell);
16
17#[pymethods]
18impl PyStructure {
19    #[new]
20    /// basis: row-wise basis vectors
21    pub fn new(
22        basis: [[f64; 3]; 3],
23        positions: Vec<[f64; 3]>,
24        numbers: Vec<i32>,
25    ) -> PyResult<Self> {
26        if positions.len() != numbers.len() {
27            return Err(PyValueError::new_err(
28                "positions and numbers should be the same length",
29            ));
30        }
31
32        // let lattice = Lattice::new(OMatrix::from(basis));
33        let lattice = Lattice::new(OMatrix::from_rows(&[
34            RowVector3::from(basis[0]),
35            RowVector3::from(basis[1]),
36            RowVector3::from(basis[2]),
37        ]));
38        let positions = positions
39            .iter()
40            .map(|x| Vector3::new(x[0], x[1], x[2]))
41            .collect::<Vec<_>>();
42        let cell = Cell::new(lattice, positions, numbers);
43
44        Ok(Self(cell))
45    }
46
47    #[getter]
48    pub fn basis(&self) -> [[f64; 3]; 3] {
49        *self.0.lattice.basis.as_ref()
50    }
51
52    #[getter]
53    pub fn positions(&self) -> Vec<[f64; 3]> {
54        self.0.positions.iter().map(|x| [x.x, x.y, x.z]).collect()
55    }
56
57    #[getter]
58    pub fn numbers(&self) -> Vec<i32> {
59        self.0.numbers.clone()
60    }
61
62    #[getter]
63    pub fn num_atoms(&self) -> usize {
64        self.0.num_atoms()
65    }
66
67    pub fn serialize_json(&self) -> PyResult<String> {
68        serde_json::to_string(self).map_err(|e| PyValueError::new_err(e.to_string()))
69    }
70
71    #[classmethod]
72    pub fn deserialize_json(_cls: &PyType, s: &str) -> PyResult<Self> {
73        serde_json::from_str(s).map_err(|e| PyValueError::new_err(e.to_string()))
74    }
75}
76
77impl From<PyStructure> for Cell {
78    fn from(structure: PyStructure) -> Self {
79        structure.0
80    }
81}
82
83impl From<Cell> for PyStructure {
84    fn from(cell: Cell) -> Self {
85        PyStructure(cell)
86    }
87}
88
89impl Serialize for PyStructure {
90    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
91    where
92        S: Serializer,
93    {
94        Cell::from(self.clone()).serialize(serializer)
95    }
96}
97
98impl<'de> Deserialize<'de> for PyStructure {
99    fn deserialize<D>(deserializer: D) -> Result<PyStructure, D::Error>
100    where
101        D: Deserializer<'de>,
102    {
103        Cell::deserialize(deserializer).map(PyStructure::from)
104    }
105}
106
107#[derive(Debug)]
108#[pyclass(name = "MoyoError")]
109#[pyo3(module = "moyopy")]
110pub struct PyMoyoError(MoyoError);
111
112impl From<PyMoyoError> for PyErr {
113    fn from(error: PyMoyoError) -> Self {
114        PyValueError::new_err(error.0.to_string())
115    }
116}
117
118impl From<MoyoError> for PyMoyoError {
119    fn from(error: MoyoError) -> Self {
120        PyMoyoError(error)
121    }
122}
123
124#[derive(Debug)]
125#[pyclass(name = "Operations")]
126#[pyo3(module = "moyopy")]
127pub struct PyOperations(Operations);
128
129#[pymethods]
130impl PyOperations {
131    #[getter]
132    pub fn rotations(&self) -> Vec<[[i32; 3]; 3]> {
133        self.0.rotations.iter().map(|x| *x.as_ref()).collect()
134    }
135
136    #[getter]
137    pub fn translations(&self) -> Vec<[f64; 3]> {
138        self.0.translations.iter().map(|x| *x.as_ref()).collect()
139    }
140
141    #[getter]
142    pub fn num_operations(&self) -> usize {
143        self.0.num_operations()
144    }
145}
146
147impl From<PyOperations> for Operations {
148    fn from(operations: PyOperations) -> Self {
149        operations.0
150    }
151}
152
153impl From<Operations> for PyOperations {
154    fn from(operations: Operations) -> Self {
155        PyOperations(operations)
156    }
157}
158
159#[cfg(test)]
160mod tests {
161    extern crate approx;
162
163    use super::PyStructure;
164    use approx::assert_relative_eq;
165    use serde_json;
166
167    #[test]
168    fn test_serialization() {
169        let structure = PyStructure::new(
170            [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]],
171            vec![[0.0, 0.0, 0.0], [0.5, 0.5, 0.5]],
172            vec![1, 2],
173        )
174        .unwrap();
175
176        let serialized = serde_json::to_string(&structure).unwrap();
177        let deserialized: PyStructure = serde_json::from_str(&serialized).unwrap();
178
179        for i in 0..3 {
180            for j in 0..3 {
181                assert_relative_eq!(structure.basis()[i][j], deserialized.basis()[i][j]);
182            }
183        }
184        assert_eq!(structure.positions().len(), deserialized.positions().len());
185        for (actual, expect) in structure.positions().iter().zip(deserialized.positions()) {
186            for i in 0..3 {
187                assert_relative_eq!(actual[i], expect[i]);
188            }
189        }
190        assert_eq!(structure.numbers(), deserialized.numbers());
191    }
192}