shadowsocks_service/local/dns/
client_cache.rs1#[cfg(unix)]
4use std::path::Path;
5use std::{
6 collections::{HashMap, VecDeque, hash_map::Entry},
7 future::Future,
8 io,
9 net::SocketAddr,
10 time::Duration,
11};
12
13use hickory_resolver::proto::{ProtoError, op::Message};
14use log::{debug, trace};
15use tokio::sync::Mutex;
16
17use shadowsocks::{config::ServerConfig, net::ConnectOpts, relay::socks5::Address};
18
19use crate::local::context::ServiceContext;
20
21use super::upstream::DnsClient;
22
23#[derive(Clone, Debug, Hash, Eq, PartialEq, PartialOrd, Ord)]
24enum DnsClientKey {
25 TcpLocal(SocketAddr),
26 UdpLocal(SocketAddr),
27 TcpRemote(Address),
28 UdpRemote(Address),
29}
30
31pub struct DnsClientCache {
32 cache: Mutex<HashMap<DnsClientKey, VecDeque<DnsClient>>>,
33 timeout: Duration,
34 retry_count: usize,
35 max_client_per_addr: usize,
36}
37
38impl DnsClientCache {
39 pub fn new(max_client_per_addr: usize) -> Self {
40 Self {
41 cache: Mutex::new(HashMap::new()),
42 timeout: Duration::from_secs(5),
43 retry_count: 1,
44 max_client_per_addr,
45 }
46 }
47
48 pub async fn lookup_local(
49 &self,
50 ns: SocketAddr,
51 msg: Message,
52 connect_opts: &ConnectOpts,
53 is_udp: bool,
54 ) -> Result<Message, ProtoError> {
55 let key = match is_udp {
56 true => DnsClientKey::UdpLocal(ns),
57 false => DnsClientKey::TcpLocal(ns),
58 };
59 self.lookup_dns(&key, msg, Some(connect_opts), None, None).await
60 }
61
62 pub async fn lookup_remote(
63 &self,
64 context: &ServiceContext,
65 svr_cfg: &ServerConfig,
66 ns: &Address,
67 msg: Message,
68 is_udp: bool,
69 ) -> Result<Message, ProtoError> {
70 let key = match is_udp {
71 true => DnsClientKey::UdpRemote(ns.clone()),
72 false => DnsClientKey::TcpRemote(ns.clone()),
73 };
74 self.lookup_dns(&key, msg, None, Some(context), Some(svr_cfg)).await
75 }
76
77 #[cfg(unix)]
78 pub async fn lookup_unix_stream<P: AsRef<Path>>(&self, ns: &P, msg: Message) -> Result<Message, ProtoError> {
79 let mut last_err = None;
80
81 for _ in 0..self.retry_count {
82 let mut client = match DnsClient::connect_unix_stream(ns).await {
90 Ok(client) => client,
91 Err(err) => {
92 last_err = Some(From::from(err));
93 continue;
94 }
95 };
96
97 let res = match client.lookup_timeout(msg.clone(), self.timeout).await {
98 Ok(msg) => msg,
99 Err(error) => {
100 last_err = Some(error);
101 continue;
102 }
103 };
104 return Ok(res);
105 }
106 Err(last_err.unwrap())
107 }
108
109 async fn lookup_dns(
110 &self,
111 dck: &DnsClientKey,
112 msg: Message,
113 connect_opts: Option<&ConnectOpts>,
114 context: Option<&ServiceContext>,
115 svr_cfg: Option<&ServerConfig>,
116 ) -> Result<Message, ProtoError> {
117 let mut last_err = None;
118 for _ in 0..self.retry_count {
119 let create_fn = async {
120 match dck {
121 DnsClientKey::TcpLocal(tcp_l) => {
122 let connect_opts = connect_opts.expect("connect options is required for local DNS");
123 DnsClient::connect_tcp_local(*tcp_l, connect_opts).await
124 }
125 DnsClientKey::UdpLocal(udp_l) => {
126 let connect_opts = connect_opts.expect("connect options is required for local DNS");
127 DnsClient::connect_udp_local(*udp_l, connect_opts).await
128 }
129 DnsClientKey::TcpRemote(tcp_l) => {
130 let context = context.expect("context is required for remote DNS");
131 let svr_cfg = svr_cfg.expect("server config is required for remote DNS");
132
133 DnsClient::connect_tcp_remote(
134 context.context(),
135 svr_cfg,
136 tcp_l,
137 context.connect_opts_ref(),
138 context.flow_stat(),
139 )
140 .await
141 }
142 DnsClientKey::UdpRemote(udp_l) => {
143 let context = context.expect("context is required for remote DNS");
144 let svr_cfg = svr_cfg.expect("server config is required for remote DNS");
145
146 DnsClient::connect_udp_remote(
147 context.context(),
148 svr_cfg,
149 udp_l.clone(),
150 context.connect_opts_ref(),
151 context.flow_stat(),
152 )
153 .await
154 }
155 }
156 };
157 match self.get_client_or_create(dck, create_fn).await {
158 Ok(mut client) => match client.lookup_timeout(msg.clone(), self.timeout).await {
159 Ok(msg) => {
160 self.save_client(dck.clone(), client).await;
161 return Ok(msg);
162 }
163 Err(err) => {
164 last_err = Some(err);
165 continue;
166 }
167 },
168 Err(err) => {
169 last_err = Some(From::from(err));
170 continue;
171 }
172 }
173 }
174 Err(last_err.unwrap())
175 }
176
177 async fn get_client_or_create<C>(&self, key: &DnsClientKey, create_fn: C) -> io::Result<DnsClient>
178 where
179 C: Future<Output = io::Result<DnsClient>>,
180 {
181 if let Some(q) = self.cache.lock().await.get_mut(key) {
183 while let Some(mut c) = q.pop_front() {
184 trace!("take cached DNS client for {:?}", key);
185 if !c.check_connected().await {
186 debug!("cached DNS client for {:?} is lost", key);
187 continue;
188 }
189 return Ok(c);
190 }
191 }
192 trace!("creating connection to DNS server {:?}", key);
193
194 create_fn.await
196 }
197
198 async fn save_client(&self, key: DnsClientKey, client: DnsClient) {
199 match self.cache.lock().await.entry(key) {
200 Entry::Occupied(occ) => {
201 let q = occ.into_mut();
202 q.push_back(client);
203 if q.len() > self.max_client_per_addr {
204 q.pop_front();
205 }
206 }
207 Entry::Vacant(vac) => {
208 let mut q = VecDeque::with_capacity(self.max_client_per_addr);
209 q.push_back(client);
210 vac.insert(q);
211 }
212 }
213 }
214}