doh_client/
config.rs

1use crate::{
2    context::Context,
3    helper::{load_certs, load_private_key, load_root_store},
4    listen::Config as ListenConfig,
5    remote::{Host as RemoteHost, Session as RemoteSession},
6    {get_listen_config, get_remote_host, Cache, DohError, DohResult},
7};
8use clap::ArgMatches;
9use futures::lock::Mutex;
10use std::{io::Result as IoResult, num::NonZeroUsize, sync::Arc};
11use tokio::net::UdpSocket;
12use tokio_rustls::rustls::ClientConfig;
13
14fn create_client_config(
15    cafile: Option<&String>,
16    client_auth: Option<(&String, &String)>,
17) -> DohResult<ClientConfig> {
18    let root_store = load_root_store(cafile)?;
19    let config_builder = ClientConfig::builder().with_root_certificates(root_store);
20    let mut config = if let Some((certs, key)) = client_auth {
21        let cert_chain = load_certs(certs)?;
22        let key_der = load_private_key(key)?;
23        config_builder.with_client_auth_cert(cert_chain, key_der)?
24    } else {
25        config_builder.with_no_client_auth()
26    };
27    config.alpn_protocols.push(vec![104, 50]); // h2
28    Ok(config)
29}
30
31/// The configuration object for the `doh-client`.
32pub struct Config {
33    listen_config: ListenConfig,
34    remote_host: RemoteHost,
35    domain: String,
36    client_config: Arc<ClientConfig>,
37    uri: String,
38    retries: u32,
39    timeout: u64,
40    post: bool,
41    cache_size: usize,
42    cache_fallback: bool,
43}
44
45impl Config {
46    /// Create a new `doh_client::Config` object.
47    pub fn new(
48        listen_config: ListenConfig,
49        remote_host: RemoteHost,
50        domain: &str,
51        cafile: Option<&String>,
52        client_auth: Option<(&String, &String)>,
53        path: &str,
54        retries: u32,
55        timeout: u64,
56        post: bool,
57        cache_size: usize,
58        cache_fallback: bool,
59    ) -> DohResult<Config> {
60        let client_config = create_client_config(cafile, client_auth)?;
61
62        let uri = format!("https://{}/{}", domain, path);
63
64        if cache_fallback && cache_size == 0 {
65            return Err(DohError::CacheSize);
66        }
67
68        Ok(Config {
69            listen_config,
70            remote_host,
71            domain: domain.to_string(),
72            client_config: Arc::new(client_config),
73            uri,
74            retries,
75            timeout,
76            post,
77            cache_size,
78            cache_fallback,
79        })
80    }
81
82    pub async fn try_from(matches: ArgMatches) -> DohResult<Config> {
83        let listen_config = get_listen_config(&matches)?;
84        let remote_host = get_remote_host(&matches).await?;
85        let domain = matches.get_one::<String>("domain").unwrap();
86        let cafile = matches.get_one::<String>("cafile");
87        let client_auth = matches
88            .get_one::<String>("client-auth-certs")
89            .map(|certs| (certs, matches.get_one::<String>("client-auth-key").unwrap()));
90        let path = matches.get_one::<String>("path").unwrap();
91        let retries: u32 = *matches.get_one::<u32>("retries").unwrap_or(&3);
92        let timeout: u64 = *matches.get_one::<u64>("timeout").unwrap_or(&2);
93        let post: bool = !matches.get_flag("get");
94        let cache_size: usize = *matches.get_one::<usize>("cache-size").unwrap_or(&1024);
95        let cache_fallback: bool = matches.get_flag("cache-fallback");
96        Config::new(
97            listen_config,
98            remote_host,
99            domain,
100            cafile,
101            client_auth,
102            path,
103            retries,
104            timeout,
105            post,
106            cache_size,
107            cache_fallback,
108        )
109    }
110
111    pub(crate) async fn into(self) -> IoResult<(Arc<UdpSocket>, Context)> {
112        let cache = if self.cache_size == 0 {
113            None
114        } else {
115            let cache_size = NonZeroUsize::new(self.cache_size).unwrap();
116            Some(Mutex::new(Cache::new(cache_size)))
117        };
118        let cache_fallback = self.cache_fallback;
119        let timeout = self.timeout;
120        let socket = self.listen_config.into_socket().await?;
121        let socket = Arc::new(socket);
122        let remote_session = RemoteSession::new(
123            self.remote_host,
124            self.domain,
125            self.client_config,
126            self.uri,
127            self.retries,
128            self.post,
129        );
130        let context = Context::new(
131            cache,
132            cache_fallback,
133            timeout,
134            remote_session,
135            socket.clone(),
136        );
137        Ok((socket, context))
138    }
139}