object_rainbow_encrypted/
lib.rs

1use std::{ops::Deref, sync::Arc};
2
3use object_rainbow::{
4    Address, ByteNode, Error, ExtractResolve, FailFuture, Fetch, FetchBytes, FullHash, Hash,
5    ListHashes, Node, Object, Parse, ParseSliceExtra, Point, PointInput, PointVisitor, Resolve,
6    Singular, SingularFetch, Tagged, ToOutput, Topological, Traversible, length_prefixed::Lp,
7};
8
9#[derive(Clone)]
10pub struct WithKey<K, Extra> {
11    pub key: K,
12    pub extra: Extra,
13}
14
15pub trait Key: 'static + Sized + Send + Sync + Clone {
16    fn encrypt(&self, data: &[u8]) -> Vec<u8>;
17    fn decrypt(&self, data: &[u8]) -> object_rainbow::Result<Vec<u8>>;
18}
19
20type Resolution<K> = Arc<Lp<Vec<Point<Encrypted<K, Vec<u8>>>>>>;
21
22#[derive(ToOutput, Clone)]
23struct Unkeyed<T>(T);
24
25impl<
26    T: Parse<I::WithExtra<Extra>>,
27    K: 'static + Clone,
28    Extra: 'static + Clone,
29    I: PointInput<Extra = WithKey<K, Extra>>,
30> Parse<I> for Unkeyed<T>
31{
32    fn parse(input: I) -> object_rainbow::Result<Self> {
33        Ok(Self(T::parse(
34            input.map_extra(|WithKey { extra, .. }| extra),
35        )?))
36    }
37}
38
39#[derive(ToOutput, Parse)]
40struct EncryptedInner<K, T> {
41    resolution: Resolution<K>,
42    decrypted: Unkeyed<Arc<T>>,
43}
44
45impl<K, T> Clone for EncryptedInner<K, T> {
46    fn clone(&self) -> Self {
47        Self {
48            resolution: self.resolution.clone(),
49            decrypted: self.decrypted.clone(),
50        }
51    }
52}
53
54type ResolutionIter<'a, K> = std::slice::Iter<'a, Point<Encrypted<K, Vec<u8>>>>;
55
56struct IterateResolution<'a, 'r, K, V> {
57    resolution: &'r mut ResolutionIter<'a, K>,
58    visitor: &'a mut V,
59}
60
61struct Visited<K, P> {
62    decrypted: P,
63    encrypted: Point<Encrypted<K, Vec<u8>>>,
64}
65
66impl<K, P> FetchBytes for Visited<K, P> {
67    fn fetch_bytes(&'_ self) -> FailFuture<'_, ByteNode> {
68        self.encrypted.fetch_bytes()
69    }
70
71    fn fetch_data(&'_ self) -> FailFuture<'_, Vec<u8>> {
72        self.encrypted.fetch_data()
73    }
74
75    fn fetch_bytes_local(&self) -> object_rainbow::Result<Option<ByteNode>> {
76        self.encrypted.fetch_bytes_local()
77    }
78
79    fn fetch_data_local(&self) -> Option<Vec<u8>> {
80        self.encrypted.fetch_data_local()
81    }
82}
83
84impl<K, P: Send + Sync> Singular for Visited<K, P> {
85    fn hash(&self) -> Hash {
86        self.encrypted.hash()
87    }
88}
89
90impl<K: Key, P: Fetch<T: Traversible>> Fetch for Visited<K, P> {
91    type T = Encrypted<K, P::T>;
92
93    fn fetch_full(&'_ self) -> FailFuture<'_, Node<Self::T>> {
94        Box::pin(async move {
95            let (
96                Encrypted {
97                    key,
98                    inner:
99                        EncryptedInner {
100                            resolution,
101                            decrypted: _,
102                        },
103                },
104                resolve,
105            ) = self.encrypted.fetch_full().await?;
106            let decrypted = self.decrypted.fetch().await?;
107            let decrypted = Unkeyed(Arc::new(decrypted));
108            Ok((
109                Encrypted {
110                    key,
111                    inner: EncryptedInner {
112                        resolution,
113                        decrypted,
114                    },
115                },
116                resolve,
117            ))
118        })
119    }
120
121    fn fetch(&'_ self) -> FailFuture<'_, Self::T> {
122        Box::pin(async move {
123            let Encrypted {
124                key,
125                inner:
126                    EncryptedInner {
127                        resolution,
128                        decrypted: _,
129                    },
130            } = self.encrypted.fetch().await?;
131            let decrypted = self.decrypted.fetch().await?;
132            let decrypted = Unkeyed(Arc::new(decrypted));
133            Ok(Encrypted {
134                key,
135                inner: EncryptedInner {
136                    resolution,
137                    decrypted,
138                },
139            })
140        })
141    }
142
143    fn try_fetch_local(&self) -> object_rainbow::Result<Option<Node<Self::T>>> {
144        let Some((
145            Encrypted {
146                key,
147                inner:
148                    EncryptedInner {
149                        resolution,
150                        decrypted: _,
151                    },
152            },
153            resolve,
154        )) = self.encrypted.try_fetch_local()?
155        else {
156            return Ok(None);
157        };
158        let Some((decrypted, _)) = self.decrypted.try_fetch_local()? else {
159            return Ok(None);
160        };
161        let decrypted = Unkeyed(Arc::new(decrypted));
162        Ok(Some((
163            Encrypted {
164                key,
165                inner: EncryptedInner {
166                    resolution,
167                    decrypted,
168                },
169            },
170            resolve,
171        )))
172    }
173
174    fn fetch_local(&self) -> Option<Self::T> {
175        let Encrypted {
176            key,
177            inner:
178                EncryptedInner {
179                    resolution,
180                    decrypted: _,
181                },
182        } = self.encrypted.fetch_local()?;
183        let decrypted = Unkeyed(Arc::new(self.decrypted.fetch_local()?));
184        Some(Encrypted {
185            key,
186            inner: EncryptedInner {
187                resolution,
188                decrypted,
189            },
190        })
191    }
192}
193
194impl<'a, K: Key, V: PointVisitor> PointVisitor for IterateResolution<'a, '_, K, V> {
195    fn visit<T: Traversible>(&mut self, decrypted: &(impl 'static + SingularFetch<T = T> + Clone)) {
196        let decrypted = decrypted.clone();
197        let encrypted = self.resolution.next().expect("length mismatch").clone();
198        let point = Point::from_fetch(
199            encrypted.hash(),
200            Arc::new(Visited {
201                decrypted,
202                encrypted,
203            }),
204        );
205        self.visitor.visit(&point);
206    }
207}
208
209impl<K, T> ListHashes for EncryptedInner<K, T> {
210    fn list_hashes(&self, f: &mut impl FnMut(Hash)) {
211        self.resolution.list_hashes(f);
212    }
213
214    fn topology_hash(&self) -> Hash {
215        self.resolution.0.data_hash()
216    }
217
218    fn point_count(&self) -> usize {
219        self.resolution.len()
220    }
221}
222
223impl<K: Key, T: Topological> Topological for EncryptedInner<K, T> {
224    fn traverse(&self, visitor: &mut impl PointVisitor) {
225        let resolution = &mut self.resolution.iter();
226        self.decrypted.0.traverse(&mut IterateResolution {
227            resolution,
228            visitor,
229        });
230        assert!(resolution.next().is_none());
231    }
232}
233
234pub struct Encrypted<K, T> {
235    key: K,
236    inner: EncryptedInner<K, T>,
237}
238
239impl<K, T: Clone> Encrypted<K, T> {
240    pub fn into_inner(self) -> T {
241        Arc::unwrap_or_clone(self.inner.decrypted.0)
242    }
243}
244
245impl<K, T> Deref for Encrypted<K, T> {
246    type Target = T;
247
248    fn deref(&self) -> &Self::Target {
249        self.inner.decrypted.0.as_ref()
250    }
251}
252
253impl<K: Clone, T> Clone for Encrypted<K, T> {
254    fn clone(&self) -> Self {
255        Self {
256            key: self.key.clone(),
257            inner: self.inner.clone(),
258        }
259    }
260}
261
262impl<K, T> ListHashes for Encrypted<K, T> {
263    fn list_hashes(&self, f: &mut impl FnMut(Hash)) {
264        self.inner.list_hashes(f);
265    }
266
267    fn topology_hash(&self) -> Hash {
268        self.inner.topology_hash()
269    }
270
271    fn point_count(&self) -> usize {
272        self.inner.point_count()
273    }
274}
275
276impl<K: Key, T: Topological> Topological for Encrypted<K, T> {
277    fn traverse(&self, visitor: &mut impl PointVisitor) {
278        self.inner.traverse(visitor);
279    }
280}
281
282impl<K: Key, T: ToOutput> ToOutput for Encrypted<K, T> {
283    fn to_output(&self, output: &mut dyn object_rainbow::Output) {
284        let source = self.inner.vec();
285        output.write(&self.key.encrypt(&source));
286    }
287}
288
289#[derive(Clone)]
290struct Decrypt<K> {
291    resolution: Resolution<K>,
292}
293
294impl<K: Key> Decrypt<K> {
295    async fn resolve_bytes(
296        &self,
297        address: Address,
298    ) -> object_rainbow::Result<(Vec<u8>, Resolution<K>)> {
299        let Encrypted {
300            key: _,
301            inner:
302                EncryptedInner {
303                    resolution,
304                    decrypted,
305                },
306        } = self
307            .resolution
308            .get(address.index)
309            .ok_or(Error::AddressOutOfBounds)?
310            .clone()
311            .fetch()
312            .await?;
313        let data = Arc::unwrap_or_clone(decrypted.0);
314        Ok((data, resolution))
315    }
316}
317
318impl<K: Key> Resolve for Decrypt<K> {
319    fn resolve(&'_ self, address: Address) -> FailFuture<'_, ByteNode> {
320        Box::pin(async move {
321            let (data, resolution) = self.resolve_bytes(address).await?;
322            Ok((data, Arc::new(Decrypt { resolution }) as _))
323        })
324    }
325
326    fn resolve_data(&'_ self, address: Address) -> FailFuture<'_, Vec<u8>> {
327        Box::pin(async move {
328            let (data, _) = self.resolve_bytes(address).await?;
329            Ok(data)
330        })
331    }
332
333    fn try_resolve_local(&self, address: Address) -> object_rainbow::Result<Option<ByteNode>> {
334        let Some((
335            Encrypted {
336                key: _,
337                inner:
338                    EncryptedInner {
339                        resolution,
340                        decrypted,
341                    },
342            },
343            _,
344        )) = self
345            .resolution
346            .get(address.index)
347            .ok_or(Error::AddressOutOfBounds)?
348            .clone()
349            .try_fetch_local()?
350        else {
351            return Ok(None);
352        };
353        let data = Arc::unwrap_or_clone(decrypted.0);
354        Ok(Some((data, Arc::new(Decrypt { resolution }) as _)))
355    }
356}
357
358impl<
359    K: Key,
360    T: Object<Extra>,
361    Extra: 'static + Send + Sync + Clone,
362    I: PointInput<Extra = WithKey<K, Extra>>,
363> Parse<I> for Encrypted<K, T>
364{
365    fn parse(input: I) -> object_rainbow::Result<Self> {
366        let with_key = input.extra().clone();
367        let resolve = input.resolve().clone();
368        let source = with_key.key.decrypt(&input.parse_all()?)?;
369        let EncryptedInner {
370            resolution,
371            decrypted,
372        } = EncryptedInner::<K, Vec<u8>>::parse_slice_extra(&source, &resolve, &with_key)?;
373        let decrypted = T::parse_slice_extra(
374            &decrypted.0,
375            &(Arc::new(Decrypt {
376                resolution: resolution.clone(),
377            }) as _),
378            &with_key.extra,
379        )?;
380        let decrypted = Unkeyed(Arc::new(decrypted));
381        let inner = EncryptedInner {
382            resolution,
383            decrypted,
384        };
385        Ok(Self {
386            key: with_key.key,
387            inner,
388        })
389    }
390}
391
392impl<K, T> Tagged for Encrypted<K, T> {}
393
394type Extracted<K> = Vec<
395    std::pin::Pin<
396        Box<dyn Future<Output = Result<Point<Encrypted<K, Vec<u8>>>, Error>> + Send + 'static>,
397    >,
398>;
399
400struct ExtractResolution<'a, K> {
401    extracted: &'a mut Extracted<K>,
402    key: &'a K,
403}
404
405struct Untyped<K, T> {
406    key: WithKey<K, ()>,
407    encrypted: Point<Encrypted<K, T>>,
408}
409
410impl<K, T> FetchBytes for Untyped<K, T> {
411    fn fetch_bytes(&'_ self) -> FailFuture<'_, ByteNode> {
412        self.encrypted.fetch_bytes()
413    }
414
415    fn fetch_data(&'_ self) -> FailFuture<'_, Vec<u8>> {
416        self.encrypted.fetch_data()
417    }
418
419    fn fetch_bytes_local(&self) -> object_rainbow::Result<Option<ByteNode>> {
420        self.encrypted.fetch_bytes_local()
421    }
422
423    fn fetch_data_local(&self) -> Option<Vec<u8>> {
424        self.encrypted.fetch_data_local()
425    }
426}
427
428impl<K: Send + Sync, T> Singular for Untyped<K, T> {
429    fn hash(&self) -> Hash {
430        self.encrypted.hash()
431    }
432}
433
434impl<K: Key, T: FullHash> Fetch for Untyped<K, T> {
435    type T = Encrypted<K, Vec<u8>>;
436
437    fn fetch_full(&'_ self) -> FailFuture<'_, Node<Self::T>> {
438        Box::pin(async move {
439            let (data, resolve) = self.fetch_bytes().await?;
440            let encrypted = Self::T::parse_slice_extra(&data, &resolve, &self.key)?;
441            Ok((encrypted, resolve))
442        })
443    }
444
445    fn fetch(&'_ self) -> FailFuture<'_, Self::T> {
446        Box::pin(async move {
447            let (data, resolve) = self.fetch_bytes().await?;
448            let encrypted = Self::T::parse_slice_extra(&data, &resolve, &self.key)?;
449            Ok(encrypted)
450        })
451    }
452
453    fn try_fetch_local(&self) -> object_rainbow::Result<Option<Node<Self::T>>> {
454        let Some((data, resolve)) = self.fetch_bytes_local()? else {
455            return Ok(None);
456        };
457        let encrypted = Self::T::parse_slice_extra(&data, &resolve, &self.key)?;
458        Ok(Some((encrypted, resolve)))
459    }
460
461    fn fetch_local(&self) -> Option<Self::T> {
462        let Encrypted {
463            key,
464            inner:
465                EncryptedInner {
466                    resolution,
467                    decrypted,
468                },
469        } = self.encrypted.fetch_local()?;
470        let decrypted = Unkeyed(Arc::new(decrypted.vec()));
471        Some(Encrypted {
472            key,
473            inner: EncryptedInner {
474                resolution,
475                decrypted,
476            },
477        })
478    }
479}
480
481impl<K: Key> PointVisitor for ExtractResolution<'_, K> {
482    fn visit<T: Traversible>(&mut self, decrypted: &(impl 'static + SingularFetch<T = T> + Clone)) {
483        let decrypted = decrypted.clone();
484        let key = self.key.clone();
485        self.extracted.push(Box::pin(async move {
486            let encrypted = encrypt_point(key.clone(), decrypted).await?;
487            let encrypted = Point::from_fetch(
488                encrypted.hash(),
489                Arc::new(Untyped {
490                    key: WithKey { key, extra: () },
491                    encrypted,
492                }),
493            );
494            Ok(encrypted)
495        }));
496    }
497}
498
499pub async fn encrypt_point<K: Key, T: Traversible>(
500    key: K,
501    decrypted: impl 'static + SingularFetch<T = T>,
502) -> object_rainbow::Result<Point<Encrypted<K, T>>> {
503    if let Some((address, decrypt)) = decrypted.extract_resolve::<Decrypt<K>>() {
504        let encrypted = decrypt
505            .resolution
506            .get(address.index)
507            .ok_or(Error::AddressOutOfBounds)?
508            .clone();
509        let point = Point::from_fetch(
510            encrypted.hash(),
511            Arc::new(Visited {
512                decrypted,
513                encrypted,
514            }),
515        );
516        return Ok(point);
517    }
518    let decrypted = decrypted.fetch().await?;
519    let encrypted = encrypt(key.clone(), decrypted).await?;
520    let point = encrypted.point();
521    Ok(point)
522}
523
524pub async fn encrypt<K: Key, T: Traversible>(
525    key: K,
526    decrypted: T,
527) -> object_rainbow::Result<Encrypted<K, T>> {
528    let mut futures = Vec::with_capacity(decrypted.point_count());
529    decrypted.traverse(&mut ExtractResolution {
530        extracted: &mut futures,
531        key: &key,
532    });
533    let resolution = futures_util::future::try_join_all(futures).await?;
534    let resolution = Arc::new(Lp(resolution));
535    let decrypted = Unkeyed(Arc::new(decrypted));
536    let inner = EncryptedInner {
537        resolution,
538        decrypted,
539    };
540    Ok(Encrypted { key, inner })
541}