Skip to main content

warpgate/
loader.rs

1use crate::clients::*;
2use crate::helpers::{
3    create_cache_key, determine_cache_extension, download_from_url_to_file, move_or_unpack_download,
4};
5use crate::loader_error::WarpgateLoaderError;
6use crate::protocols::{FileLoader, GitHubLoader, HttpLoader, LoadFrom, LoaderProtocol, OciLoader};
7use crate::registry::RegistryConfig;
8use once_cell::sync::OnceCell;
9use starbase_styles::color;
10use starbase_utils::{fs, path};
11use std::fmt::Debug;
12use std::path::{Path, PathBuf};
13use std::sync::Arc;
14use std::time::Duration;
15use tracing::{instrument, trace, warn};
16use warpgate_api::{Id, PluginLocator};
17
18pub type OfflineChecker = Arc<fn() -> bool>;
19
20/// A system for loading plugins from a locator strategy,
21/// and caching the plugin file (`.wasm`) to the host's file system.
22#[derive(Clone)]
23pub struct PluginLoader {
24    /// Duration in seconds in which to cache downloaded plugins.
25    cache_duration: Duration,
26
27    /// Loader for referencing local plugins using file paths.
28    file_loader: OnceCell<FileLoader>,
29
30    /// Loader for downloading plugins from GitHub releases.
31    github_loader: OnceCell<GitHubLoader>,
32
33    /// Instance of our HTTP client.
34    http_client: OnceCell<Arc<HttpClient>>,
35
36    /// Loader for acquiring plugins from URLs.
37    http_loader: OnceCell<HttpLoader>,
38
39    /// Options to pass to the HTTP client.
40    http_options: HttpOptions,
41
42    /// Checks whether there's an internet connection or not.
43    offline_checker: Option<OfflineChecker>,
44
45    /// Location where acquired plugins are stored.
46    plugins_dir: PathBuf,
47
48    /// Location where temporary files (like archives) are stored.
49    temp_dir: PathBuf,
50
51    /// Plugin registry locations
52    registries: Vec<RegistryConfig>,
53
54    /// A unique seed for generating hashes.
55    seed: Option<String>,
56
57    /// OCI client instance.
58    oci_client: OnceCell<Arc<OciClient>>,
59
60    /// Loader from downloading plugins from OCI registries.
61    oci_loader: OnceCell<OciLoader>,
62}
63
64impl PluginLoader {
65    /// Create a new loader that stores plugins and downloads in the provided directories.
66    pub fn new<P: AsRef<Path>, T: AsRef<Path>>(plugins_dir: P, temp_dir: T) -> Self {
67        let plugins_dir = plugins_dir.as_ref();
68
69        trace!(cache_dir = ?plugins_dir, "Creating plugin loader");
70
71        Self {
72            cache_duration: Duration::from_secs(86400 * 30), // 30 days
73            file_loader: OnceCell::new(),
74            github_loader: OnceCell::new(),
75            http_client: OnceCell::new(),
76            http_loader: OnceCell::new(),
77            http_options: HttpOptions::default(),
78            oci_client: OnceCell::new(),
79            oci_loader: OnceCell::new(),
80            offline_checker: None,
81            plugins_dir: plugins_dir.to_owned(),
82            registries: vec![],
83            seed: None,
84            temp_dir: temp_dir.as_ref().to_owned(),
85        }
86    }
87
88    /// Add an OCI registry as a backend.
89    pub fn add_registry(&mut self, registry: RegistryConfig) {
90        self.registries.push(registry);
91    }
92
93    /// Add multiple OCI registries as a backend.
94    pub fn add_registries(&mut self, registries: Vec<RegistryConfig>) {
95        for registry in registries {
96            self.add_registry(registry);
97        }
98    }
99
100    /// Return a file loader for use with [`FileLocator`]s.
101    pub fn get_file_loader(&self) -> Result<&FileLoader, WarpgateLoaderError> {
102        self.file_loader.get_or_try_init(|| Ok(FileLoader {}))
103    }
104
105    /// Return a GitHub loader for use with [`GitHubLocator`]s.
106    pub fn get_github_loader(&self) -> Result<&GitHubLoader, WarpgateLoaderError> {
107        self.github_loader.get_or_try_init(|| {
108            Ok(GitHubLoader {
109                client: Arc::clone(self.get_http_client()?),
110            })
111        })
112    }
113
114    /// Return an HTTP loader for use with [`UrlLocator`]s.
115    pub fn get_http_loader(&self) -> Result<&HttpLoader, WarpgateLoaderError> {
116        self.http_loader.get_or_try_init(|| Ok(HttpLoader {}))
117    }
118
119    /// Return an OCI loader for use with [`RegistryLocator`]s.
120    pub fn get_oci_loader(&self) -> Result<&OciLoader, WarpgateLoaderError> {
121        self.oci_loader.get_or_try_init(|| {
122            Ok(OciLoader {
123                client: Arc::clone(self.get_oci_client()?),
124            })
125        })
126    }
127
128    /// Return the HTTP client, or create it if it does not exist.
129    pub fn get_http_client(&self) -> Result<&Arc<HttpClient>, WarpgateHttpClientError> {
130        self.http_client
131            .get_or_try_init(|| create_http_client_with_options(&self.http_options).map(Arc::new))
132    }
133
134    /// Return an OCI client, or create it if it does not exist.
135    pub fn get_oci_client(&self) -> Result<&Arc<OciClient>, WarpgateHttpClientError> {
136        self.oci_client
137            .get_or_try_init(|| Ok(Arc::new(OciClient::default())))
138    }
139
140    /// Load a plugin using the provided locator. File system plugins are loaded directly,
141    /// while remote/URL plugins are downloaded and cached.
142    #[instrument(skip(self))]
143    pub async fn load_plugin<I: AsRef<Id> + Debug, L: AsRef<PluginLocator> + Debug>(
144        &self,
145        id: I,
146        locator: L,
147    ) -> Result<PathBuf, WarpgateLoaderError> {
148        let id = id.as_ref();
149        let locator = locator.as_ref();
150
151        trace!(
152            id = id.as_str(),
153            locator = locator.to_string(),
154            "Loading plugin {}",
155            color::id(id.as_str())
156        );
157
158        // Determine the source location
159        let (source, is_latest) = match locator {
160            PluginLocator::File(file) => {
161                let loader = self.get_file_loader()?;
162
163                (loader.load(id, file, &()).await?, loader.is_latest(file))
164            }
165            PluginLocator::GitHub(github) => {
166                let loader = self.get_github_loader()?;
167
168                (
169                    loader.load(id, github, &()).await?,
170                    loader.is_latest(github),
171                )
172            }
173            PluginLocator::Url(url) => {
174                let loader = self.get_http_loader()?;
175
176                (loader.load(id, url, &()).await?, loader.is_latest(url))
177            }
178            PluginLocator::Registry(registry) => {
179                let loader = self.get_oci_loader()?;
180
181                (
182                    loader.load(id, registry, &self.registries).await?,
183                    loader.is_latest(registry),
184                )
185            }
186        };
187
188        // Check if destination already exists
189        let cache_path = match &source {
190            LoadFrom::Blob { ext, hash, .. } => self.create_cache_path(id, hash, ext, is_latest),
191            LoadFrom::File(path) => {
192                // Files ignore caching rules
193                return Ok(path.to_path_buf());
194            }
195            LoadFrom::Url(url) => self.create_cache_path(
196                id,
197                create_cache_key(url, self.seed.as_deref()).as_str(),
198                determine_cache_extension(url).unwrap_or(".wasm"),
199                is_latest,
200            ),
201        };
202
203        if self.is_cached(id, &cache_path)? {
204            return Ok(cache_path);
205        }
206
207        // Acquire the source and write to the destination
208        match source {
209            LoadFrom::Blob { data, .. } => {
210                fs::write_file(&cache_path, data)?;
211            }
212            LoadFrom::Url(url) => {
213                self.download_plugin(id, &url, &cache_path).await?;
214            }
215            _ => {}
216        };
217
218        Ok(cache_path)
219    }
220
221    /// Create an absolute path to the plugin's destination file,
222    /// located in the cached plugins directory.
223    pub fn create_cache_path(&self, id: &Id, hash: &str, ext: &str, is_latest: bool) -> PathBuf {
224        self.plugins_dir.join(format!(
225            "{}{}{hash}{ext}",
226            path::encode_component(id.as_str()),
227            if is_latest { "-latest-" } else { "-" },
228        ))
229    }
230
231    /// Check if the plugin has been acquired and is cached.
232    /// If using a latest strategy (no explicit version or tag), the cache
233    /// is only valid for a duration (to ensure not stale), otherwise forever.
234    #[instrument(name = "is_plugin_cached", skip(self))]
235    pub fn is_cached(&self, id: &Id, path: &Path) -> Result<bool, WarpgateLoaderError> {
236        if !path.exists() {
237            trace!(id = id.as_str(), "Plugin not cached, acquiring");
238
239            return Ok(false);
240        }
241
242        if self.cache_duration.is_zero() {
243            trace!(
244                id = id.as_str(),
245                "Plugin caching has been disabled, acquiring"
246            );
247
248            return Ok(false);
249        }
250
251        let mut cached = !fs::is_stale(path, false, self.cache_duration)?;
252
253        if !cached && self.is_offline() {
254            cached = true;
255        }
256
257        if !cached && path.exists() {
258            fs::remove_file(path)?;
259        }
260
261        if cached {
262            trace!(id = id.as_str(), path = ?path, "Plugin already acquired and cached");
263        } else {
264            trace!(id = id.as_str(), path = ?path, "Plugin cached but stale, re-acquiring");
265        }
266
267        Ok(cached)
268    }
269
270    /// Check for an internet connection.
271    pub fn is_offline(&self) -> bool {
272        self.offline_checker
273            .as_ref()
274            .map(|op| op())
275            .unwrap_or_default()
276    }
277
278    /// Set the cache duration.
279    pub fn set_cache_duration(&mut self, duration: Duration) {
280        self.cache_duration = duration;
281    }
282
283    /// Set the options to pass to the HTTP client.
284    pub fn set_http_client_options(&mut self, options: &HttpOptions) {
285        options.clone_into(&mut self.http_options);
286    }
287
288    /// Set the function that checks for offline state.
289    pub fn set_offline_checker(&mut self, op: fn() -> bool) {
290        self.offline_checker = Some(Arc::new(op));
291    }
292
293    /// Set the provided value as a seed for generating hashes.
294    pub fn set_seed(&mut self, value: &str) {
295        self.seed = Some(value.to_owned());
296    }
297
298    #[instrument(skip(self))]
299    async fn download_plugin(
300        &self,
301        id: &Id,
302        source_url: &str,
303        dest_file: &Path,
304    ) -> Result<(), WarpgateLoaderError> {
305        if self.is_offline() {
306            return Err(WarpgateLoaderError::RequiredInternetConnection {
307                message: "Unable to download plugin.".into(),
308                url: source_url.to_owned(),
309            });
310        }
311
312        trace!(
313            id = id.as_str(),
314            from = source_url,
315            to = ?dest_file,
316            "Downloading plugin from URL"
317        );
318
319        let temp_file = self.temp_dir.join(fs::file_name(dest_file));
320
321        download_from_url_to_file(source_url, &temp_file, self.get_http_client()?).await?;
322        move_or_unpack_download(&temp_file, dest_file)?;
323
324        Ok(())
325    }
326}