ocl_stream/executor/
mod.rs1use crate::executor::context::ExecutorContext;
8use crate::executor::stream::{OCLStream, OCLStreamSender};
9use crate::utils::result::OCLStreamResult;
10use crate::utils::shared_buffer::SharedBuffer;
11use ocl::{OclPrm, ProQue};
12use std::any::type_name;
13use std::sync::Arc;
14use std::thread;
15
16pub mod context;
17pub mod stream;
18
19#[derive(Clone)]
21pub struct OCLStreamExecutor {
22 pro_que: ProQue,
23 concurrency: usize,
24}
25
26impl OCLStreamExecutor {
27 pub fn new(pro_que: ProQue) -> Self {
35 Self {
36 pro_que,
37 concurrency: 1,
38 }
39 }
40
41 pub fn set_concurrency(&mut self, mut num_tasks: usize) {
46 if num_tasks == 0 {
47 num_tasks = num_cpus::get();
48 }
49 self.concurrency = num_tasks;
50 }
51
52 pub fn execute_bounded<F, T>(&self, size: usize, func: F) -> OCLStream<T>
54 where
55 F: Fn(ExecutorContext<T>) -> OCLStreamResult<()> + Send + Sync + 'static,
56 T: Send + Sync + 'static,
57 {
58 let (stream, sender) = stream::bounded(size);
59 self.execute(func, sender);
60
61 stream
62 }
63
64 pub fn pro_que(&self) -> &ProQue {
66 &self.pro_que
67 }
68
69 pub fn execute_unbounded<F, T>(&self, func: F) -> OCLStream<T>
72 where
73 F: Fn(ExecutorContext<T>) -> OCLStreamResult<()> + Send + Sync + 'static,
74 T: Send + Sync + 'static,
75 {
76 let (stream, sender) = stream::unbounded();
77 self.execute(func, sender);
78
79 stream
80 }
81
82 fn execute<F, T>(&self, func: F, sender: OCLStreamSender<T>)
84 where
85 F: Fn(ExecutorContext<T>) -> OCLStreamResult<()> + Send + Sync + 'static,
86 T: Send + Sync + 'static,
87 {
88 let func = Arc::new(func);
89
90 log::debug!("Spawning {} executor threads", self.concurrency);
91
92 for task_id in 0..(self.concurrency) {
93 let func = Arc::clone(&func);
94 let context = self.build_context(task_id, sender.clone());
95
96 thread::Builder::new()
97 .name(format!("ocl-{}", task_id))
98 .spawn(move || {
99 let sender = context.sender().clone();
100
101 log::debug!("Running function in thread {}", task_id);
102 if let Err(e) = func(context) {
103 log::error!("Execution of function failed: {}", e);
104 if let Err(e) = sender.err(e) {
105 panic!("Failed to forward error to receiver: {}", e);
106 }
107 }
108 })
109 .expect("Failed to spawn ocl thread");
110 }
111 }
112
113 pub fn create_shared_buffer<T>(&self, len: usize) -> ocl::Result<SharedBuffer<T>>
115 where
116 T: OclPrm,
117 {
118 log::trace!(
119 "Creating shared buffer of length {} and type {}",
120 len,
121 type_name::<T>()
122 );
123 let buffer = self.pro_que.buffer_builder().len(len).build()?;
124 Ok(SharedBuffer::new(buffer))
125 }
126
127 fn build_context<T>(&self, task_id: usize, sender: OCLStreamSender<T>) -> ExecutorContext<T>
129 where
130 T: Send + Sync,
131 {
132 ExecutorContext::new(self.pro_que.clone(), task_id, sender)
133 }
134}