aad/
tape.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
use crate::operation_record::OperationRecord;
use crate::variable::Variable;
use num_traits::Zero;
use std::cell::RefCell;

#[derive(Debug, Default)]
pub struct Tape<F: Sized> {
    pub(crate) operations: RefCell<Vec<OperationRecord<F>>>,
}

impl<F> Tape<F> {
    #[inline]
    #[must_use]
    pub const fn new() -> Self {
        Self {
            operations: RefCell::new(Vec::new()),
        }
    }

    #[inline]
    #[must_use]
    pub fn with_capacity(capacity: usize) -> Self {
        Self {
            operations: RefCell::new(Vec::with_capacity(capacity)),
        }
    }
}

impl<F: Copy + Zero> Tape<F> {
    #[inline]
    pub fn create_variable(&self, value: F) -> Variable<F> {
        Variable {
            index: {
                let mut operations = self.operations.borrow_mut();
                let count = (*operations).len();
                (*operations).push(OperationRecord([(0, F::zero()), (0, F::zero())]));
                count
            },
            tape: self,
            value,
        }
    }

    #[inline]
    pub fn create_variables_as_array<const N: usize>(&self, values: &[F; N]) -> [Variable<F>; N] {
        std::array::from_fn(|i| self.create_variable(values[i]))
    }

    #[inline]
    pub fn create_variables(&self, values: &[F]) -> Vec<Variable<F>> {
        values
            .iter()
            .map(|value| self.create_variable(*value))
            .collect()
    }
}

#[cfg(test)]
mod tests {
    use crate::tape::Tape;

    #[test]
    fn test_create_variables_as_array() {
        let tape = Tape::new();
        const N: usize = 3;
        const VALUES: [f64; N] = [1.0, 2.0, 3.0];

        let variables = tape.create_variables_as_array(&VALUES);

        assert_eq!(variables.len(), N);

        for (i, variable) in variables.iter().enumerate() {
            assert_eq!(variable.value, VALUES[i]);

            assert!(std::ptr::eq(variable.tape, &tape));
        }

        let indices: Vec<_> = variables.iter().map(|var| var.index).collect();
        let unique_indices: std::collections::HashSet<_> = indices.iter().copied().collect();
        assert_eq!(indices.len(), unique_indices.len());
    }
}