polysig_protocol/
keypair.rs1use crate::{
3 constants::{PATTERN, PEM_PATTERN, PEM_PRIVATE, PEM_PUBLIC},
4 snow::params::NoiseParams,
5 Error, Result,
6};
7use pem::Pem;
8use serde::{
9 de::{self, Deserializer, Visitor},
10 ser::Serializer,
11 Deserialize, Serialize,
12};
13use std::fmt;
14
15pub struct Keypair {
17 inner: snow::Keypair,
18}
19
20impl Keypair {
21 pub fn new(params: NoiseParams) -> Result<Self> {
23 let builder = snow::Builder::new(params);
24 Ok(Self {
25 inner: builder.generate_keypair()?,
26 })
27 }
28
29 pub fn public_key(&self) -> &[u8] {
31 &self.inner.public
32 }
33
34 pub fn private_key(&self) -> &[u8] {
36 &self.inner.private
37 }
38}
39
40impl Clone for Keypair {
41 fn clone(&self) -> Self {
42 Keypair {
43 inner: snow::Keypair {
44 public: self.inner.public.clone(),
45 private: self.inner.private.clone(),
46 },
47 }
48 }
49}
50
51impl Serialize for Keypair {
52 fn serialize<S>(
53 &self,
54 serializer: S,
55 ) -> std::result::Result<S::Ok, S::Error>
56 where
57 S: Serializer,
58 {
59 let encoded = encode_keypair(self);
60 serializer.serialize_str(&encoded)
61 }
62}
63
64impl<'de> Deserialize<'de> for Keypair {
65 fn deserialize<D>(
66 deserializer: D,
67 ) -> std::result::Result<Keypair, D::Error>
68 where
69 D: Deserializer<'de>,
70 {
71 deserializer.deserialize_str(KeypairVisitor)
72 }
73}
74
75struct KeypairVisitor;
76
77impl<'de> Visitor<'de> for KeypairVisitor {
78 type Value = Keypair;
79
80 fn expecting(
81 &self,
82 formatter: &mut fmt::Formatter,
83 ) -> fmt::Result {
84 formatter.write_str("PEM encoded keypair")
85 }
86
87 fn visit_str<E>(
88 self,
89 value: &str,
90 ) -> std::result::Result<Self::Value, E>
91 where
92 E: de::Error,
93 {
94 let decoded = decode_keypair(value.as_bytes())
95 .map_err(de::Error::custom)?;
96 Ok(decoded)
97 }
98}
99
100pub fn generate_keypair() -> Result<Keypair> {
103 Keypair::new(PATTERN.parse()?)
104}
105
106pub fn encode_keypair(keypair: &Keypair) -> String {
108 let pattern_pem = Pem::new(PEM_PATTERN, PATTERN.as_bytes());
109 let public_pem =
110 Pem::new(PEM_PUBLIC, keypair.public_key().to_vec());
111 let private_pem =
112 Pem::new(PEM_PRIVATE, keypair.private_key().to_vec());
113 pem::encode_many(&[pattern_pem, public_pem, private_pem])
114}
115
116pub fn decode_keypair(keypair: impl AsRef<[u8]>) -> Result<Keypair> {
118 let mut pems = pem::parse_many(keypair)?;
119 if pems.len() == 3 {
120 let (first, second, third) =
121 (pems.remove(0), pems.remove(0), pems.remove(0));
122 if (PEM_PATTERN, PEM_PUBLIC, PEM_PRIVATE)
123 == (first.tag(), second.tag(), third.tag())
124 {
125 if first.into_contents() != PATTERN.as_bytes() {
126 return Err(Error::PatternMismatch(
127 PATTERN.to_string(),
128 ));
129 }
130
131 Ok(Keypair {
132 inner: snow::Keypair {
133 public: second.into_contents(),
134 private: third.into_contents(),
135 },
136 })
137 } else {
138 Err(Error::BadKeypairPem)
139 }
140 } else {
141 Err(Error::BadKeypairPem)
142 }
143}
144
145#[cfg(test)]
146mod tests {
147 use super::{decode_keypair, encode_keypair, generate_keypair};
148 use crate::{
149 Error, PATTERN, PEM_PATTERN, PEM_PRIVATE, PEM_PUBLIC, TAGLEN,
150 };
151 use anyhow::Result;
152 use pem::Pem;
153
154 #[test]
155 fn encode_decode_keypair() -> Result<()> {
156 let keypair = generate_keypair()?;
157 let pem = encode_keypair(&keypair);
158 let decoded = decode_keypair(&pem)?;
159 assert_eq!(keypair.public_key(), decoded.public_key());
160 assert_eq!(keypair.private_key(), decoded.private_key());
161 Ok(())
162 }
163
164 #[test]
165 fn decode_keypair_wrong_length() -> Result<()> {
166 let public_pem = Pem::new("INVALID TAG", vec![0; 32]);
167 let pem = pem::encode_many(&[public_pem]);
168 let result = decode_keypair(&pem);
169 assert!(matches!(result, Err(Error::BadKeypairPem)));
170 Ok(())
171 }
172
173 #[test]
174 fn decode_keypair_wrong_order() -> Result<()> {
175 let pattern_pem = Pem::new(PEM_PATTERN, vec![0; 32]);
176 let public_pem = Pem::new(PEM_PUBLIC, vec![0; 32]);
177 let private_pem = Pem::new(PEM_PRIVATE, vec![0; 32]);
178 let pem =
179 pem::encode_many(&[pattern_pem, private_pem, public_pem]);
180 let result = decode_keypair(&pem);
181 assert!(matches!(result, Err(Error::BadKeypairPem)));
182 Ok(())
183 }
184
185 #[test]
186 fn decode_keypair_pattern_mismatch() -> Result<()> {
187 let pattern_pem = Pem::new(PEM_PATTERN, vec![0; 32]);
188 let public_pem = Pem::new(PEM_PUBLIC, vec![0; 32]);
189 let private_pem = Pem::new(PEM_PRIVATE, vec![0; 32]);
190 let pem =
191 pem::encode_many(&[pattern_pem, public_pem, private_pem]);
192 let result = decode_keypair(&pem);
193 assert!(matches!(result, Err(Error::PatternMismatch(_))));
194 Ok(())
195 }
196
197 #[test]
198 fn noise_transport_encrypt_decrypt() -> Result<()> {
199 let builder_1 = snow::Builder::new(PATTERN.parse()?);
200 let builder_2 = snow::Builder::new(PATTERN.parse()?);
201
202 let keypair1 = builder_1.generate_keypair()?;
203 let keypair2 = builder_2.generate_keypair()?;
204
205 let mut initiator = builder_1
206 .local_private_key(&keypair1.private)
207 .remote_public_key(&keypair2.public)
208 .build_initiator()?;
209
210 let mut responder = builder_2
211 .local_private_key(&keypair2.private)
212 .remote_public_key(&keypair1.public)
213 .build_responder()?;
214
215 let (mut read_buf, mut first_msg, mut second_msg) =
216 ([0u8; 1024], [0u8; 1024], [0u8; 1024]);
217
218 let len = initiator.write_message(&[], &mut first_msg)?;
220
221 responder.read_message(&first_msg[..len], &mut read_buf)?;
223
224 let len = responder.write_message(&[], &mut second_msg)?;
226
227 initiator.read_message(&second_msg[..len], &mut read_buf)?;
229
230 let mut initiator = initiator.into_transport_mode()?;
232 let mut responder = responder.into_transport_mode()?;
233
234 let data = "this is the message that is sent out";
235 let payload = data.as_bytes();
236
237 let mut message = vec![0; payload.len() + TAGLEN];
238 let len = initiator.write_message(&payload, &mut message)?;
239
240 let payload = message;
241 let mut message = vec![0; len];
242 responder.read_message(&payload[..len], &mut message)?;
243
244 let new_length = len - TAGLEN;
245 message.truncate(new_length);
246
247 let decoded = std::str::from_utf8(&message)?;
248 assert_eq!(data, decoded);
249
250 Ok(())
251 }
252}