1use crate::{impl_storable_bounded, manager::MEMORY_MANAGER};
4use candid::CandidType;
5use canic_cdk::{
6 structures::{
7 BTreeMap as StableBTreeMap, DefaultMemoryImpl,
8 memory::{MemoryId, VirtualMemory},
9 },
10 types::BoundedString256,
11 utils::time::now_secs,
12};
13use serde::{Deserialize, Serialize};
14use std::cell::RefCell;
15use thiserror::Error as ThisError;
16
17pub const MEMORY_REGISTRY_ID: u8 = 0;
21pub const MEMORY_RANGES_ID: u8 = 1;
22
23thread_local! {
28 static MEMORY_REGISTRY: RefCell<StableBTreeMap<u8, MemoryRegistryEntry, VirtualMemory<DefaultMemoryImpl>>> =
29 RefCell::new(StableBTreeMap::init(
30 MEMORY_MANAGER.with_borrow(|this| {
31 this.get(MemoryId::new(MEMORY_REGISTRY_ID))
32 }),
33 ));
34}
35
36thread_local! {
41 static MEMORY_RANGES: RefCell<StableBTreeMap<String, MemoryRange, VirtualMemory<DefaultMemoryImpl>>> =
42 RefCell::new(StableBTreeMap::init(
43 MEMORY_MANAGER.with_borrow(|mgr| {
44 mgr.get(MemoryId::new(MEMORY_RANGES_ID))
45 }),
46 ));
47}
48
49thread_local! {
59 static PENDING_REGISTRATIONS: RefCell<Vec<(u8, &'static str, &'static str)>> = const {
60 RefCell::new(Vec::new())
61 };
62}
63
64pub fn defer_register(id: u8, crate_name: &'static str, label: &'static str) {
66 PENDING_REGISTRATIONS.with(|q| {
67 q.borrow_mut().push((id, crate_name, label));
68 });
69}
70
71#[must_use]
74pub fn drain_pending_registrations() -> Vec<(u8, &'static str, &'static str)> {
75 PENDING_REGISTRATIONS.with(|q| q.borrow_mut().drain(..).collect())
76}
77
78thread_local! {
83 pub static PENDING_RANGES: RefCell<Vec<(&'static str, u8, u8)>> = const {
84 RefCell::new(Vec::new())
85 };
86}
87
88pub fn defer_reserve_range(crate_name: &'static str, start: u8, end: u8) {
90 PENDING_RANGES.with(|q| q.borrow_mut().push((crate_name, start, end)));
91}
92
93#[must_use]
96pub fn drain_pending_ranges() -> Vec<(&'static str, u8, u8)> {
97 PENDING_RANGES.with(|q| q.borrow_mut().drain(..).collect())
98}
99
100#[derive(Debug, ThisError)]
105pub enum MemoryRegistryError {
106 #[error("ID {0} is already registered with type {1}, tried to register type {2}")]
107 AlreadyRegistered(u8, String, String),
108
109 #[error("crate `{0}` key too long ({1} bytes), max 256")]
110 CrateKeyTooLong(String, usize),
111
112 #[error("crate `{0}` already has a reserved range")]
113 DuplicateRange(String),
114
115 #[error("crate `{0}` provided invalid range {1}-{2} (start > end)")]
116 InvalidRange(String, u8, u8),
117
118 #[error("label for crate `{0}` too long ({1} bytes), max 256")]
119 LabelTooLong(String, usize),
120
121 #[error("crate `{0}` attempted to register ID {1}, but it is outside its allowed ranges")]
122 OutOfRange(String, u8),
123
124 #[error("crate `{0}` range {1}-{2} overlaps with crate `{3}` range {4}-{5}")]
125 Overlap(String, u8, u8, String, u8, u8),
126
127 #[error("crate `{0}` has not reserved any memory range")]
128 NoRange(String),
129}
130
131#[derive(Clone, Debug, Deserialize, Serialize)]
136pub struct MemoryRange {
137 pub start: u8,
138 pub end: u8,
139 pub created_at: u64,
140}
141
142impl MemoryRange {
143 pub(crate) fn try_new(
144 crate_name: &str,
145 start: u8,
146 end: u8,
147 ) -> Result<Self, MemoryRegistryError> {
148 let _ = BoundedString256::try_new(crate_name).map_err(|_| {
149 MemoryRegistryError::CrateKeyTooLong(crate_name.to_string(), crate_name.len())
150 })?;
151
152 Ok(Self {
153 start,
154 end,
155 created_at: now_secs(),
156 })
157 }
158
159 #[must_use]
160 pub fn contains(&self, id: u8) -> bool {
161 (self.start..=self.end).contains(&id)
162 }
163}
164
165impl_storable_bounded!(MemoryRange, 320, false);
166
167#[derive(CandidType, Clone, Debug, Deserialize, Serialize)]
172pub struct MemoryRegistryEntry {
173 pub label: BoundedString256,
174 pub created_at: u64,
175}
176
177impl MemoryRegistryEntry {
178 pub(crate) fn try_new(crate_name: &str, label: &str) -> Result<Self, MemoryRegistryError> {
179 let label = BoundedString256::try_new(label)
180 .map_err(|_| MemoryRegistryError::LabelTooLong(crate_name.to_string(), label.len()))?;
181
182 Ok(Self {
183 label,
184 created_at: now_secs(),
185 })
186 }
187}
188
189impl_storable_bounded!(MemoryRegistryEntry, 320, false);
190
191pub type MemoryRegistryView = Vec<(u8, MemoryRegistryEntry)>;
196
197pub struct MemoryRegistry;
202
203impl MemoryRegistry {
204 pub fn register(id: u8, crate_name: &str, label: &str) -> Result<(), MemoryRegistryError> {
211 let crate_key = crate_name.to_string();
212
213 let range = MEMORY_RANGES.with_borrow(|ranges| ranges.get(&crate_key));
215 match range {
216 None => {
217 return Err(MemoryRegistryError::NoRange(crate_key));
218 }
219 Some(r) if !r.contains(id) => {
220 return Err(MemoryRegistryError::OutOfRange(crate_key, id));
221 }
222 Some(_) => {
223 }
225 }
226
227 let existing = MEMORY_REGISTRY.with_borrow(|map| map.get(&id));
229 if let Some(existing) = existing {
230 if existing.label.as_ref() != label {
231 return Err(MemoryRegistryError::AlreadyRegistered(
232 id,
233 existing.label.to_string(),
234 label.to_string(),
235 ));
236 }
237
238 return Ok(());
240 }
241
242 let entry = MemoryRegistryEntry::try_new(crate_name, label)?;
244 MEMORY_REGISTRY.with_borrow_mut(|map| {
245 map.insert(id, entry);
246 });
247
248 Ok(())
249 }
250
251 pub fn reserve_range(crate_name: &str, start: u8, end: u8) -> Result<(), MemoryRegistryError> {
255 if start > end {
256 return Err(MemoryRegistryError::InvalidRange(
257 crate_name.to_string(),
258 start,
259 end,
260 ));
261 }
262
263 let crate_key = crate_name.to_string();
264
265 let conflict = MEMORY_RANGES.with_borrow(|ranges| {
267 if let Some(existing) = ranges.get(&crate_key) {
268 if existing.start == start && existing.end == end {
269 return None;
270 }
271
272 return Some(MemoryRegistryError::DuplicateRange(crate_key.clone()));
273 }
274
275 for entry in ranges.iter() {
276 let other_crate = entry.key();
277 let other_range = entry.value();
278
279 if !(end < other_range.start || start > other_range.end) {
280 return Some(MemoryRegistryError::Overlap(
281 crate_key.clone(),
282 start,
283 end,
284 other_crate.clone(),
285 other_range.start,
286 other_range.end,
287 ));
288 }
289 }
290
291 None
292 });
293
294 if let Some(err) = conflict {
295 return Err(err);
296 }
297
298 let range = MemoryRange::try_new(crate_name, start, end)?;
300 MEMORY_RANGES.with_borrow_mut(|ranges| {
301 ranges.insert(crate_name.to_string(), range);
302 });
303
304 Ok(())
305 }
306
307 #[must_use]
308 pub fn get(id: u8) -> Option<MemoryRegistryEntry> {
309 MEMORY_REGISTRY.with_borrow(|map| map.get(&id))
310 }
311
312 #[must_use]
313 pub fn export() -> MemoryRegistryView {
314 MEMORY_REGISTRY.with_borrow(|map| {
315 map.iter()
316 .map(|entry| (*entry.key(), entry.value()))
317 .collect()
318 })
319 }
320
321 #[must_use]
322 pub fn export_ranges() -> Vec<(String, MemoryRange)> {
323 MEMORY_RANGES.with_borrow(|ranges| {
324 ranges
325 .iter()
326 .map(|e| (e.key().clone(), e.value()))
327 .collect()
328 })
329 }
330}
331
332#[cfg(test)]
333pub(crate) fn reset_for_tests() {
334 MEMORY_REGISTRY.with_borrow_mut(StableBTreeMap::clear);
335 MEMORY_RANGES.with_borrow_mut(StableBTreeMap::clear);
336 PENDING_REGISTRATIONS.with(|q| q.borrow_mut().clear());
337 PENDING_RANGES.with(|q| q.borrow_mut().clear());
338}
339
340#[cfg(test)]
345mod tests {
346 use super::*;
347
348 #[test]
349 fn reserve_range_happy_path_and_reject_overlap() {
350 reset_for_tests();
351 MemoryRegistry::reserve_range("crate_a", 10, 20).unwrap();
352
353 let err = MemoryRegistry::reserve_range("crate_b", 15, 25).unwrap_err();
355 assert!(matches!(
356 err,
357 MemoryRegistryError::Overlap(_, _, _, _, _, _)
358 ));
359
360 MemoryRegistry::reserve_range("crate_b", 30, 40).unwrap();
362
363 let ranges = MemoryRegistry::export_ranges();
364 assert_eq!(ranges.len(), 2);
365 }
366
367 #[test]
368 fn reserve_range_rejects_invalid_order() {
369 reset_for_tests();
370 let err = MemoryRegistry::reserve_range("crate_a", 5, 4).unwrap_err();
371 assert!(matches!(err, MemoryRegistryError::InvalidRange(_, _, _)));
372 assert!(MemoryRegistry::export_ranges().is_empty());
373 }
374
375 #[test]
376 fn register_id_requires_range_and_checks_bounds() {
377 reset_for_tests();
378 MemoryRegistry::reserve_range("crate_a", 1, 3).unwrap();
379
380 let err = MemoryRegistry::register(5, "crate_a", "Foo").unwrap_err();
382 assert!(matches!(err, MemoryRegistryError::OutOfRange(_, _)));
383
384 MemoryRegistry::register(2, "crate_a", "Foo").unwrap();
386
387 MemoryRegistry::register(2, "crate_a", "Foo").unwrap();
389
390 let err = MemoryRegistry::register(2, "crate_a", "Bar").unwrap_err();
392 assert!(matches!(
393 err,
394 MemoryRegistryError::AlreadyRegistered(_, _, _)
395 ));
396
397 let view = MemoryRegistry::export();
398 assert_eq!(view.len(), 1);
399 assert_eq!(view[0].0, 2);
400 }
401
402 #[test]
403 fn pending_queues_drain_in_order() {
404 reset_for_tests();
405 defer_reserve_range("crate_a", 1, 2);
406 defer_reserve_range("crate_b", 3, 4);
407 defer_register(1, "crate_a", "A1");
408 defer_register(3, "crate_b", "B3");
409
410 let ranges = drain_pending_ranges();
411 assert_eq!(ranges, vec![("crate_a", 1, 2), ("crate_b", 3, 4)]);
412 let regs = drain_pending_registrations();
413 assert_eq!(regs, vec![(1, "crate_a", "A1"), (3, "crate_b", "B3")]);
414
415 assert!(drain_pending_ranges().is_empty());
417 assert!(drain_pending_registrations().is_empty());
418 }
419
420 #[test]
421 fn reserve_range_rejects_too_long_crate_key() {
422 reset_for_tests();
423
424 let crate_name = "a".repeat(257);
425 let err = MemoryRegistry::reserve_range(&crate_name, 1, 2).unwrap_err();
426 assert!(matches!(err, MemoryRegistryError::CrateKeyTooLong(_, 257)));
427 }
428
429 #[test]
430 fn register_rejects_too_long_label() {
431 reset_for_tests();
432 MemoryRegistry::reserve_range("crate_a", 1, 3).unwrap();
433
434 let label = "a".repeat(257);
435 let err = MemoryRegistry::register(2, "crate_a", &label).unwrap_err();
436 assert!(matches!(err, MemoryRegistryError::LabelTooLong(_, 257)));
437 }
438}