ibverbs_rs/network/
builder.rs1use crate::ibverbs::access_config::AccessFlags;
2use crate::ibverbs::error::IbvResult;
3use crate::ibverbs::protection_domain::ProtectionDomain;
4use crate::ibverbs::queue_pair::builder::QueuePairEndpoint;
5use crate::ibverbs::queue_pair::config::*;
6use crate::multi_channel::MultiChannel;
7use crate::multi_channel::{PeerRemoteMemoryRegion, PreparedMultiChannel};
8use crate::network::Node;
9use crate::network::barrier::{BarrierAlgorithm, PreparedBarrier};
10use bon::bon;
11use serde::{Deserialize, Serialize};
12use std::io;
13
14#[bon]
15impl Node {
16 #[builder(state_mod(vis = "pub"))]
17 pub fn builder(
18 rank: usize,
19 world_size: usize,
20 pd: &ProtectionDomain,
21 #[builder(default = BarrierAlgorithm::BinaryTree)] barrier: BarrierAlgorithm,
22 #[builder(default =
23 AccessFlags::new()
24 .with_local_write()
25 .with_remote_read()
26 .with_remote_write()
27 )]
28 access: AccessFlags,
29 #[builder(default = 32)] min_cq_entries: u32,
30 #[builder(default = 16)] max_send_wr: u32,
31 #[builder(default = 16)] max_recv_wr: u32,
32 #[builder(default = 16)] max_send_sge: u32,
33 #[builder(default = 16)] max_recv_sge: u32,
34 #[builder(default)] max_rnr_retries: MaxRnrRetries,
35 #[builder(default)] max_ack_retries: MaxAckRetries,
36 #[builder(default)] min_rnr_timer: MinRnrTimer,
37 #[builder(default)] ack_timeout: AckTimeout,
38 #[builder(default)] mtu: MaximumTransferUnit,
39 #[builder(default)] send_psn: PacketSequenceNumber,
40 #[builder(default)] recv_psn: PacketSequenceNumber,
41 ) -> IbvResult<PreparedNode> {
42 let multi_channel = MultiChannel::builder()
43 .num_channels(world_size)
44 .pd(pd)
45 .min_cq_entries(min_cq_entries)
46 .access(access)
47 .max_send_wr(max_send_wr)
48 .max_recv_wr(max_recv_wr)
49 .max_send_sge(max_send_sge)
50 .max_recv_sge(max_recv_sge)
51 .max_rnr_retries(max_rnr_retries)
52 .max_ack_retries(max_ack_retries)
53 .min_rnr_timer(min_rnr_timer)
54 .ack_timeout(ack_timeout)
55 .mtu(mtu)
56 .send_psn(send_psn)
57 .recv_psn(recv_psn)
58 .build()?;
59 let barrier = barrier.instance(pd, rank, world_size)?;
60
61 Ok(PreparedNode {
62 rank,
63 world_size,
64 multi_channel,
65 barrier,
66 })
67 }
68}
69
70pub struct PreparedNode {
77 rank: usize,
78 world_size: usize,
79 multi_channel: PreparedMultiChannel,
80 barrier: PreparedBarrier,
81}
82
83#[derive(Debug, Copy, Clone, Serialize, Deserialize)]
86pub struct NetworkChannelEndpoint {
87 pub(crate) single_channel_endpoint: QueuePairEndpoint,
88 pub(crate) barrier_mr_remote: PeerRemoteMemoryRegion,
89}
90
91#[derive(Clone, Serialize, Deserialize)]
93pub struct LocalEndpoint {
94 rank: usize,
95 endpoints: Box<[NetworkChannelEndpoint]>,
96}
97
98pub struct RemoteEndpoints(Box<[NetworkChannelEndpoint]>);
103
104impl PreparedNode {
105 pub fn endpoint(&self) -> LocalEndpoint {
107 LocalEndpoint {
108 rank: self.rank,
109 endpoints: self
110 .multi_channel
111 .endpoints()
112 .into_iter()
113 .map(|single_channel_endpoint| NetworkChannelEndpoint {
114 single_channel_endpoint,
115 barrier_mr_remote: self.barrier.remote_mr(),
116 })
117 .collect(),
118 }
119 }
120
121 pub fn gather_endpoints(
126 &self,
127 endpoints: impl IntoIterator<Item = LocalEndpoint>,
128 ) -> io::Result<RemoteEndpoints> {
129 let mut tmp: Vec<Option<NetworkChannelEndpoint>> = vec![None; self.world_size];
131
132 for endpoint in endpoints {
133 let in_slot = tmp.get_mut(endpoint.rank).ok_or_else(|| {
135 io::Error::new(
136 io::ErrorKind::InvalidInput,
137 format!(
138 "Input endpoint rank {} out of bounds (0..{})",
139 endpoint.rank, self.world_size
140 ),
141 )
142 })?;
143
144 if in_slot.is_some() {
146 return Err(io::Error::new(
147 io::ErrorKind::InvalidInput,
148 format!("Duplicate endpoint for rank {}", endpoint.rank),
149 ));
150 }
151
152 let qp_endpoint = endpoint.endpoints.get(self.rank).ok_or_else(|| {
154 io::Error::new(
155 io::ErrorKind::InvalidInput,
156 format!(
157 "Input endpoint for rank {} missing rechannel for local rank {}",
158 endpoint.rank, self.rank
159 ),
160 )
161 })?;
162
163 *in_slot = Some(*qp_endpoint);
165 }
166
167 let in_endpoints: Vec<NetworkChannelEndpoint> = tmp
169 .into_iter()
170 .enumerate()
171 .map(|(i, opt)| {
172 opt.ok_or_else(|| {
173 io::Error::new(
174 io::ErrorKind::InvalidData,
175 format!("Missing endpoint from rank {}", i),
176 )
177 })
178 })
179 .collect::<Result<_, _>>()?;
180
181 Ok(RemoteEndpoints(in_endpoints.into_boxed_slice()))
182 }
183
184 pub fn handshake(self, endpoints: RemoteEndpoints) -> IbvResult<Node> {
186 let multi_channel = self
187 .multi_channel
188 .handshake(endpoints.0.iter().map(|e| e.single_channel_endpoint))?;
189 let barrier = self
190 .barrier
191 .link_remote(endpoints.0.iter().map(|e| e.barrier_mr_remote).collect());
192
193 Ok(Node {
194 rank: self.rank,
195 world_size: self.world_size,
196 multi_channel,
197 barrier,
198 })
199 }
200}