1use duration_str::deserialize_duration;
5use std::time::Duration;
6use std::{collections::HashMap, str::FromStr};
7use tower::ServiceExt;
8
9use http::header::{HeaderMap, HeaderName, HeaderValue};
10use hyper_rustls;
11use hyper_util::client::legacy::connect::HttpConnector;
12use serde::Deserialize;
13use tonic::codegen::{Body, Bytes, StdError};
14use tonic::transport::{Channel, Uri};
15use tracing::warn;
16
17use super::compression::CompressionType;
18use super::errors::ConfigError;
19use super::headers_middleware::SetRequestHeaderLayer;
20use crate::auth::ClientAuthenticator;
21use crate::auth::basic::Config as BasicAuthenticationConfig;
22use crate::auth::bearer::Config as BearerAuthenticationConfig;
23use crate::component::configuration::{Configuration, ConfigurationError};
24use crate::tls::{client::TlsClientConfig as TLSSetting, common::RustlsConfigLoader};
25
26#[derive(Debug, Deserialize, PartialEq, Clone)]
31pub struct KeepaliveConfig {
32 #[serde(
34 default = "default_tcp_keepalive",
35 deserialize_with = "deserialize_duration"
36 )]
37 pub tcp_keepalive: Duration,
38
39 #[serde(
41 default = "default_http2_keepalive",
42 deserialize_with = "deserialize_duration"
43 )]
44 pub http2_keepalive: Duration,
45
46 #[serde(default = "default_timeout", deserialize_with = "deserialize_duration")]
48 pub timeout: Duration,
49
50 #[serde(default = "default_keep_alive_while_idle")]
52 pub keep_alive_while_idle: bool,
53}
54
55impl Default for KeepaliveConfig {
57 fn default() -> Self {
58 KeepaliveConfig {
59 tcp_keepalive: default_tcp_keepalive(),
60 http2_keepalive: default_http2_keepalive(),
61 timeout: default_timeout(),
62 keep_alive_while_idle: default_keep_alive_while_idle(),
63 }
64 }
65}
66
67fn default_tcp_keepalive() -> Duration {
68 Duration::from_secs(60)
69}
70
71fn default_http2_keepalive() -> Duration {
72 Duration::from_secs(60)
73}
74
75fn default_timeout() -> Duration {
76 Duration::from_secs(10)
77}
78
79fn default_keep_alive_while_idle() -> bool {
80 false
81}
82
83#[derive(Debug, Default, Deserialize, Clone, PartialEq)]
85#[serde(rename_all = "snake_case")]
86pub enum AuthenticationConfig {
87 Basic(BasicAuthenticationConfig),
89 Bearer(BearerAuthenticationConfig),
91 #[default]
93 None,
94}
95
96#[derive(Debug, Deserialize, Clone, PartialEq)]
102pub struct ClientConfig {
103 pub endpoint: String,
105
106 pub origin: Option<String>,
108
109 pub compression: Option<CompressionType>,
111
112 pub rate_limit: Option<String>,
114
115 #[serde(default, rename = "tls")]
117 pub tls_setting: TLSSetting,
118
119 pub keepalive: Option<KeepaliveConfig>,
121
122 #[serde(
124 default = "default_connect_timeout",
125 deserialize_with = "deserialize_duration"
126 )]
127 pub connect_timeout: Duration,
128
129 #[serde(
131 default = "default_request_timeout",
132 deserialize_with = "deserialize_duration"
133 )]
134 pub request_timeout: Duration,
135
136 pub buffer_size: Option<usize>,
138
139 #[serde(default)]
141 pub headers: HashMap<String, String>,
142
143 #[serde(default)]
145 #[serde(with = "serde_yaml::with::singleton_map")]
146 pub auth: AuthenticationConfig,
147}
148
149impl Default for ClientConfig {
151 fn default() -> Self {
152 ClientConfig {
153 endpoint: String::new(),
154 origin: None,
155 compression: None,
156 rate_limit: None,
157 tls_setting: TLSSetting::default(),
158 keepalive: None,
159 connect_timeout: default_connect_timeout(),
160 request_timeout: default_request_timeout(),
161 buffer_size: None,
162 headers: HashMap::new(),
163 auth: AuthenticationConfig::None,
164 }
165 }
166}
167
168fn default_connect_timeout() -> Duration {
169 Duration::from_secs(0)
170}
171
172fn default_request_timeout() -> Duration {
173 Duration::from_secs(0)
174}
175
176impl std::fmt::Display for ClientConfig {
178 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
179 write!(
180 f,
181 "ClientConfig {{ endpoint: {}, origin: {:?}, compression: {:?}, rate_limit: {:?}, tls_setting: {:?}, keepalive: {:?}, connect_timeout: {:?}, request_timeout: {:?}, buffer_size: {:?}, headers: {:?}, auth: {:?} }}",
182 self.endpoint,
183 self.origin,
184 self.compression,
185 self.rate_limit,
186 self.tls_setting,
187 self.keepalive,
188 self.connect_timeout,
189 self.request_timeout,
190 self.buffer_size,
191 self.headers,
192 self.auth
193 )
194 }
195}
196
197impl Configuration for ClientConfig {
198 fn validate(&self) -> Result<(), ConfigurationError> {
199 self.tls_setting.validate()
201 }
202}
203
204impl ClientConfig {
205 pub fn with_endpoint(endpoint: &str) -> Self {
209 Self {
210 endpoint: endpoint.to_string(),
211 ..Self::default()
212 }
213 }
214
215 pub fn with_origin(self, origin: &str) -> Self {
216 Self {
217 origin: Some(origin.to_string()),
218 ..self
219 }
220 }
221
222 pub fn with_compression(self, compression: CompressionType) -> Self {
223 Self {
224 compression: Some(compression),
225 ..self
226 }
227 }
228
229 pub fn with_rate_limit(self, rate_limit: &str) -> Self {
230 Self {
231 rate_limit: Some(rate_limit.to_string()),
232 ..self
233 }
234 }
235
236 pub fn with_tls_setting(self, tls_setting: TLSSetting) -> Self {
237 Self {
238 tls_setting,
239 ..self
240 }
241 }
242
243 pub fn with_keepalive(self, keepalive: KeepaliveConfig) -> Self {
244 Self {
245 keepalive: Some(keepalive),
246 ..self
247 }
248 }
249
250 pub fn with_connect_timeout(self, connect_timeout: Duration) -> Self {
251 Self {
252 connect_timeout,
253 ..self
254 }
255 }
256
257 pub fn with_request_timeout(self, request_timeout: Duration) -> Self {
258 Self {
259 request_timeout,
260 ..self
261 }
262 }
263
264 pub fn with_buffer_size(self, buffer_size: usize) -> Self {
265 Self {
266 buffer_size: Some(buffer_size),
267 ..self
268 }
269 }
270
271 pub fn with_headers(self, headers: HashMap<String, String>) -> Self {
272 Self { headers, ..self }
273 }
274
275 pub fn with_auth(self, auth: AuthenticationConfig) -> Self {
276 Self { auth, ..self }
277 }
278
279 pub fn to_channel(
285 &self,
286 ) -> Result<
287 impl tonic::client::GrpcService<
288 tonic::body::Body,
289 Error: Into<StdError> + Send,
290 ResponseBody: Body<Data = Bytes, Error: Into<StdError> + std::marker::Send>
291 + Send
292 + 'static,
293 Future: Send,
294 > + Send
295 + use<>,
296 ConfigError,
297 > {
298 if self.endpoint.is_empty() {
300 return Err(ConfigError::MissingEndpoint);
301 }
302
303 let uri =
305 Uri::from_str(&self.endpoint).map_err(|e| ConfigError::UriParseError(e.to_string()))?;
306 let builder = Channel::builder(uri);
307
308 let mut http = HttpConnector::new();
311
312 http.enforce_http(false);
314 http.set_nodelay(false);
315
316 match self.connect_timeout.as_secs() {
318 0 => http.set_connect_timeout(None),
319 _ => http.set_connect_timeout(Some(self.connect_timeout)),
320 }
321
322 let builder = match self.buffer_size {
324 Some(size) => builder.buffer_size(size),
325 None => builder,
326 };
327
328 let builder = match &self.keepalive {
330 Some(keepalive) => {
331 http.set_keepalive(Some(keepalive.tcp_keepalive));
333
334 builder
335 .keep_alive_timeout(keepalive.timeout)
336 .keep_alive_while_idle(keepalive.keep_alive_while_idle)
337 .http2_keep_alive_interval(keepalive.http2_keepalive)
339 }
340 None => builder,
341 };
342
343 let builder = match &self.origin {
345 Some(origin) => {
346 let uri = Uri::from_str(origin.as_str())
347 .map_err(|e| ConfigError::UriParseError(e.to_string()))?;
348
349 builder.origin(uri)
350 }
351 None => builder,
352 };
353
354 let builder = match &self.rate_limit {
355 Some(rate_limit) => {
356 let (limit, duration) = parse_rate_limit(rate_limit)
357 .map_err(|e| ConfigError::RateLimitParseError(e.to_string()))?;
358 builder.rate_limit(limit, duration)
359 }
360 None => builder,
361 };
362
363 let builder = match self.request_timeout.as_secs() {
365 0 => builder,
366 _ => builder.timeout(self.request_timeout),
367 };
368
369 let mut header_map = HeaderMap::new();
371 for (key, value) in &self.headers {
372 let k: HeaderName = key.parse().map_err(|_| {
373 ConfigError::HeaderParseError(format!("error parsing header key {}", key))
374 })?;
375 let v: HeaderValue = value.parse().map_err(|_| {
376 ConfigError::HeaderParseError(format!("error parsing header value {}", key))
377 })?;
378
379 header_map.insert(k, v);
380 }
381
382 let tls_config = TLSSetting::load_rustls_config(&self.tls_setting)
384 .map_err(|e| ConfigError::TLSSettingError(e.to_string()))?;
385
386 let channel = match tls_config {
387 Some(tls) => {
388 let connector = tower::ServiceBuilder::new()
389 .layer_fn(move |s| {
390 let tls = tls.clone();
391
392 hyper_rustls::HttpsConnectorBuilder::new()
393 .with_tls_config(tls)
394 .https_or_http()
395 .enable_http2()
396 .wrap_connector(s)
397 })
398 .service(http);
399
400 builder.connect_with_connector_lazy(connector)
401 }
402 None => builder.connect_with_connector_lazy(http),
403 };
404
405 match &self.auth {
407 AuthenticationConfig::Basic(basic) => {
408 let auth_layer = basic
409 .get_client_layer()
410 .map_err(|e| ConfigError::AuthConfigError(e.to_string()))?;
411
412 if self.tls_setting.insecure {
414 warn!("Auth is enabled without TLS. This is not recommended.");
415 }
416
417 Ok(tower::ServiceBuilder::new()
418 .layer(SetRequestHeaderLayer::new(header_map))
419 .layer(auth_layer)
420 .service(channel)
421 .boxed())
422 }
423 AuthenticationConfig::Bearer(bearer) => {
424 let auth_layer = bearer
425 .get_client_layer()
426 .map_err(|e| ConfigError::AuthConfigError(e.to_string()))?;
427
428 if self.tls_setting.insecure {
430 warn!("Auth is enabled without TLS. This is not recommended.");
431 }
432
433 Ok(tower::ServiceBuilder::new()
434 .layer(SetRequestHeaderLayer::new(header_map))
435 .layer(auth_layer)
436 .service(channel)
437 .boxed())
438 }
439 AuthenticationConfig::None => Ok(tower::ServiceBuilder::new()
440 .layer(SetRequestHeaderLayer::new(header_map))
441 .service(channel)
442 .boxed()),
443 }
444 }
445}
446
447fn parse_rate_limit(rate_limit: &str) -> Result<(u64, Duration), ConfigError> {
453 let parts: Vec<&str> = rate_limit.split('/').collect();
454
455 if parts.len() != 2 {
457 return Err(
458 ConfigError::RateLimitParseError(
459 "rate limit should be in the format of <limit>/<duration>, with duration expressed in seconds".to_string(),
460 ),
461 );
462 }
463
464 let limit = parts[0]
465 .parse::<u64>()
466 .map_err(|e| ConfigError::RateLimitParseError(e.to_string()))?;
467 let duration = Duration::from_secs(
468 parts[1]
469 .parse::<u64>()
470 .map_err(|e| ConfigError::RateLimitParseError(e.to_string()))?,
471 );
472 Ok((limit, duration))
473}
474
475#[cfg(test)]
476mod test {
477 #[allow(unused_imports)]
478 use super::*;
479 use tracing::debug;
480 use tracing_test::traced_test;
481
482 #[test]
483 fn test_default_keepalive_config() {
484 let keepalive = KeepaliveConfig::default();
485 assert_eq!(keepalive.tcp_keepalive, Duration::from_secs(60));
486 assert_eq!(keepalive.http2_keepalive, Duration::from_secs(60));
487 assert_eq!(keepalive.timeout, Duration::from_secs(10));
488 assert!(!keepalive.keep_alive_while_idle);
489 }
490
491 #[test]
492 fn test_default_client_config() {
493 let client = ClientConfig::default();
494 assert_eq!(client.endpoint, String::new());
495 assert_eq!(client.origin, None);
496 assert_eq!(client.compression, None);
497 assert_eq!(client.rate_limit, None);
498 assert_eq!(client.tls_setting, TLSSetting::default());
499 assert_eq!(client.keepalive, None);
500 assert_eq!(client.connect_timeout, Duration::from_secs(0));
501 assert_eq!(client.request_timeout, Duration::from_secs(0));
502 assert_eq!(client.buffer_size, None);
503 assert_eq!(client.headers, HashMap::new());
504 assert_eq!(client.auth, AuthenticationConfig::None);
505 }
506
507 #[test]
508 fn test_parse_rate_limit() {
509 let res = parse_rate_limit("100/10");
510 assert!(res.is_ok());
511
512 let (limit, duration) = res.unwrap();
513
514 assert_eq!(limit, 100);
515 assert_eq!(duration, Duration::from_secs(10));
516
517 let res = parse_rate_limit("100");
518 assert!(res.is_err());
519 }
520
521 #[tokio::test]
522 #[traced_test]
523 async fn test_to_channel() {
524 let test_path: &str = env!("CARGO_MANIFEST_DIR");
525
526 let mut client = ClientConfig::default();
528
529 let mut channel = client.to_channel();
531 assert!(channel.is_err());
532
533 client.endpoint = "http://localhost:8080".to_string();
535 channel = client.to_channel();
536 assert!(channel.is_ok());
537
538 client.tls_setting.insecure = true;
540 channel = client.to_channel();
541 assert!(channel.is_ok());
542
543 client.tls_setting = {
545 let mut tls = TLSSetting::default();
546 tls.config.ca_file = Some(format!("{}/testdata/grpc/{}", test_path, "ca.crt"));
547 tls.insecure = false;
548 tls
549 };
550 debug!("{}/testdata/{}", test_path, "ca.crt");
551 channel = client.to_channel();
552 assert!(channel.is_ok());
553
554 client.keepalive = Some(KeepaliveConfig::default());
556 channel = client.to_channel();
557 assert!(channel.is_ok());
558
559 client.rate_limit = Some("100/10".to_string());
561 channel = client.to_channel();
562 assert!(channel.is_ok());
563
564 client.rate_limit = Some("100".to_string());
566 channel = client.to_channel();
567 assert!(channel.is_err());
568
569 client.rate_limit = None;
571
572 client.request_timeout = Duration::from_secs(10);
574 channel = client.to_channel();
575 assert!(channel.is_ok());
576
577 client.buffer_size = Some(1024);
579 channel = client.to_channel();
580 assert!(channel.is_ok());
581
582 client.origin = Some("http://example.com".to_string());
584 channel = client.to_channel();
585 assert!(channel.is_ok());
586
587 client
589 .headers
590 .insert("X-Test".to_string(), "test".to_string());
591 channel = client.to_channel();
592 assert!(channel.is_ok());
593 }
594}