1use std::collections::HashMap;
2
3use arbitrary::Arbitrary;
4use itertools::Itertools;
5use num_traits::ConstOne;
6use triton_vm::memory_layout::MemoryRegion;
7use triton_vm::prelude::*;
8
9use crate::prelude::*;
10
11const STATIC_MEMORY_FIRST_ADDRESS_AS_U64: u64 = BFieldElement::MAX - 1;
19pub const STATIC_MEMORY_FIRST_ADDRESS: BFieldElement =
20 BFieldElement::new(STATIC_MEMORY_FIRST_ADDRESS_AS_U64);
21pub const STATIC_MEMORY_LAST_ADDRESS: BFieldElement =
22 BFieldElement::new(STATIC_MEMORY_FIRST_ADDRESS_AS_U64 - u32::MAX as u64);
23
24#[derive(Clone, Debug)]
27pub struct Library {
28 seen_snippets: HashMap<String, Vec<LabelledInstruction>>,
30
31 num_allocated_words: u32,
33}
34
35#[derive(Debug, Copy, Clone, Eq, PartialEq, Hash, Arbitrary)]
40pub struct StaticAllocation {
41 write_address: BFieldElement,
42 num_words: u32,
43}
44
45impl StaticAllocation {
46 pub fn read_address(&self) -> BFieldElement {
48 let offset = bfe!(self.num_words) - BFieldElement::ONE;
49 self.write_address() + offset
50 }
51
52 pub fn write_address(&self) -> BFieldElement {
54 self.write_address
55 }
56
57 pub fn num_words(&self) -> u32 {
59 self.num_words
60 }
61}
62
63impl Default for Library {
64 fn default() -> Self {
65 Self::new()
66 }
67}
68
69impl Library {
70 pub fn kmalloc_memory_region() -> MemoryRegion {
71 MemoryRegion::new(STATIC_MEMORY_LAST_ADDRESS, 1 << 32)
72 }
73
74 pub fn new() -> Self {
75 Self {
76 seen_snippets: HashMap::default(),
77 num_allocated_words: 0,
78 }
79 }
80
81 pub fn empty() -> Self {
83 Self::new()
84 }
85
86 #[cfg(test)]
87 pub fn with_preallocated_memory(words_statically_allocated: u32) -> Self {
88 Library {
89 num_allocated_words: words_statically_allocated,
90 ..Self::new()
91 }
92 }
93
94 pub fn import(&mut self, snippet: Box<dyn BasicSnippet>) -> String {
103 let dep_entrypoint = snippet.entrypoint();
104
105 let is_new_dependency = !self.seen_snippets.contains_key(&dep_entrypoint);
106 if is_new_dependency {
107 let dep_body = snippet.annotated_code(self);
108 self.seen_snippets.insert(dep_entrypoint.clone(), dep_body);
109 }
110
111 dep_entrypoint
112 }
113
114 pub fn explicit_import(&mut self, name: &str, body: &[LabelledInstruction]) -> String {
122 if !self.seen_snippets.contains_key(name) {
123 self.seen_snippets.insert(name.to_owned(), body.to_vec());
124 }
125
126 name.to_string()
127 }
128
129 pub fn all_external_dependencies(&self) -> Vec<Vec<LabelledInstruction>> {
132 self.seen_snippets
133 .iter()
134 .sorted_by_key(|(k, _)| *k)
135 .map(|(_, code)| code.clone())
136 .collect()
137 }
138
139 pub fn get_all_snippet_names(&self) -> Vec<String> {
142 let mut ret = self.seen_snippets.keys().cloned().collect_vec();
143 ret.sort_unstable();
144 ret
145 }
146
147 pub fn all_imports(&self) -> Vec<LabelledInstruction> {
149 self.all_external_dependencies().concat()
150 }
151
152 pub fn kmalloc(&mut self, num_words: u32) -> StaticAllocation {
160 assert!(num_words > 0, "must allocate a positive number of words");
161 let write_address =
162 STATIC_MEMORY_FIRST_ADDRESS - bfe!(self.num_allocated_words) - bfe!(num_words - 1);
163 self.num_allocated_words = self
164 .num_allocated_words
165 .checked_add(num_words)
166 .expect("Cannot allocate more that u32::MAX words through `kmalloc`.");
167
168 StaticAllocation {
169 write_address,
170 num_words,
171 }
172 }
173}
174
175#[cfg(test)]
176mod tests {
177 use triton_vm::prelude::Program;
178 use triton_vm::prelude::triton_asm;
179
180 use super::*;
181 use crate::mmr::calculate_new_peaks_from_leaf_mutation::MmrCalculateNewPeaksFromLeafMutationMtIndices;
182 use crate::test_prelude::*;
183
184 #[derive(Debug, Copy, Clone, BFieldCodec)]
185 struct ZeroSizedType;
186
187 #[derive(Debug)]
188 struct DummyTestSnippetA;
189
190 #[derive(Debug)]
191 struct DummyTestSnippetB;
192
193 #[derive(Debug)]
194 struct DummyTestSnippetC;
195
196 impl BasicSnippet for DummyTestSnippetA {
197 fn parameters(&self) -> Vec<(DataType, String)> {
198 vec![]
199 }
200
201 fn return_values(&self) -> Vec<(DataType, String)> {
202 vec![(DataType::Xfe, "dummy".to_string())]
203 }
204
205 fn entrypoint(&self) -> String {
206 "tasmlib_a_dummy_test_value".to_string()
207 }
208
209 fn code(&self, library: &mut Library) -> Vec<LabelledInstruction> {
210 let b = library.import(Box::new(DummyTestSnippetB));
211 let c = library.import(Box::new(DummyTestSnippetC));
212
213 triton_asm!(
214 {self.entrypoint()}:
215 call {b}
216 call {c}
217 return
218 )
219 }
220 }
221
222 impl BasicSnippet for DummyTestSnippetB {
223 fn parameters(&self) -> Vec<(DataType, String)> {
224 vec![]
225 }
226
227 fn return_values(&self) -> Vec<(DataType, String)> {
228 ["1"; 2]
229 .map(|name| (DataType::Bfe, name.to_string()))
230 .to_vec()
231 }
232
233 fn entrypoint(&self) -> String {
234 "tasmlib_b_dummy_test_value".to_string()
235 }
236
237 fn code(&self, library: &mut Library) -> Vec<LabelledInstruction> {
238 let c = library.import(Box::new(DummyTestSnippetC));
239
240 triton_asm!(
241 {self.entrypoint()}:
242 call {c}
243 call {c}
244 return
245 )
246 }
247 }
248
249 impl BasicSnippet for DummyTestSnippetC {
250 fn parameters(&self) -> Vec<(DataType, String)> {
251 vec![]
252 }
253
254 fn return_values(&self) -> Vec<(DataType, String)> {
255 vec![(DataType::Bfe, "1".to_string())]
256 }
257
258 fn entrypoint(&self) -> String {
259 "tasmlib_c_dummy_test_value".to_string()
260 }
261
262 fn code(&self, _: &mut Library) -> Vec<LabelledInstruction> {
263 triton_asm!({self.entrypoint()}: push 1 return)
264 }
265 }
266
267 impl Closure for DummyTestSnippetA {
268 type Args = ZeroSizedType;
269
270 fn rust_shadow(&self, stack: &mut Vec<BFieldElement>) -> Result<(), RustShadowError> {
271 push_encodable(stack, &xfe![[1, 1, 1]]);
272 Ok(())
273 }
274
275 fn pseudorandom_args(&self, _: [u8; 32], _: Option<BenchmarkCase>) -> Self::Args {
276 ZeroSizedType
277 }
278 }
279
280 impl Closure for DummyTestSnippetB {
281 type Args = ZeroSizedType;
282
283 fn rust_shadow(&self, stack: &mut Vec<BFieldElement>) -> Result<(), RustShadowError> {
284 stack.push(bfe!(1));
285 stack.push(bfe!(1));
286 Ok(())
287 }
288
289 fn pseudorandom_args(&self, _: [u8; 32], _: Option<BenchmarkCase>) -> Self::Args {
290 ZeroSizedType
291 }
292 }
293
294 impl Closure for DummyTestSnippetC {
295 type Args = ZeroSizedType;
296
297 fn rust_shadow(&self, stack: &mut Vec<BFieldElement>) -> Result<(), RustShadowError> {
298 stack.push(bfe!(1));
299 Ok(())
300 }
301
302 fn pseudorandom_args(&self, _: [u8; 32], _: Option<BenchmarkCase>) -> Self::Args {
303 ZeroSizedType
304 }
305 }
306
307 #[macro_rules_attr::apply(test)]
308 fn library_includes() {
309 ShadowedClosure::new(DummyTestSnippetA).test();
310 ShadowedClosure::new(DummyTestSnippetB).test();
311 ShadowedClosure::new(DummyTestSnippetC).test();
312 }
313
314 #[macro_rules_attr::apply(test)]
315 fn get_all_snippet_names_test_a() {
316 let mut lib = Library::new();
317 lib.import(Box::new(DummyTestSnippetA));
318 assert_eq!(
319 vec![
320 "tasmlib_a_dummy_test_value",
321 "tasmlib_b_dummy_test_value",
322 "tasmlib_c_dummy_test_value",
323 ],
324 lib.get_all_snippet_names()
325 );
326 }
327
328 #[macro_rules_attr::apply(test)]
329 fn get_all_snippet_names_test_b() {
330 let mut lib = Library::new();
331 lib.import(Box::new(DummyTestSnippetB));
332 assert_eq!(
333 vec!["tasmlib_b_dummy_test_value", "tasmlib_c_dummy_test_value"],
334 lib.get_all_snippet_names()
335 );
336 }
337
338 #[macro_rules_attr::apply(test)]
339 fn all_imports_as_instruction_lists() {
340 let mut lib = Library::new();
341 lib.import(Box::new(DummyTestSnippetA));
342 lib.import(Box::new(DummyTestSnippetA));
343 lib.import(Box::new(DummyTestSnippetC));
344 let _ret = lib.all_imports();
345 }
346
347 #[macro_rules_attr::apply(test)]
348 fn program_is_deterministic() {
349 fn smaller_program() -> Program {
352 let mut library = Library::new();
353 let memcpy = library.import(Box::new(MemCpy));
354 let calculate_new_peaks_from_leaf_mutation =
355 library.import(Box::new(MmrCalculateNewPeaksFromLeafMutationMtIndices));
356
357 let code = triton_asm!(
358 lala_entrypoint:
359 push 1 call {memcpy}
360 call {calculate_new_peaks_from_leaf_mutation}
361
362 return
363 );
364
365 let mut src = code;
366 let mut imports = library.all_imports();
367
368 let all_ext_deps = library.all_external_dependencies();
371 let imports_repeated = all_ext_deps.concat();
372 assert_eq!(imports, imports_repeated);
373
374 src.append(&mut imports);
375
376 Program::new(&src)
377 }
378
379 for _ in 0..100 {
380 let program = smaller_program();
381 let same_program = smaller_program();
382 assert_eq!(program, same_program);
383 }
384 }
385
386 #[macro_rules_attr::apply(test)]
387 fn kmalloc_test() {
388 const MINUS_TWO: BFieldElement = BFieldElement::new(BFieldElement::MAX - 1);
389 let mut lib = Library::new();
390
391 let first_chunk = lib.kmalloc(1);
392 assert_eq!(MINUS_TWO, first_chunk.write_address());
393
394 let second_chunk = lib.kmalloc(7);
395 assert_eq!(-bfe!(9), second_chunk.write_address());
396
397 let third_chunk = lib.kmalloc(1000);
398 assert_eq!(-bfe!(1009), third_chunk.write_address());
399 }
400}