Skip to main content

ibverbs_rs/network/
builder.rs

1use crate::ibverbs::access_config::AccessFlags;
2use crate::ibverbs::error::IbvResult;
3use crate::ibverbs::protection_domain::ProtectionDomain;
4use crate::ibverbs::queue_pair::builder::QueuePairEndpoint;
5use crate::ibverbs::queue_pair::config::*;
6use crate::multi_channel::MultiChannel;
7use crate::multi_channel::{PeerRemoteMemoryRegion, PreparedMultiChannel};
8use crate::network::Node;
9use crate::network::barrier::{BarrierAlgorithm, PreparedBarrier};
10use bon::bon;
11use serde::{Deserialize, Serialize};
12use std::io;
13
14#[bon]
15impl Node {
16    #[builder(state_mod(vis = "pub"))]
17    pub fn builder(
18        rank: usize,
19        world_size: usize,
20        pd: &ProtectionDomain,
21        #[builder(default = BarrierAlgorithm::BinaryTree)] barrier: BarrierAlgorithm,
22        #[builder(default =
23            AccessFlags::new()
24                .with_local_write()
25                .with_remote_read()
26                .with_remote_write()
27        )]
28        access: AccessFlags,
29        #[builder(default = 32)] min_cq_entries: u32,
30        #[builder(default = 16)] max_send_wr: u32,
31        #[builder(default = 16)] max_recv_wr: u32,
32        #[builder(default = 16)] max_send_sge: u32,
33        #[builder(default = 16)] max_recv_sge: u32,
34        #[builder(default)] max_rnr_retries: MaxRnrRetries,
35        #[builder(default)] max_ack_retries: MaxAckRetries,
36        #[builder(default)] min_rnr_timer: MinRnrTimer,
37        #[builder(default)] ack_timeout: AckTimeout,
38        #[builder(default)] mtu: MaximumTransferUnit,
39        #[builder(default)] send_psn: PacketSequenceNumber,
40        #[builder(default)] recv_psn: PacketSequenceNumber,
41    ) -> IbvResult<PreparedNode> {
42        let multi_channel = MultiChannel::builder()
43            .num_channels(world_size)
44            .pd(pd)
45            .min_cq_entries(min_cq_entries)
46            .access(access)
47            .max_send_wr(max_send_wr)
48            .max_recv_wr(max_recv_wr)
49            .max_send_sge(max_send_sge)
50            .max_recv_sge(max_recv_sge)
51            .max_rnr_retries(max_rnr_retries)
52            .max_ack_retries(max_ack_retries)
53            .min_rnr_timer(min_rnr_timer)
54            .ack_timeout(ack_timeout)
55            .mtu(mtu)
56            .send_psn(send_psn)
57            .recv_psn(recv_psn)
58            .build()?;
59        let barrier = barrier.instance(pd, rank, world_size)?;
60
61        Ok(PreparedNode {
62            rank,
63            world_size,
64            multi_channel,
65            barrier,
66        })
67    }
68}
69
70/// A [`Node`] that has been configured but not yet connected to its peers.
71///
72/// Created by [`Node::builder`]. Call [`endpoint`](Self::endpoint) to obtain the local
73/// connection information, exchange endpoints with all peers, then call
74/// [`gather_endpoints`](Self::gather_endpoints) followed by [`handshake`](Self::handshake)
75/// to finish the connections.
76pub struct PreparedNode {
77    rank: usize,
78    world_size: usize,
79    multi_channel: PreparedMultiChannel,
80    barrier: PreparedBarrier,
81}
82
83/// The per-peer endpoint information exchanged during setup, containing both
84/// the queue pair endpoint and the barrier memory region handle.
85#[derive(Debug, Copy, Clone, Serialize, Deserialize)]
86pub struct NetworkChannelEndpoint {
87    pub(crate) single_channel_endpoint: QueuePairEndpoint,
88    pub(crate) barrier_mr_remote: PeerRemoteMemoryRegion,
89}
90
91/// This node's endpoint information, ready to be sent to all peers.
92#[derive(Clone, Serialize, Deserialize)]
93pub struct LocalEndpoint {
94    rank: usize,
95    endpoints: Box<[NetworkChannelEndpoint]>,
96}
97
98/// Validated collection of remote endpoints, one per peer. Produced by
99/// [`PreparedNode::gather_endpoints`] and consumed by [`PreparedNode::handshake`].
100///
101/// Contains one [`NetworkChannelEndpoint`] per rank, indexed in rank order.
102pub struct RemoteEndpoints(Box<[NetworkChannelEndpoint]>);
103
104impl PreparedNode {
105    /// Returns this node's local endpoint information to be exchanged with all peers.
106    pub fn endpoint(&self) -> LocalEndpoint {
107        LocalEndpoint {
108            rank: self.rank,
109            endpoints: self
110                .multi_channel
111                .endpoints()
112                .into_iter()
113                .map(|single_channel_endpoint| NetworkChannelEndpoint {
114                    single_channel_endpoint,
115                    barrier_mr_remote: self.barrier.remote_mr(),
116                })
117                .collect(),
118        }
119    }
120
121    /// Collects and validates endpoints received from all peers.
122    ///
123    /// Each peer's [`LocalEndpoint`] must appear exactly once. Returns an error if
124    /// any rank is out of bounds, duplicated, or missing.
125    pub fn gather_endpoints(
126        &self,
127        endpoints: impl IntoIterator<Item = LocalEndpoint>,
128    ) -> io::Result<RemoteEndpoints> {
129        // Temporary initialization tracker
130        let mut tmp: Vec<Option<NetworkChannelEndpoint>> = vec![None; self.world_size];
131
132        for endpoint in endpoints {
133            // Check rank bounds
134            let in_slot = tmp.get_mut(endpoint.rank).ok_or_else(|| {
135                io::Error::new(
136                    io::ErrorKind::InvalidInput,
137                    format!(
138                        "Input endpoint rank {} out of bounds (0..{})",
139                        endpoint.rank, self.world_size
140                    ),
141                )
142            })?;
143
144            // Detect duplicate ranks
145            if in_slot.is_some() {
146                return Err(io::Error::new(
147                    io::ErrorKind::InvalidInput,
148                    format!("Duplicate endpoint for rank {}", endpoint.rank),
149                ));
150            }
151
152            // Check that the endpoint has a rechannel for our rank
153            let qp_endpoint = endpoint.endpoints.get(self.rank).ok_or_else(|| {
154                io::Error::new(
155                    io::ErrorKind::InvalidInput,
156                    format!(
157                        "Input endpoint for rank {} missing rechannel for local rank {}",
158                        endpoint.rank, self.rank
159                    ),
160                )
161            })?;
162
163            // Fill the temporary slot
164            *in_slot = Some(*qp_endpoint);
165        }
166
167        // Convert Option<Vec<_>> -> Vec<_> in one go, validating all slots are filled
168        let in_endpoints: Vec<NetworkChannelEndpoint> = tmp
169            .into_iter()
170            .enumerate()
171            .map(|(i, opt)| {
172                opt.ok_or_else(|| {
173                    io::Error::new(
174                        io::ErrorKind::InvalidData,
175                        format!("Missing endpoint from rank {}", i),
176                    )
177                })
178            })
179            .collect::<Result<_, _>>()?;
180
181        Ok(RemoteEndpoints(in_endpoints.into_boxed_slice()))
182    }
183
184    /// Connects all channels and the barrier, returning a ready-to-use [`Node`].
185    pub fn handshake(self, endpoints: RemoteEndpoints) -> IbvResult<Node> {
186        let multi_channel = self
187            .multi_channel
188            .handshake(endpoints.0.iter().map(|e| e.single_channel_endpoint))?;
189        let barrier = self
190            .barrier
191            .link_remote(endpoints.0.iter().map(|e| e.barrier_mr_remote).collect());
192
193        Ok(Node {
194            rank: self.rank,
195            world_size: self.world_size,
196            multi_channel,
197            barrier,
198        })
199    }
200}