1mod task_builder;
2mod kernel_arg;
3use std::os::raw::c_void;
4use std::sync::Arc;
5
6use crate::{
7 async_executor::task_builder::TaskBuilder,
8 cl_types::{
9 cl_buffer::ClBuffer,
10 cl_event::ClEvent,
11 cl_command_queue::{ClCommandQueue, command_queue_parameters::{CommandQueueProperties, Version20}},
12 cl_context::ClContext,
13 cl_device::ClDevice,
14 cl_platform::ClPlatform,
15 cl_kernel::ClKernel,
16 cl_device::opencl_version::OpenCLVersion,
17 cl_program::{ClProgram, Builded, NotBuilded, program_parameters::ProgramParameters},
18 cl_image::{ClImage, image_desc::ClImageDesc, image_formats::ClImageFormats},
19 cl_svm_buffer::ClSvmBuffer,
20 memory_flags::MemoryFlags,
21 },
22 error::ClError
23};
24
25#[cfg(feature = "CL_VERSION_1_1")]
33pub struct AsyncExecutor {
34 context: Arc<ClContext>,
35 queues: Vec<ClCommandQueue>,
36 weights: Vec<u64>,
37 profiling_enabled: bool,
38 device_versions: Vec<OpenCLVersion>,
39 devices: Vec<ClDevice>,
40}
41
42#[cfg(feature = "CL_VERSION_1_1")]
43impl AsyncExecutor {
44 pub fn new_best_platform() -> Result<Self, ClError> {
55 Self::new_best_platform_with_options(false)
56 }
57
58 pub fn new_best_platform_with_options(profiling_enabled: bool) -> Result<Self, ClError> {
63 let platforms = ClPlatform::get_all()?;
64
65 let scores: Vec<u64> = platforms
66 .iter()
67 .map(|f| Self::measure_platform_capacity(f))
68 .collect::<Result<Vec<u64>, ClError>>()?;
69
70 let best_score_index = {
71 let mut idx: usize = 0;
72 let mut max: u64 = 0;
73 for (i, v) in scores.iter().enumerate() {
74 if *v > max {
75 idx = i;
76 max = *v
77 }
78 }
79 idx
80 };
81
82 let devices = match platforms.get(best_score_index) {
83 Some(plat) => plat.get_all_devices()?,
84 None => {
85 return Err(ClError::Wrapper(
86 crate::error::wrapper_error::WrapperError::PlatformsNotFound,
87 ));
88 }
89 };
90
91 let context = Arc::new(ClContext::new(&devices)?);
92 let mut queues = Vec::new();
93 let mut weights = Vec::new();
94 let mut device_versions = Vec::new();
95
96 for device in &devices {
97 let version = device.get_opencl_version();
98 let mut supports_out_of_order = false;
99
100 let queue = if version >= OpenCLVersion::V2_0 {
101 if let Ok(host_props) = device.get_queue_on_host_properties() {
102 supports_out_of_order = (host_props as u64 & cl3::command_queue::CL_QUEUE_OUT_OF_ORDER_EXEC_MODE_ENABLE) != 0;
103 }
104
105 let properties = CommandQueueProperties::<Version20>::new()
106 .set_cl_queue_properties(supports_out_of_order, profiling_enabled, false, false)
107 .get_properties();
108 ClCommandQueue::create_command_queue_with_properties(&context, device, &properties)?
109 } else {
110 let mut properties = 0;
111 if profiling_enabled {
112 properties |= cl3::command_queue::CL_QUEUE_PROFILING_ENABLE;
113 }
114 #[allow(deprecated)]
115 ClCommandQueue::create_command_queue(&context, device, properties)?
116 };
117
118 queues.push(queue);
119 weights.push(Self::measure_device_capacity(device)?);
120 device_versions.push(version);
121 }
122
123 let executor = Self {
124 context,
125 queues,
126 weights,
127 device_versions,
128 profiling_enabled,
129 devices: devices.into_iter().map(|d| d.clone()).collect(),
130 };
131
132 Ok(executor)
133 }
134
135 pub fn new_all_platforms() -> Result<Self, ClError> {
136 let platforms = ClPlatform::get_all()?;
137 Self::new_from_platforms(&platforms)
138 }
139
140 pub fn new_from_platforms(platforms: &[ClPlatform]) -> Result<Self, ClError> {
141 let mut all_devices = Vec::new();
142 for platform in platforms {
143 let devices = platform.get_all_devices()?;
144 all_devices.extend(devices);
145 }
146 Self::new_from_devices(&all_devices)
147 }
148
149 pub fn new_from_devices(devices: &[ClDevice]) -> Result<Self, ClError> {
150 Self::new_from_devices_with_options(devices, false)
151 }
152
153 pub fn new_from_devices_with_options(devices: &[ClDevice], profiling_enabled: bool) -> Result<Self, ClError> {
154 let devices_vec = devices.to_vec();
155 let context = Arc::new(ClContext::new(&devices_vec)?);
156 let mut queues = Vec::new();
157 let mut weights = Vec::new();
158 let mut device_versions = Vec::new();
159
160 for device in devices {
161 let version = device.get_opencl_version();
162 let mut supports_out_of_order = false;
163
164 let queue = if version >= OpenCLVersion::V2_0 {
165 if let Ok(host_props) = device.get_queue_on_host_properties() {
166 supports_out_of_order = (host_props as u64 & cl3::command_queue::CL_QUEUE_OUT_OF_ORDER_EXEC_MODE_ENABLE) != 0;
167 }
168
169 let properties = CommandQueueProperties::<Version20>::new()
170 .set_cl_queue_properties(supports_out_of_order, profiling_enabled, false, false)
171 .get_properties();
172 ClCommandQueue::create_command_queue_with_properties(&context, device, &properties)?
173 } else {
174 let mut properties = 0;
175 if profiling_enabled {
176 properties |= cl3::command_queue::CL_QUEUE_PROFILING_ENABLE;
177 }
178 #[allow(deprecated)]
179 ClCommandQueue::create_command_queue(&context, device, properties)?
180 };
181
182 queues.push(queue);
183 weights.push(Self::measure_device_capacity(device)?);
184 device_versions.push(version);
185 }
186
187 Ok(Self {
188 context,
189 queues,
190 weights,
191 device_versions,
192 profiling_enabled,
193 devices: devices_vec,
194 })
195 }
196
197 pub fn is_profiling_enabled(&self) -> bool {
198 self.profiling_enabled
199 }
200
201 pub fn get_context(&self) -> Arc<ClContext> {
202 self.context.clone()
203 }
204
205 pub fn get_device_versions(&self) -> &[OpenCLVersion] {
206 &self.device_versions
207 }
208
209 pub fn get_devices(&self) -> Result<Vec<ClDevice>, ClError> {
210 Ok(self.devices.iter().map(|d| d.clone()).collect())
211 }
212
213 pub fn get_queues(&self) -> &[ClCommandQueue] {
214 &self.queues
215 }
216
217 pub fn create_task(&self, kernel: ClKernel) -> TaskBuilder<'_> {
227 TaskBuilder::new(self, kernel)
228 }
229
230 pub fn build_program(&self, source: String, options: Option<&str>) -> Result<ClProgram<Builded>, ClError> {
242 let unbuilded = ClProgram::<NotBuilded>::from_src(&self.context, source)?;
243 let devices = self.context.get_devices()?;
244
245 let params = match options {
246 Some(opt) => ProgramParameters::default().custom(opt).get_parameters(),
247 None => ProgramParameters::default().get_parameters(),
248 };
249
250 unbuilded.build(¶ms, &devices)
251 }
252
253 pub fn compile_or_binary(
258 &self,
259 src_path: &str,
260 binary_dest_folder: &str,
261 options: Option<&str>,
262 ) -> Result<ClProgram<Builded>, ClError> {
263 use std::fs;
264 use std::path::Path;
265 use std::io::Read;
266 use crate::error::wrapper_error::WrapperError;
267
268 let path = Path::new(src_path);
269 let file_stem = path.file_stem()
270 .and_then(|s| s.to_str())
271 .ok_or(ClError::Wrapper(WrapperError::FailedToConvertStrToCString))?;
272
273 let devices = self.context.get_devices()?;
274
275 let mut binaries: Vec<Vec<u8>> = Vec::new();
276 let mut use_binaries = true;
277
278 for (i, device) in devices.iter().enumerate() {
279 let device_name = device.get_name()?.replace(" ", "_");
280 let bin_filename = format!("{}_{}_{}.bin", file_stem, device_name, i);
281 let bin_path = Path::new(binary_dest_folder).join(bin_filename);
282
283 if bin_path.exists() {
284 match fs::read(&bin_path) {
285 Ok(content) => binaries.push(content),
286 Err(_) => {
287 use_binaries = false;
288 break;
289 }
290 }
291 } else {
292 use_binaries = false;
293 break;
294 }
295 }
296
297 if use_binaries && binaries.len() == devices.len() {
298 let binary_slices: Vec<&[u8]> = binaries.iter().map(|b| b.as_slice()).collect();
299
300 match ClProgram::<NotBuilded>::from_binary(&self.context, &devices, &binary_slices) {
301 Ok(program) => {
302 let params = match options {
303 Some(opt) => ProgramParameters::default().custom(opt).get_parameters(),
304 None => ProgramParameters::default().get_parameters(),
305 };
306
307 match program.build(¶ms, &devices) {
308 Ok(built) => return Ok(built),
309 Err(_) => {
310 }
312 }
313 },
314 Err(_) => {
315 }
317 }
318 }
319
320 let mut src_content = String::new();
322 fs::File::open(src_path)
323 .map_err(|_| ClError::Wrapper(WrapperError::FileIOError))?
324 .read_to_string(&mut src_content)
325 .map_err(|_| ClError::Wrapper(WrapperError::FileIOError))?;
326
327 let built_program = self.build_program(src_content, options)?;
328
329 let _ = built_program.save_binary(binary_dest_folder, file_stem);
331
332 Ok(built_program)
333 }
334
335 pub fn create_kernel(&self, program: &ClProgram<Builded>, name: &str) -> Result<ClKernel, ClError> {
338 ClKernel::new(program, name)
339 }
340
341 pub fn create_buffer(&self, flags: &[MemoryFlags], size: usize, host_ptr: *mut c_void) -> Result<ClBuffer, ClError> {
345 ClBuffer::new(&self.context, &flags.to_vec(), size, host_ptr)
346 }
347
348 #[cfg(feature = "CL_VERSION_1_2")]
351 pub fn create_image(
352 &self,
353 flags: &[MemoryFlags],
354 format: &ClImageFormats,
355 desc: &ClImageDesc,
356 host_ptr: *mut c_void
357 ) -> Result<ClImage, ClError> {
358 ClImage::new(&self.context, &flags.to_vec(), format, desc, host_ptr)
359 }
360
361 #[cfg(feature = "CL_VERSION_2_0")]
364 pub fn create_svm_buffer<T>(&self, flags: &[MemoryFlags], len: usize) -> Result<ClSvmBuffer<T>, ClError> {
365 ClSvmBuffer::<T>::new(&self.context, &flags.to_vec(), len, 0)
366 }
367
368 pub async fn read_buffer<T: Sized>(
371 &self,
372 buffer: &ClBuffer,
373 host_memory: &mut [T],
374 ) -> Result<ClEvent, ClError> {
375 let queue = self.get_optimal_queue();
376 queue.enqueue_read_buffer(buffer, None, host_memory, None).await
377 }
378
379 pub async fn write_buffer<T: Sized>(
382 &self,
383 buffer: &ClBuffer,
384 host_memory: &mut [T],
385 ) -> Result<ClEvent, ClError> {
386 let queue = self.get_optimal_queue();
387 let size = host_memory.len() * std::mem::size_of::<T>();
388 queue.write_buffer(buffer, host_memory.as_mut_ptr() as *mut c_void, 0, size, None).await
389 }
390
391 #[cfg(feature = "CL_VERSION_1_2")]
393 pub async fn read_image<T: Sized>(
394 &self,
395 image: &ClImage,
396 host_memory: &mut [T],
397 origin: [usize; 3],
398 region: [usize; 3],
399 ) -> Result<ClEvent, ClError> {
400 let queue = self.get_optimal_queue();
401 queue.read_image_raw(
402 image,
403 origin,
404 region,
405 0,
406 0,
407 host_memory.as_mut_ptr() as *mut c_void,
408 None
409 ).await
410 }
411
412 #[cfg(feature = "CL_VERSION_1_2")]
414 pub async fn write_image<T: Sized>(
415 &self,
416 image: &ClImage,
417 host_memory: &mut [T],
418 origin: [usize; 3],
419 region: [usize; 3],
420 ) -> Result<ClEvent, ClError> {
421 let queue = self.get_optimal_queue();
422 queue.write_image_raw(
423 image,
424 origin,
425 region,
426 0,
427 0,
428 host_memory.as_mut_ptr() as *mut c_void,
429 None
430 ).await
431 }
432
433 fn get_optimal_queue(&self) -> &ClCommandQueue {
439 let mut max_weight = 0;
440 let mut idx = 0;
441 for (i, &weight) in self.weights.iter().enumerate() {
442 if weight > max_weight {
443 max_weight = weight;
444 idx = i;
445 }
446 }
447 &self.queues[idx]
448 }
449
450 fn measure_platform_capacity(platform: &ClPlatform) -> Result<u64, ClError> {
451 let mut score: u64 = 0;
452
453 let devices = platform.get_all_devices()?;
454 for device in &devices {
455 score += Self::measure_device_capacity(device)?;
456 }
457 Ok(score)
458 }
459
460 fn measure_device_capacity(device: &ClDevice) -> Result<u64, ClError> {
461 let compute_units = device.get_max_compute_units()?;
462 let clock_frequency = device.get_max_clock_frequency()?;
463 let memory = device.get_global_mem_size()? / (1024 * 1024);
464
465 Ok(((compute_units as u64 * clock_frequency as u64) / 100) + (memory / 10))
466 }
467}
468
469unsafe impl Sync for AsyncExecutor {}
470unsafe impl Send for AsyncExecutor {}