object_rainbow_encrypted/
lib.rs

1use std::{ops::Deref, sync::Arc};
2
3use object_rainbow::{
4    Address, ByteNode, Error, FailFuture, Fetch, FetchBytes, Hash, Object, Parse, ParseSliceExtra,
5    Point, PointInput, PointVisitor, Resolve, Singular, Tagged, ToOutput, Topological, Traversible,
6    length_prefixed::Lp,
7};
8
9#[derive(Clone)]
10pub struct WithKey<K, Extra> {
11    pub key: K,
12    pub extra: Extra,
13}
14
15pub trait Key: 'static + Sized + Send + Sync + Clone {
16    fn encrypt(&self, data: &[u8]) -> Vec<u8>;
17    fn decrypt(&self, data: &[u8]) -> object_rainbow::Result<Vec<u8>>;
18}
19
20type Resolution<K> = Arc<Lp<Vec<Point<Encrypted<K, Vec<u8>>>>>>;
21
22#[derive(ToOutput, Clone)]
23struct Unkeyed<T>(T);
24
25impl<
26    T: Parse<I::WithExtra<Extra>>,
27    K: 'static + Clone,
28    Extra: 'static + Clone,
29    I: PointInput<Extra = WithKey<K, Extra>>,
30> Parse<I> for Unkeyed<T>
31{
32    fn parse(input: I) -> object_rainbow::Result<Self> {
33        Ok(Self(T::parse(
34            input.map_extra(|WithKey { extra, .. }| extra),
35        )?))
36    }
37}
38
39#[derive(ToOutput, Parse)]
40struct EncryptedInner<K, T> {
41    resolution: Resolution<K>,
42    decrypted: Unkeyed<Arc<T>>,
43}
44
45impl<K, T> Clone for EncryptedInner<K, T> {
46    fn clone(&self) -> Self {
47        Self {
48            resolution: self.resolution.clone(),
49            decrypted: self.decrypted.clone(),
50        }
51    }
52}
53
54type ResolutionIter<'a, K> = std::slice::Iter<'a, Point<Encrypted<K, Vec<u8>>>>;
55
56struct IterateResolution<'a, K, V> {
57    resolution: ResolutionIter<'a, K>,
58    visitor: &'a mut V,
59}
60
61struct Visited<K, T> {
62    decrypted: Point<T>,
63    encrypted: Point<Encrypted<K, Vec<u8>>>,
64}
65
66impl<K, T> FetchBytes for Visited<K, T> {
67    fn fetch_bytes(&'_ self) -> FailFuture<'_, ByteNode> {
68        self.encrypted.fetch_bytes()
69    }
70
71    fn fetch_data(&'_ self) -> FailFuture<'_, Vec<u8>> {
72        self.encrypted.fetch_data()
73    }
74}
75
76impl<K: Key, T: Traversible> Fetch for Visited<K, T> {
77    type T = Encrypted<K, T>;
78
79    fn fetch_full(&'_ self) -> FailFuture<'_, (Self::T, Arc<dyn Resolve>)> {
80        Box::pin(async move {
81            let (
82                Encrypted {
83                    key,
84                    inner:
85                        EncryptedInner {
86                            resolution,
87                            decrypted: _,
88                        },
89                },
90                resolve,
91            ) = self.encrypted.fetch_full().await?;
92            let decrypted = self.decrypted.fetch().await?;
93            let decrypted = Unkeyed(Arc::new(decrypted));
94            Ok((
95                Encrypted {
96                    key,
97                    inner: EncryptedInner {
98                        resolution,
99                        decrypted,
100                    },
101                },
102                resolve,
103            ))
104        })
105    }
106
107    fn fetch(&'_ self) -> FailFuture<'_, Self::T> {
108        Box::pin(async move {
109            let Encrypted {
110                key,
111                inner:
112                    EncryptedInner {
113                        resolution,
114                        decrypted: _,
115                    },
116            } = self.encrypted.fetch().await?;
117            let decrypted = self.decrypted.fetch().await?;
118            let decrypted = Unkeyed(Arc::new(decrypted));
119            Ok(Encrypted {
120                key,
121                inner: EncryptedInner {
122                    resolution,
123                    decrypted,
124                },
125            })
126        })
127    }
128}
129
130impl<'a, K: Key, V: PointVisitor> PointVisitor for IterateResolution<'a, K, V> {
131    fn visit<T: Traversible>(&mut self, decrypted: &Point<T>) {
132        let decrypted = decrypted.clone();
133        let encrypted = self.resolution.next().expect("length mismatch").clone();
134        let point = Point::from_origin(
135            encrypted.hash(),
136            Arc::new(Visited {
137                decrypted,
138                encrypted,
139            }),
140        );
141        self.visitor.visit(&point);
142    }
143}
144
145impl<K: Key, T: Topological> Topological for EncryptedInner<K, T> {
146    fn accept_points(&self, visitor: &mut impl PointVisitor) {
147        self.decrypted.0.accept_points(&mut IterateResolution {
148            resolution: self.resolution.iter(),
149            visitor,
150        });
151    }
152
153    fn point_count(&self) -> usize {
154        self.resolution.len()
155    }
156
157    fn topology_hash(&self) -> Hash {
158        self.resolution.0.data_hash()
159    }
160}
161
162pub struct Encrypted<K, T> {
163    key: K,
164    inner: EncryptedInner<K, T>,
165}
166
167impl<K, T: Clone> Encrypted<K, T> {
168    pub fn into_inner(self) -> T {
169        Arc::unwrap_or_clone(self.inner.decrypted.0)
170    }
171}
172
173impl<K, T> Deref for Encrypted<K, T> {
174    type Target = T;
175
176    fn deref(&self) -> &Self::Target {
177        self.inner.decrypted.0.as_ref()
178    }
179}
180
181impl<K: Clone, T> Clone for Encrypted<K, T> {
182    fn clone(&self) -> Self {
183        Self {
184            key: self.key.clone(),
185            inner: self.inner.clone(),
186        }
187    }
188}
189
190impl<K: Key, T: Topological> Topological for Encrypted<K, T> {
191    fn accept_points(&self, visitor: &mut impl PointVisitor) {
192        self.inner.accept_points(visitor);
193    }
194
195    fn topology_hash(&self) -> Hash {
196        self.inner.topology_hash()
197    }
198}
199
200impl<K: Key, T: ToOutput> ToOutput for Encrypted<K, T> {
201    fn to_output(&self, output: &mut dyn object_rainbow::Output) {
202        let source = self.inner.vec();
203        output.write(&self.key.encrypt(&source));
204    }
205}
206
207#[derive(Clone)]
208struct Decrypt<K> {
209    resolution: Resolution<K>,
210}
211
212impl<K: Key> Decrypt<K> {
213    async fn resolve_bytes(
214        &self,
215        address: Address,
216    ) -> object_rainbow::Result<(Vec<u8>, Resolution<K>)> {
217        let Encrypted {
218            key: _,
219            inner:
220                EncryptedInner {
221                    resolution,
222                    decrypted,
223                },
224        } = self
225            .resolution
226            .get(address.index)
227            .ok_or(Error::AddressOutOfBounds)?
228            .clone()
229            .fetch()
230            .await?;
231        let data = Arc::unwrap_or_clone(decrypted.0);
232        Ok((data, resolution))
233    }
234}
235
236impl<K: Key> Resolve for Decrypt<K> {
237    fn resolve(&'_ self, address: Address) -> FailFuture<'_, ByteNode> {
238        Box::pin(async move {
239            let (data, resolution) = self.resolve_bytes(address).await?;
240            Ok((data, Arc::new(Decrypt { resolution }) as _))
241        })
242    }
243
244    fn resolve_data(&'_ self, address: Address) -> FailFuture<'_, Vec<u8>> {
245        Box::pin(async move {
246            let (data, _) = self.resolve_bytes(address).await?;
247            Ok(data)
248        })
249    }
250
251    fn name(&self) -> &str {
252        "decrypt"
253    }
254}
255
256impl<
257    K: Key,
258    T: Object<Extra>,
259    Extra: 'static + Send + Sync + Clone,
260    I: PointInput<Extra = WithKey<K, Extra>>,
261> Parse<I> for Encrypted<K, T>
262{
263    fn parse(input: I) -> object_rainbow::Result<Self> {
264        let with_key = input.extra().clone();
265        let resolve = input.resolve().clone();
266        let source = with_key.key.decrypt(&input.parse_all()?)?;
267        let EncryptedInner {
268            resolution,
269            decrypted,
270        } = EncryptedInner::<K, Vec<u8>>::parse_slice_extra(&source, &resolve, &with_key)?;
271        let decrypted = T::parse_slice_extra(
272            &decrypted.0,
273            &(Arc::new(Decrypt {
274                resolution: resolution.clone(),
275            }) as _),
276            &with_key.extra,
277        )?;
278        let decrypted = Unkeyed(Arc::new(decrypted));
279        let inner = EncryptedInner {
280            resolution,
281            decrypted,
282        };
283        Ok(Self {
284            key: with_key.key,
285            inner,
286        })
287    }
288}
289
290impl<K, T> Tagged for Encrypted<K, T> {}
291
292impl<K: Key, T: Object<Extra>, Extra: 'static + Send + Sync + Clone> Object<WithKey<K, Extra>>
293    for Encrypted<K, T>
294{
295}
296
297type Extracted<K> = Vec<
298    std::pin::Pin<
299        Box<dyn Future<Output = Result<Point<Encrypted<K, Vec<u8>>>, Error>> + Send + 'static>,
300    >,
301>;
302
303struct ExtractResolution<'a, K> {
304    extracted: &'a mut Extracted<K>,
305    key: &'a K,
306}
307
308struct Untyped<K, T> {
309    key: WithKey<K, ()>,
310    encrypted: Point<Encrypted<K, T>>,
311}
312
313impl<K, T> FetchBytes for Untyped<K, T> {
314    fn fetch_bytes(&'_ self) -> FailFuture<'_, ByteNode> {
315        self.encrypted.fetch_bytes()
316    }
317
318    fn fetch_data(&'_ self) -> FailFuture<'_, Vec<u8>> {
319        self.encrypted.fetch_data()
320    }
321}
322
323impl<K: Key, T> Fetch for Untyped<K, T> {
324    type T = Encrypted<K, Vec<u8>>;
325
326    fn fetch_full(&'_ self) -> FailFuture<'_, (Self::T, Arc<dyn Resolve>)> {
327        Box::pin(async move {
328            let (data, resolve) = self.fetch_bytes().await?;
329            let encrypted = Self::T::parse_slice_extra(&data, &resolve, &self.key)?;
330            Ok((encrypted, resolve))
331        })
332    }
333
334    fn fetch(&'_ self) -> FailFuture<'_, Self::T> {
335        Box::pin(async move {
336            let (data, resolve) = self.fetch_bytes().await?;
337            let encrypted = Self::T::parse_slice_extra(&data, &resolve, &self.key)?;
338            Ok(encrypted)
339        })
340    }
341}
342
343impl<K: Key> PointVisitor for ExtractResolution<'_, K> {
344    fn visit<T: Traversible>(&mut self, decrypted: &Point<T>) {
345        let decrypted = decrypted.clone();
346        let key = self.key.clone();
347        self.extracted.push(Box::pin(async move {
348            let encrypted = encrypt_point(key.clone(), decrypted).await?;
349            let encrypted = Point::from_origin(
350                encrypted.hash(),
351                Arc::new(Untyped {
352                    key: WithKey { key, extra: () },
353                    encrypted,
354                }),
355            );
356            Ok(encrypted)
357        }));
358    }
359}
360
361pub async fn encrypt_point<K: Key, T: Traversible>(
362    key: K,
363    decrypted: Point<T>,
364) -> object_rainbow::Result<Point<Encrypted<K, T>>> {
365    if let Some((address, decrypt)) = decrypted.extract_resolve::<Decrypt<K>>() {
366        let encrypted = decrypt
367            .resolution
368            .get(address.index)
369            .ok_or(Error::AddressOutOfBounds)?
370            .clone();
371        let point = Point::from_origin(
372            encrypted.hash(),
373            Arc::new(Visited {
374                decrypted,
375                encrypted,
376            }),
377        );
378        return Ok(point);
379    }
380    let decrypted = decrypted.fetch().await?;
381    let encrypted = encrypt(key.clone(), decrypted).await?;
382    let point = encrypted.point();
383    Ok(point)
384}
385
386pub async fn encrypt<K: Key, T: Traversible>(
387    key: K,
388    decrypted: T,
389) -> object_rainbow::Result<Encrypted<K, T>> {
390    let mut futures = Vec::with_capacity(decrypted.point_count());
391    decrypted.accept_points(&mut ExtractResolution {
392        extracted: &mut futures,
393        key: &key,
394    });
395    let resolution = futures_util::future::try_join_all(futures).await?;
396    let resolution = Arc::new(Lp(resolution));
397    let decrypted = Unkeyed(Arc::new(decrypted));
398    let inner = EncryptedInner {
399        resolution,
400        decrypted,
401    };
402    Ok(Encrypted { key, inner })
403}