1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
mod shared;

use self::shared::SharedValidator;
use crossbeam_queue::ArrayQueue;
use rkyv::validation::{validators::ArchiveValidator, ArchiveContext, SharedContext};
use std::{
	any::TypeId,
	mem,
	num::NonZeroUsize,
	ops::Range,
	sync::{Arc, Weak},
};

#[derive(Debug)]
struct Inner {
	shared: ArrayQueue<shared::SharedValidator>,
}

#[derive(Clone, Debug)]
pub struct ValidatorPool {
	inner: Arc<Inner>,
}

impl ValidatorPool {
	pub fn new(capacity: usize) -> Self {
		Self {
			inner: Arc::new(Inner {
				shared: ArrayQueue::new(capacity),
			}),
		}
	}

	pub fn validator(&self, bytes: &[u8]) -> PooledValidator {
		self.validator_with_max_depth(bytes, None)
	}

	pub fn validator_with_max_depth(
		&self,
		bytes: &[u8],
		max_depth: Option<NonZeroUsize>,
	) -> PooledValidator {
		let shared = self.inner.shared.pop().unwrap_or_default();

		PooledValidator {
			pool_ref: Arc::downgrade(&self.inner),
			archive: ArchiveValidator::with_max_depth(bytes, max_depth),
			shared,
		}
	}
}

#[derive(Debug)]
pub struct PooledValidator {
	pool_ref: Weak<Inner>,
	archive: ArchiveValidator,
	shared: SharedValidator,
}

impl Drop for PooledValidator {
	fn drop(&mut self) {
		if let Some(pool) = self.pool_ref.upgrade() {
			self.shared.clear();
			let _ = pool.shared.push(mem::take(&mut self.shared));
		}
	}
}

unsafe impl<E> ArchiveContext<E> for PooledValidator
where
	ArchiveValidator: ArchiveContext<E>,
{
	#[inline]
	fn check_subtree_ptr(&mut self, ptr: *const u8, layout: &core::alloc::Layout) -> Result<(), E> {
		self.archive.check_subtree_ptr(ptr, layout)
	}

	#[inline]
	unsafe fn push_prefix_subtree_range(
		&mut self,
		root: *const u8,
		end: *const u8,
	) -> Result<Range<usize>, E> {
		self.archive.push_prefix_subtree_range(root, end)
	}

	#[inline]
	unsafe fn push_suffix_subtree_range(
		&mut self,
		start: *const u8,
		root: *const u8,
	) -> Result<Range<usize>, E> {
		self.archive.push_suffix_subtree_range(start, root)
	}

	#[inline]
	unsafe fn pop_subtree_range(&mut self, range: Range<usize>) -> Result<(), E> {
		unsafe { self.archive.pop_subtree_range(range) }
	}
}

impl<E> SharedContext<E> for PooledValidator
where
	SharedValidator: SharedContext<E>,
{
	#[inline]
	fn register_shared_ptr(&mut self, address: usize, type_id: TypeId) -> Result<bool, E> {
		self.shared.register_shared_ptr(address, type_id)
	}
}