Skip to main content

compio_driver/sys/op/managed/
fusion.rs

1use compio_buf::*;
2use rustix::net::{RecvFlags, ReturnFlags};
3use socket2::SockAddr;
4
5use super::{fallback, iour};
6use crate::{
7    BufferPool, BufferRef, IourOpCode, OpEntry, OpType, PollFirst, PollOpCode, sys::pal::*,
8};
9
10macro_rules! mop {
11    (<$($ty:ident: $trait:ident),* $(,)?> $name:ident( $($arg:ident: $arg_t:ty),* $(,)? ) with $pool:ident) => {
12        mop!(<$($ty: $trait),*> $name( $($arg: $arg_t),* ) with $pool; crate::BufferRef);
13    };
14    (<$($ty:ident: $trait:ident),* $(,)?> $name:ident( $($arg:ident: $arg_t:ty),* $(,)? ) with $pool:ident; $inner:ty) => {
15        ::paste::paste!{
16            enum [< $name Inner >] <$($ty: $trait),*> {
17                Poll(fallback::$name<$($ty),*>),
18                IoUring(iour::$name<$($ty),*>),
19            }
20
21            impl<$($ty: $trait),*> [< $name Inner >]<$($ty),*> {
22                fn poll(&mut self) -> &mut fallback::$name<$($ty),*> {
23                    match self {
24                        Self::Poll(op) => op,
25                        Self::IoUring(_) => unreachable!("Current driver is not `io-uring`"),
26                    }
27                }
28
29                fn iour(&mut self) -> &mut iour::$name<$($ty),*> {
30                    match self {
31                        Self::IoUring(op) => op,
32                        Self::Poll(_) => unreachable!("Current driver is not `polling`"),
33                    }
34                }
35            }
36
37            #[doc = concat!("A fused `", stringify!($name), "` operation")]
38            pub struct $name <$($ty: $trait),*> {
39                inner: [< $name Inner >] <$($ty),*>
40            }
41
42            impl<$($ty: $trait),*> $name <$($ty),*> {
43                #[doc = concat!("Create a new `", stringify!($name), "`.")]
44                pub fn new($($arg: $arg_t),*) -> std::io::Result<Self> {
45                    Ok(if $pool.is_io_uring()? {
46                        Self {
47                            inner: [< $name Inner >]::IoUring(iour::$name::new($($arg),*)?),
48                        }
49                    } else {
50                        Self {
51                            inner: [< $name Inner >]::Poll(fallback::$name::new($($arg),*)?),
52                        }
53                    })
54                }
55            }
56
57            impl <$($ty: $trait),*> crate::TakeBuffer for $name <$($ty),*> {
58                type Buffer = $inner;
59
60                fn take_buffer(self) -> Option<$inner> {
61                    match self.inner {
62                        [< $name Inner >]::IoUring(op) => op.take_buffer().map(Into::into),
63                        [< $name Inner >]::Poll(op) => op.take_buffer().map(Into::into),
64                    }
65                }
66            }
67
68            unsafe impl<$($ty: $trait),*> PollOpCode for $name<$($ty),*> {
69                type Control = <fallback::$name<$($ty),*> as PollOpCode>::Control;
70
71                unsafe fn init(&mut self, ctrl: &mut Self::Control) {
72                    unsafe { self.inner.poll().init(ctrl) }
73                }
74
75                fn pre_submit(&mut self, control: &mut Self::Control) -> std::io::Result<crate::Decision> {
76                    self.inner.poll().pre_submit(control)
77                }
78
79                fn op_type(&mut self, control: &mut Self::Control) -> Option<OpType> {
80                    self.inner.poll().op_type(control)
81                }
82
83                fn operate(
84                    &mut self, control: &mut Self::Control,
85                ) -> std::task::Poll<std::io::Result<usize>> {
86                    self.inner.poll().operate(control)
87                }
88            }
89
90            unsafe impl<$($ty: $trait),*> IourOpCode for $name<$($ty),*> {
91                type Control = <iour::$name<$($ty),*> as IourOpCode>::Control;
92
93                unsafe fn init(&mut self, ctrl: &mut Self::Control) {
94                    unsafe { self.inner.iour().init(ctrl) }
95                }
96
97                fn create_entry(&mut self, control: &mut Self::Control) -> OpEntry {
98                    self.inner.iour().create_entry(control)
99                }
100
101                fn create_entry_fallback(&mut self, control: &mut Self::Control) -> OpEntry {
102                    self.inner.iour().create_entry_fallback(control)
103                }
104
105                fn call_blocking(&mut self, control: &mut Self::Control) -> std::io::Result<usize> {
106                    self.inner.iour().call_blocking(control)
107                }
108
109                unsafe fn set_result(&mut self, control: &mut Self::Control, result: &std::io::Result<usize>, extra: &crate::Extra) {
110                    unsafe { self.inner.iour().set_result(control, result, extra) }
111                }
112
113                unsafe fn push_multishot(&mut self, control: &mut Self::Control, result: std::io::Result<usize>, extra: crate::Extra) {
114                    unsafe { self.inner.iour().push_multishot(control, result, extra) }
115                }
116
117                fn pop_multishot(&mut self, control: &mut Self::Control) -> Option<BufResult<usize, crate::Extra>> {
118                    self.inner.iour().pop_multishot(control)
119                }
120            }
121        }
122    };
123}
124
125mop!(<S: AsFd> ReadManagedAt(fd: S, offset: u64, pool: &BufferPool, len: usize) with pool);
126mop!(<S: AsFd> ReadManaged(fd: S, pool: &BufferPool, len: usize) with pool);
127mop!(<S: AsFd> RecvManaged(fd: S, pool: &BufferPool, len: usize, flags: RecvFlags) with pool);
128mop!(<S: AsFd> RecvFromManaged(fd: S, pool: &BufferPool, len: usize, flags: RecvFlags) with pool; (BufferRef, Option<SockAddr>));
129mop!(<C: IoBufMut, S: AsFd> RecvMsgManaged(fd: S, pool: &BufferPool, len: usize, control: C, flags: RecvFlags) with pool; ((BufferRef, C), Option<SockAddr>, usize, ReturnFlags));
130mop!(<S: AsFd> ReadMultiAt(fd: S, offset: u64, pool: &BufferPool, len: usize) with pool);
131mop!(<S: AsFd> ReadMulti(fd: S, pool: &BufferPool, len: usize) with pool);
132mop!(<S: AsFd> RecvMulti(fd: S, pool: &BufferPool, len: usize, flags: RecvFlags) with pool);
133mop!(<S: AsFd> RecvFromMulti(fd: S, pool: &BufferPool, flags: RecvFlags) with pool; RecvFromMultiResult);
134mop!(<S: AsFd> RecvMsgMulti(fd: S, pool: &BufferPool, control_len: usize, flags: RecvFlags) with pool; RecvMsgMultiResult);
135
136impl<S: AsFd> PollFirst for RecvManaged<S> {
137    fn poll_first(&mut self) {
138        match self.inner {
139            RecvManagedInner::Poll(ref mut i) => i.poll_first(),
140            RecvManagedInner::IoUring(ref mut i) => i.poll_first(),
141        }
142    }
143}
144
145impl<S: AsFd> PollFirst for RecvFromManaged<S> {
146    fn poll_first(&mut self) {
147        match self.inner {
148            RecvFromManagedInner::Poll(ref mut i) => i.poll_first(),
149            RecvFromManagedInner::IoUring(ref mut i) => i.poll_first(),
150        }
151    }
152}
153
154impl<C: IoBufMut, S: AsFd> PollFirst for RecvMsgManaged<C, S> {
155    fn poll_first(&mut self) {
156        match self.inner {
157            RecvMsgManagedInner::Poll(ref mut i) => i.poll_first(),
158            RecvMsgManagedInner::IoUring(ref mut i) => i.poll_first(),
159        }
160    }
161}
162
163enum RecvFromMultiResultInner {
164    Poll(fallback::RecvFromMultiResult),
165    IoUring(iour::RecvFromMultiResult),
166}
167
168/// Result of [`RecvFromMulti`].
169pub struct RecvFromMultiResult {
170    inner: RecvFromMultiResultInner,
171}
172
173impl From<fallback::RecvFromMultiResult> for RecvFromMultiResult {
174    fn from(result: fallback::RecvFromMultiResult) -> Self {
175        Self {
176            inner: RecvFromMultiResultInner::Poll(result),
177        }
178    }
179}
180
181impl From<iour::RecvFromMultiResult> for RecvFromMultiResult {
182    fn from(result: iour::RecvFromMultiResult) -> Self {
183        Self {
184            inner: RecvFromMultiResultInner::IoUring(result),
185        }
186    }
187}
188
189impl RecvFromMultiResult {
190    /// Create [`RecvFromMultiResult`] from a buffer received from
191    /// [`RecvFromMulti`]. It should be used for io-uring only.
192    ///
193    /// # Safety
194    ///
195    /// The buffer must be received from [`RecvFromMulti`] or have the same
196    /// format as the buffer received from [`RecvFromMulti`].
197    pub unsafe fn new(buffer: BufferRef) -> Self {
198        Self {
199            inner: RecvFromMultiResultInner::IoUring(unsafe {
200                iour::RecvFromMultiResult::new(buffer)
201            }),
202        }
203    }
204
205    /// Get the payload data.
206    pub fn data(&self) -> &[u8] {
207        match &self.inner {
208            RecvFromMultiResultInner::Poll(result) => result.data(),
209            RecvFromMultiResultInner::IoUring(result) => result.data(),
210        }
211    }
212
213    /// Get the source address if applicable.
214    pub fn addr(&self) -> Option<SockAddr> {
215        match &self.inner {
216            RecvFromMultiResultInner::Poll(result) => result.addr(),
217            RecvFromMultiResultInner::IoUring(result) => result.addr(),
218        }
219    }
220}
221
222impl IntoInner for RecvFromMultiResult {
223    type Inner = BufferRef;
224
225    fn into_inner(self) -> Self::Inner {
226        match self.inner {
227            RecvFromMultiResultInner::Poll(result) => result.into_inner(),
228            RecvFromMultiResultInner::IoUring(result) => result.into_inner(),
229        }
230    }
231}
232
233enum RecvMsgMultiResultInner {
234    Poll(fallback::RecvMsgMultiResult),
235    IoUring(iour::RecvMsgMultiResult),
236}
237
238/// Result of [`RecvMsgMulti`].
239pub struct RecvMsgMultiResult {
240    inner: RecvMsgMultiResultInner,
241}
242
243impl From<fallback::RecvMsgMultiResult> for RecvMsgMultiResult {
244    fn from(result: fallback::RecvMsgMultiResult) -> Self {
245        Self {
246            inner: RecvMsgMultiResultInner::Poll(result),
247        }
248    }
249}
250
251impl From<iour::RecvMsgMultiResult> for RecvMsgMultiResult {
252    fn from(result: iour::RecvMsgMultiResult) -> Self {
253        Self {
254            inner: RecvMsgMultiResultInner::IoUring(result),
255        }
256    }
257}
258
259impl RecvMsgMultiResult {
260    /// Create [`RecvMsgMultiResult`] from a buffer received from
261    /// [`RecvMsgMulti`]. It should be used for io-uring only.
262    ///
263    /// # Safety
264    ///
265    /// The buffer must be received from [`RecvMsgMulti`] or have the same
266    /// format as the buffer received from [`RecvMsgMulti`].
267    pub unsafe fn new(buffer: BufferRef, clen: usize) -> Self {
268        Self {
269            inner: RecvMsgMultiResultInner::IoUring(unsafe {
270                iour::RecvMsgMultiResult::new(buffer, clen)
271            }),
272        }
273    }
274
275    /// Get the payload data.
276    pub fn data(&self) -> &[u8] {
277        match &self.inner {
278            RecvMsgMultiResultInner::Poll(result) => result.data(),
279            RecvMsgMultiResultInner::IoUring(result) => result.data(),
280        }
281    }
282
283    /// Get the ancillary data.
284    pub fn ancillary(&self) -> &[u8] {
285        match &self.inner {
286            RecvMsgMultiResultInner::Poll(result) => result.ancillary(),
287            RecvMsgMultiResultInner::IoUring(result) => result.ancillary(),
288        }
289    }
290
291    /// Get the source address if applicable.
292    pub fn addr(&self) -> Option<SockAddr> {
293        match &self.inner {
294            RecvMsgMultiResultInner::Poll(result) => result.addr(),
295            RecvMsgMultiResultInner::IoUring(result) => result.addr(),
296        }
297    }
298
299    /// Get flags returned by `recvmsg`.
300    pub fn flags(&self) -> ReturnFlags {
301        match &self.inner {
302            RecvMsgMultiResultInner::Poll(result) => result.flags(),
303            RecvMsgMultiResultInner::IoUring(result) => result.flags(),
304        }
305    }
306}
307
308impl IntoInner for RecvMsgMultiResult {
309    type Inner = BufferRef;
310
311    fn into_inner(self) -> Self::Inner {
312        match self.inner {
313            RecvMsgMultiResultInner::Poll(result) => result.into_inner(),
314            RecvMsgMultiResultInner::IoUring(result) => result.into_inner(),
315        }
316    }
317}