1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
use crate::sse::qmc_traits::op_container::*;
use crate::sse::qmc_types::*;
use crate::util::allocator::{Factory, StackTuplizer};
use rand::Rng;
#[cfg(feature = "serialize")]
use serde::{Deserialize, Serialize};
use std::cmp::min;

/// The location in imaginary time (p) and the relative index of the variable.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[cfg_attr(feature = "serialize", derive(Serialize, Deserialize))]
pub struct PRel {
    /// Position in imaginary time.
    pub p: usize,
    /// Reltive index of variable.
    pub relv: usize,
}

impl From<(usize, usize)> for PRel {
    fn from((p, relv): (usize, usize)) -> Self {
        Self { p, relv }
    }
}

/// Add loop updates to OpContainer.
pub trait LoopUpdater: OpContainer + Factory<Vec<Leg>> + Factory<Vec<f64>> {
    /// The type used to contain the Op and handle movement around the worldlines.
    type Node: OpNode<Self::Op>;

    /// Get a ref to a node at position p
    fn get_node_ref(&self, p: usize) -> Option<&Self::Node>;
    /// Get a mutable ref to the node at position p
    fn get_node_mut(&mut self, p: usize) -> Option<&mut Self::Node>;

    /// Get the first occupied p if it exists.
    fn get_first_p(&self) -> Option<usize>;
    /// Get the last occupied p if it exists.
    fn get_last_p(&self) -> Option<usize>;
    /// Get the first p occupied which covers variable `var`, also returns the relative index.
    fn get_first_p_for_var(&self, var: usize) -> Option<PRel>;
    /// Get the last p occupied which covers variable `var`, also returns the relative index.
    fn get_last_p_for_var(&self, var: usize) -> Option<PRel>;

    /// Get the previous occupied p compared to `node`.
    fn get_previous_p(&self, node: &Self::Node) -> Option<usize>;
    /// Get the next occupied p compared to `node`.
    fn get_next_p(&self, node: &Self::Node) -> Option<usize>;

    /// Get the previous p for a given var, takes the relative var index in node. Also returns the
    /// new relative var index.
    fn get_previous_p_for_rel_var(&self, relvar: usize, node: &Self::Node) -> Option<PRel>;
    /// Get the next p for a given var, takes the relative var index in node. Also returns the new
    /// relative var index.
    fn get_next_p_for_rel_var(&self, relvar: usize, node: &Self::Node) -> Option<PRel>;

    /// Get the previous p for a given var.
    fn get_previous_p_for_var(&self, var: usize, node: &Self::Node) -> Result<Option<PRel>, &str> {
        let relvar = node.get_op_ref().index_of_var(var);
        if let Some(relvar) = relvar {
            Ok(self.get_previous_p_for_rel_var(relvar, node))
        } else {
            Err("Variable not present on given node")
        }
    }
    /// Get the next p for a given var.
    fn get_next_p_for_var(&self, var: usize, node: &Self::Node) -> Result<Option<PRel>, &str> {
        let relvar = node.get_op_ref().index_of_var(var);
        if let Some(relvar) = relvar {
            Ok(self.get_next_p_for_rel_var(relvar, node))
        } else {
            Err("Variable not present on given node")
        }
    }

    /// Get the nth occupied p.
    fn get_nth_p(&self, n: usize) -> usize {
        let acc = self
            .get_first_p()
            .map(|p| (p, self.get_node_ref(p).unwrap()))
            .unwrap();
        (0..n)
            .fold(acc, |(_, opnode), _| {
                let p = self.get_next_p(opnode).unwrap();
                (p, self.get_node_ref(p).unwrap())
            })
            .0
    }

    /// Returns if a given variable is covered by any ops.
    fn does_var_have_ops(&self, var: usize) -> bool {
        self.get_first_p_for_var(var).is_some()
    }

    /// Make a loop update to the graph with thread rng.
    fn make_loop_update<H>(&mut self, initial_n: Option<usize>, hamiltonian: H, state: &mut [bool])
    where
        H: Fn(&[usize], usize, &[bool], &[bool]) -> f64,
    {
        self.make_loop_update_with_rng(initial_n, hamiltonian, state, &mut rand::thread_rng())
    }

    /// Make a loop update to the graph.
    fn make_loop_update_with_rng<H, R: Rng>(
        &mut self,
        initial_n: Option<usize>,
        hamiltonian: H,
        state: &mut [bool],
        rng: &mut R,
    ) where
        H: Fn(&[usize], usize, &[bool], &[bool]) -> f64,
    {
        let h = |op: &Self::Op, entrance: Leg, exit: Leg| -> f64 {
            let mut inputs = op.clone_inputs();
            let mut outputs = op.clone_outputs();
            adjust_states(inputs.as_mut(), outputs.as_mut(), entrance);
            adjust_states(inputs.as_mut(), outputs.as_mut(), exit);
            // Call the supplied hamiltonian.
            hamiltonian(
                &op.get_vars(),
                op.get_bond(),
                inputs.as_ref(),
                outputs.as_ref(),
            )
        };

        if self.get_n() > 0 {
            let initial_n = initial_n
                .map(|n| min(n, self.get_n()))
                .unwrap_or_else(|| rng.gen_range(0, self.get_n()));
            let nth_p = self.get_nth_p(initial_n);
            // Get starting leg for pth op.
            let op = self.get_node_ref(nth_p).unwrap();
            let n_vars = op.get_op_ref().get_vars().len();
            let initial_var = rng.gen_range(0, n_vars);
            let initial_direction = if rng.gen() {
                OpSide::Inputs
            } else {
                OpSide::Outputs
            };
            let initial_leg = (initial_var, initial_direction);

            apply_loop_update(
                self,
                (nth_p, initial_leg),
                nth_p,
                initial_leg,
                h,
                state,
                rng,
            );
        }
        self.post_loop_update_hook();
    }

