reqsign_core/
api.rs

1// Licensed to the Apache Software Foundation (ASF) under one
2// or more contributor license agreements.  See the NOTICE file
3// distributed with this work for additional information
4// regarding copyright ownership.  The ASF licenses this file
5// to you under the Apache License, Version 2.0 (the
6// "License"); you may not use this file except in compliance
7// with the License.  You may obtain a copy of the License at
8//
9//   http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing,
12// software distributed under the License is distributed on an
13// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
14// KIND, either express or implied.  See the License for the
15// specific language governing permissions and limitations
16// under the License.
17
18use crate::{Context, Result};
19use std::fmt::Debug;
20use std::time::Duration;
21
22/// SigningCredential is the trait used by signer as the signing credential.
23pub trait SigningCredential: Clone + Debug + Send + Sync + Unpin + 'static {
24    /// Check if the signing credential is valid.
25    fn is_valid(&self) -> bool;
26}
27
28impl<T: SigningCredential> SigningCredential for Option<T> {
29    fn is_valid(&self) -> bool {
30        let Some(ctx) = self else {
31            return false;
32        };
33
34        ctx.is_valid()
35    }
36}
37
38/// ProvideCredential is the trait used by signer to load the credential from the environment.
39///`
40/// Service may require different credential to sign the request, for example, AWS require
41/// access key and secret key, while Google Cloud Storage require token.
42#[async_trait::async_trait]
43pub trait ProvideCredential: Debug + Send + Sync + Unpin + 'static {
44    /// Credential returned by this loader.
45    ///
46    /// Typically, it will be a credential.
47    type Credential: Send + Sync + Unpin + 'static;
48
49    /// Load signing credential from current env.
50    async fn provide_credential(&self, ctx: &Context) -> Result<Option<Self::Credential>>;
51}
52
53/// SignRequest is the trait used by signer to build the signing request.
54#[async_trait::async_trait]
55pub trait SignRequest: Debug + Send + Sync + Unpin + 'static {
56    /// Credential used by this builder.
57    ///
58    /// Typically, it will be a credential.
59    type Credential: Send + Sync + Unpin + 'static;
60
61    /// Construct the signing request.
62    ///
63    /// ## Credential
64    ///
65    /// The `credential` parameter is the credential required by the signer to sign the request.
66    ///
67    /// ## Expires In
68    ///
69    /// The `expires_in` parameter specifies the expiration time for the result.
70    /// If the signer does not support expiration, it should return an error.
71    ///
72    /// Implementation details determine how to handle the expiration logic. For instance,
73    /// AWS uses a query string that includes an `Expires` parameter.
74    async fn sign_request(
75        &self,
76        ctx: &Context,
77        req: &mut http::request::Parts,
78        credential: Option<&Self::Credential>,
79        expires_in: Option<Duration>,
80    ) -> Result<()>;
81}
82
83/// A chain of credential providers that will be tried in order.
84///
85/// This is a generic implementation that can be used by any service to chain multiple
86/// credential providers together. The chain will try each provider in order until one
87/// returns credentials or all providers have been exhausted.
88///
89/// # Example
90///
91/// ```no_run
92/// use reqsign_core::{ProvideCredentialChain, Context, ProvideCredential, Result};
93/// use async_trait::async_trait;
94///
95/// #[derive(Debug)]
96/// struct MyCredential {
97///     token: String,
98/// }
99///
100/// #[derive(Debug)]
101/// struct EnvironmentProvider;
102///
103/// #[async_trait]
104/// impl ProvideCredential for EnvironmentProvider {
105///     type Credential = MyCredential;
106///
107///     async fn provide_credential(&self, ctx: &Context) -> Result<Option<Self::Credential>> {
108///         // Implementation
109///         Ok(None)
110///     }
111/// }
112///
113/// # async fn example(ctx: Context) {
114/// let chain = ProvideCredentialChain::new()
115///     .push(EnvironmentProvider);
116///
117/// let credentials = chain.provide_credential(&ctx).await;
118/// # }
119/// ```
120pub struct ProvideCredentialChain<C> {
121    providers: Vec<Box<dyn ProvideCredential<Credential = C>>>,
122}
123
124impl<C> ProvideCredentialChain<C>
125where
126    C: Send + Sync + Unpin + 'static,
127{
128    /// Create a new empty credential provider chain.
129    pub fn new() -> Self {
130        Self {
131            providers: Vec::new(),
132        }
133    }
134
135    /// Add a credential provider to the chain.
136    pub fn push(mut self, provider: impl ProvideCredential<Credential = C> + 'static) -> Self {
137        self.providers.push(Box::new(provider));
138        self
139    }
140
141    /// Add a credential provider to the front of the chain.
142    ///
143    /// This provider will be tried first before all existing providers.
144    pub fn push_front(
145        mut self,
146        provider: impl ProvideCredential<Credential = C> + 'static,
147    ) -> Self {
148        self.providers.insert(0, Box::new(provider));
149        self
150    }
151
152    /// Create a credential provider chain from a vector of providers.
153    pub fn from_vec(providers: Vec<Box<dyn ProvideCredential<Credential = C>>>) -> Self {
154        Self { providers }
155    }
156
157    /// Get the number of providers in the chain.
158    pub fn len(&self) -> usize {
159        self.providers.len()
160    }
161
162    /// Check if the chain is empty.
163    pub fn is_empty(&self) -> bool {
164        self.providers.is_empty()
165    }
166}
167
168impl<C> Default for ProvideCredentialChain<C>
169where
170    C: Send + Sync + Unpin + 'static,
171{
172    fn default() -> Self {
173        Self::new()
174    }
175}
176
177impl<C> Debug for ProvideCredentialChain<C>
178where
179    C: Send + Sync + Unpin + 'static,
180{
181    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
182        f.debug_struct("ProvideCredentialChain")
183            .field("providers_count", &self.providers.len())
184            .finish()
185    }
186}
187
188#[async_trait::async_trait]
189impl<C> ProvideCredential for ProvideCredentialChain<C>
190where
191    C: Send + Sync + Unpin + 'static,
192{
193    type Credential = C;
194
195    async fn provide_credential(&self, ctx: &Context) -> Result<Option<Self::Credential>> {
196        for provider in &self.providers {
197            log::debug!("Trying credential provider: {provider:?}");
198
199            match provider.provide_credential(ctx).await {
200                Ok(Some(cred)) => {
201                    log::debug!("Successfully loaded credential from provider: {provider:?}");
202                    return Ok(Some(cred));
203                }
204                Ok(None) => {
205                    log::debug!("No credential found in provider: {provider:?}");
206                    continue;
207                }
208                Err(e) => {
209                    log::warn!("Error loading credential from provider {provider:?}: {e:?}");
210                    // Continue to next provider on error
211                    continue;
212                }
213            }
214        }
215
216        Ok(None)
217    }
218}