1use std::fs;
34#[cfg(unix)]
35use std::io::Write as _;
36use std::path::{Path, PathBuf};
37
38use rand_core::OsRng;
39use ssh_key::{Algorithm, EcdsaCurve, HashAlg, LineEnding, PrivateKey, PublicKey};
40use zeroize::Zeroizing;
41
42use crate::AnvilError;
43
44#[derive(Debug, Clone, Copy, PartialEq, Eq)]
48pub enum KeyType {
49 Ed25519,
51 EcdsaP256,
53 EcdsaP384,
55 EcdsaP521,
57 Rsa,
59}
60
61impl KeyType {
62 #[must_use]
64 pub fn cli_name(self) -> &'static str {
65 match self {
66 Self::Ed25519 => "ed25519",
67 Self::EcdsaP256 | Self::EcdsaP384 | Self::EcdsaP521 => "ecdsa",
68 Self::Rsa => "rsa",
69 }
70 }
71}
72
73pub fn generate(kind: KeyType, bits: Option<u32>, comment: &str) -> Result<PrivateKey, AnvilError> {
86 let algorithm = match kind {
87 KeyType::Ed25519 => Algorithm::Ed25519,
88 KeyType::EcdsaP256 => Algorithm::Ecdsa {
89 curve: EcdsaCurve::NistP256,
90 },
91 KeyType::EcdsaP384 => Algorithm::Ecdsa {
92 curve: EcdsaCurve::NistP384,
93 },
94 KeyType::EcdsaP521 => Algorithm::Ecdsa {
95 curve: EcdsaCurve::NistP521,
96 },
97 KeyType::Rsa => {
98 let b = bits.unwrap_or(DEFAULT_RSA_BITS);
99 if !(MIN_RSA_BITS..=MAX_RSA_BITS).contains(&b) {
100 return Err(AnvilError::invalid_config(format!(
101 "RSA key size {b} is out of range ({MIN_RSA_BITS}-{MAX_RSA_BITS})"
102 )));
103 }
104 return generate_rsa(b, comment);
105 }
106 };
107
108 let mut rng = OsRng;
109 let mut key = PrivateKey::random(&mut rng, algorithm)
110 .map_err(|e| AnvilError::signing(format!("key generation failed: {e}")))?;
111 key.set_comment(comment);
112 Ok(key)
113}
114
115fn generate_rsa(bits: u32, comment: &str) -> Result<PrivateKey, AnvilError> {
117 let mut rng = OsRng;
121 let usize_bits = usize::try_from(bits)
122 .map_err(|_e| AnvilError::invalid_config(format!("RSA bit count {bits} is too large")))?;
123 let rsa_key = ssh_key::private::RsaKeypair::random(&mut rng, usize_bits)
124 .map_err(|e| AnvilError::signing(format!("RSA key generation failed: {e}")))?;
125 let mut key = PrivateKey::from(rsa_key);
126 key.set_comment(comment);
127 Ok(key)
128}
129
130const DEFAULT_RSA_BITS: u32 = 3072;
132const MIN_RSA_BITS: u32 = 2048;
137const MAX_RSA_BITS: u32 = 16384;
139
140pub fn write_keypair(
160 key: &PrivateKey,
161 path: &Path,
162 passphrase: Option<&Zeroizing<String>>,
163) -> Result<(), AnvilError> {
164 let key_to_write = match passphrase {
165 Some(pp) if pp.is_empty() => {
166 return Err(AnvilError::invalid_config(
167 "empty passphrase is not allowed — pass `None` to leave the key unencrypted",
168 ));
169 }
170 Some(pp) => {
171 let mut rng = OsRng;
172 key.encrypt(&mut rng, pp.as_bytes())
173 .map_err(|e| AnvilError::signing(format!("failed to encrypt private key: {e}")))?
174 }
175 None => key.clone(),
176 };
177
178 let private_pem = key_to_write
179 .to_openssh(LineEnding::LF)
180 .map_err(|e| AnvilError::signing(format!("failed to serialize private key: {e}")))?;
181 write_private_file(path, private_pem.as_bytes())?;
182
183 let public = key.public_key();
184 let public_line = public
185 .to_openssh()
186 .map_err(|e| AnvilError::signing(format!("failed to serialize public key: {e}")))?;
187 let pub_path = pub_path_for(path);
188 let mut out = String::with_capacity(public_line.len() + 1);
189 out.push_str(&public_line);
190 out.push('\n');
191 fs::write(&pub_path, out.as_bytes())?;
192 Ok(())
193}
194
195#[cfg(unix)]
201fn write_private_file(path: &Path, bytes: &[u8]) -> Result<(), AnvilError> {
202 use std::os::unix::fs::OpenOptionsExt as _;
203 let mut f = fs::OpenOptions::new()
204 .create(true)
205 .truncate(true)
206 .write(true)
207 .mode(0o600)
210 .open(path)?;
211 f.write_all(bytes)?;
212 Ok(())
213}
214
215#[cfg(not(unix))]
216fn write_private_file(path: &Path, bytes: &[u8]) -> Result<(), AnvilError> {
217 fs::write(path, bytes)?;
218 Ok(())
219}
220
221fn pub_path_for(path: &Path) -> PathBuf {
223 let mut os = path.as_os_str().to_owned();
224 os.push(".pub");
225 PathBuf::from(os)
226}
227
228pub fn change_passphrase(
240 path: &Path,
241 old: Option<&Zeroizing<String>>,
242 new: Option<&Zeroizing<String>>,
243) -> Result<(), AnvilError> {
244 let pem = fs::read_to_string(path)?;
245 let loaded = PrivateKey::from_openssh(&pem)
246 .map_err(|e| AnvilError::signing(format!("failed to parse existing key: {e}")))?;
247
248 let decrypted = if loaded.is_encrypted() {
249 let pp = old.ok_or_else(|| {
250 AnvilError::invalid_config(
251 "existing key is encrypted but no old passphrase was provided",
252 )
253 })?;
254 loaded
255 .decrypt(pp.as_bytes())
256 .map_err(|e| AnvilError::signing(format!("old passphrase is wrong: {e}")))?
257 } else {
258 loaded
259 };
260
261 write_keypair(&decrypted, path, new)
262}
263
264#[must_use]
271pub fn fingerprint(public: &PublicKey, hash: HashAlg) -> String {
272 public.fingerprint(hash).to_string()
273}
274
275pub fn extract_public(path: &Path, out: Option<&Path>) -> Result<(), AnvilError> {
289 let pem = fs::read_to_string(path)?;
290 let key = PrivateKey::from_openssh(&pem)
291 .map_err(|e| AnvilError::signing(format!("failed to parse private key: {e}")))?;
292 let public_line = key
293 .public_key()
294 .to_openssh()
295 .map_err(|e| AnvilError::signing(format!("failed to serialize public key: {e}")))?;
296 let target = match out {
297 Some(p) => p.to_owned(),
298 None => pub_path_for(path),
299 };
300 let mut buf = String::with_capacity(public_line.len() + 1);
301 buf.push_str(&public_line);
302 buf.push('\n');
303 fs::write(&target, buf.as_bytes())?;
304 Ok(())
305}
306
307#[cfg(test)]
310mod tests {
311 use super::*;
312 use tempfile::tempdir;
313
314 #[test]
315 fn generate_ed25519_has_expected_algorithm() {
316 let key = generate(KeyType::Ed25519, None, "test").unwrap();
317 assert_eq!(key.algorithm(), Algorithm::Ed25519);
318 assert_eq!(key.comment(), "test");
319 }
320
321 #[test]
322 fn generate_ecdsa_p256_has_expected_curve() {
323 let key = generate(KeyType::EcdsaP256, None, "test").unwrap();
324 assert_eq!(
325 key.algorithm(),
326 Algorithm::Ecdsa {
327 curve: EcdsaCurve::NistP256
328 }
329 );
330 }
331
332 #[test]
333 fn write_and_read_roundtrip_unencrypted() {
334 let dir = tempdir().unwrap();
335 let path = dir.path().join("id_ed25519");
336 let key = generate(KeyType::Ed25519, None, "roundtrip@test").unwrap();
337 write_keypair(&key, &path, None).unwrap();
338
339 let pem = fs::read_to_string(&path).unwrap();
340 let loaded = PrivateKey::from_openssh(&pem).unwrap();
341 assert!(!loaded.is_encrypted());
342 assert_eq!(
343 loaded.public_key().fingerprint(HashAlg::Sha256),
344 key.public_key().fingerprint(HashAlg::Sha256)
345 );
346
347 let pub_path = path.with_extension("pub");
348 assert!(pub_path.exists(), "expected companion .pub file");
349 let pub_content = fs::read_to_string(&pub_path).unwrap();
350 assert!(pub_content.starts_with("ssh-ed25519 "));
351 }
352
353 #[test]
354 fn write_and_read_roundtrip_encrypted() {
355 let dir = tempdir().unwrap();
356 let path = dir.path().join("id_ed25519");
357 let key = generate(KeyType::Ed25519, None, "enc@test").unwrap();
358 let pp = Zeroizing::new(String::from("correcthorse"));
359 write_keypair(&key, &path, Some(&pp)).unwrap();
360
361 let pem = fs::read_to_string(&path).unwrap();
362 let loaded = PrivateKey::from_openssh(&pem).unwrap();
363 assert!(loaded.is_encrypted());
364 let decrypted = loaded.decrypt(pp.as_bytes()).unwrap();
365 assert_eq!(decrypted.comment(), "enc@test");
366 }
367
368 #[test]
369 fn rejects_empty_passphrase() {
370 let dir = tempdir().unwrap();
371 let path = dir.path().join("id_ed25519");
372 let key = generate(KeyType::Ed25519, None, "empty@test").unwrap();
373 let pp = Zeroizing::new(String::new());
374 let err = write_keypair(&key, &path, Some(&pp)).unwrap_err();
375 assert!(err.to_string().contains("empty passphrase"));
376 }
377
378 #[test]
379 fn change_passphrase_roundtrip() {
380 let dir = tempdir().unwrap();
381 let path = dir.path().join("id_ed25519");
382 let key = generate(KeyType::Ed25519, None, "change@test").unwrap();
383 let pp1 = Zeroizing::new(String::from("one"));
384 write_keypair(&key, &path, Some(&pp1)).unwrap();
385
386 let pp2 = Zeroizing::new(String::from("two"));
387 change_passphrase(&path, Some(&pp1), Some(&pp2)).unwrap();
388
389 let err = change_passphrase(&path, Some(&pp1), Some(&pp2)).unwrap_err();
391 assert!(err.to_string().contains("passphrase"));
392
393 change_passphrase(&path, Some(&pp2), None).unwrap();
395 let pem = fs::read_to_string(&path).unwrap();
396 let loaded = PrivateKey::from_openssh(&pem).unwrap();
397 assert!(!loaded.is_encrypted());
398 }
399
400 #[test]
401 fn fingerprint_format_is_sha256() {
402 let key = generate(KeyType::Ed25519, None, "fp@test").unwrap();
403 let fp = fingerprint(key.public_key(), HashAlg::Sha256);
404 assert!(fp.starts_with("SHA256:"));
405 }
406
407 #[test]
408 fn extract_public_matches_companion_file() {
409 let dir = tempdir().unwrap();
410 let path = dir.path().join("id_ed25519");
411 let key = generate(KeyType::Ed25519, None, "ext@test").unwrap();
412 write_keypair(&key, &path, None).unwrap();
413
414 let pub_path_side = dir.path().join("side.pub");
415 extract_public(&path, Some(&pub_path_side)).unwrap();
416
417 let pub_from_generate = fs::read_to_string(path.with_extension("pub")).unwrap();
418 let pub_from_extract = fs::read_to_string(&pub_path_side).unwrap();
419 assert_eq!(
420 pub_from_generate.split_whitespace().nth(1),
421 pub_from_extract.split_whitespace().nth(1),
422 "base64 key body should match"
423 );
424 }
425
426 #[test]
427 fn rsa_size_bounds_are_enforced() {
428 let err = generate(KeyType::Rsa, Some(1024), "rsa@test").unwrap_err();
429 assert!(err.to_string().contains("out of range"));
430 }
431}