1#[cfg(feature = "alloc")]
4extern crate alloc;
5
6#[cfg(feature = "alloc")]
7use alloc::{
8 string::ToString,
9 vec::Vec,
10};
11
12#[cfg(feature = "alloc")]
13use lib_q_core::{
14 Aead,
15 AeadDecryptSemantic,
16 AeadKey,
17 DecryptSemanticOutcome,
18 Error,
19 Nonce,
20 Result,
21};
22use zeroize::Zeroizing;
23
24use crate::crypto::{
25 decrypt as tweak_decrypt,
26 decrypt_semantic_outcome,
27 encrypt as tweak_encrypt,
28};
29use crate::params::{
30 KEY_BYTES,
31 NONCE_BYTES,
32 TAG_BYTES,
33};
34
35pub struct TweakAead;
37
38impl TweakAead {
39 pub const fn new() -> Self {
40 Self
41 }
42
43 pub const fn key_size() -> usize {
44 KEY_BYTES
45 }
46
47 pub const fn nonce_size() -> usize {
48 NONCE_BYTES
49 }
50
51 pub const fn tag_size() -> usize {
52 TAG_BYTES
53 }
54}
55
56impl Default for TweakAead {
57 fn default() -> Self {
58 Self::new()
59 }
60}
61
62#[cfg(feature = "alloc")]
63impl Aead for TweakAead {
64 fn encrypt(
65 &self,
66 key: &AeadKey,
67 nonce: &Nonce,
68 plaintext: &[u8],
69 associated_data: Option<&[u8]>,
70 ) -> Result<Vec<u8>> {
71 let kb = key.as_bytes();
72 if kb.len() != KEY_BYTES {
73 return Err(Error::InvalidKeySize {
74 expected: KEY_BYTES,
75 actual: kb.len(),
76 });
77 }
78 let nb = nonce.as_bytes();
79 if nb.len() != NONCE_BYTES {
80 return Err(Error::InvalidNonceSize {
81 expected: NONCE_BYTES,
82 actual: nb.len(),
83 });
84 }
85 let key_arr = {
86 let mut k = Zeroizing::new([0u8; KEY_BYTES]);
87 k.copy_from_slice(kb);
88 k
89 };
90 let nonce_arr = {
91 let mut n = Zeroizing::new([0u8; NONCE_BYTES]);
92 n.copy_from_slice(nb);
93 n
94 };
95
96 let ad = associated_data.unwrap_or(&[]);
97 let mut out = alloc::vec![0u8; plaintext.len() + TAG_BYTES];
98 tweak_encrypt(&key_arr, &nonce_arr, ad, plaintext, &mut out).map_err(|_| {
99 Error::InvalidMessageSize {
100 max: usize::MAX,
101 actual: plaintext.len(),
102 }
103 })?;
104 Ok(out)
105 }
106
107 fn decrypt(
108 &self,
109 key: &AeadKey,
110 nonce: &Nonce,
111 ciphertext: &[u8],
112 associated_data: Option<&[u8]>,
113 ) -> Result<Vec<u8>> {
114 let kb = key.as_bytes();
115 if kb.len() != KEY_BYTES {
116 return Err(Error::InvalidKeySize {
117 expected: KEY_BYTES,
118 actual: kb.len(),
119 });
120 }
121 let nb = nonce.as_bytes();
122 if nb.len() != NONCE_BYTES {
123 return Err(Error::InvalidNonceSize {
124 expected: NONCE_BYTES,
125 actual: nb.len(),
126 });
127 }
128 if ciphertext.len() < TAG_BYTES {
129 return Err(Error::aead_ciphertext_shorter_than_tag(
130 TAG_BYTES,
131 ciphertext.len(),
132 ));
133 }
134 let key_arr = {
135 let mut k = Zeroizing::new([0u8; KEY_BYTES]);
136 k.copy_from_slice(kb);
137 k
138 };
139 let nonce_arr = {
140 let mut n = Zeroizing::new([0u8; NONCE_BYTES]);
141 n.copy_from_slice(nb);
142 n
143 };
144
145 let ad = associated_data.unwrap_or(&[]);
146 let body_len = ciphertext.len() - TAG_BYTES;
147 let mut pt = alloc::vec![0u8; body_len];
148 tweak_decrypt(&key_arr, &nonce_arr, ad, ciphertext, &mut pt).map_err(|_| {
149 Error::VerificationFailed {
150 operation: "AEAD tag verification".to_string(),
151 }
152 })?;
153 Ok(pt)
154 }
155}
156
157#[cfg(feature = "alloc")]
158impl AeadDecryptSemantic for TweakAead {
159 fn decrypt_semantic(
160 &self,
161 key: &AeadKey,
162 nonce: &Nonce,
163 ciphertext: &[u8],
164 associated_data: Option<&[u8]>,
165 ) -> Result<DecryptSemanticOutcome> {
166 let kb = key.as_bytes();
167 if kb.len() != KEY_BYTES {
168 return Err(Error::InvalidKeySize {
169 expected: KEY_BYTES,
170 actual: kb.len(),
171 });
172 }
173 let nb = nonce.as_bytes();
174 if nb.len() != NONCE_BYTES {
175 return Err(Error::InvalidNonceSize {
176 expected: NONCE_BYTES,
177 actual: nb.len(),
178 });
179 }
180 if ciphertext.len() < TAG_BYTES {
181 return Err(Error::aead_ciphertext_shorter_than_tag(
182 TAG_BYTES,
183 ciphertext.len(),
184 ));
185 }
186 let key_arr = {
187 let mut k = Zeroizing::new([0u8; KEY_BYTES]);
188 k.copy_from_slice(kb);
189 k
190 };
191 let nonce_arr = {
192 let mut n = Zeroizing::new([0u8; NONCE_BYTES]);
193 n.copy_from_slice(nb);
194 n
195 };
196
197 let ad = associated_data.unwrap_or(&[]);
198 decrypt_semantic_outcome(&key_arr, &nonce_arr, ad, ciphertext).map_err(|_| {
199 Error::VerificationFailed {
200 operation: "AEAD tag verification".to_string(),
201 }
202 })
203 }
204}