1use crate::{
2    DeviceProperties,
3    kernel::KernelMetadata,
4    logging::ServerLogger,
5    memory_management::{
6        MemoryAllocationMode, MemoryHandle, MemoryUsage,
7        memory_pool::{SliceBinding, SliceHandle},
8    },
9    storage::{BindingResource, ComputeStorage},
10    tma::{OobFill, TensorMapFormat, TensorMapInterleave, TensorMapPrefetch, TensorMapSwizzle},
11};
12use alloc::collections::BTreeMap;
13use alloc::string::String;
14use alloc::sync::Arc;
15use alloc::vec;
16use alloc::vec::Vec;
17use core::fmt::Debug;
18use cubecl_common::{
19    ExecutionMode, bytes::Bytes, device, future::DynFut, profile::ProfileDuration,
20    stream_id::StreamId,
21};
22use cubecl_ir::StorageType;
23use thiserror::Error;
24
25#[derive(Debug, Clone)]
26pub enum ProfileError {
28    Unknown(String),
30    NotRegistered,
32}
33
34#[derive(Debug)]
35pub struct ServerUtilities<Server: ComputeServer> {
37    #[cfg(feature = "profile-tracy")]
39    pub epoch_time: web_time::Instant,
40    #[cfg(feature = "profile-tracy")]
42    pub gpu_client: tracy_client::GpuContext,
43    pub properties: DeviceProperties,
45    pub info: Server::Info,
47    pub logger: Arc<ServerLogger>,
49}
50
51impl<S: ComputeServer> ServerUtilities<S> {
52    pub fn new(properties: DeviceProperties, logger: Arc<ServerLogger>, info: S::Info) -> Self {
54        #[cfg(feature = "profile-tracy")]
56        let client = tracy_client::Client::start();
57
58        Self {
59            properties,
60            logger,
61            #[cfg(feature = "profile-tracy")]
63            gpu_client: client
64                .clone()
65                .new_gpu_context(
66                    Some(&format!("{info:?}")),
67                    tracy_client::GpuContextType::Invalid,
69                    0,   1.0, )
72                .unwrap(),
73            #[cfg(feature = "profile-tracy")]
74            epoch_time: web_time::Instant::now(),
75            info,
76        }
77    }
78}
79
80pub trait ComputeServer:
85    Send + core::fmt::Debug + ServerCommunication + device::DeviceState + 'static
86where
87    Self: Sized,
88{
89    type Kernel: KernelMetadata;
91    type Info: Debug + Send + Sync;
93    type Storage: ComputeStorage;
95
96    fn create(
98        &mut self,
99        descriptors: Vec<AllocationDescriptor<'_>>,
100        stream_id: StreamId,
101    ) -> Result<Vec<Allocation>, IoError>;
102
103    fn logger(&self) -> Arc<ServerLogger>;
105
106    fn utilities(&self) -> Arc<ServerUtilities<Self>>;
108
109    fn create_with_data(&mut self, data: &[u8], stream_id: StreamId) -> Result<Handle, IoError> {
111        let alloc = self
112            .create(
113                vec![AllocationDescriptor::new(
114                    AllocationKind::Contiguous,
115                    &[data.len()],
116                    1,
117                )],
118                stream_id,
119            )?
120            .remove(0);
121        self.write(
122            vec![(
123                CopyDescriptor::new(
124                    alloc.handle.clone().binding(),
125                    &[data.len()],
126                    &alloc.strides,
127                    1,
128                ),
129                data,
130            )],
131            stream_id,
132        )?;
133        Ok(alloc.handle)
134    }
135
136    fn read<'a>(
138        &mut self,
139        descriptors: Vec<CopyDescriptor<'a>>,
140        stream_id: StreamId,
141    ) -> DynFut<Result<Vec<Bytes>, IoError>>;
142
143    fn write(
145        &mut self,
146        descriptors: Vec<(CopyDescriptor<'_>, &[u8])>,
147        stream_id: StreamId,
148    ) -> Result<(), IoError>;
149
150    fn sync(&mut self, stream_id: StreamId) -> DynFut<()>;
152
153    fn get_resource(
155        &mut self,
156        binding: Binding,
157        stream_id: StreamId,
158    ) -> BindingResource<<Self::Storage as ComputeStorage>::Resource>;
159
160    unsafe fn execute(
169        &mut self,
170        kernel: Self::Kernel,
171        count: CubeCount,
172        bindings: Bindings,
173        kind: ExecutionMode,
174        stream_id: StreamId,
175    );
176
177    fn flush(&mut self, stream_id: StreamId);
179
180    fn memory_usage(&mut self, stream_id: StreamId) -> MemoryUsage;
182
183    fn memory_cleanup(&mut self, stream_id: StreamId);
185
186    fn start_profile(&mut self, stream_id: StreamId) -> ProfilingToken;
188
189    fn end_profile(
191        &mut self,
192        stream_id: StreamId,
193        token: ProfilingToken,
194    ) -> Result<ProfileDuration, ProfileError>;
195
196    fn allocation_mode(&mut self, mode: MemoryAllocationMode, stream_id: StreamId);
198}
199
200pub trait ServerCommunication {
203    const SERVER_COMM_ENABLED: bool;
205
206    #[allow(unused_variables)]
225    fn copy(
226        server_src: &mut Self,
227        server_dst: &mut Self,
228        src: CopyDescriptor<'_>,
229        stream_id_src: StreamId,
230        stream_id_dst: StreamId,
231    ) -> Result<Allocation, IoError> {
232        if !Self::SERVER_COMM_ENABLED {
233            panic!("Server-to-server communication is not supported by this server.");
234        } else {
235            panic!(
236                "[Internal Error] The `ServerCommunication` trait is incorrectly implemented by the server."
237            );
238        }
239    }
240}
241
242#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
243pub struct ProfilingToken {
245    pub id: u64,
247}
248
249#[derive(new, Debug, PartialEq, Eq)]
251pub struct Handle {
252    pub memory: SliceHandle,
254    pub offset_start: Option<u64>,
256    pub offset_end: Option<u64>,
258    pub stream: cubecl_common::stream_id::StreamId,
260    pub cursor: u64,
262    size: u64,
264}
265
266#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq)]
268pub enum AllocationKind {
269    Contiguous,
271    Optimized,
274}
275
276#[derive(new, Debug, Clone, Copy)]
278pub struct AllocationDescriptor<'a> {
279    pub kind: AllocationKind,
281    pub shape: &'a [usize],
283    pub elem_size: usize,
285}
286
287impl<'a> AllocationDescriptor<'a> {
288    pub fn optimized(shape: &'a [usize], elem_size: usize) -> Self {
290        AllocationDescriptor::new(AllocationKind::Optimized, shape, elem_size)
291    }
292
293    pub fn contiguous(shape: &'a [usize], elem_size: usize) -> Self {
295        AllocationDescriptor::new(AllocationKind::Contiguous, shape, elem_size)
296    }
297}
298
299#[derive(new, Debug)]
301pub struct Allocation {
302    pub handle: Handle,
304    pub strides: Vec<usize>,
306}
307
308#[derive(Debug, Error)]
311pub enum IoError {
312    #[error("can't allocate buffer of size")]
314    BufferTooBig(usize),
315    #[error("the provided strides are not supported for this operation")]
317    UnsupportedStrides,
318    #[error("couldn't find resource for that handle")]
320    InvalidHandle,
321    #[error("Unknown error happened during execution")]
323    Unknown(String),
324}
325
326impl Handle {
327    pub fn offset_start(mut self, offset: u64) -> Self {
329        if let Some(val) = &mut self.offset_start {
330            *val += offset;
331        } else {
332            self.offset_start = Some(offset);
333        }
334
335        self
336    }
337    pub fn offset_end(mut self, offset: u64) -> Self {
339        if let Some(val) = &mut self.offset_end {
340            *val += offset;
341        } else {
342            self.offset_end = Some(offset);
343        }
344
345        self
346    }
347
348    pub fn size(&self) -> u64 {
350        self.size - self.offset_start.unwrap_or(0) - self.offset_end.unwrap_or(0)
351    }
352}
353
354#[derive(Debug, Default)]
356pub struct Bindings {
357    pub buffers: Vec<Binding>,
359    pub metadata: MetadataBinding,
362    pub scalars: BTreeMap<StorageType, ScalarBinding>,
364    pub tensor_maps: Vec<TensorMapBinding>,
366}
367
368impl Bindings {
369    pub fn new() -> Self {
371        Self::default()
372    }
373
374    pub fn with_buffer(mut self, binding: Binding) -> Self {
376        self.buffers.push(binding);
377        self
378    }
379
380    pub fn with_buffers(mut self, bindings: Vec<Binding>) -> Self {
382        self.buffers.extend(bindings);
383        self
384    }
385
386    pub fn with_scalar(mut self, ty: StorageType, length: usize, data: Vec<u64>) -> Self {
388        self.scalars
389            .insert(ty, ScalarBinding::new(ty, length, data));
390        self
391    }
392
393    pub fn with_scalars(mut self, bindings: Vec<ScalarBinding>) -> Self {
395        self.scalars
396            .extend(bindings.into_iter().map(|binding| (binding.ty, binding)));
397        self
398    }
399
400    pub fn with_metadata(mut self, meta: MetadataBinding) -> Self {
402        self.metadata = meta;
403        self
404    }
405
406    pub fn with_tensor_maps(mut self, bindings: Vec<TensorMapBinding>) -> Self {
408        self.tensor_maps.extend(bindings);
409        self
410    }
411}
412
413#[derive(new, Debug, Default)]
415pub struct MetadataBinding {
416    pub data: Vec<u32>,
418    pub static_len: usize,
420}
421
422#[derive(new, Debug, Clone)]
424pub struct ScalarBinding {
425    pub ty: StorageType,
427    pub length: usize,
429    pub data: Vec<u64>,
431}
432
433impl ScalarBinding {
434    pub fn data(&self) -> &[u8] {
436        bytemuck::cast_slice(&self.data)
437    }
438}
439
440#[derive(new, Debug)]
442pub struct Binding {
443    pub memory: SliceBinding,
445    pub offset_start: Option<u64>,
447    pub offset_end: Option<u64>,
449    pub stream: cubecl_common::stream_id::StreamId,
451    pub cursor: u64,
453    size: u64,
455}
456
457impl Binding {
458    pub fn size(&self) -> u64 {
460        self.size - self.offset_start.unwrap_or(0) - self.offset_end.unwrap_or(0)
461    }
462}
463
464#[derive(new, Debug, Clone)]
466pub struct CopyDescriptor<'a> {
467    pub binding: Binding,
469    pub shape: &'a [usize],
471    pub strides: &'a [usize],
473    pub elem_size: usize,
475}
476
477#[derive(new, Debug, Clone)]
479pub struct TensorMapBinding {
480    pub binding: Binding,
482    pub map: TensorMapMeta,
484}
485
486#[derive(Debug, Clone)]
488pub struct TensorMapMeta {
489    pub format: TensorMapFormat,
491    pub rank: usize,
493    pub shape: Vec<usize>,
495    pub strides: Vec<usize>,
497    pub elem_stride: Vec<usize>,
500    pub interleave: TensorMapInterleave,
502    pub swizzle: TensorMapSwizzle,
504    pub prefetch: TensorMapPrefetch,
506    pub oob_fill: OobFill,
508    pub storage_ty: StorageType,
510}
511
512impl Handle {
513    pub fn can_mut(&self) -> bool {
515        self.memory.can_mut() && self.stream == StreamId::current()
516    }
517}
518
519impl Handle {
520    pub fn binding(self) -> Binding {
522        Binding {
523            memory: MemoryHandle::binding(self.memory),
524            offset_start: self.offset_start,
525            offset_end: self.offset_end,
526            size: self.size,
527            stream: self.stream,
528            cursor: self.cursor,
529        }
530    }
531
532    pub fn copy_descriptor<'a>(
534        &'a self,
535        shape: &'a [usize],
536        strides: &'a [usize],
537        elem_size: usize,
538    ) -> CopyDescriptor<'a> {
539        CopyDescriptor {
540            shape,
541            strides,
542            elem_size,
543            binding: self.clone().binding(),
544        }
545    }
546}
547
548impl Clone for Handle {
549    fn clone(&self) -> Self {
550        Self {
551            memory: self.memory.clone(),
552            offset_start: self.offset_start,
553            offset_end: self.offset_end,
554            size: self.size,
555            stream: self.stream,
556            cursor: self.cursor,
557        }
558    }
559}
560
561impl Clone for Binding {
562    fn clone(&self) -> Self {
563        Self {
564            memory: self.memory.clone(),
565            offset_start: self.offset_start,
566            offset_end: self.offset_end,
567            size: self.size,
568            stream: self.stream,
569            cursor: self.cursor,
570        }
571    }
572}
573
574#[allow(clippy::large_enum_variant)]
578pub enum CubeCount {
579    Static(u32, u32, u32),
581    Dynamic(Binding),
583}
584
585impl CubeCount {
586    pub fn new_single() -> Self {
588        CubeCount::Static(1, 1, 1)
589    }
590
591    pub fn new_1d(x: u32) -> Self {
593        CubeCount::Static(x, 1, 1)
594    }
595
596    pub fn new_2d(x: u32, y: u32) -> Self {
598        CubeCount::Static(x, y, 1)
599    }
600
601    pub fn new_3d(x: u32, y: u32, z: u32) -> Self {
603        CubeCount::Static(x, y, z)
604    }
605}
606
607impl Debug for CubeCount {
608    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
609        match self {
610            CubeCount::Static(x, y, z) => f.write_fmt(format_args!("({x}, {y}, {z})")),
611            CubeCount::Dynamic(_) => f.write_str("binding"),
612        }
613    }
614}
615
616impl Clone for CubeCount {
617    fn clone(&self) -> Self {
618        match self {
619            Self::Static(x, y, z) => Self::Static(*x, *y, *z),
620            Self::Dynamic(handle) => Self::Dynamic(handle.clone()),
621        }
622    }
623}