pq_jwt/keygen/
mod.rs

1mod builder;
2mod generate;
3
4pub use builder::Builder;
5pub use generate::generate_keypair;
6
7use crate::algorithm::MlDsaAlgo;
8
9/// Indicates the source of a keypair when using load_or_generate
10#[derive(Debug, Clone, PartialEq, Eq)]
11pub enum KeySource {
12    /// Successfully loaded existing key from file or string
13    Loaded,
14    /// Generated new key (file was missing or corrupt)
15    Generated,
16}
17use ml_dsa::{KeyGen as MlDsaKeyGen, KeyPair, MlDsa44, MlDsa65, MlDsa87};
18use std::fs;
19use std::io::Write;
20use std::path::PathBuf;
21
22/// A key generator that can generate and optionally save keypairs
23///
24/// # Example
25/// ```no_run
26/// use pq_jwt::keygen::Builder;
27/// use pq_jwt::MlDsaAlgo;
28///
29/// // Generate and save to default location (keys/)
30/// let (priv_key, pub_key) = Builder::new()
31///     .algorithm(MlDsaAlgo::Dsa65)
32///     .save_to_file()
33///     .generate()
34///     .unwrap();
35/// ```
36#[derive(Debug)]
37pub struct KeyGenerator {
38    algo: MlDsaAlgo,
39    save_path: Option<PathBuf>,
40}
41
42impl KeyGenerator {
43    /// Creates a new KeyGenerator with the specified configuration
44    ///
45    /// # Arguments
46    /// * `algo` - The ML-DSA algorithm variant
47    /// * `save_path` - Optional path to save keys
48    pub(crate) fn new(algo: MlDsaAlgo, save_path: Option<PathBuf>) -> Self {
49        Self { algo, save_path }
50    }
51
52    /// Generates a keypair and optionally saves to file
53    ///
54    /// # Returns
55    /// * `Ok((private_key_hex, public_key_hex))` - Hex-encoded keys
56    /// * `Err(String)` - Error message if generation or save fails
57    ///
58    /// # Example
59    /// ```no_run
60    /// use pq_jwt::keygen::Builder;
61    /// use pq_jwt::MlDsaAlgo;
62    ///
63    /// let generator = Builder::new()
64    ///     .algorithm(MlDsaAlgo::Dsa65)
65    ///     .save_to_file()
66    ///     .build()
67    ///     .unwrap();
68    ///
69    /// let (priv_key, pub_key) = generator.generate().unwrap();
70    /// ```
71    pub fn generate(&self) -> Result<(String, String), String> {
72        // Generate keypair based on algorithm
73        let (private_key_hex, public_key_hex) = match self.algo {
74            MlDsaAlgo::Dsa44 => self.generate_impl::<MlDsa44>()?,
75            MlDsaAlgo::Dsa65 => self.generate_impl::<MlDsa65>()?,
76            MlDsaAlgo::Dsa87 => self.generate_impl::<MlDsa87>()?,
77        };
78
79        // Save to file if path is specified
80        if let Some(path) = &self.save_path {
81            self.save_keys_to_file(path, &private_key_hex, &public_key_hex)?;
82        }
83
84        Ok((private_key_hex, public_key_hex))
85    }
86
87    fn generate_impl<P>(&self) -> Result<(String, String), String>
88    where
89        P: MlDsaKeyGen<KeyPair = KeyPair<P>>,
90    {
91        let mut rng = rand::rng();
92        let kp = P::key_gen(&mut rng);
93
94        // Extract and encode keys
95        let signing_key_encoded = kp.signing_key().encode();
96        let verifying_key_encoded = kp.verifying_key().encode();
97
98        Ok((
99            hex::encode(&signing_key_encoded[..]),
100            hex::encode(&verifying_key_encoded[..]),
101        ))
102    }
103
104    fn save_keys_to_file(
105        &self,
106        path: &PathBuf,
107        private_key: &str,
108        public_key: &str,
109    ) -> Result<(), String> {
110        // Create directory if it doesn't exist
111        fs::create_dir_all(path)
112            .map_err(|e| format!("Failed to create directory {}: {}", path.display(), e))?;
113
114        // Generate filenames based on algorithm
115        let algo_str = self.algo.as_str().to_lowercase().replace("-", "_");
116        let timestamp = std::time::SystemTime::now()
117            .duration_since(std::time::UNIX_EPOCH)
118            .map_err(|e| format!("Failed to get timestamp: {}", e))?
119            .as_secs();
120
121        let private_key_file = path.join(format!("{}_{}_private.key", algo_str, timestamp));
122        let public_key_file = path.join(format!("{}_{}_public.key", algo_str, timestamp));
123
124        // Save private key
125        let mut priv_file = fs::File::create(&private_key_file)
126            .map_err(|e| format!("Failed to create private key file: {}", e))?;
127        priv_file
128            .write_all(private_key.as_bytes())
129            .map_err(|e| format!("Failed to write private key: {}", e))?;
130
131        // Save public key
132        let mut pub_file = fs::File::create(&public_key_file)
133            .map_err(|e| format!("Failed to create public key file: {}", e))?;
134        pub_file
135            .write_all(public_key.as_bytes())
136            .map_err(|e| format!("Failed to write public key: {}", e))?;
137
138        Ok(())
139    }
140
141    /// Returns the algorithm being used by this generator
142    pub fn algorithm(&self) -> MlDsaAlgo {
143        self.algo
144    }
145
146    /// Returns the save path if set
147    pub fn save_path(&self) -> Option<&PathBuf> {
148        self.save_path.as_ref()
149    }
150}
151
152#[cfg(test)]
153mod tests {
154    use super::*;
155    use std::fs;
156
157    #[test]
158    fn test_keygen_basic() {
159        let generator = KeyGenerator::new(MlDsaAlgo::Dsa65, None);
160        let result = generator.generate();
161
162        assert!(result.is_ok());
163        let (priv_key, pub_key) = result.unwrap();
164        assert!(!priv_key.is_empty());
165        assert!(!pub_key.is_empty());
166    }
167
168    #[test]
169    fn test_keygen_all_algorithms() {
170        for algo in [MlDsaAlgo::Dsa44, MlDsaAlgo::Dsa65, MlDsaAlgo::Dsa87] {
171            let generator = KeyGenerator::new(algo, None);
172            let result = generator.generate();
173            assert!(result.is_ok());
174        }
175    }
176
177    #[test]
178    fn test_keygen_getters() {
179        let path = PathBuf::from("test/keys");
180        let generator = KeyGenerator::new(MlDsaAlgo::Dsa87, Some(path.clone()));
181
182        assert_eq!(generator.algorithm(), MlDsaAlgo::Dsa87);
183        assert_eq!(generator.save_path(), Some(&path));
184    }
185
186    #[test]
187    fn test_keygen_save_to_file() {
188        let test_dir = PathBuf::from("test_keys_temp");
189        let generator = KeyGenerator::new(MlDsaAlgo::Dsa65, Some(test_dir.clone()));
190
191        let result = generator.generate();
192        assert!(result.is_ok());
193
194        // Verify directory was created
195        assert!(test_dir.exists());
196
197        // Clean up
198        fs::remove_dir_all(&test_dir).ok();
199    }
200}