1use std::{
2 collections::{hash_map::Entry, HashMap},
3 fmt::Display,
4 path::{Path, PathBuf},
5 sync::Arc,
6};
7
8use anyhow::{anyhow, Result};
9use oci_spec::image::Descriptor;
10use tokio::{
11 sync::{watch, Mutex},
12 task::JoinHandle,
13};
14
15use crate::{
16 assemble::OciImageAssembler,
17 fetch::{OciImageFetcher, OciResolvedImage},
18 name::ImageName,
19 progress::{OciBoundProgress, OciProgress, OciProgressContext},
20 registry::OciPlatform,
21};
22
23use log::{error, info, warn};
24
25use super::{cache::OciPackerCache, OciPackedFormat, OciPackedImage};
26
27pub struct OciPackerTask {
28 progress: OciBoundProgress,
29 watch: watch::Sender<Option<Result<OciPackedImage>>>,
30 task: JoinHandle<()>,
31}
32
33#[derive(Clone)]
34pub struct OciPackerService {
35 seed: Option<PathBuf>,
36 platform: OciPlatform,
37 cache: OciPackerCache,
38 tasks: Arc<Mutex<HashMap<OciPackerTaskKey, OciPackerTask>>>,
39}
40
41impl OciPackerService {
42 pub async fn new(
43 seed: Option<PathBuf>,
44 cache_dir: &Path,
45 platform: OciPlatform,
46 ) -> Result<OciPackerService> {
47 Ok(OciPackerService {
48 seed,
49 cache: OciPackerCache::new(cache_dir).await?,
50 platform,
51 tasks: Arc::new(Mutex::new(HashMap::new())),
52 })
53 }
54
55 pub async fn list(&self) -> Result<Vec<Descriptor>> {
56 self.cache.list().await
57 }
58
59 pub async fn recall(
60 &self,
61 digest: &str,
62 format: OciPackedFormat,
63 ) -> Result<Option<OciPackedImage>> {
64 if digest.contains('/') || digest.contains('\\') || digest.contains("..") {
65 return Ok(None);
66 }
67
68 self.cache
69 .recall(ImageName::parse("cached:latest")?, digest, format)
70 .await
71 }
72
73 pub async fn request(
74 &self,
75 name: ImageName,
76 format: OciPackedFormat,
77 overwrite: bool,
78 pull: bool,
79 progress_context: OciProgressContext,
80 ) -> Result<OciPackedImage> {
81 let progress = OciProgress::new();
82 let progress = OciBoundProgress::new(progress_context.clone(), progress);
83 let mut resolved = None;
84 if !pull && !overwrite {
85 resolved = self.cache.resolve(name.clone(), format).await?;
86 }
87 let fetcher =
88 OciImageFetcher::new(self.seed.clone(), self.platform.clone(), progress.clone());
89 let resolved = if let Some(resolved) = resolved {
90 resolved
91 } else {
92 fetcher.resolve(name.clone()).await?
93 };
94
95 let key = OciPackerTaskKey {
96 digest: resolved.digest.clone(),
97 format,
98 };
99 let (progress_copy_task, mut receiver) = match self.tasks.lock().await.entry(key.clone()) {
100 Entry::Occupied(entry) => {
101 let entry = entry.get();
102 (
103 Some(entry.progress.also_update(progress_context).await),
104 entry.watch.subscribe(),
105 )
106 }
107
108 Entry::Vacant(entry) => {
109 let task = self
110 .clone()
111 .launch(
112 name,
113 key.clone(),
114 format,
115 overwrite,
116 resolved,
117 fetcher,
118 progress.clone(),
119 )
120 .await;
121 let (watch, receiver) = watch::channel(None);
122
123 let task = OciPackerTask {
124 progress: progress.clone(),
125 task,
126 watch,
127 };
128 entry.insert(task);
129 (None, receiver)
130 }
131 };
132
133 let _progress_task_guard = scopeguard::guard(progress_copy_task, |task| {
134 if let Some(task) = task {
135 task.abort();
136 }
137 });
138
139 let _task_cancel_guard = scopeguard::guard(self.clone(), |service| {
140 service.maybe_cancel_task(key);
141 });
142
143 loop {
144 receiver.changed().await?;
145 let current = receiver.borrow_and_update();
146 if current.is_some() {
147 return current
148 .as_ref()
149 .map(|x| x.as_ref().map_err(|err| anyhow!("{}", err)).cloned())
150 .unwrap();
151 }
152 }
153 }
154
155 #[allow(clippy::too_many_arguments)]
156 async fn launch(
157 self,
158 name: ImageName,
159 key: OciPackerTaskKey,
160 format: OciPackedFormat,
161 overwrite: bool,
162 resolved: OciResolvedImage,
163 fetcher: OciImageFetcher,
164 progress: OciBoundProgress,
165 ) -> JoinHandle<()> {
166 info!("started packer task {}", key);
167 tokio::task::spawn(async move {
168 let _task_drop_guard =
169 scopeguard::guard((key.clone(), self.clone()), |(key, service)| {
170 service.ensure_task_gone(key);
171 });
172 if let Err(error) = self
173 .task(
174 name,
175 key.clone(),
176 format,
177 overwrite,
178 resolved,
179 fetcher,
180 progress,
181 )
182 .await
183 {
184 self.finish(&key, Err(error)).await;
185 }
186 })
187 }
188
189 #[allow(clippy::too_many_arguments)]
190 async fn task(
191 &self,
192 name: ImageName,
193 key: OciPackerTaskKey,
194 format: OciPackedFormat,
195 overwrite: bool,
196 resolved: OciResolvedImage,
197 fetcher: OciImageFetcher,
198 progress: OciBoundProgress,
199 ) -> Result<()> {
200 if !overwrite {
201 if let Some(cached) = self
202 .cache
203 .recall(name.clone(), &resolved.digest, format)
204 .await?
205 {
206 self.finish(&key, Ok(cached)).await;
207 return Ok(());
208 }
209 }
210 let assembler =
211 OciImageAssembler::new(fetcher, resolved, progress.clone(), None, None).await?;
212 let assembled = assembler.assemble().await?;
213 let mut file = assembled
214 .tmp_dir
215 .clone()
216 .ok_or(anyhow!("tmp_dir was missing when packing image"))?;
217 file.push("image.pack");
218 let target = file.clone();
219 let packer = format.backend().create();
220 packer
221 .pack(progress, assembled.vfs.clone(), &target)
222 .await?;
223 let packed = OciPackedImage::new(
224 name,
225 assembled.digest.clone(),
226 file,
227 format,
228 assembled.descriptor.clone(),
229 assembled.config.clone(),
230 assembled.manifest.clone(),
231 );
232 let packed = self.cache.store(packed).await?;
233 self.finish(&key, Ok(packed)).await;
234 Ok(())
235 }
236
237 async fn finish(&self, key: &OciPackerTaskKey, result: Result<OciPackedImage>) {
238 let Some(task) = self.tasks.lock().await.remove(key) else {
239 error!("packer task {} was not found when task completed", key);
240 return;
241 };
242
243 match result.as_ref() {
244 Ok(_) => {
245 info!("completed packer task {}", key);
246 }
247
248 Err(err) => {
249 warn!("packer task {} failed: {}", key, err);
250 }
251 }
252
253 task.watch.send_replace(Some(result));
254 }
255
256 fn maybe_cancel_task(self, key: OciPackerTaskKey) {
257 tokio::task::spawn(async move {
258 let tasks = self.tasks.lock().await;
259 if let Some(task) = tasks.get(&key) {
260 if task.watch.is_closed() {
261 task.task.abort();
262 }
263 }
264 });
265 }
266
267 fn ensure_task_gone(self, key: OciPackerTaskKey) {
268 tokio::task::spawn(async move {
269 let mut tasks = self.tasks.lock().await;
270 if let Some(task) = tasks.remove(&key) {
271 warn!("aborted packer task {}", key);
272 task.watch.send_replace(Some(Err(anyhow!("task aborted"))));
273 }
274 });
275 }
276}
277
278#[derive(Debug, Clone, Eq, PartialEq, Hash)]
279struct OciPackerTaskKey {
280 digest: String,
281 format: OciPackedFormat,
282}
283
284impl Display for OciPackerTaskKey {
285 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
286 f.write_fmt(format_args!("{}:{}", self.digest, self.format.extension()))
287 }
288}