dns_forward_over_tcp/
server.rs

1use async_trait::async_trait;
2use dns_parser::Packet;
3use flume::{unbounded, Receiver};
4use log::{error, info, warn};
5use std::cmp;
6use std::io::{Error, ErrorKind};
7use std::net::SocketAddr;
8use std::sync::Arc;
9use std::time::Duration;
10use tokio::io::{AsyncReadExt, AsyncWriteExt};
11use tokio::net::{TcpStream, UdpSocket};
12
13#[async_trait]
14pub trait RecordCallback<T>: Send + Sync {
15    async fn request(&self, res: &Packet<'_>) -> (bool, Option<T>);
16    async fn response(&self, req: Option<&Packet<'_>>, context: Option<T>);
17}
18
19pub struct DnsServer<T> {
20    udp_socket: Arc<UdpSocket>,
21    tcp_server: Option<TcpStream>,
22
23    upstream: String,
24    callback: Arc<Box<dyn RecordCallback<T>>>,
25}
26
27impl<T: 'static + std::marker::Send> DnsServer<T> {
28    pub async fn run(
29        port: Option<String>,
30        upstream: Option<String>,
31        thread_num: Option<usize>,
32        callback: Box<dyn RecordCallback<T>>,
33    ) -> Result<(), Error> {
34        let bind_with_port = if let Some(port) = port {
35            if port.contains(":") {
36                port
37            } else {
38                String::from(format!("0.0.0.0:{}", port))
39            }
40        } else {
41            String::from("127.0.0.1:5353")
42        };
43
44        let upstream = if let Some(upstream) = upstream {
45            upstream
46        } else {
47            String::from("8.8.8.8:53")
48        };
49
50        let udp_socket = UdpSocket::bind(bind_with_port).await?;
51
52        let udp_server = Arc::new(udp_socket);
53        let udp_socket = udp_server.clone();
54
55        let (sender, receiver) = unbounded();
56
57        tokio::spawn(async move {
58            loop {
59                let mut buff = [0; 1024];
60                let rr = udp_server.recv_from(&mut buff).await;
61                if rr.is_err() {
62                    warn!("udp recv error. {:?}", rr.err());
63                    continue;
64                }
65                if let Some((size, src_addr)) = rr.ok() {
66                    let _ = sender.send_async((buff[..size].to_vec(), src_addr)).await;
67                }
68            }
69        });
70
71        let mut handles = vec![];
72        let thread_num = if let Some(thread_num) = thread_num {
73            cmp::min(thread_num, num_cpus::get())
74        } else {
75            cmp::min(2, num_cpus::get())
76        };
77
78        let callback = Arc::new(callback);
79
80        for _ in 0..thread_num {
81            let udp_socket = udp_socket.clone();
82            let receiver = receiver.clone();
83
84            let callback = callback.clone();
85
86            let mut s = DnsServer::<T> {
87                udp_socket,
88                upstream: upstream.clone(),
89
90                tcp_server: None,
91                callback,
92            };
93
94            handles.push(tokio::spawn(async move {
95                Self::process(&mut s, receiver).await;
96            }));
97        }
98
99        for h in handles {
100            let _ = h.await;
101        }
102
103        Ok(())
104    }
105
106    async fn process(dns_server: &mut DnsServer<T>, receiver: Receiver<(Vec<u8>, SocketAddr)>) {
107        loop {
108            let rr = receiver.recv_async().await;
109            if rr.is_err() {
110                continue;
111            }
112
113            let (buff, src_addr) = rr.ok().unwrap();
114
115            let dns_res_packet = dns_parser::Packet::parse(&buff);
116            if dns_res_packet.is_err() {
117                warn!(
118                    "parse dns packet failed. {:?}",
119                    dns_res_packet.as_ref().err()
120                );
121            }
122
123            let callback = dns_server.callback.clone();
124
125            let mut res_context = None;
126
127            if let Ok(dns_res_packet) = dns_res_packet {
128                let (pass, context) = callback.request(&dns_res_packet).await;
129                res_context = context;
130
131                if !pass {
132                    if let Ok(record) = dns_parser::Builder::new_query(
133                        dns_res_packet.header.id,
134                        dns_res_packet.header.recursion_available,
135                    )
136                    .build()
137                    {
138                        let _ = dns_server.udp_socket.send_to(&record, src_addr).await;
139                    }
140
141                    continue;
142                }
143            }
144
145            loop {
146                let req_buff = dns_server.forward(&buff).await;
147
148                if req_buff.is_err() {
149                    let err = req_buff.err().unwrap();
150
151                    match err.kind() {
152                        ErrorKind::BrokenPipe | ErrorKind::UnexpectedEof => {}
153                        _ => {
154                            warn!("{}", err.to_string());
155                        }
156                    }
157
158                    dns_server.tcp_server = None;
159                    continue;
160                }
161
162                let req_buff = req_buff.ok().unwrap();
163
164                let dns_req_packet = dns_parser::Packet::parse(&req_buff);
165                if dns_req_packet.is_err() {
166                    warn!(
167                        "parse dns packet failed. {:?}",
168                        dns_req_packet.as_ref().err()
169                    );
170                }
171
172                callback
173                    .response(dns_req_packet.ok().as_ref(), res_context)
174                    .await;
175
176                let _ = dns_server.udp_socket.send_to(&req_buff, src_addr).await;
177                break;
178            }
179        }
180    }
181
182    async fn forward(&mut self, data: &[u8]) -> Result<Vec<u8>, std::io::Error> {
183        if self.tcp_server.is_none() {
184            self.connect_remote_server().await;
185        }
186
187        let tcp_server = self.tcp_server.as_mut().unwrap();
188
189        let size = data.len() as u16;
190        let r = tcp_server.write(&size.to_be_bytes()).await?;
191        if r < size_of::<u16>() {
192            return Err(Error::new(
193                ErrorKind::Other,
194                format!("forward data failed. {}", size),
195            ));
196        }
197
198        let r = tcp_server.write(data).await?;
199        if r < data.len() {
200            return Err(Error::new(
201                ErrorKind::Other,
202                format!("forward data failed. {}", size),
203            ));
204        }
205
206        let size = tcp_server.read_u16().await?;
207
208        let mut buff = vec![0 as u8; size as usize];
209        let size = tcp_server.read_exact(&mut buff).await?;
210
211        if size < size_of_val(&buff) {
212            return Err(Error::new(ErrorKind::Other, "tcp read data failed."));
213        }
214
215        return Ok(buff);
216    }
217
218    async fn connect_remote_server(&mut self) {
219        loop {
220            if let Ok(s) = TcpStream::connect(&self.upstream).await {
221                self.tcp_server = Some(s);
222                break;
223            }
224
225            warn!("connect {} failed. try again later.", &self.upstream);
226            std::thread::sleep(Duration::from_secs(1));
227        }
228    }
229}