1use crate::basis::Basis;
2use crate::graph::Graph;
3use crate::types::{BasisMatrix, CoordinateVector, OffsetVector};
4use crate::unitcell::Unitcell;
5use roxmltree::{Document, Node};
6use std::fmt::{self, Write as _};
7use std::fs;
8use std::path::Path;
9
10#[derive(Debug)]
11pub enum XmlError {
12 Parse(roxmltree::Error),
13 Io(std::io::Error),
14 MissingAttribute(&'static str),
15 InvalidFormat(&'static str),
16 DimensionMismatch(&'static str),
17 UnknownEntry(String),
18}
19
20impl fmt::Display for XmlError {
21 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
22 match self {
23 XmlError::Parse(err) => write!(f, "XML parse error: {err}"),
24 XmlError::Io(err) => write!(f, "I/O error: {err}"),
25 XmlError::MissingAttribute(attr) => write!(f, "missing XML attribute: {attr}"),
26 XmlError::InvalidFormat(msg) => write!(f, "invalid XML format: {msg}"),
27 XmlError::DimensionMismatch(msg) => write!(f, "dimension mismatch: {msg}"),
28 XmlError::UnknownEntry(name) => write!(f, "lattice entry not found: {name}"),
29 }
30 }
31}
32
33impl std::error::Error for XmlError {}
34
35impl From<roxmltree::Error> for XmlError {
36 fn from(value: roxmltree::Error) -> Self {
37 XmlError::Parse(value)
38 }
39}
40
41impl From<std::io::Error> for XmlError {
42 fn from(value: std::io::Error) -> Self {
43 XmlError::Io(value)
44 }
45}
46
47pub type Result<T> = std::result::Result<T, XmlError>;
48
49pub fn read_basis_from_str(xml: &str, name: &str) -> Result<Basis> {
50 let doc = Document::parse(xml)?;
51 let entry = find_entry(&doc, "LATTICE", name)?.ok_or_else(|| XmlError::UnknownEntry(name.to_string()))?;
52 parse_basis(entry)
53}
54
55pub fn read_unitcell_from_str(xml: &str, name: &str) -> Result<Unitcell> {
56 let doc = Document::parse(xml)?;
57 let entry = find_entry(&doc, "UNITCELL", name)?.ok_or_else(|| XmlError::UnknownEntry(name.to_string()))?;
58 parse_unitcell(entry)
59}
60
61pub fn read_graph_from_str(xml: &str, name: &str) -> Result<Graph> {
62 let doc = Document::parse(xml)?;
63 let entry = find_entry(&doc, "GRAPH", name)?.ok_or_else(|| XmlError::UnknownEntry(name.to_string()))?;
64 parse_graph(entry)
65}
66
67pub fn read_basis_from_file(path: impl AsRef<Path>, name: &str) -> Result<Basis> {
68 let xml = fs::read_to_string(path)?;
69 read_basis_from_str(&xml, name)
70}
71
72pub fn read_unitcell_from_file(path: impl AsRef<Path>, name: &str) -> Result<Unitcell> {
73 let xml = fs::read_to_string(path)?;
74 read_unitcell_from_str(&xml, name)
75}
76
77pub fn read_graph_from_file(path: impl AsRef<Path>, name: &str) -> Result<Graph> {
78 let xml = fs::read_to_string(path)?;
79 read_graph_from_str(&xml, name)
80}
81
82pub fn write_basis_to_string(name: &str, basis: &Basis) -> String {
83 let mut xml = String::new();
84 let _ = writeln!(xml, "<LATTICES>");
85 let _ = writeln!(xml, " <LATTICE name=\"{}\" dimension=\"{}\">", escape_attr(name), basis.dimension());
86 let _ = writeln!(xml, " <BASIS>");
87 for column in 0..basis.dimension() {
88 let mut vector = String::new();
89 for row in 0..basis.dimension() {
90 if row > 0 {
91 vector.push(' ');
92 }
93 let _ = write!(vector, "{}", basis.basis_vectors()[(row, column)]);
94 }
95 let _ = writeln!(xml, " <VECTOR>{}</VECTOR>", vector);
96 }
97 let _ = writeln!(xml, " </BASIS>");
98 let _ = writeln!(xml, " </LATTICE>");
99 let _ = writeln!(xml, "</LATTICES>");
100 xml
101}
102
103pub fn write_unitcell_to_string(name: &str, cell: &Unitcell) -> String {
104 let mut xml = String::new();
105 let _ = writeln!(xml, "<LATTICES>");
106 let _ = writeln!(xml, " <UNITCELL name=\"{}\" dimension=\"{}\" vertices=\"{}\">", escape_attr(name), cell.dimension(), cell.num_sites());
107 for site in 0..cell.num_sites() {
108 let site = cell.site(site);
109 let _ = write!(xml, " <VERTEX type=\"{}\">", site.site_type);
110 let _ = write!(xml, "<COORDINATE>");
111 for (index, value) in site.coordinate.iter().enumerate() {
112 if index > 0 {
113 xml.push(' ');
114 }
115 let _ = write!(xml, "{}", value);
116 }
117 let _ = writeln!(xml, "</COORDINATE></VERTEX>");
118 }
119 for bond in 0..cell.num_bonds() {
120 let bond = cell.bond(bond);
121 let _ = write!(xml, " <EDGE type=\"{}\">", bond.bond_type);
122 let _ = write!(xml, "<SOURCE vertex=\"{}\"/>", bond.source + 1);
123 let _ = write!(xml, "<TARGET vertex=\"{}\" offset=\"", bond.target + 1);
124 for (index, value) in bond.target_offset.iter().enumerate() {
125 if index > 0 {
126 xml.push(' ');
127 }
128 let _ = write!(xml, "{}", value);
129 }
130 let _ = writeln!(xml, "\"/></EDGE>");
131 }
132 let _ = writeln!(xml, " </UNITCELL>");
133 let _ = writeln!(xml, "</LATTICES>");
134 xml
135}
136
137pub fn write_graph_to_string(name: &str, graph: &Graph) -> String {
138 let mut xml = String::new();
139 let _ = writeln!(xml, "<LATTICES>");
140 if graph.dimension() > 0 {
141 let _ = writeln!(xml, " <GRAPH name=\"{}\" dimension=\"{}\" vertices=\"{}\">", escape_attr(name), graph.dimension(), graph.num_sites());
142 } else {
143 let _ = writeln!(xml, " <GRAPH name=\"{}\" vertices=\"{}\">", escape_attr(name), graph.num_sites());
144 }
145 for site in 0..graph.num_sites() {
146 if graph.dimension() > 0 {
147 let _ = write!(xml, " <VERTEX type=\"{}\">", graph.site_type(site));
148 let _ = write!(xml, "<COORDINATE>");
149 for (index, value) in graph.coordinate(site).iter().enumerate() {
150 if index > 0 {
151 xml.push(' ');
152 }
153 let _ = write!(xml, "{}", value);
154 }
155 let _ = writeln!(xml, "</COORDINATE></VERTEX>");
156 } else {
157 let _ = writeln!(xml, " <VERTEX type=\"{}\"/>", graph.site_type(site));
158 }
159 }
160 for bond in 0..graph.num_bonds() {
161 let _ = writeln!(xml, " <EDGE type=\"{}\" source=\"{}\" target=\"{}\"/>", graph.bond_type(bond), graph.source(bond) + 1, graph.target(bond) + 1);
162 }
163 let _ = writeln!(xml, " </GRAPH>");
164 let _ = writeln!(xml, "</LATTICES>");
165 xml
166}
167
168fn parse_basis(entry: Node<'_, '_>) -> Result<Basis> {
169 let dimension = parse_optional_usize(entry, "dimension");
170 let basis_node = entry.children().find(|child| child.is_element() && child.tag_name().name() == "BASIS")
171 .ok_or(XmlError::InvalidFormat("missing BASIS element"))?;
172 let mut vectors = Vec::new();
173 for vector_node in basis_node.children().filter(|child| child.is_element() && child.tag_name().name() == "VECTOR") {
174 vectors.push(parse_number_list::<f64>(vector_node.text().unwrap_or(""))?);
175 }
176 let dim = match dimension {
177 Some(value) => value,
178 None => vectors.len(),
179 };
180 if dim == 0 || vectors.len() != dim {
181 return Err(XmlError::DimensionMismatch("basis dimension mismatch"));
182 }
183 for vector in &vectors {
184 if vector.len() != dim {
185 return Err(XmlError::DimensionMismatch("basis dimension mismatch"));
186 }
187 }
188 let mut matrix = BasisMatrix::zeros(dim, dim);
189 for (column, vector) in vectors.iter().enumerate() {
190 for (row, value) in vector.iter().enumerate() {
191 matrix[(row, column)] = *value;
192 }
193 }
194 Ok(Basis::new(matrix))
195}
196
197fn parse_unitcell(entry: Node<'_, '_>) -> Result<Unitcell> {
198 let dimension = parse_required_usize(entry, "dimension")?;
199 let mut cell = Unitcell::new(dimension);
200 let site_count = parse_optional_usize(entry, "vertices").unwrap_or(0);
201
202 for vertex in entry.children().filter(|child| child.is_element() && child.tag_name().name() == "VERTEX") {
203 let site_type = parse_optional_i32(vertex, "type").unwrap_or(0);
204 let mut coordinate = CoordinateVector::zeros(dimension);
205 let mut found_coordinate = false;
206 for coordinate_node in vertex.children().filter(|child| child.is_element() && child.tag_name().name() == "COORDINATE") {
207 if found_coordinate {
208 return Err(XmlError::InvalidFormat("duplicated COORDINATE tag"));
209 }
210 let values = parse_number_list::<f64>(coordinate_node.text().unwrap_or(""))?;
211 if values.len() != dimension {
212 return Err(XmlError::DimensionMismatch("site coordinate dimension mismatch"));
213 }
214 for (index, value) in values.into_iter().enumerate() {
215 coordinate[index] = value;
216 }
217 found_coordinate = true;
218 }
219 cell.add_site(coordinate, site_type);
220 }
221
222 if cell.num_sites() > 0 {
223 if site_count > 0 && cell.num_sites() != site_count {
224 return Err(XmlError::DimensionMismatch("inconsistent number of sites"));
225 }
226 } else {
227 for _ in 0..site_count {
228 cell.add_site(CoordinateVector::zeros(dimension), 0);
229 }
230 }
231
232 for edge in entry.children().filter(|child| child.is_element() && child.tag_name().name() == "EDGE") {
233 let bond_type = parse_optional_i32(edge, "type").unwrap_or(0);
234 let source = parse_edge_vertex(edge, "SOURCE")?;
235 let target = parse_edge_vertex(edge, "TARGET")?;
236 let mut source_offset = OffsetVector::zeros(dimension);
237 let mut target_offset = OffsetVector::zeros(dimension);
238 for source_node in edge.children().filter(|child| child.is_element() && child.tag_name().name() == "SOURCE") {
239 if let Some(vertex) = source_node.attribute("vertex") {
240 let parsed = vertex.parse::<usize>().map_err(|_| XmlError::InvalidFormat("invalid SOURCE vertex"))?;
241 if parsed == 0 {
242 return Err(XmlError::InvalidFormat("SOURCE vertex is 1-based"));
243 }
244 }
245 if let Some(offset) = source_node.attribute("offset") {
246 let values = parse_number_list::<i64>(offset)?;
247 if values.len() != dimension {
248 return Err(XmlError::DimensionMismatch("SOURCE offset dimension mismatch"));
249 }
250 for (index, value) in values.into_iter().enumerate() {
251 source_offset[index] = value;
252 }
253 }
254 }
255 for target_node in edge.children().filter(|child| child.is_element() && child.tag_name().name() == "TARGET") {
256 if let Some(vertex) = target_node.attribute("vertex") {
257 let parsed = vertex.parse::<usize>().map_err(|_| XmlError::InvalidFormat("invalid TARGET vertex"))?;
258 if parsed == 0 {
259 return Err(XmlError::InvalidFormat("TARGET vertex is 1-based"));
260 }
261 }
262 if let Some(offset) = target_node.attribute("offset") {
263 let values = parse_number_list::<i64>(offset)?;
264 if values.len() != dimension {
265 return Err(XmlError::DimensionMismatch("TARGET offset dimension mismatch"));
266 }
267 for (index, value) in values.into_iter().enumerate() {
268 target_offset[index] = value;
269 }
270 }
271 }
272 let final_offset = target_offset - source_offset;
273 cell.add_bond(source - 1, target - 1, final_offset, bond_type);
274 }
275
276 Ok(cell)
277}
278
279fn parse_graph(entry: Node<'_, '_>) -> Result<Graph> {
280 let dimension = parse_optional_usize(entry, "dimension").unwrap_or(0);
281 let mut graph = Graph::new(dimension);
282 let site_count = parse_optional_usize(entry, "vertices").unwrap_or(0);
283
284 for vertex in entry.children().filter(|child| child.is_element() && child.tag_name().name() == "VERTEX") {
285 let site_type = parse_optional_i32(vertex, "type").unwrap_or(0);
286 let coordinate = if dimension > 0 {
287 let coordinate_node = vertex.children().find(|child| child.is_element() && child.tag_name().name() == "COORDINATE")
288 .ok_or(XmlError::InvalidFormat("missing COORDINATE tag"))?;
289 let values = parse_number_list::<f64>(coordinate_node.text().unwrap_or(""))?;
290 if values.len() != dimension {
291 return Err(XmlError::DimensionMismatch("site coordinate dimension mismatch"));
292 }
293 let mut coordinate = CoordinateVector::zeros(dimension);
294 for (index, value) in values.into_iter().enumerate() {
295 coordinate[index] = value;
296 }
297 coordinate
298 } else {
299 CoordinateVector::zeros(0)
300 };
301 graph.add_site(coordinate, site_type);
302 }
303
304 if graph.num_sites() > 0 {
305 if site_count > 0 && graph.num_sites() != site_count {
306 return Err(XmlError::DimensionMismatch("inconsistent number of sites"));
307 }
308 } else {
309 for _ in 0..site_count {
310 graph.add_site(CoordinateVector::zeros(dimension), 0);
311 }
312 }
313
314 for edge in entry.children().filter(|child| child.is_element() && child.tag_name().name() == "EDGE") {
315 let source = parse_required_usize_attr(edge, "source")?;
316 let target = parse_required_usize_attr(edge, "target")?;
317 let bond_type = parse_optional_i32(edge, "type").unwrap_or(0);
318 graph.add_bond(source - 1, target - 1, bond_type);
319 }
320
321 Ok(graph)
322}
323
324fn find_entry<'a>(doc: &'a Document<'a>, tag: &str, name: &str) -> Result<Option<Node<'a, 'a>>> {
325 let root = doc.root_element();
326 if root.tag_name().name() != "LATTICES" {
327 return Err(XmlError::InvalidFormat("root element must be LATTICES"));
328 }
329 Ok(root.children().find(|child| {
330 child.is_element()
331 && child.tag_name().name() == tag
332 && child.attribute("name") == Some(name)
333 }))
334}
335
336fn parse_required_usize(node: Node<'_, '_>, attr: &'static str) -> Result<usize> {
337 parse_required_attr(node, attr)?.parse::<usize>().map_err(|_| XmlError::InvalidFormat(attr))
338}
339
340fn parse_required_usize_attr(node: Node<'_, '_>, attr: &'static str) -> Result<usize> {
341 node.attribute(attr)
342 .ok_or(XmlError::MissingAttribute(attr))?
343 .parse::<usize>()
344 .map_err(|_| XmlError::InvalidFormat(attr))
345}
346
347fn parse_required_attr<'a>(node: Node<'a, 'a>, attr: &'static str) -> Result<&'a str> {
348 node.attribute(attr).ok_or(XmlError::MissingAttribute(attr))
349}
350
351fn parse_optional_usize(node: Node<'_, '_>, attr: &'static str) -> Option<usize> {
352 node.attribute(attr).and_then(|value| value.parse::<usize>().ok())
353}
354
355fn parse_optional_i32(node: Node<'_, '_>, attr: &'static str) -> Option<i32> {
356 node.attribute(attr).and_then(|value| value.parse::<i32>().ok())
357}
358
359fn parse_edge_vertex(edge: Node<'_, '_>, tag: &'static str) -> Result<usize> {
360 let node = edge.children().find(|child| child.is_element() && child.tag_name().name() == tag)
361 .ok_or(XmlError::MissingAttribute(tag))?;
362 let vertex = node.attribute("vertex").ok_or(XmlError::MissingAttribute("vertex"))?.parse::<usize>().map_err(|_| XmlError::InvalidFormat("invalid vertex"))?;
363 if vertex == 0 {
364 return Err(XmlError::InvalidFormat("vertex is 1-based"));
365 }
366 Ok(vertex)
367}
368
369fn parse_number_list<T>(text: &str) -> Result<Vec<T>>
370where
371 T: std::str::FromStr,
372{
373 let mut values = Vec::new();
374 for token in text.split_whitespace() {
375 values.push(token.parse::<T>().map_err(|_| XmlError::InvalidFormat("invalid numeric value"))?);
376 }
377 Ok(values)
378}
379
380fn escape_attr(text: &str) -> String {
381 text.replace('&', "&")
382 .replace('"', """)
383 .replace('<', "<")
384 .replace('>', ">")
385}