pravega_client_config/
lib.rs1#![deny(
11 clippy::all,
12 clippy::cargo,
13 clippy::else_if_without_else,
14 clippy::empty_line_after_outer_attr,
15 clippy::multiple_inherent_impl,
16 clippy::mut_mut,
17 clippy::path_buf_push_overwrite
18)]
19#![warn(
20 clippy::cargo_common_metadata,
21 clippy::mutex_integer,
22 clippy::needless_borrow,
23 clippy::similar_names
24)]
25#![allow(clippy::multiple_crate_versions)]
26pub mod connection_type;
27pub mod credentials;
28
29use crate::connection_type::ConnectionType;
30use crate::credentials::Credentials;
31use derive_builder::*;
32use getset::{CopyGetters, Getters};
33use pravega_client_retry::retry_policy::RetryWithBackoff;
34use pravega_client_shared::PravegaNodeUri;
35use std::collections::HashMap;
36use std::time::Duration;
37use std::{env, fs};
38use tracing::debug;
39
40pub const MOCK_CONTROLLER_URI: (&str, u16) = ("localhost", 9090);
41const AUTH_METHOD: &str = "method";
42const AUTH_USERNAME: &str = "username";
43const AUTH_PASSWORD: &str = "password";
44const AUTH_TOKEN: &str = "token";
45const AUTH_KEYCLOAK_PATH: &str = "keycloak";
46const AUTH_PROPS_PREFIX_ENV: &str = "pravega_client_auth_";
47
48const TLS_CERT_PATH_ENV: &str = "pravega_client_tls_cert_path";
49const DEFAULT_TLS_CERT_PATH: &str = "./certs";
50const TLS_SCHEMES: [&str; 4] = ["tls", "ssl", "tcps", "pravegas"];
51
52#[derive(Builder, Debug, Getters, CopyGetters, Clone)]
53#[builder(setter(into), build_fn(validate = "Self::validate"))]
54pub struct ClientConfig {
55 #[get_copy = "pub"]
56 #[builder(default = "u32::max_value()")]
57 pub max_connections_in_pool: u32,
58
59 #[get_copy = "pub"]
60 #[builder(default = "3u32")]
61 pub max_controller_connections: u32,
62
63 #[get_copy = "pub"]
64 #[builder(default = "ConnectionType::Tokio")]
65 pub connection_type: ConnectionType,
66
67 #[get_copy = "pub"]
68 #[builder(default = "RetryWithBackoff::default_setting()")]
69 pub retry_policy: RetryWithBackoff,
70
71 #[get]
72 pub controller_uri: PravegaNodeUri,
73
74 #[get_copy = "pub"]
75 #[builder(default = "90 * 1000")]
76 pub transaction_timeout_time: u64,
77
78 #[get_copy = "pub"]
79 #[builder(default = "false")]
80 pub mock: bool,
81
82 #[get_copy = "pub"]
83 #[builder(default = "self.default_is_tls_enabled()")]
84 pub is_tls_enabled: bool,
85
86 #[builder(default = "false")]
87 pub disable_cert_verification: bool,
88
89 #[builder(default = "self.extract_trustcerts()")]
90 pub trustcerts: Vec<String>,
91
92 #[builder(default = "self.extract_credentials()")]
93 pub credentials: Credentials,
94
95 #[get_copy = "pub"]
96 #[builder(default = "false")]
97 pub is_auth_enabled: bool,
98
99 #[get_copy = "pub"]
100 #[builder(default = "1024 * 1024")]
101 pub reader_wrapper_buffer_size: usize,
102
103 #[get_copy = "pub"]
104 #[builder(default = "self.default_timeout()")]
105 pub request_timeout: Duration,
106}
107
108impl ClientConfigBuilder {
109 fn extract_trustcerts(&self) -> Vec<String> {
110 if let Some(enable) = self.is_tls_enabled {
111 if !enable {
112 return vec![];
113 }
114 } else if !self.default_is_tls_enabled() {
115 return vec![];
116 } else {
117 }
119 let ret_val = env::vars()
120 .filter(|(k, _v)| k.starts_with(TLS_CERT_PATH_ENV))
121 .collect::<HashMap<String, String>>();
122 let cert_path = if ret_val.contains_key(TLS_CERT_PATH_ENV) {
123 ret_val.get(TLS_CERT_PATH_ENV).expect("get tls cert path")
124 } else {
125 DEFAULT_TLS_CERT_PATH
126 };
127 let mut certs = vec![];
128 let meta = std::fs::metadata(cert_path).expect("should be valid path");
129 if meta.is_dir() {
130 let cert_paths =
131 std::fs::read_dir(cert_path).expect("cannot read from the provided cert directory");
132 for entry in cert_paths {
133 let path = entry.expect("get the cert file").path();
134 debug!("reading cert file {}", path.display());
135 certs.push(fs::read_to_string(path.clone()).expect("read cert file"));
136 }
137 } else if meta.is_file() {
138 certs.push(fs::read_to_string(cert_path).expect("read cert file"));
139 } else {
140 panic!("invalid cert path");
141 }
142 certs
143 }
144
145 fn extract_credentials(&self) -> Credentials {
146 let ret_val = env::vars()
147 .filter(|(k, _v)| k.starts_with(AUTH_PROPS_PREFIX_ENV))
148 .map(|(k, v)| {
149 let k = &k[AUTH_PROPS_PREFIX_ENV.len()..];
150 (k.to_owned(), v)
151 })
152 .collect::<HashMap<String, String>>();
153 if ret_val.contains_key(AUTH_METHOD) {
154 let method = ret_val.get(AUTH_METHOD).expect("get auth method").to_owned();
155 if method == credentials::BASIC {
156 if let Some(token) = ret_val.get(AUTH_TOKEN) {
157 return Credentials::basic_with_token(token.to_string());
158 }
159 let username = ret_val.get(AUTH_USERNAME).expect("get auth username").to_owned();
160 let password = ret_val.get(AUTH_PASSWORD).expect("get auth password").to_owned();
161 return Credentials::basic(username, password);
162 }
163 if method == credentials::BEARER {
164 let path = ret_val.get(AUTH_KEYCLOAK_PATH).expect("get keycloak json file");
165 let mut disable_cert_verification = false;
166 if self.disable_cert_verification.is_some() && self.disable_cert_verification.unwrap() {
167 disable_cert_verification = true;
168 }
169 return Credentials::keycloak(path, disable_cert_verification);
170 }
171 }
172 Credentials::basic("".into(), "".into())
173 }
174
175 fn default_timeout(&self) -> Duration {
176 Duration::from_secs(30)
177 }
178
179 fn default_is_tls_enabled(&self) -> bool {
180 if let Some(controller_uri) = &self.controller_uri {
181 return match controller_uri.scheme() {
182 Ok(scheme) => TLS_SCHEMES.contains(&&*scheme),
183 Err(_) => false,
184 };
185 }
186 false
187 }
188 fn validate(&self) -> Result<(), String> {
194 if self.is_tls_enabled.is_none() || self.controller_uri.is_none() || self
197 .controller_uri
198 .as_ref()
199 .unwrap()
200 .scheme()
201 .unwrap_or_default()
202 .is_empty()
203 {
204 return Ok(());
207 }
208 let is_tls_enabled = self.is_tls_enabled.unwrap();
209 let scheme_is_type_tls = self.default_is_tls_enabled();
210 if is_tls_enabled != scheme_is_type_tls {
211 Err(format!(
212 "is_tls_enabled option {} does not match scheme in uri {}",
213 is_tls_enabled,
214 **self.controller_uri.as_ref().unwrap()
215 ))
216 } else {
217 Ok(())
218 }
219 }
220}
221
222#[cfg(test)]
223mod tests {
224 use super::*;
225 use base64::encode;
226 use serial_test::serial;
227 use std::net::Ipv4Addr;
228
229 #[test]
230 #[serial]
231 fn test_get_set() {
232 let config = ClientConfigBuilder::default()
233 .max_connections_in_pool(15 as u32)
234 .connection_type(ConnectionType::Tokio)
235 .retry_policy(RetryWithBackoff::default_setting().initial_delay(Duration::from_millis(1000)))
236 .controller_uri(PravegaNodeUri::from("127.0.0.2:9091".to_string()))
237 .build()
238 .unwrap();
239
240 assert_eq!(config.max_connections_in_pool(), 15 as u32);
241 assert_eq!(config.connection_type(), ConnectionType::Tokio);
242 assert_eq!(
243 config.retry_policy(),
244 RetryWithBackoff::default_setting().initial_delay(Duration::from_millis(1000))
245 );
246 assert_eq!(
247 config.controller_uri().to_socket_addr().ip(),
248 Ipv4Addr::new(127, 0, 0, 2)
249 );
250 assert_eq!(config.controller_uri().to_socket_addr().port(), 9091);
251 }
252
253 #[test]
254 #[serial]
255 fn test_get_default() {
256 let config = ClientConfigBuilder::default()
257 .controller_uri(MOCK_CONTROLLER_URI)
258 .build()
259 .unwrap();
260
261 assert_eq!(config.max_connections_in_pool(), u32::MAX as u32);
262 assert_eq!(config.max_controller_connections(), 3u32);
263 assert_eq!(config.connection_type(), ConnectionType::Tokio);
264 assert_eq!(config.retry_policy(), RetryWithBackoff::default_setting());
265 }
266
267 #[test]
268 #[serial]
269 fn test_extract_credentials() {
270 let rt = tokio::runtime::Runtime::new().expect("create runtime");
271 let config = ClientConfigBuilder::default()
273 .controller_uri("127.0.0.2:9091".to_string())
274 .build()
275 .unwrap();
276
277 let token = encode(":");
278
279 assert_eq!(
280 rt.block_on(config.credentials.get_request_metadata()),
281 format!("{} {}", "Basic", token)
282 );
283
284 env::set_var("pravega_client_auth_method", "Basic");
286 env::set_var("pravega_client_auth_username", "hello");
287 env::set_var("pravega_client_auth_password", "12345");
288
289 let config = ClientConfigBuilder::default()
290 .controller_uri("127.0.0.2:9091".to_string())
291 .build()
292 .unwrap();
293
294 let token = encode("hello:12345");
295 assert_eq!(
296 rt.block_on(config.credentials.get_request_metadata()),
297 format!("{} {}", "Basic", token)
298 );
299
300 env::set_var("pravega_client_auth_token", "ABCDE");
302 let config = ClientConfigBuilder::default()
303 .controller_uri("127.0.0.2:9091".to_string())
304 .build()
305 .unwrap();
306 assert_eq!(
307 rt.block_on(config.credentials.get_request_metadata()),
308 format!("{} {}", "Basic", "ABCDE")
309 );
310 env::remove_var("pravega_client_auth_method");
311 env::remove_var("pravega_client_auth_username");
312 env::remove_var("pravega_client_auth_password");
313 env::remove_var("pravega_client_auth_token");
314 }
315
316 #[test]
317 #[serial]
318 fn test_extract_tls_cert_path() {
319 fs::create_dir_all(DEFAULT_TLS_CERT_PATH).expect("create default cert path");
321 fs::File::create(format!("{}/foo.crt", DEFAULT_TLS_CERT_PATH)).expect("create crt");
322 let config = ClientConfigBuilder::default()
323 .controller_uri("tls://127.0.0.2:9091".to_string())
324 .is_tls_enabled(true)
325 .build()
326 .unwrap();
327 assert_eq!(config.trustcerts.len(), 1);
328
329 let config = ClientConfigBuilder::default()
331 .controller_uri("tls://127.0.0.2:9091")
332 .build()
333 .unwrap();
334 assert_eq!(config.trustcerts.len(), 1);
335
336 let config = ClientConfigBuilder::default()
338 .controller_uri("tcp://127.0.0.2:9091".to_string())
339 .build()
340 .unwrap();
341 assert_eq!(config.trustcerts.len(), 0);
342
343 let conflicted_config1 = ClientConfigBuilder::default()
345 .controller_uri("pravegas://127.0.0.2:9091")
346 .is_tls_enabled(false)
347 .build();
348 assert!(conflicted_config1.is_err());
349
350 let conflicted_config2 = ClientConfigBuilder::default()
352 .controller_uri("tcp://127.0.0.2:9091".to_string())
353 .is_tls_enabled(true)
354 .build();
355
356 assert!(conflicted_config2.is_err());
357
358 fs::remove_dir_all(DEFAULT_TLS_CERT_PATH).expect("remove dir");
359 fs::create_dir_all("./bar").expect("create default cert path");
361 fs::File::create(format!("./bar/foo.crt")).expect("create crt");
362 env::set_var("pravega_client_tls_cert_path", "./bar");
363 let config = ClientConfigBuilder::default()
364 .controller_uri("tls://127.0.0.2:9091".to_string())
365 .is_tls_enabled(true)
366 .build()
367 .unwrap();
368 assert_eq!(config.trustcerts.len(), 1);
369 env::set_var("pravega_client_tls_cert_path", "./bar/foo.crt");
371 let config = ClientConfigBuilder::default()
372 .controller_uri("tls://127.0.0.2:9091".to_string())
373 .is_tls_enabled(true)
374 .build()
375 .unwrap();
376 assert_eq!(config.trustcerts.len(), 1);
377 env::set_var("pravega_client_tls_cert_path", "./wrong/path");
378 let result = std::panic::catch_unwind(|| {
380 ClientConfigBuilder::default()
381 .controller_uri("tls://127.0.0.2:9091".to_string())
382 .is_tls_enabled(true)
383 .build()
384 .unwrap()
385 });
386 assert!(result.is_err());
387 fs::remove_dir_all("./bar").expect("remove dir");
388 env::remove_var("pravega_client_tls_cert_path");
389 }
390}