Skip to main content

cuenv_tools_oci/
registry.rs

1//! OCI registry client for resolving and pulling images.
2//!
3//! Uses `oci-distribution` for registry operations.
4
5use oci_distribution::client::{ClientConfig, ClientProtocol};
6use oci_distribution::manifest::OciDescriptor;
7use oci_distribution::secrets::RegistryAuth;
8use oci_distribution::{Client, Reference};
9use sha2::{Digest, Sha256};
10use std::panic::{AssertUnwindSafe, catch_unwind};
11use std::path::Path;
12use std::sync::OnceLock;
13use tokio::io::{AsyncReadExt, AsyncWriteExt};
14use tracing::{debug, info, trace};
15
16use crate::cache::OciCache;
17use crate::platform::Platform;
18use crate::{Error, Result};
19
20/// OCI registry client for image resolution and blob pulling.
21pub struct OciClient {
22    client: OnceLock<std::result::Result<Client, String>>,
23}
24
25impl Default for OciClient {
26    fn default() -> Self {
27        Self::new()
28    }
29}
30
31impl OciClient {
32    fn create_client() -> std::result::Result<Client, String> {
33        let config = ClientConfig {
34            protocol: ClientProtocol::Https,
35            ..Default::default()
36        };
37
38        catch_unwind(AssertUnwindSafe(|| Client::new(config))).map_err(|_| {
39            "Failed to initialize OCI client because system proxy discovery panicked".to_string()
40        })
41    }
42
43    fn client(&self) -> Result<&Client> {
44        match self.client.get_or_init(Self::create_client) {
45            Ok(client) => Ok(client),
46            Err(err) => Err(Error::Oci(err.clone())),
47        }
48    }
49
50    /// Create a new OCI client with default configuration.
51    #[must_use]
52    pub fn new() -> Self {
53        Self {
54            client: OnceLock::new(),
55        }
56    }
57
58    /// Resolve an image reference to a digest for a specific platform.
59    ///
60    /// Returns the manifest digest for the platform-specific image.
61    pub async fn resolve_digest(&self, image: &str, platform: &Platform) -> Result<ResolvedImage> {
62        let reference = parse_reference(image)?;
63        info!(%image, %platform, "Resolving image digest");
64
65        let auth = self.get_auth(&reference);
66        let client = self.client()?;
67
68        // Pull the manifest and config
69        let (manifest, digest, _config) = client
70            .pull_manifest_and_config(&reference, &auth)
71            .await
72            .map_err(|e| Error::Oci(e.to_string()))?;
73
74        trace!(?manifest, "Got manifest");
75
76        // Extract layer digests from manifest
77        let layers: Vec<String> = manifest.layers.iter().map(|l| l.digest.clone()).collect();
78
79        // Also store layer descriptors for pulling
80        let layer_descriptors: Vec<OciDescriptor> = manifest.layers.clone();
81
82        debug!(
83            %image,
84            %platform,
85            %digest,
86            layer_count = layers.len(),
87            "Resolved image"
88        );
89
90        Ok(ResolvedImage {
91            reference,
92            digest,
93            layers,
94            layer_descriptors,
95        })
96    }
97
98    /// Pull a blob (layer) to a file using its descriptor.
99    ///
100    /// After downloading, the blob's SHA256 digest is verified against the
101    /// expected digest from the descriptor. If verification fails, the file
102    /// is deleted and an error is returned.
103    pub async fn pull_blob_by_descriptor(
104        &self,
105        reference: &Reference,
106        descriptor: &OciDescriptor,
107        dest: &Path,
108    ) -> Result<()> {
109        debug!(digest = %descriptor.digest, ?dest, "Pulling blob");
110
111        // Create parent directories
112        if let Some(parent) = dest.parent() {
113            tokio::fs::create_dir_all(parent).await?;
114        }
115
116        // Pull the blob
117        let mut file = tokio::fs::File::create(dest).await?;
118        let client = self.client()?;
119
120        client
121            .pull_blob(reference, descriptor, &mut file)
122            .await
123            .map_err(|e| Error::blob_pull_failed(&descriptor.digest, e.to_string()))?;
124
125        file.flush().await?;
126
127        // Verify the digest matches
128        let computed_digest = compute_file_digest(dest).await?;
129        if computed_digest != descriptor.digest {
130            // Remove the corrupted/invalid file
131            tokio::fs::remove_file(dest).await.ok();
132            return Err(Error::digest_mismatch(&descriptor.digest, &computed_digest));
133        }
134
135        debug!(digest = %descriptor.digest, ?dest, "Pulled and verified blob");
136        Ok(())
137    }
138
139    /// Pull all layers for an image and cache them.
140    pub async fn pull_layers(
141        &self,
142        resolved: &ResolvedImage,
143        cache: &OciCache,
144    ) -> Result<Vec<std::path::PathBuf>> {
145        let mut paths = Vec::new();
146
147        for descriptor in &resolved.layer_descriptors {
148            let path = cache.blob_path(&descriptor.digest);
149
150            if path.exists() {
151                trace!(digest = %descriptor.digest, "Layer already cached");
152            } else {
153                self.pull_blob_by_descriptor(&resolved.reference, descriptor, &path)
154                    .await?;
155            }
156
157            paths.push(path);
158        }
159
160        Ok(paths)
161    }
162
163    /// Get authentication for a registry.
164    ///
165    /// Currently returns anonymous auth. Can be extended to support:
166    /// - Docker config credentials
167    /// - Environment variables
168    /// - Keychain integration
169    fn get_auth(&self, reference: &Reference) -> RegistryAuth {
170        // Check for GHCR token in environment
171        if reference.registry() == "ghcr.io" {
172            if let Ok(token) = std::env::var("GITHUB_TOKEN") {
173                return RegistryAuth::Basic("".to_string(), token);
174            }
175            if let Ok(token) = std::env::var("GH_TOKEN") {
176                return RegistryAuth::Basic("".to_string(), token);
177            }
178        }
179
180        RegistryAuth::Anonymous
181    }
182}
183
184/// A resolved OCI image with digest and layer information.
185#[derive(Debug, Clone)]
186pub struct ResolvedImage {
187    /// The parsed reference.
188    pub reference: Reference,
189    /// Content-addressable digest of the manifest.
190    pub digest: String,
191    /// Layer digests (for reference).
192    pub layers: Vec<String>,
193    /// Layer descriptors (for pulling blobs).
194    pub layer_descriptors: Vec<OciDescriptor>,
195}
196
197/// Parse an image reference string.
198fn parse_reference(image: &str) -> Result<Reference> {
199    image
200        .parse()
201        .map_err(|e: oci_distribution::ParseError| Error::invalid_reference(image, e.to_string()))
202}
203
204/// Compute the SHA256 digest of a file.
205///
206/// Returns the digest in OCI format: `sha256:<hex>`.
207async fn compute_file_digest(path: &Path) -> Result<String> {
208    let mut file = tokio::fs::File::open(path).await?;
209    let mut hasher = Sha256::new();
210    let mut buffer = vec![0u8; 8192];
211
212    loop {
213        let n = file.read(&mut buffer).await?;
214        if n == 0 {
215            break;
216        }
217        hasher.update(&buffer[..n]);
218    }
219
220    Ok(format!("sha256:{:x}", hasher.finalize()))
221}
222
223#[cfg(test)]
224#[allow(unsafe_code)]
225mod tests {
226    use super::*;
227    use tempfile::TempDir;
228
229    // ==========================================================================
230    // parse_reference tests
231    // ==========================================================================
232
233    #[test]
234    fn test_parse_reference() {
235        let r = parse_reference("ghcr.io/distroless/static:nonroot").unwrap();
236        assert_eq!(r.registry(), "ghcr.io");
237        assert_eq!(r.repository(), "distroless/static");
238        assert_eq!(r.tag(), Some("nonroot"));
239    }
240
241    #[test]
242    fn test_parse_reference_with_digest() {
243        // Digest must be valid SHA256 (64 hex chars)
244        let digest = "sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855";
245        let r = parse_reference(&format!("nginx@{}", digest)).unwrap();
246        assert_eq!(r.repository(), "library/nginx");
247    }
248
249    #[test]
250    fn test_parse_reference_invalid() {
251        let r = parse_reference("not a valid reference!!!");
252        assert!(r.is_err());
253    }
254
255    #[test]
256    fn test_parse_reference_docker_hub_short() {
257        let r = parse_reference("nginx:latest").unwrap();
258        assert_eq!(r.registry(), "docker.io");
259        assert_eq!(r.repository(), "library/nginx");
260        assert_eq!(r.tag(), Some("latest"));
261    }
262
263    #[test]
264    fn test_parse_reference_with_port() {
265        let r = parse_reference("localhost:5000/myimage:v1").unwrap();
266        assert_eq!(r.registry(), "localhost:5000");
267        assert_eq!(r.repository(), "myimage");
268        assert_eq!(r.tag(), Some("v1"));
269    }
270
271    #[test]
272    fn test_parse_reference_no_tag() {
273        // Without a tag, it should default to "latest"
274        let r = parse_reference("nginx").unwrap();
275        assert_eq!(r.repository(), "library/nginx");
276    }
277
278    #[test]
279    fn test_parse_reference_private_registry() {
280        let r = parse_reference("registry.example.com/org/repo:v2.0.0").unwrap();
281        assert_eq!(r.registry(), "registry.example.com");
282        assert_eq!(r.repository(), "org/repo");
283        assert_eq!(r.tag(), Some("v2.0.0"));
284    }
285
286    // ==========================================================================
287    // compute_file_digest tests
288    // ==========================================================================
289
290    #[tokio::test]
291    async fn test_compute_file_digest() {
292        let temp = TempDir::new().unwrap();
293        let file_path = temp.path().join("test.txt");
294
295        // Write known content - empty file has a known SHA256
296        std::fs::write(&file_path, b"").unwrap();
297        let digest = compute_file_digest(&file_path).await.unwrap();
298        // SHA256 of empty string
299        assert_eq!(
300            digest,
301            "sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"
302        );
303
304        // Write "hello" and verify
305        std::fs::write(&file_path, b"hello").unwrap();
306        let digest = compute_file_digest(&file_path).await.unwrap();
307        // SHA256 of "hello"
308        assert_eq!(
309            digest,
310            "sha256:2cf24dba5fb0a30e26e83b2ac5b9e29e1b161e5c1fa7425e73043362938b9824"
311        );
312    }
313
314    #[tokio::test]
315    async fn test_compute_file_digest_larger_content() {
316        let temp = TempDir::new().unwrap();
317        let file_path = temp.path().join("large.bin");
318
319        // Write content larger than the buffer size (8192 bytes)
320        let content: Vec<u8> = (0..20000).map(|i| (i % 256) as u8).collect();
321        std::fs::write(&file_path, &content).unwrap();
322
323        let digest = compute_file_digest(&file_path).await.unwrap();
324        assert!(digest.starts_with("sha256:"));
325        assert_eq!(digest.len(), 7 + 64); // "sha256:" + 64 hex chars
326    }
327
328    #[tokio::test]
329    async fn test_compute_file_digest_nonexistent() {
330        let result = compute_file_digest(std::path::Path::new("/nonexistent/path")).await;
331        assert!(result.is_err());
332    }
333
334    // ==========================================================================
335    // Error tests
336    // ==========================================================================
337
338    #[test]
339    fn test_digest_mismatch_error() {
340        let err = Error::digest_mismatch("sha256:expected", "sha256:actual");
341        let msg = err.to_string();
342        assert!(msg.contains("expected"));
343        assert!(msg.contains("actual"));
344    }
345
346    #[test]
347    fn test_invalid_reference_error() {
348        let err = Error::invalid_reference("bad image", "parse error");
349        let msg = err.to_string();
350        assert!(msg.contains("bad image") || msg.contains("parse error"));
351    }
352
353    #[test]
354    fn test_blob_pull_failed_error() {
355        let err = Error::blob_pull_failed("sha256:abc123", "connection refused");
356        let msg = err.to_string();
357        assert!(msg.contains("sha256:abc123") || msg.contains("connection refused"));
358    }
359
360    // ==========================================================================
361    // OciClient tests
362    // ==========================================================================
363
364    #[test]
365    fn test_oci_client_new() {
366        let client = OciClient::new();
367        // Just verify it can be created
368        let _ = client;
369    }
370
371    #[test]
372    fn test_oci_client_default() {
373        let client = OciClient::default();
374        // Verify Default trait works
375        let _ = client;
376    }
377
378    #[test]
379    fn test_oci_client_get_auth_anonymous() {
380        let client = OciClient::new();
381        let reference = parse_reference("docker.io/library/nginx:latest").unwrap();
382        let auth = client.get_auth(&reference);
383        assert!(matches!(auth, RegistryAuth::Anonymous));
384    }
385
386    #[test]
387    fn test_oci_client_get_auth_ghcr_no_token() {
388        // Ensure no token env vars are set
389        // SAFETY: Test runs in isolation
390        unsafe {
391            std::env::remove_var("GITHUB_TOKEN");
392            std::env::remove_var("GH_TOKEN");
393        }
394
395        let client = OciClient::new();
396        let reference = parse_reference("ghcr.io/owner/image:latest").unwrap();
397        let auth = client.get_auth(&reference);
398        assert!(matches!(auth, RegistryAuth::Anonymous));
399    }
400
401    // ==========================================================================
402    // ResolvedImage tests
403    // ==========================================================================
404
405    #[test]
406    fn test_resolved_image_debug() {
407        let reference = parse_reference("nginx:latest").unwrap();
408        let resolved = ResolvedImage {
409            reference,
410            digest: "sha256:abc123".to_string(),
411            layers: vec!["sha256:layer1".to_string()],
412            layer_descriptors: vec![],
413        };
414
415        let debug_str = format!("{:?}", resolved);
416        assert!(debug_str.contains("sha256:abc123"));
417    }
418
419    #[test]
420    fn test_resolved_image_clone() {
421        let reference = parse_reference("nginx:latest").unwrap();
422        let resolved = ResolvedImage {
423            reference,
424            digest: "sha256:abc123".to_string(),
425            layers: vec!["sha256:layer1".to_string(), "sha256:layer2".to_string()],
426            layer_descriptors: vec![],
427        };
428
429        let cloned = resolved.clone();
430        assert_eq!(cloned.digest, "sha256:abc123");
431        assert_eq!(cloned.layers.len(), 2);
432    }
433
434    #[test]
435    fn test_resolved_image_empty_layers() {
436        let reference = parse_reference("scratch:latest").unwrap();
437        let resolved = ResolvedImage {
438            reference,
439            digest: "sha256:empty".to_string(),
440            layers: vec![],
441            layer_descriptors: vec![],
442        };
443
444        assert!(resolved.layers.is_empty());
445        assert!(resolved.layer_descriptors.is_empty());
446    }
447}