Skip to main content

ibverbs_rs/multi_channel/
builder.rs

1use crate::channel::Channel;
2use crate::channel::PreparedChannel;
3use crate::ibverbs::access_config::AccessFlags;
4use crate::ibverbs::error::{IbvError, IbvResult};
5use crate::ibverbs::protection_domain::ProtectionDomain;
6use crate::ibverbs::queue_pair::builder::QueuePairEndpoint;
7use crate::ibverbs::queue_pair::config::{
8    AckTimeout, MaxAckRetries, MaxRnrRetries, MaximumTransferUnit, MinRnrTimer,
9    PacketSequenceNumber,
10};
11use crate::multi_channel::MultiChannel;
12use bon::bon;
13
14#[bon]
15impl MultiChannel {
16    #[builder(state_mod(vis = "pub"))]
17    pub fn builder(
18        num_channels: usize,
19        pd: &ProtectionDomain,
20        #[builder(default =
21            AccessFlags::new()
22                .with_local_write()
23                .with_remote_read()
24                .with_remote_write()
25        )]
26        access: AccessFlags,
27        #[builder(default = 32)] min_cq_entries: u32,
28        #[builder(default = 16)] max_send_wr: u32,
29        #[builder(default = 16)] max_recv_wr: u32,
30        #[builder(default = 16)] max_send_sge: u32,
31        #[builder(default = 16)] max_recv_sge: u32,
32        #[builder(default)] max_rnr_retries: MaxRnrRetries,
33        #[builder(default)] max_ack_retries: MaxAckRetries,
34        #[builder(default)] min_rnr_timer: MinRnrTimer,
35        #[builder(default)] ack_timeout: AckTimeout,
36        #[builder(default)] mtu: MaximumTransferUnit,
37        #[builder(default)] send_psn: PacketSequenceNumber,
38        #[builder(default)] recv_psn: PacketSequenceNumber,
39    ) -> IbvResult<PreparedMultiChannel> {
40        let channels = (0..num_channels)
41            .map(|_| {
42                Channel::builder()
43                    .pd(pd)
44                    .min_cq_entries(min_cq_entries)
45                    .access(access)
46                    .max_send_wr(max_send_wr)
47                    .max_recv_wr(max_recv_wr)
48                    .max_send_sge(max_send_sge)
49                    .max_recv_sge(max_recv_sge)
50                    .max_rnr_retries(max_rnr_retries)
51                    .max_ack_retries(max_ack_retries)
52                    .min_rnr_timer(min_rnr_timer)
53                    .ack_timeout(ack_timeout)
54                    .mtu(mtu)
55                    .send_psn(send_psn)
56                    .recv_psn(recv_psn)
57                    .build()
58            })
59            .collect::<IbvResult<_>>()?;
60
61        Ok(PreparedMultiChannel {
62            channels,
63            pd: pd.clone(),
64        })
65    }
66}
67
68/// A [`MultiChannel`] that has been configured but not yet connected to a remote peer.
69///
70/// Created by [`MultiChannel::builder`]. Call [`endpoints`](Self::endpoints) to obtain the
71/// local connection information for each channel, exchange them with the remote side, then
72/// call [`handshake`](Self::handshake) with the remote's endpoints to finish the connections.
73pub struct PreparedMultiChannel {
74    channels: Box<[PreparedChannel]>,
75    pd: ProtectionDomain,
76}
77
78impl PreparedMultiChannel {
79    /// Returns the local endpoint information for each channel, needed by the remote peer.
80    pub fn endpoints(&self) -> Box<[QueuePairEndpoint]> {
81        self.channels.iter().map(|c| c.endpoint()).collect()
82    }
83
84    /// Connects each channel to the remote endpoint at the same index and returns a ready-to-use [`MultiChannel`].
85    pub fn handshake<I>(self, endpoints: I) -> IbvResult<MultiChannel>
86    where
87        I: IntoIterator<Item = QueuePairEndpoint>,
88        I::IntoIter: ExactSizeIterator,
89    {
90        let endpoints = endpoints.into_iter();
91        if self.channels.len() != endpoints.len() {
92            return Err(IbvError::InvalidInput(format!(
93                "Expected {} endpoints but got {}",
94                self.channels.len(),
95                endpoints.len()
96            )));
97        }
98
99        let channels = self
100            .channels
101            .into_iter()
102            .zip(endpoints)
103            .map(|(channel, endpoint)| channel.handshake(endpoint))
104            .collect::<Result<_, _>>()?;
105
106        Ok(MultiChannel {
107            channels,
108            pd: self.pd,
109        })
110    }
111}