    /// Called after an update.
    fn post_loop_update_hook(&mut self) {}
}

/// Allow recursive loop updates with a trampoline mechanic
#[derive(Debug, Clone, Copy)]
enum LoopResult {
    Return,
    Iterate(usize, Leg),
}

/// Apply loop update logic
fn apply_loop_update<L: LoopUpdater + ?Sized, H, R: Rng>(
    l: &mut L,
    initial_op_and_leg: (usize, Leg),
    mut sel_op_pos: usize,
    mut entrance_leg: Leg,
    h: H,
    state: &mut [bool],
    rng: &mut R,
) where
    H: Copy + Fn(&L::Op, Leg, Leg) -> f64,
{
    loop {
        let res = loop_body(
            l,
            initial_op_and_leg,
            sel_op_pos,
            entrance_leg,
            h,
            state,
            rng,
        );
        match res {
            LoopResult::Return => break,
            LoopResult::Iterate(new_sel_op_pos, new_entrance_leg) => {
                sel_op_pos = new_sel_op_pos;
                entrance_leg = new_entrance_leg;
            }
        }
    }
}

/// Apply loop update logic.
fn loop_body<L: LoopUpdater + ?Sized, H, R: Rng>(
    l: &mut L,
    initial_op_and_leg: (usize, Leg),
    sel_op_pos: usize,
    entrance_leg: Leg,
    h: H,
    state: &mut [bool],
    rng: &mut R,
) -> LoopResult
where
    H: Fn(&L::Op, Leg, Leg) -> f64,
{
    let mut legs = StackTuplizer::<Leg, f64>::new(l);
    let sel_opnode = l.get_node_mut(sel_op_pos).unwrap();
    let sel_op = sel_opnode.get_op();

    let inputs_legs = (0..sel_op.get_vars().len()).map(|v| (v, OpSide::Inputs));
    let outputs_legs = (0..sel_op.get_vars().len()).map(|v| (v, OpSide::Outputs));

    legs.extend(
        inputs_legs
            .chain(outputs_legs)
            .map(|leg| (leg, h(&sel_op, entrance_leg, leg))),
    );

    let total_weight: f64 = legs.iter().map(|(_, w)| *w).sum();
    let choice = rng.gen_range(0.0, total_weight);
    let exit_leg = legs
        .iter()
        .try_fold(choice, |c, (leg, weight)| {
            if c < *weight {
                Err(*leg)
            } else {
                Ok(c - *weight)
            }
        })
        .unwrap_err();

    let mut inputs = sel_opnode.get_op_ref().clone_inputs();
    let mut outputs = sel_opnode.get_op_ref().clone_outputs();
    adjust_states(inputs.as_mut(), outputs.as_mut(), entrance_leg);

    // Change the op now that we passed through.
    sel_opnode.get_op_mut().edit_in_out(|ins, outs| {
        adjust_states(ins, outs, exit_leg);
    });

    // No longer need mutability.
    let sel_opnode = l.get_node_ref(sel_op_pos).unwrap();
    let sel_op = sel_opnode.get_op_ref();

    // Check if we closed the loop before going to next opnode.
    if (sel_op_pos, exit_leg) == initial_op_and_leg {
        LoopResult::Return
    } else {
        // Get the next opnode and entrance leg, let us know if it changes the initial/final.
        let PRel {
            p: next_p,
            relv: next_rel,
        } = match exit_leg {
            (var, OpSide::Outputs) => {
                let next_var_op = l.get_next_p_for_rel_var(var, sel_opnode);
                next_var_op.unwrap_or_else(|| {
                    // Adjust the state to reflect new output.
                    state[sel_op.get_vars()[var]] = sel_op.get_outputs()[var];
                    l.get_first_p_for_var(sel_op.get_vars()[var]).unwrap()
                })
            }
            (var, OpSide::Inputs) => {
                let prev_var_op = l.get_previous_p_for_rel_var(var, sel_opnode);
                prev_var_op.unwrap_or_else(|| {
                    // Adjust the state to reflect new input.
                    state[sel_op.get_vars()[var]] = sel_op.get_inputs()[var];
                    l.get_last_p_for_var(sel_op.get_vars()[var]).unwrap()
                })
            }
        };
        let new_entrance_leg = (next_rel, exit_leg.1.reverse());

        legs.dissolve(l);

        // If back where we started, close loop and return state changes.
        if (next_p, new_entrance_leg) == initial_op_and_leg {
            LoopResult::Return
        } else {
            LoopResult::Iterate(next_p, new_entrance_leg)
        }
    }
}