1use std::{ops::Deref, sync::Arc};
2
3use object_rainbow::{
4 Address, ByteNode, Error, FailFuture, Fetch, FetchBytes, Hash, Object, Parse, ParseSliceExtra,
5 Point, PointInput, PointVisitor, Resolve, Singular, Tagged, ToOutput, Topological, Traversible,
6 length_prefixed::Lp,
7};
8
9#[derive(Clone)]
10pub struct WithKey<K, Extra> {
11 pub key: K,
12 pub extra: Extra,
13}
14
15pub trait Key: 'static + Sized + Send + Sync + Clone {
16 fn encrypt(&self, data: &[u8]) -> Vec<u8>;
17 fn decrypt(&self, data: &[u8]) -> object_rainbow::Result<Vec<u8>>;
18}
19
20type Resolution<K> = Arc<Lp<Vec<Point<Encrypted<K, Vec<u8>>>>>>;
21
22#[derive(ToOutput, Clone)]
23struct Unkeyed<T>(T);
24
25impl<
26 T: Parse<I::WithExtra<Extra>>,
27 K: 'static + Clone,
28 Extra: 'static + Clone,
29 I: PointInput<Extra = WithKey<K, Extra>>,
30> Parse<I> for Unkeyed<T>
31{
32 fn parse(input: I) -> object_rainbow::Result<Self> {
33 Ok(Self(T::parse(
34 input.map_extra(|WithKey { extra, .. }| extra),
35 )?))
36 }
37}
38
39#[derive(ToOutput, Parse)]
40struct EncryptedInner<K, T> {
41 resolution: Resolution<K>,
42 decrypted: Unkeyed<Arc<T>>,
43}
44
45impl<K, T> Clone for EncryptedInner<K, T> {
46 fn clone(&self) -> Self {
47 Self {
48 resolution: self.resolution.clone(),
49 decrypted: self.decrypted.clone(),
50 }
51 }
52}
53
54type ResolutionIter<'a, K> = std::slice::Iter<'a, Point<Encrypted<K, Vec<u8>>>>;
55
56struct IterateResolution<'a, 'r, K, V> {
57 resolution: &'r mut ResolutionIter<'a, K>,
58 visitor: &'a mut V,
59}
60
61struct Visited<K, T> {
62 decrypted: Point<T>,
63 encrypted: Point<Encrypted<K, Vec<u8>>>,
64}
65
66impl<K, T> FetchBytes for Visited<K, T> {
67 fn fetch_bytes(&'_ self) -> FailFuture<'_, ByteNode> {
68 self.encrypted.fetch_bytes()
69 }
70
71 fn fetch_data(&'_ self) -> FailFuture<'_, Vec<u8>> {
72 self.encrypted.fetch_data()
73 }
74}
75
76impl<K: Key, T: Traversible> Fetch for Visited<K, T> {
77 type T = Encrypted<K, T>;
78
79 fn fetch_full(&'_ self) -> FailFuture<'_, (Self::T, Arc<dyn Resolve>)> {
80 Box::pin(async move {
81 let (
82 Encrypted {
83 key,
84 inner:
85 EncryptedInner {
86 resolution,
87 decrypted: _,
88 },
89 },
90 resolve,
91 ) = self.encrypted.fetch_full().await?;
92 let decrypted = self.decrypted.fetch().await?;
93 let decrypted = Unkeyed(Arc::new(decrypted));
94 Ok((
95 Encrypted {
96 key,
97 inner: EncryptedInner {
98 resolution,
99 decrypted,
100 },
101 },
102 resolve,
103 ))
104 })
105 }
106
107 fn fetch(&'_ self) -> FailFuture<'_, Self::T> {
108 Box::pin(async move {
109 let Encrypted {
110 key,
111 inner:
112 EncryptedInner {
113 resolution,
114 decrypted: _,
115 },
116 } = self.encrypted.fetch().await?;
117 let decrypted = self.decrypted.fetch().await?;
118 let decrypted = Unkeyed(Arc::new(decrypted));
119 Ok(Encrypted {
120 key,
121 inner: EncryptedInner {
122 resolution,
123 decrypted,
124 },
125 })
126 })
127 }
128}
129
130impl<'a, K: Key, V: PointVisitor> PointVisitor for IterateResolution<'a, '_, K, V> {
131 fn visit<T: Traversible>(&mut self, decrypted: &Point<T>) {
132 let decrypted = decrypted.clone();
133 let encrypted = self.resolution.next().expect("length mismatch").clone();
134 let point = Point::from_fetch(
135 encrypted.hash(),
136 Arc::new(Visited {
137 decrypted,
138 encrypted,
139 }),
140 );
141 self.visitor.visit(&point);
142 }
143}
144
145impl<K: Key, T: Topological> Topological for EncryptedInner<K, T> {
146 fn accept_points(&self, visitor: &mut impl PointVisitor) {
147 let resolution = &mut self.resolution.iter();
148 self.decrypted.0.accept_points(&mut IterateResolution {
149 resolution,
150 visitor,
151 });
152 assert!(resolution.next().is_none());
153 }
154
155 fn point_count(&self) -> usize {
156 self.resolution.len()
157 }
158
159 fn topology_hash(&self) -> Hash {
160 self.resolution.0.data_hash()
161 }
162}
163
164pub struct Encrypted<K, T> {
165 key: K,
166 inner: EncryptedInner<K, T>,
167}
168
169impl<K, T: Clone> Encrypted<K, T> {
170 pub fn into_inner(self) -> T {
171 Arc::unwrap_or_clone(self.inner.decrypted.0)
172 }
173}
174
175impl<K, T> Deref for Encrypted<K, T> {
176 type Target = T;
177
178 fn deref(&self) -> &Self::Target {
179 self.inner.decrypted.0.as_ref()
180 }
181}
182
183impl<K: Clone, T> Clone for Encrypted<K, T> {
184 fn clone(&self) -> Self {
185 Self {
186 key: self.key.clone(),
187 inner: self.inner.clone(),
188 }
189 }
190}
191
192impl<K: Key, T: Topological> Topological for Encrypted<K, T> {
193 fn accept_points(&self, visitor: &mut impl PointVisitor) {
194 self.inner.accept_points(visitor);
195 }
196
197 fn topology_hash(&self) -> Hash {
198 self.inner.topology_hash()
199 }
200}
201
202impl<K: Key, T: ToOutput> ToOutput for Encrypted<K, T> {
203 fn to_output(&self, output: &mut dyn object_rainbow::Output) {
204 let source = self.inner.vec();
205 output.write(&self.key.encrypt(&source));
206 }
207}
208
209#[derive(Clone)]
210struct Decrypt<K> {
211 resolution: Resolution<K>,
212}
213
214impl<K: Key> Decrypt<K> {
215 async fn resolve_bytes(
216 &self,
217 address: Address,
218 ) -> object_rainbow::Result<(Vec<u8>, Resolution<K>)> {
219 let Encrypted {
220 key: _,
221 inner:
222 EncryptedInner {
223 resolution,
224 decrypted,
225 },
226 } = self
227 .resolution
228 .get(address.index)
229 .ok_or(Error::AddressOutOfBounds)?
230 .clone()
231 .fetch()
232 .await?;
233 let data = Arc::unwrap_or_clone(decrypted.0);
234 Ok((data, resolution))
235 }
236}
237
238impl<K: Key> Resolve for Decrypt<K> {
239 fn resolve(&'_ self, address: Address) -> FailFuture<'_, ByteNode> {
240 Box::pin(async move {
241 let (data, resolution) = self.resolve_bytes(address).await?;
242 Ok((data, Arc::new(Decrypt { resolution }) as _))
243 })
244 }
245
246 fn resolve_data(&'_ self, address: Address) -> FailFuture<'_, Vec<u8>> {
247 Box::pin(async move {
248 let (data, _) = self.resolve_bytes(address).await?;
249 Ok(data)
250 })
251 }
252}
253
254impl<
255 K: Key,
256 T: Object<Extra>,
257 Extra: 'static + Send + Sync + Clone,
258 I: PointInput<Extra = WithKey<K, Extra>>,
259> Parse<I> for Encrypted<K, T>
260{
261 fn parse(input: I) -> object_rainbow::Result<Self> {
262 let with_key = input.extra().clone();
263 let resolve = input.resolve().clone();
264 let source = with_key.key.decrypt(&input.parse_all()?)?;
265 let EncryptedInner {
266 resolution,
267 decrypted,
268 } = EncryptedInner::<K, Vec<u8>>::parse_slice_extra(&source, &resolve, &with_key)?;
269 let decrypted = T::parse_slice_extra(
270 &decrypted.0,
271 &(Arc::new(Decrypt {
272 resolution: resolution.clone(),
273 }) as _),
274 &with_key.extra,
275 )?;
276 let decrypted = Unkeyed(Arc::new(decrypted));
277 let inner = EncryptedInner {
278 resolution,
279 decrypted,
280 };
281 Ok(Self {
282 key: with_key.key,
283 inner,
284 })
285 }
286}
287
288impl<K, T> Tagged for Encrypted<K, T> {}
289
290impl<K: Key, T: Object<Extra>, Extra: 'static + Send + Sync + Clone> Object<WithKey<K, Extra>>
291 for Encrypted<K, T>
292{
293}
294
295type Extracted<K> = Vec<
296 std::pin::Pin<
297 Box<dyn Future<Output = Result<Point<Encrypted<K, Vec<u8>>>, Error>> + Send + 'static>,
298 >,
299>;
300
301struct ExtractResolution<'a, K> {
302 extracted: &'a mut Extracted<K>,
303 key: &'a K,
304}
305
306struct Untyped<K, T> {
307 key: WithKey<K, ()>,
308 encrypted: Point<Encrypted<K, T>>,
309}
310
311impl<K, T> FetchBytes for Untyped<K, T> {
312 fn fetch_bytes(&'_ self) -> FailFuture<'_, ByteNode> {
313 self.encrypted.fetch_bytes()
314 }
315
316 fn fetch_data(&'_ self) -> FailFuture<'_, Vec<u8>> {
317 self.encrypted.fetch_data()
318 }
319}
320
321impl<K: Key, T> Fetch for Untyped<K, T> {
322 type T = Encrypted<K, Vec<u8>>;
323
324 fn fetch_full(&'_ self) -> FailFuture<'_, (Self::T, Arc<dyn Resolve>)> {
325 Box::pin(async move {
326 let (data, resolve) = self.fetch_bytes().await?;
327 let encrypted = Self::T::parse_slice_extra(&data, &resolve, &self.key)?;
328 Ok((encrypted, resolve))
329 })
330 }
331
332 fn fetch(&'_ self) -> FailFuture<'_, Self::T> {
333 Box::pin(async move {
334 let (data, resolve) = self.fetch_bytes().await?;
335 let encrypted = Self::T::parse_slice_extra(&data, &resolve, &self.key)?;
336 Ok(encrypted)
337 })
338 }
339}
340
341impl<K: Key> PointVisitor for ExtractResolution<'_, K> {
342 fn visit<T: Traversible>(&mut self, decrypted: &Point<T>) {
343 let decrypted = decrypted.clone();
344 let key = self.key.clone();
345 self.extracted.push(Box::pin(async move {
346 let encrypted = encrypt_point(key.clone(), decrypted).await?;
347 let encrypted = Point::from_fetch(
348 encrypted.hash(),
349 Arc::new(Untyped {
350 key: WithKey { key, extra: () },
351 encrypted,
352 }),
353 );
354 Ok(encrypted)
355 }));
356 }
357}
358
359pub async fn encrypt_point<K: Key, T: Traversible>(
360 key: K,
361 decrypted: Point<T>,
362) -> object_rainbow::Result<Point<Encrypted<K, T>>> {
363 if let Some((address, decrypt)) = decrypted.extract_resolve::<Decrypt<K>>() {
364 let encrypted = decrypt
365 .resolution
366 .get(address.index)
367 .ok_or(Error::AddressOutOfBounds)?
368 .clone();
369 let point = Point::from_fetch(
370 encrypted.hash(),
371 Arc::new(Visited {
372 decrypted,
373 encrypted,
374 }),
375 );
376 return Ok(point);
377 }
378 let decrypted = decrypted.fetch().await?;
379 let encrypted = encrypt(key.clone(), decrypted).await?;
380 let point = encrypted.point();
381 Ok(point)
382}
383
384pub async fn encrypt<K: Key, T: Traversible>(
385 key: K,
386 decrypted: T,
387) -> object_rainbow::Result<Encrypted<K, T>> {
388 let mut futures = Vec::with_capacity(decrypted.point_count());
389 decrypted.accept_points(&mut ExtractResolution {
390 extracted: &mut futures,
391 key: &key,
392 });
393 let resolution = futures_util::future::try_join_all(futures).await?;
394 let resolution = Arc::new(Lp(resolution));
395 let decrypted = Unkeyed(Arc::new(decrypted));
396 let inner = EncryptedInner {
397 resolution,
398 decrypted,
399 };
400 Ok(Encrypted { key, inner })
401}