dns_forward_over_tcp/
server.rs1use async_trait::async_trait;
2use dns_parser::Packet;
3use flume::{unbounded, Receiver};
4use log::{error, info, warn};
5use std::cmp;
6use std::io::{Error, ErrorKind};
7use std::net::SocketAddr;
8use std::sync::Arc;
9use std::time::Duration;
10use tokio::io::{AsyncReadExt, AsyncWriteExt};
11use tokio::net::{TcpStream, UdpSocket};
12
13#[async_trait]
14pub trait RecordCallback<T>: Send + Sync {
15 async fn request(&self, res: &Packet<'_>) -> (bool, Option<T>);
16 async fn response(&self, req: Option<&Packet<'_>>, context: Option<T>);
17}
18
19pub struct DnsServer<T> {
20 udp_socket: Arc<UdpSocket>,
21 tcp_server: Option<TcpStream>,
22
23 upstream: String,
24 callback: Arc<Box<dyn RecordCallback<T>>>,
25}
26
27impl<T: 'static + std::marker::Send> DnsServer<T> {
28 pub async fn run(
29 port: Option<String>,
30 upstream: Option<String>,
31 thread_num: Option<usize>,
32 callback: Box<dyn RecordCallback<T>>,
33 ) -> Result<(), Error> {
34 let bind_with_port = if let Some(port) = port {
35 if port.contains(":") {
36 port
37 } else {
38 String::from(format!("0.0.0.0:{}", port))
39 }
40 } else {
41 String::from("127.0.0.1:5353")
42 };
43
44 let upstream = if let Some(upstream) = upstream {
45 upstream
46 } else {
47 String::from("8.8.8.8:53")
48 };
49
50 let udp_socket = UdpSocket::bind(bind_with_port).await?;
51
52 let udp_server = Arc::new(udp_socket);
53 let udp_socket = udp_server.clone();
54
55 let (sender, receiver) = unbounded();
56
57 tokio::spawn(async move {
58 loop {
59 let mut buff = [0; 1024];
60 let rr = udp_server.recv_from(&mut buff).await;
61 if rr.is_err() {
62 warn!("udp recv error. {:?}", rr.err());
63 continue;
64 }
65 if let Some((size, src_addr)) = rr.ok() {
66 let _ = sender.send_async((buff[..size].to_vec(), src_addr)).await;
67 }
68 }
69 });
70
71 let mut handles = vec![];
72 let thread_num = if let Some(thread_num) = thread_num {
73 cmp::min(thread_num, num_cpus::get())
74 } else {
75 cmp::min(2, num_cpus::get())
76 };
77
78 let callback = Arc::new(callback);
79
80 for _ in 0..thread_num {
81 let udp_socket = udp_socket.clone();
82 let receiver = receiver.clone();
83
84 let callback = callback.clone();
85
86 let mut s = DnsServer::<T> {
87 udp_socket,
88 upstream: upstream.clone(),
89
90 tcp_server: None,
91 callback,
92 };
93
94 handles.push(tokio::spawn(async move {
95 Self::process(&mut s, receiver).await;
96 }));
97 }
98
99 for h in handles {
100 let _ = h.await;
101 }
102
103 Ok(())
104 }
105
106 async fn process(dns_server: &mut DnsServer<T>, receiver: Receiver<(Vec<u8>, SocketAddr)>) {
107 loop {
108 let rr = receiver.recv_async().await;
109 if rr.is_err() {
110 continue;
111 }
112
113 let (buff, src_addr) = rr.ok().unwrap();
114
115 let dns_res_packet = dns_parser::Packet::parse(&buff);
116 if dns_res_packet.is_err() {
117 warn!(
118 "parse dns packet failed. {:?}",
119 dns_res_packet.as_ref().err()
120 );
121 }
122
123 let callback = dns_server.callback.clone();
124
125 let mut res_context = None;
126
127 if let Ok(dns_res_packet) = dns_res_packet {
128 let (pass, context) = callback.request(&dns_res_packet).await;
129 res_context = context;
130
131 if !pass {
132 if let Ok(record) = dns_parser::Builder::new_query(
133 dns_res_packet.header.id,
134 dns_res_packet.header.recursion_available,
135 )
136 .build()
137 {
138 let _ = dns_server.udp_socket.send_to(&record, src_addr).await;
139 }
140
141 continue;
142 }
143 }
144
145 loop {
146 let req_buff = dns_server.forward(&buff).await;
147
148 if req_buff.is_err() {
149 let err = req_buff.err().unwrap();
150
151 match err.kind() {
152 ErrorKind::BrokenPipe | ErrorKind::UnexpectedEof => {}
153 _ => {
154 warn!("{}", err.to_string());
155 }
156 }
157
158 dns_server.tcp_server = None;
159 continue;
160 }
161
162 let req_buff = req_buff.ok().unwrap();
163
164 let dns_req_packet = dns_parser::Packet::parse(&req_buff);
165 if dns_req_packet.is_err() {
166 warn!(
167 "parse dns packet failed. {:?}",
168 dns_req_packet.as_ref().err()
169 );
170 }
171
172 callback
173 .response(dns_req_packet.ok().as_ref(), res_context)
174 .await;
175
176 let _ = dns_server.udp_socket.send_to(&req_buff, src_addr).await;
177 break;
178 }
179 }
180 }
181
182 async fn forward(&mut self, data: &[u8]) -> Result<Vec<u8>, std::io::Error> {
183 if self.tcp_server.is_none() {
184 self.connect_remote_server().await;
185 }
186
187 let tcp_server = self.tcp_server.as_mut().unwrap();
188
189 let size = data.len() as u16;
190 let r = tcp_server.write(&size.to_be_bytes()).await?;
191 if r < size_of::<u16>() {
192 return Err(Error::new(
193 ErrorKind::Other,
194 format!("forward data failed. {}", size),
195 ));
196 }
197
198 let r = tcp_server.write(data).await?;
199 if r < data.len() {
200 return Err(Error::new(
201 ErrorKind::Other,
202 format!("forward data failed. {}", size),
203 ));
204 }
205
206 let size = tcp_server.read_u16().await?;
207
208 let mut buff = vec![0 as u8; size as usize];
209 let size = tcp_server.read_exact(&mut buff).await?;
210
211 if size < size_of_val(&buff) {
212 return Err(Error::new(ErrorKind::Other, "tcp read data failed."));
213 }
214
215 return Ok(buff);
216 }
217
218 async fn connect_remote_server(&mut self) {
219 loop {
220 if let Ok(s) = TcpStream::connect(&self.upstream).await {
221 self.tcp_server = Some(s);
222 break;
223 }
224
225 warn!("connect {} failed. try again later.", &self.upstream);
226 std::thread::sleep(Duration::from_secs(1));
227 }
228 }
229}