qudit_tensor/cpu/
tnvm.rs

1use std::marker::PhantomPinned;
2use std::pin::Pin;
3use std::sync::Arc;
4use std::sync::Mutex;
5
6use qudit_core::RealScalar;
7use qudit_expr::DifferentiationLevel;
8use qudit_expr::ExpressionCache;
9use qudit_expr::FUNCTION;
10use qudit_expr::GenerationShape;
11use rustc_hash::FxHashMap;
12
13use super::buffer::SizedTensorBuffer;
14use super::instruction::TNVMInstruction;
15use crate::bytecode::Bytecode;
16use crate::cpu::TNVMResult;
17
18use qudit_core::ComplexScalar;
19use qudit_core::memory::MemoryBuffer;
20use qudit_core::memory::alloc_zeroed_memory;
21
22pub type PinnedTNVM<C, const D: DifferentiationLevel> = Pin<Box<TNVM<C, D>>>;
23
24/// Parameters for a TNVM evaluation; tracks constant and variable arguments correctly.
25struct ParamBuffer<R: RealScalar> {
26    /// Buffer storing the entire parameter vector
27    buffer: MemoryBuffer<R>,
28
29    /// Maps variable parameter i to parameter variable_map[i]
30    variable_map: Vec<usize>,
31
32    /// Flag that enables a shortcut in updates
33    fully_parameterized: bool,
34}
35
36impl<R: RealScalar> ParamBuffer<R> {
37    /// Allocate a new parameter buffer with constant arguments cached
38    fn new(num_params: usize, const_map: Option<&FxHashMap<usize, R>>) -> Self {
39        let mut buffer = alloc_zeroed_memory(num_params);
40
41        if let Some(const_map) = const_map {
42            for (idx, arg) in const_map.iter() {
43                buffer[*idx] = *arg;
44            }
45
46            let mut variable_map = Vec::with_capacity(num_params - const_map.len());
47            for candidate_var_idx in 0..num_params {
48                if !const_map.contains_key(&candidate_var_idx) {
49                    variable_map.push(candidate_var_idx);
50                }
51            }
52
53            let fully_parameterized = variable_map.len() == num_params;
54
55            Self {
56                buffer,
57                variable_map,
58                fully_parameterized,
59            }
60        } else {
61            Self {
62                buffer,
63                variable_map: (0..num_params).collect(),
64                fully_parameterized: true,
65            }
66        }
67    }
68
69    /// Places the variable arguments into the buffer
70    #[inline(always)]
71    fn as_slice_with_var_args<'a, 'b>(&'a mut self, var_args: &'b [R]) -> &'b [R]
72    where
73        'a: 'b,
74    {
75        debug_assert_eq!(var_args.len(), self.variable_map.len());
76
77        if self.fully_parameterized {
78            return var_args;
79        }
80
81        for (arg, idx) in var_args.iter().zip(self.variable_map.iter()) {
82            self.buffer[*idx] = *arg;
83        }
84
85        self.as_slice()
86    }
87
88    /// Convert the buffer to a slice of arguments
89    #[inline(always)]
90    fn as_slice(&self) -> &[R] {
91        self.buffer.as_slice()
92    }
93}
94
95pub struct TNVM<C: ComplexScalar, const D: DifferentiationLevel> {
96    const_instructions: Vec<TNVMInstruction<C, D>>,
97    dynamic_instructions: Vec<TNVMInstruction<C, D>>,
98    #[allow(dead_code)] // Necessary to hold handle on expressions for safety.
99    expressions: Arc<Mutex<ExpressionCache>>,
100    memory: MemoryBuffer<C>,
101    param_buffer: ParamBuffer<C::R>,
102    out_buffer: SizedTensorBuffer<C>,
103    _pin: PhantomPinned,
104    // TODO: hold a mutable borrow of the expressions to prevent any uncompiling of it
105}
106
107impl<C: ComplexScalar, const D: DifferentiationLevel> TNVM<C, D> {
108    pub fn new(program: &Bytecode, const_map: Option<&FxHashMap<usize, C::R>>) -> Pin<Box<Self>> {
109        if program.buffers.is_empty() {
110            panic!("Cannot build TNVM with zero-length bytecode.");
111        };
112
113        let mut sized_buffers = Vec::with_capacity(program.buffers.len());
114        let mut offset = 0;
115        let mut out_buffer = None;
116        for (i, buffer) in program.buffers.iter().enumerate() {
117            let sized_buffer = if i == program.out_buffer {
118                let out = SizedTensorBuffer::contiguous(offset, buffer);
119                out_buffer = Some(out.clone());
120                out
121            } else {
122                SizedTensorBuffer::new(offset, buffer)
123            };
124            offset += sized_buffer.memory_size(D);
125            sized_buffers.push(sized_buffer);
126        }
127        let memory_size = offset;
128        // println!("ALLOCATING {} bytes for {} units", memory_size * std::mem::size_of::<C>(), memory_size);
129        // TODO: Log some of this stuff with proper logging utilities
130        // TODO: Explore overlapping buffers to reduce memory usage and increase locality
131        // TODO: Can further optimize FRPR after knowing strides: simple reshapes on continuous
132        // buffers can be skipped with the input and output buffer having the same offset
133        // but different strides.
134
135        let expressions = program.expressions.clone();
136
137        // Ensure that all expressions are prepared up to diff_level D.
138        expressions.lock().unwrap().prepare(D);
139
140        // Generate instructions
141        let mut const_instructions = Vec::new();
142        for inst in &program.const_code {
143            const_instructions.push(TNVMInstruction::new(
144                inst,
145                &sized_buffers,
146                expressions.clone(),
147            ));
148        }
149
150        let mut dynamic_instructions = Vec::new();
151        for inst in &program.dynamic_code {
152            dynamic_instructions.push(TNVMInstruction::new(
153                inst,
154                &sized_buffers,
155                expressions.clone(),
156            ));
157        }
158
159        // Initialize parameter buffer
160        let param_buffer = ParamBuffer::new(program.num_params, const_map);
161
162        let mut out = Self {
163            const_instructions,
164            dynamic_instructions,
165            expressions,
166            memory: alloc_zeroed_memory::<C>(memory_size),
167            param_buffer,
168            out_buffer: out_buffer.expect("Error finding output buffer index from bytecode."),
169            _pin: PhantomPinned,
170        };
171
172        // Evaluate const code
173        for inst in &out.const_instructions {
174            unsafe { inst.evaluate::<FUNCTION>(&[], &mut out.memory) };
175        }
176
177        Box::pin(out)
178    }
179
180    // TODO: evaluate_into
181
182    /// Evaluate the TNVM with the provided arguments for all variable parameters.
183    pub fn evaluate<'a, const E: DifferentiationLevel>(
184        self: &'a mut Pin<Box<Self>>,
185        var_args: &[C::R],
186    ) -> TNVMResult<'a, C> {
187        if E > D {
188            panic!("Unsafe TNVM evaluation.");
189        }
190
191        // Safety: Self is not moved
192        unsafe {
193            let this = self.as_mut().get_unchecked_mut();
194
195            let arg_slice = this.param_buffer.as_slice_with_var_args(var_args);
196
197            for inst in &this.dynamic_instructions {
198                // Safety: Whole structure of TNVM ensures that the instruction
199                // evaluates only on memory it has access to.
200                inst.evaluate::<E>(arg_slice, &mut this.memory);
201            }
202
203            // Safety: Projection of const reference from mutable pin. Caller
204            // cannot move data from this structure.
205            TNVMResult::new(&this.memory, &this.out_buffer)
206        }
207    }
208
209    pub fn num_params(&self) -> usize {
210        self.out_buffer.nparams()
211    }
212
213    pub fn out_shape(&self) -> GenerationShape {
214        self.out_buffer.shape()
215    }
216}