1use crate::{SdkError, SdkResult};
2use chrono::{DateTime, Utc};
3use rand::Rng;
4use serde::{Deserialize, Serialize, de::DeserializeOwned};
5use std::marker::PhantomData;
6
7const ENCRYPTED_PAYLOAD_CURRENT_VERSION: u8 = 1;
8
9#[derive(Debug, Clone, Serialize, Deserialize)]
15#[serde(
16 tag = "version",
17 content = "payload",
18 bound = "I: Serialize + DeserializeOwned + Sync"
19)]
20pub enum EncryptedPayload<I: Serialize + DeserializeOwned + Sync> {
21 #[serde(rename = "1")]
23 V1(Box<EncryptedPayloadV1<I>>),
24}
25
26impl<I: Serialize + DeserializeOwned + Sync> From<EncryptedPayloadV1<I>> for EncryptedPayload<I> {
27 fn from(payload: EncryptedPayloadV1<I>) -> Self {
28 Self::V1(Box::new(payload))
29 }
30}
31
32impl<I: Serialize + DeserializeOwned + Sync> EncryptedPayload<I> {
33 pub fn json_encrypt(
35 my_secret_key: &[u8; 32],
36 their_public_key: &[u8; 32],
37 msg: &I,
38 ) -> SdkResult<Self> {
39 let msg_json = serde_json::to_vec(msg)?;
40 Ok(Self::encrypt(my_secret_key, their_public_key, &msg_json))
41 }
42
43 pub fn json_decrypt(&self, my_keys: &[u8; 32]) -> SdkResult<(I, [u8; 32])> {
46 let (decrypted, their_public_key) = self.decrypt(my_keys)?;
47 let msg = serde_json::from_slice(&decrypted)?;
48 Ok((msg, their_public_key))
49 }
50
51 pub fn encrypt(my_secret_key: &[u8; 32], their_public_key: &[u8; 32], msg: &[u8]) -> Self {
53 let random = rand::rngs::OsRng.r#gen::<[u8; 16]>();
54 let created_at = Utc::now();
55 let timestamp = created_at.timestamp() as u64;
56 let mut my_public_key = [0u8; 32];
57 sodalite::scalarmult_base(&mut my_public_key, my_secret_key);
58 let aad: Vec<u8> = std::iter::once(ENCRYPTED_PAYLOAD_CURRENT_VERSION)
59 .chain(random.iter().copied())
60 .chain(timestamp.to_be_bytes().iter().copied())
61 .chain(their_public_key.iter().copied())
62 .chain(my_public_key.iter().copied())
63 .collect();
64
65 let mut hash = [0u8; 64];
69 sodalite::hash(&mut hash, &aad);
70 let nonce: [u8; 24] = (&hash[..24]).try_into().expect("Failed to convert nonce");
71
72 let mut plaintext = vec![0u8; 32];
73 plaintext.extend_from_slice(msg);
74 let mut ciphertext_and_tag = vec![0u8; plaintext.len()];
75 {
76 sodalite::box_(
77 &mut ciphertext_and_tag,
78 &plaintext,
79 &nonce,
80 their_public_key,
81 my_secret_key,
82 )
83 .expect("Failed to encrypt message");
84 }
85
86 EncryptedPayloadV1 {
87 verification_key: my_public_key,
88 random,
89 created_at,
90 ciphertext_and_tag,
91 _inner_representation: PhantomData,
92 }
93 .into()
94 }
95
96 pub fn decrypt(&self, my_secret_key: &[u8; 32]) -> SdkResult<(Vec<u8>, [u8; 32])> {
98 let mut my_public_key = [0u8; 32];
99 sodalite::scalarmult_base(&mut my_public_key, my_secret_key);
100 match self {
101 Self::V1(v1) => {
102 let timestamp = v1.created_at.timestamp() as u64;
103 let aad = std::iter::once(ENCRYPTED_PAYLOAD_CURRENT_VERSION)
104 .chain(v1.random.iter().copied())
105 .chain(timestamp.to_be_bytes().iter().copied())
106 .chain(my_public_key.iter().copied())
107 .chain(v1.verification_key.iter().copied())
108 .collect::<Vec<_>>();
109
110 let mut hash = [0u8; 64];
114 sodalite::hash(&mut hash, &aad);
115 let nonce: [u8; 24] = (&hash[..24]).try_into().expect("Failed to convert nonce");
116 let mut plaintext = vec![0u8; v1.ciphertext_and_tag.len()];
117 {
118 sodalite::box_open(
119 &mut plaintext,
120 &v1.ciphertext_and_tag,
121 &nonce,
122 &v1.verification_key,
123 my_secret_key,
124 )
125 .map_err(|_| {
126 SdkError::Decryption("encrypted payload decryption failed".to_string())
127 })?;
128 }
129 Ok((plaintext[32..].to_vec(), v1.verification_key))
130 }
131 }
132 }
133}
134
135#[derive(Clone, Debug, Deserialize, Serialize)]
137pub struct EncryptedPayloadV1<I: Serialize + DeserializeOwned + Sync> {
138 #[serde(with = "hex")]
140 verification_key: [u8; 32],
141 #[serde(with = "hex")]
143 random: [u8; 16],
144 #[serde(with = "chrono_rfc3339")]
146 created_at: DateTime<Utc>,
147 #[serde(with = "hex")]
149 ciphertext_and_tag: Vec<u8>,
150 #[serde(skip)]
155 _inner_representation: PhantomData<I>,
156}
157
158mod chrono_rfc3339 {
159 use chrono::{DateTime, Utc};
160 use serde::{Deserialize, Deserializer, Serialize, Serializer};
161 pub fn serialize<S>(date: &DateTime<Utc>, s: S) -> Result<S::Ok, S::Error>
162 where
163 S: Serializer,
164 {
165 date.to_rfc3339().serialize(s)
166 }
167
168 pub fn deserialize<'de, D>(d: D) -> Result<DateTime<Utc>, D::Error>
169 where
170 D: Deserializer<'de>,
171 {
172 let s = String::deserialize(d)?;
173 DateTime::parse_from_rfc3339(&s)
174 .map_err(serde::de::Error::custom)
175 .map(DateTime::from)
176 }
177}