ocl_extras/
sub_buffer_pool.rs1use std::collections::{LinkedList, HashMap};
4use ocl::{Queue, Buffer};
5use ocl::traits::OclPrm;
6use ocl::flags::MemFlags;
7
8
9pub struct PoolRegion {
10 buffer_id: usize,
11 origin: u32,
12 len: u32,
13}
14
15
16pub struct SubBufferPool<T: OclPrm> {
18 buffer: Buffer<T>,
19 regions: LinkedList<PoolRegion>,
20 sub_buffers: HashMap<usize, Buffer<T>>,
21 align: u32,
22 _next_uid: usize,
23}
24
25impl<T: OclPrm> SubBufferPool<T> {
26 pub fn new(len: u32, default_queue: Queue) -> SubBufferPool<T> {
28 let align = default_queue.device().mem_base_addr_align().unwrap();
29 let flags = MemFlags::new().alloc_host_ptr().read_write();
30
31 let buffer = Buffer::<T>::builder()
32 .queue(default_queue)
33 .flags(flags)
34 .len(len as usize)
35 .build().unwrap();
36
37 SubBufferPool {
38 buffer: buffer,
39 regions: LinkedList::new(),
40 sub_buffers: HashMap::new(),
41 align: align,
42 _next_uid: 0,
43 }
44 }
45
46 fn next_valid_align(&self, unaligned_origin: u32) -> u32 {
47 ((unaligned_origin / self.align) + 1) * self.align
48 }
49
50 fn next_uid(&mut self) -> usize {
51 self._next_uid += 1;
52 self._next_uid - 1
53 }
54
55 fn insert_region(&mut self, region: PoolRegion, region_idx: usize) {
56 let mut tail = self.regions.split_off(region_idx);
57 self.regions.push_back(region);
58 self.regions.append(&mut tail);
59 }
60
61 fn create_sub_buffer(&mut self, region_idx: usize, flags: Option<MemFlags>,
62 origin: u32, len: u32) -> usize {
63 let buffer_id = self.next_uid();
64 let region = PoolRegion { buffer_id: buffer_id, origin: origin, len: len };
65 let sbuf = self.buffer.create_sub_buffer(flags, region.origin as usize,
66 region.len as usize).unwrap();
67 if let Some(idx) = self.sub_buffers.insert(region.buffer_id, sbuf) {
68 panic!("Duplicate indexes: {}", idx); }
69 self.insert_region(region, region_idx);
70 buffer_id
71 }
72
73 pub fn alloc(&mut self, len: u32, flags: Option<MemFlags>) -> Result<usize, ()> {
76 assert!(self.regions.len() == self.sub_buffers.len());
77
78 match self.regions.front() {
79 Some(_) => {
80 let mut end_prev = 0;
81 let mut create_at = None;
82
83 for (region_idx, region) in self.regions.iter().enumerate() {
84 if region.origin - end_prev >= len {
85 create_at = Some(region_idx);
86 break;
87 } else {
88 end_prev = self.next_valid_align(region.origin + region.len);
89 }
90 }
91
92 if let Some(region_idx) = create_at {
93 Ok(self.create_sub_buffer(region_idx, flags, end_prev, len))
94 } else if self.buffer.len() as u32 - end_prev >= len {
95 let region_idx = self.regions.len();
96 Ok(self.create_sub_buffer(region_idx, flags, end_prev, len))
97 } else {
98 Err(())
99 }
100 },
101 None => {
102 Ok(self.create_sub_buffer(0, flags, 0, len))
103 },
104 }
105 }
106
107 pub fn free(&mut self, buffer_id: usize) -> Result<(), usize> {
110 let mut region_idx = None;
111
112 if let Some(_) = self.sub_buffers.remove(&buffer_id) {
114 region_idx = self.regions.iter().position(|r| r.buffer_id == buffer_id);
115 }
116
117 if let Some(r_idx) = region_idx {
118 let mut tail = self.regions.split_off(r_idx);
119 tail.pop_front().ok_or(buffer_id) ?;
120 self.regions.append(&mut tail);
121 Ok(())
122 } else {
123 Err(buffer_id)
124 }
125 }
126
127 pub fn get(&self, buffer_id: usize) -> Option<&Buffer<T>> {
129 self.sub_buffers.get(&buffer_id)
130 }
131
132 #[allow(dead_code)]
134 pub fn get_mut(&mut self, buffer_id: usize) -> Option<&mut Buffer<T>> {
135 self.sub_buffers.get_mut(&buffer_id)
136 }
137
138 #[allow(dead_code)]
146 pub fn defrag(&mut self) {
147 unimplemented!();
152 }
153
154 #[allow(dead_code, unused_variables)]
158 pub fn resize(&mut self, len: u32) {
159 unimplemented!();
162 }
163}