async_ucx/ucp/endpoint/
rma.rs

1use super::*;
2
3/// A memory region allocated through UCP library,
4/// which is optimized for remote memory access operations.
5#[derive(Debug)]
6pub struct MemoryHandle {
7    handle: ucp_mem_h,
8    context: Arc<Context>,
9}
10
11impl MemoryHandle {
12    /// Register memory region.
13    pub fn register(context: &Arc<Context>, region: &mut [u8]) -> Self {
14        #[allow(invalid_value)]
15        #[allow(clippy::uninit_assumed_init)]
16        let params = ucp_mem_map_params_t {
17            field_mask: (ucp_mem_map_params_field::UCP_MEM_MAP_PARAM_FIELD_ADDRESS
18                | ucp_mem_map_params_field::UCP_MEM_MAP_PARAM_FIELD_LENGTH)
19                .0 as u64,
20            address: region.as_ptr() as _,
21            length: region.len() as _,
22            ..unsafe { MaybeUninit::uninit().assume_init() }
23        };
24        let mut handle = MaybeUninit::uninit();
25        let status = unsafe { ucp_mem_map(context.handle, &params, handle.as_mut_ptr()) };
26        assert_eq!(status, ucs_status_t::UCS_OK);
27        MemoryHandle {
28            handle: unsafe { handle.assume_init() },
29            context: context.clone(),
30        }
31    }
32
33    /// Packs into the buffer a remote access key (RKEY) object.
34    pub fn pack(&self) -> RKeyBuffer {
35        let mut buf = MaybeUninit::uninit();
36        let mut len = MaybeUninit::uninit();
37        let status = unsafe {
38            ucp_rkey_pack(
39                self.context.handle,
40                self.handle,
41                buf.as_mut_ptr(),
42                len.as_mut_ptr(),
43            )
44        };
45        assert_eq!(status, ucs_status_t::UCS_OK);
46        RKeyBuffer {
47            buf: unsafe { buf.assume_init() },
48            len: unsafe { len.assume_init() },
49        }
50    }
51}
52
53impl Drop for MemoryHandle {
54    fn drop(&mut self) {
55        unsafe { ucp_mem_unmap(self.context.handle, self.handle) };
56    }
57}
58
59/// An owned buffer containing remote access key.
60#[derive(Debug)]
61pub struct RKeyBuffer {
62    buf: *mut c_void,
63    len: u64,
64}
65
66impl AsRef<[u8]> for RKeyBuffer {
67    fn as_ref(&self) -> &[u8] {
68        unsafe { std::slice::from_raw_parts(self.buf as _, self.len as _) }
69    }
70}
71
72impl Drop for RKeyBuffer {
73    fn drop(&mut self) {
74        unsafe { ucp_rkey_buffer_release(self.buf as _) }
75    }
76}
77
78/// Remote access key.
79#[derive(Debug)]
80pub struct RKey {
81    handle: ucp_rkey_h,
82}
83
84unsafe impl Send for RKey {}
85unsafe impl Sync for RKey {}
86
87impl RKey {
88    /// Create remote access key from packed buffer.
89    pub fn unpack(endpoint: &Endpoint, rkey_buffer: &[u8]) -> Self {
90        let mut handle = MaybeUninit::uninit();
91        let status = unsafe {
92            ucp_ep_rkey_unpack(
93                endpoint.handle,
94                rkey_buffer.as_ptr() as _,
95                handle.as_mut_ptr(),
96            )
97        };
98        assert_eq!(status, ucs_status_t::UCS_OK);
99        RKey {
100            handle: unsafe { handle.assume_init() },
101        }
102    }
103}
104
105impl Drop for RKey {
106    fn drop(&mut self) {
107        unsafe { ucp_rkey_destroy(self.handle) }
108    }
109}
110
111impl Endpoint {
112    /// Stores a contiguous block of data into remote memory.
113    pub async fn put(&self, buf: &[u8], remote_addr: u64, rkey: &RKey) -> Result<(), Error> {
114        trace!("put: endpoint={:?} len={}", self.handle, buf.len());
115        unsafe extern "C" fn callback(request: *mut c_void, status: ucs_status_t) {
116            trace!("put: complete. req={:?}, status={:?}", request, status);
117            let request = &mut *(request as *mut Request);
118            request.waker.wake();
119        }
120        let status = unsafe {
121            ucp_put_nb(
122                self.get_handle()?,
123                buf.as_ptr() as _,
124                buf.len() as _,
125                remote_addr,
126                rkey.handle,
127                Some(callback),
128            )
129        };
130        if status.is_null() {
131            trace!("put: complete.");
132            Ok(())
133        } else if UCS_PTR_IS_PTR(status) {
134            RequestHandle {
135                ptr: status,
136                poll_fn: poll_normal,
137            }
138            .await
139        } else {
140            Error::from_ptr(status)
141        }
142    }
143
144    /// Loads a contiguous block of data from remote memory.
145    pub async fn get(&self, buf: &mut [u8], remote_addr: u64, rkey: &RKey) -> Result<(), Error> {
146        trace!("get: endpoint={:?} len={}", self.handle, buf.len());
147        unsafe extern "C" fn callback(request: *mut c_void, status: ucs_status_t) {
148            trace!("get: complete. req={:?}, status={:?}", request, status);
149            let request = &mut *(request as *mut Request);
150            request.waker.wake();
151        }
152        let status = unsafe {
153            ucp_get_nb(
154                self.get_handle()?,
155                buf.as_mut_ptr() as _,
156                buf.len() as _,
157                remote_addr,
158                rkey.handle,
159                Some(callback),
160            )
161        };
162        if status.is_null() {
163            trace!("get: complete.");
164            Ok(())
165        } else if UCS_PTR_IS_PTR(status) {
166            RequestHandle {
167                ptr: status,
168                poll_fn: poll_normal,
169            }
170            .await
171        } else {
172            Error::from_ptr(status)
173        }
174    }
175}
176
177#[cfg(test)]
178mod tests {
179    use super::*;
180    #[test_log::test]
181    fn put_get() {
182        spawn_thread!(_put_get()).join().unwrap();
183    }
184
185    async fn _put_get() {
186        let context1 = Context::new().unwrap();
187        let worker1 = context1.create_worker().unwrap();
188        let context2 = Context::new().unwrap();
189        let worker2 = context2.create_worker().unwrap();
190        tokio::task::spawn_local(worker1.clone().polling());
191        tokio::task::spawn_local(worker2.clone().polling());
192
193        // connect with each other
194        let mut listener = worker1
195            .create_listener("0.0.0.0:0".parse().unwrap())
196            .unwrap();
197        let listen_port = listener.socket_addr().unwrap().port();
198        let mut addr: SocketAddr = "127.0.0.1:0".parse().unwrap();
199        addr.set_port(listen_port);
200        let (endpoint1, endpoint2) = tokio::join!(
201            async {
202                let conn1 = listener.next().await;
203                worker1.accept(conn1).await.unwrap()
204            },
205            async { worker2.connect_socket(addr).await.unwrap() },
206        );
207
208        let mut buf1: Vec<u8> = vec![0; 0x1000];
209        let mut buf2: Vec<u8> = (0..0x1000).map(|x| x as u8).collect();
210
211        // register memory region
212        let mem1 = MemoryHandle::register(&context1, &mut buf1);
213        let rkey_buf = mem1.pack();
214        let rkey2 = RKey::unpack(&endpoint2, rkey_buf.as_ref());
215
216        // test put
217        endpoint2
218            .put(&buf2[..], buf1.as_mut_ptr() as u64, &rkey2)
219            .await
220            .unwrap();
221        // The completion of the put operation only means that the reference to the buffer can be released.
222        // Remote side may can't see the result of put operation, need flush or barray.
223        // Otherwise, there is a certain chance that the test will fail.
224        endpoint1.flush().await.unwrap();
225        endpoint2.flush().await.unwrap();
226        assert_eq!(&buf1[..], &buf2[..]);
227
228        // test get
229        buf1.iter_mut().for_each(|x| *x = 0);
230        endpoint2
231            .get(&mut buf2[..], buf1.as_ptr() as u64, &rkey2)
232            .await
233            .unwrap();
234        assert_eq!(&buf1[..], &buf2[..]);
235
236        assert_eq!(endpoint1.get_rc(), (1, 1));
237        assert_eq!(endpoint2.get_rc(), (1, 1));
238        assert_eq!(endpoint1.close(false).await, Ok(()));
239        assert_eq!(endpoint2.close(false).await, Err(Error::ConnectionReset));
240        assert_eq!(endpoint1.get_rc(), (1, 0));
241        assert_eq!(endpoint2.get_rc(), (1, 1));
242        assert_eq!(endpoint2.close(true).await, Ok(()));
243        assert_eq!(endpoint2.get_rc(), (1, 0));
244    }
245}