wash_cli/
keys.rs

1use std::{collections::HashMap, path::PathBuf};
2
3use anyhow::Result;
4use clap::Subcommand;
5use nkeys::{KeyPair, KeyPairType};
6use serde_json::json;
7use wash_lib::cli::CommandOutput;
8use wash_lib::config::cfg_dir;
9use wash_lib::keys::{fs::KeyDir, KeyManager};
10
11const NKEYS_EXTENSION: &str = ".nk";
12
13#[derive(Debug, Clone, Subcommand)]
14#[allow(clippy::enum_variant_names)]
15pub enum KeysCliCommand {
16    #[clap(name = "gen", about = "Generates a keypair")]
17    GenCommand {
18        /// The type of keypair to generate. May be Account, User, Module (or Component), Service (or Provider), Server (or Host), Operator, Cluster, Curve (xkey)
19        keytype: String,
20    },
21    #[clap(name = "get", about = "Retrieves a keypair and prints the contents")]
22    GetCommand {
23        #[clap(help = "The name of the key to output")]
24        keyname: String,
25        #[clap(
26            short = 'd',
27            long = "directory",
28            env = "WASH_KEYS",
29            hide_env_values = true,
30            help = "Absolute path to where keypairs are stored. Defaults to `$HOME/.wash/keys`"
31        )]
32        directory: Option<PathBuf>,
33    },
34    #[clap(name = "list", about = "Lists all keypairs in a directory")]
35    ListCommand {
36        #[clap(
37            short = 'd',
38            long = "directory",
39            env = "WASH_KEYS",
40            hide_env_values = true,
41            help = "Absolute path to where keypairs are stored. Defaults to `$HOME/.wash/keys`"
42        )]
43        directory: Option<PathBuf>,
44    },
45}
46
47pub fn handle_command(command: KeysCliCommand) -> Result<CommandOutput> {
48    match command {
49        KeysCliCommand::GenCommand { keytype } => {
50            let kt = keytype_parser(&keytype)?;
51            generate(&kt)
52        }
53        KeysCliCommand::GetCommand { keyname, directory } => get(&keyname, directory),
54        KeysCliCommand::ListCommand { directory } => list(directory),
55    }
56}
57
58pub fn keytype_parser(keytype: &str) -> Result<KeyPairType> {
59    match keytype.to_lowercase().as_str() {
60        "account" => Ok(KeyPairType::Account),
61        "user" => Ok(KeyPairType::User),
62        "module" | "component" => Ok(KeyPairType::Module),
63        "service" | "provider" => Ok(KeyPairType::Service),
64        "server" | "host" => Ok(KeyPairType::Server),
65        "operator" => Ok(KeyPairType::Operator),
66        "cluster" => Ok(KeyPairType::Cluster),
67        "x25519" | "curve" => Ok(KeyPairType::Curve),
68        _ => Err(anyhow::anyhow!(
69            "Invalid key type. Must be one of Account, User, Module (or Component), Service (or Provider), Server (or Host), Operator, Cluster, Curve (xkey)"
70        )),
71    }
72}
73/// Generates a keypair of the specified KeyPairType
74pub fn generate(kt: &KeyPairType) -> Result<CommandOutput> {
75    let kp = KeyPair::new(kt.clone());
76    let seed = kp.seed()?;
77
78    let mut map = HashMap::new();
79    map.insert("public_key".to_string(), json!(kp.public_key()));
80    map.insert("seed".to_string(), json!(seed));
81    Ok(CommandOutput::new(
82        format!(
83            "Public Key: {}\nSeed: {}\n\nRemember that the seed is private, treat it as a secret.",
84            kp.public_key(),
85            seed,
86        ),
87        map,
88    ))
89}
90
91/// Retrieves a keypair by name in a specified directory, or $WASH_KEYS ($HOME/.wash/keys) if directory is not specified
92pub fn get(keyname: &str, directory: Option<PathBuf>) -> Result<CommandOutput> {
93    let key_dir = KeyDir::new(determine_directory(directory)?)?;
94    // Trim off the ".nk" for backwards compat
95    let key = key_dir
96        .get(keyname.trim_end_matches(NKEYS_EXTENSION))?
97        .ok_or_else(|| anyhow::anyhow!("Key {} doesn't exist", keyname))?;
98
99    Ok(CommandOutput::from_key_and_text("seed", key.seed()?))
100}
101
102/// Lists all keypairs (file extension .nk) in a specified directory or $WASH_KEYS($HOME/.wash/keys) if directory is not specified
103pub fn list(directory: Option<PathBuf>) -> Result<CommandOutput> {
104    let key_dir = KeyDir::new(determine_directory(directory)?)?;
105
106    let keys = key_dir.list_names()?;
107
108    let mut map = HashMap::new();
109    map.insert("keys".to_string(), json!(keys));
110    Ok(CommandOutput::new(
111        format!(
112            "====== Keys found in {} ======\n{}",
113            key_dir.display(),
114            keys.join("\n")
115        ),
116        map,
117    ))
118}
119
120fn determine_directory(directory: Option<PathBuf>) -> Result<PathBuf> {
121    if let Some(d) = directory {
122        Ok(d)
123    } else {
124        let d = cfg_dir()?.join("keys");
125        Ok(d)
126    }
127}
128
129#[cfg(test)]
130mod tests {
131
132    use super::{generate, keytype_parser, KeysCliCommand};
133    use clap::Parser;
134    use nkeys::KeyPairType;
135    use serde::Deserialize;
136    use std::path::PathBuf;
137
138    #[derive(Debug, Parser)]
139    struct Cmd {
140        #[clap(subcommand)]
141        keys: KeysCliCommand,
142    }
143    #[test]
144    fn test_generate_basic_test() {
145        let kt = KeyPairType::Account;
146
147        let keypair = generate(&kt).unwrap();
148
149        assert!(keypair.text.contains("Public Key: "));
150        assert!(keypair.text.contains("Seed: "));
151        assert!(keypair
152            .text
153            .contains("Remember that the seed is private, treat it as a secret."));
154        assert_ne!(keypair.text, "");
155        assert!(!keypair.map.is_empty());
156    }
157
158    #[derive(Debug, Clone, Deserialize)]
159    struct KeyPairJson {
160        public_key: String,
161        seed: String,
162    }
163
164    #[test]
165    fn test_generate_valid_keypair() {
166        let sample_public_key = "MBBLAHS7MCGNQ6IR4ZDSGRIAF7NVS7FCKFTKGO5JJJKN2QQRVAH7BSIO";
167        let sample_seed = "SMAH45IUULL57OSX23NOOOTLSVNQOORMDLE3Y3PQLJ4J5MY7MN2K7BIFI4";
168
169        let kt = KeyPairType::Module;
170
171        let keypair_json = generate(&kt).unwrap();
172        let keypair: KeyPairJson =
173            serde_json::from_str(serde_json::to_string(&keypair_json.map).unwrap().as_str())
174                .unwrap();
175
176        assert_eq!(keypair.public_key.len(), sample_public_key.len());
177        assert_eq!(keypair.seed.len(), sample_seed.len());
178        assert!(keypair.public_key.starts_with('M'));
179        assert!(keypair.seed.starts_with("SM"));
180    }
181
182    #[test]
183    fn test_generate_all_types() {
184        let sample_public_key = "MBBLAHS7MCGNQ6IR4ZDSGRIAF7NVS7FCKFTKGO5JJJKN2QQRVAH7BSIO";
185        let sample_seed = "SMAH45IUULL57OSXNOOAKOTLSVNQOORMDLE3Y3PQLJ4J5MY7MN2K7BIFI4";
186
187        let account_keypair: KeyPairJson = serde_json::from_str(
188            serde_json::to_string(&generate(&KeyPairType::Account).unwrap().map)
189                .unwrap()
190                .as_str(),
191        )
192        .unwrap();
193        let user_keypair: KeyPairJson = serde_json::from_str(
194            serde_json::to_string(&generate(&KeyPairType::User).unwrap().map)
195                .unwrap()
196                .as_str(),
197        )
198        .unwrap();
199        let module_keypair: KeyPairJson = serde_json::from_str(
200            serde_json::to_string(&generate(&KeyPairType::Module).unwrap().map)
201                .unwrap()
202                .as_str(),
203        )
204        .unwrap();
205        let service_keypair: KeyPairJson = serde_json::from_str(
206            serde_json::to_string(&generate(&KeyPairType::Service).unwrap().map)
207                .unwrap()
208                .as_str(),
209        )
210        .unwrap();
211        let server_keypair: KeyPairJson = serde_json::from_str(
212            serde_json::to_string(&generate(&KeyPairType::Server).unwrap().map)
213                .unwrap()
214                .as_str(),
215        )
216        .unwrap();
217        let operator_keypair: KeyPairJson = serde_json::from_str(
218            serde_json::to_string(&generate(&KeyPairType::Operator).unwrap().map)
219                .unwrap()
220                .as_str(),
221        )
222        .unwrap();
223        let cluster_keypair: KeyPairJson = serde_json::from_str(
224            serde_json::to_string(&generate(&KeyPairType::Cluster).unwrap().map)
225                .unwrap()
226                .as_str(),
227        )
228        .unwrap();
229
230        assert!(account_keypair.public_key.starts_with('A'));
231        assert_eq!(account_keypair.public_key.len(), sample_public_key.len());
232        assert!(account_keypair.seed.starts_with("SA"));
233        assert_eq!(account_keypair.seed.len(), sample_seed.len());
234
235        assert!(user_keypair.public_key.starts_with('U'));
236        assert_eq!(user_keypair.public_key.len(), sample_public_key.len());
237        assert!(user_keypair.seed.starts_with("SU"));
238        assert_eq!(user_keypair.seed.len(), sample_seed.len());
239
240        assert!(module_keypair.public_key.starts_with('M'));
241        assert_eq!(module_keypair.public_key.len(), sample_public_key.len());
242        assert!(module_keypair.seed.starts_with("SM"));
243        assert_eq!(module_keypair.seed.len(), sample_seed.len());
244
245        assert!(service_keypair.public_key.starts_with('V'));
246        assert_eq!(service_keypair.public_key.len(), sample_public_key.len());
247        assert!(service_keypair.seed.starts_with("SV"));
248        assert_eq!(service_keypair.seed.len(), sample_seed.len());
249
250        assert!(server_keypair.public_key.starts_with('N'));
251        assert_eq!(server_keypair.public_key.len(), sample_public_key.len());
252        assert!(server_keypair.seed.starts_with("SN"));
253        assert_eq!(server_keypair.seed.len(), sample_seed.len());
254
255        assert!(operator_keypair.public_key.starts_with('O'));
256        assert_eq!(operator_keypair.public_key.len(), sample_public_key.len());
257        assert!(operator_keypair.seed.starts_with("SO"));
258        assert_eq!(operator_keypair.seed.len(), sample_seed.len());
259
260        assert!(cluster_keypair.public_key.starts_with('C'));
261        assert_eq!(cluster_keypair.public_key.len(), sample_public_key.len());
262        assert!(cluster_keypair.seed.starts_with("SC"));
263        assert_eq!(cluster_keypair.seed.len(), sample_seed.len());
264    }
265
266    #[test]
267    /// Enumerates multiple options of the `gen` command to ensure API doesn't
268    /// change between versions. This test will fail if `wash keys gen <type>`
269    /// changes syntax, ordering of required elements, or flags.
270    fn test_gen_comprehensive() {
271        let key_gen_types = [
272            "acCount",
273            "usEr",
274            "module",
275            "COMPONENT",
276            "SERVICE",
277            "provider",
278            "server",
279            "HOST",
280            "operator",
281            "CLUSTER",
282        ];
283
284        key_gen_types
285            .iter()
286            .map(|cmd| cmd.to_lowercase())
287            .for_each(|cmd| {
288                let gen_cmd: Cmd = clap::Parser::try_parse_from(["keys", "gen", &cmd]).unwrap();
289                match gen_cmd.keys {
290                    KeysCliCommand::GenCommand { keytype } => {
291                        use KeyPairType::*;
292                        let parsed_keytype = keytype_parser(&keytype).unwrap();
293                        match parsed_keytype {
294                            Account => assert_eq!(&cmd, "account"),
295                            User => assert_eq!(&cmd, "user"),
296                            Module => assert!(cmd.eq("module") || cmd.eq("component")),
297                            Service => assert!(cmd.eq("service") || cmd.eq("provider")),
298                            Server => assert!(cmd.eq("server") || cmd.eq("host")),
299                            Operator => assert_eq!(&cmd, "operator"),
300                            Cluster => assert_eq!(&cmd, "cluster"),
301                            Curve => assert_eq!(&cmd, "curve"),
302                        }
303                    }
304                    _ => panic!("`keys gen` constructed incorrect command"),
305                };
306            });
307    }
308
309    #[test]
310    fn test_invalid_keytype_input() {
311        let key_gen_types = [
312            "accout", "USE", "moDUl", "actors", "SEVICE", "provder", "srver", "hos", "opERtoR",
313            "cluter",
314        ];
315
316        key_gen_types
317            .iter()
318            .map(|cmd| cmd.to_lowercase())
319            .for_each(|cmd| {
320                let gen_cmd: Cmd = clap::Parser::try_parse_from(["keys", "gen", &cmd]).unwrap();
321                match gen_cmd.keys {
322                    KeysCliCommand::GenCommand { keytype } => {
323                        let parsed_keytype = keytype_parser(&keytype);
324                        assert!(
325                            parsed_keytype.is_err(),
326                            "Invalid keytype parsed successfully"
327                        );
328                    }
329                    _ => panic!("`keys gen` constructed incorrect command"),
330                };
331            });
332    }
333
334    #[test]
335    fn test_get_basic() {
336        const KEYNAME: &str = "get_basic_test.nk";
337        const KEYPATH: &str = "./tests/fixtures";
338
339        let gen_basic: Cmd =
340            clap::Parser::try_parse_from(["keys", "get", KEYNAME, "--directory", KEYPATH]).unwrap();
341        match gen_basic.keys {
342            KeysCliCommand::GetCommand { keyname, .. } => assert_eq!(keyname, KEYNAME),
343            other_cmd => panic!("keys get generated other command {other_cmd:?}"),
344        }
345    }
346
347    #[test]
348    /// Enumerates multiple options of the `get` command to ensure API doesn't
349    /// change between versions. This test will fail if `wash keys get`
350    /// changes syntax, ordering of required elements, or flags.
351    fn test_get_comprehensive() {
352        const KEYPATH: &str = "./tests/fixtures";
353        const KEYNAME: &str = "get_comprehensive_test.nk";
354
355        let get_all_flags: Cmd =
356            clap::Parser::try_parse_from(["keys", "get", KEYNAME, "-d", KEYPATH]).unwrap();
357        match get_all_flags.keys {
358            KeysCliCommand::GetCommand { keyname, directory } => {
359                assert_eq!(keyname, KEYNAME);
360                assert_eq!(directory, Some(PathBuf::from(KEYPATH)));
361            }
362            other_cmd => panic!("keys get generated other command {other_cmd:?}"),
363        }
364    }
365
366    #[test]
367    /// Enumerates multiple options of the `list` command to ensure API doesn't
368    /// change between versions. This test will fail if `wash keys list`
369    /// changes syntax, ordering of required elements, or flags.
370    fn test_list_comprehensive() {
371        const KEYPATH: &str = "./";
372
373        let list_basic: Cmd =
374            clap::Parser::try_parse_from(["keys", "list", "-d", KEYPATH]).unwrap();
375        match list_basic.keys {
376            KeysCliCommand::ListCommand { .. } => (),
377            other_cmd => panic!("keys get generated other command {other_cmd:?}"),
378        }
379
380        let list_all_flags: Cmd =
381            clap::Parser::try_parse_from(["keys", "list", "-d", KEYPATH]).unwrap();
382        match list_all_flags.keys {
383            KeysCliCommand::ListCommand { directory } => {
384                assert_eq!(directory, Some(PathBuf::from(KEYPATH)));
385            }
386            other_cmd => panic!("keys get generated other command {other_cmd:?}"),
387        }
388    }
389}