1use super::*;
2
3#[derive(Debug)]
6pub struct MemoryHandle {
7 handle: ucp_mem_h,
8 context: Arc<Context>,
9}
10
11impl MemoryHandle {
12 pub fn register(context: &Arc<Context>, region: &mut [u8]) -> Self {
14 #[allow(invalid_value)]
15 #[allow(clippy::uninit_assumed_init)]
16 let params = ucp_mem_map_params_t {
17 field_mask: (ucp_mem_map_params_field::UCP_MEM_MAP_PARAM_FIELD_ADDRESS
18 | ucp_mem_map_params_field::UCP_MEM_MAP_PARAM_FIELD_LENGTH)
19 .0 as u64,
20 address: region.as_ptr() as _,
21 length: region.len() as _,
22 ..unsafe { MaybeUninit::uninit().assume_init() }
23 };
24 let mut handle = MaybeUninit::uninit();
25 let status = unsafe { ucp_mem_map(context.handle, ¶ms, handle.as_mut_ptr()) };
26 assert_eq!(status, ucs_status_t::UCS_OK);
27 MemoryHandle {
28 handle: unsafe { handle.assume_init() },
29 context: context.clone(),
30 }
31 }
32
33 pub fn pack(&self) -> RKeyBuffer {
35 let mut buf = MaybeUninit::uninit();
36 let mut len = MaybeUninit::uninit();
37 let status = unsafe {
38 ucp_rkey_pack(
39 self.context.handle,
40 self.handle,
41 buf.as_mut_ptr(),
42 len.as_mut_ptr(),
43 )
44 };
45 assert_eq!(status, ucs_status_t::UCS_OK);
46 RKeyBuffer {
47 buf: unsafe { buf.assume_init() },
48 len: unsafe { len.assume_init() },
49 }
50 }
51}
52
53impl Drop for MemoryHandle {
54 fn drop(&mut self) {
55 unsafe { ucp_mem_unmap(self.context.handle, self.handle) };
56 }
57}
58
59#[derive(Debug)]
61pub struct RKeyBuffer {
62 buf: *mut c_void,
63 len: u64,
64}
65
66impl AsRef<[u8]> for RKeyBuffer {
67 fn as_ref(&self) -> &[u8] {
68 unsafe { std::slice::from_raw_parts(self.buf as _, self.len as _) }
69 }
70}
71
72impl Drop for RKeyBuffer {
73 fn drop(&mut self) {
74 unsafe { ucp_rkey_buffer_release(self.buf as _) }
75 }
76}
77
78#[derive(Debug)]
80pub struct RKey {
81 handle: ucp_rkey_h,
82}
83
84unsafe impl Send for RKey {}
85unsafe impl Sync for RKey {}
86
87impl RKey {
88 pub fn unpack(endpoint: &Endpoint, rkey_buffer: &[u8]) -> Self {
90 let mut handle = MaybeUninit::uninit();
91 let status = unsafe {
92 ucp_ep_rkey_unpack(
93 endpoint.handle,
94 rkey_buffer.as_ptr() as _,
95 handle.as_mut_ptr(),
96 )
97 };
98 assert_eq!(status, ucs_status_t::UCS_OK);
99 RKey {
100 handle: unsafe { handle.assume_init() },
101 }
102 }
103}
104
105impl Drop for RKey {
106 fn drop(&mut self) {
107 unsafe { ucp_rkey_destroy(self.handle) }
108 }
109}
110
111impl Endpoint {
112 pub async fn put(&self, buf: &[u8], remote_addr: u64, rkey: &RKey) -> Result<(), Error> {
114 trace!("put: endpoint={:?} len={}", self.handle, buf.len());
115 unsafe extern "C" fn callback(request: *mut c_void, status: ucs_status_t) {
116 trace!("put: complete. req={:?}, status={:?}", request, status);
117 let request = &mut *(request as *mut Request);
118 request.waker.wake();
119 }
120 let status = unsafe {
121 ucp_put_nb(
122 self.get_handle()?,
123 buf.as_ptr() as _,
124 buf.len() as _,
125 remote_addr,
126 rkey.handle,
127 Some(callback),
128 )
129 };
130 if status.is_null() {
131 trace!("put: complete.");
132 Ok(())
133 } else if UCS_PTR_IS_PTR(status) {
134 RequestHandle {
135 ptr: status,
136 poll_fn: poll_normal,
137 }
138 .await
139 } else {
140 Error::from_ptr(status)
141 }
142 }
143
144 pub async fn get(&self, buf: &mut [u8], remote_addr: u64, rkey: &RKey) -> Result<(), Error> {
146 trace!("get: endpoint={:?} len={}", self.handle, buf.len());
147 unsafe extern "C" fn callback(request: *mut c_void, status: ucs_status_t) {
148 trace!("get: complete. req={:?}, status={:?}", request, status);
149 let request = &mut *(request as *mut Request);
150 request.waker.wake();
151 }
152 let status = unsafe {
153 ucp_get_nb(
154 self.get_handle()?,
155 buf.as_mut_ptr() as _,
156 buf.len() as _,
157 remote_addr,
158 rkey.handle,
159 Some(callback),
160 )
161 };
162 if status.is_null() {
163 trace!("get: complete.");
164 Ok(())
165 } else if UCS_PTR_IS_PTR(status) {
166 RequestHandle {
167 ptr: status,
168 poll_fn: poll_normal,
169 }
170 .await
171 } else {
172 Error::from_ptr(status)
173 }
174 }
175}
176
177#[cfg(test)]
178mod tests {
179 use super::*;
180 #[test_log::test]
181 fn put_get() {
182 spawn_thread!(_put_get()).join().unwrap();
183 }
184
185 async fn _put_get() {
186 let context1 = Context::new().unwrap();
187 let worker1 = context1.create_worker().unwrap();
188 let context2 = Context::new().unwrap();
189 let worker2 = context2.create_worker().unwrap();
190 tokio::task::spawn_local(worker1.clone().polling());
191 tokio::task::spawn_local(worker2.clone().polling());
192
193 let mut listener = worker1
195 .create_listener("0.0.0.0:0".parse().unwrap())
196 .unwrap();
197 let listen_port = listener.socket_addr().unwrap().port();
198 let mut addr: SocketAddr = "127.0.0.1:0".parse().unwrap();
199 addr.set_port(listen_port);
200 let (endpoint1, endpoint2) = tokio::join!(
201 async {
202 let conn1 = listener.next().await;
203 worker1.accept(conn1).await.unwrap()
204 },
205 async { worker2.connect_socket(addr).await.unwrap() },
206 );
207
208 let mut buf1: Vec<u8> = vec![0; 0x1000];
209 let mut buf2: Vec<u8> = (0..0x1000).map(|x| x as u8).collect();
210
211 let mem1 = MemoryHandle::register(&context1, &mut buf1);
213 let rkey_buf = mem1.pack();
214 let rkey2 = RKey::unpack(&endpoint2, rkey_buf.as_ref());
215
216 endpoint2
218 .put(&buf2[..], buf1.as_mut_ptr() as u64, &rkey2)
219 .await
220 .unwrap();
221 endpoint1.flush().await.unwrap();
225 endpoint2.flush().await.unwrap();
226 assert_eq!(&buf1[..], &buf2[..]);
227
228 buf1.iter_mut().for_each(|x| *x = 0);
230 endpoint2
231 .get(&mut buf2[..], buf1.as_ptr() as u64, &rkey2)
232 .await
233 .unwrap();
234 assert_eq!(&buf1[..], &buf2[..]);
235
236 assert_eq!(endpoint1.get_rc(), (1, 1));
237 assert_eq!(endpoint2.get_rc(), (1, 1));
238 assert_eq!(endpoint1.close(false).await, Ok(()));
239 assert_eq!(endpoint2.close(false).await, Err(Error::ConnectionReset));
240 assert_eq!(endpoint1.get_rc(), (1, 0));
241 assert_eq!(endpoint2.get_rc(), (1, 1));
242 assert_eq!(endpoint2.close(true).await, Ok(()));
243 assert_eq!(endpoint2.get_rc(), (1, 0));
244 }
245}