1use std::{ops::Deref, sync::Arc};
2
3use object_rainbow::{
4 Address, ByteNode, Error, FailFuture, Fetch, Hash, Object, Parse, ParseSliceExtra, Point,
5 PointInput, PointVisitor, RawPoint, Resolve, Tagged, ToOutput, Topological,
6 length_prefixed::Lp,
7};
8
9#[derive(Clone)]
10pub struct WithKey<K, Extra> {
11 pub key: K,
12 pub extra: Extra,
13}
14
15pub trait Key: 'static + Sized + Send + Sync + Clone {
16 fn encrypt(&self, data: &[u8]) -> Vec<u8>;
17 fn decrypt(&self, data: &[u8]) -> object_rainbow::Result<Vec<u8>>;
18}
19
20type Resolution<K, Extra> = Arc<Lp<Vec<RawPoint<Encrypted<K, Vec<u8>, Extra>, WithKey<K, Extra>>>>>;
21
22#[derive(ToOutput, Clone)]
23struct Unkeyed<T>(T);
24
25impl<
26 T: Parse<I::WithExtra<Extra>>,
27 K: 'static + Clone,
28 Extra: 'static + Clone,
29 I: PointInput<Extra = WithKey<K, Extra>>,
30> Parse<I> for Unkeyed<T>
31{
32 fn parse(input: I) -> object_rainbow::Result<Self> {
33 Ok(Self(T::parse(
34 input.map_extra(|WithKey { extra, .. }| extra),
35 )?))
36 }
37}
38
39#[derive(ToOutput, Parse)]
40struct EncryptedInner<K, T, Extra> {
41 resolution: Resolution<K, Extra>,
42 decrypted: Unkeyed<Arc<T>>,
43}
44
45impl<K, T, Extra> Clone for EncryptedInner<K, T, Extra> {
46 fn clone(&self) -> Self {
47 Self {
48 resolution: self.resolution.clone(),
49 decrypted: self.decrypted.clone(),
50 }
51 }
52}
53
54type ResolutionIter<'a, K, Extra> =
55 std::slice::Iter<'a, RawPoint<Encrypted<K, Vec<u8>, Extra>, WithKey<K, Extra>>>;
56
57struct IterateResolution<'a, K, V, Extra> {
58 resolution: ResolutionIter<'a, K, Extra>,
59 visitor: &'a mut V,
60}
61
62impl<'a, K: Key, V: PointVisitor<WithKey<K, Extra>>, Extra: 'static + Send + Sync + Clone>
63 PointVisitor<Extra> for IterateResolution<'a, K, V, Extra>
64{
65 fn visit<T: Object<Extra>>(&mut self, _: &Point<T, Extra>) {
66 let point = self
67 .resolution
68 .next()
69 .expect("length mismatch")
70 .clone()
71 .cast::<Encrypted<K, T, Extra>>()
72 .point();
73 self.visitor.visit(&point);
74 }
75}
76
77impl<K: Key, T: Topological<Extra>, Extra: 'static + Send + Sync + Clone>
78 Topological<WithKey<K, Extra>> for EncryptedInner<K, T, Extra>
79{
80 fn accept_points(&self, visitor: &mut impl PointVisitor<WithKey<K, Extra>>) {
81 self.decrypted.0.accept_points(&mut IterateResolution {
82 resolution: self.resolution.iter(),
83 visitor,
84 });
85 }
86
87 fn point_count(&self) -> usize {
88 self.resolution.len()
89 }
90
91 fn topology_hash(&self) -> Hash {
92 self.resolution.0.data_hash()
93 }
94}
95
96pub struct Encrypted<K, T, Extra> {
97 key: K,
98 inner: EncryptedInner<K, T, Extra>,
99}
100
101impl<K, T: Clone, Extra> Encrypted<K, T, Extra> {
102 pub fn into_inner(self) -> T {
103 Arc::unwrap_or_clone(self.inner.decrypted.0)
104 }
105}
106
107impl<K, T, Extra> Deref for Encrypted<K, T, Extra> {
108 type Target = T;
109
110 fn deref(&self) -> &Self::Target {
111 self.inner.decrypted.0.as_ref()
112 }
113}
114
115impl<K: Clone, T, Extra> Clone for Encrypted<K, T, Extra> {
116 fn clone(&self) -> Self {
117 Self {
118 key: self.key.clone(),
119 inner: self.inner.clone(),
120 }
121 }
122}
123
124impl<K: Key, T: Topological<Extra>, Extra: 'static + Send + Sync + Clone>
125 Topological<WithKey<K, Extra>> for Encrypted<K, T, Extra>
126{
127 fn accept_points(&self, visitor: &mut impl PointVisitor<WithKey<K, Extra>>) {
128 self.inner.accept_points(visitor);
129 }
130
131 fn topology_hash(&self) -> Hash {
132 self.inner.topology_hash()
133 }
134}
135
136impl<K: Key, T: ToOutput, Extra> ToOutput for Encrypted<K, T, Extra> {
137 fn to_output(&self, output: &mut dyn object_rainbow::Output) {
138 let source = self.inner.vec();
139 output.write(&self.key.encrypt(&source));
140 }
141}
142
143#[derive(Clone)]
144struct Decrypt<K, Extra> {
145 resolution: Resolution<K, Extra>,
146}
147
148impl<K: Key, Extra: 'static + Send + Sync + Clone> Resolve for Decrypt<K, Extra> {
149 fn resolve(&'_ self, address: Address) -> FailFuture<'_, ByteNode> {
150 Box::pin(async move {
151 let Encrypted {
152 key: _,
153 inner:
154 EncryptedInner {
155 resolution,
156 decrypted,
157 },
158 } = self
159 .resolution
160 .get(address.index)
161 .ok_or(Error::AddressOutOfBounds)?
162 .clone()
163 .fetch()
164 .await?;
165 Ok((
166 Arc::into_inner(decrypted.0).expect("not shared because reconstructed"),
167 Arc::new(Decrypt { resolution }) as _,
168 ))
169 })
170 }
171
172 fn name(&self) -> &str {
173 "decrypt"
174 }
175}
176
177impl<
178 K: Key,
179 T: Object<Extra>,
180 Extra: 'static + Send + Sync + Clone,
181 I: PointInput<Extra = WithKey<K, Extra>>,
182> Parse<I> for Encrypted<K, T, Extra>
183{
184 fn parse(input: I) -> object_rainbow::Result<Self> {
185 let with_key = input.extra().clone();
186 let resolve = input.resolve().clone();
187 let source = with_key.key.decrypt(input.parse_all()?)?;
188 let EncryptedInner {
189 resolution,
190 decrypted,
191 } = EncryptedInner::<K, Vec<u8>, Extra>::parse_slice_extra(&source, &resolve, &with_key)?;
192 let decrypted = T::parse_slice_extra(
193 &decrypted.0,
194 &(Arc::new(Decrypt {
195 resolution: resolution.clone(),
196 }) as _),
197 &with_key.extra,
198 )?;
199 let decrypted = Unkeyed(Arc::new(decrypted));
200 let inner = EncryptedInner {
201 resolution,
202 decrypted,
203 };
204 Ok(Self {
205 key: with_key.key,
206 inner,
207 })
208 }
209}
210
211impl<K, T, Extra> Tagged for Encrypted<K, T, Extra> {}
212
213impl<K: Key, T: Object<Extra>, Extra: 'static + Send + Sync + Clone> Object<WithKey<K, Extra>>
214 for Encrypted<K, T, Extra>
215{
216}
217
218type Extracted<K, Extra> = Vec<
219 std::pin::Pin<
220 Box<
221 dyn Future<
222 Output = Result<
223 RawPoint<Encrypted<K, Vec<u8>, Extra>, WithKey<K, Extra>>,
224 Error,
225 >,
226 > + Send
227 + 'static,
228 >,
229 >,
230>;
231
232struct ExtractResolution<'a, K, Extra>(&'a mut Extracted<K, Extra>, &'a K);
233
234impl<K: Key, Extra: 'static + Send + Sync + Clone> PointVisitor<Extra>
235 for ExtractResolution<'_, K, Extra>
236{
237 fn visit<T: Object<Extra>>(&mut self, point: &Point<T, Extra>) {
238 let point = point.clone();
239 let key = self.1.clone();
240 self.0.push(Box::pin(async move {
241 let point = encrypt_point(key, point).await?.raw().cast();
242 Ok(point)
243 }));
244 }
245}
246
247pub async fn encrypt_point<K: Key, T: Object<Extra>, Extra: 'static + Send + Sync + Clone>(
248 key: K,
249 point: Point<T, Extra>,
250) -> object_rainbow::Result<Point<Encrypted<K, T, Extra>, WithKey<K, Extra>>> {
251 if let Some((address, decrypt)) = point.extract_resolve::<Decrypt<K, Extra>>() {
252 let point = decrypt
253 .resolution
254 .get(address.index)
255 .ok_or(Error::AddressOutOfBounds)?;
256 return Ok(point.clone().cast().point());
257 }
258 let decrypted = point.fetch().await?;
259 let encrypted = encrypt(key.clone(), decrypted).await?;
260 let point = encrypted.point_extra(WithKey {
261 key,
262 extra: point.extra().clone(),
263 });
264 Ok(point)
265}
266
267pub async fn encrypt<K: Key, T: Object<Extra>, Extra: 'static + Send + Sync + Clone>(
268 key: K,
269 decrypted: T,
270) -> object_rainbow::Result<Encrypted<K, T, Extra>> {
271 let mut futures = Vec::with_capacity(decrypted.point_count());
272 decrypted.accept_points(&mut ExtractResolution(&mut futures, &key));
273 let resolution = futures_util::future::try_join_all(futures).await?;
274 let resolution = Arc::new(Lp(resolution));
275 let decrypted = Unkeyed(Arc::new(decrypted));
276 let inner = EncryptedInner {
277 resolution,
278 decrypted,
279 };
280 Ok(Encrypted { key, inner })
281}