1use crate::{
2 memory_management::{
3 memory_pool::{SliceBinding, SliceHandle},
4 MemoryHandle, MemoryUsage,
5 },
6 storage::{BindingResource, ComputeStorage},
7 ExecutionMode,
8};
9use alloc::vec::Vec;
10use core::{fmt::Debug, future::Future};
11use cubecl_common::benchmark::TimestampsResult;
12
13pub trait ComputeServer: Send + core::fmt::Debug
18where
19 Self: Sized,
20{
21 type Kernel: Send;
23 type Storage: ComputeStorage;
25 type Feature: Ord + Copy + Debug + Send + Sync;
27
28 fn read(
30 &mut self,
31 bindings: Vec<Binding>,
32 ) -> impl Future<Output = Vec<Vec<u8>>> + Send + 'static;
33
34 fn get_resource(&mut self, binding: Binding) -> BindingResource<Self>;
36
37 fn create(&mut self, data: &[u8]) -> Handle;
39
40 fn empty(&mut self, size: usize) -> Handle;
42
43 unsafe fn execute(
52 &mut self,
53 kernel: Self::Kernel,
54 count: CubeCount,
55 bindings: Vec<Binding>,
56 kind: ExecutionMode,
57 );
58
59 fn flush(&mut self);
61
62 fn sync(&mut self) -> impl Future<Output = ()> + Send + 'static;
64
65 fn sync_elapsed(&mut self) -> impl Future<Output = TimestampsResult> + Send + 'static;
69
70 fn memory_usage(&self) -> MemoryUsage;
72
73 fn enable_timestamps(&mut self);
75
76 fn disable_timestamps(&mut self);
78}
79
80#[derive(new, Debug)]
82pub struct Handle {
83 pub memory: SliceHandle,
85 pub offset_start: Option<u64>,
87 pub offset_end: Option<u64>,
89 size: u64,
91}
92
93impl Handle {
94 pub fn offset_start(mut self, offset: u64) -> Self {
96 if let Some(val) = &mut self.offset_start {
97 *val += offset;
98 } else {
99 self.offset_start = Some(offset);
100 }
101
102 self
103 }
104 pub fn offset_end(mut self, offset: u64) -> Self {
106 if let Some(val) = &mut self.offset_end {
107 *val += offset;
108 } else {
109 self.offset_end = Some(offset);
110 }
111
112 self
113 }
114
115 pub fn size(&self) -> u64 {
117 self.size - self.offset_start.unwrap_or(0) - self.offset_end.unwrap_or(0)
118 }
119}
120
121#[derive(new, Debug)]
123pub struct Binding {
124 pub memory: SliceBinding,
126 pub offset_start: Option<u64>,
128 pub offset_end: Option<u64>,
130}
131
132impl Handle {
133 pub fn can_mut(&self) -> bool {
135 self.memory.can_mut()
136 }
137}
138
139impl Handle {
140 pub fn binding(self) -> Binding {
142 Binding {
143 memory: MemoryHandle::binding(self.memory),
144 offset_start: self.offset_start,
145 offset_end: self.offset_end,
146 }
147 }
148}
149
150impl Clone for Handle {
151 fn clone(&self) -> Self {
152 Self {
153 memory: self.memory.clone(),
154 offset_start: self.offset_start,
155 offset_end: self.offset_end,
156 size: self.size,
157 }
158 }
159}
160
161impl Clone for Binding {
162 fn clone(&self) -> Self {
163 Self {
164 memory: self.memory.clone(),
165 offset_start: self.offset_start,
166 offset_end: self.offset_end,
167 }
168 }
169}
170
171pub enum CubeCount {
175 Static(u32, u32, u32),
177 Dynamic(Binding),
179}
180
181impl CubeCount {
182 pub fn new_single() -> Self {
184 CubeCount::Static(1, 1, 1)
185 }
186
187 pub fn new_1d(x: u32) -> Self {
189 CubeCount::Static(x, 1, 1)
190 }
191
192 pub fn new_2d(x: u32, y: u32) -> Self {
194 CubeCount::Static(x, y, 1)
195 }
196
197 pub fn new_3d(x: u32, y: u32) -> Self {
199 CubeCount::Static(x, y, 1)
200 }
201}
202
203impl Debug for CubeCount {
204 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
205 match self {
206 CubeCount::Static(x, y, z) => f.write_fmt(format_args!("({x}, {y}, {z})")),
207 CubeCount::Dynamic(_) => f.write_str("binding"),
208 }
209 }
210}
211
212impl Clone for CubeCount {
213 fn clone(&self) -> Self {
214 match self {
215 Self::Static(x, y, z) => Self::Static(*x, *y, *z),
216 Self::Dynamic(handle) => Self::Dynamic(handle.clone()),
217 }
218 }
219}