Skip to main content

initramfs_builder/registry/
client.rs

1use anyhow::{Context, Result};
2use oci_distribution::{
3    client::{Client, ClientConfig, ClientProtocol},
4    manifest::OciDescriptor,
5    secrets::RegistryAuth as OciRegistryAuth,
6    Reference,
7};
8use serde::{Deserialize, Serialize};
9use std::sync::Arc;
10use tracing::{debug, info};
11
12/// Authentication credentials for a registry
13#[derive(Debug, Clone, Default)]
14pub enum RegistryAuth {
15    #[default]
16    Anonymous,
17    Basic {
18        username: String,
19        password: String,
20    },
21}
22
23impl From<RegistryAuth> for OciRegistryAuth {
24    fn from(auth: RegistryAuth) -> Self {
25        match auth {
26            RegistryAuth::Anonymous => OciRegistryAuth::Anonymous,
27            RegistryAuth::Basic { username, password } => {
28                OciRegistryAuth::Basic(username, password)
29            }
30        }
31    }
32}
33
34/// Options for pulling an image
35#[derive(Debug, Clone)]
36pub struct PullOptions {
37    pub platform_os: String,
38    pub platform_arch: String,
39}
40
41impl Default for PullOptions {
42    fn default() -> Self {
43        Self {
44            platform_os: "linux".to_string(),
45            platform_arch: "amd64".to_string(),
46        }
47    }
48}
49
50/// Describes a layer in an OCI image
51#[derive(Debug, Clone, Serialize, Deserialize)]
52pub struct LayerDescriptor {
53    pub digest: String,
54    pub size: u64,
55    pub media_type: String,
56}
57
58impl LayerDescriptor {
59    fn to_oci_descriptor(&self) -> OciDescriptor {
60        OciDescriptor {
61            digest: self.digest.clone(),
62            size: self.size as i64,
63            media_type: self.media_type.clone(),
64            ..Default::default()
65        }
66    }
67}
68
69/// Image manifest with layers info
70#[derive(Debug, Clone)]
71pub struct ImageManifest {
72    pub config_digest: String,
73    pub layers: Vec<LayerDescriptor>,
74    pub total_size: u64,
75}
76
77/// Client for interacting with OCI registries
78pub struct RegistryClient {
79    client: Client,
80    auth: RegistryAuth,
81}
82
83impl RegistryClient {
84    pub fn new(auth: RegistryAuth) -> Self {
85        let config = ClientConfig {
86            protocol: ClientProtocol::Https,
87            ..Default::default()
88        };
89        let client = Client::new(config);
90        Self { client, auth }
91    }
92
93    pub fn parse_reference(image: &str) -> Result<Reference> {
94        image
95            .parse()
96            .with_context(|| format!("Failed to parse image reference: {}", image))
97    }
98
99    /// Fetch the manifest for an image
100    pub async fn fetch_manifest(
101        &self,
102        reference: &Reference,
103        options: &PullOptions,
104    ) -> Result<ImageManifest> {
105        info!("Fetching manifest for {}", reference);
106
107        let auth: OciRegistryAuth = self.auth.clone().into();
108
109        let (manifest, _digest) = self
110            .client
111            .pull_manifest(reference, &auth)
112            .await
113            .with_context(|| format!("Failed to pull manifest for {}", reference))?;
114
115        let oci_manifest = match manifest {
116            oci_distribution::manifest::OciManifest::Image(m) => m,
117            oci_distribution::manifest::OciManifest::ImageIndex(index) => {
118                // Multi-arch image, find the right platform
119                let platform_manifest = index
120                    .manifests
121                    .iter()
122                    .find(|m| {
123                        if let Some(p) = &m.platform {
124                            p.os == options.platform_os && p.architecture == options.platform_arch
125                        } else {
126                            false
127                        }
128                    })
129                    .with_context(|| {
130                        format!(
131                            "Platform {}/{} not found in image index",
132                            options.platform_os, options.platform_arch
133                        )
134                    })?;
135
136                debug!("Found platform manifest: {:?}", platform_manifest.digest);
137
138                // Create a reference with the specific digest
139                let platform_ref = Reference::with_digest(
140                    reference.registry().to_string(),
141                    reference.repository().to_string(),
142                    platform_manifest.digest.clone(),
143                );
144
145                let (platform_manifest, _) = self
146                    .client
147                    .pull_manifest(&platform_ref, &auth)
148                    .await
149                    .with_context(|| "Failed to pull platform-specific manifest")?;
150
151                match platform_manifest {
152                    oci_distribution::manifest::OciManifest::Image(m) => m,
153                    _ => anyhow::bail!("Expected image manifest, got index"),
154                }
155            }
156        };
157
158        let layers: Vec<LayerDescriptor> = oci_manifest
159            .layers
160            .iter()
161            .map(|l| LayerDescriptor {
162                digest: l.digest.clone(),
163                size: l.size as u64,
164                media_type: l.media_type.clone(),
165            })
166            .collect();
167
168        let total_size = layers.iter().map(|l| l.size).sum();
169
170        Ok(ImageManifest {
171            config_digest: oci_manifest.config.digest.clone(),
172            layers,
173            total_size,
174        })
175    }
176
177    /// Pull a specific layer and return its content as bytes
178    pub async fn pull_layer(
179        &self,
180        reference: &Reference,
181        layer: &LayerDescriptor,
182    ) -> Result<Vec<u8>> {
183        debug!("Pulling layer {} ({} bytes)", layer.digest, layer.size);
184
185        let _auth: OciRegistryAuth = self.auth.clone().into();
186        let descriptor = layer.to_oci_descriptor();
187
188        // Create a buffer to receive the blob data
189        let mut data = Vec::with_capacity(layer.size as usize);
190
191        self.client
192            .pull_blob(reference, &descriptor, &mut data)
193            .await
194            .with_context(|| format!("Failed to pull layer {}", layer.digest))?;
195
196        Ok(data)
197    }
198
199    /// Pull all layers and return them in order
200    pub async fn pull_all_layers(
201        &self,
202        reference: &Reference,
203        manifest: &ImageManifest,
204        progress_callback: Option<Arc<dyn Fn(usize, usize) + Send + Sync>>,
205    ) -> Result<Vec<Vec<u8>>> {
206        let mut layers_data = Vec::with_capacity(manifest.layers.len());
207        let total = manifest.layers.len();
208
209        for (idx, layer) in manifest.layers.iter().enumerate() {
210            if let Some(ref cb) = progress_callback {
211                cb(idx + 1, total);
212            }
213            let data = self.pull_layer(reference, layer).await?;
214            layers_data.push(data);
215        }
216
217        Ok(layers_data)
218    }
219}
220
221#[cfg(test)]
222mod tests {
223    use super::*;
224
225    #[test]
226    fn test_parse_reference_simple() {
227        let reference = RegistryClient::parse_reference("alpine:latest").unwrap();
228        assert_eq!(reference.repository(), "library/alpine");
229        assert_eq!(reference.tag(), Some("latest"));
230    }
231
232    #[test]
233    fn test_parse_reference_with_registry() {
234        let reference = RegistryClient::parse_reference("ghcr.io/user/repo:v1").unwrap();
235        assert_eq!(reference.registry(), "ghcr.io");
236        assert_eq!(reference.repository(), "user/repo");
237        assert_eq!(reference.tag(), Some("v1"));
238    }
239}