Skip to main content

compio_driver/sys/op/
ext.rs

1//! Extension traits
2
3use crate::sys::prelude::*;
4
5/// Take buffer out of an operation.
6pub trait TakeBuffer {
7    /// Type of the buffer.
8    type Buffer;
9
10    /// Take buffer.
11    fn take_buffer(self) -> Option<Self::Buffer>;
12}
13
14impl<I> TakeBuffer for I
15where
16    I: IntoInner<Inner = BufferRef>,
17{
18    type Buffer = I::Inner;
19
20    fn take_buffer(self) -> Option<Self::Buffer> {
21        Some(self.into_inner())
22    }
23}
24
25/// Helper trait for taking buffer from a [`BufResult`].
26pub trait ResultTakeBuffer {
27    /// Type of the buffer.
28    type Buffer;
29
30    /// Call [`SetLen::advance_to`] if the result is [`Ok`] and return the
31    /// buffer as result.
32    ///
33    /// # Safety
34    ///
35    /// The result value must be a valid length to advance to.
36    unsafe fn take_buffer(self) -> io::Result<Option<Self::Buffer>>;
37}
38
39impl ResultTakeBuffer for BufResult<usize, BufferRef> {
40    type Buffer = BufferRef;
41
42    unsafe fn take_buffer(self) -> io::Result<Option<BufferRef>> {
43        let (len, mut buf) = buf_try!(@try self);
44        if len == 0 {
45            return Ok(None);
46        }
47        unsafe { buf.advance_to(len) };
48
49        Ok(Some(buf))
50    }
51}
52
53impl<I: TakeBuffer<Buffer: IoBuf + SetLen>> ResultTakeBuffer for BufResult<usize, I> {
54    type Buffer = I::Buffer;
55
56    unsafe fn take_buffer(self) -> io::Result<Option<I::Buffer>> {
57        let (len, buf) = buf_try!(@try self);
58        // Kernel returns 0 for the operation, return Ok(None)
59        if len == 0 {
60            return Ok(None);
61        }
62        let Some(mut buf) = buf.take_buffer() else {
63            return Err(io::Error::new(
64                io::ErrorKind::UnexpectedEof,
65                format!("Read {len} bytes, but no buffer was selected by kernel"),
66            ));
67        };
68        unsafe { buf.advance_to(len) };
69        Ok(Some(buf))
70    }
71}
72
73/// Trait to update the buffer length inside the [`BufResult`].
74pub trait BufResultExt {
75    /// Call [`SetLen::advance_to`] if the result is [`Ok`].
76    ///
77    /// # Safety
78    ///
79    /// The result value must be a valid length to advance to.
80    unsafe fn map_advanced(self) -> Self;
81}
82
83/// Trait to update the buffer length inside the [`BufResult`].
84pub trait VecBufResultExt {
85    /// Call [`SetLen::advance_vec_to`] if the result is [`Ok`].
86    ///
87    /// # Safety
88    ///
89    /// The result value must be a valid length to advance to.
90    unsafe fn map_vec_advanced(self) -> Self;
91}
92
93impl<T: SetLen + IoBuf> BufResultExt for BufResult<usize, T> {
94    unsafe fn map_advanced(self) -> Self {
95        unsafe {
96            self.map_res(|res| (res, ()))
97                .map_advanced()
98                .map_res(|(res, _)| res)
99        }
100    }
101}
102
103impl<T: SetLen + IoVectoredBuf> VecBufResultExt for BufResult<usize, T> {
104    unsafe fn map_vec_advanced(self) -> Self {
105        unsafe {
106            self.map_res(|res| (res, ()))
107                .map_vec_advanced()
108                .map_res(|(res, _)| res)
109        }
110    }
111}
112
113impl<T: SetLen + IoBuf, O> BufResultExt for BufResult<(usize, O), T> {
114    unsafe fn map_advanced(self) -> Self {
115        self.map(|(init, obj), mut buffer| {
116            unsafe {
117                buffer.advance_to(init);
118            }
119            ((init, obj), buffer)
120        })
121    }
122}
123
124impl<T: SetLen + IoVectoredBuf, O> VecBufResultExt for BufResult<(usize, O), T> {
125    unsafe fn map_vec_advanced(self) -> Self {
126        self.map(|(init, obj), mut buffer| {
127            unsafe {
128                buffer.advance_vec_to(init);
129            }
130            ((init, obj), buffer)
131        })
132    }
133}
134
135impl<T: SetLen + IoBuf, C: SetLen + IoBuf, O> BufResultExt
136    for BufResult<(usize, usize, O), (T, C)>
137{
138    unsafe fn map_advanced(self) -> Self {
139        self.map(
140            |(init_buffer, init_control, obj), (mut buffer, mut control)| {
141                unsafe {
142                    buffer.advance_to(init_buffer);
143                    control.advance_to(init_control);
144                }
145                ((init_buffer, init_control, obj), (buffer, control))
146            },
147        )
148    }
149}
150
151impl<T: SetLen + IoVectoredBuf, C: SetLen + IoBuf, O> VecBufResultExt
152    for BufResult<(usize, usize, O), (T, C)>
153{
154    unsafe fn map_vec_advanced(self) -> Self {
155        self.map(
156            |(init_buffer, init_control, obj), (mut buffer, mut control)| {
157                unsafe {
158                    buffer.advance_vec_to(init_buffer);
159                    control.advance_to(init_control);
160                }
161                ((init_buffer, init_control, obj), (buffer, control))
162            },
163        )
164    }
165}
166
167/// Helper trait for [`RecvFrom`], [`RecvFromVectored`] and [`RecvMsg`].
168///
169/// [`RecvFrom`]: crate::op::RecvFrom
170/// [`RecvMsg`]: crate::op::RecvMsg
171/// [`RecvFromVectored`]: crate::op::RecvFromVectored
172pub trait RecvResultExt {
173    /// The mapped result.
174    type RecvResult;
175
176    /// Create [`SockAddr`] if the result is [`Ok`].
177    fn map_addr(self) -> Self::RecvResult;
178}
179
180impl<T> RecvResultExt for BufResult<usize, (T, Option<SockAddr>)> {
181    type RecvResult = BufResult<(usize, Option<SockAddr>), T>;
182
183    fn map_addr(self) -> Self::RecvResult {
184        self.map_buffer(|(buffer, addr)| (buffer, addr, 0))
185            .map_addr()
186            .map_res(|(res, _, addr)| (res, addr))
187    }
188}
189
190impl<T> RecvResultExt for BufResult<usize, (T, Option<SockAddr>, usize)> {
191    type RecvResult = BufResult<(usize, usize, Option<SockAddr>), T>;
192
193    fn map_addr(self) -> Self::RecvResult {
194        self.map2(
195            |res, (buffer, addr, len)| ((res, len, addr), buffer),
196            |(buffer, ..)| buffer,
197        )
198    }
199}