1use std::fmt::{Debug, Display, Formatter};
2use std::time::Instant;
3
4use bytemuck::{cast_slice, Pod};
5use bytemuck::checked::cast_slice_mut;
6use itertools::{multizip, zip_eq};
7
8use kn_cuda_sys::wrapper::handle::{CublasHandle, CudaDevice, CudaStream, CudnnHandle};
9use kn_cuda_sys::wrapper::mem::device::DevicePtr;
10use kn_cuda_sys::wrapper::mem::pinned::PinnedMem;
11use kn_graph::{dispatch_dtensor, dispatch_dtype};
12use kn_graph::dtype::{DBool, DTensor, Tensor};
13use kn_graph::graph::Graph;
14
15use crate::device_tensor::DeviceTensor;
16use crate::planner::{MemoryUsage, Plan, Planner};
17use crate::step::{Handles, Step, StepInfo};
18use crate::util::debug_vec_multiline;
19
20pub struct CudaExecutor {
21 pub handles: Handles,
22
23 pub device_inputs: Vec<DeviceTensor>,
24 pub device_outputs: Vec<DeviceTensor>,
25
26 pub batch_size: usize,
27 pub mem_usage: MemoryUsage,
28 steps: Vec<StepInfo<DevicePtr>>,
29
30 profile: bool,
31 last_profile: Option<Profile>,
32
33 buffer_inputs: Vec<PinnedMem>,
35 buffer_outputs: Vec<PinnedMem>,
36 tensor_outputs: Vec<DTensor>,
37}
38
39unsafe impl Send for CudaExecutor {}
42
43#[derive(Default, Debug, Clone)]
44pub struct Profile {
45 pub steps: Vec<String>,
46
47 pub conv: f32,
48 pub mat_mul: f32,
49 pub scalar_op: f32,
50 pub reduce_op: f32,
51 pub softmax_op: f32,
52 pub layernorm_op: f32,
53 pub gather_op: f32,
54
55 pub total_cpu: f32,
56 pub total_gpu: f32,
57 pub timing_overhead: f32,
58}
59
60impl CudaExecutor {
61 pub fn new(device: CudaDevice, graph: &Graph, batch_size: usize) -> Self {
62 let handles = Handles {
63 cudnn: CudnnHandle::new(device),
64 cublas: CublasHandle::new(device),
65 };
66
67 let Plan {
68 inputs,
69 outputs,
70 steps,
71 mem_usage,
72 } = Planner::plan(&handles, graph, batch_size);
73
74 let buffer_inputs = inputs
75 .iter()
76 .map(|x| {
77 let len_bytes = x.strided_shape().size() * x.dtype().size().bytes();
78 PinnedMem::alloc(len_bytes, false)
79 })
80 .collect();
81 let buffer_outputs = outputs
82 .iter()
83 .map(|x| {
84 let len_bytes = x.strided_shape().size() * x.dtype().size().bytes();
85 PinnedMem::alloc(len_bytes, false)
86 })
87 .collect();
88 let tensor_outputs = outputs
89 .iter()
90 .map(|x| {
91 let shape = x.strided_shape().shape().to_vec();
92 let dtype = x.dtype();
93 dispatch_dtype!(dtype, |_T, _fs, ft| ft(Tensor::zeros(shape)))
94 })
95 .collect();
96
97 CudaExecutor {
98 handles,
99 batch_size,
100 mem_usage,
101 device_inputs: inputs,
102 device_outputs: outputs,
103 steps,
104 profile: false,
105 last_profile: None,
106 buffer_inputs,
107 buffer_outputs,
108 tensor_outputs,
109 }
110 }
111
112 pub fn stream(&self) -> &CudaStream {
113 self.handles.stream()
114 }
115
116 pub fn evaluate(&mut self, inputs: &[DTensor]) -> &[DTensor] {
118 assert_eq!(inputs.len(), self.device_inputs.len(), "Wrong input count");
119 for (i, (input, tensor)) in zip_eq(inputs, &self.device_inputs).enumerate() {
120 assert_eq!(
121 input.shape(),
122 tensor.strided_shape().shape(),
123 "Wrong shape for input {}",
124 i
125 );
126 assert_eq!(input.dtype(), tensor.dtype(), "Wrong dtype for input {}", i);
127 }
128
129 unsafe {
130 self.stream().synchronize();
132
133 for (input, buffer, tensor) in multizip((inputs, &self.buffer_inputs, &self.device_inputs)) {
134 dispatch_dtensor!(input, |T, _f, input| {
137 let input = input.as_standard_layout();
138 let input_slice = input.as_slice().unwrap();
139 buffer.as_slice().copy_from_slice(cast_slice::<T, u8>(input_slice));
140 });
141
142 assert!(tensor.strided_shape().has_simple_strides());
144 tensor.ptr().copy_linear_from_host_async(buffer, self.stream());
145 }
146
147 self.run_async();
149
150 for (buffer, tensor) in zip_eq(&mut self.buffer_outputs, &self.device_outputs) {
152 tensor.ptr().copy_linear_to_host_async(buffer, self.handles.stream());
153 }
154
155 self.stream().synchronize();
157
158 for (buffer, tensor) in zip_eq(&self.buffer_outputs, &mut self.tensor_outputs) {
161 let buffer: &[u8] = buffer.as_slice();
162
163 unsafe fn branch<T: Pod>(tensor: &mut Tensor<T>, buffer: &[u8]) {
164 cast_slice_mut::<T, u8>(tensor.as_slice_mut().unwrap()).copy_from_slice(buffer);
165 }
166
167 match tensor {
168 DTensor::F32(tensor) => branch::<f32>(tensor, buffer),
169 DTensor::F64(tensor) => branch::<f64>(tensor, buffer),
170 DTensor::I8(tensor) => branch::<i8>(tensor, buffer),
171 DTensor::I16(tensor) => branch::<i16>(tensor, buffer),
172 DTensor::I32(tensor) => branch::<i32>(tensor, buffer),
173 DTensor::I64(tensor) => branch::<i64>(tensor, buffer),
174 DTensor::U8(tensor) => branch::<u8>(tensor, buffer),
175 DTensor::U16(tensor) => branch::<u16>(tensor, buffer),
176 DTensor::U32(tensor) => branch::<u32>(tensor, buffer),
177 DTensor::U64(tensor) => branch::<u64>(tensor, buffer),
178
179 DTensor::Bool(tensor) => {
182 let mut fail = false;
183 for (i, x) in tensor.iter_mut().enumerate() {
184 let y = buffer[i];
185 *x = DBool(y != 0);
186 fail |= y > 1;
187 }
188 assert!(!fail);
189 }
190 }
191 }
192 }
193
194 &self.tensor_outputs
195 }
196
197 pub unsafe fn run_async(&mut self) {
200 assert_eq!(self.stream().device(), CudaDevice::current());
201
202 if !self.profile {
203 for step_info in &self.steps {
204 step_info.step.run(&self.handles);
205 }
206
207 self.last_profile = None
208 } else {
209 let mut timers = vec![];
210
211 let start_gpu = self.stream().record_event();
212 start_gpu.synchronize();
213
214 let start_cpu = Instant::now();
215
216 for step_info in &self.steps {
217 let start = self.stream().record_event();
218 step_info.step.run(&self.handles);
219 let end = self.stream().record_event();
220
221 if self.profile {
222 timers.push((step_info, start, end));
223 }
224 }
225
226 let end_gpu = self.stream().record_event();
227 self.stream().synchronize();
228
229 let end_cpu = Instant::now();
230
231 let mut profile = Profile::default();
232
233 for (i, (step_info, start, end)) in timers.iter().enumerate() {
234 let time = end.time_elapsed_since(start);
235
236 *match step_info.step {
237 Step::Conv { .. } => &mut profile.conv,
238 Step::MatMul { .. } => &mut profile.mat_mul,
239 Step::ScalarOp { .. } => &mut profile.scalar_op,
240 Step::ReduceOp { .. } => &mut profile.reduce_op,
241 Step::SoftmaxOp { .. } => &mut profile.softmax_op,
242 Step::LayernormOp { .. } => &mut profile.layernorm_op,
243 Step::GatherOp { .. } => &mut profile.gather_op,
244 } += time;
245
246 profile
247 .steps
248 .push(format!("{: >4} time {:>10.4} ms, step {:?}", i, time * 1e3, step_info));
249 }
250
251 let overhead_end = Instant::now();
252 profile.total_gpu = end_gpu.time_elapsed_since(&start_gpu);
253 profile.total_cpu = (end_cpu - start_cpu).as_secs_f32();
254 profile.timing_overhead = (overhead_end - end_cpu).as_secs_f32();
255
256 self.last_profile = Some(profile)
257 }
258 }
259
260 pub fn set_profile(&mut self, profile: bool) {
261 self.profile = profile;
262 }
263
264 pub fn last_profile(&self) -> Option<&Profile> {
265 self.last_profile.as_ref()
266 }
267}
268
269impl Debug for CudaExecutor {
270 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
271 writeln!(f, "CudaExecutor {{")?;
272
273 writeln!(f, " batch_size: {},", self.batch_size)?;
274 writeln!(f, " mem_usage: {:?},", self.mem_usage)?;
275 writeln!(f, " profile: {},", self.profile)?;
276
277 writeln!(f, " inputs: {:?},", debug_vec_multiline(" ", &self.device_inputs))?;
278 writeln!(
279 f,
280 " outputs: {:?},",
281 debug_vec_multiline(" ", &self.device_outputs)
282 )?;
283 writeln!(f, " steps: {:?},", debug_vec_multiline(" ", &self.steps))?;
284
285 writeln!(f, "}}")?;
286
287 Ok(())
288 }
289}
290
291impl Display for Profile {
292 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
293 write!(f, "Profile {{\n steps: [\n")?;
294 for step in &self.steps {
295 writeln!(f, " {}", step)?;
296 }
297 write!(f, " ]\n\n")?;
298
299 let total = self.conv
300 + self.mat_mul
301 + self.scalar_op
302 + self.reduce_op
303 + self.softmax_op
304 + self.layernorm_op
305 + self.gather_op;
306 let mut line = |name, time| writeln!(f, " {} {:>10.4} ms {:>4.2}", name, time * 1e3, time / total);
307
308 line("Conv: ", self.conv)?;
309 line("Matmul: ", self.mat_mul)?;
310 line("Scalar: ", self.scalar_op)?;
311 line("Reduce: ", self.reduce_op)?;
312 line("Softmax: ", self.softmax_op)?;
313 line("Layernorm: ", self.layernorm_op)?;
314 line("Gather: ", self.gather_op)?;
315
316 writeln!(f, " ==============================")?;
317 writeln!(f, " Total GPU: {:>10.4} ms", self.total_gpu * 1e3)?;
318 writeln!(f, " Total CPU: {:>10.4} ms", self.total_cpu * 1e3)?;
319 writeln!(f, " Overhead: {:>10.4} ms", self.timing_overhead * 1e3)?;
320
321 writeln!(f, "}}")?;
322
323 Ok(())
324 }
325}