1use std::fs;
34#[cfg(unix)]
35use std::io::Write as _;
36use std::path::{Path, PathBuf};
37
38use rand_core::OsRng;
39use ssh_key::{Algorithm, EcdsaCurve, HashAlg, LineEnding, PrivateKey, PublicKey};
40use zeroize::Zeroizing;
41
42use crate::GitwayError;
43
44#[derive(Debug, Clone, Copy, PartialEq, Eq)]
48pub enum KeyType {
49 Ed25519,
51 EcdsaP256,
53 EcdsaP384,
55 EcdsaP521,
57 Rsa,
59}
60
61impl KeyType {
62 #[must_use]
64 pub fn cli_name(self) -> &'static str {
65 match self {
66 Self::Ed25519 => "ed25519",
67 Self::EcdsaP256 | Self::EcdsaP384 | Self::EcdsaP521 => "ecdsa",
68 Self::Rsa => "rsa",
69 }
70 }
71}
72
73pub fn generate(
86 kind: KeyType,
87 bits: Option<u32>,
88 comment: &str,
89) -> Result<PrivateKey, GitwayError> {
90 let algorithm = match kind {
91 KeyType::Ed25519 => Algorithm::Ed25519,
92 KeyType::EcdsaP256 => Algorithm::Ecdsa {
93 curve: EcdsaCurve::NistP256,
94 },
95 KeyType::EcdsaP384 => Algorithm::Ecdsa {
96 curve: EcdsaCurve::NistP384,
97 },
98 KeyType::EcdsaP521 => Algorithm::Ecdsa {
99 curve: EcdsaCurve::NistP521,
100 },
101 KeyType::Rsa => {
102 let b = bits.unwrap_or(DEFAULT_RSA_BITS);
103 if !(MIN_RSA_BITS..=MAX_RSA_BITS).contains(&b) {
104 return Err(GitwayError::invalid_config(format!(
105 "RSA key size {b} is out of range ({MIN_RSA_BITS}-{MAX_RSA_BITS})"
106 )));
107 }
108 return generate_rsa(b, comment);
109 }
110 };
111
112 let mut rng = OsRng;
113 let mut key = PrivateKey::random(&mut rng, algorithm)
114 .map_err(|e| GitwayError::signing(format!("key generation failed: {e}")))?;
115 key.set_comment(comment);
116 Ok(key)
117}
118
119fn generate_rsa(bits: u32, comment: &str) -> Result<PrivateKey, GitwayError> {
121 let mut rng = OsRng;
125 let usize_bits = usize::try_from(bits)
126 .map_err(|_e| GitwayError::invalid_config(format!("RSA bit count {bits} is too large")))?;
127 let rsa_key = ssh_key::private::RsaKeypair::random(&mut rng, usize_bits)
128 .map_err(|e| GitwayError::signing(format!("RSA key generation failed: {e}")))?;
129 let mut key = PrivateKey::from(rsa_key);
130 key.set_comment(comment);
131 Ok(key)
132}
133
134const DEFAULT_RSA_BITS: u32 = 3072;
136const MIN_RSA_BITS: u32 = 2048;
141const MAX_RSA_BITS: u32 = 16384;
143
144pub fn write_keypair(
164 key: &PrivateKey,
165 path: &Path,
166 passphrase: Option<&Zeroizing<String>>,
167) -> Result<(), GitwayError> {
168 let key_to_write = match passphrase {
169 Some(pp) if pp.is_empty() => {
170 return Err(GitwayError::invalid_config(
171 "empty passphrase is not allowed — pass `None` to leave the key unencrypted",
172 ));
173 }
174 Some(pp) => {
175 let mut rng = OsRng;
176 key.encrypt(&mut rng, pp.as_bytes())
177 .map_err(|e| GitwayError::signing(format!("failed to encrypt private key: {e}")))?
178 }
179 None => key.clone(),
180 };
181
182 let private_pem = key_to_write
183 .to_openssh(LineEnding::LF)
184 .map_err(|e| GitwayError::signing(format!("failed to serialize private key: {e}")))?;
185 write_private_file(path, private_pem.as_bytes())?;
186
187 let public = key.public_key();
188 let public_line = public
189 .to_openssh()
190 .map_err(|e| GitwayError::signing(format!("failed to serialize public key: {e}")))?;
191 let pub_path = pub_path_for(path);
192 let mut out = String::with_capacity(public_line.len() + 1);
193 out.push_str(&public_line);
194 out.push('\n');
195 fs::write(&pub_path, out.as_bytes())?;
196 Ok(())
197}
198
199#[cfg(unix)]
205fn write_private_file(path: &Path, bytes: &[u8]) -> Result<(), GitwayError> {
206 use std::os::unix::fs::OpenOptionsExt as _;
207 let mut f = fs::OpenOptions::new()
208 .create(true)
209 .truncate(true)
210 .write(true)
211 .mode(0o600)
214 .open(path)?;
215 f.write_all(bytes)?;
216 Ok(())
217}
218
219#[cfg(not(unix))]
220fn write_private_file(path: &Path, bytes: &[u8]) -> Result<(), GitwayError> {
221 fs::write(path, bytes)?;
222 Ok(())
223}
224
225fn pub_path_for(path: &Path) -> PathBuf {
227 let mut os = path.as_os_str().to_owned();
228 os.push(".pub");
229 PathBuf::from(os)
230}
231
232pub fn change_passphrase(
244 path: &Path,
245 old: Option<&Zeroizing<String>>,
246 new: Option<&Zeroizing<String>>,
247) -> Result<(), GitwayError> {
248 let pem = fs::read_to_string(path)?;
249 let loaded = PrivateKey::from_openssh(&pem)
250 .map_err(|e| GitwayError::signing(format!("failed to parse existing key: {e}")))?;
251
252 let decrypted = if loaded.is_encrypted() {
253 let pp = old.ok_or_else(|| {
254 GitwayError::invalid_config(
255 "existing key is encrypted but no old passphrase was provided",
256 )
257 })?;
258 loaded
259 .decrypt(pp.as_bytes())
260 .map_err(|e| GitwayError::signing(format!("old passphrase is wrong: {e}")))?
261 } else {
262 loaded
263 };
264
265 write_keypair(&decrypted, path, new)
266}
267
268#[must_use]
275pub fn fingerprint(public: &PublicKey, hash: HashAlg) -> String {
276 public.fingerprint(hash).to_string()
277}
278
279pub fn extract_public(path: &Path, out: Option<&Path>) -> Result<(), GitwayError> {
293 let pem = fs::read_to_string(path)?;
294 let key = PrivateKey::from_openssh(&pem)
295 .map_err(|e| GitwayError::signing(format!("failed to parse private key: {e}")))?;
296 let public_line = key
297 .public_key()
298 .to_openssh()
299 .map_err(|e| GitwayError::signing(format!("failed to serialize public key: {e}")))?;
300 let target = match out {
301 Some(p) => p.to_owned(),
302 None => pub_path_for(path),
303 };
304 let mut buf = String::with_capacity(public_line.len() + 1);
305 buf.push_str(&public_line);
306 buf.push('\n');
307 fs::write(&target, buf.as_bytes())?;
308 Ok(())
309}
310
311#[cfg(test)]
314mod tests {
315 use super::*;
316 use tempfile::tempdir;
317
318 #[test]
319 fn generate_ed25519_has_expected_algorithm() {
320 let key = generate(KeyType::Ed25519, None, "test").unwrap();
321 assert_eq!(key.algorithm(), Algorithm::Ed25519);
322 assert_eq!(key.comment(), "test");
323 }
324
325 #[test]
326 fn generate_ecdsa_p256_has_expected_curve() {
327 let key = generate(KeyType::EcdsaP256, None, "test").unwrap();
328 assert_eq!(
329 key.algorithm(),
330 Algorithm::Ecdsa {
331 curve: EcdsaCurve::NistP256
332 }
333 );
334 }
335
336 #[test]
337 fn write_and_read_roundtrip_unencrypted() {
338 let dir = tempdir().unwrap();
339 let path = dir.path().join("id_ed25519");
340 let key = generate(KeyType::Ed25519, None, "roundtrip@test").unwrap();
341 write_keypair(&key, &path, None).unwrap();
342
343 let pem = fs::read_to_string(&path).unwrap();
344 let loaded = PrivateKey::from_openssh(&pem).unwrap();
345 assert!(!loaded.is_encrypted());
346 assert_eq!(
347 loaded.public_key().fingerprint(HashAlg::Sha256),
348 key.public_key().fingerprint(HashAlg::Sha256)
349 );
350
351 let pub_path = path.with_extension("pub");
352 assert!(pub_path.exists(), "expected companion .pub file");
353 let pub_content = fs::read_to_string(&pub_path).unwrap();
354 assert!(pub_content.starts_with("ssh-ed25519 "));
355 }
356
357 #[test]
358 fn write_and_read_roundtrip_encrypted() {
359 let dir = tempdir().unwrap();
360 let path = dir.path().join("id_ed25519");
361 let key = generate(KeyType::Ed25519, None, "enc@test").unwrap();
362 let pp = Zeroizing::new(String::from("correcthorse"));
363 write_keypair(&key, &path, Some(&pp)).unwrap();
364
365 let pem = fs::read_to_string(&path).unwrap();
366 let loaded = PrivateKey::from_openssh(&pem).unwrap();
367 assert!(loaded.is_encrypted());
368 let decrypted = loaded.decrypt(pp.as_bytes()).unwrap();
369 assert_eq!(decrypted.comment(), "enc@test");
370 }
371
372 #[test]
373 fn rejects_empty_passphrase() {
374 let dir = tempdir().unwrap();
375 let path = dir.path().join("id_ed25519");
376 let key = generate(KeyType::Ed25519, None, "empty@test").unwrap();
377 let pp = Zeroizing::new(String::new());
378 let err = write_keypair(&key, &path, Some(&pp)).unwrap_err();
379 assert!(err.to_string().contains("empty passphrase"));
380 }
381
382 #[test]
383 fn change_passphrase_roundtrip() {
384 let dir = tempdir().unwrap();
385 let path = dir.path().join("id_ed25519");
386 let key = generate(KeyType::Ed25519, None, "change@test").unwrap();
387 let pp1 = Zeroizing::new(String::from("one"));
388 write_keypair(&key, &path, Some(&pp1)).unwrap();
389
390 let pp2 = Zeroizing::new(String::from("two"));
391 change_passphrase(&path, Some(&pp1), Some(&pp2)).unwrap();
392
393 let err = change_passphrase(&path, Some(&pp1), Some(&pp2)).unwrap_err();
395 assert!(err.to_string().contains("passphrase"));
396
397 change_passphrase(&path, Some(&pp2), None).unwrap();
399 let pem = fs::read_to_string(&path).unwrap();
400 let loaded = PrivateKey::from_openssh(&pem).unwrap();
401 assert!(!loaded.is_encrypted());
402 }
403
404 #[test]
405 fn fingerprint_format_is_sha256() {
406 let key = generate(KeyType::Ed25519, None, "fp@test").unwrap();
407 let fp = fingerprint(key.public_key(), HashAlg::Sha256);
408 assert!(fp.starts_with("SHA256:"));
409 }
410
411 #[test]
412 fn extract_public_matches_companion_file() {
413 let dir = tempdir().unwrap();
414 let path = dir.path().join("id_ed25519");
415 let key = generate(KeyType::Ed25519, None, "ext@test").unwrap();
416 write_keypair(&key, &path, None).unwrap();
417
418 let pub_path_side = dir.path().join("side.pub");
419 extract_public(&path, Some(&pub_path_side)).unwrap();
420
421 let pub_from_generate = fs::read_to_string(path.with_extension("pub")).unwrap();
422 let pub_from_extract = fs::read_to_string(&pub_path_side).unwrap();
423 assert_eq!(
424 pub_from_generate.split_whitespace().nth(1),
425 pub_from_extract.split_whitespace().nth(1),
426 "base64 key body should match"
427 );
428 }
429
430 #[test]
431 fn rsa_size_bounds_are_enforced() {
432 let err = generate(KeyType::Rsa, Some(1024), "rsa@test").unwrap_err();
433 assert!(err.to_string().contains("out of range"));
434 }
435}