Skip to main content

cobre_comm/
types.rs

1//! Type definitions for the cobre-comm abstraction layer.
2//!
3//! This module provides the shared types used across all communication backends:
4//!
5//! - [`ReduceOp`] — enumeration of supported reduction operations (sum, min, max)
6//!   passed to `Communicator::allreduce`.
7//! - [`CommError`] — top-level error type for the communicator API, covering
8//!   collective operation errors, buffer size mismatches, and shared memory failures.
9//! - [`BackendError`] — error type for backend selection and initialization failures.
10
11/// Element-wise reduction operations for `allreduce`.
12///
13/// These map directly to MPI reduction operations used during distributed execution.
14/// The `Sum` and `Min` variants are the two operations required by the backward
15/// pass: `MPI_SUM` for upper bound statistics and `MPI_MIN` for the lower bound.
16/// Because MPI may not support mixed reduction operations in a single `allreduce`
17/// call, the training loop issues two separate calls — one with [`ReduceOp::Min`]
18/// for the lower bound scalar and one with [`ReduceOp::Sum`] for the remaining
19/// statistics.
20#[derive(Debug, Clone, Copy, PartialEq, Eq)]
21pub enum ReduceOp {
22    /// Element-wise summation.
23    ///
24    /// Used for upper bound statistics: total cost sum, sum of squares, and
25    /// trajectory count aggregated across all ranks after the forward pass.
26    Sum,
27
28    /// Element-wise minimum.
29    ///
30    /// Used for lower bound aggregation: the minimum first-stage LP objective
31    /// across all ranks is the global lower bound for the current iteration.
32    Min,
33
34    /// Element-wise maximum.
35    ///
36    /// Reserved for future use, for example maximum per-rank solve time for
37    /// load balance diagnostics.
38    Max,
39}
40
41/// Errors that can occur during collective communication or shared memory operations.
42///
43/// This type is returned by all fallible methods on `Communicator`
44/// and `SharedMemoryProvider`.
45#[derive(Debug, thiserror::Error)]
46pub enum CommError {
47    /// An MPI collective operation failed at the library level.
48    ///
49    /// Contains the MPI error code and a human-readable description.
50    #[error(
51        "collective operation '{operation}' failed with MPI error code {mpi_error_code}: {message}"
52    )]
53    CollectiveFailed {
54        /// Name of the collective operation that failed (e.g., `"allgatherv"`).
55        operation: &'static str,
56        /// MPI error code returned by the MPI library.
57        mpi_error_code: i32,
58        /// Human-readable description of the failure.
59        message: String,
60    },
61
62    /// The buffer sizes provided to a collective operation are inconsistent.
63    ///
64    /// For example, `recv.len() < sum(counts)` in `allgatherv`, or
65    /// `send.len() != recv.len()` in `allreduce`.
66    #[error("invalid buffer size for '{operation}': expected {expected} elements, got {actual}")]
67    InvalidBufferSize {
68        /// Name of the collective operation (e.g., `"allreduce"`).
69        operation: &'static str,
70        /// Expected buffer length.
71        expected: usize,
72        /// Actual buffer length supplied by the caller.
73        actual: usize,
74    },
75
76    /// The `root` rank argument is out of range (`root >= size()`).
77    #[error("invalid root rank {root}: communicator has only {size} rank(s)")]
78    InvalidRoot {
79        /// The out-of-range root rank value provided by the caller.
80        root: usize,
81        /// Total number of ranks in the communicator.
82        size: usize,
83    },
84
85    /// The communicator has been finalized or is in an invalid state.
86    ///
87    /// This typically occurs if `MPI_Finalize` has been called before all
88    /// collective operations have completed.
89    #[error("communicator is in an invalid state (MPI may have been finalized)")]
90    InvalidCommunicator,
91
92    /// A shared memory allocation request was rejected by the OS.
93    ///
94    /// This can occur if the requested size exceeds system shared memory limits
95    /// (`/proc/sys/kernel/shmmax` on Linux), if the process lacks permissions for
96    /// shared memory operations, or if the system is out of shared memory resources.
97    #[error("shared memory allocation of {requested_bytes} bytes failed: {message}")]
98    AllocationFailed {
99        /// Number of bytes that were requested.
100        requested_bytes: usize,
101        /// Human-readable description of why the allocation failed.
102        message: String,
103    },
104}
105
106/// Errors that can occur during backend selection and initialization.
107///
108/// This type is returned by the factory function `create_communicator` when the
109/// requested backend cannot be selected or initialized.
110#[derive(Debug, thiserror::Error)]
111pub enum BackendError {
112    /// The requested backend is not compiled into this binary.
113    ///
114    /// The user or environment requested a backend that was not enabled via Cargo
115    /// feature flags at compile time.
116    #[error(
117        "communication backend '{requested}' is not available in this build (available: {available})",
118        available = available.join(", ")
119    )]
120    BackendNotAvailable {
121        /// The backend name that was requested (e.g., `"mpi"`).
122        requested: String,
123        /// List of backend names that are compiled into this binary.
124        available: Vec<String>,
125    },
126
127    /// The requested backend name is not recognized.
128    ///
129    /// The value set in `COBRE_COMM_BACKEND` does not match any known backend name.
130    #[error(
131        "unknown communication backend '{requested}' (known backends: {available})",
132        available = available.join(", ")
133    )]
134    InvalidBackend {
135        /// The unrecognized backend name that was requested.
136        requested: String,
137        /// List of all known backend names (compiled in or not).
138        available: Vec<String>,
139    },
140
141    /// The backend initialization failed.
142    ///
143    /// The backend was correctly selected but failed to initialize, for example
144    /// because the MPI runtime is not installed, the TCP coordinator is unreachable,
145    /// or the shared memory segment does not exist.
146    #[error("'{backend}' backend initialization failed: {source}")]
147    InitializationFailed {
148        /// Name of the backend that failed to initialize (e.g., `"mpi"`).
149        backend: String,
150        /// The underlying error from the backend initialization.
151        source: Box<dyn std::error::Error + Send + Sync>,
152    },
153
154    /// Required environment variables for the selected backend are not set.
155    ///
156    /// The TCP and shared memory backends require additional configuration via
157    /// environment variables. This error lists the variables that are missing.
158    #[error(
159        "backend '{backend}' requires missing configuration: {missing_vars}",
160        missing_vars = missing_vars.join(", ")
161    )]
162    MissingConfiguration {
163        /// Name of the backend requiring configuration (e.g., `"tcp"`).
164        backend: String,
165        /// List of environment variable names that are not set.
166        missing_vars: Vec<String>,
167    },
168}
169
170#[cfg(test)]
171mod tests {
172    use super::{BackendError, CommError, ReduceOp};
173
174    #[test]
175    fn test_reduce_op_debug_format() {
176        assert_eq!(format!("{:?}", ReduceOp::Sum), "Sum");
177        assert_eq!(format!("{:?}", ReduceOp::Min), "Min");
178        assert_eq!(format!("{:?}", ReduceOp::Max), "Max");
179    }
180
181    #[test]
182    fn test_reduce_op_copy_eq() {
183        let op = ReduceOp::Sum;
184        let cloned = op;
185        assert_eq!(op, cloned);
186        let copied: ReduceOp = op;
187        assert_eq!(op, copied);
188        assert_ne!(ReduceOp::Sum, ReduceOp::Min);
189        assert_ne!(ReduceOp::Min, ReduceOp::Max);
190        assert_ne!(ReduceOp::Sum, ReduceOp::Max);
191    }
192
193    #[test]
194    fn test_comm_error_display() {
195        let err = CommError::CollectiveFailed {
196            operation: "allgatherv",
197            mpi_error_code: 5,
198            message: "test".into(),
199        };
200        let display = format!("{err}");
201        assert!(display.contains("allgatherv"), "display was: {display}");
202        assert!(display.contains("test"), "display was: {display}");
203
204        let err = CommError::InvalidBufferSize {
205            operation: "allreduce",
206            expected: 4,
207            actual: 3,
208        };
209        let display = format!("{err}");
210        assert!(display.contains("allreduce"), "display was: {display}");
211        assert!(display.contains('4'), "display was: {display}");
212        assert!(display.contains('3'), "display was: {display}");
213
214        let err = CommError::InvalidRoot { root: 5, size: 4 };
215        let display = format!("{err}");
216        assert!(display.contains('5'), "display was: {display}");
217        assert!(display.contains('4'), "display was: {display}");
218
219        let display = format!("{}", CommError::InvalidCommunicator);
220        assert!(!display.is_empty(), "display was empty");
221
222        let err = CommError::AllocationFailed {
223            requested_bytes: 1024,
224            message: "permission denied".into(),
225        };
226        let display = format!("{err}");
227        assert!(display.contains("1024"), "display was: {display}");
228        assert!(
229            display.contains("permission denied"),
230            "display was: {display}"
231        );
232    }
233
234    #[test]
235    fn test_comm_error_debug() {
236        let err = CommError::CollectiveFailed {
237            operation: "broadcast",
238            mpi_error_code: 1,
239            message: "rank died".into(),
240        };
241        let debug = format!("{err:?}");
242        assert!(debug.contains("CollectiveFailed"), "debug was: {debug}");
243
244        let debug = format!("{:?}", CommError::InvalidCommunicator);
245        assert!(debug.contains("InvalidCommunicator"), "debug was: {debug}");
246    }
247
248    #[test]
249    fn test_backend_error_display() {
250        let err = BackendError::BackendNotAvailable {
251            requested: "mpi".into(),
252            available: vec!["local".into()],
253        };
254        let display = format!("{err}");
255        assert!(display.contains("mpi"), "display was: {display}");
256
257        let err = BackendError::InvalidBackend {
258            requested: "foobar".into(),
259            available: vec!["mpi".into(), "local".into()],
260        };
261        let display = format!("{err}");
262        assert!(display.contains("foobar"), "display was: {display}");
263
264        let err = BackendError::InitializationFailed {
265            backend: "mpi".into(),
266            source: "MPI runtime not found".into(),
267        };
268        let display = format!("{err}");
269        assert!(display.contains("mpi"), "display was: {display}");
270
271        let err = BackendError::MissingConfiguration {
272            backend: "tcp".into(),
273            missing_vars: vec!["COBRE_TCP_COORDINATOR".into(), "COBRE_TCP_RANK".into()],
274        };
275        let display = format!("{err}");
276        assert!(display.contains("tcp"), "display was: {display}");
277        assert!(
278            display.contains("COBRE_TCP_COORDINATOR"),
279            "display was: {display}"
280        );
281    }
282
283    #[test]
284    fn test_comm_error_std_error() {
285        fn accepts_std_error(_e: &dyn std::error::Error) {}
286        accepts_std_error(&CommError::InvalidCommunicator);
287        accepts_std_error(&CommError::InvalidRoot { root: 2, size: 1 });
288    }
289
290    #[test]
291    fn test_backend_error_std_error() {
292        fn accepts_std_error(_e: &dyn std::error::Error) {}
293        accepts_std_error(&BackendError::BackendNotAvailable {
294            requested: "mpi".into(),
295            available: vec!["local".into()],
296        });
297        accepts_std_error(&BackendError::MissingConfiguration {
298            backend: "tcp".into(),
299            missing_vars: vec!["COBRE_TCP_RANK".into()],
300        });
301    }
302}