Skip to main content

object_rainbow_encrypted/
lib.rs

1use std::{ops::Deref, sync::Arc};
2
3use object_rainbow::{
4    Address, ByteNode, Error, FailFuture, Fetch, FetchBytes, Hash, ListHashes, Node, Object, Parse,
5    ParseInline, ParseSliceExtra, PointInput, PointVisitor, Resolve, Singular, SingularFetch,
6    Tagged, ToOutput, TopoVec, Topological, Traversible, derive_for_wrapped, length_prefixed::Lp,
7    map_extra::MappedExtra, tuple_extra::Extra0,
8};
9use object_rainbow_point::{ExtractResolve, Extras, IntoPoint, Point};
10
11#[derive_for_wrapped]
12pub trait Key: 'static + Sized + Send + Sync + Clone + PartialEq + Eq {
13    type Error: Into<anyhow::Error>;
14    fn encrypt(&self, data: &[u8]) -> Vec<u8>;
15    fn decrypt(&self, data: &[u8]) -> Result<Vec<u8>, Self::Error>;
16}
17
18#[derive(ToOutput, Parse, ParseInline, Clone)]
19struct RawResolve<K> {
20    key: Extras<K>,
21    resolve: Arc<dyn Resolve>,
22    addresses: Arc<Lp<Vec<Address>>>,
23}
24
25impl<K> RawResolve<K> {
26    fn translate(&self, address: Address) -> object_rainbow::Result<Address> {
27        self.addresses
28            .get(address.index)
29            .copied()
30            .ok_or(Error::AddressOutOfBounds)
31    }
32}
33
34fn side_parse<K: Key>(
35    key: &K,
36    data: &[u8],
37    resolve: &Arc<dyn Resolve>,
38) -> object_rainbow::Result<(InnerHeader<K>, Vec<u8>)> {
39    let data = key
40        .decrypt(data)
41        .map_err(object_rainbow::Error::consistency)?;
42    <(InnerHeader<K>, Vec<u8>) as ParseSliceExtra<K>>::parse_slice_extra(&data, resolve, key)
43}
44
45impl<K: Key> Resolve for RawResolve<K> {
46    fn resolve<'a>(
47        &'a self,
48        address: Address,
49        _: &'a Arc<dyn Resolve>,
50    ) -> FailFuture<'a, ByteNode> {
51        Box::pin(async move {
52            let address = self.translate(address)?;
53            let (data, resolve) = self.resolve.resolve(address, &self.resolve).await?;
54            let (InnerHeader { resolve, .. }, data) = side_parse(&self.key.0, &data, &resolve)?;
55            let resolve = Arc::new(resolve) as _;
56            Ok((data, resolve))
57        })
58    }
59
60    fn resolve_data(&'_ self, address: Address) -> FailFuture<'_, Vec<u8>> {
61        Box::pin(async move {
62            let address = self.translate(address)?;
63            let (data, resolve) = self.resolve.resolve(address, &self.resolve).await?;
64            let (_, data) = side_parse(&self.key.0, &data, &resolve)?;
65            Ok(data)
66        })
67    }
68
69    fn try_resolve_local(
70        &self,
71        address: Address,
72        _: &Arc<dyn Resolve>,
73    ) -> object_rainbow::Result<Option<ByteNode>> {
74        let address = self.translate(address)?;
75        let Some((data, resolve)) = self.resolve.try_resolve_local(address, &self.resolve)? else {
76            return Ok(None);
77        };
78        let (InnerHeader { resolve, .. }, data) = side_parse(&self.key.0, &data, &resolve)?;
79        let resolve = Arc::new(resolve) as _;
80        Ok(Some((data, resolve)))
81    }
82}
83
84#[derive(Parse, ParseInline)]
85struct InnerHeader<K> {
86    tags: Hash,
87    resolve: RawResolve<K>,
88}
89
90impl<K: Key> InnerHeader<K> {
91    fn with<T: Topological + Tagged>(self, decrypted: T) -> object_rainbow::Result<Inner<K, T>> {
92        if self.tags != T::HASH {
93            return Err(object_rainbow::error_consistency!("tags mismatch"));
94        }
95        let mut topology = TopoVec::new();
96        let mut v = RawVisit {
97            at: 0,
98            resolve: &self.resolve,
99            visitor: &mut topology,
100        };
101        decrypted.traverse(&mut v);
102        v.done()?;
103        let topology = Arc::new(Lp(topology));
104        let decrypted = Arc::new(decrypted);
105        Ok(Inner {
106            tags: self.tags,
107            key: self.resolve.key,
108            topology,
109            decrypted,
110        })
111    }
112}
113
114#[derive(ToOutput)]
115struct Inner<K, T> {
116    tags: Hash,
117    key: Extras<K>,
118    topology: Arc<Lp<TopoVec>>,
119    decrypted: Arc<T>,
120}
121
122struct RawVisit<'a, K, V> {
123    at: usize,
124    resolve: &'a RawResolve<K>,
125    visitor: &'a mut V,
126}
127
128impl<'a, K, V> RawVisit<'a, K, V> {
129    fn done(self) -> object_rainbow::Result<()> {
130        if self.at != self.resolve.addresses.len() {
131            Err(Error::AddressOutOfBounds)
132        } else {
133            Ok(())
134        }
135    }
136}
137
138#[derive(Clone)]
139struct RawFetch<K, D> {
140    key: K,
141    resolve: Arc<dyn Resolve>,
142    address: Address,
143    decrypted: D,
144}
145
146impl<K, D> FetchBytes for RawFetch<K, D> {
147    fn fetch_bytes(&'_ self) -> FailFuture<'_, ByteNode> {
148        self.resolve.resolve(self.address, &self.resolve)
149    }
150
151    fn fetch_data(&'_ self) -> FailFuture<'_, Vec<u8>> {
152        self.resolve.resolve_data(self.address)
153    }
154
155    fn fetch_bytes_local(&self) -> object_rainbow::Result<Option<ByteNode>> {
156        self.resolve.try_resolve_local(self.address, &self.resolve)
157    }
158}
159
160impl<K: Key, D: Fetch<T: Topological + Tagged>> Fetch for RawFetch<K, D> {
161    type T = Encrypted<K, D::T>;
162
163    fn fetch_full(&'_ self) -> FailFuture<'_, Node<Self::T>> {
164        Box::pin(async move {
165            let (encrypted, resolve) = self.resolve.resolve(self.address, &self.resolve).await?;
166            let (header, _) = side_parse(&self.key, &encrypted, &resolve)?;
167            let decrypted = self.decrypted.fetch().await?;
168            let inner = header.with(decrypted)?;
169            Ok((Encrypted { inner }, resolve))
170        })
171    }
172
173    fn fetch(&'_ self) -> FailFuture<'_, Self::T> {
174        Box::pin(async move {
175            let (encrypted, resolve) = self.resolve.resolve(self.address, &self.resolve).await?;
176            let (header, _) = side_parse(&self.key, &encrypted, &resolve)?;
177            let decrypted = self.decrypted.fetch().await?;
178            let inner = header.with(decrypted)?;
179            Ok(Encrypted { inner })
180        })
181    }
182
183    fn try_fetch_local(&self) -> object_rainbow::Result<Option<Node<Self::T>>> {
184        let Some((encrypted, resolve)) = self
185            .resolve
186            .try_resolve_local(self.address, &self.resolve)?
187        else {
188            return Ok(None);
189        };
190        let (header, _) = side_parse(&self.key, &encrypted, &resolve)?;
191        let Some((decrypted, _)) = self.decrypted.try_fetch_local()? else {
192            return Ok(None);
193        };
194        let inner = header.with(decrypted)?;
195        Ok(Some((Encrypted { inner }, resolve)))
196    }
197}
198
199impl<K: Send + Sync, D: Send + Sync> Singular for RawFetch<K, D> {
200    fn hash(&self) -> Hash {
201        self.address.hash
202    }
203}
204
205impl<'a, K: Key, V: PointVisitor> PointVisitor for RawVisit<'a, K, V> {
206    fn visit<T: Traversible>(&mut self, point: &(impl 'static + SingularFetch<T = T> + Clone)) {
207        let at = self.at;
208        self.at += 1;
209        if let Some(address) = self.resolve.addresses.get(at).copied() {
210            let key = self.resolve.key.0.clone();
211            let resolve = self.resolve.resolve.clone();
212            let decrypted = point.clone();
213            self.visitor.visit(&RawFetch {
214                key,
215                resolve,
216                address,
217                decrypted,
218            });
219        }
220    }
221}
222
223impl<
224    K: Key,
225    T: Object<Extra>,
226    Extra: 'static + Send + Sync + Clone,
227    I: PointInput<Extra = (K, Extra)>,
228> Parse<I> for Inner<K, T>
229{
230    fn parse(mut input: I) -> object_rainbow::Result<Self> {
231        let header = input
232            .parse_inline::<MappedExtra<InnerHeader<K>, Extra0>>()?
233            .1;
234        let extra = input.extra().1.clone();
235        let decrypted = T::parse_slice_extra(
236            &input.parse_all()?,
237            &(Arc::new(header.resolve.clone()) as _),
238            &extra,
239        )?;
240        header.with(decrypted)
241    }
242}
243
244impl<K: Clone, T> Clone for Inner<K, T> {
245    fn clone(&self) -> Self {
246        Self {
247            tags: self.tags,
248            key: self.key.clone(),
249            topology: self.topology.clone(),
250            decrypted: self.decrypted.clone(),
251        }
252    }
253}
254
255#[derive(Clone)]
256struct InnerFetch<K, D> {
257    key: K,
258    encrypted: Arc<dyn Singular>,
259    decrypted: D,
260}
261
262impl<K, D> FetchBytes for InnerFetch<K, D> {
263    fn fetch_bytes(&'_ self) -> FailFuture<'_, ByteNode> {
264        self.encrypted.fetch_bytes()
265    }
266
267    fn fetch_data(&'_ self) -> FailFuture<'_, Vec<u8>> {
268        self.encrypted.fetch_data()
269    }
270
271    fn fetch_bytes_local(&self) -> object_rainbow::Result<Option<ByteNode>> {
272        self.encrypted.fetch_bytes_local()
273    }
274
275    fn fetch_data_local(&self) -> Option<Vec<u8>> {
276        self.encrypted.fetch_data_local()
277    }
278}
279
280impl<K: Key, D: Fetch<T: Topological + Tagged>> Fetch for InnerFetch<K, D> {
281    type T = Encrypted<K, D::T>;
282
283    fn fetch_full(&'_ self) -> FailFuture<'_, Node<Self::T>> {
284        Box::pin(async move {
285            let (encrypted, resolve) = self.encrypted.fetch_bytes().await?;
286            let (header, _) = side_parse(&self.key, &encrypted, &resolve)?;
287            let decrypted = self.decrypted.fetch().await?;
288            let inner = header.with(decrypted)?;
289            Ok((Encrypted { inner }, resolve))
290        })
291    }
292
293    fn fetch(&'_ self) -> FailFuture<'_, Self::T> {
294        Box::pin(async move {
295            let (encrypted, resolve) = self.encrypted.fetch_bytes().await?;
296            let (header, _) = side_parse(&self.key, &encrypted, &resolve)?;
297            let decrypted = self.decrypted.fetch().await?;
298            let inner = header.with(decrypted)?;
299            Ok(Encrypted { inner })
300        })
301    }
302
303    fn try_fetch_local(&self) -> object_rainbow::Result<Option<Node<Self::T>>> {
304        let Some((encrypted, resolve)) = self.encrypted.fetch_bytes_local()? else {
305            return Ok(None);
306        };
307        let (header, _) = side_parse(&self.key, &encrypted, &resolve)?;
308        let Some((decrypted, _)) = self.decrypted.try_fetch_local()? else {
309            return Ok(None);
310        };
311        let inner = header.with(decrypted)?;
312        Ok(Some((Encrypted { inner }, resolve)))
313    }
314}
315
316impl<K: Send + Sync, D: Send + Sync> Singular for InnerFetch<K, D> {
317    fn hash(&self) -> Hash {
318        self.encrypted.hash()
319    }
320}
321
322struct IterateResolution<'a, 'r, K, V> {
323    key: &'a K,
324    topology: &'r mut std::slice::Iter<'a, Arc<dyn Singular>>,
325    visitor: &'a mut V,
326}
327
328impl<'a, K: Key, V: PointVisitor> PointVisitor for IterateResolution<'a, '_, K, V> {
329    fn visit<T: Traversible>(&mut self, decrypted: &(impl 'static + SingularFetch<T = T> + Clone)) {
330        let decrypted = decrypted.clone();
331        let encrypted = self.topology.next().expect("length mismatch").clone();
332        let point = Point::from_fetch(
333            encrypted.hash(),
334            InnerFetch {
335                key: self.key.clone(),
336                decrypted,
337                encrypted,
338            }
339            .into_dyn_fetch(),
340        );
341        self.visitor.visit(&point);
342    }
343}
344
345impl<K, T> ListHashes for Inner<K, T> {
346    fn list_hashes(&self, f: &mut impl FnMut(Hash)) {
347        self.topology.list_hashes(f);
348    }
349
350    fn topology_hash(&self) -> Hash {
351        self.topology.0.data_hash()
352    }
353
354    fn point_count(&self) -> usize {
355        self.topology.len()
356    }
357}
358
359impl<K: Key, T: Topological> Topological for Inner<K, T> {
360    fn traverse(&self, visitor: &mut impl PointVisitor) {
361        let topology = &mut self.topology.iter();
362        self.decrypted.traverse(&mut IterateResolution {
363            key: &self.key.0,
364            topology,
365            visitor,
366        });
367        assert!(topology.next().is_none());
368    }
369
370    fn topology(&self) -> TopoVec {
371        self.topology.0.clone()
372    }
373}
374
375pub struct Encrypted<K, T> {
376    inner: Inner<K, T>,
377}
378
379impl<K, T: Clone> Encrypted<K, T> {
380    pub fn into_inner(self) -> T {
381        Arc::unwrap_or_clone(self.inner.decrypted)
382    }
383}
384
385impl<K, T> Deref for Encrypted<K, T> {
386    type Target = T;
387
388    fn deref(&self) -> &Self::Target {
389        self.inner.decrypted.as_ref()
390    }
391}
392
393impl<K: Clone, T> Clone for Encrypted<K, T> {
394    fn clone(&self) -> Self {
395        Self {
396            inner: self.inner.clone(),
397        }
398    }
399}
400
401impl<K, T> ListHashes for Encrypted<K, T> {
402    fn list_hashes(&self, f: &mut impl FnMut(Hash)) {
403        self.inner.list_hashes(f);
404    }
405
406    fn topology_hash(&self) -> Hash {
407        self.inner.topology_hash()
408    }
409
410    fn point_count(&self) -> usize {
411        self.inner.point_count()
412    }
413}
414
415impl<K: Key, T: Topological> Topological for Encrypted<K, T> {
416    fn traverse(&self, visitor: &mut impl PointVisitor) {
417        self.inner.traverse(visitor);
418    }
419}
420
421impl<K: Key, T: ToOutput> ToOutput for Encrypted<K, T> {
422    fn to_output(&self, output: &mut impl object_rainbow::Output) {
423        if output.is_mangling() {
424            output.write(
425                &self
426                    .inner
427                    .key
428                    .encrypt(b"this encrypted constant is followed by an unencrypted inner hash"),
429            );
430            self.inner.decrypted.data_hash();
431        }
432        if output.is_real() {
433            let source = self.inner.vec();
434            output.write(&self.inner.key.encrypt(&source));
435        }
436    }
437}
438
439trait EncryptedExtra<K>: 'static + Send + Sync + Clone {
440    type Extra: 'static + Send + Sync + Clone;
441    fn parts(&self) -> (K, Self::Extra);
442}
443
444impl<K: 'static + Send + Sync + Clone, Extra: 'static + Send + Sync + Clone> EncryptedExtra<K>
445    for (K, Extra)
446{
447    type Extra = Extra;
448
449    fn parts(&self) -> (K, Self::Extra) {
450        self.clone()
451    }
452}
453
454impl<K: 'static + Send + Sync + Clone> EncryptedExtra<K> for K {
455    type Extra = ();
456
457    fn parts(&self) -> (K, Self::Extra) {
458        (self.clone(), ())
459    }
460}
461
462impl<
463    K: Key,
464    T: Object<Extra>,
465    Extra: 'static + Send + Sync + Clone,
466    I: PointInput<Extra: EncryptedExtra<K, Extra = Extra>>,
467> Parse<I> for Encrypted<K, T>
468{
469    fn parse(input: I) -> object_rainbow::Result<Self> {
470        let with_key = input.extra().parts();
471        let resolve = input.resolve().clone();
472        let source = with_key
473            .0
474            .decrypt(&input.parse_all()?)
475            .map_err(object_rainbow::Error::consistency)?;
476        let inner = Inner::<K, T>::parse_slice_extra(&source, &resolve, &with_key)?;
477        Ok(Self { inner })
478    }
479}
480
481impl<K, T> Tagged for Encrypted<K, T> {}
482
483type Extracted =
484    Vec<std::pin::Pin<Box<dyn Future<Output = Result<Arc<dyn Singular>, Error>> + Send + 'static>>>;
485
486struct ExtractResolution<'a, K> {
487    extracted: &'a mut Extracted,
488    key: &'a K,
489}
490
491impl<K: Key> PointVisitor for ExtractResolution<'_, K> {
492    fn visit<T: Traversible>(&mut self, decrypted: &(impl 'static + SingularFetch<T = T> + Clone)) {
493        let decrypted = decrypted.clone();
494        let key = self.key.clone();
495        self.extracted.push(Box::pin(async move {
496            Ok(Arc::new(encrypt_point(key, decrypted).await?) as _)
497        }));
498    }
499}
500
501pub async fn encrypt_point<K: Key, T: Traversible>(
502    key: K,
503    decrypted: impl 'static + SingularFetch<T = T>,
504) -> object_rainbow::Result<Point<Encrypted<K, T>>> {
505    if let Some((address, resolve)) = decrypted.extract_resolve::<RawResolve<K>>()
506        && resolve.key.0 == key
507    {
508        let address = resolve.translate(*address)?;
509        let point = Point::from_fetch(
510            address.hash,
511            RawFetch {
512                key,
513                resolve: resolve.resolve.clone(),
514                address,
515                decrypted,
516            }
517            .into_dyn_fetch(),
518        );
519        return Ok(point);
520    }
521    let decrypted = decrypted.fetch().await?;
522    let encrypted = encrypt(key, decrypted).await?;
523    let point = encrypted.point();
524    Ok(point)
525}
526
527pub async fn encrypt<K: Key, T: Traversible>(
528    key: K,
529    decrypted: T,
530) -> object_rainbow::Result<Encrypted<K, T>> {
531    let mut futures = Vec::with_capacity(decrypted.point_count());
532    decrypted.traverse(&mut ExtractResolution {
533        extracted: &mut futures,
534        key: &key,
535    });
536    let topology = futures_util::future::try_join_all(futures).await?;
537    let topology = Arc::new(Lp(topology));
538    let decrypted = Arc::new(decrypted);
539    let inner = Inner {
540        tags: T::HASH,
541        key: Extras(key),
542        topology,
543        decrypted,
544    };
545    Ok(Encrypted { inner })
546}