Skip to main content

io_engine/
tasks.rs

1use std::fmt;
2use std::os::fd::RawFd;
3
4use crossfire::waitgroup::WaitGroupGuard;
5use embed_collections::SegList;
6use io_buffer::{Buffer, safe_copy};
7use rustix::io::Errno;
8
9#[derive(Copy, Clone, PartialEq, Debug)]
10#[repr(u8)]
11pub enum IOAction {
12    Read = 0,  // The same with IOCB_CMD_PREAD
13    Write = 1, // the same with IOCB_CMD_PWRITE
14    Alloc = 2,
15    Fsync = 3,
16}
17
18impl IOAction {
19    #[inline(always)]
20    pub fn is_read_write(&self) -> bool {
21        (*self as u8) < (IOAction::Alloc as u8)
22    }
23}
24
25pub trait CbArgs: Sized + 'static + Send + Unpin {
26    /// only called in MergeSubmitter (for NOMEM or SHUTDOWN)
27    fn set_merge_error(self, _e: Errno) {}
28}
29
30impl CbArgs for () {}
31
32impl<T: Send + Sync + 'static> CbArgs for WaitGroupGuard<T> {}
33
34// Carries the information of read/write event
35pub struct IOEvent<C: CbArgs> {
36    pub action: IOAction,
37    /// Result of the IO operation.
38    /// Initialized to i32::MIN.
39    /// - `>= 0`: Accumulated bytes transferred (used for partial IO retries).
40    /// - `<0`: Error code (negative errno).
41    pub(crate) res: i32,
42    /// make sure SListNode always in the front.
43    /// This is for putting sub_tasks in the link list, without additional allocation.
44    pub(crate) buf_or_len: BufOrLen,
45    pub offset: i64,
46    pub fd: RawFd,
47    pub(crate) args: Option<TaskArgs<C>>,
48}
49
50pub(crate) enum TaskArgs<C: CbArgs> {
51    Callback(C),
52    Merged(SegList<IOEventMerged<C>>),
53}
54
55pub(crate) enum BufOrLen {
56    Buffer(Buffer),
57    /// for fallocate
58    Len(u64),
59}
60
61pub(crate) struct IOEventMerged<C: CbArgs> {
62    pub buf: Buffer,
63    pub args: Option<C>,
64}
65
66impl<C: CbArgs> fmt::Debug for IOEvent<C> {
67    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
68        if let Some(TaskArgs::Merged(sub_tasks)) = self.args.as_ref() {
69            write!(f, "offset={} {:?} merged {}", self.offset, self.action, sub_tasks.len())
70        } else {
71            write!(f, "offset={} {:?}", self.offset, self.action)
72        }
73    }
74}
75
76impl<C: CbArgs> IOEvent<C> {
77    /// For IOAction::Read / IOAction::Write
78    #[inline]
79    pub fn new(fd: RawFd, buf: Buffer, action: IOAction, offset: i64) -> Self {
80        log_assert!(!buf.is_empty(), "{:?} offset={}, buffer size == 0", action, offset);
81        Self { buf_or_len: BufOrLen::Buffer(buf), fd, action, offset, res: i32::MIN, args: None }
82    }
83
84    /// For IOAction::Alloc / IOAction::Fsync
85    #[inline]
86    pub fn new_no_buf(fd: RawFd, action: IOAction, offset: i64, len: u64) -> Self {
87        Self {
88            buf_or_len: BufOrLen::Len(len), // No buffer for this action
89            fd,
90            action,
91            offset,
92            res: i32::MIN,
93            args: None,
94        }
95    }
96
97    #[inline(always)]
98    pub fn set_fd(&mut self, fd: RawFd) {
99        self.fd = fd;
100    }
101
102    /// Set callback for IOEvent, might be closure or a custom struct
103    #[inline(always)]
104    pub fn set_args(&mut self, args: C) {
105        self.args.replace(TaskArgs::Callback(args));
106    }
107
108    #[inline(always)]
109    pub fn get_size(&self) -> u64 {
110        match &self.buf_or_len {
111            BufOrLen::Buffer(buf) => buf.len() as u64,
112            BufOrLen::Len(l) => *l,
113        }
114    }
115
116    /// Set merged buffer and subtasks for the master event after merging.
117    #[inline(always)]
118    pub(crate) fn set_merged_tasks(
119        &mut self, merged_buf: Buffer, sub_tasks: SegList<IOEventMerged<C>>,
120    ) {
121        self.buf_or_len = BufOrLen::Buffer(merged_buf);
122        self.args.replace(TaskArgs::Merged(sub_tasks));
123    }
124
125    /// Convert this IOEvent into an IOEventMerged for storing in merge buffer.
126    /// Extracts the buffer and callback from the event.
127    #[inline(always)]
128    pub(crate) fn into_merged(mut self) -> IOEventMerged<C> {
129        let buf = match std::mem::replace(&mut self.buf_or_len, BufOrLen::Len(0)) {
130            BufOrLen::Buffer(buf) => buf,
131            BufOrLen::Len(_) => panic!("into_merged called on IOEvent with no buffer"),
132        };
133        let args = match self.args.take() {
134            Some(TaskArgs::Callback(args)) => Some(args),
135            _ => None,
136        };
137        IOEventMerged { buf, args }
138    }
139
140    /// Extract buffer and callback to create IOEventMerged, leaving this event with empty buffer.
141    /// Used when moving first event to merged_events list.
142    #[inline(always)]
143    pub(crate) fn extract_merged(&mut self) -> IOEventMerged<C> {
144        let buf = match std::mem::replace(&mut self.buf_or_len, BufOrLen::Len(0)) {
145            BufOrLen::Buffer(buf) => buf,
146            BufOrLen::Len(_) => panic!("extract_merged called on IOEvent with no buffer"),
147        };
148        let args = match self.args.take() {
149            Some(TaskArgs::Callback(args)) => Some(args),
150            _ => None,
151        };
152        IOEventMerged { buf, args }
153    }
154
155    /// return (offset, ptr, len)
156    #[inline(always)]
157    pub(crate) fn get_param_for_io(&mut self) -> (u64, *mut u8, u32) {
158        if let BufOrLen::Buffer(buf) = &mut self.buf_or_len {
159            let mut offset = self.offset as u64;
160            let mut p = buf.get_raw_mut();
161            let mut l = buf.len() as u32;
162            if self.res <= 0 {
163                (offset, p, l)
164            } else {
165                // resubmited I/O
166                offset += self.res as u64;
167                p = unsafe { p.add(self.res as usize) };
168                l += self.res as u32;
169                (offset, p, l)
170            }
171        } else {
172            panic!("get_buf_raw called on IOEvent with no buffer");
173        }
174    }
175
176    #[inline(always)]
177    pub fn get_write_result(self) -> Result<(), Errno> {
178        let res = self.res;
179        if res >= 0 {
180            return Ok(());
181        } else if res == i32::MIN {
182            panic!("IOEvent get_result before it's done");
183        } else {
184            return Err(Errno::from_raw_os_error(-res));
185        }
186    }
187
188    /// Get the result of the IO operation (bytes read/written or error).
189    /// Returns the number of bytes successfully transferred.
190    #[inline(always)]
191    pub fn get_result(&self) -> Result<usize, Errno> {
192        let res = self.res;
193        if res >= 0 {
194            return Ok(res as usize);
195        } else if res == i32::MIN {
196            panic!("IOEvent get_result before it's done");
197        } else {
198            return Err(Errno::from_raw_os_error(-res));
199        }
200    }
201
202    /// Get the buffer from a read operation.
203    /// Note: The buffer length is NOT modified. Use `get_result()` to get actual bytes read.
204    #[inline(always)]
205    pub fn get_read_result(mut self) -> Result<Buffer, Errno> {
206        let res = self.res;
207        if res >= 0 {
208            // XXX?
209            let buf_or_len = std::mem::replace(&mut self.buf_or_len, BufOrLen::Len(0));
210            if let BufOrLen::Buffer(buf) = buf_or_len {
211                // Do NOT modify buffer length - caller should use get_result() to know actual bytes read
212                return Ok(buf);
213            } else {
214                panic!("get_read_result called on IOEvent with no buffer");
215            }
216        } else if res == i32::MIN {
217            panic!("IOEvent get_result before it's done");
218        } else {
219            return Err(Errno::from_raw_os_error(-res));
220        }
221    }
222
223    #[inline(always)]
224    pub(crate) fn set_errno(&mut self, errno: Errno) {
225        self.res = -errno.raw_os_error();
226    }
227
228    #[inline(always)]
229    pub(crate) fn set_error(&mut self, mut errno: i32) {
230        if errno == 0 {
231            // TODO when errno == 0?
232            // XXX: EOF does not have code to represent,
233            // also when offset is not align to 4096, may return result 0,
234            errno = Errno::INVAL.raw_os_error();
235        }
236        if errno > 0 {
237            errno = -errno;
238        }
239        self.res = errno;
240    }
241
242    #[inline(always)]
243    pub(crate) fn set_copied(&mut self, len: usize) {
244        if self.res == i32::MIN {
245            // the initial state
246            self.res = len as i32;
247        } else {
248            // resubmit for short I/O
249            self.res += len as i32;
250        }
251    }
252
253    /// For writing custom callback workers
254    ///
255    /// Callback worker should always call this function on receiving IOEvent from Driver
256    ///
257    /// parameter: `check_short_read(offset: u64)` should be checking the offset exceed file end.
258    /// If `check_short_read()` return true, the callback function will return Err(IOEvent) for I/O resubmit.
259    ///
260    /// NOTE: you should always use a weak reference in `check_short_read` closure and
261    /// re-submission.
262    #[inline(always)]
263    pub fn callback<F, B>(mut self: Box<Self>, check_short_read: F, cb: B) -> Result<(), Box<Self>>
264    where
265        F: FnOnce(u64) -> bool,
266        B: Fn(C, i64, Result<Option<Buffer>, Errno>),
267    {
268        if self.res >= 0 {
269            if let BufOrLen::Buffer(buf) = &mut self.buf_or_len {
270                if buf.len() == self.res as usize {
271                    // most frequent case in the front, for cpu branch prediction
272                    self._callback_unchecked::<B>(false, cb);
273                } else if self.action == IOAction::Read {
274                    if check_short_read(self.offset as u64 + self.res as u64) {
275                        return Err(self);
276                    } else {
277                        // reach file ending
278                        buf.set_len(self.res as usize);
279                        self._callback_unchecked::<B>(false, cb);
280                    }
281                } else {
282                    // short write always need to resubmit
283                    return Err(self);
284                }
285            } else {
286                self._callback_unchecked::<B>(false, cb);
287            }
288        }
289        Ok(())
290    }
291
292    /// Perform callback on the IOEvent when cannot re-submit for short i/o
293    #[inline(always)]
294    pub fn callback_unchecked<B>(self, cb: B)
295    where
296        B: Fn(C, i64, Result<Option<Buffer>, Errno>),
297    {
298        self._callback_unchecked::<B>(true, cb);
299    }
300
301    /// Perform callback on the IOEvent when cannot re-submit for short i/o
302    ///
303    /// # Arguments
304    ///
305    /// - to_fix_short_io: should always be true, fix the buffer len of short I/O
306    ///
307    /// # Safety
308    ///
309    /// Only for callback worker does not re-submit when short I/O.
310    /// Buffer::len() will changed to actual I/O copied size during callback.
311    #[inline(always)]
312    pub(crate) fn _callback_unchecked<B>(mut self, to_fix_short_io: bool, cb: B)
313    where
314        B: Fn(C, i64, Result<Option<Buffer>, Errno>),
315    {
316        match self.args.take() {
317            Some(TaskArgs::Callback(args)) => {
318                let res: Result<Option<Buffer>, Errno> = if self.res >= 0 {
319                    match self.buf_or_len {
320                        BufOrLen::Buffer(mut buf) => {
321                            if to_fix_short_io && buf.len() > self.res as usize {
322                                buf.set_len(self.res as usize);
323                            }
324                            Ok(Some(buf))
325                        }
326                        BufOrLen::Len(_) => Ok(None),
327                    }
328                } else {
329                    Err(Errno::from_raw_os_error(-self.res))
330                };
331                cb(args, self.offset, res);
332            }
333            Some(TaskArgs::Merged(sub_tasks)) => {
334                if self.res >= 0 {
335                    let mut offset = self.offset;
336                    if self.action == IOAction::Read {
337                        if let BufOrLen::Buffer(parent_buf) = &self.buf_or_len {
338                            let mut b: &[u8] = &parent_buf[0..self.res as usize];
339                            for IOEventMerged { mut buf, args } in sub_tasks {
340                                if let Some(_args) = args {
341                                    let copied = safe_copy(&mut buf, b);
342                                    if copied < buf.len() {
343                                        buf.set_len(copied); // short I/O
344                                    }
345                                    cb(_args, offset, Ok(Some(buf)));
346                                    b = &b[copied..];
347                                    offset += copied as i64
348                                }
349                            }
350                        }
351                    } else if self.action == IOAction::Write {
352                        let mut l = self.res as usize;
353                        for IOEventMerged { mut buf, args } in sub_tasks {
354                            let mut copied = buf.len();
355                            if copied > l {
356                                // short write
357                                copied = l;
358                                buf.set_len(l);
359                            }
360                            if let Some(_args) = args {
361                                cb(_args, offset, Ok(Some(buf)));
362                            }
363                            l -= copied;
364                            offset += copied as i64;
365                        }
366                    }
367                } else {
368                    let mut offset = self.offset;
369                    for IOEventMerged { buf, args } in sub_tasks {
370                        let _l = buf.len() as i64;
371                        if let Some(_args) = args {
372                            cb(_args, offset, Err(Errno::from_raw_os_error(-self.res)));
373                        }
374                        offset += _l;
375                    }
376                }
377            }
378            None => {}
379        }
380    }
381}
382
383#[cfg(test)]
384mod tests {
385
386    use super::*;
387    use io_buffer::Buffer;
388    use rustix::io::Errno;
389    use std::mem::size_of;
390    use std::sync::Arc;
391    use std::sync::atomic::{AtomicI64, Ordering};
392
393    #[test]
394    fn test_ioevent_size() {
395        println!("IOEvent size {}", size_of::<IOEvent<()>>());
396        println!("BufOrLen size {}", size_of::<crate::tasks::BufOrLen>());
397        println!("IOEventMerged size {}", size_of::<IOEventMerged<()>>());
398    }
399
400    /// Test normal callback (non-merged case)
401    #[test]
402    fn test_callback_normal() {
403        let buffer = Buffer::alloc(4096).unwrap();
404        let mut event = IOEvent::<()>::new(0, buffer, IOAction::Write, 1024);
405
406        let result = Arc::new(std::sync::Mutex::new(None));
407        let result_clone = result.clone();
408
409        event.set_args(());
410        event.set_copied(4096);
411        event.callback_unchecked(move |_args, offset, res| {
412            *result_clone.lock().unwrap() = Some((offset, res));
413        });
414
415        let (offset, res) = result.lock().unwrap().take().unwrap();
416        assert_eq!(offset, 1024);
417        assert!(res.is_ok());
418        assert!(res.unwrap().is_some());
419    }
420
421    /// Test merged read callback - verifies offset correctness
422    #[test]
423    fn test_callback_merged_read() {
424        let offsets = Arc::new([AtomicI64::new(0), AtomicI64::new(0), AtomicI64::new(0)]);
425        let offsets_clone = offsets.clone();
426
427        // Create sub-tasks with their own buffers first
428        let mut sub_tasks = SegList::new();
429
430        sub_tasks.push(IOEventMerged { buf: Buffer::alloc(16).unwrap(), args: Some(()) });
431
432        sub_tasks.push(IOEventMerged { buf: Buffer::alloc(16).unwrap(), args: Some(()) });
433
434        sub_tasks.push(IOEventMerged { buf: Buffer::alloc(16).unwrap(), args: Some(()) });
435
436        // Create parent buffer and event
437        let parent_buf = Buffer::alloc(48).unwrap();
438        let mut event = IOEvent::<()>::new(0, parent_buf, IOAction::Read, 1000);
439        event.set_copied(48); // 48 bytes read
440
441        // Get the parent buffer back and fill with data
442        let parent_buf = match std::mem::replace(&mut event.buf_or_len, BufOrLen::Len(0)) {
443            BufOrLen::Buffer(buf) => buf,
444            BufOrLen::Len(_) => panic!("expected buffer"),
445        };
446        let mut parent_buf = parent_buf;
447        parent_buf.copy_from(0, b"ABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789!@#$%^&*()");
448
449        event.set_merged_tasks(parent_buf, sub_tasks);
450        event.callback_unchecked(move |(), offset, res| {
451            let idx = (offset - 1000) / 16;
452            offsets_clone[idx as usize].store(offset, Ordering::SeqCst);
453            assert!(res.is_ok());
454            assert!(res.unwrap().is_some());
455        });
456
457        // Verify offsets
458        assert_eq!(offsets[0].load(Ordering::SeqCst), 1000);
459        assert_eq!(offsets[1].load(Ordering::SeqCst), 1016);
460        assert_eq!(offsets[2].load(Ordering::SeqCst), 1032);
461    }
462
463    /// Test merged write callback - verifies offset correctness
464    #[test]
465    fn test_callback_merged_write() {
466        let parent_buf = Buffer::alloc(4096).unwrap();
467
468        let mut event = IOEvent::<()>::new(0, parent_buf, IOAction::Write, 2000);
469        event.set_copied(48); // All 48 bytes written
470
471        let mut sub_tasks = SegList::new();
472
473        sub_tasks.push(IOEventMerged { buf: Buffer::alloc(16).unwrap(), args: Some(()) });
474
475        sub_tasks.push(IOEventMerged { buf: Buffer::alloc(16).unwrap(), args: Some(()) });
476
477        sub_tasks.push(IOEventMerged { buf: Buffer::alloc(16).unwrap(), args: Some(()) });
478
479        event.set_merged_tasks(Buffer::alloc(4096).unwrap(), sub_tasks);
480        event.callback_unchecked(move |(), offset, res| {
481            assert!(offset >= 2000 && offset <= 2032);
482            assert!(res.is_ok());
483            assert!(res.unwrap().is_some());
484        });
485    }
486
487    /// Test merged callback with error result
488    #[test]
489    fn test_callback_merged_error() {
490        let parent_buf = Buffer::alloc(4096).unwrap();
491        let mut event = IOEvent::<()>::new(0, parent_buf, IOAction::Read, 3000);
492        event.set_error(Errno::IO.raw_os_error()); // IO error
493
494        let mut sub_tasks = SegList::new();
495
496        sub_tasks.push(IOEventMerged { buf: Buffer::alloc(16).unwrap(), args: Some(()) });
497
498        sub_tasks.push(IOEventMerged { buf: Buffer::alloc(16).unwrap(), args: Some(()) });
499
500        event.set_merged_tasks(Buffer::alloc(48).unwrap(), sub_tasks);
501        event.callback_unchecked(|(), offset, res| {
502            assert!(offset == 3000 || offset == 3016);
503            assert!(res.is_err());
504            assert_eq!(res.err().unwrap(), Errno::IO);
505        });
506    }
507
508    /// Test short read in merged callback
509    #[test]
510    fn test_callback_merged_short_read() {
511        let offsets = Arc::new([AtomicI64::new(0), AtomicI64::new(0)]);
512        let offsets_clone = offsets.clone();
513
514        // Parent buffer with 32 bytes
515        let parent_buf = Buffer::alloc(32).unwrap();
516        let mut event = IOEvent::<()>::new(0, parent_buf, IOAction::Read, 4000);
517        event.set_copied(24); // Short read: only 24 bytes (16 + 8)
518
519        let mut sub_tasks = SegList::new();
520
521        sub_tasks.push(IOEventMerged { buf: Buffer::alloc(16).unwrap(), args: Some(()) });
522
523        sub_tasks.push(IOEventMerged { buf: Buffer::alloc(16).unwrap(), args: Some(()) });
524
525        let parent_buf = match std::mem::replace(&mut event.buf_or_len, BufOrLen::Len(0)) {
526            BufOrLen::Buffer(buf) => buf,
527            BufOrLen::Len(_) => panic!("expected buffer"),
528        };
529
530        event.set_merged_tasks(parent_buf, sub_tasks);
531        event.callback_unchecked(move |(), offset, res| {
532            let idx = (offset - 4000) / 16;
533            offsets_clone[idx as usize].store(offset, Ordering::SeqCst);
534            assert!(res.is_ok());
535            assert!(res.unwrap().is_some());
536        });
537
538        // Verify
539        assert_eq!(offsets[0].load(Ordering::SeqCst), 4000);
540        assert_eq!(offsets[1].load(Ordering::SeqCst), 4016);
541    }
542}