1use {
9 crate::{
10 perf_libs,
11 recycler::{RecyclerX, Reset},
12 },
13 rand::{seq::SliceRandom, Rng},
14 rayon::prelude::*,
15 serde::{Deserialize, Serialize},
16 std::{
17 ops::{Deref, DerefMut, Index, IndexMut},
18 os::raw::c_int,
19 slice::{Iter, SliceIndex},
20 sync::Weak,
21 },
22};
23
24const CUDA_SUCCESS: c_int = 0;
25
26fn pin<T>(mem: &mut Vec<T>) {
27 if let Some(api) = perf_libs::api() {
28 use std::{ffi::c_void, mem::size_of};
29
30 let ptr = mem.as_mut_ptr();
31 let size = mem.capacity().saturating_mul(size_of::<T>());
32 let err = unsafe {
33 (api.cuda_host_register)(ptr as *mut c_void, size, 0)
34 };
35 assert!(
36 err == CUDA_SUCCESS,
37 "cudaHostRegister error: {err} ptr: {ptr:?} bytes: {size}"
38 );
39 }
40}
41
42fn unpin<T>(mem: *mut T) {
43 if let Some(api) = perf_libs::api() {
44 use std::ffi::c_void;
45
46 let err = unsafe { (api.cuda_host_unregister)(mem as *mut c_void) };
47 assert!(
48 err == CUDA_SUCCESS,
49 "cudaHostUnregister returned: {err} ptr: {mem:?}"
50 );
51 }
52}
53
54#[cfg_attr(feature = "frozen-abi", derive(AbiExample))]
58#[derive(Debug, Default, Serialize, Deserialize)]
59pub struct PinnedVec<T: Default + Clone + Sized> {
60 x: Vec<T>,
61 pinned: bool,
62 pinnable: bool,
63 #[serde(skip)]
64 recycler: Weak<RecyclerX<PinnedVec<T>>>,
65}
66
67impl<T: Default + Clone + Sized> Reset for PinnedVec<T> {
68 fn reset(&mut self) {
69 self.resize(0, T::default());
70 }
71 fn warm(&mut self, size_hint: usize) {
72 self.set_pinnable();
73 self.resize(size_hint, T::default());
74 }
75 fn set_recycler(&mut self, recycler: Weak<RecyclerX<Self>>) {
76 self.recycler = recycler;
77 }
78}
79
80impl<T: Clone + Default + Sized> From<PinnedVec<T>> for Vec<T> {
81 fn from(mut pinned_vec: PinnedVec<T>) -> Self {
82 if pinned_vec.pinned {
83 if pinned_vec.recycler.strong_count() != 0 {
87 return pinned_vec.x.clone();
88 }
89 unpin(pinned_vec.x.as_mut_ptr());
90 pinned_vec.pinned = false;
91 }
92 pinned_vec.pinnable = false;
93 pinned_vec.recycler = Weak::default();
94 std::mem::take(&mut pinned_vec.x)
95 }
96}
97
98impl<'a, T: Clone + Default + Sized> IntoIterator for &'a PinnedVec<T> {
99 type Item = &'a T;
100 type IntoIter = Iter<'a, T>;
101
102 fn into_iter(self) -> Self::IntoIter {
103 self.x.iter()
104 }
105}
106
107impl<T: Clone + Default + Sized, I: SliceIndex<[T]>> Index<I> for PinnedVec<T> {
108 type Output = I::Output;
109
110 #[inline]
111 fn index(&self, index: I) -> &Self::Output {
112 &self.x[index]
113 }
114}
115
116impl<T: Clone + Default + Sized, I: SliceIndex<[T]>> IndexMut<I> for PinnedVec<T> {
117 #[inline]
118 fn index_mut(&mut self, index: I) -> &mut Self::Output {
119 &mut self.x[index]
120 }
121}
122
123impl<'a, T: Clone + Send + Sync + Default + Sized> IntoParallelIterator for &'a PinnedVec<T> {
124 type Iter = rayon::slice::Iter<'a, T>;
125 type Item = &'a T;
126 fn into_par_iter(self) -> Self::Iter {
127 self.x.par_iter()
128 }
129}
130
131impl<'a, T: Clone + Send + Sync + Default + Sized> IntoParallelIterator for &'a mut PinnedVec<T> {
132 type Iter = rayon::slice::IterMut<'a, T>;
133 type Item = &'a mut T;
134 fn into_par_iter(self) -> Self::Iter {
135 self.x.par_iter_mut()
136 }
137}
138
139impl<T: Clone + Default + Sized> PinnedVec<T> {
140 pub fn reserve_and_pin(&mut self, size: usize) {
141 if self.x.capacity() < size {
142 if self.pinned {
143 unpin(self.x.as_mut_ptr());
144 self.pinned = false;
145 }
146 self.x.reserve(size);
147 }
148 self.set_pinnable();
149 if !self.pinned {
150 pin(&mut self.x);
151 self.pinned = true;
152 }
153 }
154
155 pub fn set_pinnable(&mut self) {
156 self.pinnable = true;
157 }
158
159 pub fn from_vec(source: Vec<T>) -> Self {
160 Self {
161 x: source,
162 pinned: false,
163 pinnable: false,
164 recycler: Weak::default(),
165 }
166 }
167
168 pub fn with_capacity(capacity: usize) -> Self {
169 Self::from_vec(Vec::with_capacity(capacity))
170 }
171
172 fn prepare_realloc(&mut self, new_size: usize) -> (*mut T, usize) {
173 let old_ptr = self.x.as_mut_ptr();
174 let old_capacity = self.x.capacity();
175 if self.pinned && self.x.capacity() < new_size {
177 unpin(old_ptr);
178 self.pinned = false;
179 }
180 (old_ptr, old_capacity)
181 }
182
183 pub fn push(&mut self, x: T) {
184 let (old_ptr, old_capacity) = self.prepare_realloc(self.x.len().saturating_add(1));
185 self.x.push(x);
186 self.check_ptr(old_ptr, old_capacity, "push");
187 }
188
189 pub fn resize(&mut self, size: usize, elem: T) {
190 let (old_ptr, old_capacity) = self.prepare_realloc(size);
191 self.x.resize(size, elem);
192 self.check_ptr(old_ptr, old_capacity, "resize");
193 }
194
195 pub fn append(&mut self, other: &mut Vec<T>) {
196 let (old_ptr, old_capacity) =
197 self.prepare_realloc(self.x.len().saturating_add(other.len()));
198 self.x.append(other);
199 self.check_ptr(old_ptr, old_capacity, "resize");
200 }
201
202 pub fn append_pinned(&mut self, other: &mut Self) {
203 let (old_ptr, old_capacity) =
204 self.prepare_realloc(self.x.len().saturating_add(other.len()));
205 self.x.append(&mut other.x);
206 self.check_ptr(old_ptr, old_capacity, "resize");
207 }
208
209 pub fn shuffle<R: Rng>(&mut self, rng: &mut R) {
210 self.x.shuffle(rng)
211 }
212
213 fn check_ptr(&mut self, old_ptr: *mut T, old_capacity: usize, from: &'static str) {
214 let api = perf_libs::api();
215 if api.is_some()
216 && self.pinnable
217 && (!std::ptr::eq(self.x.as_ptr(), old_ptr) || self.x.capacity() != old_capacity)
218 {
219 if self.pinned {
220 unpin(old_ptr);
221 }
222
223 trace!(
224 "pinning from check_ptr old: {} size: {} from: {}",
225 old_capacity,
226 self.x.capacity(),
227 from
228 );
229 pin(&mut self.x);
230 self.pinned = true;
231 }
232 }
233}
234
235impl<T: Clone + Default + Sized> Clone for PinnedVec<T> {
236 fn clone(&self) -> Self {
237 let mut x = self.x.clone();
238 let pinned = if self.pinned {
239 pin(&mut x);
240 true
241 } else {
242 false
243 };
244 debug!(
245 "clone PinnedVec: size: {} pinned?: {} pinnable?: {}",
246 self.x.capacity(),
247 self.pinned,
248 self.pinnable
249 );
250 Self {
251 x,
252 pinned,
253 pinnable: self.pinnable,
254 recycler: self.recycler.clone(),
255 }
256 }
257}
258
259impl<T: Sized + Default + Clone> Deref for PinnedVec<T> {
260 type Target = Vec<T>;
261
262 fn deref(&self) -> &Self::Target {
263 &self.x
264 }
265}
266
267impl<T: Sized + Default + Clone> DerefMut for PinnedVec<T> {
268 fn deref_mut(&mut self) -> &mut Self::Target {
269 &mut self.x
270 }
271}
272
273impl<T: Sized + Default + Clone> Drop for PinnedVec<T> {
274 fn drop(&mut self) {
275 if let Some(recycler) = self.recycler.upgrade() {
276 recycler.recycle(std::mem::take(self));
277 } else if self.pinned {
278 unpin(self.x.as_mut_ptr());
279 }
280 }
281}
282
283impl<T: Sized + Default + Clone + PartialEq> PartialEq for PinnedVec<T> {
284 fn eq(&self, other: &Self) -> bool {
285 self.x.eq(&other.x)
286 }
287}
288
289impl<T: Sized + Default + Clone + PartialEq + Eq> Eq for PinnedVec<T> {}
290
291#[cfg(test)]
292mod tests {
293 use super::*;
294
295 #[test]
296 fn test_pinned_vec() {
297 let mut mem = PinnedVec::with_capacity(10);
298 mem.set_pinnable();
299 mem.push(50);
300 mem.resize(2, 10);
301 assert_eq!(mem[0], 50);
302 assert_eq!(mem[1], 10);
303 assert_eq!(mem.len(), 2);
304 assert!(!mem.is_empty());
305 let mut iter = mem.iter();
306 assert_eq!(*iter.next().unwrap(), 50);
307 assert_eq!(*iter.next().unwrap(), 10);
308 assert_eq!(iter.next(), None);
309 }
310}