use crate::{
channel::ComputeChannel,
server::{Binding, ComputeServer, Handle},
storage::ComputeStorage,
ExecutionMode,
};
use alloc::sync::Arc;
use alloc::vec::Vec;
pub use cubecl_common::sync_type::SyncType;
#[derive(Debug)]
pub struct ComputeClient<Server: ComputeServer, Channel> {
channel: Channel,
features: Arc<Server::FeatureSet>,
}
impl<S, C> Clone for ComputeClient<S, C>
where
S: ComputeServer,
C: ComputeChannel<S>,
{
fn clone(&self) -> Self {
Self {
channel: self.channel.clone(),
features: self.features.clone(),
}
}
}
impl<Server, Channel> ComputeClient<Server, Channel>
where
Server: ComputeServer,
Channel: ComputeChannel<Server>,
{
pub fn new(channel: Channel, features: Arc<Server::FeatureSet>) -> Self {
Self { channel, features }
}
pub async fn read_async(&self, binding: Binding<Server>) -> Vec<u8> {
self.channel.read(binding).await
}
pub fn read(&self, binding: Binding<Server>) -> Vec<u8> {
cubecl_common::reader::read_sync(self.channel.read(binding))
}
pub fn get_resource(
&self,
binding: Binding<Server>,
) -> <Server::Storage as ComputeStorage>::Resource {
self.channel.get_resource(binding)
}
pub fn create(&self, data: &[u8]) -> Handle<Server> {
self.channel.create(data)
}
pub fn empty(&self, size: usize) -> Handle<Server> {
self.channel.empty(size)
}
pub fn execute(
&self,
kernel: Server::Kernel,
count: Server::DispatchOptions,
bindings: Vec<Binding<Server>>,
) {
unsafe {
self.channel
.execute(kernel, count, bindings, ExecutionMode::Checked)
}
}
pub unsafe fn execute_unchecked(
&self,
kernel: Server::Kernel,
count: Server::DispatchOptions,
bindings: Vec<Binding<Server>>,
) {
self.channel
.execute(kernel, count, bindings, ExecutionMode::Unchecked)
}
pub fn sync(&self, sync_type: SyncType) {
self.channel.sync(sync_type)
}
pub fn features(&self) -> &Server::FeatureSet {
self.features.as_ref()
}
}