Skip to main content

cfgd_core/oci/
pull.rs

1// Pull: download OCI module artifact, verify layer digest, extract to disk.
2// Optional cosign signature verification (real cryptographic check).
3
4use std::io::Read;
5use std::path::Path;
6
7use crate::PathDisplayExt;
8use crate::errors::OciError;
9use crate::output::Printer;
10use crate::sha256_digest;
11
12use super::archive::extract_tar_gz;
13use super::auth::RegistryAuth;
14use super::sign::{VerifyOptions, verify_signature};
15use super::transport::authenticated_request;
16use super::{MEDIA_TYPE_OCI_MANIFEST, OciManifest, OciReference};
17
18/// Policy for verifying a module artifact's cosign signature during pull.
19///
20/// - `None` — skip signature verification entirely (default).
21/// - `RequireKey { path }` — fail unless `cosign verify --key <path>` succeeds.
22/// - `RequireKeyless { identity, issuer }` — fail unless keyless verification
23///   matches the supplied certificate identity / OIDC issuer constraints.
24#[derive(Debug, Clone)]
25pub enum SignaturePolicy<'a> {
26    None,
27    RequireKey {
28        path: &'a str,
29    },
30    RequireKeyless {
31        identity: Option<&'a str>,
32        issuer: Option<&'a str>,
33    },
34}
35
36impl SignaturePolicy<'_> {
37    fn requires_signature(&self) -> bool {
38        !matches!(self, SignaturePolicy::None)
39    }
40}
41
42/// Pull a module from an OCI registry and extract it to `output_dir`.
43///
44/// `signature_policy` controls cryptographic signature verification:
45/// - `SignaturePolicy::None` — no verification (default).
46/// - `SignaturePolicy::RequireKey { path }` — run real `cosign verify --key`,
47///   fail the pull if it does not succeed.
48/// - `SignaturePolicy::RequireKeyless { identity, issuer }` — run real
49///   keyless verification with the supplied constraints, fail the pull if it
50///   does not succeed.
51///
52/// Prior to v0.4.0 this took a `bool` and only checked for the *presence* of
53/// a signature manifest (HEAD on `<tag>.sig`) — a TOFU sentinel an attacker
54/// who could push to the registry could trivially satisfy. The current API
55/// requires callers to supply the verifying key (or identity/issuer) so the
56/// trust decision is explicit and cryptographically enforced.
57pub fn pull_module(
58    artifact_ref: &str,
59    output_dir: &Path,
60    signature_policy: SignaturePolicy<'_>,
61    printer: Option<&Printer>,
62) -> Result<(), OciError> {
63    let oci_ref = OciReference::parse(artifact_ref)?;
64    let auth = RegistryAuth::resolve(&oci_ref.registry);
65    let agent = crate::http::http_agent(crate::http::HTTP_OCI_TIMEOUT);
66
67    let spinner = printer.map(|p| p.spinner(format!("Pulling module from {artifact_ref}...")));
68
69    if signature_policy.requires_signature() {
70        let opts = match &signature_policy {
71            SignaturePolicy::None => unreachable!("guarded by requires_signature()"),
72            SignaturePolicy::RequireKey { path } => VerifyOptions {
73                key: Some(path),
74                identity: None,
75                issuer: None,
76            },
77            SignaturePolicy::RequireKeyless { identity, issuer } => VerifyOptions {
78                key: None,
79                identity: *identity,
80                issuer: *issuer,
81            },
82        };
83        verify_signature(artifact_ref, &opts)?;
84    }
85
86    // Pull manifest
87    let manifest_url = format!(
88        "{}/{}/manifests/{}",
89        oci_ref.api_base(),
90        oci_ref.repository,
91        oci_ref.reference_str(),
92    );
93
94    let resp = authenticated_request(
95        &agent,
96        "GET",
97        &manifest_url,
98        auth.as_ref(),
99        Some(MEDIA_TYPE_OCI_MANIFEST),
100        None,
101        None,
102    )
103    .map_err(|e| OciError::ManifestNotFound {
104        reference: format!("{}: {e}", oci_ref),
105    })?;
106
107    let manifest_body = resp.into_string().map_err(|e| OciError::RequestFailed {
108        message: format!("cannot read manifest body: {e}"),
109    })?;
110    let manifest: OciManifest =
111        serde_json::from_str(&manifest_body).map_err(|e| OciError::RequestFailed {
112            message: format!("invalid manifest JSON: {e}"),
113        })?;
114
115    // Find our layer
116    let layer = manifest
117        .layers
118        .first()
119        .ok_or_else(|| OciError::RequestFailed {
120            message: "manifest has no layers".to_string(),
121        })?;
122
123    // Download layer blob
124    let blob_url = format!(
125        "{}/{}/blobs/{}",
126        oci_ref.api_base(),
127        oci_ref.repository,
128        layer.digest,
129    );
130
131    let resp = authenticated_request(
132        &agent,
133        "GET",
134        &blob_url,
135        auth.as_ref(),
136        Some("application/octet-stream"),
137        None,
138        None,
139    )
140    .map_err(|e| OciError::BlobNotFound {
141        digest: format!("{}: {e}", layer.digest),
142    })?;
143
144    // Read blob data (cap at 512 MB to prevent OOM from malicious manifests)
145    const MAX_BLOB_SIZE: u64 = 512 * 1024 * 1024;
146    if layer.size > MAX_BLOB_SIZE {
147        return Err(OciError::RequestFailed {
148            message: format!(
149                "layer size {} exceeds maximum allowed size ({} bytes)",
150                layer.size, MAX_BLOB_SIZE
151            ),
152        });
153    }
154    let mut blob_data = Vec::with_capacity(layer.size as usize);
155    resp.into_reader()
156        .take(MAX_BLOB_SIZE + 1024)
157        .read_to_end(&mut blob_data)?;
158
159    // Verify digest
160    let actual_digest = sha256_digest(&blob_data);
161    if actual_digest != layer.digest {
162        return Err(OciError::RequestFailed {
163            message: format!(
164                "layer digest mismatch: expected {}, got {}",
165                layer.digest, actual_digest
166            ),
167        });
168    }
169
170    // Extract
171    extract_tar_gz(&blob_data, output_dir)?;
172
173    if let Some(s) = spinner {
174        let _ = s.finish_ok(format!("Pulled module from {artifact_ref}"));
175    }
176
177    tracing::info!(
178        reference = %oci_ref,
179        output = %output_dir.posix(),
180        "module pulled"
181    );
182
183    Ok(())
184}
185
186#[cfg(test)]
187mod tests {
188    use super::*;
189    use crate::oci::archive::create_tar_gz;
190    use crate::oci::test_helpers::{create_test_module_dir, registry_from_url};
191    use crate::oci::{MEDIA_TYPE_MODULE_CONFIG, MEDIA_TYPE_MODULE_LAYER};
192
193    #[test]
194    fn pull_module_downloads_and_verifies_digest() {
195        let mut server = mockito::Server::new();
196        let registry = registry_from_url(&server.url());
197
198        // Create a layer tarball from a temp module dir
199        let src_dir = create_test_module_dir();
200        let layer_data = create_tar_gz(src_dir.path()).unwrap();
201        let layer_digest = sha256_digest(&layer_data);
202
203        // Build a manifest referencing this layer
204        let config_blob = serde_json::to_vec(&serde_json::json!({
205            "moduleYaml": "name: test",
206        }))
207        .unwrap();
208        let config_digest = sha256_digest(&config_blob);
209
210        let manifest = serde_json::json!({
211            "schemaVersion": 2,
212            "mediaType": MEDIA_TYPE_OCI_MANIFEST,
213            "config": {
214                "mediaType": MEDIA_TYPE_MODULE_CONFIG,
215                "digest": config_digest,
216                "size": config_blob.len(),
217            },
218            "layers": [{
219                "mediaType": MEDIA_TYPE_MODULE_LAYER,
220                "digest": layer_digest,
221                "size": layer_data.len(),
222            }],
223        });
224
225        // Mock manifest GET
226        server
227            .mock("GET", "/v2/test/pullmod/manifests/v1")
228            .with_status(200)
229            .with_header("Content-Type", MEDIA_TYPE_OCI_MANIFEST)
230            .with_body(serde_json::to_string(&manifest).unwrap())
231            .create();
232
233        // Mock layer blob GET
234        server
235            .mock(
236                "GET",
237                mockito::Matcher::Regex(r"/v2/test/pullmod/blobs/sha256:.*".to_string()),
238            )
239            .with_status(200)
240            .with_body(layer_data)
241            .create();
242
243        let output_dir = tempfile::tempdir().unwrap();
244        let artifact_ref = format!("{}/test/pullmod:v1", registry);
245        let result = pull_module(
246            &artifact_ref,
247            output_dir.path(),
248            SignaturePolicy::None,
249            None,
250        );
251        assert!(result.is_ok(), "pull_module failed: {:?}", result.err());
252
253        // Verify extracted files
254        assert!(output_dir.path().join("module.yaml").exists());
255        assert!(output_dir.path().join("README.md").exists());
256    }
257
258    #[test]
259    fn pull_module_detects_digest_mismatch() {
260        let mut server = mockito::Server::new();
261        let registry = registry_from_url(&server.url());
262
263        let real_layer_data = b"real layer content";
264        // Use a fake digest that does NOT match the real data
265        let fake_digest = "sha256:0000000000000000000000000000000000000000000000000000000000000000";
266
267        let manifest = serde_json::json!({
268            "schemaVersion": 2,
269            "mediaType": MEDIA_TYPE_OCI_MANIFEST,
270            "config": {
271                "mediaType": MEDIA_TYPE_MODULE_CONFIG,
272                "digest": "sha256:cfgcfg",
273                "size": 10,
274            },
275            "layers": [{
276                "mediaType": MEDIA_TYPE_MODULE_LAYER,
277                "digest": fake_digest,
278                "size": real_layer_data.len(),
279            }],
280        });
281
282        server
283            .mock("GET", "/v2/test/badmod/manifests/v1")
284            .with_status(200)
285            .with_body(serde_json::to_string(&manifest).unwrap())
286            .create();
287
288        server
289            .mock(
290                "GET",
291                mockito::Matcher::Regex(r"/v2/test/badmod/blobs/sha256:.*".to_string()),
292            )
293            .with_status(200)
294            .with_body(real_layer_data.as_slice())
295            .create();
296
297        let output_dir = tempfile::tempdir().unwrap();
298        let artifact_ref = format!("{}/test/badmod:v1", registry);
299        let result = pull_module(
300            &artifact_ref,
301            output_dir.path(),
302            SignaturePolicy::None,
303            None,
304        );
305        assert!(result.is_err());
306        let err_msg = format!("{}", result.unwrap_err());
307        assert!(
308            err_msg.contains("digest mismatch"),
309            "expected digest mismatch error, got: {err_msg}"
310        );
311    }
312
313    #[cfg(unix)]
314    #[test]
315    #[serial_test::serial]
316    fn pull_module_with_require_key_fails_when_cosign_verify_rejects() {
317        use crate::test_helpers::CosignTestShim;
318        let _shim = CosignTestShim::builder()
319            .with_exit(1)
320            .with_stderr("cosign error: signature does not match")
321            .install();
322
323        let server = mockito::Server::new();
324        let registry = registry_from_url(&server.url());
325        let output_dir = tempfile::tempdir().unwrap();
326        let artifact_ref = format!("{}/test/sigfail:v1", registry);
327
328        let key_dir = tempfile::tempdir().unwrap();
329        let key_path = key_dir.path().join("cosign.pub");
330        std::fs::write(&key_path, "fake-public-key").unwrap();
331        let key_path_str = key_path.to_str().unwrap();
332
333        let policy = SignaturePolicy::RequireKey { path: key_path_str };
334        let result = pull_module(&artifact_ref, output_dir.path(), policy, None);
335        assert!(result.is_err());
336        assert!(
337            matches!(result, Err(OciError::VerificationFailed { .. })),
338            "expected VerificationFailed, got: {:?}",
339            result.err()
340        );
341    }
342
343    #[cfg(unix)]
344    #[test]
345    #[serial_test::serial]
346    fn pull_module_with_require_key_proceeds_when_cosign_verify_succeeds() {
347        use crate::test_helpers::CosignTestShim;
348        let _shim = CosignTestShim::builder().with_exit(0).install();
349
350        let mut server = mockito::Server::new();
351        let registry = registry_from_url(&server.url());
352
353        let src_dir = create_test_module_dir();
354        let layer_data = create_tar_gz(src_dir.path()).unwrap();
355        let layer_digest = sha256_digest(&layer_data);
356        let config_blob =
357            serde_json::to_vec(&serde_json::json!({"moduleYaml": "name: t"})).unwrap();
358        let config_digest = sha256_digest(&config_blob);
359        let manifest = serde_json::json!({
360            "schemaVersion": 2,
361            "mediaType": MEDIA_TYPE_OCI_MANIFEST,
362            "config": {"mediaType": MEDIA_TYPE_MODULE_CONFIG, "digest": config_digest, "size": config_blob.len()},
363            "layers": [{"mediaType": MEDIA_TYPE_MODULE_LAYER, "digest": layer_digest, "size": layer_data.len()}],
364        });
365        server
366            .mock("GET", "/v2/test/sigok/manifests/v1")
367            .with_status(200)
368            .with_body(serde_json::to_string(&manifest).unwrap())
369            .create();
370        server
371            .mock(
372                "GET",
373                mockito::Matcher::Regex(r"/v2/test/sigok/blobs/sha256:.*".to_string()),
374            )
375            .with_status(200)
376            .with_body(layer_data)
377            .create();
378
379        let output_dir = tempfile::tempdir().unwrap();
380        let artifact_ref = format!("{}/test/sigok:v1", registry);
381        let key_dir = tempfile::tempdir().unwrap();
382        let key_path = key_dir.path().join("cosign.pub");
383        std::fs::write(&key_path, "fake-public-key").unwrap();
384        let key_path_str = key_path.to_str().unwrap();
385
386        let policy = SignaturePolicy::RequireKey { path: key_path_str };
387        let result = pull_module(&artifact_ref, output_dir.path(), policy, None);
388        assert!(result.is_ok(), "pull_module failed: {:?}", result.err());
389    }
390
391    #[test]
392    fn signature_policy_requires_signature_predicate() {
393        assert!(!SignaturePolicy::None.requires_signature());
394        assert!(SignaturePolicy::RequireKey { path: "k" }.requires_signature());
395        assert!(
396            SignaturePolicy::RequireKeyless {
397                identity: Some("u@example"),
398                issuer: None,
399            }
400            .requires_signature()
401        );
402    }
403
404    #[test]
405    fn pull_module_returns_manifest_not_found_on_404() {
406        let mut server = mockito::Server::new();
407        let registry = registry_from_url(&server.url());
408
409        // Manifest endpoint returns 404 → maps to ManifestNotFound
410        server
411            .mock("GET", "/v2/test/missingmod/manifests/v1")
412            .with_status(404)
413            .create();
414
415        let output_dir = tempfile::tempdir().unwrap();
416        let artifact_ref = format!("{}/test/missingmod:v1", registry);
417        let result = pull_module(
418            &artifact_ref,
419            output_dir.path(),
420            SignaturePolicy::None,
421            None,
422        );
423        assert!(matches!(result, Err(OciError::ManifestNotFound { .. })));
424    }
425
426    #[test]
427    fn pull_module_returns_blob_not_found_when_layer_missing() {
428        let mut server = mockito::Server::new();
429        let registry = registry_from_url(&server.url());
430
431        // Manifest succeeds but references a layer the registry won't serve.
432        let fake_digest = "sha256:0000000000000000000000000000000000000000000000000000000000000000";
433        let manifest = serde_json::json!({
434            "schemaVersion": 2,
435            "mediaType": MEDIA_TYPE_OCI_MANIFEST,
436            "config": {
437                "mediaType": MEDIA_TYPE_MODULE_CONFIG,
438                "digest": "sha256:cfgcfg",
439                "size": 10,
440            },
441            "layers": [{
442                "mediaType": MEDIA_TYPE_MODULE_LAYER,
443                "digest": fake_digest,
444                "size": 16,
445            }],
446        });
447
448        server
449            .mock("GET", "/v2/test/noblob/manifests/v1")
450            .with_status(200)
451            .with_body(serde_json::to_string(&manifest).unwrap())
452            .create();
453
454        // Blob fetch returns 404 → maps to BlobNotFound
455        server
456            .mock(
457                "GET",
458                mockito::Matcher::Regex(r"/v2/test/noblob/blobs/sha256:.*".to_string()),
459            )
460            .with_status(404)
461            .create();
462
463        let output_dir = tempfile::tempdir().unwrap();
464        let artifact_ref = format!("{}/test/noblob:v1", registry);
465        let result = pull_module(
466            &artifact_ref,
467            output_dir.path(),
468            SignaturePolicy::None,
469            None,
470        );
471        assert!(matches!(result, Err(OciError::BlobNotFound { .. })));
472    }
473
474    #[test]
475    fn pull_module_returns_request_failed_on_invalid_manifest_json() {
476        let mut server = mockito::Server::new();
477        let registry = registry_from_url(&server.url());
478
479        // Manifest GET succeeds (200) but body is unparseable → RequestFailed
480        server
481            .mock("GET", "/v2/test/badjson/manifests/v1")
482            .with_status(200)
483            .with_body("not valid json")
484            .create();
485
486        let output_dir = tempfile::tempdir().unwrap();
487        let artifact_ref = format!("{}/test/badjson:v1", registry);
488        let result = pull_module(
489            &artifact_ref,
490            output_dir.path(),
491            SignaturePolicy::None,
492            None,
493        );
494        assert!(matches!(result, Err(OciError::RequestFailed { .. })));
495    }
496}