Skip to main content

reqsign_aws_v4/provide_credential/
s3_express_session.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use crate::Credential;
19use async_trait::async_trait;
20use bytes::Bytes;
21use http::{Method, Request, header};
22use log::debug;
23use reqsign_core::{Context, Error, ProvideCredential, Result, SignRequest};
24use serde::Deserialize;
25
26/// S3 Express One Zone session provider that creates session credentials.
27///
28/// This provider implements the CreateSession API for S3 Express One Zone buckets,
29/// which provides low-latency access through temporary session-based credentials.
30///
31/// # Important
32///
33/// - The session token returned by this provider should be used with the
34///   `x-amz-s3session-token` header instead of the standard `x-amz-security-token`
35///   header when making requests to S3 Express One Zone buckets.
36/// - This provider does not cache sessions internally. The upper layer (e.g., Signer)
37///   handles credential caching and will request new sessions when they expire.
38///
39/// # Example
40///
41/// ```no_run
42/// use reqsign_aws_v4::{S3ExpressSessionProvider, DefaultCredentialProvider};
43/// use reqsign_core::ProvideCredential;
44///
45/// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
46/// let provider = S3ExpressSessionProvider::new(
47///     "my-bucket--usw2-az1--x-s3",
48///     DefaultCredentialProvider::new(),
49/// );
50///
51/// // Each call to provide_credential creates a new session
52/// # Ok(())
53/// # }
54/// ```
55#[derive(Debug)]
56pub struct S3ExpressSessionProvider {
57    bucket: String,
58    base_provider: Box<dyn ProvideCredential<Credential = Credential>>,
59}
60
61#[derive(Debug, Deserialize)]
62#[serde(rename = "CreateSessionResult", rename_all = "PascalCase")]
63struct CreateSessionResponse {
64    credentials: SessionCredentials,
65}
66
67#[derive(Debug, Deserialize)]
68#[serde(rename_all = "PascalCase")]
69struct SessionCredentials {
70    session_token: String,
71    secret_access_key: String,
72    access_key_id: String,
73    expiration: String,
74}
75
76impl S3ExpressSessionProvider {
77    /// Create a new S3 Express session provider for a specific bucket.
78    ///
79    /// # Arguments
80    ///
81    /// * `bucket` - The S3 Express One Zone bucket name (e.g., "my-bucket--usw2-az1--x-s3")
82    /// * `provider` - The base credential provider to use for CreateSession API calls
83    pub fn new(
84        bucket: impl Into<String>,
85        provider: impl ProvideCredential<Credential = Credential> + 'static,
86    ) -> Self {
87        Self {
88            bucket: bucket.into(),
89            base_provider: Box::new(provider),
90        }
91    }
92
93    /// Create a new session for the bucket using the CreateSession API.
94    async fn create_session(&self, ctx: &Context, base_cred: &Credential) -> Result<Credential> {
95        debug!(
96            "Creating new S3 Express session for bucket: {}",
97            self.bucket
98        );
99
100        // Extract region from bucket name (format: name--azid--x-s3)
101        let parts: Vec<&str> = self.bucket.split("--").collect();
102        if parts.len() != 3 || !parts[2].ends_with("x-s3") {
103            return Err(Error::unexpected(format!(
104                "Invalid S3 Express bucket name format: {}",
105                self.bucket
106            )));
107        }
108
109        // Extract region from AZ ID (e.g., "usw2-az1" -> "us-west-2")
110        let az_id = parts[1];
111        let region = self.parse_region_from_az_id(az_id)?;
112
113        // Build CreateSession request
114        let url = format!(
115            "https://{}.s3express-{}.amazonaws.com/?session",
116            self.bucket, az_id
117        );
118        let req = Request::builder()
119            .method(Method::GET)
120            .uri(&url)
121            .header(
122                header::HOST,
123                format!("{}.s3express-{}.amazonaws.com", self.bucket, az_id),
124            )
125            .header("x-amz-content-sha256", crate::EMPTY_STRING_SHA256)
126            .header("x-amz-create-session-mode", "ReadWrite")
127            .body(Bytes::new())
128            .map_err(|e| Error::unexpected(format!("Failed to build request: {e}")))?;
129
130        // Sign the request using base credentials
131        let (mut parts, body) = req.into_parts();
132        let signer = crate::RequestSigner::new("s3express", &region);
133        signer
134            .sign_request(ctx, &mut parts, Some(base_cred), None)
135            .await?;
136
137        // Send the request
138        let req = Request::from_parts(parts, body);
139        let resp = ctx.http_send(req).await?;
140
141        // Check response status
142        let status = resp.status();
143        if !status.is_success() {
144            let body = resp.into_body();
145            let error_msg = String::from_utf8_lossy(&body);
146            return Err(Error::unexpected(format!(
147                "CreateSession failed with status {status}: {error_msg}"
148            )));
149        }
150
151        // Parse XML response
152        let body = resp.into_body();
153        let body_str = String::from_utf8_lossy(&body);
154        debug!("CreateSession response body: {body_str}");
155
156        let create_session_resp: CreateSessionResponse = quick_xml::de::from_str(&body_str)
157            .map_err(|e| {
158                Error::unexpected(format!("Failed to parse CreateSession XML response: {e}"))
159            })?;
160
161        // Parse expiration time from ISO8601 format
162        let expiration = create_session_resp
163            .credentials
164            .expiration
165            .parse()
166            .map_err(|e| {
167                Error::unexpected(format!(
168                    "failed to parse expiration time '{}': {e}",
169                    create_session_resp.credentials.expiration
170                ))
171            })?;
172
173        // Convert to Credential with expiration time
174        let creds = create_session_resp.credentials;
175        Ok(Credential {
176            access_key_id: creds.access_key_id,
177            secret_access_key: creds.secret_access_key,
178            session_token: Some(creds.session_token),
179            expires_in: Some(expiration),
180        })
181    }
182
183    /// Parse region from AZ ID (e.g., "usw2-az1" -> "us-west-2")
184    fn parse_region_from_az_id(&self, az_id: &str) -> Result<String> {
185        // Common region mappings
186        let region = match az_id {
187            az if az.starts_with("use1-") => "us-east-1",
188            az if az.starts_with("use2-") => "us-east-2",
189            az if az.starts_with("usw1-") => "us-west-1",
190            az if az.starts_with("usw2-") => "us-west-2",
191            az if az.starts_with("euw1-") => "eu-west-1",
192            az if az.starts_with("euc1-") => "eu-central-1",
193            az if az.starts_with("apne1-") => "ap-northeast-1",
194            az if az.starts_with("apse1-") => "ap-southeast-1",
195            az if az.starts_with("apse2-") => "ap-southeast-2",
196            _ => return Err(Error::unexpected(format!("Unknown AZ ID format: {az_id}"))),
197        };
198        Ok(region.to_string())
199    }
200}
201
202#[async_trait]
203impl ProvideCredential for S3ExpressSessionProvider {
204    type Credential = Credential;
205
206    async fn provide_credential(&self, ctx: &Context) -> Result<Option<Self::Credential>> {
207        debug!("Creating S3 Express session for bucket: {}", self.bucket);
208
209        // Get base credentials - required for S3 Express
210        let base_cred = self.base_provider.provide_credential(ctx).await?
211            .ok_or_else(|| {
212                Error::unexpected(
213                    "No base credentials found. S3 Express requires valid AWS credentials to create sessions"
214                )
215            })?;
216
217        // Create new session
218        let session_cred = self.create_session(ctx, &base_cred).await?;
219
220        Ok(Some(session_cred))
221    }
222}
223
224#[cfg(test)]
225mod tests {
226    use super::*;
227
228    #[test]
229    fn test_parse_region_from_az_id() {
230        let provider = S3ExpressSessionProvider::new(
231            "test--usw2-az1--x-s3",
232            crate::StaticCredentialProvider::new("test", "test"),
233        );
234
235        assert_eq!(
236            provider.parse_region_from_az_id("usw2-az1").unwrap(),
237            "us-west-2"
238        );
239        assert_eq!(
240            provider.parse_region_from_az_id("use1-az4").unwrap(),
241            "us-east-1"
242        );
243        assert_eq!(
244            provider.parse_region_from_az_id("euw1-az2").unwrap(),
245            "eu-west-1"
246        );
247    }
248
249    #[test]
250    fn test_invalid_bucket_format() {
251        let provider = S3ExpressSessionProvider::new(
252            "invalid-bucket-name",
253            crate::StaticCredentialProvider::new("test", "test"),
254        );
255
256        // This will be tested when create_session is called
257        // Just verify the provider can be created
258        assert_eq!(provider.bucket, "invalid-bucket-name");
259    }
260
261    #[test]
262    fn test_parse_create_session_response() {
263        let xml = r#"<?xml version="1.0" encoding="UTF-8"?>
264            <CreateSessionResult xmlns="http://s3.amazonaws.com/doc/2006-03-01/">
265                <Credentials>
266                    <SessionToken>TESTSESSIONTOKEN</SessionToken>
267                    <SecretAccessKey>TESTSECRETKEY</SecretAccessKey>
268                    <AccessKeyId>ASIARTESTID</AccessKeyId>
269                    <Expiration>2024-01-29T18:53:01Z</Expiration>
270                </Credentials>
271            </CreateSessionResult>"#;
272
273        let response: CreateSessionResponse = quick_xml::de::from_str(xml).unwrap();
274        assert_eq!(response.credentials.access_key_id, "ASIARTESTID");
275        assert_eq!(response.credentials.secret_access_key, "TESTSECRETKEY");
276        assert_eq!(response.credentials.session_token, "TESTSESSIONTOKEN");
277        assert_eq!(response.credentials.expiration, "2024-01-29T18:53:01Z");
278    }
279}