1use std::collections::HashMap;
7use std::error::Error;
8use std::io::{Read, Result};
9use std::sync::atomic::{AtomicBool, Ordering};
10use std::sync::{Arc, Once, RwLock};
11use std::time::{Duration, SystemTime, UNIX_EPOCH};
12use std::{fmt, thread};
13
14use arc_swap::{ArcSwap, ArcSwapOption};
15use base64::Engine;
16use reqwest::blocking::Response;
17pub use reqwest::header::HeaderMap;
18use reqwest::header::{HeaderValue, CONTENT_LENGTH};
19use reqwest::{Method, StatusCode};
20use url::{ParseError, Url};
21
22use nydus_api::RegistryConfig;
23use nydus_utils::metrics::BackendMetrics;
24
25use crate::backend::connection::{
26 is_success_status, respond, Connection, ConnectionConfig, ConnectionError, ReqBody,
27};
28use crate::backend::{BackendError, BackendResult, BlobBackend, BlobReader};
29
30const REGISTRY_CLIENT_ID: &str = "nydus-registry-client";
31const HEADER_AUTHORIZATION: &str = "Authorization";
32const HEADER_WWW_AUTHENTICATE: &str = "www-authenticate";
33
34const REGISTRY_DEFAULT_TOKEN_EXPIRATION: u64 = 10 * 60; #[derive(Debug)]
38pub enum RegistryError {
39 Common(String),
40 Url(String, ParseError),
41 Request(ConnectionError),
42 Scheme(String),
43 Transport(reqwest::Error),
44}
45
46impl fmt::Display for RegistryError {
47 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
48 match self {
49 RegistryError::Common(s) => write!(f, "failed to access blob from registry, {}", s),
50 RegistryError::Url(u, e) => write!(f, "failed to parse URL {}, {}", u, e),
51 RegistryError::Request(e) => write!(f, "failed to issue request, {}", e),
52 RegistryError::Scheme(s) => write!(f, "invalid scheme, {}", s),
53 RegistryError::Transport(e) => write!(f, "network transport error, {}", e),
54 }
55 }
56}
57
58impl From<RegistryError> for BackendError {
59 fn from(error: RegistryError) -> Self {
60 BackendError::Registry(error)
61 }
62}
63
64type RegistryResult<T> = std::result::Result<T, RegistryError>;
65
66#[derive(Default)]
67struct Cache(RwLock<String>);
68
69impl Cache {
70 fn new(val: String) -> Self {
71 Cache(RwLock::new(val))
72 }
73
74 fn get(&self) -> String {
75 let cached_guard = self.0.read().unwrap();
76 if !cached_guard.is_empty() {
77 return cached_guard.clone();
78 }
79 String::new()
80 }
81
82 fn set(&self, last: &str, current: String) {
83 if last != current {
84 let mut cached_guard = self.0.write().unwrap();
85 *cached_guard = current;
86 }
87 }
88}
89
90#[derive(Default)]
91struct HashCache<T>(RwLock<HashMap<String, T>>);
92
93impl<T> HashCache<T> {
94 fn new() -> Self {
95 HashCache(RwLock::new(HashMap::new()))
96 }
97
98 fn get(&self, key: &str) -> Option<T>
99 where
100 T: Clone,
101 {
102 let cached_guard = self.0.read().unwrap();
103 cached_guard.get(key).cloned()
104 }
105
106 fn set(&self, key: String, value: T) {
107 let mut cached_guard = self.0.write().unwrap();
108 cached_guard.insert(key, value);
109 }
110
111 fn remove(&self, key: &str) {
112 let mut cached_guard = self.0.write().unwrap();
113 cached_guard.remove(key);
114 }
115}
116
117#[derive(Clone, serde::Deserialize)]
118struct TokenResponse {
119 #[serde(default)]
122 token: String,
123 #[serde(default)]
124 access_token: String,
125 #[serde(default = "default_expires_in")]
127 expires_in: u64,
128}
129
130fn default_expires_in() -> u64 {
131 REGISTRY_DEFAULT_TOKEN_EXPIRATION
132}
133
134impl TokenResponse {
135 fn from_resp(resp: Response) -> Result<Self> {
137 let mut token: TokenResponse = resp.json().map_err(|e| {
138 einval!(format!(
139 "failed to decode registry auth server response: {:?}",
140 e
141 ))
142 })?;
143
144 if token.token.is_empty() {
145 if token.access_token.is_empty() {
146 return Err(einval!("failed to get auth token from registry"));
147 }
148 token.token = token.access_token.clone();
149 }
150 Ok(token)
151 }
152}
153
154#[derive(Debug)]
155struct BasicAuth {
156 #[allow(unused)]
157 realm: String,
158}
159
160#[derive(Debug, Clone)]
161#[allow(dead_code)]
162struct BearerAuth {
163 realm: String,
164 service: String,
165 scope: String,
166}
167
168#[derive(Debug)]
169#[allow(dead_code)]
170enum Auth {
171 Basic(BasicAuth),
172 Bearer(BearerAuth),
173}
174
175pub struct Scheme(AtomicBool);
176
177impl Scheme {
178 fn new(value: bool) -> Self {
179 Scheme(AtomicBool::new(value))
180 }
181}
182
183impl fmt::Display for Scheme {
184 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
185 if self.0.load(Ordering::Relaxed) {
186 write!(f, "https")
187 } else {
188 write!(f, "http")
189 }
190 }
191}
192
193struct RegistryState {
194 scheme: Scheme,
196 host: String,
197 repo: String,
199 auth: Option<String>,
201 retry_limit: u8,
203 blob_url_scheme: String,
205 blob_redirected_host: String,
207 cached_auth: Cache,
213 cached_auth_using_http_get: HashCache<bool>,
218 cached_redirect: HashCache<String>,
221 token_expired_at: ArcSwapOption<u64>,
223 cached_bearer_auth: ArcSwapOption<BearerAuth>,
225}
226
227impl RegistryState {
228 fn url(&self, path: &str, query: &[&str]) -> std::result::Result<String, ParseError> {
229 let path = if query.is_empty() {
230 format!("/v2/{}{}", self.repo, path)
231 } else {
232 format!("/v2/{}{}?{}", self.repo, path, query.join("&"))
233 };
234 let url = format!("{}://{}", self.scheme, self.host.as_str());
235 let url = Url::parse(url.as_str())?;
236 let url = url.join(path.as_str())?;
237
238 Ok(url.to_string())
239 }
240
241 fn needs_fallback_http(&self, e: &dyn Error) -> bool {
242 match e.source() {
243 Some(err) => match err.source() {
244 Some(err) => {
245 if !self.scheme.0.load(Ordering::Relaxed) {
246 return false;
247 }
248 let msg = err.to_string().to_lowercase();
249 let fallback = msg.contains("wrong version number")
254 || msg.contains("connection refused")
255 || msg.to_lowercase().contains("ssl");
256 if fallback {
257 warn!("fallback to http due to tls connection error: {}", err);
258 }
259 fallback
260 }
261 None => false,
262 },
263 None => false,
264 }
265 }
266
267 fn get_token(&self, auth: BearerAuth, connection: &Arc<Connection>) -> Result<TokenResponse> {
269 let http_get = self
270 .cached_auth_using_http_get
271 .get(&self.host)
272 .unwrap_or_default();
273 let resp = if http_get {
274 self.fetch_token(&auth, connection, Method::GET)?
275 } else {
276 match self.fetch_token(&auth, connection, Method::POST) {
277 Ok(resp) => resp,
278 Err(_) => {
279 warn!("retry http GET method to get auth token");
280 let resp = self.fetch_token(&auth, connection, Method::GET)?;
281 self.cached_auth_using_http_get.set(self.host.clone(), true);
283 resp
284 }
285 }
286 };
287
288 let ret = TokenResponse::from_resp(resp)
289 .map_err(|e| einval!(format!("failed to get auth token from registry: {:?}", e)))?;
290
291 if let Ok(now_timestamp) = SystemTime::now().duration_since(UNIX_EPOCH) {
292 self.token_expired_at
293 .store(Some(Arc::new(now_timestamp.as_secs() + ret.expires_in)));
294 debug!(
295 "cached bearer auth, next time: {}",
296 now_timestamp.as_secs() + ret.expires_in
297 );
298 }
299
300 self.cached_bearer_auth.store(Some(Arc::new(auth)));
302
303 Ok(ret)
304 }
305
306 fn fetch_token(
308 &self,
309 auth: &BearerAuth,
310 connection: &Arc<Connection>,
311 method: Method,
312 ) -> Result<Response> {
313 let mut headers = HeaderMap::new();
314
315 if let Some(auth) = &self.auth {
316 headers.insert(
317 HEADER_AUTHORIZATION,
318 format!("Basic {}", auth).parse().unwrap(),
319 );
320 }
321
322 let mut query: Option<&[(&str, &str)]> = None;
323 let mut body = None;
324
325 let query_params_get;
326
327 match method {
328 Method::GET => {
329 query_params_get = [
330 ("service", auth.service.as_str()),
331 ("scope", auth.scope.as_str()),
332 ("client_id", REGISTRY_CLIENT_ID),
333 ];
334 query = Some(&query_params_get);
335 }
336 Method::POST => {
337 let mut form = HashMap::new();
338 form.insert("service".to_string(), auth.service.clone());
339 form.insert("scope".to_string(), auth.scope.clone());
340 form.insert("client_id".to_string(), REGISTRY_CLIENT_ID.to_string());
341 body = Some(ReqBody::Form(form));
342 }
343 _ => return Err(einval!()),
344 }
345
346 let token_resp = connection
347 .call::<&[u8]>(
348 method.clone(),
349 auth.realm.as_str(),
350 query,
351 body,
352 &mut headers,
353 true,
354 )
355 .map_err(move |e| {
356 warn!(
357 "failed to request registry auth server by {:?} method: {:?}",
358 method, e
359 );
360 einval!()
361 })?;
362
363 Ok(token_resp)
364 }
365
366 fn get_auth_header(&self, auth: Auth, connection: &Arc<Connection>) -> Result<String> {
367 match auth {
368 Auth::Basic(_) => self
369 .auth
370 .as_ref()
371 .map(|auth| format!("Basic {}", auth))
372 .ok_or_else(|| einval!("invalid auth config")),
373 Auth::Bearer(auth) => {
374 let token = self.get_token(auth, connection)?;
375 Ok(format!("Bearer {}", token.token))
376 }
377 }
378 }
379
380 fn parse_auth(source: &HeaderValue) -> Option<Auth> {
383 let source = source.to_str().unwrap();
384 let source: Vec<&str> = source.splitn(2, ' ').collect();
385 if source.len() < 2 {
386 return None;
387 }
388 let scheme = source[0].trim();
389 let pairs = source[1].trim();
390 let pairs = pairs.split("\",");
391 let mut paras = HashMap::new();
392 for pair in pairs {
393 let pair: Vec<&str> = pair.trim().split('=').collect();
394 if pair.len() < 2 {
395 return None;
396 }
397 let key = pair[0].trim();
398 let value = pair[1].trim().trim_matches('"');
399 paras.insert(key, value);
400 }
401
402 match scheme {
403 "Basic" => {
404 let realm = if let Some(realm) = paras.get("realm") {
405 (*realm).to_string()
406 } else {
407 String::new()
408 };
409 Some(Auth::Basic(BasicAuth { realm }))
410 }
411 "Bearer" => {
412 if !paras.contains_key("realm") || !paras.contains_key("service") {
413 return None;
414 }
415
416 let scope = if let Some(scope) = paras.get("scope") {
417 (*scope).to_string()
418 } else {
419 debug!("no scope specified for token auth challenge");
420 String::new()
421 };
422
423 Some(Auth::Bearer(BearerAuth {
424 realm: (*paras.get("realm").unwrap()).to_string(),
425 service: (*paras.get("service").unwrap()).to_string(),
426 scope,
427 }))
428 }
429 _ => None,
430 }
431 }
432
433 fn fallback_http(&self) {
434 self.scheme.0.store(false, Ordering::Relaxed);
435 }
436}
437
438#[derive(Clone)]
439struct First {
440 inner: Arc<ArcSwap<Once>>,
441}
442
443impl First {
444 fn new() -> Self {
445 First {
446 inner: Arc::new(ArcSwap::new(Arc::new(Once::new()))),
447 }
448 }
449
450 fn once<F>(&self, f: F)
451 where
452 F: FnOnce(),
453 {
454 self.inner.load().call_once(f)
455 }
456
457 fn renew(&self) {
458 self.inner.store(Arc::new(Once::new()));
459 }
460
461 fn handle<F, T>(&self, handle: &mut F) -> Option<BackendResult<T>>
462 where
463 F: FnMut() -> BackendResult<T>,
464 {
465 let mut ret = None;
466 for _ in 0..=1 {
469 self.once(|| {
470 ret = Some(handle().inspect_err(|_err| {
471 self.renew();
474 }));
475 });
476 if ret.is_some() {
477 break;
478 }
479 }
480 ret
481 }
482
483 fn handle_force<F, T>(&self, handle: &mut F) -> BackendResult<T>
493 where
494 F: FnMut() -> BackendResult<T>,
495 {
496 self.handle(handle).unwrap_or_else(handle)
497 }
498}
499
500struct RegistryReader {
501 blob_id: String,
502 connection: Arc<Connection>,
503 state: Arc<RegistryState>,
504 metrics: Arc<BackendMetrics>,
505 first: First,
506}
507
508impl RegistryReader {
509 fn request<R: Read + Clone + Send + 'static>(
536 &self,
537 method: Method,
538 url: &str,
539 data: Option<ReqBody<R>>,
540 mut headers: HeaderMap,
541 catch_status: bool,
542 ) -> RegistryResult<Response> {
543 let mut last_cached_auth = String::new();
545 let cached_auth = self.state.cached_auth.get();
546 if !cached_auth.is_empty() {
547 last_cached_auth = cached_auth.clone();
548 headers.insert(
549 HEADER_AUTHORIZATION,
550 HeaderValue::from_str(cached_auth.as_str()).unwrap(),
551 );
552 }
553
554 if let Some(data) = data {
557 return self
558 .connection
559 .call(method, url, None, Some(data), &mut headers, catch_status)
560 .map_err(RegistryError::Request);
561 }
562
563 let mut resp = self
565 .connection
566 .call::<&[u8]>(method.clone(), url, None, None, &mut headers, false)
567 .map_err(RegistryError::Request)?;
568 if resp.status() == StatusCode::UNAUTHORIZED {
569 if headers.contains_key(HEADER_AUTHORIZATION) {
570 headers.remove(HEADER_AUTHORIZATION);
578
579 resp = self
580 .connection
581 .call::<&[u8]>(method.clone(), url, None, None, &mut headers, false)
582 .map_err(RegistryError::Request)?;
583 };
584
585 if let Some(resp_auth_header) = resp.headers().get(HEADER_WWW_AUTHENTICATE) {
586 if let Some(auth) = RegistryState::parse_auth(resp_auth_header) {
588 let auth_header = self
589 .state
590 .get_auth_header(auth, &self.connection)
591 .map_err(|e| RegistryError::Common(e.to_string()))?;
592
593 headers.insert(
594 HEADER_AUTHORIZATION,
595 HeaderValue::from_str(auth_header.as_str()).unwrap(),
596 );
597
598 let resp = self
600 .connection
601 .call(method, url, None, data, &mut headers, catch_status)
602 .map_err(RegistryError::Request)?;
603
604 let status = resp.status();
605 if is_success_status(status) {
606 self.state.cached_auth.set(&last_cached_auth, auth_header)
608 }
609 return respond(resp, catch_status).map_err(RegistryError::Request);
610 }
611 }
612 }
613
614 respond(resp, catch_status).map_err(RegistryError::Request)
615 }
616
617 fn _try_read(
629 &self,
630 mut buf: &mut [u8],
631 offset: u64,
632 allow_retry: bool,
633 ) -> RegistryResult<usize> {
634 let url = format!("/blobs/sha256:{}", self.blob_id);
635 let url = self
636 .state
637 .url(url.as_str(), &[])
638 .map_err(|e| RegistryError::Url(url, e))?;
639 let mut headers = HeaderMap::new();
640 let end_at = offset + buf.len() as u64 - 1;
641 let range = format!("bytes={}-{}", offset, end_at);
642 headers.insert("Range", range.parse().unwrap());
643
644 let mut resp;
645 let cached_redirect = self.state.cached_redirect.get(&self.blob_id);
646
647 if let Some(cached_redirect) = cached_redirect {
648 resp = self
649 .connection
650 .call::<&[u8]>(
651 Method::GET,
652 cached_redirect.as_str(),
653 None,
654 None,
655 &mut headers,
656 false,
657 )
658 .map_err(RegistryError::Request)?;
659
660 if allow_retry
662 && [StatusCode::UNAUTHORIZED, StatusCode::FORBIDDEN].contains(&resp.status())
663 {
664 warn!(
665 "The redirected link has expired: {}, will retry read",
666 cached_redirect.as_str()
667 );
668 self.state.cached_redirect.remove(&self.blob_id);
669 return self._try_read(buf, offset, false);
671 }
672 } else {
673 resp = match self.request::<&[u8]>(
674 Method::GET,
675 url.as_str(),
676 None,
677 headers.clone(),
678 false,
679 ) {
680 Ok(res) => res,
681 Err(RegistryError::Request(ConnectionError::Common(e)))
682 if self.state.needs_fallback_http(&e) =>
683 {
684 self.state.fallback_http();
685 let url = format!("/blobs/sha256:{}", self.blob_id);
686 let url = self
687 .state
688 .url(url.as_str(), &[])
689 .map_err(|e| RegistryError::Url(url, e))?;
690 self.request::<&[u8]>(Method::GET, url.as_str(), None, headers.clone(), false)?
691 }
692 Err(RegistryError::Request(ConnectionError::Common(e))) => {
693 if e.to_string().contains("self signed certificate") {
694 warn!("try to enable \"skip_verify: true\" option");
695 }
696 return Err(RegistryError::Request(ConnectionError::Common(e)));
697 }
698 Err(e) => {
699 return Err(e);
700 }
701 };
702 let status = resp.status();
703 let need_redirect =
704 status >= StatusCode::MULTIPLE_CHOICES && status < StatusCode::BAD_REQUEST;
705
706 if need_redirect {
708 if let Some(location) = resp.headers().get("location") {
709 let location = location.to_str().unwrap();
710 let mut location = Url::parse(location)
711 .map_err(|e| RegistryError::Url(location.to_string(), e))?;
712 if !self.state.blob_url_scheme.is_empty() {
715 location
716 .set_scheme(&self.state.blob_url_scheme)
717 .map_err(|_| {
718 RegistryError::Scheme(self.state.blob_url_scheme.clone())
719 })?;
720 }
721 if !self.state.blob_redirected_host.is_empty() {
722 location
723 .set_host(Some(self.state.blob_redirected_host.as_str()))
724 .map_err(|e| {
725 error!(
726 "Failed to set blob redirected host to {}: {:?}",
727 self.state.blob_redirected_host.as_str(),
728 e
729 );
730 RegistryError::Url(location.to_string(), e)
731 })?;
732 debug!("New redirected location {:?}", location.host_str());
733 }
734 let resp_ret = self
735 .connection
736 .call::<&[u8]>(
737 Method::GET,
738 location.as_str(),
739 None,
740 None,
741 &mut headers,
742 true,
743 )
744 .map_err(RegistryError::Request);
745 match resp_ret {
746 Ok(_resp) => {
747 resp = _resp;
748 self.state
749 .cached_redirect
750 .set(self.blob_id.clone(), location.as_str().to_string())
751 }
752 Err(err) => {
753 return Err(err);
754 }
755 }
756 };
757 } else {
758 resp = respond(resp, true).map_err(RegistryError::Request)?;
759 }
760 }
761
762 resp.copy_to(&mut buf)
763 .map_err(RegistryError::Transport)
764 .map(|size| size as usize)
765 }
766}
767
768impl BlobReader for RegistryReader {
769 fn blob_size(&self) -> BackendResult<u64> {
770 self.first.handle_force(&mut || -> BackendResult<u64> {
771 let url = format!("/blobs/sha256:{}", self.blob_id);
772 let url = self
773 .state
774 .url(&url, &[])
775 .map_err(|e| RegistryError::Url(url, e))?;
776
777 let resp = match self.request::<&[u8]>(
778 Method::HEAD,
779 url.as_str(),
780 None,
781 HeaderMap::new(),
782 true,
783 ) {
784 Ok(res) => res,
785 Err(RegistryError::Request(ConnectionError::Common(e)))
786 if self.state.needs_fallback_http(&e) =>
787 {
788 self.state.fallback_http();
789 let url = format!("/blobs/sha256:{}", self.blob_id);
790 let url = self
791 .state
792 .url(&url, &[])
793 .map_err(|e| RegistryError::Url(url, e))?;
794 self.request::<&[u8]>(Method::HEAD, url.as_str(), None, HeaderMap::new(), true)?
795 }
796 Err(e) => {
797 return Err(BackendError::Registry(e));
798 }
799 };
800 let content_length = resp
801 .headers()
802 .get(CONTENT_LENGTH)
803 .ok_or_else(|| RegistryError::Common("invalid content length".to_string()))?;
804
805 Ok(content_length
806 .to_str()
807 .map_err(|err| RegistryError::Common(format!("invalid content length: {:?}", err)))?
808 .parse::<u64>()
809 .map_err(|err| {
810 RegistryError::Common(format!("invalid content length: {:?}", err))
811 })?)
812 })
813 }
814
815 fn try_read(&self, buf: &mut [u8], offset: u64) -> BackendResult<usize> {
816 self.first.handle_force(&mut || -> BackendResult<usize> {
817 self._try_read(buf, offset, true)
818 .map_err(BackendError::Registry)
819 })
820 }
821
822 fn metrics(&self) -> &BackendMetrics {
823 &self.metrics
824 }
825
826 fn retry_limit(&self) -> u8 {
827 self.state.retry_limit
828 }
829}
830
831pub struct Registry {
833 connection: Arc<Connection>,
834 state: Arc<RegistryState>,
835 metrics: Arc<BackendMetrics>,
836 first: First,
837}
838
839impl Registry {
840 #[allow(clippy::useless_let_if_seq)]
841 pub fn new(config: &RegistryConfig, id: Option<&str>) -> Result<Registry> {
842 let id = id.ok_or_else(|| einval!("Registry backend requires blob_id"))?;
843 let con_config: ConnectionConfig = config.clone().into();
844
845 let retry_limit = con_config.retry_limit;
846 let connection = Connection::new(&con_config)?;
847 let auth = trim(config.auth.clone());
848 let registry_token = trim(config.registry_token.clone());
849 Self::validate_authorization_info(&auth)?;
850 let cached_auth = if let Some(registry_token) = registry_token {
851 Cache::new(format!("Bearer {}", registry_token))
854 } else {
855 Cache::new(String::new())
856 };
857
858 let scheme = if !config.scheme.is_empty() && config.scheme == "http" {
859 Scheme::new(false)
860 } else {
861 Scheme::new(true)
862 };
863
864 let state = Arc::new(RegistryState {
865 scheme,
866 host: config.host.clone(),
867 repo: config.repo.clone(),
868 auth,
869 cached_auth,
870 retry_limit,
871 blob_url_scheme: config.blob_url_scheme.clone(),
872 blob_redirected_host: config.blob_redirected_host.clone(),
873 cached_auth_using_http_get: HashCache::new(),
874 cached_redirect: HashCache::new(),
875 token_expired_at: ArcSwapOption::new(None),
876 cached_bearer_auth: ArcSwapOption::new(None),
877 });
878
879 let registry = Registry {
880 connection,
881 state,
882 metrics: BackendMetrics::new(id, "registry"),
883 first: First::new(),
884 };
885
886 if config.disable_token_refresh {
887 info!("Refresh token thread is disabled.");
888 } else {
889 registry.start_refresh_token_thread();
890 info!("Refresh token thread started.");
891 }
892
893 Ok(registry)
894 }
895
896 fn validate_authorization_info(auth: &Option<String>) -> Result<()> {
897 if let Some(auth) = &auth {
898 let auth: Vec<u8> = base64::engine::general_purpose::STANDARD
899 .decode(auth.as_bytes())
900 .map_err(|e| {
901 einval!(format!(
902 "Invalid base64 encoded registry auth config: {:?}",
903 e
904 ))
905 })?;
906 let auth = std::str::from_utf8(&auth).map_err(|e| {
907 einval!(format!(
908 "Invalid utf-8 encoded registry auth config: {:?}",
909 e
910 ))
911 })?;
912 let auth: Vec<&str> = auth.splitn(2, ':').collect();
913 if auth.len() < 2 {
914 return Err(einval!("Invalid registry auth config"));
915 }
916 }
917 Ok(())
918 }
919
920 fn start_refresh_token_thread(&self) {
921 let conn = self.connection.clone();
922 let state = self.state.clone();
923 let mut refresh_interval = REGISTRY_DEFAULT_TOKEN_EXPIRATION;
925 thread::spawn(move || {
926 loop {
927 if let Ok(now_timestamp) = SystemTime::now().duration_since(UNIX_EPOCH) {
928 if let Some(token_expired_at) = state.token_expired_at.load().as_deref() {
929 if now_timestamp.as_secs() + refresh_interval >= *token_expired_at {
932 if let Some(cached_bearer_auth) =
933 state.cached_bearer_auth.load().as_deref()
934 {
935 if let Ok(token) =
936 state.get_token(cached_bearer_auth.to_owned(), &conn)
937 {
938 let new_cached_auth = format!("Bearer {}", token.token);
939 debug!(
940 "[refresh_token_thread] registry token has been refreshed"
941 );
942 state
944 .cached_auth
945 .set(&state.cached_auth.get(), new_cached_auth);
946 refresh_interval = token
949 .expires_in
950 .checked_sub(20)
951 .unwrap_or(token.expires_in);
952 } else {
953 error!(
954 "[refresh_token_thread] failed to refresh registry token"
955 );
956 }
957 }
958 }
959 }
960 }
961
962 if conn.shutdown.load(Ordering::Acquire) {
963 break;
964 }
965 thread::sleep(Duration::from_secs(refresh_interval));
966 if conn.shutdown.load(Ordering::Acquire) {
967 break;
968 }
969 }
970 });
971 }
972}
973
974impl BlobBackend for Registry {
975 fn shutdown(&self) {
976 self.connection.shutdown();
977 }
978
979 fn metrics(&self) -> &BackendMetrics {
980 &self.metrics
981 }
982
983 fn get_reader(&self, blob_id: &str) -> BackendResult<Arc<dyn BlobReader>> {
984 Ok(Arc::new(RegistryReader {
985 blob_id: blob_id.to_owned(),
986 state: self.state.clone(),
987 connection: self.connection.clone(),
988 metrics: self.metrics.clone(),
989 first: self.first.clone(),
990 }))
991 }
992}
993
994impl Drop for Registry {
995 fn drop(&mut self) {
996 self.metrics.release().unwrap_or_else(|e| error!("{:?}", e));
997 }
998}
999
1000fn trim(value: Option<String>) -> Option<String> {
1001 if let Some(val) = value.as_ref() {
1002 let trimmed_val = val.trim();
1003 if trimmed_val.is_empty() {
1004 None
1005 } else if trimmed_val.len() == val.len() {
1006 value
1007 } else {
1008 Some(trimmed_val.to_string())
1009 }
1010 } else {
1011 None
1012 }
1013}
1014
1015#[cfg(test)]
1016mod tests {
1017 use super::*;
1018 use http::response;
1019 use serde_json::json;
1020
1021 #[test]
1022 fn test_string_cache() {
1023 let cache = Cache::new("test".to_owned());
1024
1025 assert_eq!(cache.get(), "test");
1026
1027 cache.set("test", "test1".to_owned());
1028 assert_eq!(cache.get(), "test1");
1029 cache.set("test1", "test1".to_owned());
1030 assert_eq!(cache.get(), "test1");
1031 }
1032
1033 #[test]
1034 fn test_hash_cache() {
1035 let cache = HashCache::new();
1036
1037 assert_eq!(cache.get("test"), None);
1038 cache.set("test".to_owned(), "test".to_owned());
1039 assert_eq!(cache.get("test"), Some("test".to_owned()));
1040 cache.set("test".to_owned(), "test1".to_owned());
1041 assert_eq!(cache.get("test"), Some("test1".to_owned()));
1042 cache.remove("test");
1043 assert_eq!(cache.get("test"), None);
1044 }
1045
1046 #[test]
1047 fn test_state_url() {
1048 let state = RegistryState {
1049 scheme: Scheme::new(false),
1050 host: "alibaba-inc.com".to_string(),
1051 repo: "nydus".to_string(),
1052 auth: None,
1053 retry_limit: 5,
1054 blob_url_scheme: "https".to_string(),
1055 blob_redirected_host: "oss.alibaba-inc.com".to_string(),
1056 cached_auth_using_http_get: Default::default(),
1057 cached_auth: Default::default(),
1058 cached_redirect: Default::default(),
1059 token_expired_at: ArcSwapOption::new(None),
1060 cached_bearer_auth: ArcSwapOption::new(None),
1061 };
1062
1063 assert_eq!(
1064 state.url("image", &["blabla"]).unwrap(),
1065 "http://alibaba-inc.com/v2/nydusimage?blabla".to_owned()
1066 );
1067 assert_eq!(
1068 state.url("image", &[]).unwrap(),
1069 "http://alibaba-inc.com/v2/nydusimage".to_owned()
1070 );
1071 }
1072
1073 #[test]
1074 fn test_parse_auth() {
1075 let str = "Bearer realm=\"https://auth.my-registry.com/token\",service=\"my-registry.com\",scope=\"repository:test/repo:pull,push\"";
1076 let header = HeaderValue::from_str(str).unwrap();
1077 let auth = RegistryState::parse_auth(&header).unwrap();
1078 match auth {
1079 Auth::Bearer(auth) => {
1080 assert_eq!(&auth.realm, "https://auth.my-registry.com/token");
1081 assert_eq!(&auth.service, "my-registry.com");
1082 assert_eq!(&auth.scope, "repository:test/repo:pull,push");
1083 }
1084 _ => panic!("failed to parse `Bearer` authentication header"),
1085 }
1086
1087 let str = "Bearer realm=\"https://auth.my-registry.com/token\",service=\"my-registry.com\"";
1089 let header = HeaderValue::from_str(str).unwrap();
1090 let auth = RegistryState::parse_auth(&header).unwrap();
1091 match auth {
1092 Auth::Bearer(auth) => {
1093 assert_eq!(&auth.realm, "https://auth.my-registry.com/token");
1094 assert_eq!(&auth.service, "my-registry.com");
1095 assert_eq!(&auth.scope, "");
1096 }
1097 _ => panic!("failed to parse `Bearer` authentication header without scope"),
1098 }
1099
1100 let str = "Basic realm=\"https://auth.my-registry.com/token\"";
1101 let header = HeaderValue::from_str(str).unwrap();
1102 let auth = RegistryState::parse_auth(&header).unwrap();
1103 match auth {
1104 Auth::Basic(auth) => assert_eq!(&auth.realm, "https://auth.my-registry.com/token"),
1105 _ => panic!("failed to parse `Basic` authentication header"),
1106 }
1107
1108 let str = "Base realm=\"https://auth.my-registry.com/token\"";
1109 let header = HeaderValue::from_str(str).unwrap();
1110 assert!(RegistryState::parse_auth(&header).is_none());
1111 }
1112
1113 #[test]
1114 fn test_trim() {
1115 assert_eq!(trim(None), None);
1116 assert_eq!(trim(Some("".to_owned())), None);
1117 assert_eq!(trim(Some(" ".to_owned())), None);
1118 assert_eq!(trim(Some(" test ".to_owned())), Some("test".to_owned()));
1119 assert_eq!(trim(Some("test ".to_owned())), Some("test".to_owned()));
1120 assert_eq!(trim(Some(" test".to_owned())), Some("test".to_owned()));
1121 assert_eq!(trim(Some(" te st ".to_owned())), Some("te st".to_owned()));
1122 assert_eq!(trim(Some("te st".to_owned())), Some("te st".to_owned()));
1123 }
1124
1125 #[test]
1126 #[allow(clippy::redundant_clone)]
1127 fn test_first_basically() {
1128 let first = First::new();
1129 let mut val = 0;
1130 first.once(|| {
1131 val += 1;
1132 });
1133 assert_eq!(val, 1);
1134
1135 first.clone().once(|| {
1136 val += 1;
1137 });
1138 assert_eq!(val, 1);
1139
1140 first.renew();
1141 first.clone().once(|| {
1142 val += 1;
1143 });
1144 assert_eq!(val, 2);
1145 }
1146
1147 #[test]
1148 #[allow(clippy::redundant_clone)]
1149 fn test_first_concurrently() {
1150 let val = Arc::new(ArcSwap::new(Arc::new(0)));
1151 let first = First::new();
1152
1153 let mut handlers = Vec::new();
1154 for _ in 0..100 {
1155 let val_cloned = val.clone();
1156 let first_cloned = first.clone();
1157 handlers.push(std::thread::spawn(move || {
1158 let _ = first_cloned.handle(&mut || -> BackendResult<()> {
1159 let val = val_cloned.load();
1160 let ret = if *val.as_ref() == 0 {
1161 std::thread::sleep(std::time::Duration::from_secs(2));
1162 Err(BackendError::Registry(RegistryError::Common(String::from(
1163 "network error",
1164 ))))
1165 } else {
1166 Ok(())
1167 };
1168 val_cloned.store(Arc::new(val.as_ref() + 1));
1169 ret
1170 });
1171 }));
1172 }
1173
1174 for handler in handlers {
1175 handler.join().unwrap();
1176 }
1177
1178 assert_eq!(*val.load().as_ref(), 2);
1179 }
1180
1181 #[test]
1182 fn test_token_response_from_resp() {
1183 let json_with_token = json!({
1185 "token": "test_token_value",
1186 "expires_in": 3600
1187 });
1188 let response = Response::from(
1189 response::Builder::new()
1190 .body(json_with_token.to_string())
1191 .unwrap(),
1192 );
1193 let result = TokenResponse::from_resp(response).unwrap();
1194 assert_eq!(result.token, "test_token_value");
1195 assert_eq!(result.expires_in, 3600);
1196
1197 let json_with_access_token = json!({
1199 "access_token": "test_access_token_value",
1200 "expires_in": 7200
1201 });
1202 let response = Response::from(
1203 response::Builder::new()
1204 .body(json_with_access_token.to_string())
1205 .unwrap(),
1206 );
1207 let result = TokenResponse::from_resp(response).unwrap();
1208 assert_eq!(result.token, "test_access_token_value");
1209 assert_eq!(result.expires_in, 7200);
1210
1211 let json_with_default_expiration = json!({
1213 "token": "default_expiration_token"
1214 });
1215 let response = Response::from(
1216 response::Builder::new()
1217 .body(json_with_default_expiration.to_string())
1218 .unwrap(),
1219 );
1220 let result = TokenResponse::from_resp(response).unwrap();
1221 assert_eq!(result.token, "default_expiration_token");
1222 assert_eq!(result.expires_in, REGISTRY_DEFAULT_TOKEN_EXPIRATION);
1223
1224 let json_with_both_tokens = json!({
1226 "token": "test_token_value",
1227 "access_token": "test_access_token_value",
1228 });
1229 let response = Response::from(
1230 response::Builder::new()
1231 .body(json_with_both_tokens.to_string())
1232 .unwrap(),
1233 );
1234 let result = TokenResponse::from_resp(response).unwrap();
1235 assert_eq!(result.token, "test_token_value");
1236
1237 let json_with_no_token = json!({});
1239 let response = Response::from(
1240 response::Builder::new()
1241 .body(json_with_no_token.to_string())
1242 .unwrap(),
1243 );
1244 let result = TokenResponse::from_resp(response);
1245 assert!(result.is_err());
1246 }
1247}