1use std::convert::TryFrom;
27
28use aes_kw::KekAes256 as Kek;
29use serde::{Deserialize, Serialize};
30
31use super::keys::{KeyError, PublicKey, SecretKey, PUBLIC_KEY_SIZE};
32use super::secret::{Secret, SecretError, SECRET_SIZE};
33
34pub const KW_NONCE_SIZE: usize = 8;
36pub const SHARE_SIZE: usize = PUBLIC_KEY_SIZE + SECRET_SIZE + KW_NONCE_SIZE;
41
42#[derive(Debug, thiserror::Error)]
44pub enum ShareError {
45 #[error("share error: {0}")]
46 Default(#[from] anyhow::Error),
47 #[error("key error: {0}")]
48 Key(#[from] KeyError),
49 #[error("secret error: {0}")]
50 Secret(#[from] SecretError),
51}
52
53#[derive(Clone, Copy, Debug, Eq, PartialEq, Hash)]
79pub struct Share(pub(crate) [u8; SHARE_SIZE]);
80
81impl Serialize for Share {
82 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
83 where
84 S: serde::Serializer,
85 {
86 serializer.serialize_bytes(&self.0)
87 }
88}
89
90impl<'de> Deserialize<'de> for Share {
91 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
92 where
93 D: serde::Deserializer<'de>,
94 {
95 use serde::de::{Error, Visitor};
96 use std::fmt;
97
98 struct ShareVisitor;
99
100 impl<'de> Visitor<'de> for ShareVisitor {
101 type Value = Share;
102
103 fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
104 formatter.write_str("a byte array or sequence of SHARE_SIZE")
105 }
106
107 fn visit_bytes<E>(self, v: &[u8]) -> Result<Self::Value, E>
108 where
109 E: Error,
110 {
111 if v.len() != SHARE_SIZE {
112 return Err(E::invalid_length(
113 v.len(),
114 &format!("expected {} bytes", SHARE_SIZE).as_str(),
115 ));
116 }
117 let mut array = [0u8; SHARE_SIZE];
118 array.copy_from_slice(v);
119 Ok(Share(array))
120 }
121
122 fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
123 where
124 A: serde::de::SeqAccess<'de>,
125 {
126 let mut bytes = Vec::new();
127 while let Some(byte) = seq.next_element::<u8>()? {
128 bytes.push(byte);
129 }
130 if bytes.len() != SHARE_SIZE {
131 return Err(A::Error::invalid_length(
132 bytes.len(),
133 &format!("expected {} bytes", SHARE_SIZE).as_str(),
134 ));
135 }
136 let mut array = [0u8; SHARE_SIZE];
137 array.copy_from_slice(&bytes);
138 Ok(Share(array))
139 }
140 }
141
142 deserializer.deserialize_byte_buf(ShareVisitor)
144 }
145}
146
147impl Default for Share {
148 fn default() -> Self {
149 Share([0; SHARE_SIZE])
150 }
151}
152
153impl From<[u8; SHARE_SIZE]> for Share {
154 fn from(bytes: [u8; SHARE_SIZE]) -> Self {
155 Share(bytes)
156 }
157}
158
159impl From<Share> for [u8; SHARE_SIZE] {
160 fn from(share: Share) -> Self {
161 share.0
162 }
163}
164
165impl TryFrom<&[u8]> for Share {
166 type Error = ShareError;
167 fn try_from(bytes: &[u8]) -> Result<Self, Self::Error> {
168 if bytes.len() != SHARE_SIZE {
169 return Err(anyhow::anyhow!(
170 "invalid share size, expected {}, got {}",
171 SHARE_SIZE,
172 bytes.len()
173 )
174 .into());
175 }
176 let mut share = Share::default();
177 share.0.copy_from_slice(bytes);
178 Ok(share)
179 }
180}
181
182impl Share {
183 pub fn from_hex(hex: &str) -> Result<Self, ShareError> {
187 let hex = hex.strip_prefix("0x").unwrap_or(hex);
188 let mut buff = [0; SHARE_SIZE];
189 hex::decode_to_slice(hex, &mut buff).map_err(|_| anyhow::anyhow!("hex decode error"))?;
190 Ok(Share::from(buff))
191 }
192
193 #[allow(clippy::wrong_self_convention)]
195 pub fn to_hex(&self) -> String {
196 hex::encode(self.0)
197 }
198
199 pub fn new(secret: &Secret, recipient: &PublicKey) -> Result<Self, ShareError> {
217 let ephemeral_private = SecretKey::generate();
219 let ephemeral_public = ephemeral_private.public();
220
221 let ephemeral_x25519_private = ephemeral_private.to_x25519();
223 let recipient_x25519_public = recipient.to_x25519()?;
224
225 let shared_secret = ephemeral_x25519_private.diffie_hellman(&recipient_x25519_public);
227
228 let mut shared_secret_bytes = [0; SECRET_SIZE];
231 shared_secret_bytes.copy_from_slice(shared_secret.as_bytes());
232 let kek = Kek::from(shared_secret_bytes);
233 let wrapped = kek
234 .wrap_vec(secret.bytes())
235 .map_err(|_| anyhow::anyhow!("AES-KW wrap error"))?;
236
237 let mut share = Share::default();
239 let ephemeral_bytes = ephemeral_public.to_bytes();
240
241 if ephemeral_bytes.len() + wrapped.len() != SHARE_SIZE {
243 return Err(anyhow::anyhow!("expected share size is incorrect").into());
244 };
245
246 share.0[..PUBLIC_KEY_SIZE].copy_from_slice(&ephemeral_bytes);
248 share.0[PUBLIC_KEY_SIZE..PUBLIC_KEY_SIZE + wrapped.len()].copy_from_slice(&wrapped);
249
250 Ok(share)
251 }
252
253 pub fn recover(&self, recipient_secret: &SecretKey) -> Result<Secret, ShareError> {
277 let ephemeral_public_bytes = &self.0[..PUBLIC_KEY_SIZE];
279 let ephemeral_public = PublicKey::try_from(ephemeral_public_bytes)?;
280
281 let recipient_x25519_private = recipient_secret.to_x25519();
283 let ephemeral_x25519_public = ephemeral_public.to_x25519()?;
284
285 let shared_secret = recipient_x25519_private.diffie_hellman(&ephemeral_x25519_public);
287
288 let shared_secret_bytes = *shared_secret.as_bytes();
290 let kek = Kek::from(shared_secret_bytes);
291 let wrapped_data = &self.0[PUBLIC_KEY_SIZE..];
292
293 let unwrapped = kek
295 .unwrap_vec(wrapped_data)
296 .map_err(|_| anyhow::anyhow!("AES-KW unwrap error"))?;
297
298 if unwrapped.len() != SECRET_SIZE {
299 return Err(anyhow::anyhow!("unwrapped secret has wrong size").into());
300 }
301
302 let mut secret_bytes = [0; SECRET_SIZE];
303 secret_bytes.copy_from_slice(&unwrapped);
304 Ok(Secret::from(secret_bytes))
305 }
306
307 pub fn bytes(&self) -> &[u8] {
309 &self.0
310 }
311}
312
313#[cfg(test)]
314mod test {
315 use super::*;
316
317 #[test]
318 fn test_share_secret() {
319 let secret = Secret::from_slice(&[42u8; SECRET_SIZE]).unwrap();
320 let private_key = SecretKey::generate();
321 let public_key = private_key.public();
322 let share = Share::new(&secret, &public_key).unwrap();
323 let recovered_secret = share.recover(&private_key).unwrap();
324 assert_eq!(secret, recovered_secret);
325 }
326
327 #[test]
328 fn test_share_different_keys() {
329 let secret = Secret::generate();
330 let alice_private = SecretKey::generate();
331 let alice_public = alice_private.public();
332 let bob_private = SecretKey::generate();
333 let share = Share::new(&secret, &alice_public).unwrap();
335 let recovered_by_alice = share.recover(&alice_private).unwrap();
337 assert_eq!(secret, recovered_by_alice);
338 let result = share.recover(&bob_private);
340 assert!(result.is_err());
341 }
342
343 #[test]
344 fn test_share_hex_roundtrip() {
345 let secret = Secret::generate();
346 let private_key = SecretKey::generate();
347 let public_key = private_key.public();
348 let share = Share::new(&secret, &public_key).unwrap();
349 let hex = share.to_hex();
350 let recovered_share = Share::from_hex(&hex).unwrap();
351 assert_eq!(share, recovered_share);
352 let recovered_secret = recovered_share.recover(&private_key).unwrap();
353 assert_eq!(secret, recovered_secret);
354 }
355
356 #[test]
357 fn test_share_serde_json_roundtrip() {
358 let secret = Secret::generate();
359 let private_key = SecretKey::generate();
360 let public_key = private_key.public();
361 let share = Share::new(&secret, &public_key).unwrap();
362
363 let json = serde_json::to_string(&share).unwrap();
365
366 let recovered_share: Share = serde_json::from_str(&json).unwrap();
368
369 assert_eq!(share, recovered_share);
371
372 let recovered_secret = recovered_share.recover(&private_key).unwrap();
374 assert_eq!(secret, recovered_secret);
375 }
376
377 #[test]
378 fn test_share_serde_bincode_roundtrip() {
379 let secret = Secret::generate();
380 let private_key = SecretKey::generate();
381 let public_key = private_key.public();
382 let share = Share::new(&secret, &public_key).unwrap();
383
384 let binary = bincode::serialize(&share).unwrap();
386
387 let recovered_share: Share = bincode::deserialize(&binary).unwrap();
389
390 assert_eq!(share, recovered_share);
392
393 let recovered_secret = recovered_share.recover(&private_key).unwrap();
395 assert_eq!(secret, recovered_secret);
396 }
397
398 #[test]
399 fn test_share_deserialize_invalid_length() {
400 let short_data = vec![0u8; SHARE_SIZE - 1];
402 let result: Result<Share, _> =
403 bincode::deserialize(&bincode::serialize(&short_data).unwrap());
404 assert!(result.is_err());
405
406 let long_data = vec![0u8; SHARE_SIZE + 1];
408 let result: Result<Share, _> =
409 bincode::deserialize(&bincode::serialize(&long_data).unwrap());
410 assert!(result.is_err());
411 }
412
413 #[test]
414 fn test_share_deserialize_exact_size() {
415 let exact_data = vec![0u8; SHARE_SIZE];
417 let serialized = bincode::serialize(&exact_data).unwrap();
418 let result: Result<Share, _> = bincode::deserialize(&serialized);
419 assert!(result.is_ok());
420
421 let share = result.unwrap();
422 assert_eq!(share.0, [0u8; SHARE_SIZE]);
423 }
424
425 #[test]
426 fn test_share_serde_multiple_formats() {
427 let secret = Secret::generate();
428 let private_key = SecretKey::generate();
429 let public_key = private_key.public();
430 let original_share = Share::new(&secret, &public_key).unwrap();
431
432 let json = serde_json::to_string(&original_share).unwrap();
434 let json_share: Share = serde_json::from_str(&json).unwrap();
435 assert_eq!(original_share, json_share);
436
437 let binary = bincode::serialize(&original_share).unwrap();
439 let binary_share: Share = bincode::deserialize(&binary).unwrap();
440 assert_eq!(original_share, binary_share);
441
442 assert_eq!(json_share, binary_share);
444
445 let secret1 = json_share.recover(&private_key).unwrap();
447 let secret2 = binary_share.recover(&private_key).unwrap();
448 assert_eq!(secret, secret1);
449 assert_eq!(secret, secret2);
450 assert_eq!(secret1, secret2);
451 }
452}