use super::GeneratedKey;
use anyhow::{Context, Result, bail};
use russh::keys::{Algorithm, HashAlg, PrivateKey};
use ssh_key::LineEnding;
use std::io::Write;
use std::path::Path;
const MIN_RSA_BITS: u32 = 2048;
const MAX_RSA_BITS: u32 = 16384;
pub fn generate(output_path: &Path, bits: u32, comment: Option<&str>) -> Result<GeneratedKey> {
if bits < MIN_RSA_BITS {
bail!(
"RSA key size must be at least {} bits for security. Got: {}",
MIN_RSA_BITS,
bits
);
}
if bits > MAX_RSA_BITS {
bail!(
"RSA key size must not exceed {} bits. Got: {}",
MAX_RSA_BITS,
bits
);
}
tracing::info!(bits = bits, "Generating RSA key pair");
let keypair = PrivateKey::random(
&mut rand::rng(),
Algorithm::Rsa {
hash: Some(HashAlg::Sha256),
},
)
.context("Failed to generate RSA key")?;
let public_key = keypair.public_key();
let fingerprint = format!("{}", public_key.fingerprint(HashAlg::Sha256));
let private_key_pem = keypair
.to_openssh(LineEnding::LF)
.context("Failed to encode private key to OpenSSH format")?;
let comment_str = comment.unwrap_or("bssh-keygen");
let public_key_base64 = public_key
.to_openssh()
.context("Failed to encode public key to OpenSSH format")?;
let public_key_openssh = format!("{} {}", public_key_base64, comment_str);
write_private_key(output_path, &private_key_pem)?;
let pub_path = format!("{}.pub", output_path.display());
std::fs::write(&pub_path, format!("{}\n", public_key_openssh))
.with_context(|| format!("Failed to write public key to {}", pub_path))?;
tracing::info!(
path = %output_path.display(),
bits = bits,
fingerprint = %fingerprint,
"Generated RSA key"
);
Ok(GeneratedKey {
private_key_pem: private_key_pem.to_string(),
public_key_openssh,
fingerprint,
key_type: format!("rsa-{}", bits),
})
}
fn write_private_key(path: &Path, content: &str) -> Result<()> {
#[cfg(unix)]
{
use std::fs::OpenOptions;
use std::os::unix::fs::OpenOptionsExt;
let mut file = OpenOptions::new()
.write(true)
.create(true)
.truncate(true)
.mode(0o600) .open(path)
.with_context(|| format!("Failed to create private key file: {}", path.display()))?;
file.write_all(content.as_bytes())
.with_context(|| format!("Failed to write private key: {}", path.display()))?;
}
#[cfg(not(unix))]
{
std::fs::write(path, content)
.with_context(|| format!("Failed to write private key: {}", path.display()))?;
}
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use std::fs;
use tempfile::tempdir;
#[test]
fn test_generate_rsa_2048() {
let temp_dir = tempdir().unwrap();
let key_path = temp_dir.path().join("id_rsa");
let result = generate(&key_path, 2048, Some("test@example.com"));
assert!(result.is_ok());
let key = result.unwrap();
assert!(
key.private_key_pem
.contains("-----BEGIN OPENSSH PRIVATE KEY-----")
);
assert!(
key.private_key_pem
.contains("-----END OPENSSH PRIVATE KEY-----")
);
assert!(key.public_key_openssh.starts_with("ssh-rsa "));
assert!(key.public_key_openssh.ends_with("test@example.com"));
assert!(key.fingerprint.starts_with("SHA256:"));
assert_eq!(key.key_type, "rsa-2048");
}
#[test]
fn test_generate_rsa_4096() {
let temp_dir = tempdir().unwrap();
let key_path = temp_dir.path().join("id_rsa");
let result = generate(&key_path, 4096, None);
assert!(result.is_ok());
let key = result.unwrap();
assert_eq!(key.key_type, "rsa-4096");
}
#[test]
fn test_reject_small_key_size() {
let temp_dir = tempdir().unwrap();
let key_path = temp_dir.path().join("id_rsa");
let result = generate(&key_path, 1024, None);
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(err.contains("2048"));
assert!(err.contains("1024"));
}
#[test]
fn test_reject_huge_key_size() {
let temp_dir = tempdir().unwrap();
let key_path = temp_dir.path().join("id_rsa");
let result = generate(&key_path, 32768, None);
assert!(result.is_err());
let err = result.unwrap_err().to_string();
assert!(err.contains("16384"));
assert!(err.contains("32768"));
}
#[test]
fn test_files_created() {
let temp_dir = tempdir().unwrap();
let key_path = temp_dir.path().join("id_rsa");
let result = generate(&key_path, 2048, None);
assert!(result.is_ok());
assert!(key_path.exists());
let pub_path = temp_dir.path().join("id_rsa.pub");
assert!(pub_path.exists());
let pub_content = fs::read_to_string(&pub_path).unwrap();
assert!(pub_content.ends_with('\n'));
}
#[test]
fn test_default_comment() {
let temp_dir = tempdir().unwrap();
let key_path = temp_dir.path().join("id_rsa");
let result = generate(&key_path, 2048, None);
assert!(result.is_ok());
let key = result.unwrap();
assert!(key.public_key_openssh.ends_with("bssh-keygen"));
}
#[test]
#[cfg(unix)]
fn test_private_key_permissions() {
use std::os::unix::fs::PermissionsExt;
let temp_dir = tempdir().unwrap();
let key_path = temp_dir.path().join("id_rsa");
let result = generate(&key_path, 2048, None);
assert!(result.is_ok());
let metadata = fs::metadata(&key_path).unwrap();
let permissions = metadata.permissions();
assert_eq!(permissions.mode() & 0o777, 0o600);
}
#[test]
fn test_unique_keys() {
let temp_dir = tempdir().unwrap();
let key_path1 = temp_dir.path().join("id_rsa_1");
let key_path2 = temp_dir.path().join("id_rsa_2");
let result1 = generate(&key_path1, 2048, None).unwrap();
let result2 = generate(&key_path2, 2048, None).unwrap();
assert_ne!(result1.private_key_pem, result2.private_key_pem);
assert_ne!(result1.public_key_openssh, result2.public_key_openssh);
assert_ne!(result1.fingerprint, result2.fingerprint);
}
#[test]
fn test_key_can_be_read_back() {
let temp_dir = tempdir().unwrap();
let key_path = temp_dir.path().join("id_rsa");
let result = generate(&key_path, 2048, Some("test")).unwrap();
let private_key_content = fs::read_to_string(&key_path).unwrap();
assert_eq!(private_key_content, result.private_key_pem);
let pub_path = temp_dir.path().join("id_rsa.pub");
let public_key_content = fs::read_to_string(&pub_path).unwrap();
assert_eq!(public_key_content.trim(), result.public_key_openssh);
}
#[test]
fn test_boundary_key_sizes() {
let temp_dir = tempdir().unwrap();
let key_path = temp_dir.path().join("id_rsa_min");
let result = generate(&key_path, 2048, None);
assert!(result.is_ok());
let key_path = temp_dir.path().join("id_rsa_below_min");
let result = generate(&key_path, 2047, None);
assert!(result.is_err());
let key_path = temp_dir.path().join("id_rsa_above_max");
let result = generate(&key_path, 16385, None);
assert!(result.is_err());
}
}