krataoci/packer/
service.rs

1use std::{
2    collections::{hash_map::Entry, HashMap},
3    fmt::Display,
4    path::{Path, PathBuf},
5    sync::Arc,
6};
7
8use anyhow::{anyhow, Result};
9use oci_spec::image::Descriptor;
10use tokio::{
11    sync::{watch, Mutex},
12    task::JoinHandle,
13};
14
15use crate::{
16    assemble::OciImageAssembler,
17    fetch::{OciImageFetcher, OciResolvedImage},
18    name::ImageName,
19    progress::{OciBoundProgress, OciProgress, OciProgressContext},
20    registry::OciPlatform,
21};
22
23use log::{error, info, warn};
24
25use super::{cache::OciPackerCache, OciPackedFormat, OciPackedImage};
26
27pub struct OciPackerTask {
28    progress: OciBoundProgress,
29    watch: watch::Sender<Option<Result<OciPackedImage>>>,
30    task: JoinHandle<()>,
31}
32
33#[derive(Clone)]
34pub struct OciPackerService {
35    seed: Option<PathBuf>,
36    platform: OciPlatform,
37    cache: OciPackerCache,
38    tasks: Arc<Mutex<HashMap<OciPackerTaskKey, OciPackerTask>>>,
39}
40
41impl OciPackerService {
42    pub async fn new(
43        seed: Option<PathBuf>,
44        cache_dir: &Path,
45        platform: OciPlatform,
46    ) -> Result<OciPackerService> {
47        Ok(OciPackerService {
48            seed,
49            cache: OciPackerCache::new(cache_dir).await?,
50            platform,
51            tasks: Arc::new(Mutex::new(HashMap::new())),
52        })
53    }
54
55    pub async fn list(&self) -> Result<Vec<Descriptor>> {
56        self.cache.list().await
57    }
58
59    pub async fn recall(
60        &self,
61        digest: &str,
62        format: OciPackedFormat,
63    ) -> Result<Option<OciPackedImage>> {
64        if digest.contains('/') || digest.contains('\\') || digest.contains("..") {
65            return Ok(None);
66        }
67
68        self.cache
69            .recall(ImageName::parse("cached:latest")?, digest, format)
70            .await
71    }
72
73    pub async fn request(
74        &self,
75        name: ImageName,
76        format: OciPackedFormat,
77        overwrite: bool,
78        pull: bool,
79        progress_context: OciProgressContext,
80    ) -> Result<OciPackedImage> {
81        let progress = OciProgress::new();
82        let progress = OciBoundProgress::new(progress_context.clone(), progress);
83        let mut resolved = None;
84        if !pull && !overwrite {
85            resolved = self.cache.resolve(name.clone(), format).await?;
86        }
87        let fetcher =
88            OciImageFetcher::new(self.seed.clone(), self.platform.clone(), progress.clone());
89        let resolved = if let Some(resolved) = resolved {
90            resolved
91        } else {
92            fetcher.resolve(name.clone()).await?
93        };
94
95        let key = OciPackerTaskKey {
96            digest: resolved.digest.clone(),
97            format,
98        };
99        let (progress_copy_task, mut receiver) = match self.tasks.lock().await.entry(key.clone()) {
100            Entry::Occupied(entry) => {
101                let entry = entry.get();
102                (
103                    Some(entry.progress.also_update(progress_context).await),
104                    entry.watch.subscribe(),
105                )
106            }
107
108            Entry::Vacant(entry) => {
109                let task = self
110                    .clone()
111                    .launch(
112                        name,
113                        key.clone(),
114                        format,
115                        overwrite,
116                        resolved,
117                        fetcher,
118                        progress.clone(),
119                    )
120                    .await;
121                let (watch, receiver) = watch::channel(None);
122
123                let task = OciPackerTask {
124                    progress: progress.clone(),
125                    task,
126                    watch,
127                };
128                entry.insert(task);
129                (None, receiver)
130            }
131        };
132
133        let _progress_task_guard = scopeguard::guard(progress_copy_task, |task| {
134            if let Some(task) = task {
135                task.abort();
136            }
137        });
138
139        let _task_cancel_guard = scopeguard::guard(self.clone(), |service| {
140            service.maybe_cancel_task(key);
141        });
142
143        loop {
144            receiver.changed().await?;
145            let current = receiver.borrow_and_update();
146            if current.is_some() {
147                return current
148                    .as_ref()
149                    .map(|x| x.as_ref().map_err(|err| anyhow!("{}", err)).cloned())
150                    .unwrap();
151            }
152        }
153    }
154
155    #[allow(clippy::too_many_arguments)]
156    async fn launch(
157        self,
158        name: ImageName,
159        key: OciPackerTaskKey,
160        format: OciPackedFormat,
161        overwrite: bool,
162        resolved: OciResolvedImage,
163        fetcher: OciImageFetcher,
164        progress: OciBoundProgress,
165    ) -> JoinHandle<()> {
166        info!("started packer task {}", key);
167        tokio::task::spawn(async move {
168            let _task_drop_guard =
169                scopeguard::guard((key.clone(), self.clone()), |(key, service)| {
170                    service.ensure_task_gone(key);
171                });
172            if let Err(error) = self
173                .task(
174                    name,
175                    key.clone(),
176                    format,
177                    overwrite,
178                    resolved,
179                    fetcher,
180                    progress,
181                )
182                .await
183            {
184                self.finish(&key, Err(error)).await;
185            }
186        })
187    }
188
189    #[allow(clippy::too_many_arguments)]
190    async fn task(
191        &self,
192        name: ImageName,
193        key: OciPackerTaskKey,
194        format: OciPackedFormat,
195        overwrite: bool,
196        resolved: OciResolvedImage,
197        fetcher: OciImageFetcher,
198        progress: OciBoundProgress,
199    ) -> Result<()> {
200        if !overwrite {
201            if let Some(cached) = self
202                .cache
203                .recall(name.clone(), &resolved.digest, format)
204                .await?
205            {
206                self.finish(&key, Ok(cached)).await;
207                return Ok(());
208            }
209        }
210        let assembler =
211            OciImageAssembler::new(fetcher, resolved, progress.clone(), None, None).await?;
212        let assembled = assembler.assemble().await?;
213        let mut file = assembled
214            .tmp_dir
215            .clone()
216            .ok_or(anyhow!("tmp_dir was missing when packing image"))?;
217        file.push("image.pack");
218        let target = file.clone();
219        let packer = format.backend().create();
220        packer
221            .pack(progress, assembled.vfs.clone(), &target)
222            .await?;
223        let packed = OciPackedImage::new(
224            name,
225            assembled.digest.clone(),
226            file,
227            format,
228            assembled.descriptor.clone(),
229            assembled.config.clone(),
230            assembled.manifest.clone(),
231        );
232        let packed = self.cache.store(packed).await?;
233        self.finish(&key, Ok(packed)).await;
234        Ok(())
235    }
236
237    async fn finish(&self, key: &OciPackerTaskKey, result: Result<OciPackedImage>) {
238        let Some(task) = self.tasks.lock().await.remove(key) else {
239            error!("packer task {} was not found when task completed", key);
240            return;
241        };
242
243        match result.as_ref() {
244            Ok(_) => {
245                info!("completed packer task {}", key);
246            }
247
248            Err(err) => {
249                warn!("packer task {} failed: {}", key, err);
250            }
251        }
252
253        task.watch.send_replace(Some(result));
254    }
255
256    fn maybe_cancel_task(self, key: OciPackerTaskKey) {
257        tokio::task::spawn(async move {
258            let tasks = self.tasks.lock().await;
259            if let Some(task) = tasks.get(&key) {
260                if task.watch.is_closed() {
261                    task.task.abort();
262                }
263            }
264        });
265    }
266
267    fn ensure_task_gone(self, key: OciPackerTaskKey) {
268        tokio::task::spawn(async move {
269            let mut tasks = self.tasks.lock().await;
270            if let Some(task) = tasks.remove(&key) {
271                warn!("aborted packer task {}", key);
272                task.watch.send_replace(Some(Err(anyhow!("task aborted"))));
273            }
274        });
275    }
276}
277
278#[derive(Debug, Clone, Eq, PartialEq, Hash)]
279struct OciPackerTaskKey {
280    digest: String,
281    format: OciPackedFormat,
282}
283
284impl Display for OciPackerTaskKey {
285    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
286        f.write_fmt(format_args!("{}:{}", self.digest, self.format.extension()))
287    }
288}