1use std::marker::PhantomPinned;
2use std::pin::Pin;
3use std::sync::Arc;
4use std::sync::Mutex;
5
6use qudit_core::RealScalar;
7use qudit_expr::DifferentiationLevel;
8use qudit_expr::ExpressionCache;
9use qudit_expr::FUNCTION;
10use qudit_expr::GenerationShape;
11use rustc_hash::FxHashMap;
12
13use super::buffer::SizedTensorBuffer;
14use super::instruction::TNVMInstruction;
15use crate::bytecode::Bytecode;
16use crate::cpu::TNVMResult;
17
18use qudit_core::ComplexScalar;
19use qudit_core::memory::MemoryBuffer;
20use qudit_core::memory::alloc_zeroed_memory;
21
22pub type PinnedTNVM<C, const D: DifferentiationLevel> = Pin<Box<TNVM<C, D>>>;
23
24struct ParamBuffer<R: RealScalar> {
26 buffer: MemoryBuffer<R>,
28
29 variable_map: Vec<usize>,
31
32 fully_parameterized: bool,
34}
35
36impl<R: RealScalar> ParamBuffer<R> {
37 fn new(num_params: usize, const_map: Option<&FxHashMap<usize, R>>) -> Self {
39 let mut buffer = alloc_zeroed_memory(num_params);
40
41 if let Some(const_map) = const_map {
42 for (idx, arg) in const_map.iter() {
43 buffer[*idx] = *arg;
44 }
45
46 let mut variable_map = Vec::with_capacity(num_params - const_map.len());
47 for candidate_var_idx in 0..num_params {
48 if !const_map.contains_key(&candidate_var_idx) {
49 variable_map.push(candidate_var_idx);
50 }
51 }
52
53 let fully_parameterized = variable_map.len() == num_params;
54
55 Self {
56 buffer,
57 variable_map,
58 fully_parameterized,
59 }
60 } else {
61 Self {
62 buffer,
63 variable_map: (0..num_params).collect(),
64 fully_parameterized: true,
65 }
66 }
67 }
68
69 #[inline(always)]
71 fn as_slice_with_var_args<'a, 'b>(&'a mut self, var_args: &'b [R]) -> &'b [R]
72 where
73 'a: 'b,
74 {
75 debug_assert_eq!(var_args.len(), self.variable_map.len());
76
77 if self.fully_parameterized {
78 return var_args;
79 }
80
81 for (arg, idx) in var_args.iter().zip(self.variable_map.iter()) {
82 self.buffer[*idx] = *arg;
83 }
84
85 self.as_slice()
86 }
87
88 #[inline(always)]
90 fn as_slice(&self) -> &[R] {
91 self.buffer.as_slice()
92 }
93}
94
95pub struct TNVM<C: ComplexScalar, const D: DifferentiationLevel> {
96 const_instructions: Vec<TNVMInstruction<C, D>>,
97 dynamic_instructions: Vec<TNVMInstruction<C, D>>,
98 #[allow(dead_code)] expressions: Arc<Mutex<ExpressionCache>>,
100 memory: MemoryBuffer<C>,
101 param_buffer: ParamBuffer<C::R>,
102 out_buffer: SizedTensorBuffer<C>,
103 _pin: PhantomPinned,
104 }
106
107impl<C: ComplexScalar, const D: DifferentiationLevel> TNVM<C, D> {
108 pub fn new(program: &Bytecode, const_map: Option<&FxHashMap<usize, C::R>>) -> Pin<Box<Self>> {
109 if program.buffers.is_empty() {
110 panic!("Cannot build TNVM with zero-length bytecode.");
111 };
112
113 let mut sized_buffers = Vec::with_capacity(program.buffers.len());
114 let mut offset = 0;
115 let mut out_buffer = None;
116 for (i, buffer) in program.buffers.iter().enumerate() {
117 let sized_buffer = if i == program.out_buffer {
118 let out = SizedTensorBuffer::contiguous(offset, buffer);
119 out_buffer = Some(out.clone());
120 out
121 } else {
122 SizedTensorBuffer::new(offset, buffer)
123 };
124 offset += sized_buffer.memory_size(D);
125 sized_buffers.push(sized_buffer);
126 }
127 let memory_size = offset;
128 let expressions = program.expressions.clone();
136
137 expressions.lock().unwrap().prepare(D);
139
140 let mut const_instructions = Vec::new();
142 for inst in &program.const_code {
143 const_instructions.push(TNVMInstruction::new(
144 inst,
145 &sized_buffers,
146 expressions.clone(),
147 ));
148 }
149
150 let mut dynamic_instructions = Vec::new();
151 for inst in &program.dynamic_code {
152 dynamic_instructions.push(TNVMInstruction::new(
153 inst,
154 &sized_buffers,
155 expressions.clone(),
156 ));
157 }
158
159 let param_buffer = ParamBuffer::new(program.num_params, const_map);
161
162 let mut out = Self {
163 const_instructions,
164 dynamic_instructions,
165 expressions,
166 memory: alloc_zeroed_memory::<C>(memory_size),
167 param_buffer,
168 out_buffer: out_buffer.expect("Error finding output buffer index from bytecode."),
169 _pin: PhantomPinned,
170 };
171
172 for inst in &out.const_instructions {
174 unsafe { inst.evaluate::<FUNCTION>(&[], &mut out.memory) };
175 }
176
177 Box::pin(out)
178 }
179
180 pub fn evaluate<'a, const E: DifferentiationLevel>(
184 self: &'a mut Pin<Box<Self>>,
185 var_args: &[C::R],
186 ) -> TNVMResult<'a, C> {
187 if E > D {
188 panic!("Unsafe TNVM evaluation.");
189 }
190
191 unsafe {
193 let this = self.as_mut().get_unchecked_mut();
194
195 let arg_slice = this.param_buffer.as_slice_with_var_args(var_args);
196
197 for inst in &this.dynamic_instructions {
198 inst.evaluate::<E>(arg_slice, &mut this.memory);
201 }
202
203 TNVMResult::new(&this.memory, &this.out_buffer)
206 }
207 }
208
209 pub fn num_params(&self) -> usize {
210 self.out_buffer.nparams()
211 }
212
213 pub fn out_shape(&self) -> GenerationShape {
214 self.out_buffer.shape()
215 }
216}