lay_simulator_gk/
lib.rs

1use std::fmt::Debug;
2
3use rand_core::{RngCore, SeedableRng};
4use rand_xorshift::XorShiftRng;
5use lay::{Layer, gates::{PauliGate, HGate, SGate, CXGate}, operations::{opid, OpArgs}};
6
7mod bitarray;
8pub use bitarray::BitArray;
9
10mod fakerng;
11pub use fakerng::RepeatSeqFakeRng;
12
13pub type DefaultRng = XorShiftRng;
14
15#[derive(Debug)]
16pub struct GottesmanKnillSimulator<Rng> {
17    xs: Vec<BitArray>,
18    zs: Vec<BitArray>,
19    sgns: BitArray,
20    measured: BitArray,
21    rng: Rng,
22}
23
24impl<Rng: RngCore + Debug> PauliGate for GottesmanKnillSimulator<Rng> {}
25impl<Rng: RngCore + Debug> HGate for GottesmanKnillSimulator<Rng> {}
26impl<Rng: RngCore + Debug> SGate for GottesmanKnillSimulator<Rng> {}
27impl<Rng: RngCore + Debug> CXGate for GottesmanKnillSimulator<Rng> {}
28
29impl GottesmanKnillSimulator<DefaultRng> {
30    pub fn from_seed(n: u32, seed: u64) -> Self {
31        Self::from_rng(n, DefaultRng::seed_from_u64(seed))
32    }
33}
34
35impl<Rng: RngCore> GottesmanKnillSimulator<Rng> {
36    pub fn from_rng(n: u32, rng: Rng) -> Self {
37        let xs = (0..n).map(|_| BitArray::zeros(n as usize)).collect();
38        let zs = (0..n).map(|i| {
39            let mut arr = BitArray::zeros(n as usize);
40            arr.negate(i as usize);
41            arr
42        }).collect();
43        let sgns = BitArray::zeros(n as usize);
44        let measured = BitArray::zeros(n as usize);
45        Self { xs, zs, sgns, measured, rng }
46    }
47}
48
49impl<Rng> GottesmanKnillSimulator<Rng> {
50    pub fn dump_print(&self) {
51        println!("xs:   {:?}", self.xs);
52        println!("zs:   {:?}", self.zs);
53        println!("sgns: {:?}", self.sgns);
54        println!("measured: {:?}", self.measured);
55    }
56    pub fn n_qubits(&self) -> u32 {
57        self.xs.len() as _
58    }
59}
60
61impl<Rng: RngCore + Debug> Layer for GottesmanKnillSimulator<Rng> {
62    type Operation = OpArgs<Self>;
63    type Qubit = u32;
64    type Slot = u32;
65    type Buffer = BitArray;
66    type Requested = ();
67    type Response = ();
68
69    fn send(&mut self, ops: &[OpArgs<Self>]) {
70        for op in ops.iter() {
71            match op {
72                OpArgs::Empty(id) if *id == opid::INIT =>
73                    self.initialize(),
74                OpArgs::Q(id, q) => {
75                    match *id {
76                        opid::X => self.x(*q),
77                        opid::Y => self.y(*q),
78                        opid::Z => self.z(*q),
79                        opid::H => self.h(*q),
80                        opid::S => self.s(*q),
81                        opid::SDG => self.sdg(*q),
82                        _ => unimplemented!("Unexpected opid {:?}", *op)
83                    }
84                },
85                OpArgs::QS(id, q, s) if *id == opid::MEAS =>
86                    self.measure(*q, *s),
87                OpArgs::QQ(id, c, t) if *id == opid::CX =>
88                    self.cx(*c, *t),
89                _ => unimplemented!("Unexpected op {:?}", *op)
90            }
91        }
92    }
93
94    fn receive(&mut self, buf: &mut BitArray) {
95        buf.copy_from(&self.measured);
96    }
97
98    fn send_receive(&mut self, ops: &[OpArgs<Self>], buf: &mut BitArray) {
99        self.send(ops);
100        self.receive(buf);
101    }
102
103    fn make_buffer(&self) -> Self::Buffer {
104        BitArray::zeros(self.measured.len())
105    }
106}
107
108impl<Rng: RngCore> GottesmanKnillSimulator<Rng> {
109    fn initialize(&mut self) {
110        self.xs.iter_mut().for_each(|a| a.reset());
111        self.zs.iter_mut().for_each(|a| a.reset());
112        self.zs.iter_mut().enumerate().for_each(|(i, a)| a.negate(i as usize));
113        self.sgns.reset();
114        self.measured.reset();
115    }
116
117    fn measure(&mut self, q: u32, ch: u32) {
118        let bit = measure(self, q);
119        self.measured.set_bool(ch as usize, bit);
120    }
121
122    #[inline]
123    fn x(&mut self, q: u32) {
124        for (i, _) in self.zs.iter().enumerate()
125                                    .filter(|(_, zs)| zs.get_bool(q as usize)) {
126            self.sgns.negate(i as usize);
127        }
128    }
129
130    #[inline]
131    fn y(&mut self, q: u32) {
132        for (i, _) in  self.xs.iter().zip(self.zs.iter())
133                           .enumerate()
134                           .filter(|(_, (xs, zs))| (xs.get_masked(q as usize) ^ zs.get_masked(q as usize)) != 0) {
135            self.sgns.negate(i as usize);
136         }
137    }
138
139    #[inline]
140    fn z(&mut self, q: u32) {
141        for (i, _) in self.xs.iter().enumerate()
142                                    .filter(|(_, xs)| xs.get_bool(q as usize)) {
143            self.sgns.negate(i as usize);
144        }
145    }
146
147    #[inline]
148    fn h(&mut self, q: u32) {
149        for (i, (xs, zs)) in self.xs.iter_mut().zip(self.zs.iter_mut()).enumerate() {
150            let x = xs.get_bool(q as usize);
151            let z = zs.get_bool(q as usize);
152            if x && z {
153                self.sgns.negate(i);
154            } else if x || z {
155                xs.negate(q as usize);
156                zs.negate(q as usize);
157            }
158         }
159    }
160
161    #[inline]
162    fn s(&mut self, q: u32) {
163        for (i, (xs, zs)) in self.xs.iter().zip(self.zs.iter_mut())
164                                           .enumerate() {
165            if xs.get_bool(q as usize) {
166                if zs.get_bool(q as usize) {
167                    self.sgns.negate(i as usize);
168                }
169                zs.negate(q as usize);
170            }
171         }
172    }
173
174    #[inline]
175    fn sdg(&mut self, q: u32) {
176        for (i, (xs, zs)) in self.xs.iter().zip(self.zs.iter_mut())
177                                           .enumerate() {
178            if xs.get_bool(q as usize) {
179                if !zs.get_bool(q as usize) {
180                    self.sgns.negate(i as usize);
181                }
182                zs.negate(q as usize);
183            }
184         }
185    }
186
187    #[inline]
188    fn cx(&mut self, c: u32, t: u32) {
189        for (i, (xs, zs)) in self.xs.iter_mut()
190                                 .zip(self.zs.iter_mut())
191                                 .enumerate() {
192            if xs.get_bool(c as usize) {
193                xs.negate(t as usize);
194                if zs.get_bool(c as usize) {
195                    self.sgns.negate(i as usize);
196                }
197            }
198            if zs.get_bool(t as usize) {
199                zs.negate(c as usize);
200            }
201        }
202    }
203}
204
205fn mult_to<Rng>(gk: &mut GottesmanKnillSimulator<Rng>, dest: usize, src: usize) {
206    assert_ne!(dest, src);
207    let from = unsafe { &*(&gk.xs[src] as *const _) };
208    let into = &mut gk.xs[dest];
209    into.xor_all(&*from);
210    let from = unsafe { &*(&gk.zs[src] as *const _) };
211    let into = &mut gk.zs[dest];
212    into.xor_all(&*from);
213    gk.sgns.set_bool(dest, gk.sgns.get_bool(src));
214}
215
216fn measure<Rng: RngCore>(gk: &mut GottesmanKnillSimulator<Rng>, q: u32) -> bool {
217    let noncommutatives: Vec<_> = gk.xs.iter().map(|a| a.get_bool(q as usize))
218                                              .enumerate()
219                                              .filter(|(_, b)| *b)
220                                              .map(|(i, _)| i)
221                                              .collect();
222    if noncommutatives.is_empty() {
223        //eprintln!("stabilized pattern");
224        let n_qubits = gk.n_qubits() as usize;
225        let mut indices: Vec<_> = (0..n_qubits).collect();
226        for i in 0..n_qubits as usize {
227            let x_inds: Vec<_> = indices.iter().enumerate().filter(|(_, &k)| gk.xs[k].get_bool(i)).map(|(i, _)| i).collect();
228            if !x_inds.is_empty() {
229                let xs0 = unsafe { &*(&gk.xs[indices[x_inds[0]]] as *const _) };
230                let zs0 = unsafe { &*(&gk.zs[indices[x_inds[0]]] as *const _) };
231                let sg0 = gk.sgns.get_bool(indices[x_inds[0]]);
232                for j in x_inds[1..].iter() {
233                    gk.xs[indices[*j]].xor_all(&xs0);
234                    gk.zs[indices[*j]].xor_all(&zs0);
235                    if sg0 {
236                        gk.sgns.negate(indices[*j]);
237                    }
238                }
239                indices.swap_remove(x_inds[0]);
240            }
241        }
242        for i in 0..n_qubits as usize {
243            if i == q as usize { continue }
244            let z_inds: Vec<_> = indices.iter().enumerate().filter(|(_, &k)| gk.zs[k].get_bool(i)).map(|(i, _)| i).collect();
245            if !z_inds.is_empty() {
246                let xs0 = unsafe { &*(&gk.xs[indices[z_inds[0]]] as *const _) };
247                let zs0 = unsafe { &*(&gk.zs[indices[z_inds[0]]] as *const _) };
248                let sg0 = gk.sgns.get_bool(indices[z_inds[0]]);
249                for j in z_inds[1..].iter() {
250                    gk.xs[indices[*j]].xor_all(&xs0);
251                    gk.zs[indices[*j]].xor_all(&zs0);
252                    if sg0 {
253                        gk.sgns.negate(indices[*j]);
254                    }
255                }
256                indices.swap_remove(z_inds[0]);
257            }
258        }
259        assert_eq!(indices.len(), 1);
260        // println!("measured xs: {:?}", gk.xs[indices[0]]);
261        // println!("measured zs: {:?}", gk.zs[indices[0]]);
262        // println!("measured sg: {:?}", gk.sgns.get_bool(indices[0]));
263        gk.sgns.get_bool(indices[0])
264    } else {
265        //eprintln!("non-stabilized pattern");
266        let i = noncommutatives[0];
267        for &j in noncommutatives[1..].iter() {
268            mult_to(gk, j, i);
269        }
270        let is_one = (gk.rng.next_u32() & 1) != 0;
271        gk.xs[noncommutatives[0]].reset();
272        gk.zs[noncommutatives[0]].reset();
273        gk.zs[noncommutatives[0]].negate(q as usize);
274        gk.sgns.set_bool(noncommutatives[0], is_one);
275        is_one
276    }
277}
278
279#[cfg(test)]
280mod tests {
281    #![allow(unused_imports)]
282    use crate::{GottesmanKnillSimulator, BitArray, DefaultRng, RepeatSeqFakeRng};
283    use rand_core::{RngCore, SeedableRng};
284    use rand_xorshift::XorShiftRng;
285    use lay::{Layer, OpsVec, Measured};
286    use tokio::{prelude::*, runtime::Runtime};
287
288
289    #[test]
290    fn it_works() {
291        assert_eq!(2 + 2, 4);
292        let _ = GottesmanKnillSimulator::from_seed(3, 0);
293    }
294
295    fn check(f: impl Fn(&mut OpsVec<GottesmanKnillSimulator<DefaultRng>>, u32), expect: &[u32]) {
296        let mut ops = OpsVec::new();
297        let mut result = BitArray::zeros(0);
298        f(&mut ops, expect.len() as u32);
299        GottesmanKnillSimulator::from_seed(expect.len() as u32, 0).send_receive(ops.as_ref(), &mut result);
300        let actual: Vec<_> = (0..expect.len()).map(|i| result.get_bool(i) as u32).collect();
301        assert_eq!(actual.as_slice(), expect);
302    }
303
304    fn check_with_randseq(f: impl Fn(&mut OpsVec<GottesmanKnillSimulator<RepeatSeqFakeRng>>, u32),
305                          expect: &[u32],
306                          seq: Vec<u64>) {
307        let mut ops = OpsVec::new();
308        let mut result = BitArray::zeros(0);
309        f(&mut ops, expect.len() as u32);
310        GottesmanKnillSimulator::from_rng(expect.len() as u32,
311                                          RepeatSeqFakeRng::new(seq)).send_receive(ops.as_ref(), &mut result);
312        let actual: Vec<_> = (0..expect.len()).map(|i| result.get_bool(i) as u32).collect();
313        assert_eq!(actual.as_slice(), expect);
314    }
315
316    /*
317    fn check_stabilized(gk: &GottesmanKnillSimulator<DefaultRng>, bq: &BlueqatOperations) {
318        let rt = Runtime::new().unwrap();
319        let mut bqsim = BlueqatSimulator::new().unwrap();
320
321        // TODO: Implement
322    }
323    */
324
325    #[test]
326    fn test_zgate1() {
327        check(|gk, n_qubits| {
328            gk.z(0);
329            for i in 0..n_qubits {
330                gk.measure(i, i);
331            }
332        }, &[0]);
333    }
334
335    #[test]
336    fn test_xgate1() {
337        check(|gk, n_qubits| {
338            gk.x(0);
339            for i in 0..n_qubits {
340                gk.measure(i, i);
341            }
342        }, &[1]);
343    }
344
345    #[test]
346    fn test_xgate2() {
347        check(|gk, n_qubits| {
348            gk.x(0);
349            gk.x(3);
350            gk.z(2);
351            gk.x(6);
352            for i in 0..n_qubits {
353                gk.measure(i, i);
354            }
355        }, &[1, 0, 0, 1, 0, 0, 1]);
356    }
357
358    #[test]
359    fn test_cx1() {
360        check(|gk, n_qubits| {
361            gk.cx(0, 1);
362            for i in 0..n_qubits {
363                gk.measure(i, i);
364            }
365        }, &[0, 0]);
366    }
367
368    #[test]
369    fn test_cx2() {
370        check(|gk, n_qubits| {
371            gk.x(1);
372            gk.cx(0, 1);
373            for i in 0..n_qubits {
374                gk.measure(i, i);
375            }
376        }, &[0, 1]);
377    }
378
379    #[test]
380    fn test_cx3() {
381        check(|gk, n_qubits| {
382            gk.x(0);
383            gk.cx(0, 1);
384            for i in 0..n_qubits {
385                gk.measure(i, i);
386            }
387        }, &[1, 1]);
388    }
389
390    #[test]
391    fn test_cx4() {
392        check(|gk, n_qubits| {
393            gk.x(0);
394            gk.cx(0, 1);
395            gk.cx(1, 2);
396            gk.cx(2, 0);
397            for i in 0..n_qubits {
398                gk.measure(i, i);
399            }
400        }, &[0, 1, 1]);
401    }
402
403    #[test]
404    fn test_h_and_z() {
405        check(|gk, n_qubits| {
406            gk.h(0);
407            gk.z(0);
408            gk.h(0);
409            gk.x(1);
410            gk.h(1);
411            gk.h(1);
412            for i in 0..n_qubits {
413                gk.measure(i, i);
414            }
415        }, &[1, 1]);
416    }
417
418    #[test]
419    fn test_h_and_s() {
420        check(|gk, n_qubits| {
421            gk.h(0);
422            gk.s(0);
423            gk.s(0);
424            gk.s(0);
425            gk.s(0);
426            gk.h(0);
427            gk.h(1);
428            gk.sdg(1);
429            gk.sdg(1);
430            gk.h(1);
431            for i in 0..n_qubits {
432                gk.measure(i, i);
433            }
434        }, &[0, 1]);
435    }
436
437    #[test]
438    fn test_h_and_x() {
439        check(|gk, n_qubits| {
440            gk.h(0);
441            gk.s(0);
442            gk.h(0);
443            gk.x(0);
444            gk.h(0);
445            gk.sdg(0);
446            gk.h(0);
447            for i in 0..n_qubits {
448                gk.measure(i, i);
449            }
450        }, &[1]);
451    }
452
453    #[test]
454    fn test_hh() {
455        check_with_randseq(|gk, n_qubits| {
456            gk.h(0);
457            gk.cx(0, 1);
458            gk.h(2);
459            gk.cx(2, 3);
460            for i in 0..n_qubits {
461                gk.measure(i, i);
462            }
463        }, &[1, 1, 0, 0], vec![1, 0, 0, 0]);
464    }
465
466    #[test]
467    fn test_manyqubit1() {
468        let n_qubits = 200;
469        let mut sim = GottesmanKnillSimulator::from_seed(n_qubits, 0);
470        let mut ops = OpsVec::<GottesmanKnillSimulator<_>>::new();
471        ops.initialize();
472        for i in 0..n_qubits {
473            ops.x(i);
474            ops.measure(i, i);
475        }
476        sim.send(ops.as_ref());
477    }
478
479    #[test]
480    fn test_manyqubit2() {
481        let n_qubits = 200;
482        let mut sim = GottesmanKnillSimulator::from_seed(n_qubits, 0);
483        let mut ops = OpsVec::<GottesmanKnillSimulator<_>>::new();
484        ops.initialize();
485        for i in 0..n_qubits {
486            ops.x(i);
487            ops.measure(i, i);
488        }
489        sim.send(ops.as_ref());
490    }
491
492    #[test]
493    fn test_measure() {
494        let n_qubits = 5;
495        let mut sim = GottesmanKnillSimulator::from_seed(n_qubits, 0);
496        let mut ops = OpsVec::<GottesmanKnillSimulator<_>>::new();
497        ops.initialize();
498        for i in 0..n_qubits {
499            ops.x(i);
500            ops.measure(i, i);
501        }
502        let mut buf = sim.make_buffer();
503        sim.send_receive(ops.as_ref(), &mut buf);
504        assert!(buf.get(0));
505        assert!(buf.get(1));
506        assert!(buf.get(2));
507        assert!(buf.get(3));
508        assert!(buf.get(4));
509    }
510
511    #[test]
512    fn test_measure2() {
513        let n_qubits = 5;
514        let mut sim = GottesmanKnillSimulator::from_seed(n_qubits, 0);
515        let mut ops = OpsVec::<GottesmanKnillSimulator<_>>::new();
516        ops.initialize();
517        for i in 0..n_qubits {
518            ops.x(i);
519        }
520        for i in 0..n_qubits {
521            ops.measure(i, i);
522        }
523        let mut buf = sim.make_buffer();
524        sim.send_receive(ops.as_ref(), &mut buf);
525        assert!(buf.get(0));
526        assert!(buf.get(1));
527        assert!(buf.get(2));
528        assert!(buf.get(3));
529        assert!(buf.get(4));
530    }
531
532    #[test]
533    fn test_measure3() {
534        let mut sim = GottesmanKnillSimulator::from_seed(14, 0);
535        let mut ops = OpsVec::<GottesmanKnillSimulator<_>>::new();
536        ops.initialize();
537        for i in 7..14 {
538            ops.x(i);
539        }
540        for i in 0..14 {
541            ops.measure(i, i);
542        }
543        let mut buf = sim.make_buffer();
544        sim.send_receive(ops.as_ref(), &mut buf);
545        assert!(!buf.get(0));
546        assert!(!buf.get(1));
547        assert!(!buf.get(2));
548        assert!(!buf.get(3));
549        assert!(!buf.get(4));
550        assert!(!buf.get(5));
551        assert!(!buf.get(6));
552        assert!(buf.get(7));
553        assert!(buf.get(8));
554        assert!(buf.get(9));
555        assert!(buf.get(10));
556        assert!(buf.get(11));
557        assert!(buf.get(12));
558        assert!(buf.get(13));
559    }
560}