use burn_backend::{DTypeUsageSet, ExecutionError, TensorData};
use burn_communication::{Address, data_service::TensorTransferId};
use burn_ir::{OperationIr, TensorId, TensorIr};
use burn_std::{
DType,
id::{IdGenerator, StreamId},
};
use serde::{Deserialize, Serialize};
use std::fmt::Display;
#[allow(missing_docs)]
#[derive(new, Serialize, Deserialize, Debug, Hash, PartialEq, Eq, Clone, Copy, PartialOrd, Ord)]
pub struct ConnectionId {
pub position: u64,
pub stream_id: StreamId,
}
#[derive(Debug, PartialEq, Eq, Clone, Copy, Hash, Serialize, Deserialize, PartialOrd, Ord)]
pub struct SessionId {
id: u64,
}
impl Display for SessionId {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
writeln!(f, "SessionId({})", self.id)
}
}
impl SessionId {
#[allow(dead_code)]
pub fn new() -> Self {
Self {
id: IdGenerator::generate(),
}
}
}
#[allow(missing_docs)]
#[derive(Serialize, Deserialize, Debug)]
pub enum Task {
Compute(ComputeTask, ConnectionId),
Init(SessionId),
Close(SessionId),
}
#[allow(missing_docs)]
#[derive(Serialize, Deserialize, Debug, Clone)]
pub struct TensorRemote {
pub transfer_id: TensorTransferId,
pub address: Address,
}
#[allow(missing_docs)]
#[derive(Serialize, Deserialize, Debug)]
pub enum ComputeTask {
Seed(u64),
RegisterOperation(Box<OperationIr>),
RegisterTensor(TensorId, TensorData),
RegisterTensorRemote(TensorRemote, TensorId),
ExposeTensorRemote {
tensor: TensorIr,
count: u32,
transfer_id: TensorTransferId,
},
ReadTensor(TensorIr),
SyncBackend,
DTypeUsage(DType),
}
#[allow(missing_docs)]
#[derive(Serialize, Deserialize, Debug)]
pub struct TaskResponse {
pub content: TaskResponseContent,
pub id: ConnectionId,
}
#[allow(missing_docs)]
#[derive(Serialize, Deserialize, Debug)]
pub enum TaskResponseContent {
ReadTensor(Result<TensorData, ExecutionError>),
SyncBackend(Result<(), ExecutionError>),
DTypeUsage(DTypeUsageSet),
}