nevermind_neu/layers/
input_layer_ocl.rs

1use ocl::{Buffer, Context, Device, MemFlags, Queue};
2
3use log::{debug, error};
4
5use crate::layers::*;
6use crate::cpu_params::CpuParams;
7use crate::ocl::*;
8use crate::util::*;
9
10use std::{collections::HashMap, error::Error};
11
12#[derive(Clone)]
13pub struct InputLayerOcl {
14    ocl_params: Option<OclParams>,
15    size: usize,
16    batch_size: usize,
17    ocl_queue: Option<Queue>,
18}
19
20impl InputLayerOcl {
21    pub fn new(size: usize) -> Self {
22        Self {
23            ocl_params: None,
24            size,
25            batch_size: 1,
26            ocl_queue: None,
27        }
28    }
29}
30
31impl AbstractLayer for InputLayerOcl {
32    fn layer_type(&self) -> &str {
33        "InputLayerOcl"
34    }
35
36    fn size(&self) -> usize {
37        self.size
38    }
39
40    fn set_batch_size(&mut self, batch_size: usize) {
41        self.batch_size = batch_size;
42    }
43
44    fn cpu_params(&self) -> Option<CpuParams> {
45        None
46    }
47
48    fn trainable_bufs(&self) -> TrainableBufsIds {
49        (&[], &[])
50    }
51
52    fn serializable_bufs(&self) -> &[i32] {
53        &[]
54    }
55
56    fn set_cpu_params(&mut self, lp: CpuParams) {}
57
58    fn set_input_shape(&mut self, sh: &[usize]) {}
59
60    // Do copy layer memory(ws, output, ...)
61    fn copy_layer(&self) -> Box<dyn AbstractLayer> {
62        panic!("Do not copy OCL layers !");
63    }
64
65    // Do copy only Rc
66    fn clone_layer(&self) -> Box<dyn AbstractLayer> {
67        panic!("Do not copy OCL layers !");
68    }
69}
70
71impl AbstractLayerOcl for InputLayerOcl {
72    fn init_ocl(
73        &mut self,
74        _ocl_ctx: &Context,
75        _device: Device,
76        queue: Queue,
77    ) -> Result<(), Box<dyn Error>> {
78        self.ocl_queue = Some(queue);
79        Ok(())
80    }
81
82    fn forward_input_ocl(&mut self, input_data: Array2D) -> LayerOclResult {
83        if input_data.ncols() != self.size {
84            error!("[InputOCL] Invalid data size : {} | {}", input_data.ncols(), self.size);
85            return Result::Err(LayerError::InvalidSize);
86        }
87
88        let ocl_queue = self.ocl_queue.as_ref().unwrap();
89        let ocl_buf = Buffer::builder()
90            .queue(ocl_queue.clone())
91            .flags(MemFlags::new().read_write())
92            .len(input_data.len())
93            .copy_host_slice(input_data.as_slice().unwrap())
94            .build()
95            .expect("[inp_ocl] Couldn't create "); // TODO : handle unwrap
96
97        self.ocl_params = Some(OclParams::only_output(ocl_buf, ocl_queue.clone()));
98
99        Ok(vec![self.ocl_params.as_ref().unwrap().clone()])
100    }
101
102    fn forward_ocl(&mut self, params: OclParamsBlob) -> LayerOclResult {
103        Err(LayerError::NotImpl)
104    }
105
106    fn ocl_params(&self) -> Option<OclParams> {
107        Some(self.ocl_params.as_ref().unwrap().clone())
108    }
109
110    fn copy_layer_ocl(&self) -> Box<dyn AbstractLayerOcl> {
111        todo!()
112    }
113
114    fn clone_layer_ocl(&self) -> Box<dyn AbstractLayerOcl> {
115        Box::new(self.clone())
116    }
117}
118
119impl Default for InputLayerOcl {
120    fn default() -> Self {
121        Self {
122            ocl_params: None,
123            size: 0,
124            batch_size: 1,
125            ocl_queue: None,
126        }
127    }
128}
129
130impl WithParams for InputLayerOcl {
131    fn cfg(&self) -> HashMap<String, Variant> {
132        let mut out = HashMap::new();
133
134        out.insert("size".to_string(), Variant::Int(self.size as i32));
135
136        out
137    }
138
139    fn set_cfg(&mut self, args: &HashMap<String, Variant>) {
140        if let Some(size) = args.get("size") {
141            if let Variant::Int(size) = size {
142                self.size = *size as usize;
143            }
144        }
145    }
146}