nixl_sys/
notify.rs

1// SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2// SPDX-License-Identifier: Apache-2.0
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10// Unless required by applicable law or agreed to in writing, software
11// distributed under the License is distributed on an "AS IS" BASIS,
12// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13// See the License for the specific language governing permissions and
14// limitations under the License.
15
16use super::*;
17
18use std::{
19    collections::HashMap,
20    ffi::{CStr, CString},
21    os::raw::c_char, // Added for *const c_char
22    ptr::{self, NonNull},
23};
24
25/// A safe wrapper around NIXL notification map
26pub struct NotificationMap {
27    pub(crate) inner: NonNull<bindings::nixl_capi_notif_map_s>,
28}
29
30impl NotificationMap {
31    /// Creates a new empty notification map
32    pub fn new() -> Result<Self, NixlError> {
33        let mut map = ptr::null_mut();
34        let status = unsafe { nixl_capi_create_notif_map(&mut map) };
35        match status {
36            NIXL_CAPI_SUCCESS => {
37                // SAFETY: If status is NIXL_CAPI_SUCCESS, map is non-null
38                let inner = unsafe { NonNull::new_unchecked(map) };
39                Ok(Self { inner })
40            }
41            NIXL_CAPI_ERROR_INVALID_PARAM => Err(NixlError::InvalidParam),
42            _ => Err(NixlError::BackendError),
43        }
44    }
45
46    /// Returns the number of agents that have notifications
47    pub fn len(&self) -> Result<usize, NixlError> {
48        let mut size = 0;
49        let status = unsafe { nixl_capi_notif_map_size(self.inner.as_ptr(), &mut size) };
50        match status {
51            NIXL_CAPI_SUCCESS => Ok(size),
52            NIXL_CAPI_ERROR_INVALID_PARAM => Err(NixlError::InvalidParam),
53            _ => Err(NixlError::BackendError),
54        }
55    }
56
57    /// Returns true if there are no notifications
58    pub fn is_empty(&self) -> Result<bool, NixlError> {
59        Ok(self.len()? == 0)
60    }
61
62    /// Returns an iterator over the agent names that have notifications
63    pub fn agents(&self) -> NotificationMapAgentIterator<'_> {
64        NotificationMapAgentIterator {
65            map: self,
66            index: 0,
67            length: self.len().unwrap_or(0),
68        }
69    }
70
71    /// Returns the number of notifications for a given agent
72    pub fn get_notifications_size(&self, agent_name: &str) -> Result<usize, NixlError> {
73        let mut size = 0;
74        let c_name = CString::new(agent_name).map_err(|_| NixlError::InvalidParam)?;
75        let status = unsafe {
76            nixl_capi_notif_map_get_notifs_size(self.inner.as_ptr(), c_name.as_ptr(), &mut size)
77        };
78        match status {
79            NIXL_CAPI_SUCCESS => Ok(size),
80            NIXL_CAPI_ERROR_INVALID_PARAM => Err(NixlError::InvalidParam),
81            _ => Err(NixlError::BackendError),
82        }
83    }
84
85    /// Returns an iterator over the notifications for a given agent
86    pub fn get_notifications(
87        &self,
88        agent_name: &str,
89    ) -> Result<NotificationIterator<'_>, NixlError> {
90        let size = self.get_notifications_size(agent_name)?;
91        Ok(NotificationIterator {
92            map: self,
93            agent_name: agent_name.to_string(),
94            index: 0,
95            length: size,
96        })
97    }
98
99    /// Returns a specific notification for a given agent as raw bytes
100    pub fn get_notification_bytes(
101        &self,
102        agent_name: &str,
103        index: usize,
104    ) -> Result<Vec<u8>, NixlError> {
105        let c_name = CString::new(agent_name).map_err(|_| NixlError::InvalidParam)?;
106        let mut data: *const u8 = ptr::null();
107        let mut len = 0;
108        let status = unsafe {
109            nixl_capi_notif_map_get_notif(
110                self.inner.as_ptr(),
111                c_name.as_ptr(),
112                index,
113                &mut data as *mut *const _ as *mut *const std::ffi::c_void,
114                &mut len,
115            )
116        };
117        match status {
118            NIXL_CAPI_SUCCESS => {
119                if data.is_null() {
120                    Ok(Vec::new())
121                } else {
122                    // SAFETY: If status is NIXL_CAPI_SUCCESS, data points to valid memory of size len
123                    // This data is owned by the C side and is valid until the map is cleared or modified.
124                    let bytes = unsafe {
125                        let slice = std::slice::from_raw_parts(data, len);
126                        slice.to_vec()
127                    };
128                    Ok(bytes)
129                }
130            }
131            NIXL_CAPI_ERROR_INVALID_PARAM => Err(NixlError::InvalidParam),
132            _ => Err(NixlError::BackendError),
133        }
134    }
135
136    /// Takes all notifications from the map, converting them to Strings,
137    /// and clears the underlying C map for reuse.
138    ///
139    /// Returns a HashMap where keys are agent names and values are vectors of
140    /// notification strings for that agent.
141    ///
142    /// If a notification\'s byte data is not valid UTF-8, this method will
143    /// return an error (`NixlError::BackendError` in current impl, ideally a specific UTF-8 error).
144    pub fn take_notifs(&mut self) -> Result<HashMap<String, Vec<String>>, NixlError> {
145        let mut all_notifications = HashMap::new();
146        let num_agents = self.len()?;
147
148        for agent_idx in 0..num_agents {
149            let mut c_agent_name_ptr: *const c_char = ptr::null();
150            let status_agent_name = unsafe {
151                nixl_capi_notif_map_get_agent_at(
152                    self.inner.as_ptr(),
153                    agent_idx,
154                    &mut c_agent_name_ptr,
155                )
156            };
157
158            if status_agent_name != NIXL_CAPI_SUCCESS {
159                // This case should ideally not happen if num_agents is correct
160                // and map is consistent.
161                return Err(if status_agent_name == NIXL_CAPI_ERROR_INVALID_PARAM {
162                    NixlError::InvalidParam
163                } else {
164                    NixlError::BackendError
165                });
166            }
167
168            if c_agent_name_ptr.is_null() {
169                // Should not happen if get_agent_at succeeded.
170                return Err(NixlError::BackendError);
171            }
172
173            let agent_name_cstr = unsafe { CStr::from_ptr(c_agent_name_ptr) };
174            let agent_name_string = agent_name_cstr
175                .to_str()
176                .map_err(|_| NixlError::InvalidParam)? // Map UTF-8 error on agent name to InvalidParam
177                .to_owned();
178
179            let mut num_notifs_for_agent = 0;
180            let status_notif_size = unsafe {
181                nixl_capi_notif_map_get_notifs_size(
182                    self.inner.as_ptr(),
183                    c_agent_name_ptr, // Use the C string directly
184                    &mut num_notifs_for_agent,
185                )
186            };
187
188            if status_notif_size != NIXL_CAPI_SUCCESS {
189                return Err(if status_notif_size == NIXL_CAPI_ERROR_INVALID_PARAM {
190                    NixlError::InvalidParam
191                } else {
192                    NixlError::BackendError
193                });
194            }
195
196            let mut agent_specific_notifications = Vec::with_capacity(num_notifs_for_agent);
197
198            for notif_idx in 0..num_notifs_for_agent {
199                let mut data_ptr: *const std::ffi::c_void = ptr::null();
200                let mut data_len: usize = 0;
201
202                let status_notif_data = unsafe {
203                    nixl_capi_notif_map_get_notif(
204                        self.inner.as_ptr(),
205                        c_agent_name_ptr, // Use the C string directly
206                        notif_idx,
207                        &mut data_ptr,
208                        &mut data_len,
209                    )
210                };
211
212                if status_notif_data != NIXL_CAPI_SUCCESS {
213                    return Err(if status_notif_data == NIXL_CAPI_ERROR_INVALID_PARAM {
214                        NixlError::InvalidParam
215                    } else {
216                        NixlError::BackendError
217                    });
218                }
219
220                let notification_bytes = if data_ptr.is_null() || data_len == 0 {
221                    Vec::new()
222                } else {
223                    // SAFETY: Pointer and length are from a successful C API call.
224                    // Data is valid until map is cleared/modified. We copy it immediately.
225                    unsafe { std::slice::from_raw_parts(data_ptr as *const u8, data_len) }.to_vec()
226                };
227
228                // Attempt to convert Vec<u8> to String
229                let notification_string =
230                    String::from_utf8(notification_bytes).map_err(|_e| NixlError::BackendError)?; // FIXME: Ideally, a specific UTF-8 error variant in NixlError (e.g., InvalidNotificationEncoding)
231
232                agent_specific_notifications.push(notification_string);
233            }
234            all_notifications.insert(agent_name_string, agent_specific_notifications);
235        }
236
237        // After successfully extracting all data, clear the C map
238        let clear_status = unsafe { nixl_capi_notif_map_clear(self.inner.as_ptr()) };
239        match clear_status {
240            NIXL_CAPI_SUCCESS => Ok(all_notifications),
241            NIXL_CAPI_ERROR_INVALID_PARAM => Err(NixlError::InvalidParam), // Should not happen if self.inner is valid
242            _ => Err(NixlError::BackendError),
243        }
244    }
245}
246
247/// An iterator over agent names in a NotificationMap
248pub struct NotificationMapAgentIterator<'a> {
249    map: &'a NotificationMap,
250    index: usize,
251    length: usize,
252}
253
254impl<'a> Iterator for NotificationMapAgentIterator<'a> {
255    type Item = Result<&'a str, NixlError>;
256
257    fn next(&mut self) -> Option<Self::Item> {
258        if self.index >= self.length {
259            None
260        } else {
261            let mut agent_name = ptr::null();
262            let status = unsafe {
263                nixl_capi_notif_map_get_agent_at(
264                    self.map.inner.as_ptr(),
265                    self.index,
266                    &mut agent_name,
267                )
268            };
269            self.index += 1;
270            match status {
271                NIXL_CAPI_SUCCESS => {
272                    // SAFETY: If status is NIXL_CAPI_SUCCESS, agent_name points to a valid C string
273                    let name = unsafe { CStr::from_ptr(agent_name) };
274                    Some(name.to_str().map_err(|_| NixlError::InvalidParam))
275                }
276                NIXL_CAPI_ERROR_INVALID_PARAM => Some(Err(NixlError::InvalidParam)),
277                _ => Some(Err(NixlError::BackendError)),
278            }
279        }
280    }
281
282    fn size_hint(&self) -> (usize, Option<usize>) {
283        let remaining = self.length - self.index;
284        (remaining, Some(remaining))
285    }
286}
287
288/// An iterator over notifications for a specific agent
289pub struct NotificationIterator<'a> {
290    map: &'a NotificationMap,
291    agent_name: String,
292    index: usize,
293    length: usize,
294}
295
296impl Iterator for NotificationIterator<'_> {
297    type Item = Result<Vec<u8>, NixlError>;
298
299    fn next(&mut self) -> Option<Self::Item> {
300        if self.index >= self.length {
301            None
302        } else {
303            let result = self
304                .map
305                .get_notification_bytes(&self.agent_name, self.index);
306            self.index += 1;
307            Some(result)
308        }
309    }
310
311    fn size_hint(&self) -> (usize, Option<usize>) {
312        let remaining = self.length - self.index;
313        (remaining, Some(remaining))
314    }
315}
316
317impl Drop for NotificationMap {
318    fn drop(&mut self) {
319        tracing::trace!("Dropping notification map");
320        unsafe {
321            nixl_capi_destroy_notif_map(self.inner.as_ptr());
322        }
323        tracing::trace!("Notification map dropped");
324    }
325}