1use nalgebra::{OMatrix, RowVector3, Vector3};
2use pyo3::exceptions::PyValueError;
3use pyo3::prelude::*;
4use pyo3::types::PyType;
5use serde::de::{Deserialize, Deserializer};
6use serde::ser::{Serialize, Serializer};
7use serde_json;
8
9use moyo::base::{Cell, Lattice, MoyoError, Operations};
10
11#[derive(Debug, Clone)]
13#[pyclass(name = "Cell")]
14#[pyo3(module = "moyopy")]
15pub struct PyStructure(Cell);
16
17#[pymethods]
18impl PyStructure {
19 #[new]
20 pub fn new(
22 basis: [[f64; 3]; 3],
23 positions: Vec<[f64; 3]>,
24 numbers: Vec<i32>,
25 ) -> PyResult<Self> {
26 if positions.len() != numbers.len() {
27 return Err(PyValueError::new_err(
28 "positions and numbers should be the same length",
29 ));
30 }
31
32 let lattice = Lattice::new(OMatrix::from_rows(&[
34 RowVector3::from(basis[0]),
35 RowVector3::from(basis[1]),
36 RowVector3::from(basis[2]),
37 ]));
38 let positions = positions
39 .iter()
40 .map(|x| Vector3::new(x[0], x[1], x[2]))
41 .collect::<Vec<_>>();
42 let cell = Cell::new(lattice, positions, numbers);
43
44 Ok(Self(cell))
45 }
46
47 #[getter]
48 pub fn basis(&self) -> [[f64; 3]; 3] {
49 *self.0.lattice.basis.as_ref()
50 }
51
52 #[getter]
53 pub fn positions(&self) -> Vec<[f64; 3]> {
54 self.0.positions.iter().map(|x| [x.x, x.y, x.z]).collect()
55 }
56
57 #[getter]
58 pub fn numbers(&self) -> Vec<i32> {
59 self.0.numbers.clone()
60 }
61
62 #[getter]
63 pub fn num_atoms(&self) -> usize {
64 self.0.num_atoms()
65 }
66
67 pub fn serialize_json(&self) -> PyResult<String> {
68 serde_json::to_string(self).map_err(|e| PyValueError::new_err(e.to_string()))
69 }
70
71 #[classmethod]
72 pub fn deserialize_json(_cls: &PyType, s: &str) -> PyResult<Self> {
73 serde_json::from_str(s).map_err(|e| PyValueError::new_err(e.to_string()))
74 }
75}
76
77impl From<PyStructure> for Cell {
78 fn from(structure: PyStructure) -> Self {
79 structure.0
80 }
81}
82
83impl From<Cell> for PyStructure {
84 fn from(cell: Cell) -> Self {
85 PyStructure(cell)
86 }
87}
88
89impl Serialize for PyStructure {
90 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
91 where
92 S: Serializer,
93 {
94 Cell::from(self.clone()).serialize(serializer)
95 }
96}
97
98impl<'de> Deserialize<'de> for PyStructure {
99 fn deserialize<D>(deserializer: D) -> Result<PyStructure, D::Error>
100 where
101 D: Deserializer<'de>,
102 {
103 Cell::deserialize(deserializer).map(PyStructure::from)
104 }
105}
106
107#[derive(Debug)]
108#[pyclass(name = "MoyoError")]
109#[pyo3(module = "moyopy")]
110pub struct PyMoyoError(MoyoError);
111
112impl From<PyMoyoError> for PyErr {
113 fn from(error: PyMoyoError) -> Self {
114 PyValueError::new_err(error.0.to_string())
115 }
116}
117
118impl From<MoyoError> for PyMoyoError {
119 fn from(error: MoyoError) -> Self {
120 PyMoyoError(error)
121 }
122}
123
124#[derive(Debug)]
125#[pyclass(name = "Operations")]
126#[pyo3(module = "moyopy")]
127pub struct PyOperations(Operations);
128
129#[pymethods]
130impl PyOperations {
131 #[getter]
132 pub fn rotations(&self) -> Vec<[[i32; 3]; 3]> {
133 self.0.rotations.iter().map(|x| *x.as_ref()).collect()
134 }
135
136 #[getter]
137 pub fn translations(&self) -> Vec<[f64; 3]> {
138 self.0.translations.iter().map(|x| *x.as_ref()).collect()
139 }
140
141 #[getter]
142 pub fn num_operations(&self) -> usize {
143 self.0.num_operations()
144 }
145}
146
147impl From<PyOperations> for Operations {
148 fn from(operations: PyOperations) -> Self {
149 operations.0
150 }
151}
152
153impl From<Operations> for PyOperations {
154 fn from(operations: Operations) -> Self {
155 PyOperations(operations)
156 }
157}
158
159#[cfg(test)]
160mod tests {
161 extern crate approx;
162
163 use super::PyStructure;
164 use approx::assert_relative_eq;
165 use serde_json;
166
167 #[test]
168 fn test_serialization() {
169 let structure = PyStructure::new(
170 [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]],
171 vec![[0.0, 0.0, 0.0], [0.5, 0.5, 0.5]],
172 vec![1, 2],
173 )
174 .unwrap();
175
176 let serialized = serde_json::to_string(&structure).unwrap();
177 let deserialized: PyStructure = serde_json::from_str(&serialized).unwrap();
178
179 for i in 0..3 {
180 for j in 0..3 {
181 assert_relative_eq!(structure.basis()[i][j], deserialized.basis()[i][j]);
182 }
183 }
184 assert_eq!(structure.positions().len(), deserialized.positions().len());
185 for (actual, expect) in structure.positions().iter().zip(deserialized.positions()) {
186 for i in 0..3 {
187 assert_relative_eq!(actual[i], expect[i]);
188 }
189 }
190 assert_eq!(structure.numbers(), deserialized.numbers());
191 }
192}