nevermind_neu/layers/
input_layer_ocl.rs1use ocl::{Buffer, Context, Device, MemFlags, Queue};
2
3use log::{debug, error};
4
5use crate::layers::*;
6use crate::cpu_params::CpuParams;
7use crate::ocl::*;
8use crate::util::*;
9
10use std::{collections::HashMap, error::Error};
11
12#[derive(Clone)]
13pub struct InputLayerOcl {
14 ocl_params: Option<OclParams>,
15 size: usize,
16 batch_size: usize,
17 ocl_queue: Option<Queue>,
18}
19
20impl InputLayerOcl {
21 pub fn new(size: usize) -> Self {
22 Self {
23 ocl_params: None,
24 size,
25 batch_size: 1,
26 ocl_queue: None,
27 }
28 }
29}
30
31impl AbstractLayer for InputLayerOcl {
32 fn layer_type(&self) -> &str {
33 "InputLayerOcl"
34 }
35
36 fn size(&self) -> usize {
37 self.size
38 }
39
40 fn set_batch_size(&mut self, batch_size: usize) {
41 self.batch_size = batch_size;
42 }
43
44 fn cpu_params(&self) -> Option<CpuParams> {
45 None
46 }
47
48 fn trainable_bufs(&self) -> TrainableBufsIds {
49 (&[], &[])
50 }
51
52 fn serializable_bufs(&self) -> &[i32] {
53 &[]
54 }
55
56 fn set_cpu_params(&mut self, lp: CpuParams) {}
57
58 fn set_input_shape(&mut self, sh: &[usize]) {}
59
60 fn copy_layer(&self) -> Box<dyn AbstractLayer> {
62 panic!("Do not copy OCL layers !");
63 }
64
65 fn clone_layer(&self) -> Box<dyn AbstractLayer> {
67 panic!("Do not copy OCL layers !");
68 }
69}
70
71impl AbstractLayerOcl for InputLayerOcl {
72 fn init_ocl(
73 &mut self,
74 _ocl_ctx: &Context,
75 _device: Device,
76 queue: Queue,
77 ) -> Result<(), Box<dyn Error>> {
78 self.ocl_queue = Some(queue);
79 Ok(())
80 }
81
82 fn forward_input_ocl(&mut self, input_data: Array2D) -> LayerOclResult {
83 if input_data.ncols() != self.size {
84 error!("[InputOCL] Invalid data size : {} | {}", input_data.ncols(), self.size);
85 return Result::Err(LayerError::InvalidSize);
86 }
87
88 let ocl_queue = self.ocl_queue.as_ref().unwrap();
89 let ocl_buf = Buffer::builder()
90 .queue(ocl_queue.clone())
91 .flags(MemFlags::new().read_write())
92 .len(input_data.len())
93 .copy_host_slice(input_data.as_slice().unwrap())
94 .build()
95 .expect("[inp_ocl] Couldn't create "); self.ocl_params = Some(OclParams::only_output(ocl_buf, ocl_queue.clone()));
98
99 Ok(vec![self.ocl_params.as_ref().unwrap().clone()])
100 }
101
102 fn forward_ocl(&mut self, params: OclParamsBlob) -> LayerOclResult {
103 Err(LayerError::NotImpl)
104 }
105
106 fn ocl_params(&self) -> Option<OclParams> {
107 Some(self.ocl_params.as_ref().unwrap().clone())
108 }
109
110 fn copy_layer_ocl(&self) -> Box<dyn AbstractLayerOcl> {
111 todo!()
112 }
113
114 fn clone_layer_ocl(&self) -> Box<dyn AbstractLayerOcl> {
115 Box::new(self.clone())
116 }
117}
118
119impl Default for InputLayerOcl {
120 fn default() -> Self {
121 Self {
122 ocl_params: None,
123 size: 0,
124 batch_size: 1,
125 ocl_queue: None,
126 }
127 }
128}
129
130impl WithParams for InputLayerOcl {
131 fn cfg(&self) -> HashMap<String, Variant> {
132 let mut out = HashMap::new();
133
134 out.insert("size".to_string(), Variant::Int(self.size as i32));
135
136 out
137 }
138
139 fn set_cfg(&mut self, args: &HashMap<String, Variant>) {
140 if let Some(size) = args.get("size") {
141 if let Variant::Int(size) = size {
142 self.size = *size as usize;
143 }
144 }
145 }
146}