burn_compute/
server.rs

1use core::fmt::Debug;
2
3use crate::{
4    memory_management::{MemoryHandle, MemoryManagement},
5    storage::ComputeStorage,
6    tune::AutotuneKey,
7};
8use alloc::vec::Vec;
9use burn_common::reader::Reader;
10
11/// The compute server is responsible for handling resources and computations over resources.
12///
13/// Everything in the server is mutable, therefore it should be solely accessed through the
14/// [compute channel](crate::channel::ComputeChannel) for thread safety.
15pub trait ComputeServer: Send + core::fmt::Debug
16where
17    Self: Sized,
18{
19    /// The kernel type defines the computation algorithms.
20    type Kernel: Send;
21    /// The [storage](ComputeStorage) type defines how data is stored and accessed.
22    type Storage: ComputeStorage;
23    /// The [memory management](MemoryManagement) type defines strategies for allocation in the [storage](ComputeStorage) type.
24    type MemoryManagement: MemoryManagement<Self::Storage>;
25    /// The key used to cache operations used on specific inputs in autotune
26    type AutotuneKey: AutotuneKey;
27
28    /// Given a handle, returns the owned resource as bytes.
29    fn read(&mut self, handle: &Handle<Self>) -> Reader<Vec<u8>>;
30
31    /// Given a resource as bytes, stores it and returns the memory handle.
32    fn create(&mut self, data: &[u8]) -> Handle<Self>;
33
34    /// Reserves `size` bytes in the storage, and returns a handle over them.
35    fn empty(&mut self, size: usize) -> Handle<Self>;
36
37    /// Executes the `kernel` over the given memory `handles`.
38    ///
39    /// Kernels have mutable access to every resource they are given
40    /// and are responsible of determining which should be read or written.
41    fn execute(&mut self, kernel: Self::Kernel, handles: &[&Handle<Self>]);
42
43    /// Wait for the completion of every task in the server.
44    fn sync(&mut self);
45}
46
47/// Server handle containing the [memory handle](MemoryManagement::Handle).
48#[derive(new, Debug)]
49pub struct Handle<Server: ComputeServer> {
50    /// Handle for the memory in use.
51    pub memory: <Server::MemoryManagement as MemoryManagement<Server::Storage>>::Handle,
52}
53
54impl<Server: ComputeServer> Handle<Server> {
55    /// If the tensor handle can be mut with an inplace operation.
56    pub fn can_mut(&self) -> bool {
57        self.memory.can_mut()
58    }
59}
60
61impl<Server: ComputeServer> Clone for Handle<Server> {
62    fn clone(&self) -> Self {
63        Self {
64            memory: self.memory.clone(),
65        }
66    }
67}