1use async_trait::async_trait;
2use std::net::SocketAddr;
3
4#[async_trait]
6pub trait UdpSocketFactory: Sized {
7 type Socket: UdpSocket;
8 type Error: std::error::Error;
9
10 async fn bind(&mut self, addr: &SocketAddr) -> Result<Self::Socket, Self::Error>;
12}
13
14#[async_trait]
17pub trait UdpSocket: Sized {
18 type Error: std::error::Error;
19
20 async fn enable_broadcast(&mut self) -> Result<(), Self::Error>;
22
23 async fn connect(&mut self, addr: &SocketAddr) -> Result<(), Self::Error>;
26
27 async fn send(&mut self, buf: &[u8]) -> Result<usize, Self::Error>;
30
31 async fn send_to(&mut self, buf: &[u8], addr: &SocketAddr) -> Result<usize, Self::Error>;
34
35 async fn recv(&mut self, but: &mut [u8]) -> Result<usize, Self::Error>;
38
39 async fn recv_from(&mut self, buf: &mut [u8]) -> Result<(usize, SocketAddr), Self::Error>;
42}
43
44#[cfg(feature = "tokio")]
45pub type DefaultSocketFactory = TokioSocketFactory;
46
47#[cfg(all(feature = "async-std", not(feature = "tokio")))]
48pub type DefaultSocketFactory = AsyncStdSocketFactory;
49
50#[cfg(feature = "tokio")]
51pub struct TokioSocketFactory;
52
53#[cfg(feature = "tokio")]
54impl TokioSocketFactory {
55 pub fn new() -> TokioSocketFactory {
56 TokioSocketFactory
57 }
58}
59
60#[cfg(feature = "tokio")]
61#[async_trait]
62impl UdpSocketFactory for TokioSocketFactory {
63 type Error = tokio::io::Error;
64 type Socket = tokio::net::UdpSocket;
65
66 async fn bind(&mut self, addr: &SocketAddr) -> Result<Self::Socket, Self::Error> {
67 tokio::net::UdpSocket::bind(addr).await
68 }
69}
70
71#[cfg(feature = "tokio")]
72#[async_trait]
73impl UdpSocket for tokio::net::UdpSocket {
74 type Error = tokio::io::Error;
75
76 async fn enable_broadcast(&mut self) -> Result<(), Self::Error> {
77 Self::set_broadcast(self, true)
78 }
79
80 async fn connect(&mut self, addr: &SocketAddr) -> Result<(), Self::Error> {
81 Self::connect(self, addr).await
82 }
83
84 async fn send(&mut self, buf: &[u8]) -> Result<usize, Self::Error> {
85 Self::send(self, buf).await
86 }
87
88 async fn send_to(&mut self, buf: &[u8], addr: &SocketAddr) -> Result<usize, Self::Error> {
89 Self::send_to(self, buf, addr).await
90 }
91
92 async fn recv(&mut self, buf: &mut [u8]) -> Result<usize, Self::Error> {
93 Self::recv(self, buf).await
94 }
95
96 async fn recv_from(&mut self, buf: &mut [u8]) -> Result<(usize, SocketAddr), Self::Error> {
97 Self::recv_from(self, buf).await
98 }
99}
100
101#[cfg(feature = "async-std")]
102pub struct AsyncStdSocketFactory;
103
104#[cfg(feature = "async-std")]
105impl AsyncStdSocketFactory {
106 pub fn new() -> AsyncStdSocketFactory {
107 AsyncStdSocketFactory
108 }
109}
110
111#[cfg(feature = "async-std")]
112#[async_trait]
113impl UdpSocketFactory for AsyncStdSocketFactory {
114 type Error = async_std::io::Error;
115 type Socket = async_std::net::UdpSocket;
116
117 async fn bind(&mut self, addr: &SocketAddr) -> Result<Self::Socket, Self::Error> {
118 async_std::net::UdpSocket::bind(addr).await
119 }
120}
121
122#[cfg(feature = "async-std")]
123#[async_trait]
124impl UdpSocket for async_std::net::UdpSocket {
125 type Error = async_std::io::Error;
126
127 async fn enable_broadcast(&mut self) -> Result<(), Self::Error> {
128 Self::set_broadcast(self, true)
129 }
130
131 async fn connect(&mut self, addr: &SocketAddr) -> Result<(), Self::Error> {
132 Self::connect(self, addr).await
133 }
134
135 async fn send(&mut self, buf: &[u8]) -> Result<usize, Self::Error> {
136 Self::send(self, buf).await
137 }
138
139 async fn send_to(&mut self, buf: &[u8], addr: &SocketAddr) -> Result<usize, Self::Error> {
140 Self::send_to(self, buf, addr).await
141 }
142
143 async fn recv(&mut self, buf: &mut [u8]) -> Result<usize, Self::Error> {
144 Self::recv(self, buf).await
145 }
146
147 async fn recv_from(&mut self, buf: &mut [u8]) -> Result<(usize, SocketAddr), Self::Error> {
148 use std::net::ToSocketAddrs;
149
150 match Self::recv_from(self, buf).await {
151 Ok((recv_bytes, addr)) => {
152 Ok((recv_bytes, addr.to_socket_addrs().unwrap().next().unwrap()))
153 }
154 Err(x) => Err(x),
155 }
156 }
157}