1mod builder;
2mod generate;
3
4pub use builder::Builder;
5pub use generate::generate_keypair;
6
7use crate::algorithm::MlDsaAlgo;
8
9#[derive(Debug, Clone, PartialEq, Eq)]
11pub enum KeySource {
12 Loaded,
14 Generated,
16}
17use ml_dsa::{KeyGen as MlDsaKeyGen, KeyPair, MlDsa44, MlDsa65, MlDsa87};
18use std::fs;
19use std::io::Write;
20use std::path::PathBuf;
21
22#[derive(Debug)]
37pub struct KeyGenerator {
38 algo: MlDsaAlgo,
39 save_path: Option<PathBuf>,
40}
41
42impl KeyGenerator {
43 pub(crate) fn new(algo: MlDsaAlgo, save_path: Option<PathBuf>) -> Self {
49 Self { algo, save_path }
50 }
51
52 pub fn generate(&self) -> Result<(String, String), String> {
72 let (private_key_hex, public_key_hex) = match self.algo {
74 MlDsaAlgo::Dsa44 => self.generate_impl::<MlDsa44>()?,
75 MlDsaAlgo::Dsa65 => self.generate_impl::<MlDsa65>()?,
76 MlDsaAlgo::Dsa87 => self.generate_impl::<MlDsa87>()?,
77 };
78
79 if let Some(path) = &self.save_path {
81 self.save_keys_to_file(path, &private_key_hex, &public_key_hex)?;
82 }
83
84 Ok((private_key_hex, public_key_hex))
85 }
86
87 fn generate_impl<P>(&self) -> Result<(String, String), String>
88 where
89 P: MlDsaKeyGen<KeyPair = KeyPair<P>>,
90 {
91 let mut rng = rand::rng();
92 let kp = P::key_gen(&mut rng);
93
94 let signing_key_encoded = kp.signing_key().encode();
96 let verifying_key_encoded = kp.verifying_key().encode();
97
98 Ok((
99 hex::encode(&signing_key_encoded[..]),
100 hex::encode(&verifying_key_encoded[..]),
101 ))
102 }
103
104 fn save_keys_to_file(
105 &self,
106 path: &PathBuf,
107 private_key: &str,
108 public_key: &str,
109 ) -> Result<(), String> {
110 fs::create_dir_all(path)
112 .map_err(|e| format!("Failed to create directory {}: {}", path.display(), e))?;
113
114 let algo_str = self.algo.as_str().to_lowercase().replace("-", "_");
116 let timestamp = std::time::SystemTime::now()
117 .duration_since(std::time::UNIX_EPOCH)
118 .map_err(|e| format!("Failed to get timestamp: {}", e))?
119 .as_secs();
120
121 let private_key_file = path.join(format!("{}_{}_private.key", algo_str, timestamp));
122 let public_key_file = path.join(format!("{}_{}_public.key", algo_str, timestamp));
123
124 let mut priv_file = fs::File::create(&private_key_file)
126 .map_err(|e| format!("Failed to create private key file: {}", e))?;
127 priv_file
128 .write_all(private_key.as_bytes())
129 .map_err(|e| format!("Failed to write private key: {}", e))?;
130
131 let mut pub_file = fs::File::create(&public_key_file)
133 .map_err(|e| format!("Failed to create public key file: {}", e))?;
134 pub_file
135 .write_all(public_key.as_bytes())
136 .map_err(|e| format!("Failed to write public key: {}", e))?;
137
138 Ok(())
139 }
140
141 pub fn algorithm(&self) -> MlDsaAlgo {
143 self.algo
144 }
145
146 pub fn save_path(&self) -> Option<&PathBuf> {
148 self.save_path.as_ref()
149 }
150}
151
152#[cfg(test)]
153mod tests {
154 use super::*;
155 use std::fs;
156
157 #[test]
158 fn test_keygen_basic() {
159 let generator = KeyGenerator::new(MlDsaAlgo::Dsa65, None);
160 let result = generator.generate();
161
162 assert!(result.is_ok());
163 let (priv_key, pub_key) = result.unwrap();
164 assert!(!priv_key.is_empty());
165 assert!(!pub_key.is_empty());
166 }
167
168 #[test]
169 fn test_keygen_all_algorithms() {
170 for algo in [MlDsaAlgo::Dsa44, MlDsaAlgo::Dsa65, MlDsaAlgo::Dsa87] {
171 let generator = KeyGenerator::new(algo, None);
172 let result = generator.generate();
173 assert!(result.is_ok());
174 }
175 }
176
177 #[test]
178 fn test_keygen_getters() {
179 let path = PathBuf::from("test/keys");
180 let generator = KeyGenerator::new(MlDsaAlgo::Dsa87, Some(path.clone()));
181
182 assert_eq!(generator.algorithm(), MlDsaAlgo::Dsa87);
183 assert_eq!(generator.save_path(), Some(&path));
184 }
185
186 #[test]
187 fn test_keygen_save_to_file() {
188 let test_dir = PathBuf::from("test_keys_temp");
189 let generator = KeyGenerator::new(MlDsaAlgo::Dsa65, Some(test_dir.clone()));
190
191 let result = generator.generate();
192 assert!(result.is_ok());
193
194 assert!(test_dir.exists());
196
197 fs::remove_dir_all(&test_dir).ok();
199 }
200}