Skip to main content

pdk_contracts_lib/api/
basic_auth.rs

1// Copyright (c) 2026, Salesforce, Inc.,
2// All rights reserved.
3// For full license text, see the LICENSE.txt file
4
5use base64::{prelude::BASE64_STANDARD, Engine as _};
6use pdk_core::classy::hl::{HeadersHandler, HeadersState, RequestHeadersState};
7use pdk_core::logger;
8
9use zeroize::ZeroizeOnDrop;
10
11use super::credentials::{ClientId, ClientSecret};
12
13const AUTHORIZATION_HEADER: &str = "authorization";
14const BASIC_AUTHORIZATION_SCHEMA: &str = "Basic";
15const PAIR_SEPARATOR: &str = ":";
16
17/// Error returned when [basic_auth_credentials()] fails.
18#[non_exhaustive]
19#[derive(thiserror::Error, Debug, Clone, PartialEq, Eq)]
20pub enum BasicAuthError {
21    /// Basic authentication header not found.
22    #[error("Basic authentication header not found.")]
23    HeaderNotFound,
24
25    /// Invalid basic auth header value format.
26    #[error("Invalid basic auth header value format.")]
27    InvalidHeadeValueFormat,
28
29    /// Unknown auth schema.
30    #[error("Unknown auth schema {0}.")]
31    UnknownSchema(String),
32
33    /// Invalid Base-64 encoding.
34    #[error(transparent)]
35    InvalidBase64(InvalidBase64),
36
37    /// Invalid UTF-8 encoding.
38    #[error(transparent)]
39    InvalidUtf8(InvalidUtf8),
40
41    /// Invalid credentials format.
42    #[error("Invalid credentials format.")]
43    InvalidCredentialsFormat,
44}
45
46/// Represents an invalid Base-64 encoding.
47#[derive(thiserror::Error, Debug, Clone, PartialEq, Eq)]
48#[error("Invalid Base64 Encoding")]
49pub struct InvalidBase64(base64::DecodeError);
50
51/// Represents an invalid UTF-8 encoding.
52#[derive(thiserror::Error, Debug, Clone, PartialEq, Eq)]
53#[error("Invalid UTF-8 Encoding")]
54pub struct InvalidUtf8(std::str::Utf8Error);
55
56#[derive(ZeroizeOnDrop)]
57struct AuthorizationHeader(String);
58
59impl AuthorizationHeader {
60    fn from_handler(handler: &dyn HeadersHandler) -> Result<Self, BasicAuthError> {
61        handler
62            .header(AUTHORIZATION_HEADER)
63            .map(Self)
64            .ok_or_else(|| {
65                logger::debug!("Authorization header not present");
66                BasicAuthError::HeaderNotFound
67            })
68    }
69
70    fn auth_value(&self) -> Result<&str, BasicAuthError> {
71        let mut list = self.0.split_whitespace();
72
73        if let (Some(auth_type), Some(auth_value)) = (list.next(), list.next()) {
74            if auth_type != BASIC_AUTHORIZATION_SCHEMA {
75                return Err(BasicAuthError::UnknownSchema(auth_type.to_string()));
76            }
77
78            Ok(auth_value)
79        } else {
80            Err(BasicAuthError::InvalidHeadeValueFormat)
81        }
82    }
83
84    fn decode(&self) -> Result<DecodedHeader, BasicAuthError> {
85        let auth_value = self.auth_value()?;
86
87        let decoded_header = BASE64_STANDARD.decode(auth_value).map_err(|e| {
88            logger::debug!("There was a problem when trying to decoding auth header: {e}");
89
90            BasicAuthError::InvalidBase64(InvalidBase64(e))
91        })?;
92
93        Ok(DecodedHeader(decoded_header))
94    }
95}
96
97#[derive(ZeroizeOnDrop)]
98struct DecodedHeader(Vec<u8>);
99
100impl DecodedHeader {
101    fn as_utf8(&self) -> Result<&str, BasicAuthError> {
102        std::str::from_utf8(self.0.as_slice()).map_err(|e| {
103            logger::debug!("There was a problem when trying to translate auth header: {e}");
104
105            BasicAuthError::InvalidUtf8(InvalidUtf8(e))
106        })
107    }
108
109    fn as_credentials(&self) -> Result<(ClientId, ClientSecret), BasicAuthError> {
110        let Some((client_id, client_secret)) = self.as_utf8()?.split_once(PAIR_SEPARATOR) else {
111            return Err(BasicAuthError::InvalidCredentialsFormat);
112        };
113
114        let result = (
115            ClientId::new(client_id.to_string()),
116            ClientSecret::new(client_secret.to_string()),
117        );
118
119        Ok(result)
120    }
121}
122
123fn credentials_from_handler(
124    handler: &dyn HeadersHandler,
125) -> Result<(ClientId, ClientSecret), BasicAuthError> {
126    AuthorizationHeader::from_handler(handler)?
127        .decode()?
128        .as_credentials()
129}
130
131/// Extracts a pair of credentials from a Basic-Auth header.
132pub fn basic_auth_credentials(
133    request_headers_state: &RequestHeadersState,
134) -> Result<(ClientId, ClientSecret), BasicAuthError> {
135    credentials_from_handler(request_headers_state.handler())
136}
137
138#[cfg(test)]
139mod tests {
140    use pdk_core::classy::hl::HeadersHandler;
141
142    use crate::api::basic_auth::BasicAuthError;
143
144    use super::{credentials_from_handler, AUTHORIZATION_HEADER};
145
146    struct HandlerMock(Option<String>);
147
148    impl HandlerMock {
149        fn absent() -> Self {
150            Self(None)
151        }
152
153        fn new(value: impl Into<String>) -> Self {
154            Self(Some(value.into()))
155        }
156    }
157
158    impl HeadersHandler for HandlerMock {
159        fn headers(&self) -> Vec<(String, String)> {
160            unreachable!()
161        }
162
163        fn header(&self, name: &str) -> Option<String> {
164            (name == AUTHORIZATION_HEADER)
165                .then(|| self.0.clone())
166                .flatten()
167        }
168
169        fn add_header(&self, _: &str, _: &str) {
170            unreachable!()
171        }
172
173        fn set_header(&self, _: &str, _: &str) {
174            unreachable!()
175        }
176
177        fn set_headers(&self, _: Vec<(&str, &str)>) {
178            unreachable!()
179        }
180
181        fn remove_header(&self, _: &str) {
182            unreachable!()
183        }
184    }
185
186    #[test]
187    fn header_not_found() {
188        let result = credentials_from_handler(&HandlerMock::absent());
189
190        assert_eq!(result.unwrap_err(), BasicAuthError::HeaderNotFound);
191    }
192
193    #[test]
194    fn unknown_schema() {
195        let result = credentials_from_handler(&HandlerMock::new("Unknown aaaaa"));
196
197        assert_eq!(
198            result.unwrap_err(),
199            BasicAuthError::UnknownSchema("Unknown".to_string())
200        );
201    }
202
203    #[test]
204    fn invalid_header_format() {
205        let result = credentials_from_handler(&HandlerMock::new("Invalid"));
206
207        assert_eq!(result.unwrap_err(), BasicAuthError::InvalidHeadeValueFormat);
208    }
209
210    #[test]
211    fn invalid_base64() {
212        let result = credentials_from_handler(&HandlerMock::new("Basic ####"));
213
214        assert!(matches!(
215            result.unwrap_err(),
216            BasicAuthError::InvalidBase64(_)
217        ));
218    }
219
220    #[test]
221    fn invalid_utf8() {
222        let result = credentials_from_handler(&HandlerMock::new("Basic aaaa"));
223
224        assert!(matches!(
225            result.unwrap_err(),
226            BasicAuthError::InvalidUtf8(_)
227        ));
228    }
229
230    #[test]
231    fn invalid_credentials_format() {
232        let result = credentials_from_handler(&HandlerMock::new("Basic c29tZSB1c2Vy"));
233
234        assert_eq!(
235            result.unwrap_err(),
236            BasicAuthError::InvalidCredentialsFormat
237        );
238    }
239
240    #[test]
241    fn valid_credentials() {
242        let (id, secret) =
243            credentials_from_handler(&HandlerMock::new("Basic dXNlcjE6cGFzc3dvcmQx")).unwrap();
244
245        assert_eq!(id.as_str(), "user1");
246        assert_eq!(secret.as_str(), "password1");
247    }
248}