fire_crypto/cipher/
keypair.rs

1use super::{PublicKey, SharedSecret};
2#[cfg(feature = "b64")]
3use crate::error::DecodeError;
4use crate::error::TryFromError;
5
6use std::convert::{TryFrom, TryInto};
7use std::fmt;
8
9use rand::rngs::OsRng;
10
11use x25519_dalek as x;
12
13#[cfg(feature = "b64")]
14use base64::engine::{general_purpose::URL_SAFE_NO_PAD, Engine};
15
16// EphemeralKeypair
17
18/// A Keypair that can only be used once.
19pub struct EphemeralKeypair {
20	secret: x::EphemeralSecret,
21	public: PublicKey,
22}
23
24impl EphemeralKeypair {
25	pub fn new() -> Self {
26		let secret = x::EphemeralSecret::random_from_rng(OsRng);
27		let public = PublicKey::from_ephemeral_secret(&secret);
28
29		Self { secret, public }
30	}
31
32	// maybe return a Key??
33	pub fn diffie_hellman(self, public_key: &PublicKey) -> SharedSecret {
34		let secret = self.secret.diffie_hellman(public_key.inner());
35		SharedSecret::from_shared_secret(secret)
36	}
37
38	pub fn public(&self) -> &PublicKey {
39		&self.public
40	}
41}
42
43impl fmt::Debug for EphemeralKeypair {
44	fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
45		f.debug_struct("EphemeralKeypair")
46			.field("public", &self.public)
47			.finish()
48	}
49}
50
51// Keypair
52
53/// A Keypair that can be used multiple times.
54#[derive(Clone)]
55pub struct Keypair {
56	pub secret: x::StaticSecret,
57	pub public: PublicKey,
58}
59
60impl Keypair {
61	pub const LEN: usize = 32;
62
63	fn from_static_secret(secret: x::StaticSecret) -> Self {
64		let public = PublicKey::from_static_secret(&secret);
65
66		Self { secret, public }
67	}
68
69	pub fn new() -> Self {
70		Self::from_static_secret(x::StaticSecret::random_from_rng(OsRng))
71	}
72
73	/// ## Panics
74	/// if the slice is not 32 bytes long.
75	pub fn from_slice(slice: &[u8]) -> Self {
76		slice.try_into().unwrap()
77	}
78
79	pub fn to_bytes(&self) -> [u8; 32] {
80		self.secret.to_bytes()
81	}
82
83	// pub fn as_slice(&self) -> &[u8] {
84	// 	self.secret.as_ref()
85	// }
86
87	pub fn public(&self) -> &PublicKey {
88		&self.public
89	}
90
91	pub fn diffie_hellman(&self, public_key: &PublicKey) -> SharedSecret {
92		let secret = self.secret.diffie_hellman(public_key.inner());
93		SharedSecret::from_shared_secret(secret)
94	}
95}
96
97#[cfg(not(feature = "b64"))]
98impl fmt::Debug for Keypair {
99	fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
100		f.debug_struct("Keypair")
101			.field("secret", &self.to_bytes())
102			.field("public", &self.public)
103			.finish()
104	}
105}
106
107#[cfg(feature = "b64")]
108impl fmt::Debug for Keypair {
109	fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
110		f.debug_struct("Keypair")
111			.field("secret", &self.to_string())
112			.field("public", &self.public)
113			.finish()
114	}
115}
116
117// Display
118#[cfg(feature = "b64")]
119impl fmt::Display for Keypair {
120	fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
121		base64::display::Base64Display::new(&self.to_bytes(), &URL_SAFE_NO_PAD)
122			.fmt(f)
123	}
124}
125
126impl From<[u8; 32]> for Keypair {
127	fn from(bytes: [u8; 32]) -> Self {
128		Self::from_static_secret(x::StaticSecret::from(bytes))
129	}
130}
131
132impl TryFrom<&[u8]> for Keypair {
133	type Error = TryFromError;
134
135	fn try_from(v: &[u8]) -> Result<Self, Self::Error> {
136		<[u8; 32]>::try_from(v)
137			.map(Self::from)
138			.map_err(TryFromError::from_any)
139	}
140}
141
142#[cfg(feature = "b64")]
143impl crate::FromStr for Keypair {
144	type Err = DecodeError;
145
146	fn from_str(s: &str) -> Result<Self, Self::Err> {
147		if s.len() != crate::calculate_b64_len(Self::LEN) {
148			return Err(DecodeError::InvalidLength);
149		}
150
151		let mut bytes = [0u8; Self::LEN];
152		URL_SAFE_NO_PAD
153			.decode_slice_unchecked(s, &mut bytes)
154			.map(|_| Self::from(bytes))
155			.map_err(DecodeError::inv_bytes)
156	}
157}
158
159#[cfg(all(feature = "b64", feature = "serde"))]
160mod impl_serde {
161	use super::*;
162
163	use std::borrow::Cow;
164	use std::str::FromStr;
165
166	use _serde::de::Error;
167	use _serde::{Deserialize, Deserializer, Serialize, Serializer};
168
169	impl Serialize for Keypair {
170		fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
171		where
172			S: Serializer,
173		{
174			serializer.collect_str(&self)
175		}
176	}
177
178	impl<'de> Deserialize<'de> for Keypair {
179		fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
180		where
181			D: Deserializer<'de>,
182		{
183			let s: Cow<'_, str> = Deserialize::deserialize(deserializer)?;
184			Self::from_str(s.as_ref()).map_err(D::Error::custom)
185		}
186	}
187}
188
189#[cfg(all(feature = "b64", feature = "postgres"))]
190mod impl_postgres {
191	use super::*;
192
193	use bytes::BytesMut;
194	use postgres_types::{to_sql_checked, FromSql, IsNull, ToSql, Type};
195
196	impl ToSql for Keypair {
197		fn to_sql(
198			&self,
199			ty: &Type,
200			out: &mut BytesMut,
201		) -> Result<IsNull, Box<dyn std::error::Error + Sync + Send>>
202		where
203			Self: Sized,
204		{
205			self.to_string().to_sql(ty, out)
206		}
207
208		fn accepts(ty: &Type) -> bool
209		where
210			Self: Sized,
211		{
212			<&str as ToSql>::accepts(ty)
213		}
214
215		to_sql_checked!();
216	}
217
218	impl<'r> FromSql<'r> for Keypair {
219		fn from_sql(
220			ty: &Type,
221			raw: &'r [u8],
222		) -> Result<Self, Box<dyn std::error::Error + Sync + Send>> {
223			let s = <&str as FromSql>::from_sql(ty, raw)?;
224			s.parse().map_err(Into::into)
225		}
226
227		fn accepts(ty: &Type) -> bool {
228			<&str as FromSql>::accepts(ty)
229		}
230	}
231}