Skip to main content

object_rainbow_encrypted/
lib.rs

1use std::{ops::Deref, sync::Arc};
2
3use object_rainbow::{
4    Address, ByteNode, Error, FailFuture, Fetch, FetchBytes, FullHash, Hash, ListHashes, Node,
5    Object, Parse, ParseSliceExtra, PointInput, PointVisitor, Resolve, Singular, SingularFetch,
6    Tagged, ToOutput, Topological, Traversible, derive_for_wrapped, length_prefixed::Lp,
7};
8use object_rainbow_point::{ExtractResolve, IntoPoint, Point};
9
10#[derive_for_wrapped]
11pub trait Key: 'static + Sized + Send + Sync + Clone {
12    type Error: Into<anyhow::Error>;
13    fn encrypt(&self, data: &[u8]) -> Vec<u8>;
14    fn decrypt(&self, data: &[u8]) -> Result<Vec<u8>, Self::Error>;
15}
16
17type Resolution<K> = Arc<Lp<Vec<Point<Encrypted<K, Vec<u8>>>>>>;
18
19#[derive(ToOutput, Clone)]
20struct Unkeyed<T>(T);
21
22impl<
23    T: Parse<I::WithExtra<Extra>>,
24    K: 'static + Clone,
25    Extra: 'static + Clone,
26    I: PointInput<Extra = (K, Extra)>,
27> Parse<I> for Unkeyed<T>
28{
29    fn parse(input: I) -> object_rainbow::Result<Self> {
30        Ok(Self(T::parse(input.map_extra(|(_, extra)| extra))?))
31    }
32}
33
34#[derive(ToOutput, Parse)]
35struct EncryptedInner<K, T> {
36    tags: Hash,
37    resolution: Resolution<K>,
38    decrypted: Unkeyed<Arc<T>>,
39}
40
41impl<K, T> Clone for EncryptedInner<K, T> {
42    fn clone(&self) -> Self {
43        Self {
44            tags: self.tags,
45            resolution: self.resolution.clone(),
46            decrypted: self.decrypted.clone(),
47        }
48    }
49}
50
51type ResolutionIter<'a, K> = std::slice::Iter<'a, Point<Encrypted<K, Vec<u8>>>>;
52
53struct IterateResolution<'a, 'r, K, V> {
54    resolution: &'r mut ResolutionIter<'a, K>,
55    visitor: &'a mut V,
56}
57
58struct Visited<K, P> {
59    decrypted: P,
60    encrypted: Point<Encrypted<K, Vec<u8>>>,
61}
62
63impl<K, P> FetchBytes for Visited<K, P> {
64    fn fetch_bytes(&'_ self) -> FailFuture<'_, ByteNode> {
65        self.encrypted.fetch_bytes()
66    }
67
68    fn fetch_data(&'_ self) -> FailFuture<'_, Vec<u8>> {
69        self.encrypted.fetch_data()
70    }
71
72    fn fetch_bytes_local(&self) -> object_rainbow::Result<Option<ByteNode>> {
73        self.encrypted.fetch_bytes_local()
74    }
75
76    fn fetch_data_local(&self) -> Option<Vec<u8>> {
77        self.encrypted.fetch_data_local()
78    }
79}
80
81impl<K, P: Send + Sync> Singular for Visited<K, P> {
82    fn hash(&self) -> Hash {
83        self.encrypted.hash()
84    }
85}
86
87impl<K: Key, P: Fetch<T: Traversible>> Fetch for Visited<K, P> {
88    type T = Encrypted<K, P::T>;
89
90    fn fetch_full(&'_ self) -> FailFuture<'_, Node<Self::T>> {
91        Box::pin(async move {
92            let (
93                Encrypted {
94                    key,
95                    inner:
96                        EncryptedInner {
97                            tags,
98                            resolution,
99                            decrypted: _,
100                        },
101                },
102                resolve,
103            ) = self.encrypted.fetch_full().await?;
104            let decrypted = self.decrypted.fetch().await?;
105            let decrypted = Unkeyed(Arc::new(decrypted));
106            Ok((
107                Encrypted {
108                    key,
109                    inner: EncryptedInner {
110                        tags,
111                        resolution,
112                        decrypted,
113                    },
114                },
115                resolve,
116            ))
117        })
118    }
119
120    fn fetch(&'_ self) -> FailFuture<'_, Self::T> {
121        Box::pin(async move {
122            let Encrypted {
123                key,
124                inner:
125                    EncryptedInner {
126                        tags,
127                        resolution,
128                        decrypted: _,
129                    },
130            } = self.encrypted.fetch().await?;
131            let decrypted = self.decrypted.fetch().await?;
132            let decrypted = Unkeyed(Arc::new(decrypted));
133            Ok(Encrypted {
134                key,
135                inner: EncryptedInner {
136                    tags,
137                    resolution,
138                    decrypted,
139                },
140            })
141        })
142    }
143
144    fn try_fetch_local(&self) -> object_rainbow::Result<Option<Node<Self::T>>> {
145        let Some((
146            Encrypted {
147                key,
148                inner:
149                    EncryptedInner {
150                        tags,
151                        resolution,
152                        decrypted: _,
153                    },
154            },
155            resolve,
156        )) = self.encrypted.try_fetch_local()?
157        else {
158            return Ok(None);
159        };
160        let Some((decrypted, _)) = self.decrypted.try_fetch_local()? else {
161            return Ok(None);
162        };
163        let decrypted = Unkeyed(Arc::new(decrypted));
164        Ok(Some((
165            Encrypted {
166                key,
167                inner: EncryptedInner {
168                    tags,
169                    resolution,
170                    decrypted,
171                },
172            },
173            resolve,
174        )))
175    }
176
177    fn fetch_local(&self) -> Option<Self::T> {
178        let Encrypted {
179            key,
180            inner:
181                EncryptedInner {
182                    tags,
183                    resolution,
184                    decrypted: _,
185                },
186        } = self.encrypted.fetch_local()?;
187        let decrypted = Unkeyed(Arc::new(self.decrypted.fetch_local()?));
188        Some(Encrypted {
189            key,
190            inner: EncryptedInner {
191                tags,
192                resolution,
193                decrypted,
194            },
195        })
196    }
197}
198
199impl<'a, K: Key, V: PointVisitor> PointVisitor for IterateResolution<'a, '_, K, V> {
200    fn visit<T: Traversible>(&mut self, decrypted: &(impl 'static + SingularFetch<T = T> + Clone)) {
201        let decrypted = decrypted.clone();
202        let encrypted = self.resolution.next().expect("length mismatch").clone();
203        let point = Point::from_fetch(
204            encrypted.hash(),
205            Visited {
206                decrypted,
207                encrypted,
208            }
209            .into_dyn_fetch(),
210        );
211        self.visitor.visit(&point);
212    }
213}
214
215impl<K, T> ListHashes for EncryptedInner<K, T> {
216    fn list_hashes(&self, f: &mut impl FnMut(Hash)) {
217        self.resolution.list_hashes(f);
218    }
219
220    fn topology_hash(&self) -> Hash {
221        self.resolution.0.data_hash()
222    }
223
224    fn point_count(&self) -> usize {
225        self.resolution.len()
226    }
227}
228
229impl<K: Key, T: Topological> Topological for EncryptedInner<K, T> {
230    fn traverse(&self, visitor: &mut impl PointVisitor) {
231        let resolution = &mut self.resolution.iter();
232        self.decrypted.0.traverse(&mut IterateResolution {
233            resolution,
234            visitor,
235        });
236        assert!(resolution.next().is_none());
237    }
238}
239
240pub struct Encrypted<K, T> {
241    key: K,
242    inner: EncryptedInner<K, T>,
243}
244
245impl<K, T: Clone> Encrypted<K, T> {
246    pub fn into_inner(self) -> T {
247        Arc::unwrap_or_clone(self.inner.decrypted.0)
248    }
249}
250
251impl<K, T> Deref for Encrypted<K, T> {
252    type Target = T;
253
254    fn deref(&self) -> &Self::Target {
255        self.inner.decrypted.0.as_ref()
256    }
257}
258
259impl<K: Clone, T> Clone for Encrypted<K, T> {
260    fn clone(&self) -> Self {
261        Self {
262            key: self.key.clone(),
263            inner: self.inner.clone(),
264        }
265    }
266}
267
268impl<K, T> ListHashes for Encrypted<K, T> {
269    fn list_hashes(&self, f: &mut impl FnMut(Hash)) {
270        self.inner.list_hashes(f);
271    }
272
273    fn topology_hash(&self) -> Hash {
274        self.inner.topology_hash()
275    }
276
277    fn point_count(&self) -> usize {
278        self.inner.point_count()
279    }
280}
281
282impl<K: Key, T: Topological> Topological for Encrypted<K, T> {
283    fn traverse(&self, visitor: &mut impl PointVisitor) {
284        self.inner.traverse(visitor);
285    }
286}
287
288impl<K: Key, T: ToOutput> ToOutput for Encrypted<K, T> {
289    fn to_output(&self, output: &mut dyn object_rainbow::Output) {
290        if output.is_mangling() {
291            output.write(
292                &self
293                    .key
294                    .encrypt(b"this encrypted constant is followed by an unencrypted inner hash"),
295            );
296            self.inner.decrypted.0.data_hash();
297        }
298        if output.is_real() {
299            let source = self.inner.vec();
300            output.write(&self.key.encrypt(&source));
301        }
302    }
303}
304
305#[derive(Clone)]
306struct Decrypt<K> {
307    resolution: Resolution<K>,
308}
309
310impl<K: Key> Decrypt<K> {
311    async fn resolve_bytes(
312        &self,
313        address: Address,
314    ) -> object_rainbow::Result<(Vec<u8>, Resolution<K>)> {
315        let Encrypted {
316            key: _,
317            inner:
318                EncryptedInner {
319                    tags: _,
320                    resolution,
321                    decrypted,
322                },
323        } = self
324            .resolution
325            .get(address.index)
326            .ok_or(Error::AddressOutOfBounds)?
327            .clone()
328            .fetch()
329            .await?;
330        let data = Arc::unwrap_or_clone(decrypted.0);
331        Ok((data, resolution))
332    }
333}
334
335impl<K: Key> Resolve for Decrypt<K> {
336    fn resolve(&'_ self, address: Address, _: &Arc<dyn Resolve>) -> FailFuture<'_, ByteNode> {
337        Box::pin(async move {
338            let (data, resolution) = self.resolve_bytes(address).await?;
339            Ok((data, Arc::new(Decrypt { resolution }) as _))
340        })
341    }
342
343    fn resolve_data(&'_ self, address: Address) -> FailFuture<'_, Vec<u8>> {
344        Box::pin(async move {
345            let (data, _) = self.resolve_bytes(address).await?;
346            Ok(data)
347        })
348    }
349
350    fn try_resolve_local(
351        &self,
352        address: Address,
353        _: &Arc<dyn Resolve>,
354    ) -> object_rainbow::Result<Option<ByteNode>> {
355        let Some((
356            Encrypted {
357                key: _,
358                inner:
359                    EncryptedInner {
360                        tags: _,
361                        resolution,
362                        decrypted,
363                    },
364            },
365            _,
366        )) = self
367            .resolution
368            .get(address.index)
369            .ok_or(Error::AddressOutOfBounds)?
370            .clone()
371            .try_fetch_local()?
372        else {
373            return Ok(None);
374        };
375        let data = Arc::unwrap_or_clone(decrypted.0);
376        Ok(Some((data, Arc::new(Decrypt { resolution }) as _)))
377    }
378}
379
380trait EncryptedExtra<K>: 'static + Send + Sync + Clone {
381    type Extra: 'static + Send + Sync + Clone;
382    fn parts(&self) -> (K, Self::Extra);
383}
384
385impl<K: 'static + Send + Sync + Clone, Extra: 'static + Send + Sync + Clone> EncryptedExtra<K>
386    for (K, Extra)
387{
388    type Extra = Extra;
389
390    fn parts(&self) -> (K, Self::Extra) {
391        self.clone()
392    }
393}
394
395impl<K: 'static + Send + Sync + Clone> EncryptedExtra<K> for K {
396    type Extra = ();
397
398    fn parts(&self) -> (K, Self::Extra) {
399        (self.clone(), ())
400    }
401}
402
403impl<
404    K: Key,
405    T: Object<Extra>,
406    Extra: 'static + Send + Sync + Clone,
407    I: PointInput<Extra: EncryptedExtra<K, Extra = Extra>>,
408> Parse<I> for Encrypted<K, T>
409{
410    fn parse(input: I) -> object_rainbow::Result<Self> {
411        let with_key = input.extra().parts();
412        let resolve = input.resolve().clone();
413        let source = with_key
414            .0
415            .decrypt(&input.parse_all()?)
416            .map_err(object_rainbow::Error::consistency)?;
417        let EncryptedInner {
418            tags,
419            resolution,
420            decrypted,
421        } = EncryptedInner::<K, Vec<u8>>::parse_slice_extra(&source, &resolve, &with_key)?;
422        let decrypted = T::parse_slice_extra(
423            &decrypted.0,
424            &(Arc::new(Decrypt {
425                resolution: resolution.clone(),
426            }) as _),
427            &with_key.1,
428        )?;
429        let decrypted = Unkeyed(Arc::new(decrypted));
430        let inner = EncryptedInner {
431            tags,
432            resolution,
433            decrypted,
434        };
435        Ok(Self {
436            key: with_key.0,
437            inner,
438        })
439    }
440}
441
442impl<K, T> Tagged for Encrypted<K, T> {}
443
444type Extracted<K> = Vec<
445    std::pin::Pin<
446        Box<dyn Future<Output = Result<Point<Encrypted<K, Vec<u8>>>, Error>> + Send + 'static>,
447    >,
448>;
449
450struct ExtractResolution<'a, K> {
451    extracted: &'a mut Extracted<K>,
452    key: &'a K,
453}
454
455struct Untyped<K, T> {
456    key: K,
457    encrypted: Point<Encrypted<K, T>>,
458}
459
460impl<K, T> FetchBytes for Untyped<K, T> {
461    fn fetch_bytes(&'_ self) -> FailFuture<'_, ByteNode> {
462        self.encrypted.fetch_bytes()
463    }
464
465    fn fetch_data(&'_ self) -> FailFuture<'_, Vec<u8>> {
466        self.encrypted.fetch_data()
467    }
468
469    fn fetch_bytes_local(&self) -> object_rainbow::Result<Option<ByteNode>> {
470        self.encrypted.fetch_bytes_local()
471    }
472
473    fn fetch_data_local(&self) -> Option<Vec<u8>> {
474        self.encrypted.fetch_data_local()
475    }
476}
477
478impl<K: Send + Sync, T> Singular for Untyped<K, T> {
479    fn hash(&self) -> Hash {
480        self.encrypted.hash()
481    }
482}
483
484impl<K: Key, T: FullHash> Fetch for Untyped<K, T> {
485    type T = Encrypted<K, Vec<u8>>;
486
487    fn fetch_full(&'_ self) -> FailFuture<'_, Node<Self::T>> {
488        Box::pin(async move {
489            let (data, resolve) = self.fetch_bytes().await?;
490            let encrypted = Self::T::parse_slice_extra(&data, &resolve, &self.key)?;
491            Ok((encrypted, resolve))
492        })
493    }
494
495    fn fetch(&'_ self) -> FailFuture<'_, Self::T> {
496        Box::pin(async move {
497            let (data, resolve) = self.fetch_bytes().await?;
498            let encrypted = Self::T::parse_slice_extra(&data, &resolve, &self.key)?;
499            Ok(encrypted)
500        })
501    }
502
503    fn try_fetch_local(&self) -> object_rainbow::Result<Option<Node<Self::T>>> {
504        let Some((data, resolve)) = self.fetch_bytes_local()? else {
505            return Ok(None);
506        };
507        let encrypted = Self::T::parse_slice_extra(&data, &resolve, &self.key)?;
508        Ok(Some((encrypted, resolve)))
509    }
510
511    fn fetch_local(&self) -> Option<Self::T> {
512        let Encrypted {
513            key,
514            inner:
515                EncryptedInner {
516                    tags,
517                    resolution,
518                    decrypted,
519                },
520        } = self.encrypted.fetch_local()?;
521        let decrypted = Unkeyed(Arc::new(decrypted.vec()));
522        Some(Encrypted {
523            key,
524            inner: EncryptedInner {
525                tags,
526                resolution,
527                decrypted,
528            },
529        })
530    }
531}
532
533impl<K: Key> PointVisitor for ExtractResolution<'_, K> {
534    fn visit<T: Traversible>(&mut self, decrypted: &(impl 'static + SingularFetch<T = T> + Clone)) {
535        let decrypted = decrypted.clone();
536        let key = self.key.clone();
537        self.extracted.push(Box::pin(async move {
538            let encrypted = encrypt_point(key.clone(), decrypted).await?;
539            let encrypted = Point::from_fetch(
540                encrypted.hash(),
541                Untyped { key, encrypted }.into_dyn_fetch(),
542            );
543            Ok(encrypted)
544        }));
545    }
546}
547
548pub async fn encrypt_point<K: Key, T: Traversible>(
549    key: K,
550    decrypted: impl 'static + SingularFetch<T = T>,
551) -> object_rainbow::Result<Point<Encrypted<K, T>>> {
552    if let Some((address, decrypt)) = decrypted.extract_resolve::<Decrypt<K>>() {
553        let encrypted = decrypt
554            .resolution
555            .get(address.index)
556            .ok_or(Error::AddressOutOfBounds)?
557            .clone();
558        let point = Point::from_fetch(
559            encrypted.hash(),
560            Visited {
561                decrypted,
562                encrypted,
563            }
564            .into_dyn_fetch(),
565        );
566        return Ok(point);
567    }
568    let decrypted = decrypted.fetch().await?;
569    let encrypted = encrypt(key.clone(), decrypted).await?;
570    let point = encrypted.point();
571    Ok(point)
572}
573
574pub async fn encrypt<K: Key, T: Traversible>(
575    key: K,
576    decrypted: T,
577) -> object_rainbow::Result<Encrypted<K, T>> {
578    let mut futures = Vec::with_capacity(decrypted.point_count());
579    decrypted.traverse(&mut ExtractResolution {
580        extracted: &mut futures,
581        key: &key,
582    });
583    let resolution = futures_util::future::try_join_all(futures).await?;
584    let resolution = Arc::new(Lp(resolution));
585    let decrypted = Unkeyed(Arc::new(decrypted));
586    let inner = EncryptedInner {
587        tags: T::HASH,
588        resolution,
589        decrypted,
590    };
591    Ok(Encrypted { key, inner })
592}