object_rainbow_encrypted/
lib.rs

1use std::{ops::Deref, sync::Arc};
2
3use object_rainbow::{
4    Address, ByteNode, Error, FailFuture, Fetch, FetchBytes, FullHash, Hash, ListHashes, Node,
5    Object, Parse, ParseSliceExtra, PointInput, PointVisitor, Resolve, Singular, SingularFetch,
6    Tagged, ToOutput, Topological, Traversible, length_prefixed::Lp,
7};
8use object_rainbow_point::{ExtractResolve, IntoPoint, Point};
9
10#[derive(Clone)]
11pub struct WithKey<K, Extra> {
12    pub key: K,
13    pub extra: Extra,
14}
15
16pub trait Key: 'static + Sized + Send + Sync + Clone {
17    fn encrypt(&self, data: &[u8]) -> Vec<u8>;
18    fn decrypt(&self, data: &[u8]) -> object_rainbow::Result<Vec<u8>>;
19}
20
21type Resolution<K> = Arc<Lp<Vec<Point<Encrypted<K, Vec<u8>>>>>>;
22
23#[derive(ToOutput, Clone)]
24struct Unkeyed<T>(T);
25
26impl<
27    T: Parse<I::WithExtra<Extra>>,
28    K: 'static + Clone,
29    Extra: 'static + Clone,
30    I: PointInput<Extra = WithKey<K, Extra>>,
31> Parse<I> for Unkeyed<T>
32{
33    fn parse(input: I) -> object_rainbow::Result<Self> {
34        Ok(Self(T::parse(
35            input.map_extra(|WithKey { extra, .. }| extra),
36        )?))
37    }
38}
39
40#[derive(ToOutput, Parse)]
41struct EncryptedInner<K, T> {
42    resolution: Resolution<K>,
43    decrypted: Unkeyed<Arc<T>>,
44}
45
46impl<K, T> Clone for EncryptedInner<K, T> {
47    fn clone(&self) -> Self {
48        Self {
49            resolution: self.resolution.clone(),
50            decrypted: self.decrypted.clone(),
51        }
52    }
53}
54
55type ResolutionIter<'a, K> = std::slice::Iter<'a, Point<Encrypted<K, Vec<u8>>>>;
56
57struct IterateResolution<'a, 'r, K, V> {
58    resolution: &'r mut ResolutionIter<'a, K>,
59    visitor: &'a mut V,
60}
61
62struct Visited<K, P> {
63    decrypted: P,
64    encrypted: Point<Encrypted<K, Vec<u8>>>,
65}
66
67impl<K, P> FetchBytes for Visited<K, P> {
68    fn fetch_bytes(&'_ self) -> FailFuture<'_, ByteNode> {
69        self.encrypted.fetch_bytes()
70    }
71
72    fn fetch_data(&'_ self) -> FailFuture<'_, Vec<u8>> {
73        self.encrypted.fetch_data()
74    }
75
76    fn fetch_bytes_local(&self) -> object_rainbow::Result<Option<ByteNode>> {
77        self.encrypted.fetch_bytes_local()
78    }
79
80    fn fetch_data_local(&self) -> Option<Vec<u8>> {
81        self.encrypted.fetch_data_local()
82    }
83}
84
85impl<K, P: Send + Sync> Singular for Visited<K, P> {
86    fn hash(&self) -> Hash {
87        self.encrypted.hash()
88    }
89}
90
91impl<K: Key, P: Fetch<T: Traversible>> Fetch for Visited<K, P> {
92    type T = Encrypted<K, P::T>;
93
94    fn fetch_full(&'_ self) -> FailFuture<'_, Node<Self::T>> {
95        Box::pin(async move {
96            let (
97                Encrypted {
98                    key,
99                    inner:
100                        EncryptedInner {
101                            resolution,
102                            decrypted: _,
103                        },
104                },
105                resolve,
106            ) = self.encrypted.fetch_full().await?;
107            let decrypted = self.decrypted.fetch().await?;
108            let decrypted = Unkeyed(Arc::new(decrypted));
109            Ok((
110                Encrypted {
111                    key,
112                    inner: EncryptedInner {
113                        resolution,
114                        decrypted,
115                    },
116                },
117                resolve,
118            ))
119        })
120    }
121
122    fn fetch(&'_ self) -> FailFuture<'_, Self::T> {
123        Box::pin(async move {
124            let Encrypted {
125                key,
126                inner:
127                    EncryptedInner {
128                        resolution,
129                        decrypted: _,
130                    },
131            } = self.encrypted.fetch().await?;
132            let decrypted = self.decrypted.fetch().await?;
133            let decrypted = Unkeyed(Arc::new(decrypted));
134            Ok(Encrypted {
135                key,
136                inner: EncryptedInner {
137                    resolution,
138                    decrypted,
139                },
140            })
141        })
142    }
143
144    fn try_fetch_local(&self) -> object_rainbow::Result<Option<Node<Self::T>>> {
145        let Some((
146            Encrypted {
147                key,
148                inner:
149                    EncryptedInner {
150                        resolution,
151                        decrypted: _,
152                    },
153            },
154            resolve,
155        )) = self.encrypted.try_fetch_local()?
156        else {
157            return Ok(None);
158        };
159        let Some((decrypted, _)) = self.decrypted.try_fetch_local()? else {
160            return Ok(None);
161        };
162        let decrypted = Unkeyed(Arc::new(decrypted));
163        Ok(Some((
164            Encrypted {
165                key,
166                inner: EncryptedInner {
167                    resolution,
168                    decrypted,
169                },
170            },
171            resolve,
172        )))
173    }
174
175    fn fetch_local(&self) -> Option<Self::T> {
176        let Encrypted {
177            key,
178            inner:
179                EncryptedInner {
180                    resolution,
181                    decrypted: _,
182                },
183        } = self.encrypted.fetch_local()?;
184        let decrypted = Unkeyed(Arc::new(self.decrypted.fetch_local()?));
185        Some(Encrypted {
186            key,
187            inner: EncryptedInner {
188                resolution,
189                decrypted,
190            },
191        })
192    }
193}
194
195impl<'a, K: Key, V: PointVisitor> PointVisitor for IterateResolution<'a, '_, K, V> {
196    fn visit<T: Traversible>(&mut self, decrypted: &(impl 'static + SingularFetch<T = T> + Clone)) {
197        let decrypted = decrypted.clone();
198        let encrypted = self.resolution.next().expect("length mismatch").clone();
199        let point = Point::from_fetch(
200            encrypted.hash(),
201            Arc::new(Visited {
202                decrypted,
203                encrypted,
204            }),
205        );
206        self.visitor.visit(&point);
207    }
208}
209
210impl<K, T> ListHashes for EncryptedInner<K, T> {
211    fn list_hashes(&self, f: &mut impl FnMut(Hash)) {
212        self.resolution.list_hashes(f);
213    }
214
215    fn topology_hash(&self) -> Hash {
216        self.resolution.0.data_hash()
217    }
218
219    fn point_count(&self) -> usize {
220        self.resolution.len()
221    }
222}
223
224impl<K: Key, T: Topological> Topological for EncryptedInner<K, T> {
225    fn traverse(&self, visitor: &mut impl PointVisitor) {
226        let resolution = &mut self.resolution.iter();
227        self.decrypted.0.traverse(&mut IterateResolution {
228            resolution,
229            visitor,
230        });
231        assert!(resolution.next().is_none());
232    }
233}
234
235pub struct Encrypted<K, T> {
236    key: K,
237    inner: EncryptedInner<K, T>,
238}
239
240impl<K, T: Clone> Encrypted<K, T> {
241    pub fn into_inner(self) -> T {
242        Arc::unwrap_or_clone(self.inner.decrypted.0)
243    }
244}
245
246impl<K, T> Deref for Encrypted<K, T> {
247    type Target = T;
248
249    fn deref(&self) -> &Self::Target {
250        self.inner.decrypted.0.as_ref()
251    }
252}
253
254impl<K: Clone, T> Clone for Encrypted<K, T> {
255    fn clone(&self) -> Self {
256        Self {
257            key: self.key.clone(),
258            inner: self.inner.clone(),
259        }
260    }
261}
262
263impl<K, T> ListHashes for Encrypted<K, T> {
264    fn list_hashes(&self, f: &mut impl FnMut(Hash)) {
265        self.inner.list_hashes(f);
266    }
267
268    fn topology_hash(&self) -> Hash {
269        self.inner.topology_hash()
270    }
271
272    fn point_count(&self) -> usize {
273        self.inner.point_count()
274    }
275}
276
277impl<K: Key, T: Topological> Topological for Encrypted<K, T> {
278    fn traverse(&self, visitor: &mut impl PointVisitor) {
279        self.inner.traverse(visitor);
280    }
281}
282
283impl<K: Key, T: ToOutput> ToOutput for Encrypted<K, T> {
284    fn to_output(&self, output: &mut dyn object_rainbow::Output) {
285        let source = self.inner.vec();
286        output.write(&self.key.encrypt(&source));
287    }
288}
289
290#[derive(Clone)]
291struct Decrypt<K> {
292    resolution: Resolution<K>,
293}
294
295impl<K: Key> Decrypt<K> {
296    async fn resolve_bytes(
297        &self,
298        address: Address,
299    ) -> object_rainbow::Result<(Vec<u8>, Resolution<K>)> {
300        let Encrypted {
301            key: _,
302            inner:
303                EncryptedInner {
304                    resolution,
305                    decrypted,
306                },
307        } = self
308            .resolution
309            .get(address.index)
310            .ok_or(Error::AddressOutOfBounds)?
311            .clone()
312            .fetch()
313            .await?;
314        let data = Arc::unwrap_or_clone(decrypted.0);
315        Ok((data, resolution))
316    }
317}
318
319impl<K: Key> Resolve for Decrypt<K> {
320    fn resolve(&'_ self, address: Address) -> FailFuture<'_, ByteNode> {
321        Box::pin(async move {
322            let (data, resolution) = self.resolve_bytes(address).await?;
323            Ok((data, Arc::new(Decrypt { resolution }) as _))
324        })
325    }
326
327    fn resolve_data(&'_ self, address: Address) -> FailFuture<'_, Vec<u8>> {
328        Box::pin(async move {
329            let (data, _) = self.resolve_bytes(address).await?;
330            Ok(data)
331        })
332    }
333
334    fn try_resolve_local(&self, address: Address) -> object_rainbow::Result<Option<ByteNode>> {
335        let Some((
336            Encrypted {
337                key: _,
338                inner:
339                    EncryptedInner {
340                        resolution,
341                        decrypted,
342                    },
343            },
344            _,
345        )) = self
346            .resolution
347            .get(address.index)
348            .ok_or(Error::AddressOutOfBounds)?
349            .clone()
350            .try_fetch_local()?
351        else {
352            return Ok(None);
353        };
354        let data = Arc::unwrap_or_clone(decrypted.0);
355        Ok(Some((data, Arc::new(Decrypt { resolution }) as _)))
356    }
357}
358
359impl<
360    K: Key,
361    T: Object<Extra>,
362    Extra: 'static + Send + Sync + Clone,
363    I: PointInput<Extra = WithKey<K, Extra>>,
364> Parse<I> for Encrypted<K, T>
365{
366    fn parse(input: I) -> object_rainbow::Result<Self> {
367        let with_key = input.extra().clone();
368        let resolve = input.resolve().clone();
369        let source = with_key.key.decrypt(&input.parse_all()?)?;
370        let EncryptedInner {
371            resolution,
372            decrypted,
373        } = EncryptedInner::<K, Vec<u8>>::parse_slice_extra(&source, &resolve, &with_key)?;
374        let decrypted = T::parse_slice_extra(
375            &decrypted.0,
376            &(Arc::new(Decrypt {
377                resolution: resolution.clone(),
378            }) as _),
379            &with_key.extra,
380        )?;
381        let decrypted = Unkeyed(Arc::new(decrypted));
382        let inner = EncryptedInner {
383            resolution,
384            decrypted,
385        };
386        Ok(Self {
387            key: with_key.key,
388            inner,
389        })
390    }
391}
392
393impl<K, T> Tagged for Encrypted<K, T> {}
394
395type Extracted<K> = Vec<
396    std::pin::Pin<
397        Box<dyn Future<Output = Result<Point<Encrypted<K, Vec<u8>>>, Error>> + Send + 'static>,
398    >,
399>;
400
401struct ExtractResolution<'a, K> {
402    extracted: &'a mut Extracted<K>,
403    key: &'a K,
404}
405
406struct Untyped<K, T> {
407    key: WithKey<K, ()>,
408    encrypted: Point<Encrypted<K, T>>,
409}
410
411impl<K, T> FetchBytes for Untyped<K, T> {
412    fn fetch_bytes(&'_ self) -> FailFuture<'_, ByteNode> {
413        self.encrypted.fetch_bytes()
414    }
415
416    fn fetch_data(&'_ self) -> FailFuture<'_, Vec<u8>> {
417        self.encrypted.fetch_data()
418    }
419
420    fn fetch_bytes_local(&self) -> object_rainbow::Result<Option<ByteNode>> {
421        self.encrypted.fetch_bytes_local()
422    }
423
424    fn fetch_data_local(&self) -> Option<Vec<u8>> {
425        self.encrypted.fetch_data_local()
426    }
427}
428
429impl<K: Send + Sync, T> Singular for Untyped<K, T> {
430    fn hash(&self) -> Hash {
431        self.encrypted.hash()
432    }
433}
434
435impl<K: Key, T: FullHash> Fetch for Untyped<K, T> {
436    type T = Encrypted<K, Vec<u8>>;
437
438    fn fetch_full(&'_ self) -> FailFuture<'_, Node<Self::T>> {
439        Box::pin(async move {
440            let (data, resolve) = self.fetch_bytes().await?;
441            let encrypted = Self::T::parse_slice_extra(&data, &resolve, &self.key)?;
442            Ok((encrypted, resolve))
443        })
444    }
445
446    fn fetch(&'_ self) -> FailFuture<'_, Self::T> {
447        Box::pin(async move {
448            let (data, resolve) = self.fetch_bytes().await?;
449            let encrypted = Self::T::parse_slice_extra(&data, &resolve, &self.key)?;
450            Ok(encrypted)
451        })
452    }
453
454    fn try_fetch_local(&self) -> object_rainbow::Result<Option<Node<Self::T>>> {
455        let Some((data, resolve)) = self.fetch_bytes_local()? else {
456            return Ok(None);
457        };
458        let encrypted = Self::T::parse_slice_extra(&data, &resolve, &self.key)?;
459        Ok(Some((encrypted, resolve)))
460    }
461
462    fn fetch_local(&self) -> Option<Self::T> {
463        let Encrypted {
464            key,
465            inner:
466                EncryptedInner {
467                    resolution,
468                    decrypted,
469                },
470        } = self.encrypted.fetch_local()?;
471        let decrypted = Unkeyed(Arc::new(decrypted.vec()));
472        Some(Encrypted {
473            key,
474            inner: EncryptedInner {
475                resolution,
476                decrypted,
477            },
478        })
479    }
480}
481
482impl<K: Key> PointVisitor for ExtractResolution<'_, K> {
483    fn visit<T: Traversible>(&mut self, decrypted: &(impl 'static + SingularFetch<T = T> + Clone)) {
484        let decrypted = decrypted.clone();
485        let key = self.key.clone();
486        self.extracted.push(Box::pin(async move {
487            let encrypted = encrypt_point(key.clone(), decrypted).await?;
488            let encrypted = Point::from_fetch(
489                encrypted.hash(),
490                Arc::new(Untyped {
491                    key: WithKey { key, extra: () },
492                    encrypted,
493                }),
494            );
495            Ok(encrypted)
496        }));
497    }
498}
499
500pub async fn encrypt_point<K: Key, T: Traversible>(
501    key: K,
502    decrypted: impl 'static + SingularFetch<T = T>,
503) -> object_rainbow::Result<Point<Encrypted<K, T>>> {
504    if let Some((address, decrypt)) = decrypted.extract_resolve::<Decrypt<K>>() {
505        let encrypted = decrypt
506            .resolution
507            .get(address.index)
508            .ok_or(Error::AddressOutOfBounds)?
509            .clone();
510        let point = Point::from_fetch(
511            encrypted.hash(),
512            Arc::new(Visited {
513                decrypted,
514                encrypted,
515            }),
516        );
517        return Ok(point);
518    }
519    let decrypted = decrypted.fetch().await?;
520    let encrypted = encrypt(key.clone(), decrypted).await?;
521    let point = encrypted.point();
522    Ok(point)
523}
524
525pub async fn encrypt<K: Key, T: Traversible>(
526    key: K,
527    decrypted: T,
528) -> object_rainbow::Result<Encrypted<K, T>> {
529    let mut futures = Vec::with_capacity(decrypted.point_count());
530    decrypted.traverse(&mut ExtractResolution {
531        extracted: &mut futures,
532        key: &key,
533    });
534    let resolution = futures_util::future::try_join_all(futures).await?;
535    let resolution = Arc::new(Lp(resolution));
536    let decrypted = Unkeyed(Arc::new(decrypted));
537    let inner = EncryptedInner {
538        resolution,
539        decrypted,
540    };
541    Ok(Encrypted { key, inner })
542}