1use crate::util::vec_help::remove_doubles;
2use rand::prelude::*;
3use smallvec::{smallvec, SmallVec};
4use std::cmp::max;
5use std::fmt::{Debug, Error, Formatter};
6
7pub struct GraphState<R: Rng> {
9 pub(crate) edges: Vec<(Edge, f64)>,
10 pub(crate) binding_mat: Vec<Vec<(usize, f64)>>,
11 pub(crate) biases: Vec<f64>,
12 pub(crate) state: Option<Vec<bool>>,
13 cumulative_weight: Option<(Vec<f64>, f64)>,
14 pub(crate) rng: R,
15}
16
17impl<R: Rng> Debug for GraphState<R> {
18 fn fmt(&self, f: &mut Formatter<'_>) -> Result<(), Error> {
19 if let Some(state) = &self.state {
20 let s = state
21 .iter()
22 .map(|b| if *b { "1" } else { "0" })
23 .collect::<Vec<_>>()
24 .join("");
25 let e = self.get_energy();
26 f.write_str(&format!("{}\t{}", s, e))
27 } else {
28 f.write_str("Error")
29 }
30 }
31}
32
33impl<R: Rng + Clone> Clone for GraphState<R> {
34 fn clone(&self) -> Self {
35 Self {
36 edges: self.edges.clone(),
37 binding_mat: self.binding_mat.clone(),
38 biases: self.biases.clone(),
39 state: self.state.clone(),
40 cumulative_weight: self.cumulative_weight.clone(),
41 rng: self.rng.clone(),
42 }
43 }
44}
45
46#[derive(Copy, Clone, Debug)]
47enum WormMove {
48 Single(usize),
49 Double(usize, usize),
50}
51
52pub type Edge = (usize, usize);
54impl<R: Rng> GraphState<R> {
55 pub fn new(edges: &[(Edge, f64)], biases: &[f64], mut rng: R) -> Self {
57 let state = make_random_spin_state(biases.len(), &mut rng);
58 Self::new_with_state_and_rng(state, edges, biases, rng)
59 }
60
61 pub fn new_with_state_and_rng(
63 state: Vec<bool>,
64 edges: &[(Edge, f64)],
65 biases: &[f64],
66 rng: R,
67 ) -> Self {
68 let mut binding_mat: Vec<Vec<(usize, f64)>> = vec![vec![]; biases.len()];
70
71 edges.iter().for_each(|((va, vb), j)| {
72 binding_mat[*va].push((*vb, *j));
73 binding_mat[*vb].push((*va, *j));
74 });
75 binding_mat.iter_mut().for_each(|vs| {
77 vs.sort_by_key(|(i, _)| *i);
78 });
79
80 GraphState {
81 edges: edges.to_vec(),
82 binding_mat,
83 biases: biases.to_vec(),
84 state: Some(state),
85 cumulative_weight: None,
86 rng,
87 }
88 }
89
90 pub fn do_spin_flip(
92 rng: &mut R,
93 beta: f64,
94 binding_mat: &[Vec<(usize, f64)>],
95 biases: &[f64],
96 state: &mut [bool],
97 ) {
98 let random_index = rng.gen_range(0..state.len());
99 let curr_value = state[random_index];
100 let binding_slice = &binding_mat[random_index];
102 let delta_e: f64 = binding_slice
103 .iter()
104 .cloned()
105 .map(|(indx, j)| {
106 let old_coupling = if !(curr_value ^ state[indx]) {
107 1.0
108 } else {
109 -1.0
110 };
111 -2.0 * j * old_coupling
113 })
114 .sum();
115 let delta_e = delta_e + (2.0 * biases[random_index] * if curr_value { 1.0 } else { -1.0 });
116 if Self::should_flip(rng, beta, delta_e) {
117 state[random_index] = !state[random_index]
118 }
119 }
120
121 fn do_edge_flip(
123 rng: &mut R,
124 beta: f64,
125 edges: &[(Edge, f64)],
126 binding_mat: &[Vec<(usize, f64)>],
127 biases: &[f64],
128 state: &mut [bool],
129 cumulative_edge_weights: Option<(&[f64], f64)>,
130 ) {
131 let indx_edge = if let Some((cumulative_edge_weights, totalw)) = cumulative_edge_weights {
132 let p = rng.gen_range(0. ..totalw);
133 let indx = cumulative_edge_weights
134 .binary_search_by(|v| v.partial_cmp(&p).expect("Couldn't compare values"));
135 match indx {
136 Ok(indx) => indx,
137 Err(indx) => indx,
138 }
139 } else {
140 rng.gen_range(0..edges.len())
141 };
142 let ((va, vb), _) = edges[indx_edge];
143
144 let delta_e = |va: usize, vb: usize| -> f64 {
145 let delta_e = Self::delta_e(va, Some(vb), state, binding_mat);
146 delta_e + (2.0 * biases[va] * if state[va] { 1.0 } else { -1.0 })
147 };
148 let delta_e = delta_e(va, vb) + delta_e(vb, va);
149 if Self::should_flip(rng, beta, delta_e) {
150 state[va] = !state[va];
151 state[vb] = !state[vb];
152 }
153 }
154
155 fn delta_e(
156 v: usize,
157 omit: Option<usize>,
158 state: &[bool],
159 binding_mat: &[Vec<(usize, f64)>],
160 ) -> f64 {
161 let curr_value = state[v];
162 binding_mat[v]
163 .iter()
164 .cloned()
165 .filter(|(ov, _)| Some(*ov) != omit)
166 .map(|(indx, j)| {
167 let old_coupling = if !(curr_value ^ state[indx]) {
168 1.0
169 } else {
170 -1.0
171 };
172 -2.0 * j * old_coupling
174 })
175 .sum()
176 }
177
178 fn do_worm_flip(
180 rng: &mut R,
181 beta: f64,
182 binding_mat: &[Vec<(usize, f64)>],
183 biases: &[f64],
184 state: &mut [bool],
185 allow_doubles: bool,
186 ) {
187 let start_index = rng.gen_range(0..state.len());
188 let mut visit_path = vec![WormMove::Single(start_index)];
189 let mut last_index = start_index;
190
191 let delta_e = |wm: WormMove, state: &mut [bool]| -> f64 {
192 match wm {
193 WormMove::Single(va) => Self::delta_e(va, None, state, binding_mat),
194 WormMove::Double(va, vb) => {
195 let de = Self::delta_e(va, Some(vb), state, binding_mat);
196 de + Self::delta_e(vb, Some(va), state, binding_mat)
197 }
198 }
199 };
200 let starting_e = delta_e(WormMove::Single(start_index), state);
201 state[start_index] = !state[start_index];
202
203 let mut update_failed = false;
204
205 let mut smallstack = vec![];
206 loop {
207 smallstack.clear();
208 let sel_move = visit_path[visit_path.len() - 1];
209 let sel_var = match sel_move {
210 WormMove::Single(v) => v,
211 WormMove::Double(_, v) => v,
212 };
213 let mut any_resolve = false;
214 binding_mat[sel_var].iter().cloned().for_each(|(ov, _)| {
215 if ov != last_index {
216 let de = delta_e(WormMove::Single(ov), state);
217 if de.abs() < f64::EPSILON {
218 smallstack.push((WormMove::Single(ov), de));
219 } else if (de + starting_e).abs() < f64::EPSILON {
220 smallstack.push((WormMove::Single(ov), de));
221 any_resolve = true;
222 }
223 if allow_doubles {
226 state[ov] = !state[ov];
227 binding_mat[ov].iter().cloned().for_each(|(oov, _)| {
228 if oov != ov && oov != sel_var {
229 let de = delta_e(WormMove::Single(oov), state) + de;
230 if de.abs() < f64::EPSILON {
231 smallstack.push((WormMove::Double(ov, oov), de));
232 } else if (de + starting_e).abs() < f64::EPSILON {
233 smallstack.push((WormMove::Double(ov, oov), de));
234 any_resolve = true;
235 }
236 }
237 });
238 state[ov] = !state[ov];
240 }
241 }
242 });
243 if any_resolve {
244 smallstack.retain(|(_, de)| (de + starting_e).abs() < f64::EPSILON);
245 }
246 let (ov, de) = if !smallstack.is_empty() {
248 let choice = rng.gen_range(0..smallstack.len());
249 let (ov, de) = smallstack[choice];
250 visit_path.push(ov);
251 (ov, de)
252 } else {
253 let sel_move = match sel_move {
255 WormMove::Single(_) => sel_move,
256 WormMove::Double(va, vb) => WormMove::Double(vb, va),
257 };
258 visit_path.push(sel_move);
259
260 let de = delta_e(sel_move, state);
261 (sel_move, de)
262 };
263 match ov {
264 WormMove::Single(va) => {
265 state[va] = !state[va];
266 }
267 WormMove::Double(va, vb) => {
268 state[va] = !state[va];
269 state[vb] = !state[vb];
270 }
271 }
272 last_index = match (ov, sel_move) {
273 (WormMove::Single(_), WormMove::Single(v)) => v,
274 (WormMove::Single(_), WormMove::Double(_, v)) => v,
275 (WormMove::Double(v, _), _) => v,
276 };
277 if (de + starting_e).abs() < f64::EPSILON {
278 break;
280 }
281 if visit_path.len() > state.len() {
283 update_failed = true;
284 break;
285 }
286 }
287 let mut visit_path = visit_path
288 .into_iter()
289 .map(|wm| -> SmallVec<[usize; 2]> {
290 match wm {
291 WormMove::Single(v) => smallvec![v],
292 WormMove::Double(va, vb) => smallvec![va, vb],
293 }
294 })
295 .flatten()
296 .collect::<Vec<_>>();
297 visit_path.sort_unstable();
298 remove_doubles(&mut visit_path);
299
300 if !update_failed {
301 let total_he = visit_path
303 .iter()
304 .cloned()
305 .map(|v| 2.0 * biases[v] * if state[v] { 1.0 } else { -1.0 })
306 .sum();
307 if !Self::should_flip(rng, beta, total_he) {
308 visit_path.iter().cloned().for_each(|v| {
309 state[v] = !state[v];
310 })
311 }
312 } else {
313 visit_path.iter().cloned().for_each(|v| {
315 state[v] = !state[v];
316 })
317 }
318 }
319
320 pub fn enable_edge_importance_sampling(&mut self, enable: bool) {
322 self.cumulative_weight = if enable {
323 let v = Vec::with_capacity(self.edges.len());
324 let (v, totalw) =
325 self.edges
326 .iter()
327 .map(|(_, w)| *w)
328 .fold((v, 0.), |(mut accv, accw), w| {
329 accv.push(accw + w);
330 (accv, accw + w)
331 });
332 Some((v, totalw))
333 } else {
334 None
335 }
336 }
337
338 pub fn should_flip(rng: &mut R, beta: f64, delta_e: f64) -> bool {
340 if delta_e > 0.0 {
342 let chance = (-beta * delta_e).exp();
343 rng.gen::<f64>() < chance
344 } else {
345 true
346 }
347 }
348
349 pub fn do_time_step(
351 &mut self,
352 beta: f64,
353 nspinupdates: Option<usize>,
354 nedgeupdates: Option<usize>,
355 nwormupdates: Option<usize>,
356 only_basic_moves: Option<bool>,
357 ) -> Result<(), String> {
358 if let Some(mut spin_state) = self.state.take() {
360 let only_basic_moves = only_basic_moves.unwrap_or(false);
361 let nspinupdates = nspinupdates.unwrap_or_else(|| max(1, spin_state.len() / 2));
362 let nedgeupdates = nedgeupdates.unwrap_or_else(|| max(1, self.edges.len() / 2));
363 let nwormupdates = nwormupdates.unwrap_or(1);
364 let t = if only_basic_moves { 2 } else { 3 };
365 let choice: u8 = self.rng.gen_range(0..t);
366 match choice {
367 0 => (0..nspinupdates).for_each(|_| {
368 Self::do_spin_flip(
369 &mut self.rng,
370 beta,
371 &self.binding_mat,
372 &self.biases,
373 &mut spin_state,
374 )
375 }),
376 1 => (0..nedgeupdates).for_each(|_| {
377 Self::do_edge_flip(
378 &mut self.rng,
379 beta,
380 &self.edges,
381 &self.binding_mat,
382 &self.biases,
383 &mut spin_state,
384 self.cumulative_weight
385 .as_ref()
386 .map(|(v, w)| (v.as_slice(), *w)),
387 )
388 }),
389 2 => (0..nwormupdates).for_each(|_| {
390 Self::do_worm_flip(
391 &mut self.rng,
392 beta,
393 &self.binding_mat,
394 &self.biases,
395 &mut spin_state,
396 true,
397 )
398 }),
399 _ => unreachable!(),
400 }
401 self.state = Some(spin_state);
402 Ok(())
403 } else {
404 Err("No state to edit".to_string())
405 }
406 }
407
408 pub fn get_state(self) -> Vec<bool> {
410 self.state.unwrap()
411 }
412
413 pub fn clone_state(&self) -> Vec<bool> {
415 self.state.clone().unwrap()
416 }
417
418 pub fn state_ref(&self) -> &[bool] {
420 self.state.as_ref().unwrap()
421 }
422
423 pub fn set_state(&mut self, state: Vec<bool>) {
425 assert_eq!(self.state.as_ref().unwrap().len(), state.len());
426 self.state = Some(state)
427 }
428
429 pub fn get_energy(&self) -> f64 {
431 if let Some(state) = &self.state {
432 state.iter().enumerate().fold(0.0, |acc, (i, si)| {
433 let binding_slice = &self.binding_mat[i];
434 let total_e: f64 = binding_slice
435 .iter()
436 .map(|(indx, j)| -> f64 {
437 let old_coupling = if !(si ^ state[*indx]) { 1.0 } else { -1.0 };
438 j * old_coupling / 2.0
439 })
440 .sum();
441 let bias_e = if *si { -self.biases[i] } else { self.biases[i] };
442 acc + total_e + bias_e
443 })
444 } else {
445 std::f64::NAN
446 }
447 }
448}
449
450pub fn make_random_spin_state<R: Rng>(n: usize, rng: &mut R) -> Vec<bool> {
452 (0..n).map(|_| -> bool { rng.gen() }).collect()
453}
454
455#[cfg(test)]
456mod classic_tests {
457 use super::*;
458 use itertools::Itertools;
459 use std::cmp::{max, min};
460
461 fn two_d_periodic(l: usize) -> Vec<(Edge, f64)> {
462 let indices: Vec<(usize, usize)> = (0usize..l)
463 .map(|i| (0usize..l).map(|j| (i, j)).collect::<Vec<(usize, usize)>>())
464 .flatten()
465 .collect();
466 let f = |i, j| j * l + i;
467
468 let right_connects = indices
469 .iter()
470 .cloned()
471 .map(|(i, j)| ((f(i, j), f((i + 1) % l, j)), -1.0));
472 let down_connects = indices.iter().cloned().map(|(i, j)| {
473 (
474 (f(i, j), f(i, (j + 1) % l)),
475 if i % 2 == 0 { 1.0 } else { -1.0 },
476 )
477 });
478 right_connects.chain(down_connects).collect()
479 }
480
481 #[test]
482 fn test_worm_flip() {
483 let mut g = GraphState::new_with_state_and_rng(
484 vec![false, false, false],
485 &[((0, 1), 1.0), ((1, 2), 1.0), ((2, 0), 1.0)],
486 &[0., 0., 0.],
487 SmallRng::from_entropy(),
488 );
489 GraphState::do_worm_flip(
490 &mut g.rng,
491 1.0,
492 &g.binding_mat,
493 &g.biases,
494 g.state.as_mut().unwrap(),
495 false,
496 );
497 assert!(g.state.unwrap().into_iter().all(|b| b))
498 }
499
500 #[test]
501 fn test_worm_flip_bias() {
502 let mut g = GraphState::new_with_state_and_rng(
503 vec![false, false, false],
504 &[((0, 1), 1.0), ((1, 2), 1.0), ((2, 0), 1.0)],
505 &[-1., -1., -1.],
506 SmallRng::from_entropy(),
507 );
508 GraphState::do_worm_flip(
509 &mut g.rng,
510 1.0,
511 &g.binding_mat,
512 &g.biases,
513 g.state.as_mut().unwrap(),
514 false,
515 );
516 assert!(g.state.unwrap().into_iter().all(|b| b))
517 }
518
519 #[test]
520 fn test_worm_flip_bias_not() {
521 let mut g = GraphState::new_with_state_and_rng(
522 vec![false, false, false],
523 &[((0, 1), 1.0), ((1, 2), 1.0), ((2, 0), 1.0)],
524 &[1., 1., 1.],
525 SmallRng::from_entropy(),
526 );
527 GraphState::do_worm_flip(
528 &mut g.rng,
529 1000.,
530 &g.binding_mat,
531 &g.biases,
532 g.state.as_mut().unwrap(),
533 false,
534 );
535 assert!(g.state.unwrap().into_iter().all(|b| !b))
536 }
537
538 #[test]
539 fn test_worm_flip_bounce() {
540 let nvars = 20;
541 let edges = (0..nvars - 1)
542 .map(|x| ((x, x + 1), 1.0))
543 .collect::<Vec<_>>();
544 let mut biases = vec![0.0; nvars];
545 biases[0] = 10.;
546 biases[nvars - 1] = 10.;
547 let mut g = GraphState::new_with_state_and_rng(
548 vec![false; nvars],
549 &edges,
550 &biases,
551 SmallRng::from_entropy(),
552 );
553 GraphState::do_worm_flip(
554 &mut g.rng,
555 1000.0,
556 &g.binding_mat,
557 &g.biases,
558 g.state.as_mut().unwrap(),
559 false,
560 );
561 assert!(g.state.unwrap().into_iter().all(|b| !b))
562 }
563
564 #[test]
565 fn test_worm_flip_doubles() {
566 let mut g = GraphState::new_with_state_and_rng(
567 vec![false, false, false],
568 &[((0, 1), 1.0), ((1, 2), 1.0), ((2, 0), 1.0)],
569 &[0., 0., 0.],
570 SmallRng::from_entropy(),
571 );
572 GraphState::do_worm_flip(
573 &mut g.rng,
574 1.0,
575 &g.binding_mat,
576 &g.biases,
577 g.state.as_mut().unwrap(),
578 true,
579 );
580 assert!(g.state.unwrap().into_iter().all_equal())
581 }
582
583 #[test]
584 fn test_worm_2d() {
585 let l = 4;
586 let nvars = l * l;
587 let edges = two_d_periodic(l);
588 let biases = vec![0.0; nvars];
589 let mut g = GraphState::new_with_state_and_rng(
590 vec![false; nvars],
591 &edges,
592 &biases,
593 SmallRng::from_entropy(),
594 );
595 GraphState::do_worm_flip(
596 &mut g.rng,
597 1000.0,
598 &g.binding_mat,
599 &g.biases,
600 g.state.as_mut().unwrap(),
601 true,
602 );
603 }
604
605 fn bathroom_unit_cells(l: usize) -> Vec<(Edge, f64)> {
606 let mut edges = vec![];
607 for x in 0..l {
608 for y in 0..l {
609 for i in 0..4 {
610 let va = y * l * 4 + x * 4 + i;
611 let vb = y * l * 4 + x * 4 + (i + 1) % 4;
612 edges.push(((min(va, vb), max(va, vb)), if i == 0 { 1.0 } else { -1.0 }))
613 }
614 let va = y * l * 4 + x * 4 + 1;
615 let vb = y * l * 4 + ((x + 1) % l) * 4 + 3;
616 edges.push(((min(va, vb), max(va, vb)), -1.0));
617
618 let va = y * l * 4 + x * 4;
619 let vb = ((y + 1) % l) * l * 4 + x * 4 + 2;
620 edges.push(((min(va, vb), max(va, vb)), -1.0));
621 }
622 }
623
624 edges
625 }
626
627 #[test]
628 fn test_worm_2d_bathroom() {
629 let l = 16;
630 let nvars = l * l * 4;
631 let edges = bathroom_unit_cells(l);
632 let biases = vec![0.0; nvars];
633 let mut g = GraphState::new_with_state_and_rng(
634 vec![false; nvars],
635 &edges,
636 &biases,
637 SmallRng::from_entropy(),
638 );
639 GraphState::do_worm_flip(
640 &mut g.rng,
641 1000.0,
642 &g.binding_mat,
643 &g.biases,
644 g.state.as_mut().unwrap(),
645 true,
646 );
647 }
648}