use ndarray::Array1;
use rand::Rng;
use rqism::{
circuit::Circuit, instruction::*, simulator_traits::Simulator, state_vector::QuantumStateVector,
};
struct QNN {
theta: Array1<f32>,
epsilon: f32,
iter_per_epoch: usize,
learning_rate: f32,
}
impl QNN {
fn new(theta: Array1<f32>, epsilon: f32, iter_per_epoch: usize, learning_rate: f32) -> Self {
Self {
theta,
epsilon,
iter_per_epoch,
learning_rate,
}
}
fn loss(y: &f32, ex: &f32) -> f32 {
(*y - *ex).powi(2)
}
fn run(theta: &Array1<f32>, x: &Array1<f32>) -> f32 {
let n = x.len() + 1;
let mut ins = vec![];
x.iter()
.enumerate()
.for_each(|(i, x)| ins.push(Instruction::rx(i, *x)));
for i in 0..n - 1 {
ins.push(Instruction::cnot([i, i + 1]));
}
ins.push(Instruction::cnot([n - 2, 0]));
for i in 0..n - 1 {
ins.push(Instruction::ry(i, theta[i]));
}
ins.push(Instruction::Measure { indices: vec![0] });
ins.push(Instruction::cnot([0, n - 1]));
let circ = Circuit { ins, n };
let sim = QuantumStateVector::new(n);
let counts = sim.get_counts(&circ, 100);
let one = counts
.data
.iter()
.find(|(key, _)| key.to_owned().chars().last().unwrap() == '1')
.unwrap()
.1;
*one as f32 / counts.shots as f32
}
fn grad(&self, x: &Array1<f32>, y: &f32) -> Array1<f32> {
(0..self.theta.len())
.map(|i| {
let mut delta = self.theta.clone();
delta[i] += self.epsilon;
let p1 = Self::run(&delta, x);
let p2 = Self::run(&self.theta, x);
(Self::loss(&p1, y) - Self::loss(&p2, y)) / self.epsilon
})
.collect()
}
fn epoch(&mut self, data: (&[Array1<f32>], &[f32])) {
let mut losses = vec![];
for _ in 0..self.iter_per_epoch {
let mut loss_iter = vec![];
for (x, y) in data.0.iter().zip(data.1.iter()) {
let pr = Self::run(&self.theta, x);
loss_iter.push(Self::loss(&pr, y));
self.theta = self.theta.clone() - self.learning_rate * self.grad(x, y);
}
losses.push(loss_iter.iter().sum::<f32>() / losses.len() as f32);
}
}
fn accuracy(&self, xs: &[Array1<f32>], ys: &[f32]) -> f32 {
let mut counter = 0.0;
for (x, y) in xs.iter().zip(ys.iter()) {
let pred = Self::run(&self.theta, x);
if pred < 0.5 && *y == 0.0 {
counter += 1.;
} else if pred > 0.5 && *y == 1.0 {
counter += 1.;
}
}
counter / ys.len() as f32
}
}
fn main() {
let mut rng = rand::thread_rng();
let mut nn = QNN::new(
Array1::from_vec((0..4).map(|_| rng.gen_range(0.0..1.0)).collect()),
0.01,
10,
0.05,
);
let xs = [
vec![5.1, 3.5, 1.4, 0.2],
vec![4.9, 3.0, 1.4, 0.2],
vec![4.7, 3.2, 1.3, 0.2],
vec![4.6, 3.1, 1.5, 0.2],
vec![5.0, 3.6, 1.4, 0.2],
vec![5.4, 3.9, 1.7, 0.4],
vec![4.6, 3.4, 1.4, 0.3],
vec![5.0, 3.4, 1.5, 0.2],
vec![4.4, 2.9, 1.4, 0.2],
vec![4.9, 3.1, 1.5, 0.1],
vec![5.4, 3.7, 1.5, 0.2],
vec![4.8, 3.4, 1.6, 0.2],
vec![4.8, 3.0, 1.4, 0.1],
vec![4.3, 3.0, 1.1, 0.1],
vec![5.8, 4.0, 1.2, 0.2],
vec![5.7, 4.4, 1.5, 0.4],
vec![5.4, 3.9, 1.3, 0.4],
vec![5.1, 3.5, 1.4, 0.3],
vec![5.7, 3.8, 1.7, 0.3],
vec![5.1, 3.8, 1.5, 0.3],
vec![5.4, 3.4, 1.7, 0.2],
vec![5.1, 3.7, 1.5, 0.4],
vec![4.6, 3.6, 1.0, 0.2],
vec![5.1, 3.3, 1.7, 0.5],
vec![4.8, 3.4, 1.9, 0.2],
vec![5.0, 3.0, 1.6, 0.2],
vec![5.0, 3.4, 1.6, 0.4],
vec![5.2, 3.5, 1.5, 0.2],
vec![5.2, 3.4, 1.4, 0.2],
vec![4.7, 3.2, 1.6, 0.2],
vec![4.8, 3.1, 1.6, 0.2],
vec![5.4, 3.4, 1.5, 0.4],
vec![5.2, 4.1, 1.5, 0.1],
vec![5.5, 4.2, 1.4, 0.2],
vec![4.9, 3.1, 1.5, 0.1],
vec![5.0, 3.2, 1.2, 0.2],
vec![5.5, 3.5, 1.3, 0.2],
vec![4.9, 3.1, 1.5, 0.1],
vec![4.4, 3.0, 1.3, 0.2],
vec![5.1, 3.4, 1.5, 0.2],
vec![5.0, 3.5, 1.3, 0.3],
vec![4.5, 2.3, 1.3, 0.3],
vec![4.4, 3.2, 1.3, 0.2],
vec![5.0, 3.5, 1.6, 0.6],
vec![5.1, 3.8, 1.9, 0.4],
vec![4.8, 3.0, 1.4, 0.3],
vec![5.1, 3.8, 1.6, 0.2],
vec![4.6, 3.2, 1.4, 0.2],
vec![5.3, 3.7, 1.5, 0.2],
vec![5.0, 3.3, 1.4, 0.2],
vec![7.0, 3.2, 4.7, 1.4],
vec![6.4, 3.2, 4.5, 1.5],
vec![6.9, 3.1, 4.9, 1.5],
vec![5.5, 2.3, 4.0, 1.3],
vec![6.5, 2.8, 4.6, 1.5],
vec![5.7, 2.8, 4.5, 1.3],
vec![6.3, 3.3, 4.7, 1.6],
vec![4.9, 2.4, 3.3, 1.0],
vec![6.6, 2.9, 4.6, 1.3],
vec![5.2, 2.7, 3.9, 1.4],
vec![5.0, 2.0, 3.5, 1.0],
vec![5.9, 3.0, 4.2, 1.5],
vec![6.0, 2.2, 4.0, 1.0],
vec![6.1, 2.9, 4.7, 1.4],
vec![5.6, 2.9, 3.6, 1.3],
vec![6.7, 3.1, 4.4, 1.4],
vec![5.6, 3.0, 4.5, 1.5],
vec![5.8, 2.7, 4.1, 1.0],
vec![6.2, 2.2, 4.5, 1.5],
vec![5.6, 2.5, 3.9, 1.1],
vec![5.9, 3.2, 4.8, 1.8],
vec![6.1, 2.8, 4.0, 1.3],
vec![6.3, 2.5, 4.9, 1.5],
vec![6.1, 2.8, 4.7, 1.2],
vec![6.4, 2.9, 4.3, 1.3],
vec![6.6, 3.0, 4.4, 1.4],
vec![6.8, 2.8, 4.8, 1.4],
vec![6.7, 3.0, 5.0, 1.7],
vec![6.0, 2.9, 4.5, 1.5],
vec![5.7, 2.6, 3.5, 1.0],
vec![5.5, 2.4, 3.8, 1.1],
vec![5.5, 2.4, 3.7, 1.0],
vec![5.8, 2.7, 3.9, 1.2],
vec![6.0, 2.7, 5.1, 1.6],
vec![5.4, 3.0, 4.5, 1.5],
vec![6.0, 3.4, 4.5, 1.6],
vec![6.7, 3.1, 4.7, 1.5],
vec![6.3, 2.3, 4.4, 1.3],
vec![5.6, 3.0, 4.1, 1.3],
vec![5.5, 2.5, 4.0, 1.3],
vec![5.5, 2.6, 4.4, 1.2],
vec![6.1, 3.0, 4.6, 1.4],
vec![5.8, 2.6, 4.0, 1.2],
vec![5.0, 2.3, 3.3, 1.0],
vec![5.6, 2.7, 4.2, 1.3],
vec![5.7, 3.0, 4.2, 1.2],
vec![5.7, 2.9, 4.2, 1.3],
vec![6.2, 2.9, 4.3, 1.3],
vec![5.1, 2.5, 3.0, 1.1],
vec![5.7, 2.8, 4.1, 1.3],
vec![6.3, 3.3, 6.0, 2.5],
vec![5.8, 2.7, 5.1, 1.9],
vec![7.1, 3.0, 5.9, 2.1],
vec![6.3, 2.9, 5.6, 1.8],
vec![6.5, 3.0, 5.8, 2.2],
vec![7.6, 3.0, 6.6, 2.1],
vec![4.9, 2.5, 4.5, 1.7],
vec![7.3, 2.9, 6.3, 1.8],
vec![6.7, 2.5, 5.8, 1.8],
vec![7.2, 3.6, 6.1, 2.5],
vec![6.5, 3.2, 5.1, 2.0],
vec![6.4, 2.7, 5.3, 1.9],
vec![6.8, 3.0, 5.5, 2.1],
vec![5.7, 2.5, 5.0, 2.0],
vec![5.8, 2.8, 5.1, 2.4],
vec![6.4, 3.2, 5.3, 2.3],
vec![6.5, 3.0, 5.5, 1.8],
vec![7.7, 3.8, 6.7, 2.2],
vec![7.7, 2.6, 6.9, 2.3],
vec![6.0, 2.2, 5.0, 1.5],
vec![6.9, 3.2, 5.7, 2.3],
vec![5.6, 2.8, 4.9, 2.0],
vec![7.7, 2.8, 6.7, 2.0],
vec![6.3, 2.7, 4.9, 1.8],
vec![6.7, 3.3, 5.7, 2.1],
vec![7.2, 3.2, 6.0, 1.8],
vec![6.2, 2.8, 4.8, 1.8],
vec![6.1, 3.0, 4.9, 1.8],
vec![6.4, 2.8, 5.6, 2.1],
vec![7.2, 3.0, 5.8, 1.6],
vec![7.4, 2.8, 6.1, 1.9],
vec![7.9, 3.8, 6.4, 2.0],
vec![6.4, 2.8, 5.6, 2.2],
vec![6.3, 2.8, 5.1, 1.5],
vec![6.1, 2.6, 5.6, 1.4],
vec![7.7, 3.0, 6.1, 2.3],
vec![6.3, 3.4, 5.6, 2.4],
vec![6.4, 3.1, 5.5, 1.8],
vec![6.0, 3.0, 4.8, 1.8],
vec![6.9, 3.1, 5.4, 2.1],
vec![6.7, 3.1, 5.6, 2.4],
vec![6.9, 3.1, 5.1, 2.3],
vec![5.8, 2.7, 5.1, 1.9],
vec![6.8, 3.2, 5.9, 2.3],
vec![6.7, 3.3, 5.7, 2.5],
vec![6.7, 3.0, 5.2, 2.3],
vec![6.3, 2.5, 5.0, 1.9],
vec![6.5, 3.0, 5.2, 2.0],
vec![6.2, 3.4, 5.4, 2.3],
vec![5.9, 3.0, 5.1, 1.8],
]
.into_iter()
.map(|x| Array1::from_vec(x))
.collect::<Vec<Array1<f32>>>();
let ys = vec![
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0,
0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 0.0, 1.0, 1.0, 1.0, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0,
1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0,
2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0,
2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0, 2.0,
2.0, 2.0, 2.0, 2.0, 2.0, 2.0,
];
let (training_data, testing_data) = xs.split_at(50);
let (training_res, testing_res) = ys.split_at(50);
for i in 0..100 {
println!("epoch {}", i + 1);
nn.epoch((training_data, training_res));
}
println!("{}", nn.accuracy(&testing_data, &testing_res));
}