1use crate::{
2 DeviceProperties,
3 kernel::KernelMetadata,
4 logging::ServerLogger,
5 memory_management::{
6 MemoryAllocationMode, MemoryHandle, MemoryUsage,
7 memory_pool::{SliceBinding, SliceHandle},
8 },
9 storage::{BindingResource, ComputeStorage},
10 tma::{OobFill, TensorMapFormat, TensorMapInterleave, TensorMapPrefetch, TensorMapSwizzle},
11};
12use alloc::collections::BTreeMap;
13use alloc::string::String;
14use alloc::sync::Arc;
15use alloc::vec;
16use alloc::vec::Vec;
17use core::fmt::Debug;
18use cubecl_common::{
19 ExecutionMode, bytes::Bytes, device, future::DynFut, profile::ProfileDuration,
20 stream_id::StreamId,
21};
22use cubecl_ir::StorageType;
23use thiserror::Error;
24
25#[derive(Debug, Clone)]
26pub enum ProfileError {
28 Unknown(String),
30 NotRegistered,
32}
33
34#[derive(Debug)]
35pub struct ServerUtilities<Server: ComputeServer> {
37 #[cfg(feature = "profile-tracy")]
39 pub epoch_time: web_time::Instant,
40 #[cfg(feature = "profile-tracy")]
42 pub gpu_client: tracy_client::GpuContext,
43 pub properties: DeviceProperties,
45 pub info: Server::Info,
47 pub logger: Arc<ServerLogger>,
49}
50
51impl<S: ComputeServer> ServerUtilities<S> {
52 pub fn new(properties: DeviceProperties, logger: Arc<ServerLogger>, info: S::Info) -> Self {
54 #[cfg(feature = "profile-tracy")]
56 let client = tracy_client::Client::start();
57
58 Self {
59 properties,
60 logger,
61 #[cfg(feature = "profile-tracy")]
63 gpu_client: client
64 .clone()
65 .new_gpu_context(
66 Some(&format!("{info:?}")),
67 tracy_client::GpuContextType::Invalid,
69 0, 1.0, )
72 .unwrap(),
73 #[cfg(feature = "profile-tracy")]
74 epoch_time: web_time::Instant::now(),
75 info,
76 }
77 }
78}
79
80pub trait ComputeServer:
85 Send + core::fmt::Debug + ServerCommunication + device::DeviceState + 'static
86where
87 Self: Sized,
88{
89 type Kernel: KernelMetadata;
91 type Info: Debug + Send + Sync;
93 type Storage: ComputeStorage;
95
96 fn create(
98 &mut self,
99 descriptors: Vec<AllocationDescriptor<'_>>,
100 stream_id: StreamId,
101 ) -> Result<Vec<Allocation>, IoError>;
102
103 fn logger(&self) -> Arc<ServerLogger>;
105
106 fn utilities(&self) -> Arc<ServerUtilities<Self>>;
108
109 fn create_with_data(&mut self, data: &[u8], stream_id: StreamId) -> Result<Handle, IoError> {
111 let alloc = self
112 .create(
113 vec![AllocationDescriptor::new(
114 AllocationKind::Contiguous,
115 &[data.len()],
116 1,
117 )],
118 stream_id,
119 )?
120 .remove(0);
121 self.write(
122 vec![(
123 CopyDescriptor::new(
124 alloc.handle.clone().binding(),
125 &[data.len()],
126 &alloc.strides,
127 1,
128 ),
129 data,
130 )],
131 stream_id,
132 )?;
133 Ok(alloc.handle)
134 }
135
136 fn read<'a>(
138 &mut self,
139 descriptors: Vec<CopyDescriptor<'a>>,
140 stream_id: StreamId,
141 ) -> DynFut<Result<Vec<Bytes>, IoError>>;
142
143 fn write(
145 &mut self,
146 descriptors: Vec<(CopyDescriptor<'_>, &[u8])>,
147 stream_id: StreamId,
148 ) -> Result<(), IoError>;
149
150 fn sync(&mut self, stream_id: StreamId) -> DynFut<()>;
152
153 fn get_resource(
155 &mut self,
156 binding: Binding,
157 stream_id: StreamId,
158 ) -> BindingResource<<Self::Storage as ComputeStorage>::Resource>;
159
160 unsafe fn execute(
169 &mut self,
170 kernel: Self::Kernel,
171 count: CubeCount,
172 bindings: Bindings,
173 kind: ExecutionMode,
174 stream_id: StreamId,
175 );
176
177 fn flush(&mut self, stream_id: StreamId);
179
180 fn memory_usage(&mut self, stream_id: StreamId) -> MemoryUsage;
182
183 fn memory_cleanup(&mut self, stream_id: StreamId);
185
186 fn start_profile(&mut self, stream_id: StreamId) -> ProfilingToken;
188
189 fn end_profile(
191 &mut self,
192 stream_id: StreamId,
193 token: ProfilingToken,
194 ) -> Result<ProfileDuration, ProfileError>;
195
196 fn allocation_mode(&mut self, mode: MemoryAllocationMode, stream_id: StreamId);
198}
199
200pub trait ServerCommunication {
203 const SERVER_COMM_ENABLED: bool;
205
206 #[allow(unused_variables)]
225 fn copy(
226 server_src: &mut Self,
227 server_dst: &mut Self,
228 src: CopyDescriptor<'_>,
229 stream_id_src: StreamId,
230 stream_id_dst: StreamId,
231 ) -> Result<Allocation, IoError> {
232 if !Self::SERVER_COMM_ENABLED {
233 panic!("Server-to-server communication is not supported by this server.");
234 } else {
235 panic!(
236 "[Internal Error] The `ServerCommunication` trait is incorrectly implemented by the server."
237 );
238 }
239 }
240}
241
242#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
243pub struct ProfilingToken {
245 pub id: u64,
247}
248
249#[derive(new, Debug, PartialEq, Eq)]
251pub struct Handle {
252 pub memory: SliceHandle,
254 pub offset_start: Option<u64>,
256 pub offset_end: Option<u64>,
258 pub stream: cubecl_common::stream_id::StreamId,
260 pub cursor: u64,
262 size: u64,
264}
265
266#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
268pub enum AllocationKind {
269 Contiguous,
271 Optimized,
274}
275
276#[derive(new, Debug, Clone, Copy)]
278pub struct AllocationDescriptor<'a> {
279 pub kind: AllocationKind,
281 pub shape: &'a [usize],
283 pub elem_size: usize,
285}
286
287impl<'a> AllocationDescriptor<'a> {
288 pub fn optimized(shape: &'a [usize], elem_size: usize) -> Self {
290 AllocationDescriptor::new(AllocationKind::Optimized, shape, elem_size)
291 }
292
293 pub fn contiguous(shape: &'a [usize], elem_size: usize) -> Self {
295 AllocationDescriptor::new(AllocationKind::Contiguous, shape, elem_size)
296 }
297}
298
299#[derive(new, Debug)]
301pub struct Allocation {
302 pub handle: Handle,
304 pub strides: Vec<usize>,
306}
307
308#[derive(Debug, Error)]
311pub enum IoError {
312 #[error("can't allocate buffer of size")]
314 BufferTooBig(usize),
315 #[error("the provided strides are not supported for this operation")]
317 UnsupportedStrides,
318 #[error("couldn't find resource for that handle")]
320 InvalidHandle,
321 #[error("Unknown error happened during execution")]
323 Unknown(String),
324}
325
326impl Handle {
327 pub fn offset_start(mut self, offset: u64) -> Self {
329 if let Some(val) = &mut self.offset_start {
330 *val += offset;
331 } else {
332 self.offset_start = Some(offset);
333 }
334
335 self
336 }
337 pub fn offset_end(mut self, offset: u64) -> Self {
339 if let Some(val) = &mut self.offset_end {
340 *val += offset;
341 } else {
342 self.offset_end = Some(offset);
343 }
344
345 self
346 }
347
348 pub fn size(&self) -> u64 {
350 self.size - self.offset_start.unwrap_or(0) - self.offset_end.unwrap_or(0)
351 }
352}
353
354#[derive(Debug, Default)]
356pub struct Bindings {
357 pub buffers: Vec<Binding>,
359 pub metadata: MetadataBinding,
362 pub scalars: BTreeMap<StorageType, ScalarBinding>,
364 pub tensor_maps: Vec<TensorMapBinding>,
366}
367
368impl Bindings {
369 pub fn new() -> Self {
371 Self::default()
372 }
373
374 pub fn with_buffer(mut self, binding: Binding) -> Self {
376 self.buffers.push(binding);
377 self
378 }
379
380 pub fn with_buffers(mut self, bindings: Vec<Binding>) -> Self {
382 self.buffers.extend(bindings);
383 self
384 }
385
386 pub fn with_scalar(mut self, ty: StorageType, length: usize, data: Vec<u64>) -> Self {
388 self.scalars
389 .insert(ty, ScalarBinding::new(ty, length, data));
390 self
391 }
392
393 pub fn with_scalars(mut self, bindings: Vec<ScalarBinding>) -> Self {
395 self.scalars
396 .extend(bindings.into_iter().map(|binding| (binding.ty, binding)));
397 self
398 }
399
400 pub fn with_metadata(mut self, meta: MetadataBinding) -> Self {
402 self.metadata = meta;
403 self
404 }
405
406 pub fn with_tensor_maps(mut self, bindings: Vec<TensorMapBinding>) -> Self {
408 self.tensor_maps.extend(bindings);
409 self
410 }
411}
412
413#[derive(new, Debug, Default)]
415pub struct MetadataBinding {
416 pub data: Vec<u32>,
418 pub static_len: usize,
420}
421
422#[derive(new, Debug, Clone)]
424pub struct ScalarBinding {
425 pub ty: StorageType,
427 pub length: usize,
429 pub data: Vec<u64>,
431}
432
433impl ScalarBinding {
434 pub fn data(&self) -> &[u8] {
436 bytemuck::cast_slice(&self.data)
437 }
438}
439
440#[derive(new, Debug)]
442pub struct Binding {
443 pub memory: SliceBinding,
445 pub offset_start: Option<u64>,
447 pub offset_end: Option<u64>,
449 pub stream: cubecl_common::stream_id::StreamId,
451 pub cursor: u64,
453 size: u64,
455}
456
457impl Binding {
458 pub fn size(&self) -> u64 {
460 self.size - self.offset_start.unwrap_or(0) - self.offset_end.unwrap_or(0)
461 }
462}
463
464#[derive(new, Debug, Clone)]
466pub struct CopyDescriptor<'a> {
467 pub binding: Binding,
469 pub shape: &'a [usize],
471 pub strides: &'a [usize],
473 pub elem_size: usize,
475}
476
477#[derive(new, Debug, Clone)]
479pub struct TensorMapBinding {
480 pub binding: Binding,
482 pub map: TensorMapMeta,
484}
485
486#[derive(Debug, Clone)]
488pub struct TensorMapMeta {
489 pub format: TensorMapFormat,
491 pub rank: usize,
493 pub shape: Vec<usize>,
495 pub strides: Vec<usize>,
497 pub elem_stride: Vec<usize>,
500 pub interleave: TensorMapInterleave,
502 pub swizzle: TensorMapSwizzle,
504 pub prefetch: TensorMapPrefetch,
506 pub oob_fill: OobFill,
508 pub storage_ty: StorageType,
510}
511
512impl Handle {
513 pub fn can_mut(&self) -> bool {
515 self.memory.can_mut() && self.stream == StreamId::current()
516 }
517}
518
519impl Handle {
520 pub fn binding(self) -> Binding {
522 Binding {
523 memory: MemoryHandle::binding(self.memory),
524 offset_start: self.offset_start,
525 offset_end: self.offset_end,
526 size: self.size,
527 stream: self.stream,
528 cursor: self.cursor,
529 }
530 }
531
532 pub fn copy_descriptor<'a>(
534 &'a self,
535 shape: &'a [usize],
536 strides: &'a [usize],
537 elem_size: usize,
538 ) -> CopyDescriptor<'a> {
539 CopyDescriptor {
540 shape,
541 strides,
542 elem_size,
543 binding: self.clone().binding(),
544 }
545 }
546}
547
548impl Clone for Handle {
549 fn clone(&self) -> Self {
550 Self {
551 memory: self.memory.clone(),
552 offset_start: self.offset_start,
553 offset_end: self.offset_end,
554 size: self.size,
555 stream: self.stream,
556 cursor: self.cursor,
557 }
558 }
559}
560
561impl Clone for Binding {
562 fn clone(&self) -> Self {
563 Self {
564 memory: self.memory.clone(),
565 offset_start: self.offset_start,
566 offset_end: self.offset_end,
567 size: self.size,
568 stream: self.stream,
569 cursor: self.cursor,
570 }
571 }
572}
573
574#[allow(clippy::large_enum_variant)]
578pub enum CubeCount {
579 Static(u32, u32, u32),
581 Dynamic(Binding),
583}
584
585impl CubeCount {
586 pub fn new_single() -> Self {
588 CubeCount::Static(1, 1, 1)
589 }
590
591 pub fn new_1d(x: u32) -> Self {
593 CubeCount::Static(x, 1, 1)
594 }
595
596 pub fn new_2d(x: u32, y: u32) -> Self {
598 CubeCount::Static(x, y, 1)
599 }
600
601 pub fn new_3d(x: u32, y: u32, z: u32) -> Self {
603 CubeCount::Static(x, y, z)
604 }
605}
606
607impl Debug for CubeCount {
608 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
609 match self {
610 CubeCount::Static(x, y, z) => f.write_fmt(format_args!("({x}, {y}, {z})")),
611 CubeCount::Dynamic(_) => f.write_str("binding"),
612 }
613 }
614}
615
616impl Clone for CubeCount {
617 fn clone(&self) -> Self {
618 match self {
619 Self::Static(x, y, z) => Self::Static(*x, *y, *z),
620 Self::Dynamic(handle) => Self::Dynamic(handle.clone()),
621 }
622 }
623}