object_rainbow_encrypted/
lib.rs

1use std::{ops::Deref, sync::Arc};
2
3use object_rainbow::{
4    Address, ByteNode, Error, FailFuture, Fetch, FetchBytes, Hash, Object, Parse, ParseSliceExtra,
5    Point, PointInput, PointVisitor, Resolve, Singular, Tagged, ToOutput, Topological, Traversible,
6    length_prefixed::Lp,
7};
8
9#[derive(Clone)]
10pub struct WithKey<K, Extra> {
11    pub key: K,
12    pub extra: Extra,
13}
14
15pub trait Key: 'static + Sized + Send + Sync + Clone {
16    fn encrypt(&self, data: &[u8]) -> Vec<u8>;
17    fn decrypt(&self, data: &[u8]) -> object_rainbow::Result<Vec<u8>>;
18}
19
20type Resolution<K> = Arc<Lp<Vec<Point<Encrypted<K, Vec<u8>>>>>>;
21
22#[derive(ToOutput, Clone)]
23struct Unkeyed<T>(T);
24
25impl<
26    T: Parse<I::WithExtra<Extra>>,
27    K: 'static + Clone,
28    Extra: 'static + Clone,
29    I: PointInput<Extra = WithKey<K, Extra>>,
30> Parse<I> for Unkeyed<T>
31{
32    fn parse(input: I) -> object_rainbow::Result<Self> {
33        Ok(Self(T::parse(
34            input.map_extra(|WithKey { extra, .. }| extra),
35        )?))
36    }
37}
38
39#[derive(ToOutput, Parse)]
40struct EncryptedInner<K, T> {
41    resolution: Resolution<K>,
42    decrypted: Unkeyed<Arc<T>>,
43}
44
45impl<K, T> Clone for EncryptedInner<K, T> {
46    fn clone(&self) -> Self {
47        Self {
48            resolution: self.resolution.clone(),
49            decrypted: self.decrypted.clone(),
50        }
51    }
52}
53
54type ResolutionIter<'a, K> = std::slice::Iter<'a, Point<Encrypted<K, Vec<u8>>>>;
55
56struct IterateResolution<'a, 'r, K, V> {
57    resolution: &'r mut ResolutionIter<'a, K>,
58    visitor: &'a mut V,
59}
60
61struct Visited<K, T> {
62    decrypted: Point<T>,
63    encrypted: Point<Encrypted<K, Vec<u8>>>,
64}
65
66impl<K, T> FetchBytes for Visited<K, T> {
67    fn fetch_bytes(&'_ self) -> FailFuture<'_, ByteNode> {
68        self.encrypted.fetch_bytes()
69    }
70
71    fn fetch_data(&'_ self) -> FailFuture<'_, Vec<u8>> {
72        self.encrypted.fetch_data()
73    }
74}
75
76impl<K: Key, T: Traversible> Fetch for Visited<K, T> {
77    type T = Encrypted<K, T>;
78
79    fn fetch_full(&'_ self) -> FailFuture<'_, (Self::T, Arc<dyn Resolve>)> {
80        Box::pin(async move {
81            let (
82                Encrypted {
83                    key,
84                    inner:
85                        EncryptedInner {
86                            resolution,
87                            decrypted: _,
88                        },
89                },
90                resolve,
91            ) = self.encrypted.fetch_full().await?;
92            let decrypted = self.decrypted.fetch().await?;
93            let decrypted = Unkeyed(Arc::new(decrypted));
94            Ok((
95                Encrypted {
96                    key,
97                    inner: EncryptedInner {
98                        resolution,
99                        decrypted,
100                    },
101                },
102                resolve,
103            ))
104        })
105    }
106
107    fn fetch(&'_ self) -> FailFuture<'_, Self::T> {
108        Box::pin(async move {
109            let Encrypted {
110                key,
111                inner:
112                    EncryptedInner {
113                        resolution,
114                        decrypted: _,
115                    },
116            } = self.encrypted.fetch().await?;
117            let decrypted = self.decrypted.fetch().await?;
118            let decrypted = Unkeyed(Arc::new(decrypted));
119            Ok(Encrypted {
120                key,
121                inner: EncryptedInner {
122                    resolution,
123                    decrypted,
124                },
125            })
126        })
127    }
128}
129
130impl<'a, K: Key, V: PointVisitor> PointVisitor for IterateResolution<'a, '_, K, V> {
131    fn visit<T: Traversible>(&mut self, decrypted: &Point<T>) {
132        let decrypted = decrypted.clone();
133        let encrypted = self.resolution.next().expect("length mismatch").clone();
134        let point = Point::from_fetch(
135            encrypted.hash(),
136            Arc::new(Visited {
137                decrypted,
138                encrypted,
139            }),
140        );
141        self.visitor.visit(&point);
142    }
143}
144
145impl<K: Key, T: Topological> Topological for EncryptedInner<K, T> {
146    fn accept_points(&self, visitor: &mut impl PointVisitor) {
147        let resolution = &mut self.resolution.iter();
148        self.decrypted.0.accept_points(&mut IterateResolution {
149            resolution,
150            visitor,
151        });
152        assert!(resolution.next().is_none());
153    }
154
155    fn point_count(&self) -> usize {
156        self.resolution.len()
157    }
158
159    fn topology_hash(&self) -> Hash {
160        self.resolution.0.data_hash()
161    }
162}
163
164pub struct Encrypted<K, T> {
165    key: K,
166    inner: EncryptedInner<K, T>,
167}
168
169impl<K, T: Clone> Encrypted<K, T> {
170    pub fn into_inner(self) -> T {
171        Arc::unwrap_or_clone(self.inner.decrypted.0)
172    }
173}
174
175impl<K, T> Deref for Encrypted<K, T> {
176    type Target = T;
177
178    fn deref(&self) -> &Self::Target {
179        self.inner.decrypted.0.as_ref()
180    }
181}
182
183impl<K: Clone, T> Clone for Encrypted<K, T> {
184    fn clone(&self) -> Self {
185        Self {
186            key: self.key.clone(),
187            inner: self.inner.clone(),
188        }
189    }
190}
191
192impl<K: Key, T: Topological> Topological for Encrypted<K, T> {
193    fn accept_points(&self, visitor: &mut impl PointVisitor) {
194        self.inner.accept_points(visitor);
195    }
196
197    fn topology_hash(&self) -> Hash {
198        self.inner.topology_hash()
199    }
200}
201
202impl<K: Key, T: ToOutput> ToOutput for Encrypted<K, T> {
203    fn to_output(&self, output: &mut dyn object_rainbow::Output) {
204        let source = self.inner.vec();
205        output.write(&self.key.encrypt(&source));
206    }
207}
208
209#[derive(Clone)]
210struct Decrypt<K> {
211    resolution: Resolution<K>,
212}
213
214impl<K: Key> Decrypt<K> {
215    async fn resolve_bytes(
216        &self,
217        address: Address,
218    ) -> object_rainbow::Result<(Vec<u8>, Resolution<K>)> {
219        let Encrypted {
220            key: _,
221            inner:
222                EncryptedInner {
223                    resolution,
224                    decrypted,
225                },
226        } = self
227            .resolution
228            .get(address.index)
229            .ok_or(Error::AddressOutOfBounds)?
230            .clone()
231            .fetch()
232            .await?;
233        let data = Arc::unwrap_or_clone(decrypted.0);
234        Ok((data, resolution))
235    }
236}
237
238impl<K: Key> Resolve for Decrypt<K> {
239    fn resolve(&'_ self, address: Address) -> FailFuture<'_, ByteNode> {
240        Box::pin(async move {
241            let (data, resolution) = self.resolve_bytes(address).await?;
242            Ok((data, Arc::new(Decrypt { resolution }) as _))
243        })
244    }
245
246    fn resolve_data(&'_ self, address: Address) -> FailFuture<'_, Vec<u8>> {
247        Box::pin(async move {
248            let (data, _) = self.resolve_bytes(address).await?;
249            Ok(data)
250        })
251    }
252}
253
254impl<
255    K: Key,
256    T: Object<Extra>,
257    Extra: 'static + Send + Sync + Clone,
258    I: PointInput<Extra = WithKey<K, Extra>>,
259> Parse<I> for Encrypted<K, T>
260{
261    fn parse(input: I) -> object_rainbow::Result<Self> {
262        let with_key = input.extra().clone();
263        let resolve = input.resolve().clone();
264        let source = with_key.key.decrypt(&input.parse_all()?)?;
265        let EncryptedInner {
266            resolution,
267            decrypted,
268        } = EncryptedInner::<K, Vec<u8>>::parse_slice_extra(&source, &resolve, &with_key)?;
269        let decrypted = T::parse_slice_extra(
270            &decrypted.0,
271            &(Arc::new(Decrypt {
272                resolution: resolution.clone(),
273            }) as _),
274            &with_key.extra,
275        )?;
276        let decrypted = Unkeyed(Arc::new(decrypted));
277        let inner = EncryptedInner {
278            resolution,
279            decrypted,
280        };
281        Ok(Self {
282            key: with_key.key,
283            inner,
284        })
285    }
286}
287
288impl<K, T> Tagged for Encrypted<K, T> {}
289
290impl<K: Key, T: Object<Extra>, Extra: 'static + Send + Sync + Clone> Object<WithKey<K, Extra>>
291    for Encrypted<K, T>
292{
293}
294
295type Extracted<K> = Vec<
296    std::pin::Pin<
297        Box<dyn Future<Output = Result<Point<Encrypted<K, Vec<u8>>>, Error>> + Send + 'static>,
298    >,
299>;
300
301struct ExtractResolution<'a, K> {
302    extracted: &'a mut Extracted<K>,
303    key: &'a K,
304}
305
306struct Untyped<K, T> {
307    key: WithKey<K, ()>,
308    encrypted: Point<Encrypted<K, T>>,
309}
310
311impl<K, T> FetchBytes for Untyped<K, T> {
312    fn fetch_bytes(&'_ self) -> FailFuture<'_, ByteNode> {
313        self.encrypted.fetch_bytes()
314    }
315
316    fn fetch_data(&'_ self) -> FailFuture<'_, Vec<u8>> {
317        self.encrypted.fetch_data()
318    }
319}
320
321impl<K: Key, T> Fetch for Untyped<K, T> {
322    type T = Encrypted<K, Vec<u8>>;
323
324    fn fetch_full(&'_ self) -> FailFuture<'_, (Self::T, Arc<dyn Resolve>)> {
325        Box::pin(async move {
326            let (data, resolve) = self.fetch_bytes().await?;
327            let encrypted = Self::T::parse_slice_extra(&data, &resolve, &self.key)?;
328            Ok((encrypted, resolve))
329        })
330    }
331
332    fn fetch(&'_ self) -> FailFuture<'_, Self::T> {
333        Box::pin(async move {
334            let (data, resolve) = self.fetch_bytes().await?;
335            let encrypted = Self::T::parse_slice_extra(&data, &resolve, &self.key)?;
336            Ok(encrypted)
337        })
338    }
339}
340
341impl<K: Key> PointVisitor for ExtractResolution<'_, K> {
342    fn visit<T: Traversible>(&mut self, decrypted: &Point<T>) {
343        let decrypted = decrypted.clone();
344        let key = self.key.clone();
345        self.extracted.push(Box::pin(async move {
346            let encrypted = encrypt_point(key.clone(), decrypted).await?;
347            let encrypted = Point::from_fetch(
348                encrypted.hash(),
349                Arc::new(Untyped {
350                    key: WithKey { key, extra: () },
351                    encrypted,
352                }),
353            );
354            Ok(encrypted)
355        }));
356    }
357}
358
359pub async fn encrypt_point<K: Key, T: Traversible>(
360    key: K,
361    decrypted: Point<T>,
362) -> object_rainbow::Result<Point<Encrypted<K, T>>> {
363    if let Some((address, decrypt)) = decrypted.extract_resolve::<Decrypt<K>>() {
364        let encrypted = decrypt
365            .resolution
366            .get(address.index)
367            .ok_or(Error::AddressOutOfBounds)?
368            .clone();
369        let point = Point::from_fetch(
370            encrypted.hash(),
371            Arc::new(Visited {
372                decrypted,
373                encrypted,
374            }),
375        );
376        return Ok(point);
377    }
378    let decrypted = decrypted.fetch().await?;
379    let encrypted = encrypt(key.clone(), decrypted).await?;
380    let point = encrypted.point();
381    Ok(point)
382}
383
384pub async fn encrypt<K: Key, T: Traversible>(
385    key: K,
386    decrypted: T,
387) -> object_rainbow::Result<Encrypted<K, T>> {
388    let mut futures = Vec::with_capacity(decrypted.point_count());
389    decrypted.accept_points(&mut ExtractResolution {
390        extracted: &mut futures,
391        key: &key,
392    });
393    let resolution = futures_util::future::try_join_all(futures).await?;
394    let resolution = Arc::new(Lp(resolution));
395    let decrypted = Unkeyed(Arc::new(decrypted));
396    let inner = EncryptedInner {
397        resolution,
398        decrypted,
399    };
400    Ok(Encrypted { key, inner })
401}