fire_crypto/cipher/
keypair.rs1use super::{PublicKey, SharedSecret};
2#[cfg(feature = "b64")]
3use crate::error::DecodeError;
4use crate::error::TryFromError;
5
6use std::convert::{TryFrom, TryInto};
7use std::fmt;
8
9use rand::rngs::OsRng;
10
11use x25519_dalek as x;
12
13#[cfg(feature = "b64")]
14use base64::engine::{general_purpose::URL_SAFE_NO_PAD, Engine};
15
16pub struct EphemeralKeypair {
20 secret: x::EphemeralSecret,
21 public: PublicKey,
22}
23
24impl EphemeralKeypair {
25 pub fn new() -> Self {
26 let secret = x::EphemeralSecret::random_from_rng(OsRng);
27 let public = PublicKey::from_ephemeral_secret(&secret);
28
29 Self { secret, public }
30 }
31
32 pub fn diffie_hellman(self, public_key: &PublicKey) -> SharedSecret {
34 let secret = self.secret.diffie_hellman(public_key.inner());
35 SharedSecret::from_shared_secret(secret)
36 }
37
38 pub fn public(&self) -> &PublicKey {
39 &self.public
40 }
41}
42
43impl fmt::Debug for EphemeralKeypair {
44 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
45 f.debug_struct("EphemeralKeypair")
46 .field("public", &self.public)
47 .finish()
48 }
49}
50
51#[derive(Clone)]
55pub struct Keypair {
56 pub secret: x::StaticSecret,
57 pub public: PublicKey,
58}
59
60impl Keypair {
61 pub const LEN: usize = 32;
62
63 fn from_static_secret(secret: x::StaticSecret) -> Self {
64 let public = PublicKey::from_static_secret(&secret);
65
66 Self { secret, public }
67 }
68
69 pub fn new() -> Self {
70 Self::from_static_secret(x::StaticSecret::random_from_rng(OsRng))
71 }
72
73 pub fn from_slice(slice: &[u8]) -> Self {
76 slice.try_into().unwrap()
77 }
78
79 pub fn to_bytes(&self) -> [u8; 32] {
80 self.secret.to_bytes()
81 }
82
83 pub fn public(&self) -> &PublicKey {
88 &self.public
89 }
90
91 pub fn diffie_hellman(&self, public_key: &PublicKey) -> SharedSecret {
92 let secret = self.secret.diffie_hellman(public_key.inner());
93 SharedSecret::from_shared_secret(secret)
94 }
95}
96
97#[cfg(not(feature = "b64"))]
98impl fmt::Debug for Keypair {
99 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
100 f.debug_struct("Keypair")
101 .field("secret", &self.to_bytes())
102 .field("public", &self.public)
103 .finish()
104 }
105}
106
107#[cfg(feature = "b64")]
108impl fmt::Debug for Keypair {
109 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
110 f.debug_struct("Keypair")
111 .field("secret", &self.to_string())
112 .field("public", &self.public)
113 .finish()
114 }
115}
116
117#[cfg(feature = "b64")]
119impl fmt::Display for Keypair {
120 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
121 base64::display::Base64Display::new(&self.to_bytes(), &URL_SAFE_NO_PAD)
122 .fmt(f)
123 }
124}
125
126impl From<[u8; 32]> for Keypair {
127 fn from(bytes: [u8; 32]) -> Self {
128 Self::from_static_secret(x::StaticSecret::from(bytes))
129 }
130}
131
132impl TryFrom<&[u8]> for Keypair {
133 type Error = TryFromError;
134
135 fn try_from(v: &[u8]) -> Result<Self, Self::Error> {
136 <[u8; 32]>::try_from(v)
137 .map(Self::from)
138 .map_err(TryFromError::from_any)
139 }
140}
141
142#[cfg(feature = "b64")]
143impl crate::FromStr for Keypair {
144 type Err = DecodeError;
145
146 fn from_str(s: &str) -> Result<Self, Self::Err> {
147 if s.len() != crate::calculate_b64_len(Self::LEN) {
148 return Err(DecodeError::InvalidLength);
149 }
150
151 let mut bytes = [0u8; Self::LEN];
152 URL_SAFE_NO_PAD
153 .decode_slice_unchecked(s, &mut bytes)
154 .map(|_| Self::from(bytes))
155 .map_err(DecodeError::inv_bytes)
156 }
157}
158
159#[cfg(all(feature = "b64", feature = "serde"))]
160mod impl_serde {
161 use super::*;
162
163 use std::borrow::Cow;
164 use std::str::FromStr;
165
166 use _serde::de::Error;
167 use _serde::{Deserialize, Deserializer, Serialize, Serializer};
168
169 impl Serialize for Keypair {
170 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
171 where
172 S: Serializer,
173 {
174 serializer.collect_str(&self)
175 }
176 }
177
178 impl<'de> Deserialize<'de> for Keypair {
179 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
180 where
181 D: Deserializer<'de>,
182 {
183 let s: Cow<'_, str> = Deserialize::deserialize(deserializer)?;
184 Self::from_str(s.as_ref()).map_err(D::Error::custom)
185 }
186 }
187}
188
189#[cfg(all(feature = "b64", feature = "postgres"))]
190mod impl_postgres {
191 use super::*;
192
193 use bytes::BytesMut;
194 use postgres_types::{to_sql_checked, FromSql, IsNull, ToSql, Type};
195
196 impl ToSql for Keypair {
197 fn to_sql(
198 &self,
199 ty: &Type,
200 out: &mut BytesMut,
201 ) -> Result<IsNull, Box<dyn std::error::Error + Sync + Send>>
202 where
203 Self: Sized,
204 {
205 self.to_string().to_sql(ty, out)
206 }
207
208 fn accepts(ty: &Type) -> bool
209 where
210 Self: Sized,
211 {
212 <&str as ToSql>::accepts(ty)
213 }
214
215 to_sql_checked!();
216 }
217
218 impl<'r> FromSql<'r> for Keypair {
219 fn from_sql(
220 ty: &Type,
221 raw: &'r [u8],
222 ) -> Result<Self, Box<dyn std::error::Error + Sync + Send>> {
223 let s = <&str as FromSql>::from_sql(ty, raw)?;
224 s.parse().map_err(Into::into)
225 }
226
227 fn accepts(ty: &Type) -> bool {
228 <&str as FromSql>::accepts(ty)
229 }
230 }
231}