Skip to main content

nydus_storage/backend/
registry.rs

1// Copyright 2020 Ant Group. All rights reserved.
2//
3// SPDX-License-Identifier: Apache-2.0
4
5//! Storage backend driver to access blobs on container image registry.
6use std::collections::HashMap;
7use std::error::Error;
8use std::io::{Read, Result};
9use std::sync::atomic::{AtomicBool, Ordering};
10use std::sync::{Arc, Once, RwLock};
11use std::time::{Duration, SystemTime, UNIX_EPOCH};
12use std::{fmt, thread};
13
14use arc_swap::{ArcSwap, ArcSwapOption};
15use base64::Engine;
16use reqwest::blocking::Response;
17pub use reqwest::header::HeaderMap;
18use reqwest::header::{HeaderValue, CONTENT_LENGTH};
19use reqwest::{Method, StatusCode};
20use url::{ParseError, Url};
21
22use nydus_api::RegistryConfig;
23use nydus_utils::metrics::BackendMetrics;
24
25use crate::backend::connection::{
26    is_success_status, respond, Connection, ConnectionConfig, ConnectionError, ReqBody,
27};
28use crate::backend::{BackendError, BackendResult, BlobBackend, BlobReader};
29
30const REGISTRY_CLIENT_ID: &str = "nydus-registry-client";
31const HEADER_AUTHORIZATION: &str = "Authorization";
32const HEADER_WWW_AUTHENTICATE: &str = "www-authenticate";
33
34const REGISTRY_DEFAULT_TOKEN_EXPIRATION: u64 = 10 * 60; // in seconds
35
36/// Error codes related to registry storage backend operations.
37#[derive(Debug)]
38pub enum RegistryError {
39    Common(String),
40    Url(String, ParseError),
41    Request(ConnectionError),
42    Scheme(String),
43    Transport(reqwest::Error),
44}
45
46impl fmt::Display for RegistryError {
47    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
48        match self {
49            RegistryError::Common(s) => write!(f, "failed to access blob from registry, {}", s),
50            RegistryError::Url(u, e) => write!(f, "failed to parse URL {}, {}", u, e),
51            RegistryError::Request(e) => write!(f, "failed to issue request, {}", e),
52            RegistryError::Scheme(s) => write!(f, "invalid scheme, {}", s),
53            RegistryError::Transport(e) => write!(f, "network transport error, {}", e),
54        }
55    }
56}
57
58impl From<RegistryError> for BackendError {
59    fn from(error: RegistryError) -> Self {
60        BackendError::Registry(error)
61    }
62}
63
64type RegistryResult<T> = std::result::Result<T, RegistryError>;
65
66#[derive(Default)]
67struct Cache(RwLock<String>);
68
69impl Cache {
70    fn new(val: String) -> Self {
71        Cache(RwLock::new(val))
72    }
73
74    fn get(&self) -> String {
75        let cached_guard = self.0.read().unwrap();
76        if !cached_guard.is_empty() {
77            return cached_guard.clone();
78        }
79        String::new()
80    }
81
82    fn set(&self, last: &str, current: String) {
83        if last != current {
84            let mut cached_guard = self.0.write().unwrap();
85            *cached_guard = current;
86        }
87    }
88}
89
90#[derive(Default)]
91struct HashCache<T>(RwLock<HashMap<String, T>>);
92
93impl<T> HashCache<T> {
94    fn new() -> Self {
95        HashCache(RwLock::new(HashMap::new()))
96    }
97
98    fn get(&self, key: &str) -> Option<T>
99    where
100        T: Clone,
101    {
102        let cached_guard = self.0.read().unwrap();
103        cached_guard.get(key).cloned()
104    }
105
106    fn set(&self, key: String, value: T) {
107        let mut cached_guard = self.0.write().unwrap();
108        cached_guard.insert(key, value);
109    }
110
111    fn remove(&self, key: &str) {
112        let mut cached_guard = self.0.write().unwrap();
113        cached_guard.remove(key);
114    }
115}
116
117#[derive(Clone, serde::Deserialize)]
118struct TokenResponse {
119    /// Registry token string.
120    /// This field might vary depending on the registry server.
121    #[serde(default)]
122    token: String,
123    #[serde(default)]
124    access_token: String,
125    /// Registry token period of validity, in seconds.
126    #[serde(default = "default_expires_in")]
127    expires_in: u64,
128}
129
130fn default_expires_in() -> u64 {
131    REGISTRY_DEFAULT_TOKEN_EXPIRATION
132}
133
134impl TokenResponse {
135    // Extract the bearer token from the registry auth server response
136    fn from_resp(resp: Response) -> Result<Self> {
137        let mut token: TokenResponse = resp.json().map_err(|e| {
138            einval!(format!(
139                "failed to decode registry auth server response: {:?}",
140                e
141            ))
142        })?;
143
144        if token.token.is_empty() {
145            if token.access_token.is_empty() {
146                return Err(einval!("failed to get auth token from registry"));
147            }
148            token.token = token.access_token.clone();
149        }
150        Ok(token)
151    }
152}
153
154#[derive(Debug)]
155struct BasicAuth {
156    #[allow(unused)]
157    realm: String,
158}
159
160#[derive(Debug, Clone)]
161#[allow(dead_code)]
162struct BearerAuth {
163    realm: String,
164    service: String,
165    scope: String,
166}
167
168#[derive(Debug)]
169#[allow(dead_code)]
170enum Auth {
171    Basic(BasicAuth),
172    Bearer(BearerAuth),
173}
174
175pub struct Scheme(AtomicBool);
176
177impl Scheme {
178    fn new(value: bool) -> Self {
179        Scheme(AtomicBool::new(value))
180    }
181}
182
183impl fmt::Display for Scheme {
184    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
185        if self.0.load(Ordering::Relaxed) {
186            write!(f, "https")
187        } else {
188            write!(f, "http")
189        }
190    }
191}
192
193struct RegistryState {
194    // HTTP scheme like: https, http
195    scheme: Scheme,
196    host: String,
197    // Image repo name like: library/ubuntu
198    repo: String,
199    // Base64 encoded registry auth
200    auth: Option<String>,
201    // Retry limit for read operation
202    retry_limit: u8,
203    // Scheme specified for blob server
204    blob_url_scheme: String,
205    // Replace registry redirected url host with the given host
206    blob_redirected_host: String,
207    // Cache bearer token (get from registry authentication server) or basic authentication auth string.
208    // We need use it to reduce the pressure on token authentication server or reduce the base64 compute workload for every request.
209    // Use RwLock here to avoid using mut backend trait object.
210    // Example: RwLock<"Bearer <token>">
211    //          RwLock<"Basic base64(<username:password>)">
212    cached_auth: Cache,
213    // Cache for the HTTP method when getting auth, it is "true" when using "GET" method.
214    // Due to the different implementations of various image registries, auth requests
215    // may use the GET or POST methods, we need to cache the method after the
216    // fallback, so it can be reused next time and reduce an unnecessary request.
217    cached_auth_using_http_get: HashCache<bool>,
218    // Cache 30X redirect url
219    // Example: RwLock<HashMap<"<blob_id>", "<redirected_url>">>
220    cached_redirect: HashCache<String>,
221    // The epoch timestamp of token expiration, which is obtained from the registry server.
222    token_expired_at: ArcSwapOption<u64>,
223    // Cache bearer auth for refreshing token.
224    cached_bearer_auth: ArcSwapOption<BearerAuth>,
225}
226
227impl RegistryState {
228    fn url(&self, path: &str, query: &[&str]) -> std::result::Result<String, ParseError> {
229        let path = if query.is_empty() {
230            format!("/v2/{}{}", self.repo, path)
231        } else {
232            format!("/v2/{}{}?{}", self.repo, path, query.join("&"))
233        };
234        let url = format!("{}://{}", self.scheme, self.host.as_str());
235        let url = Url::parse(url.as_str())?;
236        let url = url.join(path.as_str())?;
237
238        Ok(url.to_string())
239    }
240
241    fn needs_fallback_http(&self, e: &dyn Error) -> bool {
242        match e.source() {
243            Some(err) => match err.source() {
244                Some(err) => {
245                    if !self.scheme.0.load(Ordering::Relaxed) {
246                        return false;
247                    }
248                    let msg = err.to_string().to_lowercase();
249                    // If we attempt to establish a TLS connection with the HTTP registry server,
250                    // we are likely to encounter these types of error:
251                    // https://github.com/openssl/openssl/blob/6b3d28757620e0781bb1556032bb6961ee39af63/crypto/err/openssl.txt#L1574
252                    // https://github.com/containerd/nerdctl/blob/225a70bdc3b93cdb00efac7db1ceb50c098a8a16/pkg/cmd/image/push.go#LL135C66-L135C66
253                    let fallback = msg.contains("wrong version number")
254                        || msg.contains("connection refused")
255                        || msg.to_lowercase().contains("ssl");
256                    if fallback {
257                        warn!("fallback to http due to tls connection error: {}", err);
258                    }
259                    fallback
260                }
261                None => false,
262            },
263            None => false,
264        }
265    }
266
267    // Request registry authentication server to get bearer token
268    fn get_token(&self, auth: BearerAuth, connection: &Arc<Connection>) -> Result<TokenResponse> {
269        let http_get = self
270            .cached_auth_using_http_get
271            .get(&self.host)
272            .unwrap_or_default();
273        let resp = if http_get {
274            self.fetch_token(&auth, connection, Method::GET)?
275        } else {
276            match self.fetch_token(&auth, connection, Method::POST) {
277                Ok(resp) => resp,
278                Err(_) => {
279                    warn!("retry http GET method to get auth token");
280                    let resp = self.fetch_token(&auth, connection, Method::GET)?;
281                    // Cache http method for next use.
282                    self.cached_auth_using_http_get.set(self.host.clone(), true);
283                    resp
284                }
285            }
286        };
287
288        let ret = TokenResponse::from_resp(resp)
289            .map_err(|e| einval!(format!("failed to get auth token from registry: {:?}", e)))?;
290
291        if let Ok(now_timestamp) = SystemTime::now().duration_since(UNIX_EPOCH) {
292            self.token_expired_at
293                .store(Some(Arc::new(now_timestamp.as_secs() + ret.expires_in)));
294            debug!(
295                "cached bearer auth, next time: {}",
296                now_timestamp.as_secs() + ret.expires_in
297            );
298        }
299
300        // Cache bearer auth for refreshing token.
301        self.cached_bearer_auth.store(Some(Arc::new(auth)));
302
303        Ok(ret)
304    }
305
306    // Fetches a bearer token from the registry's authentication
307    fn fetch_token(
308        &self,
309        auth: &BearerAuth,
310        connection: &Arc<Connection>,
311        method: Method,
312    ) -> Result<Response> {
313        let mut headers = HeaderMap::new();
314
315        if let Some(auth) = &self.auth {
316            headers.insert(
317                HEADER_AUTHORIZATION,
318                format!("Basic {}", auth).parse().unwrap(),
319            );
320        }
321
322        let mut query: Option<&[(&str, &str)]> = None;
323        let mut body = None;
324
325        let query_params_get;
326
327        match method {
328            Method::GET => {
329                query_params_get = [
330                    ("service", auth.service.as_str()),
331                    ("scope", auth.scope.as_str()),
332                    ("client_id", REGISTRY_CLIENT_ID),
333                ];
334                query = Some(&query_params_get);
335            }
336            Method::POST => {
337                let mut form = HashMap::new();
338                form.insert("service".to_string(), auth.service.clone());
339                form.insert("scope".to_string(), auth.scope.clone());
340                form.insert("client_id".to_string(), REGISTRY_CLIENT_ID.to_string());
341                body = Some(ReqBody::Form(form));
342            }
343            _ => return Err(einval!()),
344        }
345
346        let token_resp = connection
347            .call::<&[u8]>(
348                method.clone(),
349                auth.realm.as_str(),
350                query,
351                body,
352                &mut headers,
353                true,
354            )
355            .map_err(move |e| {
356                warn!(
357                    "failed to request registry auth server by {:?} method: {:?}",
358                    method, e
359                );
360                einval!()
361            })?;
362
363        Ok(token_resp)
364    }
365
366    fn get_auth_header(&self, auth: Auth, connection: &Arc<Connection>) -> Result<String> {
367        match auth {
368            Auth::Basic(_) => self
369                .auth
370                .as_ref()
371                .map(|auth| format!("Basic {}", auth))
372                .ok_or_else(|| einval!("invalid auth config")),
373            Auth::Bearer(auth) => {
374                let token = self.get_token(auth, connection)?;
375                Ok(format!("Bearer {}", token.token))
376            }
377        }
378    }
379
380    /// Parse `www-authenticate` response header respond from registry server
381    /// The header format like: `Bearer realm="https://auth.my-registry.com/token",service="my-registry.com",scope="repository:test/repo:pull,push"`
382    fn parse_auth(source: &HeaderValue) -> Option<Auth> {
383        let source = source.to_str().unwrap();
384        let source: Vec<&str> = source.splitn(2, ' ').collect();
385        if source.len() < 2 {
386            return None;
387        }
388        let scheme = source[0].trim();
389        let pairs = source[1].trim();
390        let pairs = pairs.split("\",");
391        let mut paras = HashMap::new();
392        for pair in pairs {
393            let pair: Vec<&str> = pair.trim().split('=').collect();
394            if pair.len() < 2 {
395                return None;
396            }
397            let key = pair[0].trim();
398            let value = pair[1].trim().trim_matches('"');
399            paras.insert(key, value);
400        }
401
402        match scheme {
403            "Basic" => {
404                let realm = if let Some(realm) = paras.get("realm") {
405                    (*realm).to_string()
406                } else {
407                    String::new()
408                };
409                Some(Auth::Basic(BasicAuth { realm }))
410            }
411            "Bearer" => {
412                if !paras.contains_key("realm") || !paras.contains_key("service") {
413                    return None;
414                }
415
416                let scope = if let Some(scope) = paras.get("scope") {
417                    (*scope).to_string()
418                } else {
419                    debug!("no scope specified for token auth challenge");
420                    String::new()
421                };
422
423                Some(Auth::Bearer(BearerAuth {
424                    realm: (*paras.get("realm").unwrap()).to_string(),
425                    service: (*paras.get("service").unwrap()).to_string(),
426                    scope,
427                }))
428            }
429            _ => None,
430        }
431    }
432
433    fn fallback_http(&self) {
434        self.scheme.0.store(false, Ordering::Relaxed);
435    }
436}
437
438#[derive(Clone)]
439struct First {
440    inner: Arc<ArcSwap<Once>>,
441}
442
443impl First {
444    fn new() -> Self {
445        First {
446            inner: Arc::new(ArcSwap::new(Arc::new(Once::new()))),
447        }
448    }
449
450    fn once<F>(&self, f: F)
451    where
452        F: FnOnce(),
453    {
454        self.inner.load().call_once(f)
455    }
456
457    fn renew(&self) {
458        self.inner.store(Arc::new(Once::new()));
459    }
460
461    fn handle<F, T>(&self, handle: &mut F) -> Option<BackendResult<T>>
462    where
463        F: FnMut() -> BackendResult<T>,
464    {
465        let mut ret = None;
466        // Call once twice to ensure the subsequent requests use the new
467        // Once instance after renew happens.
468        for _ in 0..=1 {
469            self.once(|| {
470                ret = Some(handle().inspect_err(|_err| {
471                    // Replace the Once instance so that we can retry it when
472                    // the handle call failed.
473                    self.renew();
474                }));
475            });
476            if ret.is_some() {
477                break;
478            }
479        }
480        ret
481    }
482
483    /// When invoking concurrently, only one of the handle methods will be executed first,
484    /// then subsequent handle methods will be allowed to execute concurrently.
485    ///
486    /// Nydusd uses a registry backend which generates a surge of blob requests without
487    /// auth tokens on initial startup, this caused mirror backends (e.g. dragonfly)
488    /// to process very slowly. The method implements waiting for the first blob request
489    /// to complete before making other blob requests, this ensures the first request
490    /// caches a valid registry auth token, and subsequent concurrent blob requests can
491    /// reuse the cached token.
492    fn handle_force<F, T>(&self, handle: &mut F) -> BackendResult<T>
493    where
494        F: FnMut() -> BackendResult<T>,
495    {
496        self.handle(handle).unwrap_or_else(handle)
497    }
498}
499
500struct RegistryReader {
501    blob_id: String,
502    connection: Arc<Connection>,
503    state: Arc<RegistryState>,
504    metrics: Arc<BackendMetrics>,
505    first: First,
506}
507
508impl RegistryReader {
509    /// Request registry server with `authorization` header
510    ///
511    /// Bearer token authenticate workflow:
512    ///
513    /// Request:  POST https://my-registry.com/test/repo/blobs/uploads
514    /// Response: status: 401 Unauthorized
515    ///           header: www-authenticate: Bearer realm="https://auth.my-registry.com/token",service="my-registry.com",scope="repository:test/repo:pull,push"
516    ///
517    /// Request:  POST https://auth.my-registry.com/token
518    ///           body: "service=my-registry.com&scope=repository:test/repo:pull,push&grant_type=password&username=x&password=x&client_id=nydus-registry-client"
519    /// Response: status: 200 Ok
520    ///           body: { "token": "<token>" }
521    ///
522    /// Request:  POST https://my-registry.com/test/repo/blobs/uploads
523    ///           header: authorization: Bearer <token>
524    /// Response: status: 200 Ok
525    ///
526    /// Basic authenticate workflow:
527    ///
528    /// Request:  POST https://my-registry.com/test/repo/blobs/uploads
529    /// Response: status: 401 Unauthorized
530    ///           header: www-authenticate: Basic
531    ///
532    /// Request:  POST https://my-registry.com/test/repo/blobs/uploads
533    ///           header: authorization: Basic base64(<username:password>)
534    /// Response: status: 200 Ok
535    fn request<R: Read + Clone + Send + 'static>(
536        &self,
537        method: Method,
538        url: &str,
539        data: Option<ReqBody<R>>,
540        mut headers: HeaderMap,
541        catch_status: bool,
542    ) -> RegistryResult<Response> {
543        // Try get authorization header from cache for this request
544        let mut last_cached_auth = String::new();
545        let cached_auth = self.state.cached_auth.get();
546        if !cached_auth.is_empty() {
547            last_cached_auth = cached_auth.clone();
548            headers.insert(
549                HEADER_AUTHORIZATION,
550                HeaderValue::from_str(cached_auth.as_str()).unwrap(),
551            );
552        }
553
554        // For upload request with payload, the auth header should be cached
555        // after create_upload(), so we can request registry server directly
556        if let Some(data) = data {
557            return self
558                .connection
559                .call(method, url, None, Some(data), &mut headers, catch_status)
560                .map_err(RegistryError::Request);
561        }
562
563        // Try to request registry server with `authorization` header
564        let mut resp = self
565            .connection
566            .call::<&[u8]>(method.clone(), url, None, None, &mut headers, false)
567            .map_err(RegistryError::Request)?;
568        if resp.status() == StatusCode::UNAUTHORIZED {
569            if headers.contains_key(HEADER_AUTHORIZATION) {
570                // If we request registry (harbor server) with expired authorization token,
571                // the `www-authenticate: Basic realm="harbor"` in response headers is not expected.
572                // Related code in harbor:
573                // https://github.com/goharbor/harbor/blob/v2.5.3/src/server/middleware/v2auth/auth.go#L98
574                //
575                // We can remove the expired authorization token and
576                // resend the request to get the correct "www-authenticate" value.
577                headers.remove(HEADER_AUTHORIZATION);
578
579                resp = self
580                    .connection
581                    .call::<&[u8]>(method.clone(), url, None, None, &mut headers, false)
582                    .map_err(RegistryError::Request)?;
583            };
584
585            if let Some(resp_auth_header) = resp.headers().get(HEADER_WWW_AUTHENTICATE) {
586                // Get token from registry authorization server
587                if let Some(auth) = RegistryState::parse_auth(resp_auth_header) {
588                    let auth_header = self
589                        .state
590                        .get_auth_header(auth, &self.connection)
591                        .map_err(|e| RegistryError::Common(e.to_string()))?;
592
593                    headers.insert(
594                        HEADER_AUTHORIZATION,
595                        HeaderValue::from_str(auth_header.as_str()).unwrap(),
596                    );
597
598                    // Try to request registry server with `authorization` header again
599                    let resp = self
600                        .connection
601                        .call(method, url, None, data, &mut headers, catch_status)
602                        .map_err(RegistryError::Request)?;
603
604                    let status = resp.status();
605                    if is_success_status(status) {
606                        // Cache authorization header for next request
607                        self.state.cached_auth.set(&last_cached_auth, auth_header)
608                    }
609                    return respond(resp, catch_status).map_err(RegistryError::Request);
610                }
611            }
612        }
613
614        respond(resp, catch_status).map_err(RegistryError::Request)
615    }
616
617    /// Read data from registry server
618    ///
619    /// Step:
620    ///
621    /// Request:  GET /blobs/sha256:<blob_id>
622    /// Response: status: 307 Temporary Redirect
623    ///           header: location: https://raw-blob-storage-host.com/signature=x
624    ///
625    /// Request:  GET https://raw-blob-storage-host.com/signature=x
626    /// Response: status: 200 Ok / 403 Forbidden
627    /// If responding 403, we need to repeat step one
628    fn _try_read(
629        &self,
630        mut buf: &mut [u8],
631        offset: u64,
632        allow_retry: bool,
633    ) -> RegistryResult<usize> {
634        let url = format!("/blobs/sha256:{}", self.blob_id);
635        let url = self
636            .state
637            .url(url.as_str(), &[])
638            .map_err(|e| RegistryError::Url(url, e))?;
639        let mut headers = HeaderMap::new();
640        let end_at = offset + buf.len() as u64 - 1;
641        let range = format!("bytes={}-{}", offset, end_at);
642        headers.insert("Range", range.parse().unwrap());
643
644        let mut resp;
645        let cached_redirect = self.state.cached_redirect.get(&self.blob_id);
646
647        if let Some(cached_redirect) = cached_redirect {
648            resp = self
649                .connection
650                .call::<&[u8]>(
651                    Method::GET,
652                    cached_redirect.as_str(),
653                    None,
654                    None,
655                    &mut headers,
656                    false,
657                )
658                .map_err(RegistryError::Request)?;
659
660            // The request has expired or has been denied, need to re-request
661            if allow_retry
662                && [StatusCode::UNAUTHORIZED, StatusCode::FORBIDDEN].contains(&resp.status())
663            {
664                warn!(
665                    "The redirected link has expired: {}, will retry read",
666                    cached_redirect.as_str()
667                );
668                self.state.cached_redirect.remove(&self.blob_id);
669                // Try read again only once
670                return self._try_read(buf, offset, false);
671            }
672        } else {
673            resp = match self.request::<&[u8]>(
674                Method::GET,
675                url.as_str(),
676                None,
677                headers.clone(),
678                false,
679            ) {
680                Ok(res) => res,
681                Err(RegistryError::Request(ConnectionError::Common(e)))
682                    if self.state.needs_fallback_http(&e) =>
683                {
684                    self.state.fallback_http();
685                    let url = format!("/blobs/sha256:{}", self.blob_id);
686                    let url = self
687                        .state
688                        .url(url.as_str(), &[])
689                        .map_err(|e| RegistryError::Url(url, e))?;
690                    self.request::<&[u8]>(Method::GET, url.as_str(), None, headers.clone(), false)?
691                }
692                Err(RegistryError::Request(ConnectionError::Common(e))) => {
693                    if e.to_string().contains("self signed certificate") {
694                        warn!("try to enable \"skip_verify: true\" option");
695                    }
696                    return Err(RegistryError::Request(ConnectionError::Common(e)));
697                }
698                Err(e) => {
699                    return Err(e);
700                }
701            };
702            let status = resp.status();
703            let need_redirect =
704                status >= StatusCode::MULTIPLE_CHOICES && status < StatusCode::BAD_REQUEST;
705
706            // Handle redirect request and cache redirect url
707            if need_redirect {
708                if let Some(location) = resp.headers().get("location") {
709                    let location = location.to_str().unwrap();
710                    let mut location = Url::parse(location)
711                        .map_err(|e| RegistryError::Url(location.to_string(), e))?;
712                    // Note: Some P2P proxy server supports only scheme specified origin blob server,
713                    // so we need change scheme to `blob_url_scheme` here
714                    if !self.state.blob_url_scheme.is_empty() {
715                        location
716                            .set_scheme(&self.state.blob_url_scheme)
717                            .map_err(|_| {
718                                RegistryError::Scheme(self.state.blob_url_scheme.clone())
719                            })?;
720                    }
721                    if !self.state.blob_redirected_host.is_empty() {
722                        location
723                            .set_host(Some(self.state.blob_redirected_host.as_str()))
724                            .map_err(|e| {
725                                error!(
726                                    "Failed to set blob redirected host to {}: {:?}",
727                                    self.state.blob_redirected_host.as_str(),
728                                    e
729                                );
730                                RegistryError::Url(location.to_string(), e)
731                            })?;
732                        debug!("New redirected location {:?}", location.host_str());
733                    }
734                    let resp_ret = self
735                        .connection
736                        .call::<&[u8]>(
737                            Method::GET,
738                            location.as_str(),
739                            None,
740                            None,
741                            &mut headers,
742                            true,
743                        )
744                        .map_err(RegistryError::Request);
745                    match resp_ret {
746                        Ok(_resp) => {
747                            resp = _resp;
748                            self.state
749                                .cached_redirect
750                                .set(self.blob_id.clone(), location.as_str().to_string())
751                        }
752                        Err(err) => {
753                            return Err(err);
754                        }
755                    }
756                };
757            } else {
758                resp = respond(resp, true).map_err(RegistryError::Request)?;
759            }
760        }
761
762        resp.copy_to(&mut buf)
763            .map_err(RegistryError::Transport)
764            .map(|size| size as usize)
765    }
766}
767
768impl BlobReader for RegistryReader {
769    fn blob_size(&self) -> BackendResult<u64> {
770        self.first.handle_force(&mut || -> BackendResult<u64> {
771            let url = format!("/blobs/sha256:{}", self.blob_id);
772            let url = self
773                .state
774                .url(&url, &[])
775                .map_err(|e| RegistryError::Url(url, e))?;
776
777            let resp = match self.request::<&[u8]>(
778                Method::HEAD,
779                url.as_str(),
780                None,
781                HeaderMap::new(),
782                true,
783            ) {
784                Ok(res) => res,
785                Err(RegistryError::Request(ConnectionError::Common(e)))
786                    if self.state.needs_fallback_http(&e) =>
787                {
788                    self.state.fallback_http();
789                    let url = format!("/blobs/sha256:{}", self.blob_id);
790                    let url = self
791                        .state
792                        .url(&url, &[])
793                        .map_err(|e| RegistryError::Url(url, e))?;
794                    self.request::<&[u8]>(Method::HEAD, url.as_str(), None, HeaderMap::new(), true)?
795                }
796                Err(e) => {
797                    return Err(BackendError::Registry(e));
798                }
799            };
800            let content_length = resp
801                .headers()
802                .get(CONTENT_LENGTH)
803                .ok_or_else(|| RegistryError::Common("invalid content length".to_string()))?;
804
805            Ok(content_length
806                .to_str()
807                .map_err(|err| RegistryError::Common(format!("invalid content length: {:?}", err)))?
808                .parse::<u64>()
809                .map_err(|err| {
810                    RegistryError::Common(format!("invalid content length: {:?}", err))
811                })?)
812        })
813    }
814
815    fn try_read(&self, buf: &mut [u8], offset: u64) -> BackendResult<usize> {
816        self.first.handle_force(&mut || -> BackendResult<usize> {
817            self._try_read(buf, offset, true)
818                .map_err(BackendError::Registry)
819        })
820    }
821
822    fn metrics(&self) -> &BackendMetrics {
823        &self.metrics
824    }
825
826    fn retry_limit(&self) -> u8 {
827        self.state.retry_limit
828    }
829}
830
831/// Storage backend based on image registry.
832pub struct Registry {
833    connection: Arc<Connection>,
834    state: Arc<RegistryState>,
835    metrics: Arc<BackendMetrics>,
836    first: First,
837}
838
839impl Registry {
840    #[allow(clippy::useless_let_if_seq)]
841    pub fn new(config: &RegistryConfig, id: Option<&str>) -> Result<Registry> {
842        let id = id.ok_or_else(|| einval!("Registry backend requires blob_id"))?;
843        let con_config: ConnectionConfig = config.clone().into();
844
845        let retry_limit = con_config.retry_limit;
846        let connection = Connection::new(&con_config)?;
847        let auth = trim(config.auth.clone());
848        let registry_token = trim(config.registry_token.clone());
849        Self::validate_authorization_info(&auth)?;
850        let cached_auth = if let Some(registry_token) = registry_token {
851            // Store the registry bearer token to cached_auth, prefer to
852            // use the token stored in cached_auth to request registry.
853            Cache::new(format!("Bearer {}", registry_token))
854        } else {
855            Cache::new(String::new())
856        };
857
858        let scheme = if !config.scheme.is_empty() && config.scheme == "http" {
859            Scheme::new(false)
860        } else {
861            Scheme::new(true)
862        };
863
864        let state = Arc::new(RegistryState {
865            scheme,
866            host: config.host.clone(),
867            repo: config.repo.clone(),
868            auth,
869            cached_auth,
870            retry_limit,
871            blob_url_scheme: config.blob_url_scheme.clone(),
872            blob_redirected_host: config.blob_redirected_host.clone(),
873            cached_auth_using_http_get: HashCache::new(),
874            cached_redirect: HashCache::new(),
875            token_expired_at: ArcSwapOption::new(None),
876            cached_bearer_auth: ArcSwapOption::new(None),
877        });
878
879        let registry = Registry {
880            connection,
881            state,
882            metrics: BackendMetrics::new(id, "registry"),
883            first: First::new(),
884        };
885
886        if config.disable_token_refresh {
887            info!("Refresh token thread is disabled.");
888        } else {
889            registry.start_refresh_token_thread();
890            info!("Refresh token thread started.");
891        }
892
893        Ok(registry)
894    }
895
896    fn validate_authorization_info(auth: &Option<String>) -> Result<()> {
897        if let Some(auth) = &auth {
898            let auth: Vec<u8> = base64::engine::general_purpose::STANDARD
899                .decode(auth.as_bytes())
900                .map_err(|e| {
901                    einval!(format!(
902                        "Invalid base64 encoded registry auth config: {:?}",
903                        e
904                    ))
905                })?;
906            let auth = std::str::from_utf8(&auth).map_err(|e| {
907                einval!(format!(
908                    "Invalid utf-8 encoded registry auth config: {:?}",
909                    e
910                ))
911            })?;
912            let auth: Vec<&str> = auth.splitn(2, ':').collect();
913            if auth.len() < 2 {
914                return Err(einval!("Invalid registry auth config"));
915            }
916        }
917        Ok(())
918    }
919
920    fn start_refresh_token_thread(&self) {
921        let conn = self.connection.clone();
922        let state = self.state.clone();
923        // FIXME: we'd better allow users to specify the expiration time.
924        let mut refresh_interval = REGISTRY_DEFAULT_TOKEN_EXPIRATION;
925        thread::spawn(move || {
926            loop {
927                if let Ok(now_timestamp) = SystemTime::now().duration_since(UNIX_EPOCH) {
928                    if let Some(token_expired_at) = state.token_expired_at.load().as_deref() {
929                        // If the token will expire within the next refresh interval,
930                        // refresh it immediately.
931                        if now_timestamp.as_secs() + refresh_interval >= *token_expired_at {
932                            if let Some(cached_bearer_auth) =
933                                state.cached_bearer_auth.load().as_deref()
934                            {
935                                if let Ok(token) =
936                                    state.get_token(cached_bearer_auth.to_owned(), &conn)
937                                {
938                                    let new_cached_auth = format!("Bearer {}", token.token);
939                                    debug!(
940                                        "[refresh_token_thread] registry token has been refreshed"
941                                    );
942                                    // Refresh cached token.
943                                    state
944                                        .cached_auth
945                                        .set(&state.cached_auth.get(), new_cached_auth);
946                                    // Reset refresh interval according to real expiration time,
947                                    // and advance 20s to handle the unexpected cases.
948                                    refresh_interval = token
949                                        .expires_in
950                                        .checked_sub(20)
951                                        .unwrap_or(token.expires_in);
952                                } else {
953                                    error!(
954                                        "[refresh_token_thread] failed to refresh registry token"
955                                    );
956                                }
957                            }
958                        }
959                    }
960                }
961
962                if conn.shutdown.load(Ordering::Acquire) {
963                    break;
964                }
965                thread::sleep(Duration::from_secs(refresh_interval));
966                if conn.shutdown.load(Ordering::Acquire) {
967                    break;
968                }
969            }
970        });
971    }
972}
973
974impl BlobBackend for Registry {
975    fn shutdown(&self) {
976        self.connection.shutdown();
977    }
978
979    fn metrics(&self) -> &BackendMetrics {
980        &self.metrics
981    }
982
983    fn get_reader(&self, blob_id: &str) -> BackendResult<Arc<dyn BlobReader>> {
984        Ok(Arc::new(RegistryReader {
985            blob_id: blob_id.to_owned(),
986            state: self.state.clone(),
987            connection: self.connection.clone(),
988            metrics: self.metrics.clone(),
989            first: self.first.clone(),
990        }))
991    }
992}
993
994impl Drop for Registry {
995    fn drop(&mut self) {
996        self.metrics.release().unwrap_or_else(|e| error!("{:?}", e));
997    }
998}
999
1000fn trim(value: Option<String>) -> Option<String> {
1001    if let Some(val) = value.as_ref() {
1002        let trimmed_val = val.trim();
1003        if trimmed_val.is_empty() {
1004            None
1005        } else if trimmed_val.len() == val.len() {
1006            value
1007        } else {
1008            Some(trimmed_val.to_string())
1009        }
1010    } else {
1011        None
1012    }
1013}
1014
1015#[cfg(test)]
1016mod tests {
1017    use super::*;
1018    use http::response;
1019    use serde_json::json;
1020
1021    #[test]
1022    fn test_string_cache() {
1023        let cache = Cache::new("test".to_owned());
1024
1025        assert_eq!(cache.get(), "test");
1026
1027        cache.set("test", "test1".to_owned());
1028        assert_eq!(cache.get(), "test1");
1029        cache.set("test1", "test1".to_owned());
1030        assert_eq!(cache.get(), "test1");
1031    }
1032
1033    #[test]
1034    fn test_hash_cache() {
1035        let cache = HashCache::new();
1036
1037        assert_eq!(cache.get("test"), None);
1038        cache.set("test".to_owned(), "test".to_owned());
1039        assert_eq!(cache.get("test"), Some("test".to_owned()));
1040        cache.set("test".to_owned(), "test1".to_owned());
1041        assert_eq!(cache.get("test"), Some("test1".to_owned()));
1042        cache.remove("test");
1043        assert_eq!(cache.get("test"), None);
1044    }
1045
1046    #[test]
1047    fn test_state_url() {
1048        let state = RegistryState {
1049            scheme: Scheme::new(false),
1050            host: "alibaba-inc.com".to_string(),
1051            repo: "nydus".to_string(),
1052            auth: None,
1053            retry_limit: 5,
1054            blob_url_scheme: "https".to_string(),
1055            blob_redirected_host: "oss.alibaba-inc.com".to_string(),
1056            cached_auth_using_http_get: Default::default(),
1057            cached_auth: Default::default(),
1058            cached_redirect: Default::default(),
1059            token_expired_at: ArcSwapOption::new(None),
1060            cached_bearer_auth: ArcSwapOption::new(None),
1061        };
1062
1063        assert_eq!(
1064            state.url("image", &["blabla"]).unwrap(),
1065            "http://alibaba-inc.com/v2/nydusimage?blabla".to_owned()
1066        );
1067        assert_eq!(
1068            state.url("image", &[]).unwrap(),
1069            "http://alibaba-inc.com/v2/nydusimage".to_owned()
1070        );
1071    }
1072
1073    #[test]
1074    fn test_parse_auth() {
1075        let str = "Bearer realm=\"https://auth.my-registry.com/token\",service=\"my-registry.com\",scope=\"repository:test/repo:pull,push\"";
1076        let header = HeaderValue::from_str(str).unwrap();
1077        let auth = RegistryState::parse_auth(&header).unwrap();
1078        match auth {
1079            Auth::Bearer(auth) => {
1080                assert_eq!(&auth.realm, "https://auth.my-registry.com/token");
1081                assert_eq!(&auth.service, "my-registry.com");
1082                assert_eq!(&auth.scope, "repository:test/repo:pull,push");
1083            }
1084            _ => panic!("failed to parse `Bearer` authentication header"),
1085        }
1086
1087        // No scope is accetpable
1088        let str = "Bearer realm=\"https://auth.my-registry.com/token\",service=\"my-registry.com\"";
1089        let header = HeaderValue::from_str(str).unwrap();
1090        let auth = RegistryState::parse_auth(&header).unwrap();
1091        match auth {
1092            Auth::Bearer(auth) => {
1093                assert_eq!(&auth.realm, "https://auth.my-registry.com/token");
1094                assert_eq!(&auth.service, "my-registry.com");
1095                assert_eq!(&auth.scope, "");
1096            }
1097            _ => panic!("failed to parse `Bearer` authentication header without scope"),
1098        }
1099
1100        let str = "Basic realm=\"https://auth.my-registry.com/token\"";
1101        let header = HeaderValue::from_str(str).unwrap();
1102        let auth = RegistryState::parse_auth(&header).unwrap();
1103        match auth {
1104            Auth::Basic(auth) => assert_eq!(&auth.realm, "https://auth.my-registry.com/token"),
1105            _ => panic!("failed to parse `Basic` authentication header"),
1106        }
1107
1108        let str = "Base realm=\"https://auth.my-registry.com/token\"";
1109        let header = HeaderValue::from_str(str).unwrap();
1110        assert!(RegistryState::parse_auth(&header).is_none());
1111    }
1112
1113    #[test]
1114    fn test_trim() {
1115        assert_eq!(trim(None), None);
1116        assert_eq!(trim(Some("".to_owned())), None);
1117        assert_eq!(trim(Some("    ".to_owned())), None);
1118        assert_eq!(trim(Some("  test  ".to_owned())), Some("test".to_owned()));
1119        assert_eq!(trim(Some("test  ".to_owned())), Some("test".to_owned()));
1120        assert_eq!(trim(Some("  test".to_owned())), Some("test".to_owned()));
1121        assert_eq!(trim(Some("  te st  ".to_owned())), Some("te st".to_owned()));
1122        assert_eq!(trim(Some("te st".to_owned())), Some("te st".to_owned()));
1123    }
1124
1125    #[test]
1126    #[allow(clippy::redundant_clone)]
1127    fn test_first_basically() {
1128        let first = First::new();
1129        let mut val = 0;
1130        first.once(|| {
1131            val += 1;
1132        });
1133        assert_eq!(val, 1);
1134
1135        first.clone().once(|| {
1136            val += 1;
1137        });
1138        assert_eq!(val, 1);
1139
1140        first.renew();
1141        first.clone().once(|| {
1142            val += 1;
1143        });
1144        assert_eq!(val, 2);
1145    }
1146
1147    #[test]
1148    #[allow(clippy::redundant_clone)]
1149    fn test_first_concurrently() {
1150        let val = Arc::new(ArcSwap::new(Arc::new(0)));
1151        let first = First::new();
1152
1153        let mut handlers = Vec::new();
1154        for _ in 0..100 {
1155            let val_cloned = val.clone();
1156            let first_cloned = first.clone();
1157            handlers.push(std::thread::spawn(move || {
1158                let _ = first_cloned.handle(&mut || -> BackendResult<()> {
1159                    let val = val_cloned.load();
1160                    let ret = if *val.as_ref() == 0 {
1161                        std::thread::sleep(std::time::Duration::from_secs(2));
1162                        Err(BackendError::Registry(RegistryError::Common(String::from(
1163                            "network error",
1164                        ))))
1165                    } else {
1166                        Ok(())
1167                    };
1168                    val_cloned.store(Arc::new(val.as_ref() + 1));
1169                    ret
1170                });
1171            }));
1172        }
1173
1174        for handler in handlers {
1175            handler.join().unwrap();
1176        }
1177
1178        assert_eq!(*val.load().as_ref(), 2);
1179    }
1180
1181    #[test]
1182    fn test_token_response_from_resp() {
1183        // Case 1: Response contains "token"
1184        let json_with_token = json!({
1185            "token": "test_token_value",
1186            "expires_in": 3600
1187        });
1188        let response = Response::from(
1189            response::Builder::new()
1190                .body(json_with_token.to_string())
1191                .unwrap(),
1192        );
1193        let result = TokenResponse::from_resp(response).unwrap();
1194        assert_eq!(result.token, "test_token_value");
1195        assert_eq!(result.expires_in, 3600);
1196
1197        // Case 2: Response contains "access_token"
1198        let json_with_access_token = json!({
1199            "access_token": "test_access_token_value",
1200            "expires_in": 7200
1201        });
1202        let response = Response::from(
1203            response::Builder::new()
1204                .body(json_with_access_token.to_string())
1205                .unwrap(),
1206        );
1207        let result = TokenResponse::from_resp(response).unwrap();
1208        assert_eq!(result.token, "test_access_token_value");
1209        assert_eq!(result.expires_in, 7200);
1210
1211        // Case 3: Default expiration time when "expires_in" is missing
1212        let json_with_default_expiration = json!({
1213            "token": "default_expiration_token"
1214        });
1215        let response = Response::from(
1216            response::Builder::new()
1217                .body(json_with_default_expiration.to_string())
1218                .unwrap(),
1219        );
1220        let result = TokenResponse::from_resp(response).unwrap();
1221        assert_eq!(result.token, "default_expiration_token");
1222        assert_eq!(result.expires_in, REGISTRY_DEFAULT_TOKEN_EXPIRATION);
1223
1224        // Case 4: Response contains both token and access_token
1225        let json_with_both_tokens = json!({
1226            "token": "test_token_value",
1227            "access_token": "test_access_token_value",
1228        });
1229        let response = Response::from(
1230            response::Builder::new()
1231                .body(json_with_both_tokens.to_string())
1232                .unwrap(),
1233        );
1234        let result = TokenResponse::from_resp(response).unwrap();
1235        assert_eq!(result.token, "test_token_value");
1236
1237        // Case 5: Response contains no token
1238        let json_with_no_token = json!({});
1239        let response = Response::from(
1240            response::Builder::new()
1241                .body(json_with_no_token.to_string())
1242                .unwrap(),
1243        );
1244        let result = TokenResponse::from_resp(response);
1245        assert!(result.is_err());
1246    }
1247}