wash_cli/
plugin.rs

1use std::path::PathBuf;
2
3use anyhow::Context;
4use clap::{Parser, Subcommand};
5use futures::TryStreamExt;
6use oci_client::Reference;
7use sha2::{Digest, Sha256};
8use tokio::io::AsyncWriteExt;
9use wash_lib::{
10    cli::{registry::AuthOpts, CommandOutput, OutputKind},
11    registry::{pull_oci_artifact, OciPullOptions},
12};
13
14use crate::{
15    appearance::spinner::Spinner,
16    ctl::plugins_table,
17    util::{ensure_plugin_dir, load_plugins},
18};
19
20#[derive(Debug, Clone, Subcommand)]
21pub enum PluginCommand {
22    /// Install a wash plugin
23    #[clap(name = "install")]
24    Install(PluginInstallCommand),
25    /// Uninstall a plugin
26    #[clap(name = "uninstall", alias = "delete", alias = "rm")]
27    Uninstall(PluginUninstallCommand),
28    /// List installed plugins
29    #[clap(name = "list", alias = "ls")]
30    List(PluginListCommand),
31}
32
33#[derive(Parser, Debug, Clone)]
34pub struct PluginCommonOpts {
35    /// Path to plugin directory. Defaults to $HOME/.wash/plugins.
36    #[clap(long = "plugin-dir", env = "WASH_PLUGIN_DIR")]
37    pub plugin_dir: Option<PathBuf>,
38}
39
40#[derive(Debug, Clone, Parser)]
41pub struct PluginInstallCommand {
42    #[clap(flatten)]
43    pub oci_auth: AuthOpts,
44
45    /// URL of the plugin to install. Can be a file://, http://, https://, or oci:// URL.
46    #[clap(name = "url")]
47    pub url: String,
48
49    /// Digest to verify plugin against. For OCI manifests, this is the digest format used in the
50    /// manifest. For other types of plugins, this is the SHA256 digest of the plugin binary.
51    #[clap(short = 'd', long = "digest")]
52    pub digest: Option<String>,
53
54    /// Allow latest artifact tags (if pulling from OCI registry)
55    #[clap(long = "allow-latest")]
56    pub allow_latest: bool,
57
58    /// Whether or not to update the plugin if it is already installed. Defaults to false
59    #[clap(long = "update")]
60    pub update: bool,
61
62    #[clap(flatten)]
63    pub opts: PluginCommonOpts,
64}
65
66#[derive(Debug, Clone, Parser)]
67pub struct PluginUninstallCommand {
68    /// ID of the plugin to uninstall
69    #[clap(name = "id")]
70    pub plugin: String,
71
72    #[clap(flatten)]
73    pub opts: PluginCommonOpts,
74}
75
76#[derive(Debug, Clone, Parser)]
77pub struct PluginListCommand {
78    #[clap(flatten)]
79    pub opts: PluginCommonOpts,
80}
81
82pub async fn handle_command(
83    cmd: PluginCommand,
84    output_kind: OutputKind,
85) -> anyhow::Result<CommandOutput> {
86    match cmd {
87        PluginCommand::Install(cmd) => handle_install(cmd, output_kind).await,
88        PluginCommand::Uninstall(cmd) => handle_uninstall(cmd, output_kind).await,
89        PluginCommand::List(cmd) => handle_list(cmd, output_kind).await,
90    }
91}
92
93pub async fn handle_install(
94    cmd: PluginInstallCommand,
95    output_kind: OutputKind,
96) -> anyhow::Result<CommandOutput> {
97    let plugin_dir = ensure_plugin_dir(cmd.opts.plugin_dir).await?;
98    let spinner = Spinner::new(&output_kind)?;
99    // Write the data to a temp file that will be cleaned and then we can move it to its real
100    // location if everything is successful.
101    let tempdir =
102        tempfile::tempdir().context("Unable to create temp directory for plugin download")?;
103    let temp_location = tempdir.path().join("temp_plugin.wasm");
104    let mut file = tokio::fs::OpenOptions::new()
105        .create(true)
106        .truncate(true)
107        .write(true)
108        .read(true)
109        .open(&temp_location)
110        .await
111        .context("Unable to create temp file for plugin download")?;
112
113    let (scheme, rest) = cmd
114        .url
115        .split_once("://")
116        .context("Invalid URL. It should contain a scheme (e.g. file://)")?;
117
118    // OCI checks the digest on pull, so we have to return whether to check the digest for the others
119    let compute_digest = match scheme {
120        "file" => {
121            let path = PathBuf::from(rest);
122            spinner.update_spinner_message(format!(" Opening plugin from {}", path.display()));
123            let mut existing_file = tokio::fs::File::open(&path)
124                .await
125                .context(format!("Unable to open plugin file at {}", path.display()))?;
126            // NOTE(thomastaylor312): This is less efficient than just opening the file as a plugin,
127            // but simplifies the logic so we can just move it at the end with all of the other
128            // checks we have to do. We could also just read bytes in and load those, but that also
129            // results in a whole bunch of extra code in the plugin runner code that we would only
130            // need for this specific subcommand. We can improve this later if we need to.
131            tokio::io::copy(&mut existing_file, &mut file)
132                .await
133                .context("Unable to copy plugin file")?;
134            cmd.digest
135        }
136        "http" | "https" => {
137            spinner.update_spinner_message(format!(" Downloading plugin from URL {}", cmd.url));
138            let resp = reqwest::get(&cmd.url)
139                .await
140                .context("Unable to perform http request")?;
141            if !resp.status().is_success() {
142                anyhow::bail!(
143                    "Unable to fetch plugin from {}. HTTP status code: {}",
144                    cmd.url,
145                    resp.status()
146                );
147            }
148            let mut stream_reader = tokio_util::io::StreamReader::new(
149                resp.bytes_stream()
150                    .map_err(|err| std::io::Error::new(std::io::ErrorKind::Other, err)),
151            );
152            tokio::io::copy(&mut stream_reader, &mut file)
153                .await
154                .context("Unable to save plugin file to disk")?;
155
156            cmd.digest
157        }
158        "oci" => {
159            spinner.update_spinner_message(format!(" Downloading plugin from registry {}", rest));
160            let image: Reference = rest
161                .trim()
162                .to_ascii_lowercase()
163                .parse()
164                .context("Invalid image reference")?;
165
166            // TODO: Add support for pulling via stream to wash_lib
167            let image_data = pull_oci_artifact(
168                &image,
169                OciPullOptions {
170                    digest: cmd.digest.clone(),
171                    allow_latest: cmd.allow_latest,
172                    user: cmd.oci_auth.user,
173                    password: cmd.oci_auth.password,
174                    insecure: cmd.oci_auth.insecure,
175                    insecure_skip_tls_verify: cmd.oci_auth.insecure_skip_tls_verify,
176                },
177            )
178            .await
179            .context("Unable to pull plugin from registry")?;
180            file.write_all(&image_data)
181                .await
182                .context("Unable to write plugin to file")?;
183
184            None
185        }
186        _ => {
187            anyhow::bail!("Invalid URL scheme: {}", scheme);
188        }
189    };
190
191    // Flush the file to make sure we're done writing to it.
192    file.flush()
193        .await
194        .context("Unable to flush plugin file to disk")?;
195    file.shutdown()
196        .await
197        .context("Unable to shutdown plugin file")?;
198
199    // Check the digest if we have one
200    if let Some(expected_digest) = compute_digest {
201        spinner.update_spinner_message(" Computing digest");
202        let mut digest = Sha256::new();
203        let data = tokio::fs::read(&temp_location)
204            .await
205            .context("Unable to read plugin data for digest computation")?;
206        digest.update(data);
207        let hash = format!("{:x}", digest.finalize());
208        let sanitized = expected_digest.trim().to_lowercase();
209        anyhow::ensure!(
210            hash != sanitized,
211            "Digest mismatch. Expected {sanitized}, got {hash}"
212        );
213    }
214
215    spinner.update_spinner_message(" Loading existing plugins");
216    // Load existing plugins so we can check for duplicates.
217    let mut plugins = load_plugins(&plugin_dir)
218        .await
219        .context("Unable to load existing plugins")?;
220
221    spinner.update_spinner_message(" Validating plugin");
222    let metadata = if cmd.update {
223        plugins.update_plugin(&temp_location).await
224    } else {
225        plugins.add_plugin(&temp_location).await
226    }
227    .context("Unable to add plugin")?;
228
229    spinner.update_spinner_message(" Installing plugin");
230
231    // We already have ensured that this plugin is valid, so we can overwrite it even if it already
232    // exists in the plugin dir.
233    let final_location = plugin_dir.join(metadata.id.clone());
234    tokio::fs::rename(temp_location, final_location)
235        .await
236        .context("Unable to install plugin in the plugin directory")?;
237    spinner.finish_and_clear();
238
239    Ok(CommandOutput {
240        text: format!(
241            "Plugin {} (version {}) installed",
242            metadata.name, metadata.version
243        ),
244        map: [
245            ("name".to_string(), metadata.name.into()),
246            ("version".to_string(), metadata.version.into()),
247            ("description".to_string(), metadata.description.into()),
248        ]
249        .into(),
250    })
251}
252
253pub async fn handle_uninstall(
254    cmd: PluginUninstallCommand,
255    output_kind: OutputKind,
256) -> anyhow::Result<CommandOutput> {
257    let plugin_dir = ensure_plugin_dir(cmd.opts.plugin_dir).await?;
258    let spinner = Spinner::new(&output_kind)?;
259
260    spinner.update_spinner_message(" Loading plugins");
261    let plugins = load_plugins(plugin_dir)
262        .await
263        .context("Unable to load plugins")?;
264
265    let metadata = match plugins.metadata(&cmd.plugin) {
266        Some(metadata) => metadata,
267        None => {
268            let message = format!("Plugin {} is not currently installed", cmd.plugin);
269            return Ok(CommandOutput {
270                text: message.clone(),
271                map: [
272                    ("uninstalled".to_string(), false.into()),
273                    ("message".to_string(), message.into()),
274                ]
275                .into(),
276            });
277        }
278    };
279
280    spinner.update_spinner_message(" Uninstalling plugin");
281    // Ok to unwrap because we know the plugin is installed from previous checks
282    let path = plugins.path(&cmd.plugin).unwrap();
283    tokio::fs::remove_file(path)
284        .await
285        .context("Unable to remove plugin")?;
286    spinner.finish_and_clear();
287
288    Ok(CommandOutput {
289        text: format!(
290            "Plugin {} (version {}) uninstalled",
291            cmd.plugin, metadata.version
292        ),
293        map: [("uninstalled".to_string(), true.into())].into(),
294    })
295}
296
297pub async fn handle_list(
298    cmd: PluginListCommand,
299    output_kind: OutputKind,
300) -> anyhow::Result<CommandOutput> {
301    let plugin_dir = ensure_plugin_dir(cmd.opts.plugin_dir).await?;
302    let spinner = Spinner::new(&output_kind)?;
303    spinner.update_spinner_message(" Loading plugins");
304    let plugins = load_plugins(plugin_dir)
305        .await
306        .context("Unable to load plugins")?;
307
308    spinner.finish_and_clear();
309
310    let data = plugins.all_metadata();
311
312    Ok(CommandOutput {
313        text: plugins_table(data.clone()),
314        map: data
315            .into_iter()
316            .map(|m| {
317                (
318                    m.name.clone(),
319                    serde_json::json!({
320                        "version": m.version,
321                        "description": m.description,
322                        "id": m.id,
323                        "name": m.name,
324                        "author": m.author,
325                    }),
326                )
327            })
328            .collect(),
329    })
330}