Skip to main content

lattice_core/
graph.rs

1use crate::basis::Basis;
2use crate::types::{Boundary, CoordinateVector, ExtentVector, OffsetVector};
3use crate::unitcell::Unitcell;
4
5#[derive(Clone, Debug, PartialEq)]
6pub struct Site {
7  pub site_type: i32,
8  pub neighbors: Vec<usize>,
9  pub neighbor_bonds: Vec<usize>,
10}
11
12#[derive(Clone, Debug, PartialEq)]
13pub struct Bond {
14  pub source: usize,
15  pub target: usize,
16  pub bond_type: i32,
17}
18
19#[derive(Clone, Debug, PartialEq)]
20pub struct Graph {
21  dim: usize,
22  sites: Vec<Site>,
23  coordinates: Vec<CoordinateVector>,
24  bonds: Vec<Bond>,
25}
26
27impl Graph {
28  pub fn new(dim: usize) -> Self {
29    Self { dim, sites: Vec::new(), coordinates: Vec::new(), bonds: Vec::new() }
30  }
31
32  pub fn simple(dim: usize, length: usize) -> Self {
33    let basis = Basis::simple(dim);
34    let cell = Unitcell::simple(dim);
35    let extent = ExtentVector::from_element(dim, length as i64);
36    let boundary = vec![Boundary::Periodic; dim];
37    Self::from_basis_unitcell_extent(&basis, &cell, &extent, &boundary)
38  }
39
40  pub fn fully_connected(num_sites: usize) -> Self {
41    let mut graph = Self::new(0);
42    let pos = CoordinateVector::zeros(0);
43    for _ in 0..num_sites {
44      graph.add_site(pos.clone(), 0);
45    }
46    for source in 0..num_sites {
47      for target in (source + 1)..num_sites {
48        graph.add_bond(source, target, 0);
49      }
50    }
51    graph
52  }
53
54  pub fn from_basis_unitcell_extent(
55    basis: &Basis,
56    cell: &Unitcell,
57    extent: &ExtentVector,
58    boundary: &[Boundary],
59  ) -> Self {
60    assert_eq!(cell.dimension(), basis.dimension(), "dimension mismatch");
61    assert_eq!(cell.dimension(), extent.len(), "dimension mismatch");
62    assert_eq!(cell.dimension(), boundary.len(), "dimension mismatch");
63    for value in extent.iter() {
64      assert!(*value > 0, "extent must be positive");
65    }
66
67    let dim = cell.dimension();
68    let mut graph = Self::new(dim);
69    let num_cells = extent.iter().fold(1usize, |acc, value| acc * (*value as usize));
70
71    for cell_index in 0..num_cells {
72      let cell_offset = index_to_offset(cell_index, extent);
73      let cell_offset_f = CoordinateVector::from_iterator(
74        dim,
75        cell_offset.iter().map(|value| *value as f64),
76      );
77      for site in 0..cell.num_sites() {
78        let coordinate = basis.basis_vectors() * (cell_offset_f.clone() + cell.site(site).coordinate.clone());
79        graph.add_site(coordinate, cell.site(site).site_type);
80      }
81    }
82
83    for cell_index in 0..num_cells {
84      let cell_offset = index_to_offset(cell_index, extent);
85      for bond_index in 0..cell.num_bonds() {
86        let bond = cell.bond(bond_index);
87        let mut target_offset = cell_offset.clone() + bond.target_offset.clone();
88        if !wrap_offset(&mut target_offset, extent, boundary) {
89          continue;
90        }
91        let target_cell = offset_to_index(&target_offset, extent);
92        let source_site = cell_index * cell.num_sites() + bond.source;
93        let target_site = target_cell * cell.num_sites() + bond.target;
94        if source_site != target_site {
95          graph.add_bond(source_site, target_site, bond.bond_type);
96        }
97      }
98    }
99
100    graph
101  }
102
103  pub fn from_basis_unitcell_length(
104    basis: &Basis,
105    cell: &Unitcell,
106    length: usize,
107    boundary: Boundary,
108  ) -> Self {
109    let extent = ExtentVector::from_element(cell.dimension(), length as i64);
110    let boundary = vec![boundary; cell.dimension()];
111    Self::from_basis_unitcell_extent(basis, cell, &extent, &boundary)
112  }
113
114  pub fn dimension(&self) -> usize {
115    self.dim
116  }
117
118  pub fn num_sites(&self) -> usize {
119    self.sites.len()
120  }
121
122  pub fn site_type(&self, site: usize) -> i32 {
123    self.sites[site].site_type
124  }
125
126  pub fn coordinate(&self, site: usize) -> &CoordinateVector {
127    &self.coordinates[site]
128  }
129
130  pub fn num_neighbors(&self, site: usize) -> usize {
131    self.sites[site].neighbors.len()
132  }
133
134  pub fn neighbor(&self, site: usize, neighbor: usize) -> usize {
135    self.sites[site].neighbors[neighbor]
136  }
137
138  pub fn neighbor_bond(&self, site: usize, neighbor: usize) -> usize {
139    self.sites[site].neighbor_bonds[neighbor]
140  }
141
142  pub fn num_bonds(&self) -> usize {
143    self.bonds.len()
144  }
145
146  pub fn bond_type(&self, bond: usize) -> i32 {
147    self.bonds[bond].bond_type
148  }
149
150  pub fn source(&self, bond: usize) -> usize {
151    self.bonds[bond].source
152  }
153
154  pub fn target(&self, bond: usize) -> usize {
155    self.bonds[bond].target
156  }
157
158  pub fn edge_sites(&self, bond: usize) -> (usize, usize) {
159    let bond = &self.bonds[bond];
160    (bond.source, bond.target)
161  }
162
163  pub fn add_site(&mut self, coordinate: CoordinateVector, site_type: i32) -> usize {
164    assert_eq!(coordinate.len(), self.dim, "dimension mismatch");
165    let index = self.sites.len();
166    self.sites.push(Site { site_type, neighbors: Vec::new(), neighbor_bonds: Vec::new() });
167    self.coordinates.push(coordinate);
168    index
169  }
170
171  pub fn add_bond(&mut self, source: usize, target: usize, bond_type: i32) -> usize {
172    if source >= self.num_sites() || target >= self.num_sites() {
173      panic!("site index out of range");
174    }
175    if source == target {
176      panic!("self loop is not allowed");
177    }
178    let index = self.bonds.len();
179    self.bonds.push(Bond { source, target, bond_type });
180    self.sites[source].neighbors.push(target);
181    self.sites[source].neighbor_bonds.push(index);
182    self.sites[target].neighbors.push(source);
183    self.sites[target].neighbor_bonds.push(index);
184    index
185  }
186}
187
188fn index_to_offset(index: usize, extent: &ExtentVector) -> OffsetVector {
189  let dim = extent.len();
190  let mut remainder = index;
191  let mut offset = OffsetVector::zeros(dim);
192  for axis in 0..dim {
193    let size = extent[axis] as usize;
194    offset[axis] = (remainder % size) as i64;
195    remainder /= size;
196  }
197  offset
198}
199
200fn offset_to_index(offset: &OffsetVector, extent: &ExtentVector) -> usize {
201  let mut index = 0usize;
202  let mut stride = 1usize;
203  for axis in 0..extent.len() {
204    index += (offset[axis] as usize) * stride;
205    stride *= extent[axis] as usize;
206  }
207  index
208}
209
210fn wrap_offset(offset: &mut OffsetVector, extent: &ExtentVector, boundary: &[Boundary]) -> bool {
211  for axis in 0..offset.len() {
212    let size = extent[axis];
213    let value = offset[axis];
214    match boundary[axis] {
215      Boundary::Open => {
216        if value < 0 || value >= size {
217          return false;
218        }
219      }
220      Boundary::Periodic => {
221        let wrapped = ((value % size) + size) % size;
222        offset[axis] = wrapped;
223      }
224    }
225  }
226  true
227}