initramfs_builder/registry/
client.rs1use anyhow::{Context, Result};
2use oci_distribution::{
3 client::{Client, ClientConfig, ClientProtocol},
4 manifest::OciDescriptor,
5 secrets::RegistryAuth as OciRegistryAuth,
6 Reference,
7};
8use serde::{Deserialize, Serialize};
9use std::sync::Arc;
10use tracing::{debug, info};
11
12#[derive(Debug, Clone, Default)]
14pub enum RegistryAuth {
15 #[default]
16 Anonymous,
17 Basic {
18 username: String,
19 password: String,
20 },
21}
22
23impl From<RegistryAuth> for OciRegistryAuth {
24 fn from(auth: RegistryAuth) -> Self {
25 match auth {
26 RegistryAuth::Anonymous => OciRegistryAuth::Anonymous,
27 RegistryAuth::Basic { username, password } => {
28 OciRegistryAuth::Basic(username, password)
29 }
30 }
31 }
32}
33
34#[derive(Debug, Clone)]
36pub struct PullOptions {
37 pub platform_os: String,
38 pub platform_arch: String,
39}
40
41impl Default for PullOptions {
42 fn default() -> Self {
43 Self {
44 platform_os: "linux".to_string(),
45 platform_arch: "amd64".to_string(),
46 }
47 }
48}
49
50#[derive(Debug, Clone, Serialize, Deserialize)]
52pub struct LayerDescriptor {
53 pub digest: String,
54 pub size: u64,
55 pub media_type: String,
56}
57
58impl LayerDescriptor {
59 fn to_oci_descriptor(&self) -> OciDescriptor {
60 OciDescriptor {
61 digest: self.digest.clone(),
62 size: self.size as i64,
63 media_type: self.media_type.clone(),
64 ..Default::default()
65 }
66 }
67}
68
69#[derive(Debug, Clone)]
71pub struct ImageManifest {
72 pub config_digest: String,
73 pub layers: Vec<LayerDescriptor>,
74 pub total_size: u64,
75}
76
77pub struct RegistryClient {
79 client: Client,
80 auth: RegistryAuth,
81}
82
83impl RegistryClient {
84 pub fn new(auth: RegistryAuth) -> Self {
85 let config = ClientConfig {
86 protocol: ClientProtocol::Https,
87 ..Default::default()
88 };
89 let client = Client::new(config);
90 Self { client, auth }
91 }
92
93 pub fn parse_reference(image: &str) -> Result<Reference> {
94 image
95 .parse()
96 .with_context(|| format!("Failed to parse image reference: {}", image))
97 }
98
99 pub async fn fetch_manifest(
101 &self,
102 reference: &Reference,
103 options: &PullOptions,
104 ) -> Result<ImageManifest> {
105 info!("Fetching manifest for {}", reference);
106
107 let auth: OciRegistryAuth = self.auth.clone().into();
108
109 let (manifest, _digest) = self
110 .client
111 .pull_manifest(reference, &auth)
112 .await
113 .with_context(|| format!("Failed to pull manifest for {}", reference))?;
114
115 let oci_manifest = match manifest {
116 oci_distribution::manifest::OciManifest::Image(m) => m,
117 oci_distribution::manifest::OciManifest::ImageIndex(index) => {
118 let platform_manifest = index
120 .manifests
121 .iter()
122 .find(|m| {
123 if let Some(p) = &m.platform {
124 p.os == options.platform_os && p.architecture == options.platform_arch
125 } else {
126 false
127 }
128 })
129 .with_context(|| {
130 format!(
131 "Platform {}/{} not found in image index",
132 options.platform_os, options.platform_arch
133 )
134 })?;
135
136 debug!("Found platform manifest: {:?}", platform_manifest.digest);
137
138 let platform_ref = Reference::with_digest(
140 reference.registry().to_string(),
141 reference.repository().to_string(),
142 platform_manifest.digest.clone(),
143 );
144
145 let (platform_manifest, _) = self
146 .client
147 .pull_manifest(&platform_ref, &auth)
148 .await
149 .with_context(|| "Failed to pull platform-specific manifest")?;
150
151 match platform_manifest {
152 oci_distribution::manifest::OciManifest::Image(m) => m,
153 _ => anyhow::bail!("Expected image manifest, got index"),
154 }
155 }
156 };
157
158 let layers: Vec<LayerDescriptor> = oci_manifest
159 .layers
160 .iter()
161 .map(|l| LayerDescriptor {
162 digest: l.digest.clone(),
163 size: l.size as u64,
164 media_type: l.media_type.clone(),
165 })
166 .collect();
167
168 let total_size = layers.iter().map(|l| l.size).sum();
169
170 Ok(ImageManifest {
171 config_digest: oci_manifest.config.digest.clone(),
172 layers,
173 total_size,
174 })
175 }
176
177 pub async fn pull_layer(
179 &self,
180 reference: &Reference,
181 layer: &LayerDescriptor,
182 ) -> Result<Vec<u8>> {
183 debug!("Pulling layer {} ({} bytes)", layer.digest, layer.size);
184
185 let _auth: OciRegistryAuth = self.auth.clone().into();
186 let descriptor = layer.to_oci_descriptor();
187
188 let mut data = Vec::with_capacity(layer.size as usize);
190
191 self.client
192 .pull_blob(reference, &descriptor, &mut data)
193 .await
194 .with_context(|| format!("Failed to pull layer {}", layer.digest))?;
195
196 Ok(data)
197 }
198
199 pub async fn pull_all_layers(
201 &self,
202 reference: &Reference,
203 manifest: &ImageManifest,
204 progress_callback: Option<Arc<dyn Fn(usize, usize) + Send + Sync>>,
205 ) -> Result<Vec<Vec<u8>>> {
206 let mut layers_data = Vec::with_capacity(manifest.layers.len());
207 let total = manifest.layers.len();
208
209 for (idx, layer) in manifest.layers.iter().enumerate() {
210 if let Some(ref cb) = progress_callback {
211 cb(idx + 1, total);
212 }
213 let data = self.pull_layer(reference, layer).await?;
214 layers_data.push(data);
215 }
216
217 Ok(layers_data)
218 }
219}
220
221#[cfg(test)]
222mod tests {
223 use super::*;
224
225 #[test]
226 fn test_parse_reference_simple() {
227 let reference = RegistryClient::parse_reference("alpine:latest").unwrap();
228 assert_eq!(reference.repository(), "library/alpine");
229 assert_eq!(reference.tag(), Some("latest"));
230 }
231
232 #[test]
233 fn test_parse_reference_with_registry() {
234 let reference = RegistryClient::parse_reference("ghcr.io/user/repo:v1").unwrap();
235 assert_eq!(reference.registry(), "ghcr.io");
236 assert_eq!(reference.repository(), "user/repo");
237 assert_eq!(reference.tag(), Some("v1"));
238 }
239}