1use oci_distribution::client::{ClientConfig, ClientProtocol};
6use oci_distribution::manifest::OciDescriptor;
7use oci_distribution::secrets::RegistryAuth;
8use oci_distribution::{Client, Reference};
9use sha2::{Digest, Sha256};
10use std::panic::{AssertUnwindSafe, catch_unwind};
11use std::path::Path;
12use std::sync::OnceLock;
13use tokio::io::{AsyncReadExt, AsyncWriteExt};
14use tracing::{debug, info, trace};
15
16use crate::cache::OciCache;
17use crate::platform::Platform;
18use crate::{Error, Result};
19
20pub struct OciClient {
22 client: OnceLock<std::result::Result<Client, String>>,
23}
24
25impl Default for OciClient {
26 fn default() -> Self {
27 Self::new()
28 }
29}
30
31impl OciClient {
32 fn create_client() -> std::result::Result<Client, String> {
33 let config = ClientConfig {
34 protocol: ClientProtocol::Https,
35 ..Default::default()
36 };
37
38 catch_unwind(AssertUnwindSafe(|| Client::new(config))).map_err(|_| {
39 "Failed to initialize OCI client because system proxy discovery panicked".to_string()
40 })
41 }
42
43 fn client(&self) -> Result<&Client> {
44 match self.client.get_or_init(Self::create_client) {
45 Ok(client) => Ok(client),
46 Err(err) => Err(Error::Oci(err.clone())),
47 }
48 }
49
50 #[must_use]
52 pub fn new() -> Self {
53 Self {
54 client: OnceLock::new(),
55 }
56 }
57
58 pub async fn resolve_digest(&self, image: &str, platform: &Platform) -> Result<ResolvedImage> {
62 let reference = parse_reference(image)?;
63 info!(%image, %platform, "Resolving image digest");
64
65 let auth = self.get_auth(&reference);
66 let client = self.client()?;
67
68 let (manifest, digest, _config) = client
70 .pull_manifest_and_config(&reference, &auth)
71 .await
72 .map_err(|e| Error::Oci(e.to_string()))?;
73
74 trace!(?manifest, "Got manifest");
75
76 let layers: Vec<String> = manifest.layers.iter().map(|l| l.digest.clone()).collect();
78
79 let layer_descriptors: Vec<OciDescriptor> = manifest.layers.clone();
81
82 debug!(
83 %image,
84 %platform,
85 %digest,
86 layer_count = layers.len(),
87 "Resolved image"
88 );
89
90 Ok(ResolvedImage {
91 reference,
92 digest,
93 layers,
94 layer_descriptors,
95 })
96 }
97
98 pub async fn pull_blob_by_descriptor(
104 &self,
105 reference: &Reference,
106 descriptor: &OciDescriptor,
107 dest: &Path,
108 ) -> Result<()> {
109 debug!(digest = %descriptor.digest, ?dest, "Pulling blob");
110
111 if let Some(parent) = dest.parent() {
113 tokio::fs::create_dir_all(parent).await?;
114 }
115
116 let mut file = tokio::fs::File::create(dest).await?;
118 let client = self.client()?;
119
120 client
121 .pull_blob(reference, descriptor, &mut file)
122 .await
123 .map_err(|e| Error::blob_pull_failed(&descriptor.digest, e.to_string()))?;
124
125 file.flush().await?;
126
127 let computed_digest = compute_file_digest(dest).await?;
129 if computed_digest != descriptor.digest {
130 tokio::fs::remove_file(dest).await.ok();
132 return Err(Error::digest_mismatch(&descriptor.digest, &computed_digest));
133 }
134
135 debug!(digest = %descriptor.digest, ?dest, "Pulled and verified blob");
136 Ok(())
137 }
138
139 pub async fn pull_layers(
141 &self,
142 resolved: &ResolvedImage,
143 cache: &OciCache,
144 ) -> Result<Vec<std::path::PathBuf>> {
145 let mut paths = Vec::new();
146
147 for descriptor in &resolved.layer_descriptors {
148 let path = cache.blob_path(&descriptor.digest);
149
150 if path.exists() {
151 trace!(digest = %descriptor.digest, "Layer already cached");
152 } else {
153 self.pull_blob_by_descriptor(&resolved.reference, descriptor, &path)
154 .await?;
155 }
156
157 paths.push(path);
158 }
159
160 Ok(paths)
161 }
162
163 fn get_auth(&self, reference: &Reference) -> RegistryAuth {
170 if reference.registry() == "ghcr.io" {
172 if let Ok(token) = std::env::var("GITHUB_TOKEN") {
173 return RegistryAuth::Basic("".to_string(), token);
174 }
175 if let Ok(token) = std::env::var("GH_TOKEN") {
176 return RegistryAuth::Basic("".to_string(), token);
177 }
178 }
179
180 RegistryAuth::Anonymous
181 }
182}
183
184#[derive(Debug, Clone)]
186pub struct ResolvedImage {
187 pub reference: Reference,
189 pub digest: String,
191 pub layers: Vec<String>,
193 pub layer_descriptors: Vec<OciDescriptor>,
195}
196
197fn parse_reference(image: &str) -> Result<Reference> {
199 image
200 .parse()
201 .map_err(|e: oci_distribution::ParseError| Error::invalid_reference(image, e.to_string()))
202}
203
204async fn compute_file_digest(path: &Path) -> Result<String> {
208 let mut file = tokio::fs::File::open(path).await?;
209 let mut hasher = Sha256::new();
210 let mut buffer = vec![0u8; 8192];
211
212 loop {
213 let n = file.read(&mut buffer).await?;
214 if n == 0 {
215 break;
216 }
217 hasher.update(&buffer[..n]);
218 }
219
220 Ok(format!("sha256:{:x}", hasher.finalize()))
221}
222
223#[cfg(test)]
224#[allow(unsafe_code)]
225mod tests {
226 use super::*;
227 use tempfile::TempDir;
228
229 #[test]
234 fn test_parse_reference() {
235 let r = parse_reference("ghcr.io/distroless/static:nonroot").unwrap();
236 assert_eq!(r.registry(), "ghcr.io");
237 assert_eq!(r.repository(), "distroless/static");
238 assert_eq!(r.tag(), Some("nonroot"));
239 }
240
241 #[test]
242 fn test_parse_reference_with_digest() {
243 let digest = "sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855";
245 let r = parse_reference(&format!("nginx@{}", digest)).unwrap();
246 assert_eq!(r.repository(), "library/nginx");
247 }
248
249 #[test]
250 fn test_parse_reference_invalid() {
251 let r = parse_reference("not a valid reference!!!");
252 assert!(r.is_err());
253 }
254
255 #[test]
256 fn test_parse_reference_docker_hub_short() {
257 let r = parse_reference("nginx:latest").unwrap();
258 assert_eq!(r.registry(), "docker.io");
259 assert_eq!(r.repository(), "library/nginx");
260 assert_eq!(r.tag(), Some("latest"));
261 }
262
263 #[test]
264 fn test_parse_reference_with_port() {
265 let r = parse_reference("localhost:5000/myimage:v1").unwrap();
266 assert_eq!(r.registry(), "localhost:5000");
267 assert_eq!(r.repository(), "myimage");
268 assert_eq!(r.tag(), Some("v1"));
269 }
270
271 #[test]
272 fn test_parse_reference_no_tag() {
273 let r = parse_reference("nginx").unwrap();
275 assert_eq!(r.repository(), "library/nginx");
276 }
277
278 #[test]
279 fn test_parse_reference_private_registry() {
280 let r = parse_reference("registry.example.com/org/repo:v2.0.0").unwrap();
281 assert_eq!(r.registry(), "registry.example.com");
282 assert_eq!(r.repository(), "org/repo");
283 assert_eq!(r.tag(), Some("v2.0.0"));
284 }
285
286 #[tokio::test]
291 async fn test_compute_file_digest() {
292 let temp = TempDir::new().unwrap();
293 let file_path = temp.path().join("test.txt");
294
295 std::fs::write(&file_path, b"").unwrap();
297 let digest = compute_file_digest(&file_path).await.unwrap();
298 assert_eq!(
300 digest,
301 "sha256:e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855"
302 );
303
304 std::fs::write(&file_path, b"hello").unwrap();
306 let digest = compute_file_digest(&file_path).await.unwrap();
307 assert_eq!(
309 digest,
310 "sha256:2cf24dba5fb0a30e26e83b2ac5b9e29e1b161e5c1fa7425e73043362938b9824"
311 );
312 }
313
314 #[tokio::test]
315 async fn test_compute_file_digest_larger_content() {
316 let temp = TempDir::new().unwrap();
317 let file_path = temp.path().join("large.bin");
318
319 let content: Vec<u8> = (0..20000).map(|i| (i % 256) as u8).collect();
321 std::fs::write(&file_path, &content).unwrap();
322
323 let digest = compute_file_digest(&file_path).await.unwrap();
324 assert!(digest.starts_with("sha256:"));
325 assert_eq!(digest.len(), 7 + 64); }
327
328 #[tokio::test]
329 async fn test_compute_file_digest_nonexistent() {
330 let result = compute_file_digest(std::path::Path::new("/nonexistent/path")).await;
331 assert!(result.is_err());
332 }
333
334 #[test]
339 fn test_digest_mismatch_error() {
340 let err = Error::digest_mismatch("sha256:expected", "sha256:actual");
341 let msg = err.to_string();
342 assert!(msg.contains("expected"));
343 assert!(msg.contains("actual"));
344 }
345
346 #[test]
347 fn test_invalid_reference_error() {
348 let err = Error::invalid_reference("bad image", "parse error");
349 let msg = err.to_string();
350 assert!(msg.contains("bad image") || msg.contains("parse error"));
351 }
352
353 #[test]
354 fn test_blob_pull_failed_error() {
355 let err = Error::blob_pull_failed("sha256:abc123", "connection refused");
356 let msg = err.to_string();
357 assert!(msg.contains("sha256:abc123") || msg.contains("connection refused"));
358 }
359
360 #[test]
365 fn test_oci_client_new() {
366 let client = OciClient::new();
367 let _ = client;
369 }
370
371 #[test]
372 fn test_oci_client_default() {
373 let client = OciClient::default();
374 let _ = client;
376 }
377
378 #[test]
379 fn test_oci_client_get_auth_anonymous() {
380 let client = OciClient::new();
381 let reference = parse_reference("docker.io/library/nginx:latest").unwrap();
382 let auth = client.get_auth(&reference);
383 assert!(matches!(auth, RegistryAuth::Anonymous));
384 }
385
386 #[test]
387 fn test_oci_client_get_auth_ghcr_no_token() {
388 unsafe {
391 std::env::remove_var("GITHUB_TOKEN");
392 std::env::remove_var("GH_TOKEN");
393 }
394
395 let client = OciClient::new();
396 let reference = parse_reference("ghcr.io/owner/image:latest").unwrap();
397 let auth = client.get_auth(&reference);
398 assert!(matches!(auth, RegistryAuth::Anonymous));
399 }
400
401 #[test]
406 fn test_resolved_image_debug() {
407 let reference = parse_reference("nginx:latest").unwrap();
408 let resolved = ResolvedImage {
409 reference,
410 digest: "sha256:abc123".to_string(),
411 layers: vec!["sha256:layer1".to_string()],
412 layer_descriptors: vec![],
413 };
414
415 let debug_str = format!("{:?}", resolved);
416 assert!(debug_str.contains("sha256:abc123"));
417 }
418
419 #[test]
420 fn test_resolved_image_clone() {
421 let reference = parse_reference("nginx:latest").unwrap();
422 let resolved = ResolvedImage {
423 reference,
424 digest: "sha256:abc123".to_string(),
425 layers: vec!["sha256:layer1".to_string(), "sha256:layer2".to_string()],
426 layer_descriptors: vec![],
427 };
428
429 let cloned = resolved.clone();
430 assert_eq!(cloned.digest, "sha256:abc123");
431 assert_eq!(cloned.layers.len(), 2);
432 }
433
434 #[test]
435 fn test_resolved_image_empty_layers() {
436 let reference = parse_reference("scratch:latest").unwrap();
437 let resolved = ResolvedImage {
438 reference,
439 digest: "sha256:empty".to_string(),
440 layers: vec![],
441 layer_descriptors: vec![],
442 };
443
444 assert!(resolved.layers.is_empty());
445 assert!(resolved.layer_descriptors.is_empty());
446 }
447}