1use crate::{
7 sys::{DmaBuffer, Source, SourceType},
8 ByteSliceMutExt, Reactor,
9};
10use nix::sys::socket::{MsgFlags, SockaddrLike};
11use std::{
12 cell::Cell,
13 io,
14 os::unix::io::{AsRawFd, FromRawFd, RawFd},
15 rc::{Rc, Weak},
16 time::Duration,
17};
18
19const DEFAULT_BUFFER_SIZE: usize = 8192;
20
21type Result<T> = crate::Result<T, ()>;
22
23#[derive(Debug)]
24pub struct GlommioDatagram<S: AsRawFd + FromRawFd + From<socket2::Socket>> {
25 pub(crate) reactor: Weak<Reactor>,
26 pub(crate) socket: S,
27
28 pub(crate) write_timeout: Cell<Option<Duration>>,
29 pub(crate) read_timeout: Cell<Option<Duration>>,
30
31 pub(crate) tx_yolo: Cell<bool>,
39 pub(crate) rx_yolo: Cell<bool>,
40
41 pub(crate) rx_buf_size: usize,
42}
43
44impl<S: AsRawFd + FromRawFd + From<socket2::Socket>> From<socket2::Socket> for GlommioDatagram<S> {
45 fn from(socket: socket2::Socket) -> GlommioDatagram<S> {
46 let socket = socket.into();
47 GlommioDatagram {
48 reactor: Rc::downgrade(&crate::executor().reactor()),
49 socket,
50 tx_yolo: Cell::new(true),
51 rx_yolo: Cell::new(true),
52 write_timeout: Cell::new(None),
53 read_timeout: Cell::new(None),
54 rx_buf_size: DEFAULT_BUFFER_SIZE,
55 }
56 }
57}
58
59impl<S: AsRawFd + FromRawFd + From<socket2::Socket>> AsRawFd for GlommioDatagram<S> {
60 fn as_raw_fd(&self) -> RawFd {
61 self.socket.as_raw_fd()
62 }
63}
64
65impl<S: FromRawFd + AsRawFd + From<socket2::Socket>> FromRawFd for GlommioDatagram<S> {
66 unsafe fn from_raw_fd(fd: RawFd) -> Self {
67 let socket = socket2::Socket::from_raw_fd(fd);
68 GlommioDatagram::from(socket)
69 }
70}
71
72impl<S: AsRawFd + FromRawFd + From<socket2::Socket>> GlommioDatagram<S> {
73 async fn consume_receive_buffer(&self, source: Source, buf: &mut [u8]) -> io::Result<usize> {
74 let sz = source.collect_rw().await?;
75 let src = match source.extract_source_type() {
76 SourceType::SockRecv(mut buf) => {
77 let mut buf = buf.take().unwrap();
78 buf.trim_to_size(sz);
79 buf
80 }
81 _ => unreachable!(),
82 };
83 buf[0..sz].copy_from_slice(&src.as_bytes()[0..sz]);
84 self.rx_yolo.set(true);
85 Ok(sz)
86 }
87
88 pub(crate) async fn peek(&self, buf: &mut [u8]) -> io::Result<usize> {
89 let source = self.reactor.upgrade().unwrap().recv(
90 self.socket.as_raw_fd(),
91 buf.len(),
92 MsgFlags::MSG_PEEK,
93 );
94
95 self.consume_receive_buffer(source, buf).await
96 }
97
98 pub(crate) async fn peek_from<T: SockaddrLike>(
99 &self,
100 buf: &mut [u8],
101 ) -> io::Result<(usize, T)> {
102 match self.yolo_recvmsg(buf, MsgFlags::MSG_PEEK) {
103 Some(res) => res,
104 None => self.recv_from_blocking(buf, MsgFlags::MSG_PEEK).await,
105 }
106 }
107
108 pub(crate) async fn recv(&self, buf: &mut [u8]) -> io::Result<usize> {
109 match self.yolo_rx(buf) {
110 Some(x) => x,
111 None => {
112 let source = self.reactor.upgrade().unwrap().rushed_recv(
113 self.socket.as_raw_fd(),
114 buf.len(),
115 self.read_timeout.get(),
116 )?;
117 self.consume_receive_buffer(source, buf).await
118 }
119 }
120 }
121
122 pub(crate) async fn recv_from_blocking<T: SockaddrLike>(
123 &self,
124 buf: &mut [u8],
125 flags: MsgFlags,
126 ) -> io::Result<(usize, T)> {
127 let source = self.reactor.upgrade().unwrap().rushed_recvmsg(
128 self.socket.as_raw_fd(),
129 buf.len(),
130 flags,
131 self.read_timeout.get(),
132 )?;
133 let sz = source.collect_rw().await?;
134 match source.extract_source_type() {
135 SourceType::SockRecvMsg(mut src, _iov, hdr, addr) => {
136 let mut src = src.take().unwrap();
137 src.trim_to_size(sz);
138 buf[0..sz].copy_from_slice(&src.as_bytes()[0..sz]);
139 let addr = unsafe {
140 T::from_raw(addr.as_ptr() as *const _, Some(hdr.msg_namelen)).unwrap()
141 };
142 self.rx_yolo.set(true);
143 Ok((sz, addr))
144 }
145 _ => unreachable!(),
146 }
147 }
148
149 pub(crate) fn set_write_timeout(&self, dur: Option<Duration>) -> Result<()> {
150 if let Some(dur) = dur.as_ref() {
151 if dur.as_nanos() == 0 {
152 return Err(io::Error::from_raw_os_error(libc::EINVAL).into());
153 }
154 }
155 self.write_timeout.set(dur);
156 Ok(())
157 }
158
159 pub(crate) fn set_read_timeout(&self, dur: Option<Duration>) -> Result<()> {
160 if let Some(dur) = dur.as_ref() {
161 if dur.as_nanos() == 0 {
162 return Err(io::Error::from_raw_os_error(libc::EINVAL).into());
163 }
164 }
165 self.read_timeout.set(dur);
166 Ok(())
167 }
168
169 pub(crate) fn write_timeout(&self) -> Option<Duration> {
170 self.write_timeout.get()
171 }
172
173 pub(crate) fn read_timeout(&self) -> Option<Duration> {
174 self.read_timeout.get()
175 }
176
177 pub(crate) async fn recv_from<T: SockaddrLike>(
178 &self,
179 buf: &mut [u8],
180 ) -> io::Result<(usize, T)> {
181 match self.yolo_recvmsg(buf, MsgFlags::empty()) {
182 Some(res) => res,
183 None => self.recv_from_blocking(buf, MsgFlags::empty()).await,
184 }
185 }
186
187 pub(crate) async fn send_to_blocking(
188 &self,
189 buf: &[u8],
190 sockaddr: impl nix::sys::socket::SockaddrLike,
191 ) -> io::Result<usize> {
192 let mut dma = self.allocate_buffer(buf.len());
193 assert_eq!(dma.write_at(0, buf), buf.len());
194 let source = self.reactor.upgrade().unwrap().rushed_sendmsg(
195 self.socket.as_raw_fd(),
196 dma,
197 sockaddr,
198 self.write_timeout.get(),
199 )?;
200 let ret = source.collect_rw().await?;
201 self.tx_yolo.set(true);
202 Ok(ret)
203 }
204
205 pub(crate) async fn send_to(
206 &self,
207 buf: &[u8],
208 addr: impl nix::sys::socket::SockaddrLike,
209 ) -> io::Result<usize> {
210 match self.yolo_sendmsg(buf, &addr) {
211 Some(res) => res,
212 None => self.send_to_blocking(buf, addr).await,
213 }
214 }
215
216 pub(crate) async fn send(&self, buf: &[u8]) -> io::Result<usize> {
217 match self.yolo_tx(buf) {
218 Some(r) => r,
219 None => {
220 let mut dma = self.allocate_buffer(buf.len());
221 assert_eq!(dma.write_at(0, buf), buf.len());
222 let source = self.reactor.upgrade().unwrap().rushed_send(
223 self.socket.as_raw_fd(),
224 dma,
225 self.write_timeout.get(),
226 )?;
227 let ret = source.collect_rw().await?;
228 self.tx_yolo.set(true);
229 Ok(ret)
230 }
231 }
232 }
233
234 fn allocate_buffer(&self, size: usize) -> DmaBuffer {
235 self.reactor.upgrade().unwrap().alloc_dma_buffer(size)
236 }
237
238 fn yolo_rx(&self, buf: &mut [u8]) -> Option<io::Result<usize>> {
239 if self.rx_yolo.get() {
240 super::yolo_recv(self.socket.as_raw_fd(), buf)
241 } else {
242 None
243 }
244 .or_else(|| {
245 self.rx_yolo.set(false);
246 None
247 })
248 }
249
250 fn yolo_recvmsg<T: SockaddrLike>(
251 &self,
252 buf: &mut [u8],
253 flags: MsgFlags,
254 ) -> Option<io::Result<(usize, T)>> {
255 if self.rx_yolo.get() {
256 super::yolo_recvmsg(self.socket.as_raw_fd(), buf, flags)
257 } else {
258 None
259 }
260 .or_else(|| {
261 self.rx_yolo.set(false);
262 None
263 })
264 }
265
266 fn yolo_tx(&self, buf: &[u8]) -> Option<io::Result<usize>> {
267 if self.tx_yolo.get() {
268 super::yolo_send(self.socket.as_raw_fd(), buf)
269 } else {
270 None
271 }
272 .or_else(|| {
273 self.tx_yolo.set(false);
274 None
275 })
276 }
277
278 fn yolo_sendmsg(
279 &self,
280 buf: &[u8],
281 addr: &impl nix::sys::socket::SockaddrLike,
282 ) -> Option<io::Result<usize>> {
283 if self.tx_yolo.get() {
284 super::yolo_sendmsg(self.socket.as_raw_fd(), buf, addr)
285 } else {
286 None
287 }
288 .or_else(|| {
289 self.tx_yolo.set(false);
290 None
291 })
292 }
293}