Skip to main content

agp_config/grpc/
client.rs

1// Copyright AGNTCY Contributors (https://github.com/agntcy)
2// SPDX-License-Identifier: Apache-2.0
3
4use duration_str::deserialize_duration;
5use std::time::Duration;
6use std::{collections::HashMap, str::FromStr};
7use tower::ServiceExt;
8
9use http::header::{HeaderMap, HeaderName, HeaderValue};
10use hyper_rustls;
11use hyper_util::client::legacy::connect::HttpConnector;
12use serde::Deserialize;
13use tonic::codegen::{Body, Bytes, StdError};
14use tonic::transport::{Channel, Uri};
15use tracing::warn;
16
17use super::compression::CompressionType;
18use super::errors::ConfigError;
19use super::headers_middleware::SetRequestHeaderLayer;
20use crate::auth::ClientAuthenticator;
21use crate::auth::basic::Config as BasicAuthenticationConfig;
22use crate::auth::bearer::Config as BearerAuthenticationConfig;
23use crate::component::configuration::{Configuration, ConfigurationError};
24use crate::tls::{client::TlsClientConfig as TLSSetting, common::RustlsConfigLoader};
25
26/// Keepalive configuration for the client.
27/// This struct contains the keepalive time for TCP and HTTP2,
28/// the timeout duration for the keepalive, and whether to permit
29/// keepalive without an active stream.
30#[derive(Debug, Deserialize, PartialEq, Clone)]
31pub struct KeepaliveConfig {
32    /// The duration of the keepalive time for TCP
33    #[serde(
34        default = "default_tcp_keepalive",
35        deserialize_with = "deserialize_duration"
36    )]
37    pub tcp_keepalive: Duration,
38
39    /// The duration of the keepalive time for HTTP2
40    #[serde(
41        default = "default_http2_keepalive",
42        deserialize_with = "deserialize_duration"
43    )]
44    pub http2_keepalive: Duration,
45
46    /// The timeout duration for the keepalive
47    #[serde(default = "default_timeout", deserialize_with = "deserialize_duration")]
48    pub timeout: Duration,
49
50    /// Whether to permit keepalive without an active stream
51    #[serde(default = "default_keep_alive_while_idle")]
52    pub keep_alive_while_idle: bool,
53}
54
55/// Defaults for KeepaliveConfig
56impl Default for KeepaliveConfig {
57    fn default() -> Self {
58        KeepaliveConfig {
59            tcp_keepalive: default_tcp_keepalive(),
60            http2_keepalive: default_http2_keepalive(),
61            timeout: default_timeout(),
62            keep_alive_while_idle: default_keep_alive_while_idle(),
63        }
64    }
65}
66
67fn default_tcp_keepalive() -> Duration {
68    Duration::from_secs(60)
69}
70
71fn default_http2_keepalive() -> Duration {
72    Duration::from_secs(60)
73}
74
75fn default_timeout() -> Duration {
76    Duration::from_secs(10)
77}
78
79fn default_keep_alive_while_idle() -> bool {
80    false
81}
82
83/// Enum holding one configuration for the client.
84#[derive(Debug, Default, Deserialize, Clone, PartialEq)]
85#[serde(rename_all = "snake_case")]
86pub enum AuthenticationConfig {
87    /// Basic authentication configuration.
88    Basic(BasicAuthenticationConfig),
89    /// Bearer authentication configuration.
90    Bearer(BearerAuthenticationConfig),
91    /// None
92    #[default]
93    None,
94}
95
96/// Struct for the client configuration.
97/// This struct contains the endpoint, origin, compression type, rate limit,
98/// TLS settings, keepalive settings, timeout settings, buffer size settings,
99/// headers, and auth settings.
100/// The client configuration can be converted to a tonic channel.
101#[derive(Debug, Deserialize, Clone, PartialEq)]
102pub struct ClientConfig {
103    /// The target the client will connect to.
104    pub endpoint: String,
105
106    /// Origin for the client.
107    pub origin: Option<String>,
108
109    /// Compression type - TODO(msardara): not implemented yet.
110    pub compression: Option<CompressionType>,
111
112    /// Rate Limits
113    pub rate_limit: Option<String>,
114
115    /// TLS client configuration.
116    #[serde(default, rename = "tls")]
117    pub tls_setting: TLSSetting,
118
119    /// Keepalive parameters.
120    pub keepalive: Option<KeepaliveConfig>,
121
122    /// Timeout for the connection.
123    #[serde(
124        default = "default_connect_timeout",
125        deserialize_with = "deserialize_duration"
126    )]
127    pub connect_timeout: Duration,
128
129    /// Timeout per request.
130    #[serde(
131        default = "default_request_timeout",
132        deserialize_with = "deserialize_duration"
133    )]
134    pub request_timeout: Duration,
135
136    /// ReadBufferSize.
137    pub buffer_size: Option<usize>,
138
139    /// The headers associated with gRPC requests.
140    #[serde(default)]
141    pub headers: HashMap<String, String>,
142
143    /// Auth configuration for outgoing RPCs.
144    #[serde(default)]
145    #[serde(with = "serde_yaml::with::singleton_map")]
146    pub auth: AuthenticationConfig,
147}
148
149/// Defaults for ClientConfig
150impl Default for ClientConfig {
151    fn default() -> Self {
152        ClientConfig {
153            endpoint: String::new(),
154            origin: None,
155            compression: None,
156            rate_limit: None,
157            tls_setting: TLSSetting::default(),
158            keepalive: None,
159            connect_timeout: default_connect_timeout(),
160            request_timeout: default_request_timeout(),
161            buffer_size: None,
162            headers: HashMap::new(),
163            auth: AuthenticationConfig::None,
164        }
165    }
166}
167
168fn default_connect_timeout() -> Duration {
169    Duration::from_secs(0)
170}
171
172fn default_request_timeout() -> Duration {
173    Duration::from_secs(0)
174}
175
176// Display for ClientConfig
177impl std::fmt::Display for ClientConfig {
178    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
179        write!(
180            f,
181            "ClientConfig {{ endpoint: {}, origin: {:?}, compression: {:?}, rate_limit: {:?}, tls_setting: {:?}, keepalive: {:?}, connect_timeout: {:?}, request_timeout: {:?}, buffer_size: {:?}, headers: {:?}, auth: {:?} }}",
182            self.endpoint,
183            self.origin,
184            self.compression,
185            self.rate_limit,
186            self.tls_setting,
187            self.keepalive,
188            self.connect_timeout,
189            self.request_timeout,
190            self.buffer_size,
191            self.headers,
192            self.auth
193        )
194    }
195}
196
197impl Configuration for ClientConfig {
198    fn validate(&self) -> Result<(), ConfigurationError> {
199        // Validate the client configuration
200        self.tls_setting.validate()
201    }
202}
203
204impl ClientConfig {
205    /// Creates a new client configuration with the given endpoint.
206    /// This function will return a ClientConfig with the endpoint set
207    /// and all other fields set to default.
208    pub fn with_endpoint(endpoint: &str) -> Self {
209        Self {
210            endpoint: endpoint.to_string(),
211            ..Self::default()
212        }
213    }
214
215    pub fn with_origin(self, origin: &str) -> Self {
216        Self {
217            origin: Some(origin.to_string()),
218            ..self
219        }
220    }
221
222    pub fn with_compression(self, compression: CompressionType) -> Self {
223        Self {
224            compression: Some(compression),
225            ..self
226        }
227    }
228
229    pub fn with_rate_limit(self, rate_limit: &str) -> Self {
230        Self {
231            rate_limit: Some(rate_limit.to_string()),
232            ..self
233        }
234    }
235
236    pub fn with_tls_setting(self, tls_setting: TLSSetting) -> Self {
237        Self {
238            tls_setting,
239            ..self
240        }
241    }
242
243    pub fn with_keepalive(self, keepalive: KeepaliveConfig) -> Self {
244        Self {
245            keepalive: Some(keepalive),
246            ..self
247        }
248    }
249
250    pub fn with_connect_timeout(self, connect_timeout: Duration) -> Self {
251        Self {
252            connect_timeout,
253            ..self
254        }
255    }
256
257    pub fn with_request_timeout(self, request_timeout: Duration) -> Self {
258        Self {
259            request_timeout,
260            ..self
261        }
262    }
263
264    pub fn with_buffer_size(self, buffer_size: usize) -> Self {
265        Self {
266            buffer_size: Some(buffer_size),
267            ..self
268        }
269    }
270
271    pub fn with_headers(self, headers: HashMap<String, String>) -> Self {
272        Self { headers, ..self }
273    }
274
275    pub fn with_auth(self, auth: AuthenticationConfig) -> Self {
276        Self { auth, ..self }
277    }
278
279    /// Converts the client configuration to a tonic channel.
280    /// This function will return a Result with the channel if the configuration is valid.
281    /// If the configuration is invalid, it will return a ConfigError.
282    /// The function will set the headers, tls settings, keepalive settings, rate limit settings
283    /// timeout settings, buffer size settings, and origin settings.
284    pub fn to_channel(
285        &self,
286    ) -> Result<
287        impl tonic::client::GrpcService<
288            tonic::body::Body,
289            Error: Into<StdError> + Send,
290            ResponseBody: Body<Data = Bytes, Error: Into<StdError> + std::marker::Send>
291                              + Send
292                              + 'static,
293            Future: Send,
294        > + Send
295        + use<>,
296        ConfigError,
297    > {
298        // Make sure the endpoint is set and is valid
299        if self.endpoint.is_empty() {
300            return Err(ConfigError::MissingEndpoint);
301        }
302
303        // channel builder
304        let uri =
305            Uri::from_str(&self.endpoint).map_err(|e| ConfigError::UriParseError(e.to_string()))?;
306        let builder = Channel::builder(uri);
307
308        // HTTP2 connector. We need this to be able to use directly a rustls config
309        // cf. https://github.com/hyperium/tonic/issues/1615
310        let mut http = HttpConnector::new();
311
312        // NOTE(msardara): we might want to make these configurable as well.
313        http.enforce_http(false);
314        http.set_nodelay(false);
315
316        // set the connection timeout
317        match self.connect_timeout.as_secs() {
318            0 => http.set_connect_timeout(None),
319            _ => http.set_connect_timeout(Some(self.connect_timeout)),
320        }
321
322        // set the buffer size
323        let builder = match self.buffer_size {
324            Some(size) => builder.buffer_size(size),
325            None => builder,
326        };
327
328        // set keepalive settings
329        let builder = match &self.keepalive {
330            Some(keepalive) => {
331                // TCP level keepalive
332                http.set_keepalive(Some(keepalive.tcp_keepalive));
333
334                builder
335                    .keep_alive_timeout(keepalive.timeout)
336                    .keep_alive_while_idle(keepalive.keep_alive_while_idle)
337                    // HTTP level keepalive
338                    .http2_keep_alive_interval(keepalive.http2_keepalive)
339            }
340            None => builder,
341        };
342
343        // set origin settings
344        let builder = match &self.origin {
345            Some(origin) => {
346                let uri = Uri::from_str(origin.as_str())
347                    .map_err(|e| ConfigError::UriParseError(e.to_string()))?;
348
349                builder.origin(uri)
350            }
351            None => builder,
352        };
353
354        let builder = match &self.rate_limit {
355            Some(rate_limit) => {
356                let (limit, duration) = parse_rate_limit(rate_limit)
357                    .map_err(|e| ConfigError::RateLimitParseError(e.to_string()))?;
358                builder.rate_limit(limit, duration)
359            }
360            None => builder,
361        };
362
363        // set the request timeout
364        let builder = match self.request_timeout.as_secs() {
365            0 => builder,
366            _ => builder.timeout(self.request_timeout),
367        };
368
369        // set header to http connector
370        let mut header_map = HeaderMap::new();
371        for (key, value) in &self.headers {
372            let k: HeaderName = key.parse().map_err(|_| {
373                ConfigError::HeaderParseError(format!("error parsing header key {}", key))
374            })?;
375            let v: HeaderValue = value.parse().map_err(|_| {
376                ConfigError::HeaderParseError(format!("error parsing header value {}", key))
377            })?;
378
379            header_map.insert(k, v);
380        }
381
382        // TLS configuration
383        let tls_config = TLSSetting::load_rustls_config(&self.tls_setting)
384            .map_err(|e| ConfigError::TLSSettingError(e.to_string()))?;
385
386        let channel = match tls_config {
387            Some(tls) => {
388                let connector = tower::ServiceBuilder::new()
389                    .layer_fn(move |s| {
390                        let tls = tls.clone();
391
392                        hyper_rustls::HttpsConnectorBuilder::new()
393                            .with_tls_config(tls)
394                            .https_or_http()
395                            .enable_http2()
396                            .wrap_connector(s)
397                    })
398                    .service(http);
399
400                builder.connect_with_connector_lazy(connector)
401            }
402            None => builder.connect_with_connector_lazy(http),
403        };
404
405        // Auth configuration
406        match &self.auth {
407            AuthenticationConfig::Basic(basic) => {
408                let auth_layer = basic
409                    .get_client_layer()
410                    .map_err(|e| ConfigError::AuthConfigError(e.to_string()))?;
411
412                // If auth is enabled without TLS, print a warning
413                if self.tls_setting.insecure {
414                    warn!("Auth is enabled without TLS. This is not recommended.");
415                }
416
417                Ok(tower::ServiceBuilder::new()
418                    .layer(SetRequestHeaderLayer::new(header_map))
419                    .layer(auth_layer)
420                    .service(channel)
421                    .boxed())
422            }
423            AuthenticationConfig::Bearer(bearer) => {
424                let auth_layer = bearer
425                    .get_client_layer()
426                    .map_err(|e| ConfigError::AuthConfigError(e.to_string()))?;
427
428                // If auth is enabled without TLS, print a warning
429                if self.tls_setting.insecure {
430                    warn!("Auth is enabled without TLS. This is not recommended.");
431                }
432
433                Ok(tower::ServiceBuilder::new()
434                    .layer(SetRequestHeaderLayer::new(header_map))
435                    .layer(auth_layer)
436                    .service(channel)
437                    .boxed())
438            }
439            AuthenticationConfig::None => Ok(tower::ServiceBuilder::new()
440                .layer(SetRequestHeaderLayer::new(header_map))
441                .service(channel)
442                .boxed()),
443        }
444    }
445}
446
447/// Parse the rate limit string into a limit and a duration.
448/// The rate limit string should be in the format of <limit>/<duration>,
449/// with duration expressed in seconds.
450/// This function will return a Result with the limit and duration if the
451/// rate limit is valid.
452fn parse_rate_limit(rate_limit: &str) -> Result<(u64, Duration), ConfigError> {
453    let parts: Vec<&str> = rate_limit.split('/').collect();
454
455    // Check the parts has two elements
456    if parts.len() != 2 {
457        return Err(
458            ConfigError::RateLimitParseError(
459                "rate limit should be in the format of <limit>/<duration>, with duration expressed in seconds".to_string(),
460            ),
461        );
462    }
463
464    let limit = parts[0]
465        .parse::<u64>()
466        .map_err(|e| ConfigError::RateLimitParseError(e.to_string()))?;
467    let duration = Duration::from_secs(
468        parts[1]
469            .parse::<u64>()
470            .map_err(|e| ConfigError::RateLimitParseError(e.to_string()))?,
471    );
472    Ok((limit, duration))
473}
474
475#[cfg(test)]
476mod test {
477    #[allow(unused_imports)]
478    use super::*;
479    use tracing::debug;
480    use tracing_test::traced_test;
481
482    #[test]
483    fn test_default_keepalive_config() {
484        let keepalive = KeepaliveConfig::default();
485        assert_eq!(keepalive.tcp_keepalive, Duration::from_secs(60));
486        assert_eq!(keepalive.http2_keepalive, Duration::from_secs(60));
487        assert_eq!(keepalive.timeout, Duration::from_secs(10));
488        assert!(!keepalive.keep_alive_while_idle);
489    }
490
491    #[test]
492    fn test_default_client_config() {
493        let client = ClientConfig::default();
494        assert_eq!(client.endpoint, String::new());
495        assert_eq!(client.origin, None);
496        assert_eq!(client.compression, None);
497        assert_eq!(client.rate_limit, None);
498        assert_eq!(client.tls_setting, TLSSetting::default());
499        assert_eq!(client.keepalive, None);
500        assert_eq!(client.connect_timeout, Duration::from_secs(0));
501        assert_eq!(client.request_timeout, Duration::from_secs(0));
502        assert_eq!(client.buffer_size, None);
503        assert_eq!(client.headers, HashMap::new());
504        assert_eq!(client.auth, AuthenticationConfig::None);
505    }
506
507    #[test]
508    fn test_parse_rate_limit() {
509        let res = parse_rate_limit("100/10");
510        assert!(res.is_ok());
511
512        let (limit, duration) = res.unwrap();
513
514        assert_eq!(limit, 100);
515        assert_eq!(duration, Duration::from_secs(10));
516
517        let res = parse_rate_limit("100");
518        assert!(res.is_err());
519    }
520
521    #[tokio::test]
522    #[traced_test]
523    async fn test_to_channel() {
524        let test_path: &str = env!("CARGO_MANIFEST_DIR");
525
526        // create a new client config
527        let mut client = ClientConfig::default();
528
529        // as the endpoint is missing, this should fail
530        let mut channel = client.to_channel();
531        assert!(channel.is_err());
532
533        // Set the endpoint
534        client.endpoint = "http://localhost:8080".to_string();
535        channel = client.to_channel();
536        assert!(channel.is_ok());
537
538        // Set the tls settings
539        client.tls_setting.insecure = true;
540        channel = client.to_channel();
541        assert!(channel.is_ok());
542
543        // Set the tls settings
544        client.tls_setting = {
545            let mut tls = TLSSetting::default();
546            tls.config.ca_file = Some(format!("{}/testdata/grpc/{}", test_path, "ca.crt"));
547            tls.insecure = false;
548            tls
549        };
550        debug!("{}/testdata/{}", test_path, "ca.crt");
551        channel = client.to_channel();
552        assert!(channel.is_ok());
553
554        // Set keepalive settings
555        client.keepalive = Some(KeepaliveConfig::default());
556        channel = client.to_channel();
557        assert!(channel.is_ok());
558
559        // Set rate limit settings
560        client.rate_limit = Some("100/10".to_string());
561        channel = client.to_channel();
562        assert!(channel.is_ok());
563
564        // Set rate limit settings wrong
565        client.rate_limit = Some("100".to_string());
566        channel = client.to_channel();
567        assert!(channel.is_err());
568
569        // reset config
570        client.rate_limit = None;
571
572        // Set timeout settings
573        client.request_timeout = Duration::from_secs(10);
574        channel = client.to_channel();
575        assert!(channel.is_ok());
576
577        // Set buffer size settings
578        client.buffer_size = Some(1024);
579        channel = client.to_channel();
580        assert!(channel.is_ok());
581
582        // Set origin settings
583        client.origin = Some("http://example.com".to_string());
584        channel = client.to_channel();
585        assert!(channel.is_ok());
586
587        // set additional header to add to the request
588        client
589            .headers
590            .insert("X-Test".to_string(), "test".to_string());
591        channel = client.to_channel();
592        assert!(channel.is_ok());
593    }
594}