Skip to main content

object_rainbow_encrypted/
lib.rs

1use std::{ops::Deref, sync::Arc};
2
3use object_rainbow::{
4    Address, ByteNode, Error, FailFuture, Fetch, FetchBytes, FullHash, Hash, ListHashes, Node,
5    Object, Parse, ParseSliceExtra, PointInput, PointVisitor, Resolve, Singular, SingularFetch,
6    Tagged, ToOutput, Topological, Traversible, length_prefixed::Lp,
7};
8use object_rainbow_point::{ExtractResolve, IntoPoint, Point};
9
10pub trait Key: 'static + Sized + Send + Sync + Clone {
11    fn encrypt(&self, data: &[u8]) -> Vec<u8>;
12    fn decrypt(&self, data: &[u8]) -> object_rainbow::Result<Vec<u8>>;
13}
14
15type Resolution<K> = Arc<Lp<Vec<Point<Encrypted<K, Vec<u8>>>>>>;
16
17#[derive(ToOutput, Clone)]
18struct Unkeyed<T>(T);
19
20impl<
21    T: Parse<I::WithExtra<Extra>>,
22    K: 'static + Clone,
23    Extra: 'static + Clone,
24    I: PointInput<Extra = (K, Extra)>,
25> Parse<I> for Unkeyed<T>
26{
27    fn parse(input: I) -> object_rainbow::Result<Self> {
28        Ok(Self(T::parse(input.map_extra(|(_, extra)| extra))?))
29    }
30}
31
32#[derive(ToOutput, Parse)]
33struct EncryptedInner<K, T> {
34    resolution: Resolution<K>,
35    decrypted: Unkeyed<Arc<T>>,
36}
37
38impl<K, T> Clone for EncryptedInner<K, T> {
39    fn clone(&self) -> Self {
40        Self {
41            resolution: self.resolution.clone(),
42            decrypted: self.decrypted.clone(),
43        }
44    }
45}
46
47type ResolutionIter<'a, K> = std::slice::Iter<'a, Point<Encrypted<K, Vec<u8>>>>;
48
49struct IterateResolution<'a, 'r, K, V> {
50    resolution: &'r mut ResolutionIter<'a, K>,
51    visitor: &'a mut V,
52}
53
54struct Visited<K, P> {
55    decrypted: P,
56    encrypted: Point<Encrypted<K, Vec<u8>>>,
57}
58
59impl<K, P> FetchBytes for Visited<K, P> {
60    fn fetch_bytes(&'_ self) -> FailFuture<'_, ByteNode> {
61        self.encrypted.fetch_bytes()
62    }
63
64    fn fetch_data(&'_ self) -> FailFuture<'_, Vec<u8>> {
65        self.encrypted.fetch_data()
66    }
67
68    fn fetch_bytes_local(&self) -> object_rainbow::Result<Option<ByteNode>> {
69        self.encrypted.fetch_bytes_local()
70    }
71
72    fn fetch_data_local(&self) -> Option<Vec<u8>> {
73        self.encrypted.fetch_data_local()
74    }
75}
76
77impl<K, P: Send + Sync> Singular for Visited<K, P> {
78    fn hash(&self) -> Hash {
79        self.encrypted.hash()
80    }
81}
82
83impl<K: Key, P: Fetch<T: Traversible>> Fetch for Visited<K, P> {
84    type T = Encrypted<K, P::T>;
85
86    fn fetch_full(&'_ self) -> FailFuture<'_, Node<Self::T>> {
87        Box::pin(async move {
88            let (
89                Encrypted {
90                    key,
91                    inner:
92                        EncryptedInner {
93                            resolution,
94                            decrypted: _,
95                        },
96                },
97                resolve,
98            ) = self.encrypted.fetch_full().await?;
99            let decrypted = self.decrypted.fetch().await?;
100            let decrypted = Unkeyed(Arc::new(decrypted));
101            Ok((
102                Encrypted {
103                    key,
104                    inner: EncryptedInner {
105                        resolution,
106                        decrypted,
107                    },
108                },
109                resolve,
110            ))
111        })
112    }
113
114    fn fetch(&'_ self) -> FailFuture<'_, Self::T> {
115        Box::pin(async move {
116            let Encrypted {
117                key,
118                inner:
119                    EncryptedInner {
120                        resolution,
121                        decrypted: _,
122                    },
123            } = self.encrypted.fetch().await?;
124            let decrypted = self.decrypted.fetch().await?;
125            let decrypted = Unkeyed(Arc::new(decrypted));
126            Ok(Encrypted {
127                key,
128                inner: EncryptedInner {
129                    resolution,
130                    decrypted,
131                },
132            })
133        })
134    }
135
136    fn try_fetch_local(&self) -> object_rainbow::Result<Option<Node<Self::T>>> {
137        let Some((
138            Encrypted {
139                key,
140                inner:
141                    EncryptedInner {
142                        resolution,
143                        decrypted: _,
144                    },
145            },
146            resolve,
147        )) = self.encrypted.try_fetch_local()?
148        else {
149            return Ok(None);
150        };
151        let Some((decrypted, _)) = self.decrypted.try_fetch_local()? else {
152            return Ok(None);
153        };
154        let decrypted = Unkeyed(Arc::new(decrypted));
155        Ok(Some((
156            Encrypted {
157                key,
158                inner: EncryptedInner {
159                    resolution,
160                    decrypted,
161                },
162            },
163            resolve,
164        )))
165    }
166
167    fn fetch_local(&self) -> Option<Self::T> {
168        let Encrypted {
169            key,
170            inner:
171                EncryptedInner {
172                    resolution,
173                    decrypted: _,
174                },
175        } = self.encrypted.fetch_local()?;
176        let decrypted = Unkeyed(Arc::new(self.decrypted.fetch_local()?));
177        Some(Encrypted {
178            key,
179            inner: EncryptedInner {
180                resolution,
181                decrypted,
182            },
183        })
184    }
185}
186
187impl<'a, K: Key, V: PointVisitor> PointVisitor for IterateResolution<'a, '_, K, V> {
188    fn visit<T: Traversible>(&mut self, decrypted: &(impl 'static + SingularFetch<T = T> + Clone)) {
189        let decrypted = decrypted.clone();
190        let encrypted = self.resolution.next().expect("length mismatch").clone();
191        let point = Point::from_fetch(
192            encrypted.hash(),
193            Arc::new(Visited {
194                decrypted,
195                encrypted,
196            }),
197        );
198        self.visitor.visit(&point);
199    }
200}
201
202impl<K, T> ListHashes for EncryptedInner<K, T> {
203    fn list_hashes(&self, f: &mut impl FnMut(Hash)) {
204        self.resolution.list_hashes(f);
205    }
206
207    fn topology_hash(&self) -> Hash {
208        self.resolution.0.data_hash()
209    }
210
211    fn point_count(&self) -> usize {
212        self.resolution.len()
213    }
214}
215
216impl<K: Key, T: Topological> Topological for EncryptedInner<K, T> {
217    fn traverse(&self, visitor: &mut impl PointVisitor) {
218        let resolution = &mut self.resolution.iter();
219        self.decrypted.0.traverse(&mut IterateResolution {
220            resolution,
221            visitor,
222        });
223        assert!(resolution.next().is_none());
224    }
225}
226
227pub struct Encrypted<K, T> {
228    key: K,
229    inner: EncryptedInner<K, T>,
230}
231
232impl<K, T: Clone> Encrypted<K, T> {
233    pub fn into_inner(self) -> T {
234        Arc::unwrap_or_clone(self.inner.decrypted.0)
235    }
236}
237
238impl<K, T> Deref for Encrypted<K, T> {
239    type Target = T;
240
241    fn deref(&self) -> &Self::Target {
242        self.inner.decrypted.0.as_ref()
243    }
244}
245
246impl<K: Clone, T> Clone for Encrypted<K, T> {
247    fn clone(&self) -> Self {
248        Self {
249            key: self.key.clone(),
250            inner: self.inner.clone(),
251        }
252    }
253}
254
255impl<K, T> ListHashes for Encrypted<K, T> {
256    fn list_hashes(&self, f: &mut impl FnMut(Hash)) {
257        self.inner.list_hashes(f);
258    }
259
260    fn topology_hash(&self) -> Hash {
261        self.inner.topology_hash()
262    }
263
264    fn point_count(&self) -> usize {
265        self.inner.point_count()
266    }
267}
268
269impl<K: Key, T: Topological> Topological for Encrypted<K, T> {
270    fn traverse(&self, visitor: &mut impl PointVisitor) {
271        self.inner.traverse(visitor);
272    }
273}
274
275impl<K: Key, T: ToOutput> ToOutput for Encrypted<K, T> {
276    fn to_output(&self, output: &mut dyn object_rainbow::Output) {
277        let source = self.inner.vec();
278        output.write(&self.key.encrypt(&source));
279    }
280}
281
282#[derive(Clone)]
283struct Decrypt<K> {
284    resolution: Resolution<K>,
285}
286
287impl<K: Key> Decrypt<K> {
288    async fn resolve_bytes(
289        &self,
290        address: Address,
291    ) -> object_rainbow::Result<(Vec<u8>, Resolution<K>)> {
292        let Encrypted {
293            key: _,
294            inner:
295                EncryptedInner {
296                    resolution,
297                    decrypted,
298                },
299        } = self
300            .resolution
301            .get(address.index)
302            .ok_or(Error::AddressOutOfBounds)?
303            .clone()
304            .fetch()
305            .await?;
306        let data = Arc::unwrap_or_clone(decrypted.0);
307        Ok((data, resolution))
308    }
309}
310
311impl<K: Key> Resolve for Decrypt<K> {
312    fn resolve(&'_ self, address: Address) -> FailFuture<'_, ByteNode> {
313        Box::pin(async move {
314            let (data, resolution) = self.resolve_bytes(address).await?;
315            Ok((data, Arc::new(Decrypt { resolution }) as _))
316        })
317    }
318
319    fn resolve_data(&'_ self, address: Address) -> FailFuture<'_, Vec<u8>> {
320        Box::pin(async move {
321            let (data, _) = self.resolve_bytes(address).await?;
322            Ok(data)
323        })
324    }
325
326    fn try_resolve_local(&self, address: Address) -> object_rainbow::Result<Option<ByteNode>> {
327        let Some((
328            Encrypted {
329                key: _,
330                inner:
331                    EncryptedInner {
332                        resolution,
333                        decrypted,
334                    },
335            },
336            _,
337        )) = self
338            .resolution
339            .get(address.index)
340            .ok_or(Error::AddressOutOfBounds)?
341            .clone()
342            .try_fetch_local()?
343        else {
344            return Ok(None);
345        };
346        let data = Arc::unwrap_or_clone(decrypted.0);
347        Ok(Some((data, Arc::new(Decrypt { resolution }) as _)))
348    }
349}
350
351trait EncryptedExtra<K>: 'static + Send + Sync + Clone {
352    type Extra: 'static + Send + Sync + Clone;
353    fn parts(&self) -> (K, Self::Extra);
354}
355
356impl<K: 'static + Send + Sync + Clone, Extra: 'static + Send + Sync + Clone> EncryptedExtra<K>
357    for (K, Extra)
358{
359    type Extra = Extra;
360
361    fn parts(&self) -> (K, Self::Extra) {
362        self.clone()
363    }
364}
365
366impl<K: 'static + Send + Sync + Clone> EncryptedExtra<K> for K {
367    type Extra = ();
368
369    fn parts(&self) -> (K, Self::Extra) {
370        (self.clone(), ())
371    }
372}
373
374impl<
375    K: Key,
376    T: Object<Extra>,
377    Extra: 'static + Send + Sync + Clone,
378    I: PointInput<Extra: EncryptedExtra<K, Extra = Extra>>,
379> Parse<I> for Encrypted<K, T>
380{
381    fn parse(input: I) -> object_rainbow::Result<Self> {
382        let with_key = input.extra().parts();
383        let resolve = input.resolve().clone();
384        let source = with_key.0.decrypt(&input.parse_all()?)?;
385        let EncryptedInner {
386            resolution,
387            decrypted,
388        } = EncryptedInner::<K, Vec<u8>>::parse_slice_extra(&source, &resolve, &with_key)?;
389        let decrypted = T::parse_slice_extra(
390            &decrypted.0,
391            &(Arc::new(Decrypt {
392                resolution: resolution.clone(),
393            }) as _),
394            &with_key.1,
395        )?;
396        let decrypted = Unkeyed(Arc::new(decrypted));
397        let inner = EncryptedInner {
398            resolution,
399            decrypted,
400        };
401        Ok(Self {
402            key: with_key.0,
403            inner,
404        })
405    }
406}
407
408impl<K, T> Tagged for Encrypted<K, T> {}
409
410type Extracted<K> = Vec<
411    std::pin::Pin<
412        Box<dyn Future<Output = Result<Point<Encrypted<K, Vec<u8>>>, Error>> + Send + 'static>,
413    >,
414>;
415
416struct ExtractResolution<'a, K> {
417    extracted: &'a mut Extracted<K>,
418    key: &'a K,
419}
420
421struct Untyped<K, T> {
422    key: K,
423    encrypted: Point<Encrypted<K, T>>,
424}
425
426impl<K, T> FetchBytes for Untyped<K, T> {
427    fn fetch_bytes(&'_ self) -> FailFuture<'_, ByteNode> {
428        self.encrypted.fetch_bytes()
429    }
430
431    fn fetch_data(&'_ self) -> FailFuture<'_, Vec<u8>> {
432        self.encrypted.fetch_data()
433    }
434
435    fn fetch_bytes_local(&self) -> object_rainbow::Result<Option<ByteNode>> {
436        self.encrypted.fetch_bytes_local()
437    }
438
439    fn fetch_data_local(&self) -> Option<Vec<u8>> {
440        self.encrypted.fetch_data_local()
441    }
442}
443
444impl<K: Send + Sync, T> Singular for Untyped<K, T> {
445    fn hash(&self) -> Hash {
446        self.encrypted.hash()
447    }
448}
449
450impl<K: Key, T: FullHash> Fetch for Untyped<K, T> {
451    type T = Encrypted<K, Vec<u8>>;
452
453    fn fetch_full(&'_ self) -> FailFuture<'_, Node<Self::T>> {
454        Box::pin(async move {
455            let (data, resolve) = self.fetch_bytes().await?;
456            let encrypted = Self::T::parse_slice_extra(&data, &resolve, &self.key)?;
457            Ok((encrypted, resolve))
458        })
459    }
460
461    fn fetch(&'_ self) -> FailFuture<'_, Self::T> {
462        Box::pin(async move {
463            let (data, resolve) = self.fetch_bytes().await?;
464            let encrypted = Self::T::parse_slice_extra(&data, &resolve, &self.key)?;
465            Ok(encrypted)
466        })
467    }
468
469    fn try_fetch_local(&self) -> object_rainbow::Result<Option<Node<Self::T>>> {
470        let Some((data, resolve)) = self.fetch_bytes_local()? else {
471            return Ok(None);
472        };
473        let encrypted = Self::T::parse_slice_extra(&data, &resolve, &self.key)?;
474        Ok(Some((encrypted, resolve)))
475    }
476
477    fn fetch_local(&self) -> Option<Self::T> {
478        let Encrypted {
479            key,
480            inner:
481                EncryptedInner {
482                    resolution,
483                    decrypted,
484                },
485        } = self.encrypted.fetch_local()?;
486        let decrypted = Unkeyed(Arc::new(decrypted.vec()));
487        Some(Encrypted {
488            key,
489            inner: EncryptedInner {
490                resolution,
491                decrypted,
492            },
493        })
494    }
495}
496
497impl<K: Key> PointVisitor for ExtractResolution<'_, K> {
498    fn visit<T: Traversible>(&mut self, decrypted: &(impl 'static + SingularFetch<T = T> + Clone)) {
499        let decrypted = decrypted.clone();
500        let key = self.key.clone();
501        self.extracted.push(Box::pin(async move {
502            let encrypted = encrypt_point(key.clone(), decrypted).await?;
503            let encrypted =
504                Point::from_fetch(encrypted.hash(), Arc::new(Untyped { key, encrypted }));
505            Ok(encrypted)
506        }));
507    }
508}
509
510pub async fn encrypt_point<K: Key, T: Traversible>(
511    key: K,
512    decrypted: impl 'static + SingularFetch<T = T>,
513) -> object_rainbow::Result<Point<Encrypted<K, T>>> {
514    if let Some((address, decrypt)) = decrypted.extract_resolve::<Decrypt<K>>() {
515        let encrypted = decrypt
516            .resolution
517            .get(address.index)
518            .ok_or(Error::AddressOutOfBounds)?
519            .clone();
520        let point = Point::from_fetch(
521            encrypted.hash(),
522            Arc::new(Visited {
523                decrypted,
524                encrypted,
525            }),
526        );
527        return Ok(point);
528    }
529    let decrypted = decrypted.fetch().await?;
530    let encrypted = encrypt(key.clone(), decrypted).await?;
531    let point = encrypted.point();
532    Ok(point)
533}
534
535pub async fn encrypt<K: Key, T: Traversible>(
536    key: K,
537    decrypted: T,
538) -> object_rainbow::Result<Encrypted<K, T>> {
539    let mut futures = Vec::with_capacity(decrypted.point_count());
540    decrypted.traverse(&mut ExtractResolution {
541        extracted: &mut futures,
542        key: &key,
543    });
544    let resolution = futures_util::future::try_join_all(futures).await?;
545    let resolution = Arc::new(Lp(resolution));
546    let decrypted = Unkeyed(Arc::new(decrypted));
547    let inner = EncryptedInner {
548        resolution,
549        decrypted,
550    };
551    Ok(Encrypted { key, inner })
552}