1use std::fs;
34use std::io::Write as _;
35use std::path::{Path, PathBuf};
36
37use rand_core::OsRng;
38use ssh_key::{Algorithm, EcdsaCurve, HashAlg, LineEnding, PrivateKey, PublicKey};
39use zeroize::Zeroizing;
40
41use crate::GitwayError;
42
43#[derive(Debug, Clone, Copy, PartialEq, Eq)]
47pub enum KeyType {
48 Ed25519,
50 EcdsaP256,
52 EcdsaP384,
54 EcdsaP521,
56 Rsa,
58}
59
60impl KeyType {
61 #[must_use]
63 pub fn cli_name(self) -> &'static str {
64 match self {
65 Self::Ed25519 => "ed25519",
66 Self::EcdsaP256 | Self::EcdsaP384 | Self::EcdsaP521 => "ecdsa",
67 Self::Rsa => "rsa",
68 }
69 }
70}
71
72pub fn generate(
85 kind: KeyType,
86 bits: Option<u32>,
87 comment: &str,
88) -> Result<PrivateKey, GitwayError> {
89 let algorithm = match kind {
90 KeyType::Ed25519 => Algorithm::Ed25519,
91 KeyType::EcdsaP256 => Algorithm::Ecdsa {
92 curve: EcdsaCurve::NistP256,
93 },
94 KeyType::EcdsaP384 => Algorithm::Ecdsa {
95 curve: EcdsaCurve::NistP384,
96 },
97 KeyType::EcdsaP521 => Algorithm::Ecdsa {
98 curve: EcdsaCurve::NistP521,
99 },
100 KeyType::Rsa => {
101 let b = bits.unwrap_or(DEFAULT_RSA_BITS);
102 if !(MIN_RSA_BITS..=MAX_RSA_BITS).contains(&b) {
103 return Err(GitwayError::invalid_config(format!(
104 "RSA key size {b} is out of range ({MIN_RSA_BITS}-{MAX_RSA_BITS})"
105 )));
106 }
107 return generate_rsa(b, comment);
108 }
109 };
110
111 let mut rng = OsRng;
112 let mut key = PrivateKey::random(&mut rng, algorithm)
113 .map_err(|e| GitwayError::signing(format!("key generation failed: {e}")))?;
114 key.set_comment(comment);
115 Ok(key)
116}
117
118fn generate_rsa(bits: u32, comment: &str) -> Result<PrivateKey, GitwayError> {
120 let mut rng = OsRng;
124 let usize_bits = usize::try_from(bits)
125 .map_err(|_e| GitwayError::invalid_config(format!("RSA bit count {bits} is too large")))?;
126 let rsa_key = ssh_key::private::RsaKeypair::random(&mut rng, usize_bits)
127 .map_err(|e| GitwayError::signing(format!("RSA key generation failed: {e}")))?;
128 let mut key = PrivateKey::from(rsa_key);
129 key.set_comment(comment);
130 Ok(key)
131}
132
133const DEFAULT_RSA_BITS: u32 = 3072;
135const MIN_RSA_BITS: u32 = 2048;
140const MAX_RSA_BITS: u32 = 16384;
142
143pub fn write_keypair(
163 key: &PrivateKey,
164 path: &Path,
165 passphrase: Option<&Zeroizing<String>>,
166) -> Result<(), GitwayError> {
167 let key_to_write = match passphrase {
168 Some(pp) if pp.is_empty() => {
169 return Err(GitwayError::invalid_config(
170 "empty passphrase is not allowed — pass `None` to leave the key unencrypted",
171 ));
172 }
173 Some(pp) => {
174 let mut rng = OsRng;
175 key.encrypt(&mut rng, pp.as_bytes())
176 .map_err(|e| GitwayError::signing(format!("failed to encrypt private key: {e}")))?
177 }
178 None => key.clone(),
179 };
180
181 let private_pem = key_to_write
182 .to_openssh(LineEnding::LF)
183 .map_err(|e| GitwayError::signing(format!("failed to serialize private key: {e}")))?;
184 write_private_file(path, private_pem.as_bytes())?;
185
186 let public = key.public_key();
187 let public_line = public
188 .to_openssh()
189 .map_err(|e| GitwayError::signing(format!("failed to serialize public key: {e}")))?;
190 let pub_path = pub_path_for(path);
191 let mut out = String::with_capacity(public_line.len() + 1);
192 out.push_str(&public_line);
193 out.push('\n');
194 fs::write(&pub_path, out.as_bytes())?;
195 Ok(())
196}
197
198#[cfg(unix)]
204fn write_private_file(path: &Path, bytes: &[u8]) -> Result<(), GitwayError> {
205 use std::os::unix::fs::OpenOptionsExt as _;
206 let mut f = fs::OpenOptions::new()
207 .create(true)
208 .truncate(true)
209 .write(true)
210 .mode(0o600)
213 .open(path)?;
214 f.write_all(bytes)?;
215 Ok(())
216}
217
218#[cfg(not(unix))]
219fn write_private_file(path: &Path, bytes: &[u8]) -> Result<(), GitwayError> {
220 fs::write(path, bytes)?;
221 Ok(())
222}
223
224fn pub_path_for(path: &Path) -> PathBuf {
226 let mut os = path.as_os_str().to_owned();
227 os.push(".pub");
228 PathBuf::from(os)
229}
230
231pub fn change_passphrase(
243 path: &Path,
244 old: Option<&Zeroizing<String>>,
245 new: Option<&Zeroizing<String>>,
246) -> Result<(), GitwayError> {
247 let pem = fs::read_to_string(path)?;
248 let loaded = PrivateKey::from_openssh(&pem)
249 .map_err(|e| GitwayError::signing(format!("failed to parse existing key: {e}")))?;
250
251 let decrypted = if loaded.is_encrypted() {
252 let pp = old.ok_or_else(|| {
253 GitwayError::invalid_config(
254 "existing key is encrypted but no old passphrase was provided",
255 )
256 })?;
257 loaded
258 .decrypt(pp.as_bytes())
259 .map_err(|e| GitwayError::signing(format!("old passphrase is wrong: {e}")))?
260 } else {
261 loaded
262 };
263
264 write_keypair(&decrypted, path, new)
265}
266
267#[must_use]
274pub fn fingerprint(public: &PublicKey, hash: HashAlg) -> String {
275 public.fingerprint(hash).to_string()
276}
277
278pub fn extract_public(path: &Path, out: Option<&Path>) -> Result<(), GitwayError> {
292 let pem = fs::read_to_string(path)?;
293 let key = PrivateKey::from_openssh(&pem)
294 .map_err(|e| GitwayError::signing(format!("failed to parse private key: {e}")))?;
295 let public_line = key
296 .public_key()
297 .to_openssh()
298 .map_err(|e| GitwayError::signing(format!("failed to serialize public key: {e}")))?;
299 let target = match out {
300 Some(p) => p.to_owned(),
301 None => pub_path_for(path),
302 };
303 let mut buf = String::with_capacity(public_line.len() + 1);
304 buf.push_str(&public_line);
305 buf.push('\n');
306 fs::write(&target, buf.as_bytes())?;
307 Ok(())
308}
309
310#[cfg(test)]
313mod tests {
314 use super::*;
315 use tempfile::tempdir;
316
317 #[test]
318 fn generate_ed25519_has_expected_algorithm() {
319 let key = generate(KeyType::Ed25519, None, "test").unwrap();
320 assert_eq!(key.algorithm(), Algorithm::Ed25519);
321 assert_eq!(key.comment(), "test");
322 }
323
324 #[test]
325 fn generate_ecdsa_p256_has_expected_curve() {
326 let key = generate(KeyType::EcdsaP256, None, "test").unwrap();
327 assert_eq!(
328 key.algorithm(),
329 Algorithm::Ecdsa {
330 curve: EcdsaCurve::NistP256
331 }
332 );
333 }
334
335 #[test]
336 fn write_and_read_roundtrip_unencrypted() {
337 let dir = tempdir().unwrap();
338 let path = dir.path().join("id_ed25519");
339 let key = generate(KeyType::Ed25519, None, "roundtrip@test").unwrap();
340 write_keypair(&key, &path, None).unwrap();
341
342 let pem = fs::read_to_string(&path).unwrap();
343 let loaded = PrivateKey::from_openssh(&pem).unwrap();
344 assert!(!loaded.is_encrypted());
345 assert_eq!(
346 loaded.public_key().fingerprint(HashAlg::Sha256),
347 key.public_key().fingerprint(HashAlg::Sha256)
348 );
349
350 let pub_path = path.with_extension("pub");
351 assert!(pub_path.exists(), "expected companion .pub file");
352 let pub_content = fs::read_to_string(&pub_path).unwrap();
353 assert!(pub_content.starts_with("ssh-ed25519 "));
354 }
355
356 #[test]
357 fn write_and_read_roundtrip_encrypted() {
358 let dir = tempdir().unwrap();
359 let path = dir.path().join("id_ed25519");
360 let key = generate(KeyType::Ed25519, None, "enc@test").unwrap();
361 let pp = Zeroizing::new(String::from("correcthorse"));
362 write_keypair(&key, &path, Some(&pp)).unwrap();
363
364 let pem = fs::read_to_string(&path).unwrap();
365 let loaded = PrivateKey::from_openssh(&pem).unwrap();
366 assert!(loaded.is_encrypted());
367 let decrypted = loaded.decrypt(pp.as_bytes()).unwrap();
368 assert_eq!(decrypted.comment(), "enc@test");
369 }
370
371 #[test]
372 fn rejects_empty_passphrase() {
373 let dir = tempdir().unwrap();
374 let path = dir.path().join("id_ed25519");
375 let key = generate(KeyType::Ed25519, None, "empty@test").unwrap();
376 let pp = Zeroizing::new(String::new());
377 let err = write_keypair(&key, &path, Some(&pp)).unwrap_err();
378 assert!(err.to_string().contains("empty passphrase"));
379 }
380
381 #[test]
382 fn change_passphrase_roundtrip() {
383 let dir = tempdir().unwrap();
384 let path = dir.path().join("id_ed25519");
385 let key = generate(KeyType::Ed25519, None, "change@test").unwrap();
386 let pp1 = Zeroizing::new(String::from("one"));
387 write_keypair(&key, &path, Some(&pp1)).unwrap();
388
389 let pp2 = Zeroizing::new(String::from("two"));
390 change_passphrase(&path, Some(&pp1), Some(&pp2)).unwrap();
391
392 let err = change_passphrase(&path, Some(&pp1), Some(&pp2)).unwrap_err();
394 assert!(err.to_string().contains("passphrase"));
395
396 change_passphrase(&path, Some(&pp2), None).unwrap();
398 let pem = fs::read_to_string(&path).unwrap();
399 let loaded = PrivateKey::from_openssh(&pem).unwrap();
400 assert!(!loaded.is_encrypted());
401 }
402
403 #[test]
404 fn fingerprint_format_is_sha256() {
405 let key = generate(KeyType::Ed25519, None, "fp@test").unwrap();
406 let fp = fingerprint(key.public_key(), HashAlg::Sha256);
407 assert!(fp.starts_with("SHA256:"));
408 }
409
410 #[test]
411 fn extract_public_matches_companion_file() {
412 let dir = tempdir().unwrap();
413 let path = dir.path().join("id_ed25519");
414 let key = generate(KeyType::Ed25519, None, "ext@test").unwrap();
415 write_keypair(&key, &path, None).unwrap();
416
417 let pub_path_side = dir.path().join("side.pub");
418 extract_public(&path, Some(&pub_path_side)).unwrap();
419
420 let pub_from_generate = fs::read_to_string(path.with_extension("pub")).unwrap();
421 let pub_from_extract = fs::read_to_string(&pub_path_side).unwrap();
422 assert_eq!(
423 pub_from_generate.split_whitespace().nth(1),
424 pub_from_extract.split_whitespace().nth(1),
425 "base64 key body should match"
426 );
427 }
428
429 #[test]
430 fn rsa_size_bounds_are_enforced() {
431 let err = generate(KeyType::Rsa, Some(1024), "rsa@test").unwrap_err();
432 assert!(err.to_string().contains("out of range"));
433 }
434}