gllm_kernels/comm/
traits.rs

1//! Communication traits for distributed computation.
2
3use std::fmt;
4
5use burn::tensor::TensorData;
6
7/// Communication error types.
8#[derive(Debug)]
9pub enum CommError {
10    /// Invalid rank or configuration.
11    InvalidConfig(String),
12    /// Connection failed.
13    ConnectionFailed(String),
14    /// Send operation failed.
15    SendFailed(String),
16    /// Receive operation failed.
17    RecvFailed(String),
18    /// Channel disconnected.
19    Disconnected,
20    /// Serialization/deserialization error.
21    Serialization(String),
22}
23
24impl fmt::Display for CommError {
25    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
26        match self {
27            CommError::InvalidConfig(msg) => write!(f, "Invalid config: {}", msg),
28            CommError::ConnectionFailed(msg) => write!(f, "Connection failed: {}", msg),
29            CommError::SendFailed(msg) => write!(f, "Send failed: {}", msg),
30            CommError::RecvFailed(msg) => write!(f, "Recv failed: {}", msg),
31            CommError::Disconnected => write!(f, "Channel disconnected"),
32            CommError::Serialization(msg) => write!(f, "Serialization error: {}", msg),
33        }
34    }
35}
36
37impl std::error::Error for CommError {}
38
39/// Result type for communication operations.
40pub type CommResult<T> = Result<T, CommError>;
41
42/// Communicator trait for ring communication pattern.
43pub trait Communicator: Send + Sync {
44    /// Get the rank of this communicator.
45    fn rank(&self) -> usize;
46
47    /// Get the total number of participants.
48    fn world_size(&self) -> usize;
49
50    /// Send data to the next rank in the ring.
51    fn send(&self, data: &TensorData) -> CommResult<()>;
52
53    /// Receive data from the previous rank in the ring.
54    fn recv(&self) -> CommResult<TensorData>;
55
56    /// Send to next and receive from previous simultaneously.
57    fn send_recv(&self, send_data: &TensorData) -> CommResult<TensorData> {
58        self.send(send_data)?;
59        self.recv()
60    }
61
62    /// Barrier synchronization across all ranks.
63    fn barrier(&self) -> CommResult<()>;
64}