1use std::io::Read;
5use std::path::Path;
6
7use crate::PathDisplayExt;
8use crate::errors::OciError;
9use crate::output::Printer;
10use crate::sha256_digest;
11
12use super::archive::extract_tar_gz;
13use super::auth::RegistryAuth;
14use super::sign::{VerifyOptions, verify_signature};
15use super::transport::authenticated_request;
16use super::{MEDIA_TYPE_OCI_MANIFEST, OciManifest, OciReference};
17
18#[derive(Debug, Clone)]
25pub enum SignaturePolicy<'a> {
26 None,
27 RequireKey {
28 path: &'a str,
29 },
30 RequireKeyless {
31 identity: Option<&'a str>,
32 issuer: Option<&'a str>,
33 },
34}
35
36impl SignaturePolicy<'_> {
37 fn requires_signature(&self) -> bool {
38 !matches!(self, SignaturePolicy::None)
39 }
40}
41
42pub fn pull_module(
58 artifact_ref: &str,
59 output_dir: &Path,
60 signature_policy: SignaturePolicy<'_>,
61 printer: Option<&Printer>,
62) -> Result<(), OciError> {
63 let oci_ref = OciReference::parse(artifact_ref)?;
64 let auth = RegistryAuth::resolve(&oci_ref.registry);
65 let agent = crate::http::http_agent(crate::http::HTTP_OCI_TIMEOUT);
66
67 let spinner = printer.map(|p| p.spinner(format!("Pulling module from {artifact_ref}...")));
68
69 if signature_policy.requires_signature() {
70 let opts = match &signature_policy {
71 SignaturePolicy::None => unreachable!("guarded by requires_signature()"),
72 SignaturePolicy::RequireKey { path } => VerifyOptions {
73 key: Some(path),
74 identity: None,
75 issuer: None,
76 },
77 SignaturePolicy::RequireKeyless { identity, issuer } => VerifyOptions {
78 key: None,
79 identity: *identity,
80 issuer: *issuer,
81 },
82 };
83 verify_signature(artifact_ref, &opts)?;
84 }
85
86 let manifest_url = format!(
88 "{}/{}/manifests/{}",
89 oci_ref.api_base(),
90 oci_ref.repository,
91 oci_ref.reference_str(),
92 );
93
94 let resp = authenticated_request(
95 &agent,
96 "GET",
97 &manifest_url,
98 auth.as_ref(),
99 Some(MEDIA_TYPE_OCI_MANIFEST),
100 None,
101 None,
102 )
103 .map_err(|e| OciError::ManifestNotFound {
104 reference: format!("{}: {e}", oci_ref),
105 })?;
106
107 let manifest_body = resp.into_string().map_err(|e| OciError::RequestFailed {
108 message: format!("cannot read manifest body: {e}"),
109 })?;
110 let manifest: OciManifest =
111 serde_json::from_str(&manifest_body).map_err(|e| OciError::RequestFailed {
112 message: format!("invalid manifest JSON: {e}"),
113 })?;
114
115 let layer = manifest
117 .layers
118 .first()
119 .ok_or_else(|| OciError::RequestFailed {
120 message: "manifest has no layers".to_string(),
121 })?;
122
123 let blob_url = format!(
125 "{}/{}/blobs/{}",
126 oci_ref.api_base(),
127 oci_ref.repository,
128 layer.digest,
129 );
130
131 let resp = authenticated_request(
132 &agent,
133 "GET",
134 &blob_url,
135 auth.as_ref(),
136 Some("application/octet-stream"),
137 None,
138 None,
139 )
140 .map_err(|e| OciError::BlobNotFound {
141 digest: format!("{}: {e}", layer.digest),
142 })?;
143
144 const MAX_BLOB_SIZE: u64 = 512 * 1024 * 1024;
146 if layer.size > MAX_BLOB_SIZE {
147 return Err(OciError::RequestFailed {
148 message: format!(
149 "layer size {} exceeds maximum allowed size ({} bytes)",
150 layer.size, MAX_BLOB_SIZE
151 ),
152 });
153 }
154 let mut blob_data = Vec::with_capacity(layer.size as usize);
155 resp.into_reader()
156 .take(MAX_BLOB_SIZE + 1024)
157 .read_to_end(&mut blob_data)?;
158
159 let actual_digest = sha256_digest(&blob_data);
161 if actual_digest != layer.digest {
162 return Err(OciError::RequestFailed {
163 message: format!(
164 "layer digest mismatch: expected {}, got {}",
165 layer.digest, actual_digest
166 ),
167 });
168 }
169
170 extract_tar_gz(&blob_data, output_dir)?;
172
173 if let Some(s) = spinner {
174 let _ = s.finish_ok(format!("Pulled module from {artifact_ref}"));
175 }
176
177 tracing::info!(
178 reference = %oci_ref,
179 output = %output_dir.posix(),
180 "module pulled"
181 );
182
183 Ok(())
184}
185
186#[cfg(test)]
187mod tests {
188 use super::*;
189 use crate::oci::archive::create_tar_gz;
190 use crate::oci::test_helpers::{create_test_module_dir, registry_from_url};
191 use crate::oci::{MEDIA_TYPE_MODULE_CONFIG, MEDIA_TYPE_MODULE_LAYER};
192
193 #[test]
194 fn pull_module_downloads_and_verifies_digest() {
195 let mut server = mockito::Server::new();
196 let registry = registry_from_url(&server.url());
197
198 let src_dir = create_test_module_dir();
200 let layer_data = create_tar_gz(src_dir.path()).unwrap();
201 let layer_digest = sha256_digest(&layer_data);
202
203 let config_blob = serde_json::to_vec(&serde_json::json!({
205 "moduleYaml": "name: test",
206 }))
207 .unwrap();
208 let config_digest = sha256_digest(&config_blob);
209
210 let manifest = serde_json::json!({
211 "schemaVersion": 2,
212 "mediaType": MEDIA_TYPE_OCI_MANIFEST,
213 "config": {
214 "mediaType": MEDIA_TYPE_MODULE_CONFIG,
215 "digest": config_digest,
216 "size": config_blob.len(),
217 },
218 "layers": [{
219 "mediaType": MEDIA_TYPE_MODULE_LAYER,
220 "digest": layer_digest,
221 "size": layer_data.len(),
222 }],
223 });
224
225 server
227 .mock("GET", "/v2/test/pullmod/manifests/v1")
228 .with_status(200)
229 .with_header("Content-Type", MEDIA_TYPE_OCI_MANIFEST)
230 .with_body(serde_json::to_string(&manifest).unwrap())
231 .create();
232
233 server
235 .mock(
236 "GET",
237 mockito::Matcher::Regex(r"/v2/test/pullmod/blobs/sha256:.*".to_string()),
238 )
239 .with_status(200)
240 .with_body(layer_data)
241 .create();
242
243 let output_dir = tempfile::tempdir().unwrap();
244 let artifact_ref = format!("{}/test/pullmod:v1", registry);
245 let result = pull_module(
246 &artifact_ref,
247 output_dir.path(),
248 SignaturePolicy::None,
249 None,
250 );
251 assert!(result.is_ok(), "pull_module failed: {:?}", result.err());
252
253 assert!(output_dir.path().join("module.yaml").exists());
255 assert!(output_dir.path().join("README.md").exists());
256 }
257
258 #[test]
259 fn pull_module_detects_digest_mismatch() {
260 let mut server = mockito::Server::new();
261 let registry = registry_from_url(&server.url());
262
263 let real_layer_data = b"real layer content";
264 let fake_digest = "sha256:0000000000000000000000000000000000000000000000000000000000000000";
266
267 let manifest = serde_json::json!({
268 "schemaVersion": 2,
269 "mediaType": MEDIA_TYPE_OCI_MANIFEST,
270 "config": {
271 "mediaType": MEDIA_TYPE_MODULE_CONFIG,
272 "digest": "sha256:cfgcfg",
273 "size": 10,
274 },
275 "layers": [{
276 "mediaType": MEDIA_TYPE_MODULE_LAYER,
277 "digest": fake_digest,
278 "size": real_layer_data.len(),
279 }],
280 });
281
282 server
283 .mock("GET", "/v2/test/badmod/manifests/v1")
284 .with_status(200)
285 .with_body(serde_json::to_string(&manifest).unwrap())
286 .create();
287
288 server
289 .mock(
290 "GET",
291 mockito::Matcher::Regex(r"/v2/test/badmod/blobs/sha256:.*".to_string()),
292 )
293 .with_status(200)
294 .with_body(real_layer_data.as_slice())
295 .create();
296
297 let output_dir = tempfile::tempdir().unwrap();
298 let artifact_ref = format!("{}/test/badmod:v1", registry);
299 let result = pull_module(
300 &artifact_ref,
301 output_dir.path(),
302 SignaturePolicy::None,
303 None,
304 );
305 assert!(result.is_err());
306 let err_msg = format!("{}", result.unwrap_err());
307 assert!(
308 err_msg.contains("digest mismatch"),
309 "expected digest mismatch error, got: {err_msg}"
310 );
311 }
312
313 #[cfg(unix)]
314 #[test]
315 #[serial_test::serial]
316 fn pull_module_with_require_key_fails_when_cosign_verify_rejects() {
317 use crate::test_helpers::CosignTestShim;
318 let _shim = CosignTestShim::builder()
319 .with_exit(1)
320 .with_stderr("cosign error: signature does not match")
321 .install();
322
323 let server = mockito::Server::new();
324 let registry = registry_from_url(&server.url());
325 let output_dir = tempfile::tempdir().unwrap();
326 let artifact_ref = format!("{}/test/sigfail:v1", registry);
327
328 let key_dir = tempfile::tempdir().unwrap();
329 let key_path = key_dir.path().join("cosign.pub");
330 std::fs::write(&key_path, "fake-public-key").unwrap();
331 let key_path_str = key_path.to_str().unwrap();
332
333 let policy = SignaturePolicy::RequireKey { path: key_path_str };
334 let result = pull_module(&artifact_ref, output_dir.path(), policy, None);
335 assert!(result.is_err());
336 assert!(
337 matches!(result, Err(OciError::VerificationFailed { .. })),
338 "expected VerificationFailed, got: {:?}",
339 result.err()
340 );
341 }
342
343 #[cfg(unix)]
344 #[test]
345 #[serial_test::serial]
346 fn pull_module_with_require_key_proceeds_when_cosign_verify_succeeds() {
347 use crate::test_helpers::CosignTestShim;
348 let _shim = CosignTestShim::builder().with_exit(0).install();
349
350 let mut server = mockito::Server::new();
351 let registry = registry_from_url(&server.url());
352
353 let src_dir = create_test_module_dir();
354 let layer_data = create_tar_gz(src_dir.path()).unwrap();
355 let layer_digest = sha256_digest(&layer_data);
356 let config_blob =
357 serde_json::to_vec(&serde_json::json!({"moduleYaml": "name: t"})).unwrap();
358 let config_digest = sha256_digest(&config_blob);
359 let manifest = serde_json::json!({
360 "schemaVersion": 2,
361 "mediaType": MEDIA_TYPE_OCI_MANIFEST,
362 "config": {"mediaType": MEDIA_TYPE_MODULE_CONFIG, "digest": config_digest, "size": config_blob.len()},
363 "layers": [{"mediaType": MEDIA_TYPE_MODULE_LAYER, "digest": layer_digest, "size": layer_data.len()}],
364 });
365 server
366 .mock("GET", "/v2/test/sigok/manifests/v1")
367 .with_status(200)
368 .with_body(serde_json::to_string(&manifest).unwrap())
369 .create();
370 server
371 .mock(
372 "GET",
373 mockito::Matcher::Regex(r"/v2/test/sigok/blobs/sha256:.*".to_string()),
374 )
375 .with_status(200)
376 .with_body(layer_data)
377 .create();
378
379 let output_dir = tempfile::tempdir().unwrap();
380 let artifact_ref = format!("{}/test/sigok:v1", registry);
381 let key_dir = tempfile::tempdir().unwrap();
382 let key_path = key_dir.path().join("cosign.pub");
383 std::fs::write(&key_path, "fake-public-key").unwrap();
384 let key_path_str = key_path.to_str().unwrap();
385
386 let policy = SignaturePolicy::RequireKey { path: key_path_str };
387 let result = pull_module(&artifact_ref, output_dir.path(), policy, None);
388 assert!(result.is_ok(), "pull_module failed: {:?}", result.err());
389 }
390
391 #[test]
392 fn signature_policy_requires_signature_predicate() {
393 assert!(!SignaturePolicy::None.requires_signature());
394 assert!(SignaturePolicy::RequireKey { path: "k" }.requires_signature());
395 assert!(
396 SignaturePolicy::RequireKeyless {
397 identity: Some("u@example"),
398 issuer: None,
399 }
400 .requires_signature()
401 );
402 }
403
404 #[test]
405 fn pull_module_returns_manifest_not_found_on_404() {
406 let mut server = mockito::Server::new();
407 let registry = registry_from_url(&server.url());
408
409 server
411 .mock("GET", "/v2/test/missingmod/manifests/v1")
412 .with_status(404)
413 .create();
414
415 let output_dir = tempfile::tempdir().unwrap();
416 let artifact_ref = format!("{}/test/missingmod:v1", registry);
417 let result = pull_module(
418 &artifact_ref,
419 output_dir.path(),
420 SignaturePolicy::None,
421 None,
422 );
423 assert!(matches!(result, Err(OciError::ManifestNotFound { .. })));
424 }
425
426 #[test]
427 fn pull_module_returns_blob_not_found_when_layer_missing() {
428 let mut server = mockito::Server::new();
429 let registry = registry_from_url(&server.url());
430
431 let fake_digest = "sha256:0000000000000000000000000000000000000000000000000000000000000000";
433 let manifest = serde_json::json!({
434 "schemaVersion": 2,
435 "mediaType": MEDIA_TYPE_OCI_MANIFEST,
436 "config": {
437 "mediaType": MEDIA_TYPE_MODULE_CONFIG,
438 "digest": "sha256:cfgcfg",
439 "size": 10,
440 },
441 "layers": [{
442 "mediaType": MEDIA_TYPE_MODULE_LAYER,
443 "digest": fake_digest,
444 "size": 16,
445 }],
446 });
447
448 server
449 .mock("GET", "/v2/test/noblob/manifests/v1")
450 .with_status(200)
451 .with_body(serde_json::to_string(&manifest).unwrap())
452 .create();
453
454 server
456 .mock(
457 "GET",
458 mockito::Matcher::Regex(r"/v2/test/noblob/blobs/sha256:.*".to_string()),
459 )
460 .with_status(404)
461 .create();
462
463 let output_dir = tempfile::tempdir().unwrap();
464 let artifact_ref = format!("{}/test/noblob:v1", registry);
465 let result = pull_module(
466 &artifact_ref,
467 output_dir.path(),
468 SignaturePolicy::None,
469 None,
470 );
471 assert!(matches!(result, Err(OciError::BlobNotFound { .. })));
472 }
473
474 #[test]
475 fn pull_module_returns_request_failed_on_invalid_manifest_json() {
476 let mut server = mockito::Server::new();
477 let registry = registry_from_url(&server.url());
478
479 server
481 .mock("GET", "/v2/test/badjson/manifests/v1")
482 .with_status(200)
483 .with_body("not valid json")
484 .create();
485
486 let output_dir = tempfile::tempdir().unwrap();
487 let artifact_ref = format!("{}/test/badjson:v1", registry);
488 let result = pull_module(
489 &artifact_ref,
490 output_dir.path(),
491 SignaturePolicy::None,
492 None,
493 );
494 assert!(matches!(result, Err(OciError::RequestFailed { .. })));
495 }
496}