1use crate::basis::Basis;
2use crate::types::{Boundary, CoordinateVector, ExtentVector, OffsetVector};
3use crate::unitcell::Unitcell;
4
5#[derive(Clone, Debug, PartialEq)]
6pub struct Site {
7 pub site_type: i32,
8 pub neighbors: Vec<usize>,
9 pub neighbor_bonds: Vec<usize>,
10}
11
12#[derive(Clone, Debug, PartialEq)]
13pub struct Bond {
14 pub source: usize,
15 pub target: usize,
16 pub bond_type: i32,
17}
18
19#[derive(Clone, Debug, PartialEq)]
20pub struct Graph {
21 dim: usize,
22 sites: Vec<Site>,
23 coordinates: Vec<CoordinateVector>,
24 bonds: Vec<Bond>,
25}
26
27impl Graph {
28 pub fn new(dim: usize) -> Self {
29 Self { dim, sites: Vec::new(), coordinates: Vec::new(), bonds: Vec::new() }
30 }
31
32 pub fn simple(dim: usize, length: usize) -> Self {
33 let basis = Basis::simple(dim);
34 let cell = Unitcell::simple(dim);
35 let extent = ExtentVector::from_element(dim, length as i64);
36 let boundary = vec![Boundary::Periodic; dim];
37 Self::from_basis_unitcell_extent(&basis, &cell, &extent, &boundary)
38 }
39
40 pub fn fully_connected(num_sites: usize) -> Self {
41 let mut graph = Self::new(0);
42 let pos = CoordinateVector::zeros(0);
43 for _ in 0..num_sites {
44 graph.add_site(pos.clone(), 0);
45 }
46 for source in 0..num_sites {
47 for target in (source + 1)..num_sites {
48 graph.add_bond(source, target, 0);
49 }
50 }
51 graph
52 }
53
54 pub fn from_basis_unitcell_extent(
55 basis: &Basis,
56 cell: &Unitcell,
57 extent: &ExtentVector,
58 boundary: &[Boundary],
59 ) -> Self {
60 assert_eq!(cell.dimension(), basis.dimension(), "dimension mismatch");
61 assert_eq!(cell.dimension(), extent.len(), "dimension mismatch");
62 assert_eq!(cell.dimension(), boundary.len(), "dimension mismatch");
63 for value in extent.iter() {
64 assert!(*value > 0, "extent must be positive");
65 }
66
67 let dim = cell.dimension();
68 let mut graph = Self::new(dim);
69 let num_cells = extent.iter().fold(1usize, |acc, value| acc * (*value as usize));
70
71 for cell_index in 0..num_cells {
72 let cell_offset = index_to_offset(cell_index, extent);
73 let cell_offset_f = CoordinateVector::from_iterator(
74 dim,
75 cell_offset.iter().map(|value| *value as f64),
76 );
77 for site in 0..cell.num_sites() {
78 let coordinate = basis.basis_vectors() * (cell_offset_f.clone() + cell.site(site).coordinate.clone());
79 graph.add_site(coordinate, cell.site(site).site_type);
80 }
81 }
82
83 for cell_index in 0..num_cells {
84 let cell_offset = index_to_offset(cell_index, extent);
85 for bond_index in 0..cell.num_bonds() {
86 let bond = cell.bond(bond_index);
87 let mut target_offset = cell_offset.clone() + bond.target_offset.clone();
88 if !wrap_offset(&mut target_offset, extent, boundary) {
89 continue;
90 }
91 let target_cell = offset_to_index(&target_offset, extent);
92 let source_site = cell_index * cell.num_sites() + bond.source;
93 let target_site = target_cell * cell.num_sites() + bond.target;
94 if source_site != target_site {
95 graph.add_bond(source_site, target_site, bond.bond_type);
96 }
97 }
98 }
99
100 graph
101 }
102
103 pub fn from_basis_unitcell_length(
104 basis: &Basis,
105 cell: &Unitcell,
106 length: usize,
107 boundary: Boundary,
108 ) -> Self {
109 let extent = ExtentVector::from_element(cell.dimension(), length as i64);
110 let boundary = vec![boundary; cell.dimension()];
111 Self::from_basis_unitcell_extent(basis, cell, &extent, &boundary)
112 }
113
114 pub fn dimension(&self) -> usize {
115 self.dim
116 }
117
118 pub fn num_sites(&self) -> usize {
119 self.sites.len()
120 }
121
122 pub fn site_type(&self, site: usize) -> i32 {
123 self.sites[site].site_type
124 }
125
126 pub fn coordinate(&self, site: usize) -> &CoordinateVector {
127 &self.coordinates[site]
128 }
129
130 pub fn num_neighbors(&self, site: usize) -> usize {
131 self.sites[site].neighbors.len()
132 }
133
134 pub fn neighbor(&self, site: usize, neighbor: usize) -> usize {
135 self.sites[site].neighbors[neighbor]
136 }
137
138 pub fn neighbor_bond(&self, site: usize, neighbor: usize) -> usize {
139 self.sites[site].neighbor_bonds[neighbor]
140 }
141
142 pub fn num_bonds(&self) -> usize {
143 self.bonds.len()
144 }
145
146 pub fn bond_type(&self, bond: usize) -> i32 {
147 self.bonds[bond].bond_type
148 }
149
150 pub fn source(&self, bond: usize) -> usize {
151 self.bonds[bond].source
152 }
153
154 pub fn target(&self, bond: usize) -> usize {
155 self.bonds[bond].target
156 }
157
158 pub fn edge_sites(&self, bond: usize) -> (usize, usize) {
159 let bond = &self.bonds[bond];
160 (bond.source, bond.target)
161 }
162
163 pub fn add_site(&mut self, coordinate: CoordinateVector, site_type: i32) -> usize {
164 assert_eq!(coordinate.len(), self.dim, "dimension mismatch");
165 let index = self.sites.len();
166 self.sites.push(Site { site_type, neighbors: Vec::new(), neighbor_bonds: Vec::new() });
167 self.coordinates.push(coordinate);
168 index
169 }
170
171 pub fn add_bond(&mut self, source: usize, target: usize, bond_type: i32) -> usize {
172 if source >= self.num_sites() || target >= self.num_sites() {
173 panic!("site index out of range");
174 }
175 if source == target {
176 panic!("self loop is not allowed");
177 }
178 let index = self.bonds.len();
179 self.bonds.push(Bond { source, target, bond_type });
180 self.sites[source].neighbors.push(target);
181 self.sites[source].neighbor_bonds.push(index);
182 self.sites[target].neighbors.push(source);
183 self.sites[target].neighbor_bonds.push(index);
184 index
185 }
186}
187
188fn index_to_offset(index: usize, extent: &ExtentVector) -> OffsetVector {
189 let dim = extent.len();
190 let mut remainder = index;
191 let mut offset = OffsetVector::zeros(dim);
192 for axis in 0..dim {
193 let size = extent[axis] as usize;
194 offset[axis] = (remainder % size) as i64;
195 remainder /= size;
196 }
197 offset
198}
199
200fn offset_to_index(offset: &OffsetVector, extent: &ExtentVector) -> usize {
201 let mut index = 0usize;
202 let mut stride = 1usize;
203 for axis in 0..extent.len() {
204 index += (offset[axis] as usize) * stride;
205 stride *= extent[axis] as usize;
206 }
207 index
208}
209
210fn wrap_offset(offset: &mut OffsetVector, extent: &ExtentVector, boundary: &[Boundary]) -> bool {
211 for axis in 0..offset.len() {
212 let size = extent[axis];
213 let value = offset[axis];
214 match boundary[axis] {
215 Boundary::Open => {
216 if value < 0 || value >= size {
217 return false;
218 }
219 }
220 Boundary::Periodic => {
221 let wrapped = ((value % size) + size) % size;
222 offset[axis] = wrapped;
223 }
224 }
225 }
226 true
227}