Skip to main content

io_engine/
tasks.rs

1use std::cell::UnsafeCell;
2use std::fmt;
3use std::ops::{Deref, DerefMut};
4use std::os::fd::RawFd;
5
6use nix::errno::Errno;
7
8use embed_collections::slist::{SLinkedList, SListItem, SListNode};
9use io_buffer::{Buffer, safe_copy};
10
11pub enum BufOrLen {
12    Buffer(Buffer),
13    Len(u64),
14}
15
16#[derive(Copy, Clone, PartialEq, Debug)]
17pub enum IOAction {
18    Read = 0,
19    Write = 1,
20    Alloc = 2,
21    Fsync = 3,
22}
23
24/// Define your callback with this trait
25pub trait IOCallback: Sized + 'static + Send + Unpin {
26    fn call(self, _event: IOEvent<Self>);
27}
28
29/// Closure callback for IOEvent
30pub struct ClosureCb(pub Box<dyn FnOnce(IOEvent<Self>) + Send + 'static>);
31
32impl IOCallback for ClosureCb {
33    fn call(self, event: IOEvent<Self>) {
34        (self.0)(event)
35    }
36}
37
38pub struct IOEvent<C: IOCallback>(pub Box<IOEvent_<C>>);
39
40impl<C: IOCallback> Deref for IOEvent<C> {
41    type Target = IOEvent_<C>;
42    fn deref(&self) -> &Self::Target {
43        &self.0
44    }
45}
46
47impl<C: IOCallback> DerefMut for IOEvent<C> {
48    fn deref_mut(&mut self) -> &mut Self::Target {
49        &mut self.0
50    }
51}
52
53impl<C: IOCallback> fmt::Debug for IOEvent<C> {
54    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
55        self.0.fmt(f)
56    }
57}
58
59// Carries the information of read/write event
60#[repr(C)]
61pub struct IOEvent_<C: IOCallback> {
62    /// make sure SListNode always in the front.
63    /// This is for putting sub_tasks in the link list, without additional allocation.
64    pub(crate) node: UnsafeCell<SListNode<Self, ()>>,
65    pub buf_or_len: BufOrLen,
66    pub offset: i64,
67    pub action: IOAction,
68    pub fd: RawFd,
69    /// Result of the IO operation.
70    /// Initialized to i32::MIN.
71    /// >= 0: Accumulated bytes transferred (used for partial IO retries).
72    /// < 0: Error code (negative errno).
73    pub(crate) res: i32,
74    cb: Option<C>,
75    sub_tasks: SLinkedList<Box<Self>, ()>,
76}
77
78// Implement SListItem for IOEvent_ to allow it to be linked
79unsafe impl<C: IOCallback> SListItem<()> for IOEvent_<C> {
80    fn get_node(&self) -> &mut SListNode<Self, ()> {
81        unsafe { &mut *self.node.get() }
82    }
83}
84
85impl<C: IOCallback> fmt::Debug for IOEvent_<C> {
86    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
87        write!(f, "offset={} {:?} sub_tasks {}", self.offset, self.action, self.sub_tasks.len())
88    }
89}
90
91impl<C: IOCallback> IOEvent<C> {
92    #[inline]
93    pub fn new(fd: RawFd, buf: Buffer, action: IOAction, offset: i64) -> IOEvent<C> {
94        log_assert!(buf.len() > 0, "{:?} offset={}, buffer size == 0", action, offset);
95        IOEvent(Box::new(IOEvent_ {
96            buf_or_len: BufOrLen::Buffer(buf),
97            fd,
98            action,
99            offset,
100            res: i32::MIN,
101            cb: None,
102            sub_tasks: SLinkedList::new(),
103            node: UnsafeCell::new(SListNode::default()),
104        }))
105    }
106
107    #[inline]
108    pub fn new_no_buf(fd: RawFd, action: IOAction, offset: i64, len: u64) -> IOEvent<C> {
109        IOEvent(Box::new(IOEvent_ {
110            buf_or_len: BufOrLen::Len(len), // No buffer for this action
111            fd,
112            action,
113            offset,
114            res: i32::MIN,
115            cb: None,
116            sub_tasks: SLinkedList::new(),
117            node: UnsafeCell::new(SListNode::default()),
118        }))
119    }
120
121    #[inline(always)]
122    pub fn set_fd(&mut self, fd: RawFd) {
123        self.fd = fd;
124    }
125
126    /// Set callback for IOEvent, might be closure or a custom struct
127    #[inline(always)]
128    pub fn set_callback(&mut self, cb: C) {
129        self.cb = Some(cb);
130    }
131
132    #[inline(always)]
133    pub fn get_size(&self) -> usize {
134        if let BufOrLen::Buffer(buf) = &self.buf_or_len { buf.len() } else { 0 }
135    }
136
137    #[inline(always)]
138    pub(crate) fn push_to_list(self, events: &mut SLinkedList<Box<IOEvent_<C>>, ()>) {
139        events.push_back(self.0);
140    }
141
142    #[inline(always)]
143    pub(crate) fn pop_from_list(events: &mut SLinkedList<Box<IOEvent_<C>>, ()>) -> Option<Self> {
144        events.pop_front().map(IOEvent)
145    }
146
147    #[inline(always)]
148    pub(crate) fn set_subtasks(&mut self, sub_tasks: SLinkedList<Box<IOEvent_<C>>, ()>) {
149        self.sub_tasks = sub_tasks;
150    }
151
152    #[inline(always)]
153    pub fn get_buf_ref<'a>(&'a self) -> &'a [u8] {
154        if let BufOrLen::Buffer(buf) = &self.buf_or_len {
155            buf.as_ref()
156        } else {
157            panic!("get_buf_ref called on IOEvent with no buffer");
158        }
159    }
160
161    #[inline(always)]
162    pub fn is_done(&self) -> bool {
163        self.res != i32::MIN
164    }
165
166    #[inline(always)]
167    pub fn get_write_result(self) -> Result<(), Errno> {
168        let res = self.res;
169        if res >= 0 {
170            return Ok(());
171        } else if res == i32::MIN {
172            panic!("IOEvent get_result before it's done");
173        } else {
174            return Err(Errno::from_raw(-res));
175        }
176    }
177
178    /// Get the result of the IO operation (bytes read/written or error).
179    /// Returns the number of bytes successfully transferred.
180    #[inline(always)]
181    pub fn get_result(&self) -> Result<usize, Errno> {
182        let res = self.res;
183        if res >= 0 {
184            return Ok(res as usize);
185        } else if res == i32::MIN {
186            panic!("IOEvent get_result before it's done");
187        } else {
188            return Err(Errno::from_raw(-res));
189        }
190    }
191
192    /// Get the buffer from a read operation.
193    /// Note: The buffer length is NOT modified. Use `get_result()` to get actual bytes read.
194    #[inline(always)]
195    pub fn get_read_result(mut self) -> Result<Buffer, Errno> {
196        let res = self.res;
197        if res >= 0 {
198            let buf_or_len = std::mem::replace(&mut self.buf_or_len, BufOrLen::Len(0));
199            if let BufOrLen::Buffer(buf) = buf_or_len {
200                // Do NOT modify buffer length - caller should use get_result() to know actual bytes read
201                return Ok(buf);
202            } else {
203                panic!("get_read_result called on IOEvent with no buffer");
204            }
205        } else if res == i32::MIN {
206            panic!("IOEvent get_result before it's done");
207        } else {
208            return Err(Errno::from_raw(-res));
209        }
210    }
211
212    #[inline(always)]
213    pub(crate) fn set_error(&mut self, mut errno: i32) {
214        if errno == 0 {
215            // XXX: EOF does not have code to represent,
216            // also when offset is not align to 4096, may return result 0,
217            errno = Errno::EINVAL as i32;
218        }
219        if errno > 0 {
220            errno = -errno;
221        }
222        self.res = errno;
223    }
224
225    #[inline(always)]
226    pub(crate) fn set_copied(&mut self, len: usize) {
227        if self.res == i32::MIN {
228            self.res = len as i32;
229        } else {
230            self.res += len as i32;
231        }
232    }
233
234    /// Trigger the callback for this IOEvent.
235    /// This consumes the event and calls the associated callback.
236    #[inline(always)]
237    pub(crate) fn callback(mut self) {
238        match self.cb.take() {
239            Some(cb) => {
240                cb.call(self);
241            }
242            None => return,
243        }
244    }
245
246    /// For writing custom callback workers
247    ///
248    /// Callback worker should always call this function on receiving IOEvent from Driver
249    #[inline(always)]
250    pub fn callback_merged(mut self) {
251        if !self.sub_tasks.is_empty() {
252            let res = self.res;
253            if res >= 0 {
254                if self.action == IOAction::Read {
255                    let buf_or_len = std::mem::replace(&mut self.buf_or_len, BufOrLen::Len(0));
256                    if let BufOrLen::Buffer(buffer) = buf_or_len {
257                        let mut b = buffer.as_ref();
258                        for event_box in self.sub_tasks.drain() {
259                            let mut event = IOEvent(event_box);
260                            if let BufOrLen::Buffer(sub_buf) = &mut event.buf_or_len {
261                                if b.len() == 0 {
262                                    // short read
263                                    event.set_copied(0);
264                                } else {
265                                    let copied = safe_copy(sub_buf, b);
266                                    event.set_copied(copied);
267                                    b = &b[copied..];
268                                }
269                            }
270                            event.callback();
271                        }
272                    }
273                } else {
274                    let l = self.get_size();
275                    for event_box in self.sub_tasks.drain() {
276                        let mut event = IOEvent(event_box);
277                        let mut sub_len = event.get_size();
278                        if sub_len > l {
279                            // short write
280                            sub_len = l;
281                        }
282                        event.set_copied(sub_len);
283                        event.callback();
284                    }
285                }
286            } else {
287                let errno = -res;
288                for event_box in self.sub_tasks.drain() {
289                    let mut event = IOEvent(event_box);
290                    event.set_error(errno);
291                    event.callback();
292                }
293            }
294        } else {
295            self.callback();
296        }
297    }
298
299    // New constructor for exit signal events
300    pub(crate) fn new_exit_signal(fd: RawFd) -> Self {
301        // Exit signal wraps a IOEvent
302        Self(Box::new(IOEvent_ {
303            node: UnsafeCell::new(SListNode::default()),
304            buf_or_len: BufOrLen::Len(0),
305            offset: 0,
306            action: IOAction::Read, // Exit signal is a read
307            fd,
308            res: i32::MIN,
309            cb: None, // No callback for exit signal
310            sub_tasks: SLinkedList::new(),
311        }))
312    }
313}