gllm_kernels/comm/
traits.rs1use std::fmt;
4
5use burn::tensor::TensorData;
6
7#[derive(Debug)]
9pub enum CommError {
10 InvalidConfig(String),
12 ConnectionFailed(String),
14 SendFailed(String),
16 RecvFailed(String),
18 Disconnected,
20 Serialization(String),
22}
23
24impl fmt::Display for CommError {
25 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
26 match self {
27 CommError::InvalidConfig(msg) => write!(f, "Invalid config: {}", msg),
28 CommError::ConnectionFailed(msg) => write!(f, "Connection failed: {}", msg),
29 CommError::SendFailed(msg) => write!(f, "Send failed: {}", msg),
30 CommError::RecvFailed(msg) => write!(f, "Recv failed: {}", msg),
31 CommError::Disconnected => write!(f, "Channel disconnected"),
32 CommError::Serialization(msg) => write!(f, "Serialization error: {}", msg),
33 }
34 }
35}
36
37impl std::error::Error for CommError {}
38
39pub type CommResult<T> = Result<T, CommError>;
41
42pub trait Communicator: Send + Sync {
44 fn rank(&self) -> usize;
46
47 fn world_size(&self) -> usize;
49
50 fn send(&self, data: &TensorData) -> CommResult<()>;
52
53 fn recv(&self) -> CommResult<TensorData>;
55
56 fn send_recv(&self, send_data: &TensorData) -> CommResult<TensorData> {
58 self.send(send_data)?;
59 self.recv()
60 }
61
62 fn barrier(&self) -> CommResult<()>;
64}