1use async_trait::async_trait;
37use libp2prs_core::transport::{IListener, ITransport};
38use libp2prs_core::{
39 multiaddr::{protocol, protocol::Protocol, Multiaddr},
40 transport::TransportError,
41 Transport,
42};
43use libp2prs_runtime::net;
44use log::{error, trace};
45use std::{error, fmt, io};
46
47#[derive(Clone)]
55pub struct DnsConfig<T> {
56 inner: T,
58}
59
60impl<T> DnsConfig<T> {
61 pub fn new(inner: T) -> Self {
63 DnsConfig { inner }
64 }
65}
66
67impl<T> fmt::Debug for DnsConfig<T>
68where
69 T: fmt::Debug,
70{
71 fn fmt(&self, fmt: &mut fmt::Formatter<'_>) -> fmt::Result {
72 fmt.debug_tuple("DnsConfig").field(&self.inner).finish()
73 }
74}
75
76#[async_trait]
77impl<T> Transport for DnsConfig<T>
78where
79 T: Transport + Clone + 'static,
80{
81 type Output = T::Output;
82
83 fn listen_on(&mut self, addr: Multiaddr) -> Result<IListener<Self::Output>, TransportError> {
84 self.inner.listen_on(addr)
85 }
86
87 async fn dial(&mut self, addr: Multiaddr) -> Result<Self::Output, TransportError> {
88 let mut iter = addr.iter();
90 let proto = iter.find_map(|x| match x {
91 Protocol::Dns(name) => Some((name, true, true)),
92 Protocol::Dns4(name) => Some((name, true, false)),
93 Protocol::Dns6(name) => Some((name, false, true)),
94 _ => None,
95 });
96
97 let index = addr.iter().count() - iter.count() - 1;
98
99 let (name, dns4, dns6) = match proto {
102 Some((name, dns4, dns6)) => (name, dns4, dns6),
103 None => {
104 trace!("Pass-through address without DNS: {}", addr);
105 return self.inner.dial(addr).await;
106 }
107 };
108
109 let name = name.to_string();
110 let to_resolve = format!("{}:0", name);
111
112 let list = net::resolve_host(to_resolve).await.map_err(|_| {
113 error!("DNS resolver crashed");
114 TransportError::ResolveFail(name.clone())
115 })?;
116 let list = list.map(|s| s.ip()).collect::<Vec<_>>();
117
118 let outcome = list
119 .into_iter()
120 .filter_map(|addr| {
121 if (dns4 && addr.is_ipv4()) || (dns6 && addr.is_ipv6()) {
122 Some(Protocol::from(addr))
123 } else {
124 None
125 }
126 })
127 .next()
128 .ok_or_else(|| TransportError::ResolveFail(name.clone()))?;
129
130 if let Some(addr) = addr.replace(index, |_| Some(outcome.clone())) {
140 return self.inner.dial(addr).await;
141 }
142
143 Err(TransportError::ResolveFail(name))
144 }
145
146 fn box_clone(&self) -> ITransport<Self::Output> {
147 Box::new(self.clone())
148 }
149
150 fn protocols(&self) -> Vec<u32> {
151 let mut p = self.inner.protocols();
152 p.push(protocol::DNS);
153 p
154 }
155}
156
157#[derive(Debug)]
159pub enum DnsErr {
160 ResolveFail(String),
162 ResolveError { domain_name: String, error: io::Error },
164 MultiaddrNotSupported,
166}
167
168impl fmt::Display for DnsErr {
169 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
170 match self {
171 DnsErr::ResolveFail(addr) => write!(f, "Failed to resolve DNS address: {:?}", addr),
172 DnsErr::ResolveError { domain_name, error } => write!(f, "Failed to resolve DNS address: {:?}; {:?}", domain_name, error),
173 DnsErr::MultiaddrNotSupported => write!(f, "Resolve multiaddr not supported"),
174 }
175 }
176}
177
178impl error::Error for DnsErr {
179 fn source(&self) -> Option<&(dyn error::Error + 'static)> {
180 match self {
181 DnsErr::ResolveFail(_) => None,
182 DnsErr::ResolveError { error, .. } => Some(error),
183 DnsErr::MultiaddrNotSupported => None,
184 }
185 }
186}
187
188#[cfg(test)]
189mod tests {
190 use super::DnsConfig;
191 use futures::{AsyncReadExt, AsyncWriteExt};
192 use libp2prs_core::transport::ListenerEvent;
193 use libp2prs_core::Transport;
194 use libp2prs_multiaddr::Multiaddr;
195 use libp2prs_runtime::task;
196 use libp2prs_tcp::TcpConfig;
197
198 #[test]
199 fn basic_resolve_v4() {
200 task::block_on(async move {
201 let listen_addr: Multiaddr = "/ip4/127.0.0.1/tcp/8384".parse().unwrap();
202 let addr: Multiaddr = "/dns4/localhost/tcp/8384".parse().unwrap();
203 let mut transport = DnsConfig::new(TcpConfig::default());
204 let mut client = transport.clone();
205
206 let msg = b"Hello World";
207
208 let mut listener = transport.listen_on(listen_addr).unwrap();
209 let handle = task::spawn(async move {
210 let mut conn = match listener.accept().await.unwrap() {
211 ListenerEvent::Accepted(s) => s,
212 _ => panic!("unreachable"),
213 };
214
215 let mut buf = vec![0; msg.len()];
216 conn.read_exact(&mut buf).await.expect("server read exact");
217
218 assert_eq!(&msg[..], &buf[..]);
219
220 conn.close().await.expect("server close connection");
221 });
222
223 let mut conn = client.dial(addr).await.expect("client dial");
224 conn.write_all(&msg[..]).await.expect("client write all");
225 conn.close().await.expect("client close connection");
226
227 handle.await;
228 });
229 }
230
231 #[test]
232 fn basic_resolve_v6() {
233 task::block_on(async move {
234 let listen_addr: Multiaddr = "/ip6/::1/tcp/8384".parse().unwrap();
235 let addr: Multiaddr = "/dns6/localhost/tcp/8384".parse().unwrap();
236 let mut transport = DnsConfig::new(TcpConfig::default());
237 let mut client = transport.clone();
238
239 let msg = b"Hello World";
240
241 let mut listener = transport.listen_on(listen_addr).expect("S listen");
242 let handle = task::spawn(async move {
243 let mut conn = match listener.accept().await.unwrap() {
244 ListenerEvent::Accepted(s) => s,
245 _ => panic!("unreachable"),
246 };
247
248 let mut buf = vec![0; msg.len()];
249 conn.read_exact(&mut buf).await.expect("S read exact");
250
251 assert_eq!(&msg[..], &buf[..]);
252
253 conn.close().await.expect("S close connection");
254 });
255
256 let mut conn = client.dial(addr).await.expect("C dial");
257 conn.write_all(&msg[..]).await.expect("C write all");
258 conn.close().await.expect("C close connection");
259
260 handle.await;
261 });
262 }
263}