mountpoint_s3_client/
endpoint_config.rs

1use std::os::unix::prelude::OsStrExt;
2use std::sync::LazyLock;
3use std::time::Instant;
4
5use mountpoint_s3_crt::{
6    common::allocator::Allocator,
7    s3::endpoint_resolver::{RequestContext, ResolvedEndpoint, ResolverError, RuleEngine},
8};
9use thiserror::Error;
10
11pub use mountpoint_s3_crt::auth::signing_config::SigningAlgorithm;
12pub use mountpoint_s3_crt::common::uri::Uri;
13
14/// A static s3 endpoint rule engine that can be shared across s3 client
15static S3_ENDPOINT_RULE_ENGINE: LazyLock<RuleEngine> = LazyLock::new(|| RuleEngine::new(&Default::default()).unwrap());
16
17#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
18pub enum AddressingStyle {
19    /// Use virtual addressing if possible, but fall back to path addressing if necessary
20    #[default]
21    Automatic,
22    /// Always use path addressing
23    Path,
24}
25
26#[derive(Debug, Clone)]
27pub struct AuthScheme {
28    disable_double_encoding: bool,
29    scheme_name: SigningAlgorithm,
30    signing_name: String,
31    signing_region: String,
32}
33
34impl AuthScheme {
35    /// Get the siging name from [AuthScheme]
36    pub fn signing_name(&self) -> &str {
37        &self.signing_name
38    }
39
40    /// Get the signing region from [AuthScheme]
41    pub fn signing_region(&self) -> &str {
42        &self.signing_region
43    }
44
45    /// Get Disable double encoding value for [AuthScheme]
46    pub fn disable_double_encoding(&self) -> bool {
47        self.disable_double_encoding
48    }
49
50    /// Get the name of [AuthScheme]
51    pub fn scheme_name(&self) -> SigningAlgorithm {
52        self.scheme_name
53    }
54}
55
56/// Configuration for resolution of S3 endpoints
57#[derive(Debug, Clone)]
58pub struct EndpointConfig {
59    region: String,
60    use_fips: bool,
61    use_accelerate: bool,
62    use_dual_stack: bool,
63    endpoint: Option<Uri>,
64    addressing_style: AddressingStyle,
65}
66
67impl EndpointConfig {
68    /// Create a new endpoint configuration for a given region
69    pub fn new(region: &str) -> Self {
70        Self {
71            region: region.to_owned(),
72            use_fips: false,
73            use_accelerate: false,
74            use_dual_stack: false,
75            endpoint: None,
76            addressing_style: AddressingStyle::Automatic,
77        }
78    }
79
80    /// Set region for a given endpoint config
81    #[must_use = "EndpointConfig follows a builder pattern"]
82    pub fn region(mut self, region: &str) -> Self {
83        region.clone_into(&mut self.region);
84        self
85    }
86
87    /// use FIPS config for S3
88    #[must_use = "EndpointConfig follows a builder pattern"]
89    pub fn use_fips(mut self, fips: bool) -> Self {
90        self.use_fips = fips;
91        self
92    }
93
94    /// use Transfer Acceleration config for S3
95    #[must_use = "EndpointConfig follows a builder pattern"]
96    pub fn use_accelerate(mut self, accelerate: bool) -> Self {
97        self.use_accelerate = accelerate;
98        self
99    }
100
101    /// use dual stack config for S3
102    #[must_use = "EndpointConfig follows a builder pattern"]
103    pub fn use_dual_stack(mut self, dual_stack: bool) -> Self {
104        self.use_dual_stack = dual_stack;
105        self
106    }
107
108    /// Set predefined url for endpoint configuration
109    #[must_use = "EndpointConfig follows a builder pattern"]
110    pub fn endpoint(mut self, endpoint: Uri) -> Self {
111        self.endpoint = Some(endpoint);
112        self
113    }
114
115    /// Set addressing style for [EndpointConfig]
116    #[must_use = "EndpointConfig follows a builder pattern"]
117    pub fn addressing_style(mut self, addressing_style: AddressingStyle) -> Self {
118        self.addressing_style = addressing_style;
119        self
120    }
121
122    /// get the region from the [EndpointConfig]
123    pub fn get_region(&self) -> &str {
124        &self.region
125    }
126
127    /// get the fips config from the [EndpointConfig]
128    pub fn get_fips(&self) -> bool {
129        self.use_fips
130    }
131
132    /// get the transfer acceleration config from the [EndpointConfig]
133    pub fn get_accelerate(&self) -> bool {
134        self.use_accelerate
135    }
136
137    /// get the dual stack config from the [EndpointConfig]
138    pub fn get_dual_stack(&self) -> bool {
139        self.use_dual_stack
140    }
141
142    /// get the endpoint uri if provided from [EndpointConfig]
143    pub fn get_endpoint(&self) -> Option<Uri> {
144        self.endpoint.clone()
145    }
146
147    /// get the addressing style from the [EndpointConfig]
148    pub fn get_addressing_style(&self) -> AddressingStyle {
149        self.addressing_style
150    }
151
152    /// resolve the endpoint from the [EndpointConfig] and the bucket name
153    pub fn resolve_for_bucket(&self, bucket: &str) -> Result<ResolvedEndpointInfo, EndpointError> {
154        let allocator = Allocator::default();
155        let mut endpoint_request_context: RequestContext = RequestContext::new(&allocator).unwrap();
156
157        endpoint_request_context
158            .add_string(&allocator, "Region", &self.region)
159            .unwrap();
160        endpoint_request_context
161            .add_string(&allocator, "Bucket", bucket)
162            .unwrap();
163        if let Some(endpoint_uri) = &self.endpoint {
164            endpoint_request_context
165                .add_string(&allocator, "Endpoint", endpoint_uri.as_os_str())
166                .unwrap()
167        };
168        if self.use_fips {
169            endpoint_request_context
170                .add_boolean(&allocator, "UseFIPS", true)
171                .unwrap()
172        };
173        if self.use_dual_stack {
174            endpoint_request_context
175                .add_boolean(&allocator, "UseDualStack", true)
176                .unwrap()
177        };
178        if self.use_accelerate {
179            endpoint_request_context
180                .add_boolean(&allocator, "Accelerate", true)
181                .unwrap()
182        };
183        if self.addressing_style == AddressingStyle::Path {
184            endpoint_request_context
185                .add_boolean(&allocator, "ForcePathStyle", true)
186                .unwrap()
187        };
188
189        let resolved_endpoint = {
190            let start_time = Instant::now();
191            let endpoint_result = S3_ENDPOINT_RULE_ENGINE.resolve(endpoint_request_context);
192            metrics::histogram!("s3.endpoint_resolution_us").record(start_time.elapsed().as_micros() as f64);
193            endpoint_result.map_err(EndpointError::UnresolvedEndpoint)?
194        };
195
196        Ok(ResolvedEndpointInfo(resolved_endpoint))
197    }
198}
199
200/// Wrapper for [ResolvedEndpoint] from CRT to get [Uri] and [AuthScheme]
201#[derive(Debug)]
202pub struct ResolvedEndpointInfo(ResolvedEndpoint);
203
204impl ResolvedEndpointInfo {
205    /// Get the [Uri] from [ResolvedEndpointInfo]
206    pub fn uri(&self) -> Result<Uri, EndpointError> {
207        let allocator = Allocator::default();
208        let endpoint_url = self.0.get_url();
209        Uri::new_from_str(&allocator, endpoint_url)
210            .map_err(|e| EndpointError::InvalidUri(InvalidUriError::CouldNotParse(e)))
211    }
212
213    /// Get the [AuthScheme] from [ResolvedEndpointInfo] for the signing config
214    pub fn auth_scheme(&self) -> Result<AuthScheme, EndpointError> {
215        // ResolvedEndpoint is wrapper for aws_endpoints_resolved_endpoint which has url, properties and header for the endpoint.
216        // Property if in json format containing the AuthScheme. Egs. -
217        // {\"authSchemes\":[{\"disableDoubleEncoding\":true,\"name\":\"sigv4\",\"signingName\":\"s3\",\"signingRegion\":\"us-east-2\"}]}
218        let endpoint_properties = self.0.get_properties();
219        let auth_scheme_data: serde_json::Value = serde_json::from_slice(endpoint_properties.as_bytes())?;
220        let auth_scheme_value = auth_scheme_data["authSchemes"]
221            .get(0)
222            .ok_or(EndpointError::MissingAuthSchemeField("authSchemes"))?;
223        let disable_double_encoding = auth_scheme_value["disableDoubleEncoding"]
224            .as_bool()
225            .ok_or(EndpointError::MissingAuthSchemeField("disableDoubleEncoding"))?;
226        let scheme_name = auth_scheme_value["name"]
227            .as_str()
228            .ok_or(EndpointError::MissingAuthSchemeField("name"))?;
229        let scheme_name = match scheme_name {
230            "sigv4" => SigningAlgorithm::SigV4,
231            "sigv4a" => SigningAlgorithm::SigV4A,
232            "sigv4-s3express" => SigningAlgorithm::SigV4Express,
233            _ => return Err(EndpointError::InvalidAuthSchemeField("name", scheme_name.to_owned())),
234        };
235
236        let signing_name = auth_scheme_value["signingName"]
237            .as_str()
238            .ok_or(EndpointError::MissingAuthSchemeField("signingName"))?;
239        let signing_region = auth_scheme_value
240            .get("signingRegion")
241            .or_else(|| auth_scheme_value["signingRegionSet"].get(0))
242            .and_then(|t| t.as_str())
243            .ok_or(EndpointError::MissingAuthSchemeField(
244                "signingRegion or signingRegionSet",
245            ))?;
246
247        Ok(AuthScheme {
248            disable_double_encoding,
249            scheme_name,
250            signing_name: signing_name.to_owned(),
251            signing_region: signing_region.to_owned(),
252        })
253    }
254}
255
256#[derive(Debug, Error)]
257pub enum EndpointError {
258    #[error("invalid URI")]
259    InvalidUri(#[from] InvalidUriError),
260    #[error("endpoint could not be resolved")]
261    UnresolvedEndpoint(#[from] ResolverError),
262    #[error("Endpoint properties could not be parsed")]
263    ParseError(#[from] serde_json::Error),
264    #[error("AuthScheme field missing: {0}")]
265    MissingAuthSchemeField(&'static str),
266    #[error("invalid value {1} for AuthScheme field {0}")]
267    InvalidAuthSchemeField(&'static str, String),
268}
269
270#[derive(Debug, Error)]
271pub enum InvalidUriError {
272    #[error("URI could not be parsed")]
273    CouldNotParse(#[from] mountpoint_s3_crt::common::error::Error),
274}
275
276#[cfg(test)]
277mod test {
278    use super::*;
279
280    #[test]
281    fn test_virtual_addr() {
282        let endpoint_config = EndpointConfig::new("eu-west-1").addressing_style(AddressingStyle::Automatic);
283        let endpoint_uri = endpoint_config
284            .resolve_for_bucket("amzn-s3-demo-bucket")
285            .unwrap()
286            .uri()
287            .unwrap();
288        assert_eq!(
289            "https://amzn-s3-demo-bucket.s3.eu-west-1.amazonaws.com",
290            endpoint_uri.as_os_str()
291        );
292    }
293
294    #[test]
295    fn test_path_addr_endpoint_arg() {
296        let endpoint_config = EndpointConfig::new("eu-west-1")
297            .addressing_style(AddressingStyle::Path)
298            .endpoint(Uri::new_from_str(&Allocator::default(), "https://example.com").unwrap());
299        let endpoint_uri = endpoint_config
300            .resolve_for_bucket("amzn-s3-demo-bucket")
301            .unwrap()
302            .uri()
303            .unwrap();
304        assert_eq!("https://example.com/amzn-s3-demo-bucket", endpoint_uri.as_os_str());
305    }
306
307    #[test]
308    fn test_endpoint_arg_with_region() {
309        let endpoint_config = EndpointConfig::new("us-east-1")
310            .endpoint(Uri::new_from_str(&Allocator::default(), "https://s3.eu-west-1.amazonaws.com").unwrap());
311        let resolved_endpoint = endpoint_config.resolve_for_bucket("amzn-s3-demo-bucket").unwrap();
312        let endpoint_uri = resolved_endpoint.uri().unwrap();
313        // region is ignored when endpoint_url is specified
314        assert_eq!(
315            "https://amzn-s3-demo-bucket.s3.eu-west-1.amazonaws.com",
316            endpoint_uri.as_os_str()
317        );
318        let endpoint_auth_scheme = resolved_endpoint.auth_scheme().unwrap();
319        let signing_region = endpoint_auth_scheme.signing_region();
320        //signing region is still the region provided
321        assert_eq!(signing_region, "us-east-1");
322    }
323
324    #[test]
325    fn test_fips_dual_stack() {
326        let endpoint_config = EndpointConfig::new("eu-west-1").use_fips(true).use_dual_stack(true);
327        let endpoint_uri = endpoint_config
328            .resolve_for_bucket("amzn-s3-demo-bucket")
329            .unwrap()
330            .uri()
331            .unwrap();
332        assert_eq!(
333            "https://amzn-s3-demo-bucket.s3-fips.dualstack.eu-west-1.amazonaws.com",
334            endpoint_uri.as_os_str()
335        );
336    }
337
338    #[test]
339    fn test_dual_stack_accelerate() {
340        let endpoint_config = EndpointConfig::new("eu-west-1")
341            .use_accelerate(true)
342            .use_dual_stack(true);
343        let endpoint_uri = endpoint_config
344            .resolve_for_bucket("amzn-s3-demo-bucket")
345            .unwrap()
346            .uri()
347            .unwrap();
348        assert_eq!(
349            "https://amzn-s3-demo-bucket.s3-accelerate.dualstack.amazonaws.com",
350            endpoint_uri.as_os_str()
351        );
352    }
353
354    #[test]
355    fn test_dual_stack_path_addr() {
356        let endpoint_config = EndpointConfig::new("eu-west-1")
357            .use_dual_stack(true)
358            .addressing_style(AddressingStyle::Path);
359        let endpoint_uri = endpoint_config
360            .resolve_for_bucket("amzn-s3-demo-bucket")
361            .unwrap()
362            .uri()
363            .unwrap();
364        assert_eq!(
365            "https://s3.dualstack.eu-west-1.amazonaws.com/amzn-s3-demo-bucket",
366            endpoint_uri.as_os_str()
367        );
368    }
369
370    #[test]
371    fn test_arn_as_bucket() {
372        let endpoint_config = EndpointConfig::new("eu-west-1");
373        let endpoint_uri = endpoint_config
374            .resolve_for_bucket("arn:aws:s3::accountID:accesspoint/s3-bucket-test.mrap")
375            .unwrap()
376            .uri()
377            .unwrap();
378        assert_eq!(
379            "https://s3-bucket-test.mrap.accesspoint.s3-global.amazonaws.com",
380            endpoint_uri.as_os_str()
381        );
382    }
383
384    #[test]
385    fn test_arn_override_region() {
386        let endpoint_config = EndpointConfig::new("cn-north-1");
387        // Also a test for China region
388        let endpoint_uri = endpoint_config
389            .resolve_for_bucket("arn:aws-cn:s3:cn-north-2:555555555555:accesspoint/china-region-ap")
390            .unwrap()
391            .uri()
392            .unwrap();
393        assert_eq!(
394            "https://china-region-ap-555555555555.s3-accesspoint.cn-north-2.amazonaws.com.cn",
395            endpoint_uri.as_os_str()
396        );
397    }
398
399    #[test]
400    fn test_outpost() {
401        let endpoint_config = EndpointConfig::new("us-gov-west-1");
402        let endpoint_uri = endpoint_config
403            .resolve_for_bucket("mountpoint-o-o000s3-bucket-test0000000000000000000000000--op-s3")
404            .unwrap()
405            .uri()
406            .unwrap();
407        assert_eq!(
408            "https://mountpoint-o-o000s3-bucket-test0000000000000000000000000--op-s3.op-000s3-bucket-test.s3-outposts.us-gov-west-1.amazonaws.com",
409            endpoint_uri.as_os_str()
410        );
411    }
412
413    #[test]
414    fn test_bucket_arn() {
415        let endpoint_config = EndpointConfig::new("eu-west-1");
416        let endpoint_err = endpoint_config
417            .resolve_for_bucket("arn:aws:s3:::testbucket")
418            .unwrap_err();
419        assert!(matches!(
420            endpoint_err,
421            EndpointError::UnresolvedEndpoint(ResolverError::EndpointNotResolved(_))
422        ));
423        if let EndpointError::UnresolvedEndpoint(ResolverError::EndpointNotResolved(str)) = endpoint_err {
424            let err_str = "Invalid ARN: Unrecognized format: arn:aws:s3:::testbucket (type: testbucket)".to_owned();
425            assert_eq!(str, err_str);
426        }
427    }
428
429    #[test]
430    fn test_auth_scheme_for_arn() {
431        let endpoint_config = EndpointConfig::new("eu-west-1");
432        let endpoint_auth_scheme = endpoint_config
433            .resolve_for_bucket("arn:aws:s3::accountID:accesspoint/s3-bucket-test.mrap")
434            .unwrap()
435            .auth_scheme()
436            .unwrap();
437
438        let signing_region = endpoint_auth_scheme.signing_region();
439        assert_eq!(signing_region, "*");
440        let signing_name = endpoint_auth_scheme.signing_name();
441        assert_eq!(signing_name, "s3");
442    }
443
444    #[test]
445    fn test_auth_scheme_for_bucket() {
446        let endpoint_config = EndpointConfig::new("eu-west-1");
447        let endpoint_auth_scheme = endpoint_config
448            .resolve_for_bucket("test-bucket")
449            .unwrap()
450            .auth_scheme()
451            .unwrap();
452
453        let signing_region = endpoint_auth_scheme.signing_region();
454        assert_eq!(signing_region, "eu-west-1");
455        let signing_name = endpoint_auth_scheme.signing_name();
456        assert_eq!(signing_name, "s3");
457    }
458}