Skip to main content

compio_driver/sys/op/
ext.rs

1//! Extension traits
2
3use rustix::net::ReturnFlags;
4
5use crate::sys::prelude::*;
6
7/// Take buffer out of an operation.
8pub trait TakeBuffer {
9    /// Type of the buffer.
10    type Buffer;
11
12    /// Take buffer.
13    fn take_buffer(self) -> Option<Self::Buffer>;
14}
15
16impl<I> TakeBuffer for I
17where
18    I: IntoInner<Inner = BufferRef>,
19{
20    type Buffer = I::Inner;
21
22    fn take_buffer(self) -> Option<Self::Buffer> {
23        Some(self.into_inner())
24    }
25}
26
27/// Helper trait for taking buffer from a [`BufResult`].
28pub trait ResultTakeBuffer {
29    /// Type of the buffer.
30    type Buffer;
31
32    /// Call [`SetLen::advance_to`] if the result is [`Ok`] and return the
33    /// buffer as result.
34    ///
35    /// # Safety
36    ///
37    /// The result value must be a valid length to advance to.
38    unsafe fn take_buffer(self) -> io::Result<Option<Self::Buffer>>;
39}
40
41impl ResultTakeBuffer for BufResult<usize, BufferRef> {
42    type Buffer = BufferRef;
43
44    unsafe fn take_buffer(self) -> io::Result<Option<BufferRef>> {
45        let (len, mut buf) = buf_try!(@try self);
46        if len == 0 {
47            return Ok(None);
48        }
49        unsafe { buf.advance_to(len) };
50
51        Ok(Some(buf))
52    }
53}
54
55impl<I: TakeBuffer<Buffer: IoBuf + SetLen>> ResultTakeBuffer for BufResult<usize, I> {
56    type Buffer = I::Buffer;
57
58    unsafe fn take_buffer(self) -> io::Result<Option<I::Buffer>> {
59        let (len, buf) = buf_try!(@try self);
60        // Kernel returns 0 for the operation, return Ok(None)
61        if len == 0 {
62            return Ok(None);
63        }
64        let Some(mut buf) = buf.take_buffer() else {
65            return Err(io::Error::new(
66                io::ErrorKind::UnexpectedEof,
67                format!("Read {len} bytes, but no buffer was selected by kernel"),
68            ));
69        };
70        unsafe { buf.advance_to(len) };
71        Ok(Some(buf))
72    }
73}
74
75/// Trait to update the buffer length inside the [`BufResult`].
76pub trait BufResultExt {
77    /// Call [`SetLen::advance_to`] if the result is [`Ok`].
78    ///
79    /// # Safety
80    ///
81    /// The result value must be a valid length to advance to.
82    unsafe fn map_advanced(self) -> Self;
83}
84
85/// Trait to update the buffer length inside the [`BufResult`].
86pub trait VecBufResultExt {
87    /// Call [`SetLen::advance_vec_to`] if the result is [`Ok`].
88    ///
89    /// # Safety
90    ///
91    /// The result value must be a valid length to advance to.
92    unsafe fn map_vec_advanced(self) -> Self;
93}
94
95impl<T: SetLen + IoBuf> BufResultExt for BufResult<usize, T> {
96    unsafe fn map_advanced(self) -> Self {
97        unsafe {
98            self.map_res(|res| (res, ()))
99                .map_advanced()
100                .map_res(|(res, _)| res)
101        }
102    }
103}
104
105impl<T: SetLen + IoVectoredBuf> VecBufResultExt for BufResult<usize, T> {
106    unsafe fn map_vec_advanced(self) -> Self {
107        unsafe {
108            self.map_res(|res| (res, ()))
109                .map_vec_advanced()
110                .map_res(|(res, _)| res)
111        }
112    }
113}
114
115impl<T: SetLen + IoBuf, O> BufResultExt for BufResult<(usize, O), T> {
116    unsafe fn map_advanced(self) -> Self {
117        self.map(|(init, obj), mut buffer| {
118            unsafe {
119                buffer.advance_to(init);
120            }
121            ((init, obj), buffer)
122        })
123    }
124}
125
126impl<T: SetLen + IoVectoredBuf, O> VecBufResultExt for BufResult<(usize, O), T> {
127    unsafe fn map_vec_advanced(self) -> Self {
128        self.map(|(init, obj), mut buffer| {
129            unsafe {
130                buffer.advance_vec_to(init);
131            }
132            ((init, obj), buffer)
133        })
134    }
135}
136
137impl<T: SetLen + IoBuf, C: SetLen + IoBuf, O> BufResultExt
138    for BufResult<(usize, usize, O), (T, C)>
139{
140    unsafe fn map_advanced(self) -> Self {
141        self.map(
142            |(init_buffer, init_control, obj), (mut buffer, mut control)| {
143                unsafe {
144                    buffer.advance_to(init_buffer);
145                    control.advance_to(init_control);
146                }
147                ((init_buffer, init_control, obj), (buffer, control))
148            },
149        )
150    }
151}
152
153impl<T: SetLen + IoVectoredBuf, C: SetLen + IoBuf, O> VecBufResultExt
154    for BufResult<(usize, usize, O), (T, C)>
155{
156    unsafe fn map_vec_advanced(self) -> Self {
157        self.map(
158            |(init_buffer, init_control, obj), (mut buffer, mut control)| {
159                unsafe {
160                    buffer.advance_vec_to(init_buffer);
161                    control.advance_to(init_control);
162                }
163                ((init_buffer, init_control, obj), (buffer, control))
164            },
165        )
166    }
167}
168
169impl<T: SetLen + IoBuf, C: SetLen + IoBuf, O1, O2> BufResultExt
170    for BufResult<(usize, usize, O1, O2), (T, C)>
171{
172    unsafe fn map_advanced(self) -> Self {
173        self.map(
174            |(init_buffer, init_control, obj1, obj2), (mut buffer, mut control)| {
175                unsafe {
176                    buffer.advance_to(init_buffer);
177                    control.advance_to(init_control);
178                }
179                ((init_buffer, init_control, obj1, obj2), (buffer, control))
180            },
181        )
182    }
183}
184
185impl<T: SetLen + IoVectoredBuf, C: SetLen + IoBuf, O1, O2> VecBufResultExt
186    for BufResult<(usize, usize, O1, O2), (T, C)>
187{
188    unsafe fn map_vec_advanced(self) -> Self {
189        self.map(
190            |(init_buffer, init_control, obj1, obj2), (mut buffer, mut control)| {
191                unsafe {
192                    buffer.advance_vec_to(init_buffer);
193                    control.advance_to(init_control);
194                }
195                ((init_buffer, init_control, obj1, obj2), (buffer, control))
196            },
197        )
198    }
199}
200
201/// Helper trait for [`RecvFrom`], [`RecvFromVectored`] and [`RecvMsg`].
202///
203/// [`RecvFrom`]: crate::op::RecvFrom
204/// [`RecvMsg`]: crate::op::RecvMsg
205/// [`RecvFromVectored`]: crate::op::RecvFromVectored
206pub trait RecvResultExt {
207    /// The mapped result.
208    type RecvResult;
209
210    /// Create [`SockAddr`] if the result is [`Ok`].
211    fn map_addr(self) -> Self::RecvResult;
212}
213
214impl<T> RecvResultExt for BufResult<usize, (T, Option<SockAddr>)> {
215    type RecvResult = BufResult<(usize, Option<SockAddr>), T>;
216
217    fn map_addr(self) -> Self::RecvResult {
218        self.map_buffer(|(buffer, addr)| (buffer, addr, 0))
219            .map_addr()
220            .map_res(|(res, _, addr)| (res, addr))
221    }
222}
223
224impl<T> RecvResultExt for BufResult<usize, (T, Option<SockAddr>, usize)> {
225    type RecvResult = BufResult<(usize, usize, Option<SockAddr>), T>;
226
227    fn map_addr(self) -> Self::RecvResult {
228        self.map2(
229            |res, (buffer, addr, len)| ((res, len, addr), buffer),
230            |(buffer, ..)| buffer,
231        )
232    }
233}
234
235impl<T> RecvResultExt for BufResult<usize, (T, Option<SockAddr>, usize, ReturnFlags)> {
236    type RecvResult = BufResult<(usize, usize, Option<SockAddr>, ReturnFlags), T>;
237
238    fn map_addr(self) -> Self::RecvResult {
239        self.map2(
240            |res, (buffer, addr, len, flags)| ((res, len, addr, flags), buffer),
241            |(buffer, ..)| buffer,
242        )
243    }
244}