Skip to main content

object_rainbow_encrypted/
lib.rs

1use std::{ops::Deref, sync::Arc};
2
3use object_rainbow::{
4    Address, ByteNode, Error, FailFuture, Fetch, FetchBytes, FullHash, Hash, ListHashes, Node,
5    Object, Parse, ParseSliceExtra, PointInput, PointVisitor, Resolve, Singular, SingularFetch,
6    Tagged, ToOutput, Topological, Traversible, length_prefixed::Lp,
7};
8use object_rainbow_point::{ExtractResolve, IntoPoint, Point};
9
10pub trait Key: 'static + Sized + Send + Sync + Clone {
11    type Error: Into<anyhow::Error>;
12    fn encrypt(&self, data: &[u8]) -> Vec<u8>;
13    fn decrypt(&self, data: &[u8]) -> Result<Vec<u8>, Self::Error>;
14}
15
16type Resolution<K> = Arc<Lp<Vec<Point<Encrypted<K, Vec<u8>>>>>>;
17
18#[derive(ToOutput, Clone)]
19struct Unkeyed<T>(T);
20
21impl<
22    T: Parse<I::WithExtra<Extra>>,
23    K: 'static + Clone,
24    Extra: 'static + Clone,
25    I: PointInput<Extra = (K, Extra)>,
26> Parse<I> for Unkeyed<T>
27{
28    fn parse(input: I) -> object_rainbow::Result<Self> {
29        Ok(Self(T::parse(input.map_extra(|(_, extra)| extra))?))
30    }
31}
32
33#[derive(ToOutput, Parse)]
34struct EncryptedInner<K, T> {
35    tags: Hash,
36    resolution: Resolution<K>,
37    decrypted: Unkeyed<Arc<T>>,
38}
39
40impl<K, T> Clone for EncryptedInner<K, T> {
41    fn clone(&self) -> Self {
42        Self {
43            tags: self.tags,
44            resolution: self.resolution.clone(),
45            decrypted: self.decrypted.clone(),
46        }
47    }
48}
49
50type ResolutionIter<'a, K> = std::slice::Iter<'a, Point<Encrypted<K, Vec<u8>>>>;
51
52struct IterateResolution<'a, 'r, K, V> {
53    resolution: &'r mut ResolutionIter<'a, K>,
54    visitor: &'a mut V,
55}
56
57struct Visited<K, P> {
58    decrypted: P,
59    encrypted: Point<Encrypted<K, Vec<u8>>>,
60}
61
62impl<K, P> FetchBytes for Visited<K, P> {
63    fn fetch_bytes(&'_ self) -> FailFuture<'_, ByteNode> {
64        self.encrypted.fetch_bytes()
65    }
66
67    fn fetch_data(&'_ self) -> FailFuture<'_, Vec<u8>> {
68        self.encrypted.fetch_data()
69    }
70
71    fn fetch_bytes_local(&self) -> object_rainbow::Result<Option<ByteNode>> {
72        self.encrypted.fetch_bytes_local()
73    }
74
75    fn fetch_data_local(&self) -> Option<Vec<u8>> {
76        self.encrypted.fetch_data_local()
77    }
78}
79
80impl<K, P: Send + Sync> Singular for Visited<K, P> {
81    fn hash(&self) -> Hash {
82        self.encrypted.hash()
83    }
84}
85
86impl<K: Key, P: Fetch<T: Traversible>> Fetch for Visited<K, P> {
87    type T = Encrypted<K, P::T>;
88
89    fn fetch_full(&'_ self) -> FailFuture<'_, Node<Self::T>> {
90        Box::pin(async move {
91            let (
92                Encrypted {
93                    key,
94                    inner:
95                        EncryptedInner {
96                            tags,
97                            resolution,
98                            decrypted: _,
99                        },
100                },
101                resolve,
102            ) = self.encrypted.fetch_full().await?;
103            let decrypted = self.decrypted.fetch().await?;
104            let decrypted = Unkeyed(Arc::new(decrypted));
105            Ok((
106                Encrypted {
107                    key,
108                    inner: EncryptedInner {
109                        tags,
110                        resolution,
111                        decrypted,
112                    },
113                },
114                resolve,
115            ))
116        })
117    }
118
119    fn fetch(&'_ self) -> FailFuture<'_, Self::T> {
120        Box::pin(async move {
121            let Encrypted {
122                key,
123                inner:
124                    EncryptedInner {
125                        tags,
126                        resolution,
127                        decrypted: _,
128                    },
129            } = self.encrypted.fetch().await?;
130            let decrypted = self.decrypted.fetch().await?;
131            let decrypted = Unkeyed(Arc::new(decrypted));
132            Ok(Encrypted {
133                key,
134                inner: EncryptedInner {
135                    tags,
136                    resolution,
137                    decrypted,
138                },
139            })
140        })
141    }
142
143    fn try_fetch_local(&self) -> object_rainbow::Result<Option<Node<Self::T>>> {
144        let Some((
145            Encrypted {
146                key,
147                inner:
148                    EncryptedInner {
149                        tags,
150                        resolution,
151                        decrypted: _,
152                    },
153            },
154            resolve,
155        )) = self.encrypted.try_fetch_local()?
156        else {
157            return Ok(None);
158        };
159        let Some((decrypted, _)) = self.decrypted.try_fetch_local()? else {
160            return Ok(None);
161        };
162        let decrypted = Unkeyed(Arc::new(decrypted));
163        Ok(Some((
164            Encrypted {
165                key,
166                inner: EncryptedInner {
167                    tags,
168                    resolution,
169                    decrypted,
170                },
171            },
172            resolve,
173        )))
174    }
175
176    fn fetch_local(&self) -> Option<Self::T> {
177        let Encrypted {
178            key,
179            inner:
180                EncryptedInner {
181                    tags,
182                    resolution,
183                    decrypted: _,
184                },
185        } = self.encrypted.fetch_local()?;
186        let decrypted = Unkeyed(Arc::new(self.decrypted.fetch_local()?));
187        Some(Encrypted {
188            key,
189            inner: EncryptedInner {
190                tags,
191                resolution,
192                decrypted,
193            },
194        })
195    }
196}
197
198impl<'a, K: Key, V: PointVisitor> PointVisitor for IterateResolution<'a, '_, K, V> {
199    fn visit<T: Traversible>(&mut self, decrypted: &(impl 'static + SingularFetch<T = T> + Clone)) {
200        let decrypted = decrypted.clone();
201        let encrypted = self.resolution.next().expect("length mismatch").clone();
202        let point = Point::from_fetch(
203            encrypted.hash(),
204            Visited {
205                decrypted,
206                encrypted,
207            }
208            .into_dyn_fetch(),
209        );
210        self.visitor.visit(&point);
211    }
212}
213
214impl<K, T> ListHashes for EncryptedInner<K, T> {
215    fn list_hashes(&self, f: &mut impl FnMut(Hash)) {
216        self.resolution.list_hashes(f);
217    }
218
219    fn topology_hash(&self) -> Hash {
220        self.resolution.0.data_hash()
221    }
222
223    fn point_count(&self) -> usize {
224        self.resolution.len()
225    }
226}
227
228impl<K: Key, T: Topological> Topological for EncryptedInner<K, T> {
229    fn traverse(&self, visitor: &mut impl PointVisitor) {
230        let resolution = &mut self.resolution.iter();
231        self.decrypted.0.traverse(&mut IterateResolution {
232            resolution,
233            visitor,
234        });
235        assert!(resolution.next().is_none());
236    }
237}
238
239pub struct Encrypted<K, T> {
240    key: K,
241    inner: EncryptedInner<K, T>,
242}
243
244impl<K, T: Clone> Encrypted<K, T> {
245    pub fn into_inner(self) -> T {
246        Arc::unwrap_or_clone(self.inner.decrypted.0)
247    }
248}
249
250impl<K, T> Deref for Encrypted<K, T> {
251    type Target = T;
252
253    fn deref(&self) -> &Self::Target {
254        self.inner.decrypted.0.as_ref()
255    }
256}
257
258impl<K: Clone, T> Clone for Encrypted<K, T> {
259    fn clone(&self) -> Self {
260        Self {
261            key: self.key.clone(),
262            inner: self.inner.clone(),
263        }
264    }
265}
266
267impl<K, T> ListHashes for Encrypted<K, T> {
268    fn list_hashes(&self, f: &mut impl FnMut(Hash)) {
269        self.inner.list_hashes(f);
270    }
271
272    fn topology_hash(&self) -> Hash {
273        self.inner.topology_hash()
274    }
275
276    fn point_count(&self) -> usize {
277        self.inner.point_count()
278    }
279}
280
281impl<K: Key, T: Topological> Topological for Encrypted<K, T> {
282    fn traverse(&self, visitor: &mut impl PointVisitor) {
283        self.inner.traverse(visitor);
284    }
285}
286
287impl<K: Key, T: ToOutput> ToOutput for Encrypted<K, T> {
288    fn to_output(&self, output: &mut dyn object_rainbow::Output) {
289        if output.is_mangling() {
290            output.write(
291                &self
292                    .key
293                    .encrypt(b"this encrypted constant is followed by an unencrypted inner hash"),
294            );
295            self.inner.decrypted.0.data_hash();
296        }
297        if output.is_real() {
298            let source = self.inner.vec();
299            output.write(&self.key.encrypt(&source));
300        }
301    }
302}
303
304#[derive(Clone)]
305struct Decrypt<K> {
306    resolution: Resolution<K>,
307}
308
309impl<K: Key> Decrypt<K> {
310    async fn resolve_bytes(
311        &self,
312        address: Address,
313    ) -> object_rainbow::Result<(Vec<u8>, Resolution<K>)> {
314        let Encrypted {
315            key: _,
316            inner:
317                EncryptedInner {
318                    tags: _,
319                    resolution,
320                    decrypted,
321                },
322        } = self
323            .resolution
324            .get(address.index)
325            .ok_or(Error::AddressOutOfBounds)?
326            .clone()
327            .fetch()
328            .await?;
329        let data = Arc::unwrap_or_clone(decrypted.0);
330        Ok((data, resolution))
331    }
332}
333
334impl<K: Key> Resolve for Decrypt<K> {
335    fn resolve(&'_ self, address: Address, _: &Arc<dyn Resolve>) -> FailFuture<'_, ByteNode> {
336        Box::pin(async move {
337            let (data, resolution) = self.resolve_bytes(address).await?;
338            Ok((data, Arc::new(Decrypt { resolution }) as _))
339        })
340    }
341
342    fn resolve_data(&'_ self, address: Address) -> FailFuture<'_, Vec<u8>> {
343        Box::pin(async move {
344            let (data, _) = self.resolve_bytes(address).await?;
345            Ok(data)
346        })
347    }
348
349    fn try_resolve_local(
350        &self,
351        address: Address,
352        _: &Arc<dyn Resolve>,
353    ) -> object_rainbow::Result<Option<ByteNode>> {
354        let Some((
355            Encrypted {
356                key: _,
357                inner:
358                    EncryptedInner {
359                        tags: _,
360                        resolution,
361                        decrypted,
362                    },
363            },
364            _,
365        )) = self
366            .resolution
367            .get(address.index)
368            .ok_or(Error::AddressOutOfBounds)?
369            .clone()
370            .try_fetch_local()?
371        else {
372            return Ok(None);
373        };
374        let data = Arc::unwrap_or_clone(decrypted.0);
375        Ok(Some((data, Arc::new(Decrypt { resolution }) as _)))
376    }
377}
378
379trait EncryptedExtra<K>: 'static + Send + Sync + Clone {
380    type Extra: 'static + Send + Sync + Clone;
381    fn parts(&self) -> (K, Self::Extra);
382}
383
384impl<K: 'static + Send + Sync + Clone, Extra: 'static + Send + Sync + Clone> EncryptedExtra<K>
385    for (K, Extra)
386{
387    type Extra = Extra;
388
389    fn parts(&self) -> (K, Self::Extra) {
390        self.clone()
391    }
392}
393
394impl<K: 'static + Send + Sync + Clone> EncryptedExtra<K> for K {
395    type Extra = ();
396
397    fn parts(&self) -> (K, Self::Extra) {
398        (self.clone(), ())
399    }
400}
401
402impl<
403    K: Key,
404    T: Object<Extra>,
405    Extra: 'static + Send + Sync + Clone,
406    I: PointInput<Extra: EncryptedExtra<K, Extra = Extra>>,
407> Parse<I> for Encrypted<K, T>
408{
409    fn parse(input: I) -> object_rainbow::Result<Self> {
410        let with_key = input.extra().parts();
411        let resolve = input.resolve().clone();
412        let source = with_key
413            .0
414            .decrypt(&input.parse_all()?)
415            .map_err(object_rainbow::Error::consistency)?;
416        let EncryptedInner {
417            tags,
418            resolution,
419            decrypted,
420        } = EncryptedInner::<K, Vec<u8>>::parse_slice_extra(&source, &resolve, &with_key)?;
421        let decrypted = T::parse_slice_extra(
422            &decrypted.0,
423            &(Arc::new(Decrypt {
424                resolution: resolution.clone(),
425            }) as _),
426            &with_key.1,
427        )?;
428        let decrypted = Unkeyed(Arc::new(decrypted));
429        let inner = EncryptedInner {
430            tags,
431            resolution,
432            decrypted,
433        };
434        Ok(Self {
435            key: with_key.0,
436            inner,
437        })
438    }
439}
440
441impl<K, T> Tagged for Encrypted<K, T> {}
442
443type Extracted<K> = Vec<
444    std::pin::Pin<
445        Box<dyn Future<Output = Result<Point<Encrypted<K, Vec<u8>>>, Error>> + Send + 'static>,
446    >,
447>;
448
449struct ExtractResolution<'a, K> {
450    extracted: &'a mut Extracted<K>,
451    key: &'a K,
452}
453
454struct Untyped<K, T> {
455    key: K,
456    encrypted: Point<Encrypted<K, T>>,
457}
458
459impl<K, T> FetchBytes for Untyped<K, T> {
460    fn fetch_bytes(&'_ self) -> FailFuture<'_, ByteNode> {
461        self.encrypted.fetch_bytes()
462    }
463
464    fn fetch_data(&'_ self) -> FailFuture<'_, Vec<u8>> {
465        self.encrypted.fetch_data()
466    }
467
468    fn fetch_bytes_local(&self) -> object_rainbow::Result<Option<ByteNode>> {
469        self.encrypted.fetch_bytes_local()
470    }
471
472    fn fetch_data_local(&self) -> Option<Vec<u8>> {
473        self.encrypted.fetch_data_local()
474    }
475}
476
477impl<K: Send + Sync, T> Singular for Untyped<K, T> {
478    fn hash(&self) -> Hash {
479        self.encrypted.hash()
480    }
481}
482
483impl<K: Key, T: FullHash> Fetch for Untyped<K, T> {
484    type T = Encrypted<K, Vec<u8>>;
485
486    fn fetch_full(&'_ self) -> FailFuture<'_, Node<Self::T>> {
487        Box::pin(async move {
488            let (data, resolve) = self.fetch_bytes().await?;
489            let encrypted = Self::T::parse_slice_extra(&data, &resolve, &self.key)?;
490            Ok((encrypted, resolve))
491        })
492    }
493
494    fn fetch(&'_ self) -> FailFuture<'_, Self::T> {
495        Box::pin(async move {
496            let (data, resolve) = self.fetch_bytes().await?;
497            let encrypted = Self::T::parse_slice_extra(&data, &resolve, &self.key)?;
498            Ok(encrypted)
499        })
500    }
501
502    fn try_fetch_local(&self) -> object_rainbow::Result<Option<Node<Self::T>>> {
503        let Some((data, resolve)) = self.fetch_bytes_local()? else {
504            return Ok(None);
505        };
506        let encrypted = Self::T::parse_slice_extra(&data, &resolve, &self.key)?;
507        Ok(Some((encrypted, resolve)))
508    }
509
510    fn fetch_local(&self) -> Option<Self::T> {
511        let Encrypted {
512            key,
513            inner:
514                EncryptedInner {
515                    tags,
516                    resolution,
517                    decrypted,
518                },
519        } = self.encrypted.fetch_local()?;
520        let decrypted = Unkeyed(Arc::new(decrypted.vec()));
521        Some(Encrypted {
522            key,
523            inner: EncryptedInner {
524                tags,
525                resolution,
526                decrypted,
527            },
528        })
529    }
530}
531
532impl<K: Key> PointVisitor for ExtractResolution<'_, K> {
533    fn visit<T: Traversible>(&mut self, decrypted: &(impl 'static + SingularFetch<T = T> + Clone)) {
534        let decrypted = decrypted.clone();
535        let key = self.key.clone();
536        self.extracted.push(Box::pin(async move {
537            let encrypted = encrypt_point(key.clone(), decrypted).await?;
538            let encrypted = Point::from_fetch(
539                encrypted.hash(),
540                Untyped { key, encrypted }.into_dyn_fetch(),
541            );
542            Ok(encrypted)
543        }));
544    }
545}
546
547pub async fn encrypt_point<K: Key, T: Traversible>(
548    key: K,
549    decrypted: impl 'static + SingularFetch<T = T>,
550) -> object_rainbow::Result<Point<Encrypted<K, T>>> {
551    if let Some((address, decrypt)) = decrypted.extract_resolve::<Decrypt<K>>() {
552        let encrypted = decrypt
553            .resolution
554            .get(address.index)
555            .ok_or(Error::AddressOutOfBounds)?
556            .clone();
557        let point = Point::from_fetch(
558            encrypted.hash(),
559            Visited {
560                decrypted,
561                encrypted,
562            }
563            .into_dyn_fetch(),
564        );
565        return Ok(point);
566    }
567    let decrypted = decrypted.fetch().await?;
568    let encrypted = encrypt(key.clone(), decrypted).await?;
569    let point = encrypted.point();
570    Ok(point)
571}
572
573pub async fn encrypt<K: Key, T: Traversible>(
574    key: K,
575    decrypted: T,
576) -> object_rainbow::Result<Encrypted<K, T>> {
577    let mut futures = Vec::with_capacity(decrypted.point_count());
578    decrypted.traverse(&mut ExtractResolution {
579        extracted: &mut futures,
580        key: &key,
581    });
582    let resolution = futures_util::future::try_join_all(futures).await?;
583    let resolution = Arc::new(Lp(resolution));
584    let decrypted = Unkeyed(Arc::new(decrypted));
585    let inner = EncryptedInner {
586        tags: T::HASH,
587        resolution,
588        decrypted,
589    };
590    Ok(Encrypted { key, inner })
591}