object_rainbow_encrypted/
lib.rs

1use std::{ops::Deref, sync::Arc};
2
3use object_rainbow::{
4    Address, ByteNode, Error, FailFuture, Fetch, FetchBytes, FullHash, Hash, Node, Object, Parse,
5    ParseSliceExtra, Point, PointInput, PointVisitor, Resolve, Singular, Tagged, ToOutput,
6    Topological, Traversible, length_prefixed::Lp,
7};
8
9#[derive(Clone)]
10pub struct WithKey<K, Extra> {
11    pub key: K,
12    pub extra: Extra,
13}
14
15pub trait Key: 'static + Sized + Send + Sync + Clone {
16    fn encrypt(&self, data: &[u8]) -> Vec<u8>;
17    fn decrypt(&self, data: &[u8]) -> object_rainbow::Result<Vec<u8>>;
18}
19
20type Resolution<K> = Arc<Lp<Vec<Point<Encrypted<K, Vec<u8>>>>>>;
21
22#[derive(ToOutput, Clone)]
23struct Unkeyed<T>(T);
24
25impl<
26    T: Parse<I::WithExtra<Extra>>,
27    K: 'static + Clone,
28    Extra: 'static + Clone,
29    I: PointInput<Extra = WithKey<K, Extra>>,
30> Parse<I> for Unkeyed<T>
31{
32    fn parse(input: I) -> object_rainbow::Result<Self> {
33        Ok(Self(T::parse(
34            input.map_extra(|WithKey { extra, .. }| extra),
35        )?))
36    }
37}
38
39#[derive(ToOutput, Parse)]
40struct EncryptedInner<K, T> {
41    resolution: Resolution<K>,
42    decrypted: Unkeyed<Arc<T>>,
43}
44
45impl<K, T> Clone for EncryptedInner<K, T> {
46    fn clone(&self) -> Self {
47        Self {
48            resolution: self.resolution.clone(),
49            decrypted: self.decrypted.clone(),
50        }
51    }
52}
53
54type ResolutionIter<'a, K> = std::slice::Iter<'a, Point<Encrypted<K, Vec<u8>>>>;
55
56struct IterateResolution<'a, 'r, K, V> {
57    resolution: &'r mut ResolutionIter<'a, K>,
58    visitor: &'a mut V,
59}
60
61struct Visited<K, T> {
62    decrypted: Point<T>,
63    encrypted: Point<Encrypted<K, Vec<u8>>>,
64}
65
66impl<K, T> FetchBytes for Visited<K, T> {
67    fn fetch_bytes(&'_ self) -> FailFuture<'_, ByteNode> {
68        self.encrypted.fetch_bytes()
69    }
70
71    fn fetch_data(&'_ self) -> FailFuture<'_, Vec<u8>> {
72        self.encrypted.fetch_data()
73    }
74
75    fn fetch_bytes_local(&self) -> object_rainbow::Result<Option<ByteNode>> {
76        self.encrypted.fetch_bytes_local()
77    }
78
79    fn fetch_data_local(&self) -> Option<Vec<u8>> {
80        self.encrypted.fetch_data_local()
81    }
82}
83
84impl<K, T> Singular for Visited<K, T> {
85    fn hash(&self) -> Hash {
86        self.encrypted.hash()
87    }
88}
89
90impl<K: Key, T: Traversible> Fetch for Visited<K, T> {
91    type T = Encrypted<K, T>;
92
93    fn fetch_full(&'_ self) -> FailFuture<'_, Node<Self::T>> {
94        Box::pin(async move {
95            let (
96                Encrypted {
97                    key,
98                    inner:
99                        EncryptedInner {
100                            resolution,
101                            decrypted: _,
102                        },
103                },
104                resolve,
105            ) = self.encrypted.fetch_full().await?;
106            let decrypted = self.decrypted.fetch().await?;
107            let decrypted = Unkeyed(Arc::new(decrypted));
108            Ok((
109                Encrypted {
110                    key,
111                    inner: EncryptedInner {
112                        resolution,
113                        decrypted,
114                    },
115                },
116                resolve,
117            ))
118        })
119    }
120
121    fn fetch(&'_ self) -> FailFuture<'_, Self::T> {
122        Box::pin(async move {
123            let Encrypted {
124                key,
125                inner:
126                    EncryptedInner {
127                        resolution,
128                        decrypted: _,
129                    },
130            } = self.encrypted.fetch().await?;
131            let decrypted = self.decrypted.fetch().await?;
132            let decrypted = Unkeyed(Arc::new(decrypted));
133            Ok(Encrypted {
134                key,
135                inner: EncryptedInner {
136                    resolution,
137                    decrypted,
138                },
139            })
140        })
141    }
142
143    fn try_fetch_local(&self) -> object_rainbow::Result<Option<Node<Self::T>>> {
144        let Some((
145            Encrypted {
146                key,
147                inner:
148                    EncryptedInner {
149                        resolution,
150                        decrypted: _,
151                    },
152            },
153            resolve,
154        )) = self.encrypted.try_fetch_local()?
155        else {
156            return Ok(None);
157        };
158        let Some((decrypted, _)) = self.decrypted.try_fetch_local()? else {
159            return Ok(None);
160        };
161        let decrypted = Unkeyed(Arc::new(decrypted));
162        Ok(Some((
163            Encrypted {
164                key,
165                inner: EncryptedInner {
166                    resolution,
167                    decrypted,
168                },
169            },
170            resolve,
171        )))
172    }
173
174    fn fetch_local(&self) -> Option<Self::T> {
175        let Encrypted {
176            key,
177            inner:
178                EncryptedInner {
179                    resolution,
180                    decrypted: _,
181                },
182        } = self.encrypted.fetch_local()?;
183        let decrypted = Unkeyed(Arc::new(self.decrypted.fetch_local()?));
184        Some(Encrypted {
185            key,
186            inner: EncryptedInner {
187                resolution,
188                decrypted,
189            },
190        })
191    }
192}
193
194impl<'a, K: Key, V: PointVisitor> PointVisitor for IterateResolution<'a, '_, K, V> {
195    fn visit<T: Traversible>(&mut self, decrypted: &Point<T>) {
196        let decrypted = decrypted.clone();
197        let encrypted = self.resolution.next().expect("length mismatch").clone();
198        let point = Point::from_fetch(
199            encrypted.hash(),
200            Arc::new(Visited {
201                decrypted,
202                encrypted,
203            }),
204        );
205        self.visitor.visit(&point);
206    }
207}
208
209impl<K: Key, T: Topological> Topological for EncryptedInner<K, T> {
210    fn accept_points(&self, visitor: &mut impl PointVisitor) {
211        let resolution = &mut self.resolution.iter();
212        self.decrypted.0.accept_points(&mut IterateResolution {
213            resolution,
214            visitor,
215        });
216        assert!(resolution.next().is_none());
217    }
218
219    fn point_count(&self) -> usize {
220        self.resolution.len()
221    }
222
223    fn topology_hash(&self) -> Hash {
224        self.resolution.0.data_hash()
225    }
226}
227
228pub struct Encrypted<K, T> {
229    key: K,
230    inner: EncryptedInner<K, T>,
231}
232
233impl<K, T: Clone> Encrypted<K, T> {
234    pub fn into_inner(self) -> T {
235        Arc::unwrap_or_clone(self.inner.decrypted.0)
236    }
237}
238
239impl<K, T> Deref for Encrypted<K, T> {
240    type Target = T;
241
242    fn deref(&self) -> &Self::Target {
243        self.inner.decrypted.0.as_ref()
244    }
245}
246
247impl<K: Clone, T> Clone for Encrypted<K, T> {
248    fn clone(&self) -> Self {
249        Self {
250            key: self.key.clone(),
251            inner: self.inner.clone(),
252        }
253    }
254}
255
256impl<K: Key, T: Topological> Topological for Encrypted<K, T> {
257    fn accept_points(&self, visitor: &mut impl PointVisitor) {
258        self.inner.accept_points(visitor);
259    }
260
261    fn topology_hash(&self) -> Hash {
262        self.inner.topology_hash()
263    }
264}
265
266impl<K: Key, T: ToOutput> ToOutput for Encrypted<K, T> {
267    fn to_output(&self, output: &mut dyn object_rainbow::Output) {
268        let source = self.inner.vec();
269        output.write(&self.key.encrypt(&source));
270    }
271}
272
273#[derive(Clone)]
274struct Decrypt<K> {
275    resolution: Resolution<K>,
276}
277
278impl<K: Key> Decrypt<K> {
279    async fn resolve_bytes(
280        &self,
281        address: Address,
282    ) -> object_rainbow::Result<(Vec<u8>, Resolution<K>)> {
283        let Encrypted {
284            key: _,
285            inner:
286                EncryptedInner {
287                    resolution,
288                    decrypted,
289                },
290        } = self
291            .resolution
292            .get(address.index)
293            .ok_or(Error::AddressOutOfBounds)?
294            .clone()
295            .fetch()
296            .await?;
297        let data = Arc::unwrap_or_clone(decrypted.0);
298        Ok((data, resolution))
299    }
300}
301
302impl<K: Key> Resolve for Decrypt<K> {
303    fn resolve(&'_ self, address: Address) -> FailFuture<'_, ByteNode> {
304        Box::pin(async move {
305            let (data, resolution) = self.resolve_bytes(address).await?;
306            Ok((data, Arc::new(Decrypt { resolution }) as _))
307        })
308    }
309
310    fn resolve_data(&'_ self, address: Address) -> FailFuture<'_, Vec<u8>> {
311        Box::pin(async move {
312            let (data, _) = self.resolve_bytes(address).await?;
313            Ok(data)
314        })
315    }
316
317    fn try_resolve_local(&self, address: Address) -> object_rainbow::Result<Option<ByteNode>> {
318        let Some((
319            Encrypted {
320                key: _,
321                inner:
322                    EncryptedInner {
323                        resolution,
324                        decrypted,
325                    },
326            },
327            _,
328        )) = self
329            .resolution
330            .get(address.index)
331            .ok_or(Error::AddressOutOfBounds)?
332            .clone()
333            .try_fetch_local()?
334        else {
335            return Ok(None);
336        };
337        let data = Arc::unwrap_or_clone(decrypted.0);
338        Ok(Some((data, Arc::new(Decrypt { resolution }) as _)))
339    }
340}
341
342impl<
343    K: Key,
344    T: Object<Extra>,
345    Extra: 'static + Send + Sync + Clone,
346    I: PointInput<Extra = WithKey<K, Extra>>,
347> Parse<I> for Encrypted<K, T>
348{
349    fn parse(input: I) -> object_rainbow::Result<Self> {
350        let with_key = input.extra().clone();
351        let resolve = input.resolve().clone();
352        let source = with_key.key.decrypt(&input.parse_all()?)?;
353        let EncryptedInner {
354            resolution,
355            decrypted,
356        } = EncryptedInner::<K, Vec<u8>>::parse_slice_extra(&source, &resolve, &with_key)?;
357        let decrypted = T::parse_slice_extra(
358            &decrypted.0,
359            &(Arc::new(Decrypt {
360                resolution: resolution.clone(),
361            }) as _),
362            &with_key.extra,
363        )?;
364        let decrypted = Unkeyed(Arc::new(decrypted));
365        let inner = EncryptedInner {
366            resolution,
367            decrypted,
368        };
369        Ok(Self {
370            key: with_key.key,
371            inner,
372        })
373    }
374}
375
376impl<K, T> Tagged for Encrypted<K, T> {}
377
378impl<K: Key, T: Object<Extra>, Extra: 'static + Send + Sync + Clone> Object<WithKey<K, Extra>>
379    for Encrypted<K, T>
380{
381}
382
383type Extracted<K> = Vec<
384    std::pin::Pin<
385        Box<dyn Future<Output = Result<Point<Encrypted<K, Vec<u8>>>, Error>> + Send + 'static>,
386    >,
387>;
388
389struct ExtractResolution<'a, K> {
390    extracted: &'a mut Extracted<K>,
391    key: &'a K,
392}
393
394struct Untyped<K, T> {
395    key: WithKey<K, ()>,
396    encrypted: Point<Encrypted<K, T>>,
397}
398
399impl<K, T> FetchBytes for Untyped<K, T> {
400    fn fetch_bytes(&'_ self) -> FailFuture<'_, ByteNode> {
401        self.encrypted.fetch_bytes()
402    }
403
404    fn fetch_data(&'_ self) -> FailFuture<'_, Vec<u8>> {
405        self.encrypted.fetch_data()
406    }
407
408    fn fetch_bytes_local(&self) -> object_rainbow::Result<Option<ByteNode>> {
409        self.encrypted.fetch_bytes_local()
410    }
411
412    fn fetch_data_local(&self) -> Option<Vec<u8>> {
413        self.encrypted.fetch_data_local()
414    }
415}
416
417impl<K: Send + Sync, T> Singular for Untyped<K, T> {
418    fn hash(&self) -> Hash {
419        self.encrypted.hash()
420    }
421}
422
423impl<K: Key, T: FullHash> Fetch for Untyped<K, T> {
424    type T = Encrypted<K, Vec<u8>>;
425
426    fn fetch_full(&'_ self) -> FailFuture<'_, Node<Self::T>> {
427        Box::pin(async move {
428            let (data, resolve) = self.fetch_bytes().await?;
429            let encrypted = Self::T::parse_slice_extra(&data, &resolve, &self.key)?;
430            Ok((encrypted, resolve))
431        })
432    }
433
434    fn fetch(&'_ self) -> FailFuture<'_, Self::T> {
435        Box::pin(async move {
436            let (data, resolve) = self.fetch_bytes().await?;
437            let encrypted = Self::T::parse_slice_extra(&data, &resolve, &self.key)?;
438            Ok(encrypted)
439        })
440    }
441
442    fn try_fetch_local(&self) -> object_rainbow::Result<Option<Node<Self::T>>> {
443        let Some((data, resolve)) = self.fetch_bytes_local()? else {
444            return Ok(None);
445        };
446        let encrypted = Self::T::parse_slice_extra(&data, &resolve, &self.key)?;
447        Ok(Some((encrypted, resolve)))
448    }
449
450    fn fetch_local(&self) -> Option<Self::T> {
451        let Encrypted {
452            key,
453            inner:
454                EncryptedInner {
455                    resolution,
456                    decrypted,
457                },
458        } = self.encrypted.fetch_local()?;
459        let decrypted = Unkeyed(Arc::new(decrypted.vec()));
460        Some(Encrypted {
461            key,
462            inner: EncryptedInner {
463                resolution,
464                decrypted,
465            },
466        })
467    }
468}
469
470impl<K: Key> PointVisitor for ExtractResolution<'_, K> {
471    fn visit<T: Traversible>(&mut self, decrypted: &Point<T>) {
472        let decrypted = decrypted.clone();
473        let key = self.key.clone();
474        self.extracted.push(Box::pin(async move {
475            let encrypted = encrypt_point(key.clone(), decrypted).await?;
476            let encrypted = Point::from_fetch(
477                encrypted.hash(),
478                Arc::new(Untyped {
479                    key: WithKey { key, extra: () },
480                    encrypted,
481                }),
482            );
483            Ok(encrypted)
484        }));
485    }
486}
487
488pub async fn encrypt_point<K: Key, T: Traversible>(
489    key: K,
490    decrypted: Point<T>,
491) -> object_rainbow::Result<Point<Encrypted<K, T>>> {
492    if let Some((address, decrypt)) = decrypted.extract_resolve::<Decrypt<K>>() {
493        let encrypted = decrypt
494            .resolution
495            .get(address.index)
496            .ok_or(Error::AddressOutOfBounds)?
497            .clone();
498        let point = Point::from_fetch(
499            encrypted.hash(),
500            Arc::new(Visited {
501                decrypted,
502                encrypted,
503            }),
504        );
505        return Ok(point);
506    }
507    let decrypted = decrypted.fetch().await?;
508    let encrypted = encrypt(key.clone(), decrypted).await?;
509    let point = encrypted.point();
510    Ok(point)
511}
512
513pub async fn encrypt<K: Key, T: Traversible>(
514    key: K,
515    decrypted: T,
516) -> object_rainbow::Result<Encrypted<K, T>> {
517    let mut futures = Vec::with_capacity(decrypted.point_count());
518    decrypted.accept_points(&mut ExtractResolution {
519        extracted: &mut futures,
520        key: &key,
521    });
522    let resolution = futures_util::future::try_join_all(futures).await?;
523    let resolution = Arc::new(Lp(resolution));
524    let decrypted = Unkeyed(Arc::new(decrypted));
525    let inner = EncryptedInner {
526        resolution,
527        decrypted,
528    };
529    Ok(Encrypted { key, inner })
530}