cobre_comm/types.rs
1//! Type definitions for the cobre-comm abstraction layer.
2//!
3//! This module provides the shared types used across all communication backends:
4//!
5//! - [`ReduceOp`] — enumeration of supported reduction operations (sum, min, max)
6//! passed to `Communicator::allreduce`.
7//! - [`CommError`] — top-level error type for the communicator API, covering
8//! collective operation errors, buffer size mismatches, and shared memory failures.
9//! - [`BackendError`] — error type for backend selection and initialization failures.
10
11/// Element-wise reduction operations for `allreduce`.
12///
13/// These map directly to MPI reduction operations used during distributed execution.
14/// The `Sum` and `Min` variants are the two operations required by the backward
15/// pass: `MPI_SUM` for upper bound statistics and `MPI_MIN` for the lower bound.
16/// Because MPI may not support mixed reduction operations in a single `allreduce`
17/// call, the training loop issues two separate calls — one with [`ReduceOp::Min`]
18/// for the lower bound scalar and one with [`ReduceOp::Sum`] for the remaining
19/// statistics.
20#[derive(Debug, Clone, Copy, PartialEq, Eq)]
21pub enum ReduceOp {
22 /// Element-wise summation.
23 ///
24 /// Used for upper bound statistics: total cost sum, sum of squares, and
25 /// trajectory count aggregated across all ranks after the forward pass.
26 Sum,
27
28 /// Element-wise minimum.
29 ///
30 /// Used for lower bound aggregation: the minimum first-stage LP objective
31 /// across all ranks is the global lower bound for the current iteration.
32 Min,
33
34 /// Element-wise maximum.
35 ///
36 /// Reserved for future use, for example maximum per-rank solve time for
37 /// load balance diagnostics.
38 Max,
39}
40
41/// Errors that can occur during collective communication or shared memory operations.
42///
43/// This type is returned by all fallible methods on `Communicator`
44/// and `SharedMemoryProvider`.
45#[derive(Debug, thiserror::Error)]
46pub enum CommError {
47 /// An MPI collective operation failed at the library level.
48 ///
49 /// Contains the MPI error code and a human-readable description.
50 #[error(
51 "collective operation '{operation}' failed with MPI error code {mpi_error_code}: {message}"
52 )]
53 CollectiveFailed {
54 /// Name of the collective operation that failed (e.g., `"allgatherv"`).
55 operation: &'static str,
56 /// MPI error code returned by the MPI library.
57 mpi_error_code: i32,
58 /// Human-readable description of the failure.
59 message: String,
60 },
61
62 /// The buffer sizes provided to a collective operation are inconsistent.
63 ///
64 /// For example, `recv.len() < sum(counts)` in `allgatherv`, or
65 /// `send.len() != recv.len()` in `allreduce`.
66 #[error("invalid buffer size for '{operation}': expected {expected} elements, got {actual}")]
67 InvalidBufferSize {
68 /// Name of the collective operation (e.g., `"allreduce"`).
69 operation: &'static str,
70 /// Expected buffer length.
71 expected: usize,
72 /// Actual buffer length supplied by the caller.
73 actual: usize,
74 },
75
76 /// The `root` rank argument is out of range (`root >= size()`).
77 #[error("invalid root rank {root}: communicator has only {size} rank(s)")]
78 InvalidRoot {
79 /// The out-of-range root rank value provided by the caller.
80 root: usize,
81 /// Total number of ranks in the communicator.
82 size: usize,
83 },
84
85 /// The communicator has been finalized or is in an invalid state.
86 ///
87 /// This typically occurs if `MPI_Finalize` has been called before all
88 /// collective operations have completed.
89 #[error("communicator is in an invalid state (MPI may have been finalized)")]
90 InvalidCommunicator,
91
92 /// A shared memory allocation request was rejected by the OS.
93 ///
94 /// This can occur if the requested size exceeds system shared memory limits
95 /// (`/proc/sys/kernel/shmmax` on Linux), if the process lacks permissions for
96 /// shared memory operations, or if the system is out of shared memory resources.
97 #[error("shared memory allocation of {requested_bytes} bytes failed: {message}")]
98 AllocationFailed {
99 /// Number of bytes that were requested.
100 requested_bytes: usize,
101 /// Human-readable description of why the allocation failed.
102 message: String,
103 },
104}
105
106/// Errors that can occur during backend selection and initialization.
107///
108/// This type is returned by the factory function `create_communicator` when the
109/// requested backend cannot be selected or initialized.
110#[derive(Debug, thiserror::Error)]
111pub enum BackendError {
112 /// The requested backend is not compiled into this binary.
113 ///
114 /// The user or environment requested a backend that was not enabled via Cargo
115 /// feature flags at compile time.
116 #[error(
117 "communication backend '{requested}' is not available in this build (available: {available})",
118 available = available.join(", ")
119 )]
120 BackendNotAvailable {
121 /// The backend name that was requested (e.g., `"mpi"`).
122 requested: String,
123 /// List of backend names that are compiled into this binary.
124 available: Vec<String>,
125 },
126
127 /// The requested backend name is not recognized.
128 ///
129 /// The value set in `COBRE_COMM_BACKEND` does not match any known backend name.
130 #[error(
131 "unknown communication backend '{requested}' (known backends: {available})",
132 available = available.join(", ")
133 )]
134 InvalidBackend {
135 /// The unrecognized backend name that was requested.
136 requested: String,
137 /// List of all known backend names (compiled in or not).
138 available: Vec<String>,
139 },
140
141 /// The backend initialization failed.
142 ///
143 /// The backend was correctly selected but failed to initialize, for example
144 /// because the MPI runtime is not installed, the TCP coordinator is unreachable,
145 /// or the shared memory segment does not exist.
146 #[error("'{backend}' backend initialization failed: {source}")]
147 InitializationFailed {
148 /// Name of the backend that failed to initialize (e.g., `"mpi"`).
149 backend: String,
150 /// The underlying error from the backend initialization.
151 source: Box<dyn std::error::Error + Send + Sync>,
152 },
153
154 /// Required environment variables for the selected backend are not set.
155 ///
156 /// The TCP and shared memory backends require additional configuration via
157 /// environment variables. This error lists the variables that are missing.
158 #[error(
159 "backend '{backend}' requires missing configuration: {missing_vars}",
160 missing_vars = missing_vars.join(", ")
161 )]
162 MissingConfiguration {
163 /// Name of the backend requiring configuration (e.g., `"tcp"`).
164 backend: String,
165 /// List of environment variable names that are not set.
166 missing_vars: Vec<String>,
167 },
168}
169
170#[cfg(test)]
171mod tests {
172 use super::{BackendError, CommError, ReduceOp};
173
174 #[test]
175 fn test_reduce_op_debug_format() {
176 assert_eq!(format!("{:?}", ReduceOp::Sum), "Sum");
177 assert_eq!(format!("{:?}", ReduceOp::Min), "Min");
178 assert_eq!(format!("{:?}", ReduceOp::Max), "Max");
179 }
180
181 #[test]
182 fn test_reduce_op_copy_eq() {
183 let op = ReduceOp::Sum;
184 let cloned = op;
185 assert_eq!(op, cloned);
186 let copied: ReduceOp = op;
187 assert_eq!(op, copied);
188 assert_ne!(ReduceOp::Sum, ReduceOp::Min);
189 assert_ne!(ReduceOp::Min, ReduceOp::Max);
190 assert_ne!(ReduceOp::Sum, ReduceOp::Max);
191 }
192
193 #[test]
194 fn test_comm_error_display() {
195 let err = CommError::CollectiveFailed {
196 operation: "allgatherv",
197 mpi_error_code: 5,
198 message: "test".into(),
199 };
200 let display = format!("{err}");
201 assert!(display.contains("allgatherv"), "display was: {display}");
202 assert!(display.contains("test"), "display was: {display}");
203
204 let err = CommError::InvalidBufferSize {
205 operation: "allreduce",
206 expected: 4,
207 actual: 3,
208 };
209 let display = format!("{err}");
210 assert!(display.contains("allreduce"), "display was: {display}");
211 assert!(display.contains('4'), "display was: {display}");
212 assert!(display.contains('3'), "display was: {display}");
213
214 let err = CommError::InvalidRoot { root: 5, size: 4 };
215 let display = format!("{err}");
216 assert!(display.contains('5'), "display was: {display}");
217 assert!(display.contains('4'), "display was: {display}");
218
219 let display = format!("{}", CommError::InvalidCommunicator);
220 assert!(!display.is_empty(), "display was empty");
221
222 let err = CommError::AllocationFailed {
223 requested_bytes: 1024,
224 message: "permission denied".into(),
225 };
226 let display = format!("{err}");
227 assert!(display.contains("1024"), "display was: {display}");
228 assert!(
229 display.contains("permission denied"),
230 "display was: {display}"
231 );
232 }
233
234 #[test]
235 fn test_comm_error_debug() {
236 let err = CommError::CollectiveFailed {
237 operation: "broadcast",
238 mpi_error_code: 1,
239 message: "rank died".into(),
240 };
241 let debug = format!("{err:?}");
242 assert!(debug.contains("CollectiveFailed"), "debug was: {debug}");
243
244 let debug = format!("{:?}", CommError::InvalidCommunicator);
245 assert!(debug.contains("InvalidCommunicator"), "debug was: {debug}");
246 }
247
248 #[test]
249 fn test_backend_error_display() {
250 let err = BackendError::BackendNotAvailable {
251 requested: "mpi".into(),
252 available: vec!["local".into()],
253 };
254 let display = format!("{err}");
255 assert!(display.contains("mpi"), "display was: {display}");
256
257 let err = BackendError::InvalidBackend {
258 requested: "foobar".into(),
259 available: vec!["mpi".into(), "local".into()],
260 };
261 let display = format!("{err}");
262 assert!(display.contains("foobar"), "display was: {display}");
263
264 let err = BackendError::InitializationFailed {
265 backend: "mpi".into(),
266 source: "MPI runtime not found".into(),
267 };
268 let display = format!("{err}");
269 assert!(display.contains("mpi"), "display was: {display}");
270
271 let err = BackendError::MissingConfiguration {
272 backend: "tcp".into(),
273 missing_vars: vec!["COBRE_TCP_COORDINATOR".into(), "COBRE_TCP_RANK".into()],
274 };
275 let display = format!("{err}");
276 assert!(display.contains("tcp"), "display was: {display}");
277 assert!(
278 display.contains("COBRE_TCP_COORDINATOR"),
279 "display was: {display}"
280 );
281 }
282
283 #[test]
284 fn test_comm_error_std_error() {
285 fn accepts_std_error(_e: &dyn std::error::Error) {}
286 accepts_std_error(&CommError::InvalidCommunicator);
287 accepts_std_error(&CommError::InvalidRoot { root: 2, size: 1 });
288 }
289
290 #[test]
291 fn test_backend_error_std_error() {
292 fn accepts_std_error(_e: &dyn std::error::Error) {}
293 accepts_std_error(&BackendError::BackendNotAvailable {
294 requested: "mpi".into(),
295 available: vec!["local".into()],
296 });
297 accepts_std_error(&BackendError::MissingConfiguration {
298 backend: "tcp".into(),
299 missing_vars: vec!["COBRE_TCP_RANK".into()],
300 });
301 }
302}