kn_cuda_eval/
executor.rs

1use std::fmt::{Debug, Display, Formatter};
2use std::time::Instant;
3
4use bytemuck::{cast_slice, Pod};
5use bytemuck::checked::cast_slice_mut;
6use itertools::{multizip, zip_eq};
7
8use kn_cuda_sys::wrapper::handle::{CublasHandle, CudaDevice, CudaStream, CudnnHandle};
9use kn_cuda_sys::wrapper::mem::device::DevicePtr;
10use kn_cuda_sys::wrapper::mem::pinned::PinnedMem;
11use kn_graph::{dispatch_dtensor, dispatch_dtype};
12use kn_graph::dtype::{DBool, DTensor, Tensor};
13use kn_graph::graph::Graph;
14
15use crate::device_tensor::DeviceTensor;
16use crate::planner::{MemoryUsage, Plan, Planner};
17use crate::step::{Handles, Step, StepInfo};
18use crate::util::debug_vec_multiline;
19
20pub struct CudaExecutor {
21    pub handles: Handles,
22
23    pub device_inputs: Vec<DeviceTensor>,
24    pub device_outputs: Vec<DeviceTensor>,
25
26    pub batch_size: usize,
27    pub mem_usage: MemoryUsage,
28    steps: Vec<StepInfo<DevicePtr>>,
29
30    profile: bool,
31    last_profile: Option<Profile>,
32
33    // TODO switch to single in/out buffer each, so we do a single memcpy between host and device?
34    buffer_inputs: Vec<PinnedMem>,
35    buffer_outputs: Vec<PinnedMem>,
36    tensor_outputs: Vec<DTensor>,
37}
38
39// TODO is this safe?
40//   check what the cuda docs say about transfering streams and all the structure pointers across threads
41unsafe impl Send for CudaExecutor {}
42
43#[derive(Default, Debug, Clone)]
44pub struct Profile {
45    pub steps: Vec<String>,
46
47    pub conv: f32,
48    pub mat_mul: f32,
49    pub scalar_op: f32,
50    pub reduce_op: f32,
51    pub softmax_op: f32,
52    pub layernorm_op: f32,
53    pub gather_op: f32,
54
55    pub total_cpu: f32,
56    pub total_gpu: f32,
57    pub timing_overhead: f32,
58}
59
60impl CudaExecutor {
61    pub fn new(device: CudaDevice, graph: &Graph, batch_size: usize) -> Self {
62        let handles = Handles {
63            cudnn: CudnnHandle::new(device),
64            cublas: CublasHandle::new(device),
65        };
66
67        let Plan {
68            inputs,
69            outputs,
70            steps,
71            mem_usage,
72        } = Planner::plan(&handles, graph, batch_size);
73
74        let buffer_inputs = inputs
75            .iter()
76            .map(|x| {
77                let len_bytes = x.strided_shape().size() * x.dtype().size().bytes();
78                PinnedMem::alloc(len_bytes, false)
79            })
80            .collect();
81        let buffer_outputs = outputs
82            .iter()
83            .map(|x| {
84                let len_bytes = x.strided_shape().size() * x.dtype().size().bytes();
85                PinnedMem::alloc(len_bytes, false)
86            })
87            .collect();
88        let tensor_outputs = outputs
89            .iter()
90            .map(|x| {
91                let shape = x.strided_shape().shape().to_vec();
92                let dtype = x.dtype();
93                dispatch_dtype!(dtype, |_T, _fs, ft| ft(Tensor::zeros(shape)))
94            })
95            .collect();
96
97        CudaExecutor {
98            handles,
99            batch_size,
100            mem_usage,
101            device_inputs: inputs,
102            device_outputs: outputs,
103            steps,
104            profile: false,
105            last_profile: None,
106            buffer_inputs,
107            buffer_outputs,
108            tensor_outputs,
109        }
110    }
111
112    pub fn stream(&self) -> &CudaStream {
113        self.handles.stream()
114    }
115
116    // TODO accept views as inputs? introduce DView struct/alias?
117    pub fn evaluate(&mut self, inputs: &[DTensor]) -> &[DTensor] {
118        assert_eq!(inputs.len(), self.device_inputs.len(), "Wrong input count");
119        for (i, (input, tensor)) in zip_eq(inputs, &self.device_inputs).enumerate() {
120            assert_eq!(
121                input.shape(),
122                tensor.strided_shape().shape(),
123                "Wrong shape for input {}",
124                i
125            );
126            assert_eq!(input.dtype(), tensor.dtype(), "Wrong dtype for input {}", i);
127        }
128
129        unsafe {
130            // make sure nothing else is using the buffers
131            self.stream().synchronize();
132
133            for (input, buffer, tensor) in multizip((inputs, &self.buffer_inputs, &self.device_inputs)) {
134                // copy inputs to buffer
135                // TODO is there a simple way to avoid the potential extra layout copy?
136                dispatch_dtensor!(input, |T, _f, input| {
137                    let input = input.as_standard_layout();
138                    let input_slice = input.as_slice().unwrap();
139                    buffer.as_slice().copy_from_slice(cast_slice::<T, u8>(input_slice));
140                });
141
142                // copy buffer to device
143                assert!(tensor.strided_shape().has_simple_strides());
144                tensor.ptr().copy_linear_from_host_async(buffer, self.stream());
145            }
146
147            // run the steps
148            self.run_async();
149
150            // copy outputs to buffers
151            for (buffer, tensor) in zip_eq(&mut self.buffer_outputs, &self.device_outputs) {
152                tensor.ptr().copy_linear_to_host_async(buffer, self.handles.stream());
153            }
154
155            // wait for everything to complete
156            self.stream().synchronize();
157
158            // copy buffers to tensors
159            // TODO interleave this with copying to host?
160            for (buffer, tensor) in zip_eq(&self.buffer_outputs, &mut self.tensor_outputs) {
161                let buffer: &[u8] = buffer.as_slice();
162
163                unsafe fn branch<T: Pod>(tensor: &mut Tensor<T>, buffer: &[u8]) {
164                    cast_slice_mut::<T, u8>(tensor.as_slice_mut().unwrap()).copy_from_slice(buffer);
165                }
166
167                match tensor {
168                    DTensor::F32(tensor) => branch::<f32>(tensor, buffer),
169                    DTensor::F64(tensor) => branch::<f64>(tensor, buffer),
170                    DTensor::I8(tensor) => branch::<i8>(tensor, buffer),
171                    DTensor::I16(tensor) => branch::<i16>(tensor, buffer),
172                    DTensor::I32(tensor) => branch::<i32>(tensor, buffer),
173                    DTensor::I64(tensor) => branch::<i64>(tensor, buffer),
174                    DTensor::U8(tensor) => branch::<u8>(tensor, buffer),
175                    DTensor::U16(tensor) => branch::<u16>(tensor, buffer),
176                    DTensor::U32(tensor) => branch::<u32>(tensor, buffer),
177                    DTensor::U64(tensor) => branch::<u64>(tensor, buffer),
178
179                    // do a manual copy, with proper error checking
180                    // we can't use bytemuck here since it rightfully doesn't want to cast &mut bool/DBool to &mut u8
181                    DTensor::Bool(tensor) => {
182                        let mut fail = false;
183                        for (i, x) in tensor.iter_mut().enumerate() {
184                            let y = buffer[i];
185                            *x = DBool(y != 0);
186                            fail |= y > 1;
187                        }
188                        assert!(!fail);
189                    }
190                }
191            }
192        }
193
194        &self.tensor_outputs
195    }
196
197    /// Run the steps in this executor. Does no explicit before/after synchronization,
198    /// so ensure inputs are written and synchronize before reading outputs.
199    pub unsafe fn run_async(&mut self) {
200        assert_eq!(self.stream().device(), CudaDevice::current());
201
202        if !self.profile {
203            for step_info in &self.steps {
204                step_info.step.run(&self.handles);
205            }
206
207            self.last_profile = None
208        } else {
209            let mut timers = vec![];
210
211            let start_gpu = self.stream().record_event();
212            start_gpu.synchronize();
213
214            let start_cpu = Instant::now();
215
216            for step_info in &self.steps {
217                let start = self.stream().record_event();
218                step_info.step.run(&self.handles);
219                let end = self.stream().record_event();
220
221                if self.profile {
222                    timers.push((step_info, start, end));
223                }
224            }
225
226            let end_gpu = self.stream().record_event();
227            self.stream().synchronize();
228
229            let end_cpu = Instant::now();
230
231            let mut profile = Profile::default();
232
233            for (i, (step_info, start, end)) in timers.iter().enumerate() {
234                let time = end.time_elapsed_since(start);
235
236                *match step_info.step {
237                    Step::Conv { .. } => &mut profile.conv,
238                    Step::MatMul { .. } => &mut profile.mat_mul,
239                    Step::ScalarOp { .. } => &mut profile.scalar_op,
240                    Step::ReduceOp { .. } => &mut profile.reduce_op,
241                    Step::SoftmaxOp { .. } => &mut profile.softmax_op,
242                    Step::LayernormOp { .. } => &mut profile.layernorm_op,
243                    Step::GatherOp { .. } => &mut profile.gather_op,
244                } += time;
245
246                profile
247                    .steps
248                    .push(format!("{: >4} time {:>10.4} ms, step {:?}", i, time * 1e3, step_info));
249            }
250
251            let overhead_end = Instant::now();
252            profile.total_gpu = end_gpu.time_elapsed_since(&start_gpu);
253            profile.total_cpu = (end_cpu - start_cpu).as_secs_f32();
254            profile.timing_overhead = (overhead_end - end_cpu).as_secs_f32();
255
256            self.last_profile = Some(profile)
257        }
258    }
259
260    pub fn set_profile(&mut self, profile: bool) {
261        self.profile = profile;
262    }
263
264    pub fn last_profile(&self) -> Option<&Profile> {
265        self.last_profile.as_ref()
266    }
267}
268
269impl Debug for CudaExecutor {
270    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
271        writeln!(f, "CudaExecutor {{")?;
272
273        writeln!(f, "    batch_size: {},", self.batch_size)?;
274        writeln!(f, "    mem_usage: {:?},", self.mem_usage)?;
275        writeln!(f, "    profile: {},", self.profile)?;
276
277        writeln!(f, "    inputs: {:?},", debug_vec_multiline("    ", &self.device_inputs))?;
278        writeln!(
279            f,
280            "    outputs: {:?},",
281            debug_vec_multiline("    ", &self.device_outputs)
282        )?;
283        writeln!(f, "    steps: {:?},", debug_vec_multiline("    ", &self.steps))?;
284
285        writeln!(f, "}}")?;
286
287        Ok(())
288    }
289}
290
291impl Display for Profile {
292    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
293        write!(f, "Profile {{\n  steps: [\n")?;
294        for step in &self.steps {
295            writeln!(f, "    {}", step)?;
296        }
297        write!(f, "  ]\n\n")?;
298
299        let total = self.conv
300            + self.mat_mul
301            + self.scalar_op
302            + self.reduce_op
303            + self.softmax_op
304            + self.layernorm_op
305            + self.gather_op;
306        let mut line = |name, time| writeln!(f, "  {} {:>10.4} ms  {:>4.2}", name, time * 1e3, time / total);
307
308        line("Conv:      ", self.conv)?;
309        line("Matmul:    ", self.mat_mul)?;
310        line("Scalar:    ", self.scalar_op)?;
311        line("Reduce:    ", self.reduce_op)?;
312        line("Softmax:   ", self.softmax_op)?;
313        line("Layernorm: ", self.layernorm_op)?;
314        line("Gather:    ", self.gather_op)?;
315
316        writeln!(f, "  ==============================")?;
317        writeln!(f, "  Total GPU:  {:>10.4} ms", self.total_gpu * 1e3)?;
318        writeln!(f, "  Total CPU:  {:>10.4} ms", self.total_cpu * 1e3)?;
319        writeln!(f, "  Overhead:   {:>10.4} ms", self.timing_overhead * 1e3)?;
320
321        writeln!(f, "}}")?;
322
323        Ok(())
324    }
325}