polysig_protocol/
keypair.rs

1//! Helper functions for working with static keys.
2use crate::{
3    constants::{PATTERN, PEM_PATTERN, PEM_PRIVATE, PEM_PUBLIC},
4    snow::params::NoiseParams,
5    Error, Result,
6};
7use pem::Pem;
8use serde::{
9    de::{self, Deserializer, Visitor},
10    ser::Serializer,
11    Deserialize, Serialize,
12};
13use std::fmt;
14
15/// Key pair used by the noise protocol.
16pub struct Keypair {
17    inner: snow::Keypair,
18}
19
20impl Keypair {
21    /// Generate a new keypair.
22    pub fn new(params: NoiseParams) -> Result<Self> {
23        let builder = snow::Builder::new(params);
24        Ok(Self {
25            inner: builder.generate_keypair()?,
26        })
27    }
28
29    /// Public key.
30    pub fn public_key(&self) -> &[u8] {
31        &self.inner.public
32    }
33
34    /// Private key.
35    pub fn private_key(&self) -> &[u8] {
36        &self.inner.private
37    }
38}
39
40impl Clone for Keypair {
41    fn clone(&self) -> Self {
42        Keypair {
43            inner: snow::Keypair {
44                public: self.inner.public.clone(),
45                private: self.inner.private.clone(),
46            },
47        }
48    }
49}
50
51impl Serialize for Keypair {
52    fn serialize<S>(
53        &self,
54        serializer: S,
55    ) -> std::result::Result<S::Ok, S::Error>
56    where
57        S: Serializer,
58    {
59        let encoded = encode_keypair(self);
60        serializer.serialize_str(&encoded)
61    }
62}
63
64impl<'de> Deserialize<'de> for Keypair {
65    fn deserialize<D>(
66        deserializer: D,
67    ) -> std::result::Result<Keypair, D::Error>
68    where
69        D: Deserializer<'de>,
70    {
71        deserializer.deserialize_str(KeypairVisitor)
72    }
73}
74
75struct KeypairVisitor;
76
77impl<'de> Visitor<'de> for KeypairVisitor {
78    type Value = Keypair;
79
80    fn expecting(
81        &self,
82        formatter: &mut fmt::Formatter,
83    ) -> fmt::Result {
84        formatter.write_str("PEM encoded keypair")
85    }
86
87    fn visit_str<E>(
88        self,
89        value: &str,
90    ) -> std::result::Result<Self::Value, E>
91    where
92        E: de::Error,
93    {
94        let decoded = decode_keypair(value.as_bytes())
95            .map_err(de::Error::custom)?;
96        Ok(decoded)
97    }
98}
99
100/// Generate a keypair for the noise protocol using the
101/// standard pattern.
102pub fn generate_keypair() -> Result<Keypair> {
103    Keypair::new(PATTERN.parse()?)
104}
105
106/// Encode a keypair into a PEM-encoded string.
107pub fn encode_keypair(keypair: &Keypair) -> String {
108    let pattern_pem = Pem::new(PEM_PATTERN, PATTERN.as_bytes());
109    let public_pem =
110        Pem::new(PEM_PUBLIC, keypair.public_key().to_vec());
111    let private_pem =
112        Pem::new(PEM_PRIVATE, keypair.private_key().to_vec());
113    pem::encode_many(&[pattern_pem, public_pem, private_pem])
114}
115
116/// Decode from a PEM-encoded string into a keypair.
117pub fn decode_keypair(keypair: impl AsRef<[u8]>) -> Result<Keypair> {
118    let mut pems = pem::parse_many(keypair)?;
119    if pems.len() == 3 {
120        let (first, second, third) =
121            (pems.remove(0), pems.remove(0), pems.remove(0));
122        if (PEM_PATTERN, PEM_PUBLIC, PEM_PRIVATE)
123            == (first.tag(), second.tag(), third.tag())
124        {
125            if first.into_contents() != PATTERN.as_bytes() {
126                return Err(Error::PatternMismatch(
127                    PATTERN.to_string(),
128                ));
129            }
130
131            Ok(Keypair {
132                inner: snow::Keypair {
133                    public: second.into_contents(),
134                    private: third.into_contents(),
135                },
136            })
137        } else {
138            Err(Error::BadKeypairPem)
139        }
140    } else {
141        Err(Error::BadKeypairPem)
142    }
143}
144
145#[cfg(test)]
146mod tests {
147    use super::{decode_keypair, encode_keypair, generate_keypair};
148    use crate::{
149        Error, PATTERN, PEM_PATTERN, PEM_PRIVATE, PEM_PUBLIC, TAGLEN,
150    };
151    use anyhow::Result;
152    use pem::Pem;
153
154    #[test]
155    fn encode_decode_keypair() -> Result<()> {
156        let keypair = generate_keypair()?;
157        let pem = encode_keypair(&keypair);
158        let decoded = decode_keypair(&pem)?;
159        assert_eq!(keypair.public_key(), decoded.public_key());
160        assert_eq!(keypair.private_key(), decoded.private_key());
161        Ok(())
162    }
163
164    #[test]
165    fn decode_keypair_wrong_length() -> Result<()> {
166        let public_pem = Pem::new("INVALID TAG", vec![0; 32]);
167        let pem = pem::encode_many(&[public_pem]);
168        let result = decode_keypair(&pem);
169        assert!(matches!(result, Err(Error::BadKeypairPem)));
170        Ok(())
171    }
172
173    #[test]
174    fn decode_keypair_wrong_order() -> Result<()> {
175        let pattern_pem = Pem::new(PEM_PATTERN, vec![0; 32]);
176        let public_pem = Pem::new(PEM_PUBLIC, vec![0; 32]);
177        let private_pem = Pem::new(PEM_PRIVATE, vec![0; 32]);
178        let pem =
179            pem::encode_many(&[pattern_pem, private_pem, public_pem]);
180        let result = decode_keypair(&pem);
181        assert!(matches!(result, Err(Error::BadKeypairPem)));
182        Ok(())
183    }
184
185    #[test]
186    fn decode_keypair_pattern_mismatch() -> Result<()> {
187        let pattern_pem = Pem::new(PEM_PATTERN, vec![0; 32]);
188        let public_pem = Pem::new(PEM_PUBLIC, vec![0; 32]);
189        let private_pem = Pem::new(PEM_PRIVATE, vec![0; 32]);
190        let pem =
191            pem::encode_many(&[pattern_pem, public_pem, private_pem]);
192        let result = decode_keypair(&pem);
193        assert!(matches!(result, Err(Error::PatternMismatch(_))));
194        Ok(())
195    }
196
197    #[test]
198    fn noise_transport_encrypt_decrypt() -> Result<()> {
199        let builder_1 = snow::Builder::new(PATTERN.parse()?);
200        let builder_2 = snow::Builder::new(PATTERN.parse()?);
201
202        let keypair1 = builder_1.generate_keypair()?;
203        let keypair2 = builder_2.generate_keypair()?;
204
205        let mut initiator = builder_1
206            .local_private_key(&keypair1.private)
207            .remote_public_key(&keypair2.public)
208            .build_initiator()?;
209
210        let mut responder = builder_2
211            .local_private_key(&keypair2.private)
212            .remote_public_key(&keypair1.public)
213            .build_responder()?;
214
215        let (mut read_buf, mut first_msg, mut second_msg) =
216            ([0u8; 1024], [0u8; 1024], [0u8; 1024]);
217
218        // -> e
219        let len = initiator.write_message(&[], &mut first_msg)?;
220
221        // responder processes the first message...
222        responder.read_message(&first_msg[..len], &mut read_buf)?;
223
224        // <- e, ee
225        let len = responder.write_message(&[], &mut second_msg)?;
226
227        // initiator processes the response...
228        initiator.read_message(&second_msg[..len], &mut read_buf)?;
229
230        // NN handshake complete, transition into transport mode.
231        let mut initiator = initiator.into_transport_mode()?;
232        let mut responder = responder.into_transport_mode()?;
233
234        let data = "this is the message that is sent out";
235        let payload = data.as_bytes();
236
237        let mut message = vec![0; payload.len() + TAGLEN];
238        let len = initiator.write_message(&payload, &mut message)?;
239
240        let payload = message;
241        let mut message = vec![0; len];
242        responder.read_message(&payload[..len], &mut message)?;
243
244        let new_length = len - TAGLEN;
245        message.truncate(new_length);
246
247        let decoded = std::str::from_utf8(&message)?;
248        assert_eq!(data, decoded);
249
250        Ok(())
251    }
252}