Skip to main content

lattice_core/
xml.rs

1use crate::basis::Basis;
2use crate::graph::Graph;
3use crate::types::{BasisMatrix, CoordinateVector, OffsetVector};
4use crate::unitcell::Unitcell;
5use roxmltree::{Document, Node};
6use std::fmt::{self, Write as _};
7use std::fs;
8use std::path::Path;
9
10#[derive(Debug)]
11pub enum XmlError {
12  Parse(roxmltree::Error),
13  Io(std::io::Error),
14  MissingAttribute(&'static str),
15  InvalidFormat(&'static str),
16  DimensionMismatch(&'static str),
17  UnknownEntry(String),
18}
19
20impl fmt::Display for XmlError {
21  fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
22    match self {
23      XmlError::Parse(err) => write!(f, "XML parse error: {err}"),
24      XmlError::Io(err) => write!(f, "I/O error: {err}"),
25      XmlError::MissingAttribute(attr) => write!(f, "missing XML attribute: {attr}"),
26      XmlError::InvalidFormat(msg) => write!(f, "invalid XML format: {msg}"),
27      XmlError::DimensionMismatch(msg) => write!(f, "dimension mismatch: {msg}"),
28      XmlError::UnknownEntry(name) => write!(f, "lattice entry not found: {name}"),
29    }
30  }
31}
32
33impl std::error::Error for XmlError {}
34
35impl From<roxmltree::Error> for XmlError {
36  fn from(value: roxmltree::Error) -> Self {
37    XmlError::Parse(value)
38  }
39}
40
41impl From<std::io::Error> for XmlError {
42  fn from(value: std::io::Error) -> Self {
43    XmlError::Io(value)
44  }
45}
46
47pub type Result<T> = std::result::Result<T, XmlError>;
48
49pub fn read_basis_from_str(xml: &str, name: &str) -> Result<Basis> {
50  let doc = Document::parse(xml)?;
51  let entry = find_entry(&doc, "LATTICE", name)?.ok_or_else(|| XmlError::UnknownEntry(name.to_string()))?;
52  parse_basis(entry)
53}
54
55pub fn read_unitcell_from_str(xml: &str, name: &str) -> Result<Unitcell> {
56  let doc = Document::parse(xml)?;
57  let entry = find_entry(&doc, "UNITCELL", name)?.ok_or_else(|| XmlError::UnknownEntry(name.to_string()))?;
58  parse_unitcell(entry)
59}
60
61pub fn read_graph_from_str(xml: &str, name: &str) -> Result<Graph> {
62  let doc = Document::parse(xml)?;
63  let entry = find_entry(&doc, "GRAPH", name)?.ok_or_else(|| XmlError::UnknownEntry(name.to_string()))?;
64  parse_graph(entry)
65}
66
67pub fn read_basis_from_file(path: impl AsRef<Path>, name: &str) -> Result<Basis> {
68  let xml = fs::read_to_string(path)?;
69  read_basis_from_str(&xml, name)
70}
71
72pub fn read_unitcell_from_file(path: impl AsRef<Path>, name: &str) -> Result<Unitcell> {
73  let xml = fs::read_to_string(path)?;
74  read_unitcell_from_str(&xml, name)
75}
76
77pub fn read_graph_from_file(path: impl AsRef<Path>, name: &str) -> Result<Graph> {
78  let xml = fs::read_to_string(path)?;
79  read_graph_from_str(&xml, name)
80}
81
82pub fn write_basis_to_string(name: &str, basis: &Basis) -> String {
83  let mut xml = String::new();
84  let _ = writeln!(xml, "<LATTICES>");
85  let _ = writeln!(xml, "  <LATTICE name=\"{}\" dimension=\"{}\">", escape_attr(name), basis.dimension());
86  let _ = writeln!(xml, "    <BASIS>");
87  for column in 0..basis.dimension() {
88    let mut vector = String::new();
89    for row in 0..basis.dimension() {
90      if row > 0 {
91        vector.push(' ');
92      }
93      let _ = write!(vector, "{}", basis.basis_vectors()[(row, column)]);
94    }
95    let _ = writeln!(xml, "      <VECTOR>{}</VECTOR>", vector);
96  }
97  let _ = writeln!(xml, "    </BASIS>");
98  let _ = writeln!(xml, "  </LATTICE>");
99  let _ = writeln!(xml, "</LATTICES>");
100  xml
101}
102
103pub fn write_unitcell_to_string(name: &str, cell: &Unitcell) -> String {
104  let mut xml = String::new();
105  let _ = writeln!(xml, "<LATTICES>");
106  let _ = writeln!(xml, "  <UNITCELL name=\"{}\" dimension=\"{}\" vertices=\"{}\">", escape_attr(name), cell.dimension(), cell.num_sites());
107  for site in 0..cell.num_sites() {
108    let site = cell.site(site);
109    let _ = write!(xml, "    <VERTEX type=\"{}\">", site.site_type);
110    let _ = write!(xml, "<COORDINATE>");
111    for (index, value) in site.coordinate.iter().enumerate() {
112      if index > 0 {
113        xml.push(' ');
114      }
115      let _ = write!(xml, "{}", value);
116    }
117    let _ = writeln!(xml, "</COORDINATE></VERTEX>");
118  }
119  for bond in 0..cell.num_bonds() {
120    let bond = cell.bond(bond);
121    let _ = write!(xml, "    <EDGE type=\"{}\">", bond.bond_type);
122    let _ = write!(xml, "<SOURCE vertex=\"{}\"/>", bond.source + 1);
123    let _ = write!(xml, "<TARGET vertex=\"{}\" offset=\"", bond.target + 1);
124    for (index, value) in bond.target_offset.iter().enumerate() {
125      if index > 0 {
126        xml.push(' ');
127      }
128      let _ = write!(xml, "{}", value);
129    }
130    let _ = writeln!(xml, "\"/></EDGE>");
131  }
132  let _ = writeln!(xml, "  </UNITCELL>");
133  let _ = writeln!(xml, "</LATTICES>");
134  xml
135}
136
137pub fn write_graph_to_string(name: &str, graph: &Graph) -> String {
138  let mut xml = String::new();
139  let _ = writeln!(xml, "<LATTICES>");
140  if graph.dimension() > 0 {
141    let _ = writeln!(xml, "  <GRAPH name=\"{}\" dimension=\"{}\" vertices=\"{}\">", escape_attr(name), graph.dimension(), graph.num_sites());
142  } else {
143    let _ = writeln!(xml, "  <GRAPH name=\"{}\" vertices=\"{}\">", escape_attr(name), graph.num_sites());
144  }
145  for site in 0..graph.num_sites() {
146    if graph.dimension() > 0 {
147      let _ = write!(xml, "    <VERTEX type=\"{}\">", graph.site_type(site));
148      let _ = write!(xml, "<COORDINATE>");
149      for (index, value) in graph.coordinate(site).iter().enumerate() {
150        if index > 0 {
151          xml.push(' ');
152        }
153        let _ = write!(xml, "{}", value);
154      }
155      let _ = writeln!(xml, "</COORDINATE></VERTEX>");
156    } else {
157      let _ = writeln!(xml, "    <VERTEX type=\"{}\"/>", graph.site_type(site));
158    }
159  }
160  for bond in 0..graph.num_bonds() {
161    let _ = writeln!(xml, "    <EDGE type=\"{}\" source=\"{}\" target=\"{}\"/>", graph.bond_type(bond), graph.source(bond) + 1, graph.target(bond) + 1);
162  }
163  let _ = writeln!(xml, "  </GRAPH>");
164  let _ = writeln!(xml, "</LATTICES>");
165  xml
166}
167
168fn parse_basis(entry: Node<'_, '_>) -> Result<Basis> {
169  let dimension = parse_optional_usize(entry, "dimension");
170  let basis_node = entry.children().find(|child| child.is_element() && child.tag_name().name() == "BASIS")
171    .ok_or(XmlError::InvalidFormat("missing BASIS element"))?;
172  let mut vectors = Vec::new();
173  for vector_node in basis_node.children().filter(|child| child.is_element() && child.tag_name().name() == "VECTOR") {
174    vectors.push(parse_number_list::<f64>(vector_node.text().unwrap_or(""))?);
175  }
176  let dim = match dimension {
177    Some(value) => value,
178    None => vectors.len(),
179  };
180  if dim == 0 || vectors.len() != dim {
181    return Err(XmlError::DimensionMismatch("basis dimension mismatch"));
182  }
183  for vector in &vectors {
184    if vector.len() != dim {
185      return Err(XmlError::DimensionMismatch("basis dimension mismatch"));
186    }
187  }
188  let mut matrix = BasisMatrix::zeros(dim, dim);
189  for (column, vector) in vectors.iter().enumerate() {
190    for (row, value) in vector.iter().enumerate() {
191      matrix[(row, column)] = *value;
192    }
193  }
194  Ok(Basis::new(matrix))
195}
196
197fn parse_unitcell(entry: Node<'_, '_>) -> Result<Unitcell> {
198  let dimension = parse_required_usize(entry, "dimension")?;
199  let mut cell = Unitcell::new(dimension);
200  let site_count = parse_optional_usize(entry, "vertices").unwrap_or(0);
201
202  for vertex in entry.children().filter(|child| child.is_element() && child.tag_name().name() == "VERTEX") {
203    let site_type = parse_optional_i32(vertex, "type").unwrap_or(0);
204    let mut coordinate = CoordinateVector::zeros(dimension);
205    let mut found_coordinate = false;
206    for coordinate_node in vertex.children().filter(|child| child.is_element() && child.tag_name().name() == "COORDINATE") {
207      if found_coordinate {
208        return Err(XmlError::InvalidFormat("duplicated COORDINATE tag"));
209      }
210      let values = parse_number_list::<f64>(coordinate_node.text().unwrap_or(""))?;
211      if values.len() != dimension {
212        return Err(XmlError::DimensionMismatch("site coordinate dimension mismatch"));
213      }
214      for (index, value) in values.into_iter().enumerate() {
215        coordinate[index] = value;
216      }
217      found_coordinate = true;
218    }
219    cell.add_site(coordinate, site_type);
220  }
221
222  if cell.num_sites() > 0 {
223    if site_count > 0 && cell.num_sites() != site_count {
224      return Err(XmlError::DimensionMismatch("inconsistent number of sites"));
225    }
226  } else {
227    for _ in 0..site_count {
228      cell.add_site(CoordinateVector::zeros(dimension), 0);
229    }
230  }
231
232  for edge in entry.children().filter(|child| child.is_element() && child.tag_name().name() == "EDGE") {
233    let bond_type = parse_optional_i32(edge, "type").unwrap_or(0);
234    let source = parse_edge_vertex(edge, "SOURCE")?;
235    let target = parse_edge_vertex(edge, "TARGET")?;
236    let mut source_offset = OffsetVector::zeros(dimension);
237    let mut target_offset = OffsetVector::zeros(dimension);
238    for source_node in edge.children().filter(|child| child.is_element() && child.tag_name().name() == "SOURCE") {
239      if let Some(vertex) = source_node.attribute("vertex") {
240        let parsed = vertex.parse::<usize>().map_err(|_| XmlError::InvalidFormat("invalid SOURCE vertex"))?;
241        if parsed == 0 {
242          return Err(XmlError::InvalidFormat("SOURCE vertex is 1-based"));
243        }
244      }
245      if let Some(offset) = source_node.attribute("offset") {
246        let values = parse_number_list::<i64>(offset)?;
247        if values.len() != dimension {
248          return Err(XmlError::DimensionMismatch("SOURCE offset dimension mismatch"));
249        }
250        for (index, value) in values.into_iter().enumerate() {
251          source_offset[index] = value;
252        }
253      }
254    }
255    for target_node in edge.children().filter(|child| child.is_element() && child.tag_name().name() == "TARGET") {
256      if let Some(vertex) = target_node.attribute("vertex") {
257        let parsed = vertex.parse::<usize>().map_err(|_| XmlError::InvalidFormat("invalid TARGET vertex"))?;
258        if parsed == 0 {
259          return Err(XmlError::InvalidFormat("TARGET vertex is 1-based"));
260        }
261      }
262      if let Some(offset) = target_node.attribute("offset") {
263        let values = parse_number_list::<i64>(offset)?;
264        if values.len() != dimension {
265          return Err(XmlError::DimensionMismatch("TARGET offset dimension mismatch"));
266        }
267        for (index, value) in values.into_iter().enumerate() {
268          target_offset[index] = value;
269        }
270      }
271    }
272    let final_offset = target_offset - source_offset;
273    cell.add_bond(source - 1, target - 1, final_offset, bond_type);
274  }
275
276  Ok(cell)
277}
278
279fn parse_graph(entry: Node<'_, '_>) -> Result<Graph> {
280  let dimension = parse_optional_usize(entry, "dimension").unwrap_or(0);
281  let mut graph = Graph::new(dimension);
282  let site_count = parse_optional_usize(entry, "vertices").unwrap_or(0);
283
284  for vertex in entry.children().filter(|child| child.is_element() && child.tag_name().name() == "VERTEX") {
285    let site_type = parse_optional_i32(vertex, "type").unwrap_or(0);
286    let coordinate = if dimension > 0 {
287      let coordinate_node = vertex.children().find(|child| child.is_element() && child.tag_name().name() == "COORDINATE")
288        .ok_or(XmlError::InvalidFormat("missing COORDINATE tag"))?;
289      let values = parse_number_list::<f64>(coordinate_node.text().unwrap_or(""))?;
290      if values.len() != dimension {
291        return Err(XmlError::DimensionMismatch("site coordinate dimension mismatch"));
292      }
293      let mut coordinate = CoordinateVector::zeros(dimension);
294      for (index, value) in values.into_iter().enumerate() {
295        coordinate[index] = value;
296      }
297      coordinate
298    } else {
299      CoordinateVector::zeros(0)
300    };
301    graph.add_site(coordinate, site_type);
302  }
303
304  if graph.num_sites() > 0 {
305    if site_count > 0 && graph.num_sites() != site_count {
306      return Err(XmlError::DimensionMismatch("inconsistent number of sites"));
307    }
308  } else {
309    for _ in 0..site_count {
310      graph.add_site(CoordinateVector::zeros(dimension), 0);
311    }
312  }
313
314  for edge in entry.children().filter(|child| child.is_element() && child.tag_name().name() == "EDGE") {
315    let source = parse_required_usize_attr(edge, "source")?;
316    let target = parse_required_usize_attr(edge, "target")?;
317    let bond_type = parse_optional_i32(edge, "type").unwrap_or(0);
318    graph.add_bond(source - 1, target - 1, bond_type);
319  }
320
321  Ok(graph)
322}
323
324fn find_entry<'a>(doc: &'a Document<'a>, tag: &str, name: &str) -> Result<Option<Node<'a, 'a>>> {
325  let root = doc.root_element();
326  if root.tag_name().name() != "LATTICES" {
327    return Err(XmlError::InvalidFormat("root element must be LATTICES"));
328  }
329  Ok(root.children().find(|child| {
330    child.is_element()
331      && child.tag_name().name() == tag
332      && child.attribute("name") == Some(name)
333  }))
334}
335
336fn parse_required_usize(node: Node<'_, '_>, attr: &'static str) -> Result<usize> {
337  parse_required_attr(node, attr)?.parse::<usize>().map_err(|_| XmlError::InvalidFormat(attr))
338}
339
340fn parse_required_usize_attr(node: Node<'_, '_>, attr: &'static str) -> Result<usize> {
341  node.attribute(attr)
342    .ok_or(XmlError::MissingAttribute(attr))?
343    .parse::<usize>()
344    .map_err(|_| XmlError::InvalidFormat(attr))
345}
346
347fn parse_required_attr<'a>(node: Node<'a, 'a>, attr: &'static str) -> Result<&'a str> {
348  node.attribute(attr).ok_or(XmlError::MissingAttribute(attr))
349}
350
351fn parse_optional_usize(node: Node<'_, '_>, attr: &'static str) -> Option<usize> {
352  node.attribute(attr).and_then(|value| value.parse::<usize>().ok())
353}
354
355fn parse_optional_i32(node: Node<'_, '_>, attr: &'static str) -> Option<i32> {
356  node.attribute(attr).and_then(|value| value.parse::<i32>().ok())
357}
358
359fn parse_edge_vertex(edge: Node<'_, '_>, tag: &'static str) -> Result<usize> {
360  let node = edge.children().find(|child| child.is_element() && child.tag_name().name() == tag)
361    .ok_or(XmlError::MissingAttribute(tag))?;
362  let vertex = node.attribute("vertex").ok_or(XmlError::MissingAttribute("vertex"))?.parse::<usize>().map_err(|_| XmlError::InvalidFormat("invalid vertex"))?;
363  if vertex == 0 {
364    return Err(XmlError::InvalidFormat("vertex is 1-based"));
365  }
366  Ok(vertex)
367}
368
369fn parse_number_list<T>(text: &str) -> Result<Vec<T>>
370where
371  T: std::str::FromStr,
372{
373  let mut values = Vec::new();
374  for token in text.split_whitespace() {
375    values.push(token.parse::<T>().map_err(|_| XmlError::InvalidFormat("invalid numeric value"))?);
376  }
377  Ok(values)
378}
379
380fn escape_attr(text: &str) -> String {
381  text.replace('&', "&amp;")
382    .replace('"', "&quot;")
383    .replace('<', "&lt;")
384    .replace('>', "&gt;")
385}