Skip to main content

object_rainbow_encrypted/
lib.rs

1use std::{ops::Deref, sync::Arc};
2
3use object_rainbow::{
4    Address, ByteNode, Error, FailFuture, Fetch, FetchBytes, FullHash, Hash, ListHashes, Node,
5    Object, Parse, ParseSliceExtra, PointInput, PointVisitor, Resolve, Singular, SingularFetch,
6    Tagged, ToOutput, Topological, Traversible, length_prefixed::Lp,
7};
8use object_rainbow_point::{ExtractResolve, IntoPoint, Point};
9
10pub trait Key: 'static + Sized + Send + Sync + Clone {
11    type Error: Into<anyhow::Error>;
12    fn encrypt(&self, data: &[u8]) -> Vec<u8>;
13    fn decrypt(&self, data: &[u8]) -> Result<Vec<u8>, Self::Error>;
14}
15
16type Resolution<K> = Arc<Lp<Vec<Point<Encrypted<K, Vec<u8>>>>>>;
17
18#[derive(ToOutput, Clone)]
19struct Unkeyed<T>(T);
20
21impl<
22    T: Parse<I::WithExtra<Extra>>,
23    K: 'static + Clone,
24    Extra: 'static + Clone,
25    I: PointInput<Extra = (K, Extra)>,
26> Parse<I> for Unkeyed<T>
27{
28    fn parse(input: I) -> object_rainbow::Result<Self> {
29        Ok(Self(T::parse(input.map_extra(|(_, extra)| extra))?))
30    }
31}
32
33#[derive(ToOutput, Parse)]
34struct EncryptedInner<K, T> {
35    resolution: Resolution<K>,
36    decrypted: Unkeyed<Arc<T>>,
37}
38
39impl<K, T> Clone for EncryptedInner<K, T> {
40    fn clone(&self) -> Self {
41        Self {
42            resolution: self.resolution.clone(),
43            decrypted: self.decrypted.clone(),
44        }
45    }
46}
47
48type ResolutionIter<'a, K> = std::slice::Iter<'a, Point<Encrypted<K, Vec<u8>>>>;
49
50struct IterateResolution<'a, 'r, K, V> {
51    resolution: &'r mut ResolutionIter<'a, K>,
52    visitor: &'a mut V,
53}
54
55struct Visited<K, P> {
56    decrypted: P,
57    encrypted: Point<Encrypted<K, Vec<u8>>>,
58}
59
60impl<K, P> FetchBytes for Visited<K, P> {
61    fn fetch_bytes(&'_ self) -> FailFuture<'_, ByteNode> {
62        self.encrypted.fetch_bytes()
63    }
64
65    fn fetch_data(&'_ self) -> FailFuture<'_, Vec<u8>> {
66        self.encrypted.fetch_data()
67    }
68
69    fn fetch_bytes_local(&self) -> object_rainbow::Result<Option<ByteNode>> {
70        self.encrypted.fetch_bytes_local()
71    }
72
73    fn fetch_data_local(&self) -> Option<Vec<u8>> {
74        self.encrypted.fetch_data_local()
75    }
76}
77
78impl<K, P: Send + Sync> Singular for Visited<K, P> {
79    fn hash(&self) -> Hash {
80        self.encrypted.hash()
81    }
82}
83
84impl<K: Key, P: Fetch<T: Traversible>> Fetch for Visited<K, P> {
85    type T = Encrypted<K, P::T>;
86
87    fn fetch_full(&'_ self) -> FailFuture<'_, Node<Self::T>> {
88        Box::pin(async move {
89            let (
90                Encrypted {
91                    key,
92                    inner:
93                        EncryptedInner {
94                            resolution,
95                            decrypted: _,
96                        },
97                },
98                resolve,
99            ) = self.encrypted.fetch_full().await?;
100            let decrypted = self.decrypted.fetch().await?;
101            let decrypted = Unkeyed(Arc::new(decrypted));
102            Ok((
103                Encrypted {
104                    key,
105                    inner: EncryptedInner {
106                        resolution,
107                        decrypted,
108                    },
109                },
110                resolve,
111            ))
112        })
113    }
114
115    fn fetch(&'_ self) -> FailFuture<'_, Self::T> {
116        Box::pin(async move {
117            let Encrypted {
118                key,
119                inner:
120                    EncryptedInner {
121                        resolution,
122                        decrypted: _,
123                    },
124            } = self.encrypted.fetch().await?;
125            let decrypted = self.decrypted.fetch().await?;
126            let decrypted = Unkeyed(Arc::new(decrypted));
127            Ok(Encrypted {
128                key,
129                inner: EncryptedInner {
130                    resolution,
131                    decrypted,
132                },
133            })
134        })
135    }
136
137    fn try_fetch_local(&self) -> object_rainbow::Result<Option<Node<Self::T>>> {
138        let Some((
139            Encrypted {
140                key,
141                inner:
142                    EncryptedInner {
143                        resolution,
144                        decrypted: _,
145                    },
146            },
147            resolve,
148        )) = self.encrypted.try_fetch_local()?
149        else {
150            return Ok(None);
151        };
152        let Some((decrypted, _)) = self.decrypted.try_fetch_local()? else {
153            return Ok(None);
154        };
155        let decrypted = Unkeyed(Arc::new(decrypted));
156        Ok(Some((
157            Encrypted {
158                key,
159                inner: EncryptedInner {
160                    resolution,
161                    decrypted,
162                },
163            },
164            resolve,
165        )))
166    }
167
168    fn fetch_local(&self) -> Option<Self::T> {
169        let Encrypted {
170            key,
171            inner:
172                EncryptedInner {
173                    resolution,
174                    decrypted: _,
175                },
176        } = self.encrypted.fetch_local()?;
177        let decrypted = Unkeyed(Arc::new(self.decrypted.fetch_local()?));
178        Some(Encrypted {
179            key,
180            inner: EncryptedInner {
181                resolution,
182                decrypted,
183            },
184        })
185    }
186}
187
188impl<'a, K: Key, V: PointVisitor> PointVisitor for IterateResolution<'a, '_, K, V> {
189    fn visit<T: Traversible>(&mut self, decrypted: &(impl 'static + SingularFetch<T = T> + Clone)) {
190        let decrypted = decrypted.clone();
191        let encrypted = self.resolution.next().expect("length mismatch").clone();
192        let point = Point::from_fetch(
193            encrypted.hash(),
194            Arc::new(Visited {
195                decrypted,
196                encrypted,
197            }),
198        );
199        self.visitor.visit(&point);
200    }
201}
202
203impl<K, T> ListHashes for EncryptedInner<K, T> {
204    fn list_hashes(&self, f: &mut impl FnMut(Hash)) {
205        self.resolution.list_hashes(f);
206    }
207
208    fn topology_hash(&self) -> Hash {
209        self.resolution.0.data_hash()
210    }
211
212    fn point_count(&self) -> usize {
213        self.resolution.len()
214    }
215}
216
217impl<K: Key, T: Topological> Topological for EncryptedInner<K, T> {
218    fn traverse(&self, visitor: &mut impl PointVisitor) {
219        let resolution = &mut self.resolution.iter();
220        self.decrypted.0.traverse(&mut IterateResolution {
221            resolution,
222            visitor,
223        });
224        assert!(resolution.next().is_none());
225    }
226}
227
228pub struct Encrypted<K, T> {
229    key: K,
230    inner: EncryptedInner<K, T>,
231}
232
233impl<K, T: Clone> Encrypted<K, T> {
234    pub fn into_inner(self) -> T {
235        Arc::unwrap_or_clone(self.inner.decrypted.0)
236    }
237}
238
239impl<K, T> Deref for Encrypted<K, T> {
240    type Target = T;
241
242    fn deref(&self) -> &Self::Target {
243        self.inner.decrypted.0.as_ref()
244    }
245}
246
247impl<K: Clone, T> Clone for Encrypted<K, T> {
248    fn clone(&self) -> Self {
249        Self {
250            key: self.key.clone(),
251            inner: self.inner.clone(),
252        }
253    }
254}
255
256impl<K, T> ListHashes for Encrypted<K, T> {
257    fn list_hashes(&self, f: &mut impl FnMut(Hash)) {
258        self.inner.list_hashes(f);
259    }
260
261    fn topology_hash(&self) -> Hash {
262        self.inner.topology_hash()
263    }
264
265    fn point_count(&self) -> usize {
266        self.inner.point_count()
267    }
268}
269
270impl<K: Key, T: Topological> Topological for Encrypted<K, T> {
271    fn traverse(&self, visitor: &mut impl PointVisitor) {
272        self.inner.traverse(visitor);
273    }
274}
275
276impl<K: Key, T: ToOutput> ToOutput for Encrypted<K, T> {
277    fn to_output(&self, output: &mut dyn object_rainbow::Output) {
278        let source = self.inner.vec();
279        output.write(&self.key.encrypt(&source));
280    }
281}
282
283#[derive(Clone)]
284struct Decrypt<K> {
285    resolution: Resolution<K>,
286}
287
288impl<K: Key> Decrypt<K> {
289    async fn resolve_bytes(
290        &self,
291        address: Address,
292    ) -> object_rainbow::Result<(Vec<u8>, Resolution<K>)> {
293        let Encrypted {
294            key: _,
295            inner:
296                EncryptedInner {
297                    resolution,
298                    decrypted,
299                },
300        } = self
301            .resolution
302            .get(address.index)
303            .ok_or(Error::AddressOutOfBounds)?
304            .clone()
305            .fetch()
306            .await?;
307        let data = Arc::unwrap_or_clone(decrypted.0);
308        Ok((data, resolution))
309    }
310}
311
312impl<K: Key> Resolve for Decrypt<K> {
313    fn resolve(&'_ self, address: Address) -> FailFuture<'_, ByteNode> {
314        Box::pin(async move {
315            let (data, resolution) = self.resolve_bytes(address).await?;
316            Ok((data, Arc::new(Decrypt { resolution }) as _))
317        })
318    }
319
320    fn resolve_data(&'_ self, address: Address) -> FailFuture<'_, Vec<u8>> {
321        Box::pin(async move {
322            let (data, _) = self.resolve_bytes(address).await?;
323            Ok(data)
324        })
325    }
326
327    fn try_resolve_local(&self, address: Address) -> object_rainbow::Result<Option<ByteNode>> {
328        let Some((
329            Encrypted {
330                key: _,
331                inner:
332                    EncryptedInner {
333                        resolution,
334                        decrypted,
335                    },
336            },
337            _,
338        )) = self
339            .resolution
340            .get(address.index)
341            .ok_or(Error::AddressOutOfBounds)?
342            .clone()
343            .try_fetch_local()?
344        else {
345            return Ok(None);
346        };
347        let data = Arc::unwrap_or_clone(decrypted.0);
348        Ok(Some((data, Arc::new(Decrypt { resolution }) as _)))
349    }
350}
351
352trait EncryptedExtra<K>: 'static + Send + Sync + Clone {
353    type Extra: 'static + Send + Sync + Clone;
354    fn parts(&self) -> (K, Self::Extra);
355}
356
357impl<K: 'static + Send + Sync + Clone, Extra: 'static + Send + Sync + Clone> EncryptedExtra<K>
358    for (K, Extra)
359{
360    type Extra = Extra;
361
362    fn parts(&self) -> (K, Self::Extra) {
363        self.clone()
364    }
365}
366
367impl<K: 'static + Send + Sync + Clone> EncryptedExtra<K> for K {
368    type Extra = ();
369
370    fn parts(&self) -> (K, Self::Extra) {
371        (self.clone(), ())
372    }
373}
374
375impl<
376    K: Key,
377    T: Object<Extra>,
378    Extra: 'static + Send + Sync + Clone,
379    I: PointInput<Extra: EncryptedExtra<K, Extra = Extra>>,
380> Parse<I> for Encrypted<K, T>
381{
382    fn parse(input: I) -> object_rainbow::Result<Self> {
383        let with_key = input.extra().parts();
384        let resolve = input.resolve().clone();
385        let source = with_key
386            .0
387            .decrypt(&input.parse_all()?)
388            .map_err(object_rainbow::Error::consistency)?;
389        let EncryptedInner {
390            resolution,
391            decrypted,
392        } = EncryptedInner::<K, Vec<u8>>::parse_slice_extra(&source, &resolve, &with_key)?;
393        let decrypted = T::parse_slice_extra(
394            &decrypted.0,
395            &(Arc::new(Decrypt {
396                resolution: resolution.clone(),
397            }) as _),
398            &with_key.1,
399        )?;
400        let decrypted = Unkeyed(Arc::new(decrypted));
401        let inner = EncryptedInner {
402            resolution,
403            decrypted,
404        };
405        Ok(Self {
406            key: with_key.0,
407            inner,
408        })
409    }
410}
411
412impl<K, T> Tagged for Encrypted<K, T> {}
413
414type Extracted<K> = Vec<
415    std::pin::Pin<
416        Box<dyn Future<Output = Result<Point<Encrypted<K, Vec<u8>>>, Error>> + Send + 'static>,
417    >,
418>;
419
420struct ExtractResolution<'a, K> {
421    extracted: &'a mut Extracted<K>,
422    key: &'a K,
423}
424
425struct Untyped<K, T> {
426    key: K,
427    encrypted: Point<Encrypted<K, T>>,
428}
429
430impl<K, T> FetchBytes for Untyped<K, T> {
431    fn fetch_bytes(&'_ self) -> FailFuture<'_, ByteNode> {
432        self.encrypted.fetch_bytes()
433    }
434
435    fn fetch_data(&'_ self) -> FailFuture<'_, Vec<u8>> {
436        self.encrypted.fetch_data()
437    }
438
439    fn fetch_bytes_local(&self) -> object_rainbow::Result<Option<ByteNode>> {
440        self.encrypted.fetch_bytes_local()
441    }
442
443    fn fetch_data_local(&self) -> Option<Vec<u8>> {
444        self.encrypted.fetch_data_local()
445    }
446}
447
448impl<K: Send + Sync, T> Singular for Untyped<K, T> {
449    fn hash(&self) -> Hash {
450        self.encrypted.hash()
451    }
452}
453
454impl<K: Key, T: FullHash> Fetch for Untyped<K, T> {
455    type T = Encrypted<K, Vec<u8>>;
456
457    fn fetch_full(&'_ self) -> FailFuture<'_, Node<Self::T>> {
458        Box::pin(async move {
459            let (data, resolve) = self.fetch_bytes().await?;
460            let encrypted = Self::T::parse_slice_extra(&data, &resolve, &self.key)?;
461            Ok((encrypted, resolve))
462        })
463    }
464
465    fn fetch(&'_ self) -> FailFuture<'_, Self::T> {
466        Box::pin(async move {
467            let (data, resolve) = self.fetch_bytes().await?;
468            let encrypted = Self::T::parse_slice_extra(&data, &resolve, &self.key)?;
469            Ok(encrypted)
470        })
471    }
472
473    fn try_fetch_local(&self) -> object_rainbow::Result<Option<Node<Self::T>>> {
474        let Some((data, resolve)) = self.fetch_bytes_local()? else {
475            return Ok(None);
476        };
477        let encrypted = Self::T::parse_slice_extra(&data, &resolve, &self.key)?;
478        Ok(Some((encrypted, resolve)))
479    }
480
481    fn fetch_local(&self) -> Option<Self::T> {
482        let Encrypted {
483            key,
484            inner:
485                EncryptedInner {
486                    resolution,
487                    decrypted,
488                },
489        } = self.encrypted.fetch_local()?;
490        let decrypted = Unkeyed(Arc::new(decrypted.vec()));
491        Some(Encrypted {
492            key,
493            inner: EncryptedInner {
494                resolution,
495                decrypted,
496            },
497        })
498    }
499}
500
501impl<K: Key> PointVisitor for ExtractResolution<'_, K> {
502    fn visit<T: Traversible>(&mut self, decrypted: &(impl 'static + SingularFetch<T = T> + Clone)) {
503        let decrypted = decrypted.clone();
504        let key = self.key.clone();
505        self.extracted.push(Box::pin(async move {
506            let encrypted = encrypt_point(key.clone(), decrypted).await?;
507            let encrypted =
508                Point::from_fetch(encrypted.hash(), Arc::new(Untyped { key, encrypted }));
509            Ok(encrypted)
510        }));
511    }
512}
513
514pub async fn encrypt_point<K: Key, T: Traversible>(
515    key: K,
516    decrypted: impl 'static + SingularFetch<T = T>,
517) -> object_rainbow::Result<Point<Encrypted<K, T>>> {
518    if let Some((address, decrypt)) = decrypted.extract_resolve::<Decrypt<K>>() {
519        let encrypted = decrypt
520            .resolution
521            .get(address.index)
522            .ok_or(Error::AddressOutOfBounds)?
523            .clone();
524        let point = Point::from_fetch(
525            encrypted.hash(),
526            Arc::new(Visited {
527                decrypted,
528                encrypted,
529            }),
530        );
531        return Ok(point);
532    }
533    let decrypted = decrypted.fetch().await?;
534    let encrypted = encrypt(key.clone(), decrypted).await?;
535    let point = encrypted.point();
536    Ok(point)
537}
538
539pub async fn encrypt<K: Key, T: Traversible>(
540    key: K,
541    decrypted: T,
542) -> object_rainbow::Result<Encrypted<K, T>> {
543    let mut futures = Vec::with_capacity(decrypted.point_count());
544    decrypted.traverse(&mut ExtractResolution {
545        extracted: &mut futures,
546        key: &key,
547    });
548    let resolution = futures_util::future::try_join_all(futures).await?;
549    let resolution = Arc::new(Lp(resolution));
550    let decrypted = Unkeyed(Arc::new(decrypted));
551    let inner = EncryptedInner {
552        resolution,
553        decrypted,
554    };
555    Ok(Encrypted { key, inner })
556}