pravega_client_config/
lib.rs

1//
2// Copyright (c) Dell Inc., or its subsidiaries. All Rights Reserved.
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10#![deny(
11    clippy::all,
12    clippy::cargo,
13    clippy::else_if_without_else,
14    clippy::empty_line_after_outer_attr,
15    clippy::multiple_inherent_impl,
16    clippy::mut_mut,
17    clippy::path_buf_push_overwrite
18)]
19#![warn(
20    clippy::cargo_common_metadata,
21    clippy::mutex_integer,
22    clippy::needless_borrow,
23    clippy::similar_names
24)]
25#![allow(clippy::multiple_crate_versions)]
26pub mod connection_type;
27pub mod credentials;
28
29use crate::connection_type::ConnectionType;
30use crate::credentials::Credentials;
31use derive_builder::*;
32use getset::{CopyGetters, Getters};
33use pravega_client_retry::retry_policy::RetryWithBackoff;
34use pravega_client_shared::PravegaNodeUri;
35use std::collections::HashMap;
36use std::time::Duration;
37use std::{env, fs};
38use tracing::debug;
39
40pub const MOCK_CONTROLLER_URI: (&str, u16) = ("localhost", 9090);
41const AUTH_METHOD: &str = "method";
42const AUTH_USERNAME: &str = "username";
43const AUTH_PASSWORD: &str = "password";
44const AUTH_TOKEN: &str = "token";
45const AUTH_KEYCLOAK_PATH: &str = "keycloak";
46const AUTH_PROPS_PREFIX_ENV: &str = "pravega_client_auth_";
47
48const TLS_CERT_PATH_ENV: &str = "pravega_client_tls_cert_path";
49const DEFAULT_TLS_CERT_PATH: &str = "./certs";
50const TLS_SCHEMES: [&str; 4] = ["tls", "ssl", "tcps", "pravegas"];
51
52#[derive(Builder, Debug, Getters, CopyGetters, Clone)]
53#[builder(setter(into), build_fn(validate = "Self::validate"))]
54pub struct ClientConfig {
55    #[get_copy = "pub"]
56    #[builder(default = "u32::max_value()")]
57    pub max_connections_in_pool: u32,
58
59    #[get_copy = "pub"]
60    #[builder(default = "3u32")]
61    pub max_controller_connections: u32,
62
63    #[get_copy = "pub"]
64    #[builder(default = "ConnectionType::Tokio")]
65    pub connection_type: ConnectionType,
66
67    #[get_copy = "pub"]
68    #[builder(default = "RetryWithBackoff::default_setting()")]
69    pub retry_policy: RetryWithBackoff,
70
71    #[get]
72    pub controller_uri: PravegaNodeUri,
73
74    #[get_copy = "pub"]
75    #[builder(default = "90 * 1000")]
76    pub transaction_timeout_time: u64,
77
78    #[get_copy = "pub"]
79    #[builder(default = "false")]
80    pub mock: bool,
81
82    #[get_copy = "pub"]
83    #[builder(default = "self.default_is_tls_enabled()")]
84    pub is_tls_enabled: bool,
85
86    #[builder(default = "false")]
87    pub disable_cert_verification: bool,
88
89    #[builder(default = "self.extract_trustcerts()")]
90    pub trustcerts: Vec<String>,
91
92    #[builder(default = "self.extract_credentials()")]
93    pub credentials: Credentials,
94
95    #[get_copy = "pub"]
96    #[builder(default = "false")]
97    pub is_auth_enabled: bool,
98
99    #[get_copy = "pub"]
100    #[builder(default = "1024 * 1024")]
101    pub reader_wrapper_buffer_size: usize,
102
103    #[get_copy = "pub"]
104    #[builder(default = "self.default_timeout()")]
105    pub request_timeout: Duration,
106}
107
108impl ClientConfigBuilder {
109    fn extract_trustcerts(&self) -> Vec<String> {
110        if let Some(enable) = self.is_tls_enabled {
111            if !enable {
112                return vec![];
113            }
114        } else if !self.default_is_tls_enabled() {
115            return vec![];
116        } else {
117            // fall through to parsing certs path
118        }
119        let ret_val = env::vars()
120            .filter(|(k, _v)| k.starts_with(TLS_CERT_PATH_ENV))
121            .collect::<HashMap<String, String>>();
122        let cert_path = if ret_val.contains_key(TLS_CERT_PATH_ENV) {
123            ret_val.get(TLS_CERT_PATH_ENV).expect("get tls cert path")
124        } else {
125            DEFAULT_TLS_CERT_PATH
126        };
127        let mut certs = vec![];
128        let meta = std::fs::metadata(cert_path).expect("should be valid path");
129        if meta.is_dir() {
130            let cert_paths =
131                std::fs::read_dir(cert_path).expect("cannot read from the provided cert directory");
132            for entry in cert_paths {
133                let path = entry.expect("get the cert file").path();
134                debug!("reading cert file {}", path.display());
135                certs.push(fs::read_to_string(path.clone()).expect("read cert file"));
136            }
137        } else if meta.is_file() {
138            certs.push(fs::read_to_string(cert_path).expect("read cert file"));
139        } else {
140            panic!("invalid cert path");
141        }
142        certs
143    }
144
145    fn extract_credentials(&self) -> Credentials {
146        let ret_val = env::vars()
147            .filter(|(k, _v)| k.starts_with(AUTH_PROPS_PREFIX_ENV))
148            .map(|(k, v)| {
149                let k = &k[AUTH_PROPS_PREFIX_ENV.len()..];
150                (k.to_owned(), v)
151            })
152            .collect::<HashMap<String, String>>();
153        if ret_val.contains_key(AUTH_METHOD) {
154            let method = ret_val.get(AUTH_METHOD).expect("get auth method").to_owned();
155            if method == credentials::BASIC {
156                if let Some(token) = ret_val.get(AUTH_TOKEN) {
157                    return Credentials::basic_with_token(token.to_string());
158                }
159                let username = ret_val.get(AUTH_USERNAME).expect("get auth username").to_owned();
160                let password = ret_val.get(AUTH_PASSWORD).expect("get auth password").to_owned();
161                return Credentials::basic(username, password);
162            }
163            if method == credentials::BEARER {
164                let path = ret_val.get(AUTH_KEYCLOAK_PATH).expect("get keycloak json file");
165                let mut disable_cert_verification = false;
166                if self.disable_cert_verification.is_some() && self.disable_cert_verification.unwrap() {
167                    disable_cert_verification = true;
168                }
169                return Credentials::keycloak(path, disable_cert_verification);
170            }
171        }
172        Credentials::basic("".into(), "".into())
173    }
174
175    fn default_timeout(&self) -> Duration {
176        Duration::from_secs(30)
177    }
178
179    fn default_is_tls_enabled(&self) -> bool {
180        if let Some(controller_uri) = &self.controller_uri {
181            return match controller_uri.scheme() {
182                Ok(scheme) => TLS_SCHEMES.contains(&&*scheme),
183                Err(_) => false,
184            };
185        }
186        false
187    }
188    /// validate the builder before returning it
189    ///
190    /// if is_tls_enabled, controller_uri have been set and if controller_uri
191    /// contains a scheme, then verify that the uri scheme matches the is_tls_enabled
192    /// value.
193    fn validate(&self) -> Result<(), String> {
194        if self.is_tls_enabled.is_none()    // is_tls_enabled not specified
195            || self.controller_uri.is_none()    // controller_uri not specified
196            || self
197                .controller_uri
198                .as_ref()
199                .unwrap()
200                .scheme()
201                .unwrap_or_default()
202                .is_empty()
203        {
204            // at least one option has not been specified or uri has no scheme,
205            // therefore cannot have a conflict with is_tls_enabled
206            return Ok(());
207        }
208        let is_tls_enabled = self.is_tls_enabled.unwrap();
209        let scheme_is_type_tls = self.default_is_tls_enabled();
210        if is_tls_enabled != scheme_is_type_tls {
211            Err(format!(
212                "is_tls_enabled option {} does not match scheme in uri {}",
213                is_tls_enabled,
214                **self.controller_uri.as_ref().unwrap()
215            ))
216        } else {
217            Ok(())
218        }
219    }
220}
221
222#[cfg(test)]
223mod tests {
224    use super::*;
225    use base64::encode;
226    use serial_test::serial;
227    use std::net::Ipv4Addr;
228
229    #[test]
230    #[serial]
231    fn test_get_set() {
232        let config = ClientConfigBuilder::default()
233            .max_connections_in_pool(15 as u32)
234            .connection_type(ConnectionType::Tokio)
235            .retry_policy(RetryWithBackoff::default_setting().initial_delay(Duration::from_millis(1000)))
236            .controller_uri(PravegaNodeUri::from("127.0.0.2:9091".to_string()))
237            .build()
238            .unwrap();
239
240        assert_eq!(config.max_connections_in_pool(), 15 as u32);
241        assert_eq!(config.connection_type(), ConnectionType::Tokio);
242        assert_eq!(
243            config.retry_policy(),
244            RetryWithBackoff::default_setting().initial_delay(Duration::from_millis(1000))
245        );
246        assert_eq!(
247            config.controller_uri().to_socket_addr().ip(),
248            Ipv4Addr::new(127, 0, 0, 2)
249        );
250        assert_eq!(config.controller_uri().to_socket_addr().port(), 9091);
251    }
252
253    #[test]
254    #[serial]
255    fn test_get_default() {
256        let config = ClientConfigBuilder::default()
257            .controller_uri(MOCK_CONTROLLER_URI)
258            .build()
259            .unwrap();
260
261        assert_eq!(config.max_connections_in_pool(), u32::MAX as u32);
262        assert_eq!(config.max_controller_connections(), 3u32);
263        assert_eq!(config.connection_type(), ConnectionType::Tokio);
264        assert_eq!(config.retry_policy(), RetryWithBackoff::default_setting());
265    }
266
267    #[test]
268    #[serial]
269    fn test_extract_credentials() {
270        let rt = tokio::runtime::Runtime::new().expect("create runtime");
271        // test empty env
272        let config = ClientConfigBuilder::default()
273            .controller_uri("127.0.0.2:9091".to_string())
274            .build()
275            .unwrap();
276
277        let token = encode(":");
278
279        assert_eq!(
280            rt.block_on(config.credentials.get_request_metadata()),
281            format!("{} {}", "Basic", token)
282        );
283
284        // retrieve from env
285        env::set_var("pravega_client_auth_method", "Basic");
286        env::set_var("pravega_client_auth_username", "hello");
287        env::set_var("pravega_client_auth_password", "12345");
288
289        let config = ClientConfigBuilder::default()
290            .controller_uri("127.0.0.2:9091".to_string())
291            .build()
292            .unwrap();
293
294        let token = encode("hello:12345");
295        assert_eq!(
296            rt.block_on(config.credentials.get_request_metadata()),
297            format!("{} {}", "Basic", token)
298        );
299
300        // retrieve from env with priority
301        env::set_var("pravega_client_auth_token", "ABCDE");
302        let config = ClientConfigBuilder::default()
303            .controller_uri("127.0.0.2:9091".to_string())
304            .build()
305            .unwrap();
306        assert_eq!(
307            rt.block_on(config.credentials.get_request_metadata()),
308            format!("{} {}", "Basic", "ABCDE")
309        );
310        env::remove_var("pravega_client_auth_method");
311        env::remove_var("pravega_client_auth_username");
312        env::remove_var("pravega_client_auth_password");
313        env::remove_var("pravega_client_auth_token");
314    }
315
316    #[test]
317    #[serial]
318    fn test_extract_tls_cert_path() {
319        // test default
320        fs::create_dir_all(DEFAULT_TLS_CERT_PATH).expect("create default cert path");
321        fs::File::create(format!("{}/foo.crt", DEFAULT_TLS_CERT_PATH)).expect("create crt");
322        let config = ClientConfigBuilder::default()
323            .controller_uri("tls://127.0.0.2:9091".to_string())
324            .is_tls_enabled(true)
325            .build()
326            .unwrap();
327        assert_eq!(config.trustcerts.len(), 1);
328
329        // test w/ tls uri prefix
330        let config = ClientConfigBuilder::default()
331            .controller_uri("tls://127.0.0.2:9091")
332            .build()
333            .unwrap();
334        assert_eq!(config.trustcerts.len(), 1);
335
336        // test w/o tls uri prefix
337        let config = ClientConfigBuilder::default()
338            .controller_uri("tcp://127.0.0.2:9091".to_string())
339            .build()
340            .unwrap();
341        assert_eq!(config.trustcerts.len(), 0);
342
343        // test conflicting tls setting vs scheme
344        let conflicted_config1 = ClientConfigBuilder::default()
345            .controller_uri("pravegas://127.0.0.2:9091")
346            .is_tls_enabled(false)
347            .build();
348        assert!(conflicted_config1.is_err());
349
350        // test alternate conflicting tlst setting vs scheme
351        let conflicted_config2 = ClientConfigBuilder::default()
352            .controller_uri("tcp://127.0.0.2:9091".to_string())
353            .is_tls_enabled(true)
354            .build();
355
356        assert!(conflicted_config2.is_err());
357
358        fs::remove_dir_all(DEFAULT_TLS_CERT_PATH).expect("remove dir");
359        // test with env var set
360        fs::create_dir_all("./bar").expect("create default cert path");
361        fs::File::create(format!("./bar/foo.crt")).expect("create crt");
362        env::set_var("pravega_client_tls_cert_path", "./bar");
363        let config = ClientConfigBuilder::default()
364            .controller_uri("tls://127.0.0.2:9091".to_string())
365            .is_tls_enabled(true)
366            .build()
367            .unwrap();
368        assert_eq!(config.trustcerts.len(), 1);
369        // test with file path
370        env::set_var("pravega_client_tls_cert_path", "./bar/foo.crt");
371        let config = ClientConfigBuilder::default()
372            .controller_uri("tls://127.0.0.2:9091".to_string())
373            .is_tls_enabled(true)
374            .build()
375            .unwrap();
376        assert_eq!(config.trustcerts.len(), 1);
377        env::set_var("pravega_client_tls_cert_path", "./wrong/path");
378        // test with invalid path
379        let result = std::panic::catch_unwind(|| {
380            ClientConfigBuilder::default()
381                .controller_uri("tls://127.0.0.2:9091".to_string())
382                .is_tls_enabled(true)
383                .build()
384                .unwrap()
385        });
386        assert!(result.is_err());
387        fs::remove_dir_all("./bar").expect("remove dir");
388        env::remove_var("pravega_client_tls_cert_path");
389    }
390}