1use std::{
2 convert::{TryFrom, TryInto},
3 sync::LazyLock,
4};
5
6use log::{trace, warn};
7use regex_lite::Regex;
8use reqwest::{RequestBuilder, StatusCode, Url, header::HeaderValue};
9use serde::{Deserialize, Serialize};
10
11use crate::{
12 errors::{Error, Result},
13 v2::*,
14};
15
16#[derive(Debug, Clone)]
18pub enum Auth {
19 Bearer(BearerAuth),
20 Basic(BasicAuth),
21}
22
23impl Auth {
24 pub(crate) fn add_auth_headers(&self, request_builder: RequestBuilder) -> RequestBuilder {
26 match self {
27 Auth::Bearer(bearer_auth) => request_builder.bearer_auth(bearer_auth.token.clone()),
28 Auth::Basic(basic_auth) => request_builder.basic_auth(basic_auth.user.clone(), basic_auth.password.clone()),
29 }
30 }
31}
32
33#[derive(Debug, Clone, Default, Deserialize, Serialize)]
35pub struct BearerAuth {
36 token: String,
37 expires_in: Option<u32>,
38 issued_at: Option<String>,
39 refresh_token: Option<String>,
40}
41
42#[derive(Debug, Clone, Default, Deserialize)]
44pub struct MultiTokenBearerAuth {
45 token: Option<String>,
46 access_token: Option<String>,
47 expires_in: Option<u32>,
48 issued_at: Option<String>,
49 refresh_token: Option<String>,
50}
51
52impl TryFrom<MultiTokenBearerAuth> for BearerAuth {
53 type Error = Error;
54
55 fn try_from(value: MultiTokenBearerAuth) -> std::result::Result<Self, Error> {
56 let t = value.token.or(value.access_token).ok_or(Error::NoTokenReceived)?;
57
58 Ok(Self {
59 token: t,
60 expires_in: value.expires_in,
61 issued_at: value.issued_at,
62 refresh_token: value.refresh_token,
63 })
64 }
65}
66
67impl BearerAuth {
68 async fn try_from_header_content(
69 client: Client,
70 scopes: &[&str],
71 credentials: Option<(String, String)>,
72 bearer_header_content: WwwAuthenticateHeaderContentBearer,
73 ) -> Result<Self> {
74 let auth_ep = bearer_header_content.auth_ep(scopes);
75 trace!("authenticate: token endpoint: {auth_ep}");
76
77 let url = reqwest::Url::parse(&auth_ep)?;
78
79 let auth_req = {
80 Client {
81 auth: credentials.map(|(user, password)| {
82 Auth::Basic(BasicAuth {
83 user,
84 password: Some(password),
85 })
86 }),
87 ..client
88 }
89 }
90 .build_reqwest(Method::GET, url);
91
92 let r = auth_req.send().await?;
93 let status = r.status();
94 trace!("authenticate: got status {status}");
95 if status != StatusCode::OK {
96 return Err(Error::UnexpectedHttpStatus(status));
97 }
98
99 let bearer_auth: BearerAuth = r.json::<MultiTokenBearerAuth>().await?.try_into()?;
100
101 match bearer_auth.token.as_str() {
102 "unauthenticated" | "" => return Err(Error::InvalidAuthToken(bearer_auth.token)),
103 _ => {}
104 };
105
106 let chars_count = bearer_auth.token.chars().count();
108 let mask_start = std::cmp::min(1, chars_count - 1);
109 let mask_end = std::cmp::max(chars_count - 1, 1);
110 let mut masked_token = bearer_auth.token.clone();
111 masked_token.replace_range(mask_start..mask_end, &"*".repeat(mask_end - mask_start));
112
113 trace!("authenticate: got token: {masked_token:?}");
114
115 Ok(bearer_auth)
116 }
117}
118
119#[derive(Debug, Clone)]
121pub struct BasicAuth {
122 user: String,
123 password: Option<String>,
124}
125
126#[derive(Debug, PartialEq, Eq, Deserialize)]
128#[serde(rename_all(deserialize = "lowercase"))]
129pub(crate) enum WwwAuthenticateHeaderContent {
130 Bearer(WwwAuthenticateHeaderContentBearer),
131 Basic(WwwAuthenticateHeaderContentBasic),
132}
133
134const REGEX: &str = r#"(?x)\s*
135((?P<method>[A-Za-z]+)\s)?
136\s*
137(
138 (?P<key>[A-Za-z]+)
139 \s*
140 =
141 \s*
142 "(?P<value>[^"]+)"
143 \s*
144)
145"#;
146
147#[derive(Debug, thiserror::Error)]
148pub enum WwwHeaderParseError {
149 #[error("header value must conform to {}", REGEX)]
150 InvalidValue,
151 #[error("'method' field missing")]
152 FieldMethodMissing,
153}
154
155static AUTH_REGEX: LazyLock<Regex> = LazyLock::new(|| Regex::new(REGEX).expect("this static regex is valid"));
156
157impl WwwAuthenticateHeaderContent {
158 pub(crate) fn from_www_authentication_header(header_value: HeaderValue) -> Result<Self> {
160 let header = String::from_utf8(header_value.as_bytes().to_vec())?;
161
162 let captures = AUTH_REGEX.captures_iter(&header).collect::<Vec<_>>();
165
166 let method = captures
167 .first()
168 .ok_or(WwwHeaderParseError::InvalidValue)?
169 .name("method")
170 .ok_or(WwwHeaderParseError::FieldMethodMissing)?
171 .as_str()
172 .to_lowercase();
173
174 let serialized_content = {
175 let serialized_captures = captures
176 .iter()
177 .filter_map(|capture| {
178 match (
179 capture.name("key").map(|n| n.as_str().to_lowercase()),
180 capture.name("value").map(|n| n.as_str().to_string()),
181 ) {
182 (Some(key), Some(value)) => Some(format!(
183 r#"{}: {}"#,
184 serde_json::Value::String(key),
185 serde_json::Value::String(value),
186 )),
187 _ => None,
188 }
189 })
190 .collect::<Vec<_>>()
191 .join(", ");
192
193 format!(
194 r#"{{ {}: {{ {} }} }}"#,
195 serde_json::Value::String(method),
196 serialized_captures
197 )
198 };
199
200 let mut unsupported_keys = std::collections::HashSet::new();
202 let content: WwwAuthenticateHeaderContent =
203 serde_ignored::deserialize(&mut serde_json::Deserializer::from_str(&serialized_content), |path| {
204 unsupported_keys.insert(path.to_string());
205 })?;
206
207 if !unsupported_keys.is_empty() {
208 warn!("skipping unrecognized keys in authentication header: {unsupported_keys:#?}");
209 }
210
211 Ok(content)
212 }
213}
214
215#[derive(Debug, Default, PartialEq, Eq, Deserialize)]
217pub(crate) struct WwwAuthenticateHeaderContentBearer {
218 realm: String,
219 service: Option<String>,
220 scope: Option<String>,
221}
222
223impl WwwAuthenticateHeaderContentBearer {
224 fn auth_ep(&self, scopes: &[&str]) -> String {
225 let service = self
226 .service
227 .as_ref()
228 .map(|sv| format!("?service={sv}"))
229 .unwrap_or_default();
230
231 let scope = scopes.iter().enumerate().fold(String::new(), |acc, (i, &s)| {
232 let separator = if i > 0 { "&" } else { "" };
233 acc + separator + "scope=" + s
234 });
235
236 let scope_prefix = if scopes.is_empty() {
237 ""
238 } else if service.is_empty() {
239 "?"
240 } else {
241 "&"
242 };
243
244 format!("{}{}{}{}", self.realm, service, scope_prefix, scope)
245 }
246}
247
248#[derive(Debug, Default, PartialEq, Eq, Deserialize)]
250pub(crate) struct WwwAuthenticateHeaderContentBasic {
251 realm: String,
252}
253
254impl Client {
255 async fn get_www_authentication_header(&self) -> Result<HeaderValue> {
257 let url = {
258 let ep = format!("{}/v2/", self.base_url.clone(),);
259 reqwest::Url::parse(&ep)?
260 };
261
262 let r = self.build_reqwest(Method::GET, url.clone()).send().await?;
263
264 trace!("GET '{}' status: {:?}", r.url(), r.status());
265 r.headers()
266 .get(reqwest::header::WWW_AUTHENTICATE)
267 .ok_or(Error::MissingAuthHeader("WWW-Authenticate"))
268 .map(ToOwned::to_owned)
269 }
270
271 pub async fn authenticate(mut self, scopes: &[&str]) -> Result<Self> {
275 let credentials = self.credentials.clone();
276
277 let client = Client {
278 auth: None,
279 ..self.clone()
280 };
281
282 let authentication_header = client.get_www_authentication_header().await?;
283 let auth = match WwwAuthenticateHeaderContent::from_www_authentication_header(authentication_header)? {
284 WwwAuthenticateHeaderContent::Basic(_) => {
285 let basic_auth = credentials
286 .map(|(user, password)| BasicAuth {
287 user,
288 password: Some(password),
289 })
290 .ok_or(Error::NoCredentials)?;
291
292 Auth::Basic(basic_auth)
293 }
294 WwwAuthenticateHeaderContent::Bearer(bearer_header_content) => {
295 let bearer_auth =
296 BearerAuth::try_from_header_content(client, scopes, credentials, bearer_header_content).await?;
297
298 Auth::Bearer(bearer_auth)
299 }
300 };
301
302 trace!("authenticate: login succeeded");
303 self.auth = Some(auth);
304
305 Ok(self)
306 }
307
308 pub async fn is_auth(&self) -> Result<bool> {
312 let url = {
313 let ep = format!("{}/v2/", self.base_url.clone(),);
314 Url::parse(&ep)?
315 };
316
317 let req = self.build_reqwest(Method::GET, url.clone());
318
319 trace!("Sending request to '{url}'");
320 let resp = req.send().await?;
321 trace!("GET '{url}' status={}", resp.status());
322
323 let status = resp.status();
324 match status {
325 reqwest::StatusCode::OK => Ok(true),
326 reqwest::StatusCode::UNAUTHORIZED => Ok(false),
327 _ => Err(Error::UnexpectedHttpStatus(status)),
328 }
329 }
330}
331
332#[cfg(test)]
333mod tests {
334 use test_case::test_case;
335
336 use super::*;
337
338 #[test]
339 fn bearer_realm_parses_correctly() -> Result<()> {
340 let realm = "https://sat-r220-02.lab.eng.rdu2.redhat.com/v2/token";
341 let service = "sat-r220-02.lab.eng.rdu2.redhat.com";
342 let scope = "repository:registry:pull,push";
343
344 for header_value in [
345 HeaderValue::from_str(&format!(
346 r#"Bearer realm="{realm}",service="{service}",scope="{scope}""#
347 ))
348 .unwrap(),
349 HeaderValue::from_str(&format!(
350 r#"bearer realm="{realm}",service="{service}",scope="{scope}""#
351 ))
352 .unwrap(),
353 HeaderValue::from_str(&format!(
354 r#"BEARER realm="{realm}",service="{service}",scope="{scope}""#
355 ))
356 .unwrap(),
357 HeaderValue::from_str(&format!(
358 r#"Bearer Realm="{realm}",Service="{service}",Scope="{scope}""#
359 ))
360 .unwrap(),
361 HeaderValue::from_str(&format!(
362 r#"Bearer REALM="{realm}",SERVICE="{service}",SCOPE="{scope}""#
363 ))
364 .unwrap(),
365 ]
366 .iter()
367 {
368 let content = WwwAuthenticateHeaderContent::from_www_authentication_header(header_value.to_owned())?;
369
370 assert_eq!(
371 WwwAuthenticateHeaderContent::Bearer(WwwAuthenticateHeaderContentBearer {
372 realm: realm.to_string(),
373 service: Some(service.to_string()),
374 scope: Some(scope.to_string()),
375 }),
376 content
377 );
378 }
379
380 Ok(())
381 }
382
383 #[test]
393 fn basic_realm_parses_correctly() -> Result<()> {
394 let realm = "Registry realm";
395
396 for header_value in [
397 HeaderValue::from_str(&format!(r#"Basic realm="{realm}""#)).unwrap(),
398 HeaderValue::from_str(&format!(r#"basic realm="{realm}""#)).unwrap(),
399 HeaderValue::from_str(&format!(r#"BASIC realm="{realm}""#)).unwrap(),
400 HeaderValue::from_str(&format!(r#"Basic Realm="{realm}""#)).unwrap(),
401 HeaderValue::from_str(&format!(r#"Basic REALM="{realm}""#)).unwrap(),
402 ]
403 .iter()
404 {
405 let content = WwwAuthenticateHeaderContent::from_www_authentication_header(header_value.to_owned())?;
406
407 assert_eq!(
408 WwwAuthenticateHeaderContent::Basic(WwwAuthenticateHeaderContentBasic {
409 realm: realm.to_string(),
410 }),
411 content
412 );
413 }
414
415 Ok(())
416 }
417
418 #[test_case(&[], true; "Test with no scopes and with service")]
424 #[test_case(&["repository:test:pull"], true; "Test with single scope and service")]
425 #[test_case(&["repository:test:pull", "repository:example:pull,push", "repository:another:*"], false;
426 "Test with multiple scopes")]
427 fn bearer_auth_ep_scope_construction(scopes: &[&str], include_service: bool) {
428 let realm = "https://sat-r220-02.lab.eng.rdu2.redhat.com/v2/token";
429 let service = "sat-r220-02.lab.eng.rdu2.redhat.com";
430
431 let bearer_header_content = WwwAuthenticateHeaderContentBearer {
432 realm: realm.to_string(),
433 service: if include_service {
434 Some(service.to_string())
435 } else {
436 None
437 },
438 scope: None,
439 };
440
441 let mut expected_headers: Vec<(String, String)> =
443 scopes.iter().map(|a| ("scope".to_owned(), a.to_string())).collect();
444 if include_service {
446 expected_headers.insert(0, ("service".to_owned(), service.to_string()));
447 }
448
449 let result = bearer_header_content.auth_ep(scopes);
450 let url = Url::parse(&result).unwrap();
451
452 assert_eq!(url.query_pairs().into_owned().collect::<Vec<_>>(), expected_headers);
453 }
454}