Skip to main content

object_rainbow_encrypted/
lib.rs

1use std::{ops::Deref, sync::Arc};
2
3use object_rainbow::{
4    Address, ByteNode, Error, FailFuture, Fetch, FetchBytes, FullHash, Hash, ListHashes, Node,
5    Object, Parse, ParseSliceExtra, PointInput, PointVisitor, Resolve, Singular, SingularFetch,
6    Tagged, ToOutput, Topological, Traversible, length_prefixed::Lp,
7};
8use object_rainbow_point::{ExtractResolve, IntoPoint, Point};
9
10pub trait Key: 'static + Sized + Send + Sync + Clone {
11    type Error: Into<anyhow::Error>;
12    fn encrypt(&self, data: &[u8]) -> Vec<u8>;
13    fn decrypt(&self, data: &[u8]) -> Result<Vec<u8>, Self::Error>;
14}
15
16type Resolution<K> = Arc<Lp<Vec<Point<Encrypted<K, Vec<u8>>>>>>;
17
18#[derive(ToOutput, Clone)]
19struct Unkeyed<T>(T);
20
21impl<
22    T: Parse<I::WithExtra<Extra>>,
23    K: 'static + Clone,
24    Extra: 'static + Clone,
25    I: PointInput<Extra = (K, Extra)>,
26> Parse<I> for Unkeyed<T>
27{
28    fn parse(input: I) -> object_rainbow::Result<Self> {
29        Ok(Self(T::parse(input.map_extra(|(_, extra)| extra))?))
30    }
31}
32
33#[derive(ToOutput, Parse)]
34struct EncryptedInner<K, T> {
35    resolution: Resolution<K>,
36    decrypted: Unkeyed<Arc<T>>,
37}
38
39impl<K, T> Clone for EncryptedInner<K, T> {
40    fn clone(&self) -> Self {
41        Self {
42            resolution: self.resolution.clone(),
43            decrypted: self.decrypted.clone(),
44        }
45    }
46}
47
48type ResolutionIter<'a, K> = std::slice::Iter<'a, Point<Encrypted<K, Vec<u8>>>>;
49
50struct IterateResolution<'a, 'r, K, V> {
51    resolution: &'r mut ResolutionIter<'a, K>,
52    visitor: &'a mut V,
53}
54
55struct Visited<K, P> {
56    decrypted: P,
57    encrypted: Point<Encrypted<K, Vec<u8>>>,
58}
59
60impl<K, P> FetchBytes for Visited<K, P> {
61    fn fetch_bytes(&'_ self) -> FailFuture<'_, ByteNode> {
62        self.encrypted.fetch_bytes()
63    }
64
65    fn fetch_data(&'_ self) -> FailFuture<'_, Vec<u8>> {
66        self.encrypted.fetch_data()
67    }
68
69    fn fetch_bytes_local(&self) -> object_rainbow::Result<Option<ByteNode>> {
70        self.encrypted.fetch_bytes_local()
71    }
72
73    fn fetch_data_local(&self) -> Option<Vec<u8>> {
74        self.encrypted.fetch_data_local()
75    }
76}
77
78impl<K, P: Send + Sync> Singular for Visited<K, P> {
79    fn hash(&self) -> Hash {
80        self.encrypted.hash()
81    }
82}
83
84impl<K: Key, P: Fetch<T: Traversible>> Fetch for Visited<K, P> {
85    type T = Encrypted<K, P::T>;
86
87    fn fetch_full(&'_ self) -> FailFuture<'_, Node<Self::T>> {
88        Box::pin(async move {
89            let (
90                Encrypted {
91                    key,
92                    inner:
93                        EncryptedInner {
94                            resolution,
95                            decrypted: _,
96                        },
97                },
98                resolve,
99            ) = self.encrypted.fetch_full().await?;
100            let decrypted = self.decrypted.fetch().await?;
101            let decrypted = Unkeyed(Arc::new(decrypted));
102            Ok((
103                Encrypted {
104                    key,
105                    inner: EncryptedInner {
106                        resolution,
107                        decrypted,
108                    },
109                },
110                resolve,
111            ))
112        })
113    }
114
115    fn fetch(&'_ self) -> FailFuture<'_, Self::T> {
116        Box::pin(async move {
117            let Encrypted {
118                key,
119                inner:
120                    EncryptedInner {
121                        resolution,
122                        decrypted: _,
123                    },
124            } = self.encrypted.fetch().await?;
125            let decrypted = self.decrypted.fetch().await?;
126            let decrypted = Unkeyed(Arc::new(decrypted));
127            Ok(Encrypted {
128                key,
129                inner: EncryptedInner {
130                    resolution,
131                    decrypted,
132                },
133            })
134        })
135    }
136
137    fn try_fetch_local(&self) -> object_rainbow::Result<Option<Node<Self::T>>> {
138        let Some((
139            Encrypted {
140                key,
141                inner:
142                    EncryptedInner {
143                        resolution,
144                        decrypted: _,
145                    },
146            },
147            resolve,
148        )) = self.encrypted.try_fetch_local()?
149        else {
150            return Ok(None);
151        };
152        let Some((decrypted, _)) = self.decrypted.try_fetch_local()? else {
153            return Ok(None);
154        };
155        let decrypted = Unkeyed(Arc::new(decrypted));
156        Ok(Some((
157            Encrypted {
158                key,
159                inner: EncryptedInner {
160                    resolution,
161                    decrypted,
162                },
163            },
164            resolve,
165        )))
166    }
167
168    fn fetch_local(&self) -> Option<Self::T> {
169        let Encrypted {
170            key,
171            inner:
172                EncryptedInner {
173                    resolution,
174                    decrypted: _,
175                },
176        } = self.encrypted.fetch_local()?;
177        let decrypted = Unkeyed(Arc::new(self.decrypted.fetch_local()?));
178        Some(Encrypted {
179            key,
180            inner: EncryptedInner {
181                resolution,
182                decrypted,
183            },
184        })
185    }
186}
187
188impl<'a, K: Key, V: PointVisitor> PointVisitor for IterateResolution<'a, '_, K, V> {
189    fn visit<T: Traversible>(&mut self, decrypted: &(impl 'static + SingularFetch<T = T> + Clone)) {
190        let decrypted = decrypted.clone();
191        let encrypted = self.resolution.next().expect("length mismatch").clone();
192        let point = Point::from_fetch(
193            encrypted.hash(),
194            Visited {
195                decrypted,
196                encrypted,
197            }
198            .into_dyn_fetch(),
199        );
200        self.visitor.visit(&point);
201    }
202}
203
204impl<K, T> ListHashes for EncryptedInner<K, T> {
205    fn list_hashes(&self, f: &mut impl FnMut(Hash)) {
206        self.resolution.list_hashes(f);
207    }
208
209    fn topology_hash(&self) -> Hash {
210        self.resolution.0.data_hash()
211    }
212
213    fn point_count(&self) -> usize {
214        self.resolution.len()
215    }
216}
217
218impl<K: Key, T: Topological> Topological for EncryptedInner<K, T> {
219    fn traverse(&self, visitor: &mut impl PointVisitor) {
220        let resolution = &mut self.resolution.iter();
221        self.decrypted.0.traverse(&mut IterateResolution {
222            resolution,
223            visitor,
224        });
225        assert!(resolution.next().is_none());
226    }
227}
228
229pub struct Encrypted<K, T> {
230    key: K,
231    inner: EncryptedInner<K, T>,
232}
233
234impl<K, T: Clone> Encrypted<K, T> {
235    pub fn into_inner(self) -> T {
236        Arc::unwrap_or_clone(self.inner.decrypted.0)
237    }
238}
239
240impl<K, T> Deref for Encrypted<K, T> {
241    type Target = T;
242
243    fn deref(&self) -> &Self::Target {
244        self.inner.decrypted.0.as_ref()
245    }
246}
247
248impl<K: Clone, T> Clone for Encrypted<K, T> {
249    fn clone(&self) -> Self {
250        Self {
251            key: self.key.clone(),
252            inner: self.inner.clone(),
253        }
254    }
255}
256
257impl<K, T> ListHashes for Encrypted<K, T> {
258    fn list_hashes(&self, f: &mut impl FnMut(Hash)) {
259        self.inner.list_hashes(f);
260    }
261
262    fn topology_hash(&self) -> Hash {
263        self.inner.topology_hash()
264    }
265
266    fn point_count(&self) -> usize {
267        self.inner.point_count()
268    }
269}
270
271impl<K: Key, T: Topological> Topological for Encrypted<K, T> {
272    fn traverse(&self, visitor: &mut impl PointVisitor) {
273        self.inner.traverse(visitor);
274    }
275}
276
277impl<K: Key, T: ToOutput> ToOutput for Encrypted<K, T> {
278    fn to_output(&self, output: &mut dyn object_rainbow::Output) {
279        let source = self.inner.vec();
280        output.write(&self.key.encrypt(&source));
281    }
282}
283
284#[derive(Clone)]
285struct Decrypt<K> {
286    resolution: Resolution<K>,
287}
288
289impl<K: Key> Decrypt<K> {
290    async fn resolve_bytes(
291        &self,
292        address: Address,
293    ) -> object_rainbow::Result<(Vec<u8>, Resolution<K>)> {
294        let Encrypted {
295            key: _,
296            inner:
297                EncryptedInner {
298                    resolution,
299                    decrypted,
300                },
301        } = self
302            .resolution
303            .get(address.index)
304            .ok_or(Error::AddressOutOfBounds)?
305            .clone()
306            .fetch()
307            .await?;
308        let data = Arc::unwrap_or_clone(decrypted.0);
309        Ok((data, resolution))
310    }
311}
312
313impl<K: Key> Resolve for Decrypt<K> {
314    fn resolve(&'_ self, address: Address, _: &Arc<dyn Resolve>) -> FailFuture<'_, ByteNode> {
315        Box::pin(async move {
316            let (data, resolution) = self.resolve_bytes(address).await?;
317            Ok((data, Arc::new(Decrypt { resolution }) as _))
318        })
319    }
320
321    fn resolve_data(&'_ self, address: Address) -> FailFuture<'_, Vec<u8>> {
322        Box::pin(async move {
323            let (data, _) = self.resolve_bytes(address).await?;
324            Ok(data)
325        })
326    }
327
328    fn try_resolve_local(
329        &self,
330        address: Address,
331        _: &Arc<dyn Resolve>,
332    ) -> object_rainbow::Result<Option<ByteNode>> {
333        let Some((
334            Encrypted {
335                key: _,
336                inner:
337                    EncryptedInner {
338                        resolution,
339                        decrypted,
340                    },
341            },
342            _,
343        )) = self
344            .resolution
345            .get(address.index)
346            .ok_or(Error::AddressOutOfBounds)?
347            .clone()
348            .try_fetch_local()?
349        else {
350            return Ok(None);
351        };
352        let data = Arc::unwrap_or_clone(decrypted.0);
353        Ok(Some((data, Arc::new(Decrypt { resolution }) as _)))
354    }
355}
356
357trait EncryptedExtra<K>: 'static + Send + Sync + Clone {
358    type Extra: 'static + Send + Sync + Clone;
359    fn parts(&self) -> (K, Self::Extra);
360}
361
362impl<K: 'static + Send + Sync + Clone, Extra: 'static + Send + Sync + Clone> EncryptedExtra<K>
363    for (K, Extra)
364{
365    type Extra = Extra;
366
367    fn parts(&self) -> (K, Self::Extra) {
368        self.clone()
369    }
370}
371
372impl<K: 'static + Send + Sync + Clone> EncryptedExtra<K> for K {
373    type Extra = ();
374
375    fn parts(&self) -> (K, Self::Extra) {
376        (self.clone(), ())
377    }
378}
379
380impl<
381    K: Key,
382    T: Object<Extra>,
383    Extra: 'static + Send + Sync + Clone,
384    I: PointInput<Extra: EncryptedExtra<K, Extra = Extra>>,
385> Parse<I> for Encrypted<K, T>
386{
387    fn parse(input: I) -> object_rainbow::Result<Self> {
388        let with_key = input.extra().parts();
389        let resolve = input.resolve().clone();
390        let source = with_key
391            .0
392            .decrypt(&input.parse_all()?)
393            .map_err(object_rainbow::Error::consistency)?;
394        let EncryptedInner {
395            resolution,
396            decrypted,
397        } = EncryptedInner::<K, Vec<u8>>::parse_slice_extra(&source, &resolve, &with_key)?;
398        let decrypted = T::parse_slice_extra(
399            &decrypted.0,
400            &(Arc::new(Decrypt {
401                resolution: resolution.clone(),
402            }) as _),
403            &with_key.1,
404        )?;
405        let decrypted = Unkeyed(Arc::new(decrypted));
406        let inner = EncryptedInner {
407            resolution,
408            decrypted,
409        };
410        Ok(Self {
411            key: with_key.0,
412            inner,
413        })
414    }
415}
416
417impl<K, T> Tagged for Encrypted<K, T> {}
418
419type Extracted<K> = Vec<
420    std::pin::Pin<
421        Box<dyn Future<Output = Result<Point<Encrypted<K, Vec<u8>>>, Error>> + Send + 'static>,
422    >,
423>;
424
425struct ExtractResolution<'a, K> {
426    extracted: &'a mut Extracted<K>,
427    key: &'a K,
428}
429
430struct Untyped<K, T> {
431    key: K,
432    encrypted: Point<Encrypted<K, T>>,
433}
434
435impl<K, T> FetchBytes for Untyped<K, T> {
436    fn fetch_bytes(&'_ self) -> FailFuture<'_, ByteNode> {
437        self.encrypted.fetch_bytes()
438    }
439
440    fn fetch_data(&'_ self) -> FailFuture<'_, Vec<u8>> {
441        self.encrypted.fetch_data()
442    }
443
444    fn fetch_bytes_local(&self) -> object_rainbow::Result<Option<ByteNode>> {
445        self.encrypted.fetch_bytes_local()
446    }
447
448    fn fetch_data_local(&self) -> Option<Vec<u8>> {
449        self.encrypted.fetch_data_local()
450    }
451}
452
453impl<K: Send + Sync, T> Singular for Untyped<K, T> {
454    fn hash(&self) -> Hash {
455        self.encrypted.hash()
456    }
457}
458
459impl<K: Key, T: FullHash> Fetch for Untyped<K, T> {
460    type T = Encrypted<K, Vec<u8>>;
461
462    fn fetch_full(&'_ self) -> FailFuture<'_, Node<Self::T>> {
463        Box::pin(async move {
464            let (data, resolve) = self.fetch_bytes().await?;
465            let encrypted = Self::T::parse_slice_extra(&data, &resolve, &self.key)?;
466            Ok((encrypted, resolve))
467        })
468    }
469
470    fn fetch(&'_ self) -> FailFuture<'_, Self::T> {
471        Box::pin(async move {
472            let (data, resolve) = self.fetch_bytes().await?;
473            let encrypted = Self::T::parse_slice_extra(&data, &resolve, &self.key)?;
474            Ok(encrypted)
475        })
476    }
477
478    fn try_fetch_local(&self) -> object_rainbow::Result<Option<Node<Self::T>>> {
479        let Some((data, resolve)) = self.fetch_bytes_local()? else {
480            return Ok(None);
481        };
482        let encrypted = Self::T::parse_slice_extra(&data, &resolve, &self.key)?;
483        Ok(Some((encrypted, resolve)))
484    }
485
486    fn fetch_local(&self) -> Option<Self::T> {
487        let Encrypted {
488            key,
489            inner:
490                EncryptedInner {
491                    resolution,
492                    decrypted,
493                },
494        } = self.encrypted.fetch_local()?;
495        let decrypted = Unkeyed(Arc::new(decrypted.vec()));
496        Some(Encrypted {
497            key,
498            inner: EncryptedInner {
499                resolution,
500                decrypted,
501            },
502        })
503    }
504}
505
506impl<K: Key> PointVisitor for ExtractResolution<'_, K> {
507    fn visit<T: Traversible>(&mut self, decrypted: &(impl 'static + SingularFetch<T = T> + Clone)) {
508        let decrypted = decrypted.clone();
509        let key = self.key.clone();
510        self.extracted.push(Box::pin(async move {
511            let encrypted = encrypt_point(key.clone(), decrypted).await?;
512            let encrypted = Point::from_fetch(
513                encrypted.hash(),
514                Untyped { key, encrypted }.into_dyn_fetch(),
515            );
516            Ok(encrypted)
517        }));
518    }
519}
520
521pub async fn encrypt_point<K: Key, T: Traversible>(
522    key: K,
523    decrypted: impl 'static + SingularFetch<T = T>,
524) -> object_rainbow::Result<Point<Encrypted<K, T>>> {
525    if let Some((address, decrypt)) = decrypted.extract_resolve::<Decrypt<K>>() {
526        let encrypted = decrypt
527            .resolution
528            .get(address.index)
529            .ok_or(Error::AddressOutOfBounds)?
530            .clone();
531        let point = Point::from_fetch(
532            encrypted.hash(),
533            Visited {
534                decrypted,
535                encrypted,
536            }
537            .into_dyn_fetch(),
538        );
539        return Ok(point);
540    }
541    let decrypted = decrypted.fetch().await?;
542    let encrypted = encrypt(key.clone(), decrypted).await?;
543    let point = encrypted.point();
544    Ok(point)
545}
546
547pub async fn encrypt<K: Key, T: Traversible>(
548    key: K,
549    decrypted: T,
550) -> object_rainbow::Result<Encrypted<K, T>> {
551    let mut futures = Vec::with_capacity(decrypted.point_count());
552    decrypted.traverse(&mut ExtractResolution {
553        extracted: &mut futures,
554        key: &key,
555    });
556    let resolution = futures_util::future::try_join_all(futures).await?;
557    let resolution = Arc::new(Lp(resolution));
558    let decrypted = Unkeyed(Arc::new(decrypted));
559    let inner = EncryptedInner {
560        resolution,
561        decrypted,
562    };
563    Ok(Encrypted { key, inner })
564}