extern crate self as qip;
use crate::errors::CircuitError;
use crate::inverter::RecursiveCircuitBuilder;
use crate::macros::program_ops::*;
use crate::prelude::*;
use qip_macros::*;
use std::num::NonZeroUsize;
macro_rules! register_tuple {
($($T:ident),*) => {
(
$(<$T as CircuitBuilder>::Register,)*
)
}
}
type R2<CB> = register_tuple!(CB, CB);
type R3<CB> = register_tuple!(CB, CB, CB);
type R4<CB> = register_tuple!(CB, CB, CB, CB);
type R5<CB> = register_tuple!(CB, CB, CB, CB, CB);
#[invert]
pub fn add<P: Precision, CB: RecursiveCircuitBuilder<P>>(
b: &mut CB,
rc: CB::Register,
ra: CB::Register,
rb: CB::Register,
) -> CircuitResult<R3<CB>> {
match (rc.n(), ra.n(), rb.n()) {
(1, 1, 2) => {
let (rc, ra, rb) = program!(&mut *b; rc, ra, rb;
carry rc, ra, rb[0], rb[1];
sum rc, ra, rb[0];
)?;
Ok((rc, ra, rb))
}
(nc, na, nb) if nc == na && nc + 1 == nb => {
let n = nc;
let (rc, ra, rb) = program!(&mut *b; rc, ra, rb;
carry rc[0], ra[0], rb[0], rc[1];
add rc[1..n], ra[1..n], rb[1..=n];
carry_inv rc[0], ra[0], rb[0], rc[1];
sum rc[0], ra[0], rb[0];
)?;
Ok((rc, ra, rb))
}
(nc, na, nb) => Err(CircuitError::new(format!(
"Expected rc[n] ra[n] and rb[n+1], but got ({},{},{})",
nc, na, nb
))),
}
}
fn sum<P: Precision, CB: RecursiveCircuitBuilder<P>>(
b: &mut CB,
rc: CB::Register,
ra: CB::Register,
rb: CB::Register,
) -> CircuitResult<R3<CB>> {
program!(&mut *b; rc, ra, rb;
control x ra, rb;
control x rc, rb;
)
}
#[invert]
fn carry<P: Precision, CB: RecursiveCircuitBuilder<P>>(
b: &mut CB,
rc: CB::Register,
ra: CB::Register,
rb: CB::Register,
rcp: CB::Register,
) -> CircuitResult<R4<CB>> {
let (rc, ra, rb, rcp) = program!(&mut *b; rc, ra, rb, rcp;
control x [ra, rb] rcp;
control x ra, rb;
control x [rc, rb] rcp;
control x ra, rb;
)?;
Ok((rc, ra, rb, rcp))
}
#[invert]
pub fn add_mod<P: Precision, CB: RecursiveCircuitBuilder<P>>(
b: &mut CB,
ra: CB::Register,
rb: CB::Register,
rm: CB::Register,
) -> CircuitResult<R3<CB>> {
if ra.n() != rm.n() {
Err(CircuitError::new(format!(
"Expected rm.n == ra.n == {}, found rm.n={}.",
ra.n(),
rm.n()
)))
} else if rb.n() != ra.n() + 1 {
Err(CircuitError::new(format!(
"Expected rb.n == ra.n + 1== {}, found rm.n={}.",
ra.n() + 1,
rb.n()
)))
} else {
let n = ra.n();
let rt = b.make_zeroed_temp_qubit();
let rc = b.make_zeroed_temp_register(NonZeroUsize::new(n).unwrap());
let (ra, rb, rm, rt, rc) = program!(&mut *b; ra, rb, rm, rt, rc;
add rc, ra, rb;
add_inv rc, rm, rb;
control x rb[n], rt;
control add rt, rc, rm, rb;
add_inv rc, ra, rb;
control(0) x rb[n], rt;
add rc, ra, rb;
)?;
b.return_zeroed_temp_register(rt);
b.return_zeroed_temp_register(rc);
Ok((ra, rb, rm))
}
}
#[invert]
pub fn times_mod<P: Precision, CB: RecursiveCircuitBuilder<P>>(
b: &mut CB,
ra: CB::Register,
rb: CB::Register,
rm: CB::Register,
rp: CB::Register,
) -> CircuitResult<R4<CB>> {
let n = rm.n();
let k = rb.n();
if ra.n() != n + 1 {
Err(CircuitError::new(format!(
"Expected ra.n = rm.n + 1 = {}, but found {}",
n + 1,
ra.n()
)))
} else if rp.n() != n + 1 {
Err(CircuitError::new(format!(
"Expected rp.n = rm.n + 1 = {}, but found {}",
n + 1,
rp.n()
)))
} else {
let rt = b.make_zeroed_temp_register(NonZeroUsize::new(k).unwrap());
let rc = b.make_zeroed_temp_register(NonZeroUsize::new(n).unwrap());
let rs = (ra, rb, rm, rp, rt, rc);
let rs = (0..k).try_fold(rs, |rs, indx| {
let (ra, rb, rm, rp, rt, rc) = rs;
program!(&mut *b; ra, rb, rm, rp, rt, rc;
add_inv rc, rm, ra;
control x ra[n], rt[indx];
control add rt[indx], rc, rm, ra;
control add_mod rb[indx], ra[0 .. n], rp, rm;
rshift ra;
)
})?;
let rs = (0..k).rev().try_fold(rs, |rs, indx| {
let (ra, rb, rm, rp, rt, rc) = rs;
let (ra, rm, rt, rc) = program!(&mut *b; ra, rm, rt, rc;
lshift ra;
control add_inv rt[indx], rc, rm, ra;
control x ra[n], rt[indx];
add rc, rm, ra;
)?;
Ok((ra, rb, rm, rp, rt, rc))
})?;
let (ra, rb, rm, rp, rt, rc) = rs;
b.return_zeroed_temp_register(rc);
b.return_zeroed_temp_register(rt);
Ok((ra, rb, rm, rp))
}
}
#[invert(lshift)]
pub fn rshift<P: Precision, CB: RecursiveCircuitBuilder<P>>(
b: &mut CB,
r: CB::Register,
) -> CircuitResult<CB::Register> {
let n = r.n();
let mut rs: Vec<Option<CB::Register>> = b.split_all_register(r).into_iter().map(Some).collect();
(0..n - 1).rev().for_each(|indx| {
let ra = rs[indx].take().unwrap();
let offset = (indx as i64 - 1) % (n as i64);
let offset = if offset < 0 {
offset + n as i64
} else {
offset
} as u64;
let rb = rs[offset as usize].take().unwrap();
let (ra, rb) = b.swap(ra, rb).unwrap();
rs[indx] = Some(ra);
rs[offset as usize] = Some(rb);
});
Ok(b.merge_registers(rs.into_iter().flatten()).unwrap())
}
#[invert]
pub fn copy<P: Precision, CB: RecursiveCircuitBuilder<P>>(
b: &mut CB,
ra: CB::Register,
rb: CB::Register,
) -> CircuitResult<R2<CB>> {
if ra.n() != rb.n() {
Err(CircuitError::new(format!(
"Expected ra.n = rb.n, but found {} and {}",
ra.n(),
rb.n()
)))
} else {
let ras = b.split_all_register(ra);
let rbs = b.split_all_register(rb);
let (ras, rbs) = ras.into_iter().zip(rbs.into_iter()).try_fold(
(vec![], vec![]),
|(mut ras, mut rbs), (ra, rb)| {
let (ra, rb) = b.cnot(ra, rb)?;
ras.push(ra);
rbs.push(rb);
Ok((ras, rbs))
},
)?;
let ra = b.merge_registers(ras).unwrap();
let rb = b.merge_registers(rbs).unwrap();
Ok((ra, rb))
}
}
#[invert]
pub fn square_mod<P: Precision, CB: RecursiveCircuitBuilder<P>>(
b: &mut CB,
ra: CB::Register,
rm: CB::Register,
rs: CB::Register,
) -> CircuitResult<R3<CB>> {
let n = rm.n();
if ra.n() != n + 1 {
Err(CircuitError::new(format!(
"Expected ra.n = rm.n + 1 = {}, but found {}",
n + 1,
ra.n()
)))
} else if rs.n() != n + 1 {
Err(CircuitError::new(format!(
"Expected rs.n = rm.n + 1 = {}, but found {}",
n + 1,
rs.n()
)))
} else {
let rt = b.make_zeroed_temp_register(NonZeroUsize::new(n).unwrap());
let (ra, rm, rs, rt) = program!(&mut *b; ra, rm, rs, rt;
copy ra[0 .. n], rt;
times_mod ra, rt, rm, rs;
copy_inv ra[0 .. n], rt;
)?;
b.return_zeroed_temp_register(rt);
Ok((ra, rm, rs))
}
}
#[invert]
pub fn exp_mod<P: Precision, CB: RecursiveCircuitBuilder<P>>(
b: &mut CB,
ra: CB::Register,
rb: CB::Register,
rm: CB::Register,
rp: CB::Register,
re: CB::Register,
) -> CircuitResult<R5<CB>> {
let n = rm.n();
let k = rb.n();
if ra.n() != n + 1 {
Err(CircuitError::new(format!(
"Expected ra.n = rm.n + 1 = {}, but found {}",
n + 1,
ra.n()
)))
} else if rp.n() != n + 1 {
Err(CircuitError::new(format!(
"Expected ro.n = rm.n + 1 = {}, but found {}",
n + 1,
rp.n()
)))
} else if re.n() != n + 1 {
Err(CircuitError::new(format!(
"Expected re.n = rm.n + 1 = {}, but found {}",
n + 1,
re.n()
)))
} else if k == 1 {
program!(&mut *b; ra, rb, rm, rp, re;
control(0) copy rb[0], rp, re;
control times_mod rb[0], ra, rp, rm, re;
)
} else {
let ru = b.make_zeroed_temp_register(NonZeroUsize::new(n + 1).unwrap());
let rv = b.make_zeroed_temp_register(NonZeroUsize::new(n + 1).unwrap());
let (ra, rb, rm, rp, re, ru, rv) = program!(&mut *b; ra, rb, rm, rp, re, ru, rv;
control(0) copy rb[0], rp, rv;
control times_mod rb[0], ra, rp, rm, re;
square_mod ra, rm, ru;
exp_mod ru, rb[1 .. k], rm, rv, re;
square_mod_inv ra, rm, ru;
control times_mod_inv rb[0], ra, rp, rm, re;
control(0) copy_inv rb[0], rp, rv;
)?;
b.return_zeroed_temp_register(ru);
b.return_zeroed_temp_register(rv);
Ok((ra, rb, rm, rp, re))
}
}
#[cfg(test)]
mod arithmetic_tests {
}