1use crate::{
2 context::Context,
3 helper::{load_certs, load_private_key, load_root_store},
4 listen::Config as ListenConfig,
5 remote::{Host as RemoteHost, Session as RemoteSession},
6 {get_listen_config, get_remote_host, Cache, DohError, DohResult},
7};
8use clap::ArgMatches;
9use futures::lock::Mutex;
10use std::{io::Result as IoResult, num::NonZeroUsize, sync::Arc};
11use tokio::net::UdpSocket;
12use tokio_rustls::rustls::ClientConfig;
13
14fn create_client_config(
15 cafile: Option<&String>,
16 client_auth: Option<(&String, &String)>,
17) -> DohResult<ClientConfig> {
18 let root_store = load_root_store(cafile)?;
19 let config_builder = ClientConfig::builder().with_root_certificates(root_store);
20 let mut config = if let Some((certs, key)) = client_auth {
21 let cert_chain = load_certs(certs)?;
22 let key_der = load_private_key(key)?;
23 config_builder.with_client_auth_cert(cert_chain, key_der)?
24 } else {
25 config_builder.with_no_client_auth()
26 };
27 config.alpn_protocols.push(vec![104, 50]); Ok(config)
29}
30
31pub struct Config {
33 listen_config: ListenConfig,
34 remote_host: RemoteHost,
35 domain: String,
36 client_config: Arc<ClientConfig>,
37 uri: String,
38 retries: u32,
39 timeout: u64,
40 post: bool,
41 cache_size: usize,
42 cache_fallback: bool,
43}
44
45impl Config {
46 pub fn new(
48 listen_config: ListenConfig,
49 remote_host: RemoteHost,
50 domain: &str,
51 cafile: Option<&String>,
52 client_auth: Option<(&String, &String)>,
53 path: &str,
54 retries: u32,
55 timeout: u64,
56 post: bool,
57 cache_size: usize,
58 cache_fallback: bool,
59 ) -> DohResult<Config> {
60 let client_config = create_client_config(cafile, client_auth)?;
61
62 let uri = format!("https://{}/{}", domain, path);
63
64 if cache_fallback && cache_size == 0 {
65 return Err(DohError::CacheSize);
66 }
67
68 Ok(Config {
69 listen_config,
70 remote_host,
71 domain: domain.to_string(),
72 client_config: Arc::new(client_config),
73 uri,
74 retries,
75 timeout,
76 post,
77 cache_size,
78 cache_fallback,
79 })
80 }
81
82 pub async fn try_from(matches: ArgMatches) -> DohResult<Config> {
83 let listen_config = get_listen_config(&matches)?;
84 let remote_host = get_remote_host(&matches).await?;
85 let domain = matches.get_one::<String>("domain").unwrap();
86 let cafile = matches.get_one::<String>("cafile");
87 let client_auth = matches
88 .get_one::<String>("client-auth-certs")
89 .map(|certs| (certs, matches.get_one::<String>("client-auth-key").unwrap()));
90 let path = matches.get_one::<String>("path").unwrap();
91 let retries: u32 = *matches.get_one::<u32>("retries").unwrap_or(&3);
92 let timeout: u64 = *matches.get_one::<u64>("timeout").unwrap_or(&2);
93 let post: bool = !matches.get_flag("get");
94 let cache_size: usize = *matches.get_one::<usize>("cache-size").unwrap_or(&1024);
95 let cache_fallback: bool = matches.get_flag("cache-fallback");
96 Config::new(
97 listen_config,
98 remote_host,
99 domain,
100 cafile,
101 client_auth,
102 path,
103 retries,
104 timeout,
105 post,
106 cache_size,
107 cache_fallback,
108 )
109 }
110
111 pub(crate) async fn into(self) -> IoResult<(Arc<UdpSocket>, Context)> {
112 let cache = if self.cache_size == 0 {
113 None
114 } else {
115 let cache_size = NonZeroUsize::new(self.cache_size).unwrap();
116 Some(Mutex::new(Cache::new(cache_size)))
117 };
118 let cache_fallback = self.cache_fallback;
119 let timeout = self.timeout;
120 let socket = self.listen_config.into_socket().await?;
121 let socket = Arc::new(socket);
122 let remote_session = RemoteSession::new(
123 self.remote_host,
124 self.domain,
125 self.client_config,
126 self.uri,
127 self.retries,
128 self.post,
129 );
130 let context = Context::new(
131 cache,
132 cache_fallback,
133 timeout,
134 remote_session,
135 socket.clone(),
136 );
137 Ok((socket, context))
138 }
139}