object_rainbow_encrypted/
lib.rs

1use std::{ops::Deref, sync::Arc};
2
3use object_rainbow::{
4    Address, ByteNode, Error, FailFuture, Fetch, FetchBytes, Hash, Object, Parse, ParseSliceExtra,
5    Point, PointInput, PointVisitor, Resolve, Singular, Tagged, ToOutput, Topological, Traversible,
6    length_prefixed::Lp,
7};
8
9#[derive(Clone)]
10pub struct WithKey<K, Extra> {
11    pub key: K,
12    pub extra: Extra,
13}
14
15pub trait Key: 'static + Sized + Send + Sync + Clone {
16    fn encrypt(&self, data: &[u8]) -> Vec<u8>;
17    fn decrypt(&self, data: &[u8]) -> object_rainbow::Result<Vec<u8>>;
18}
19
20type Resolution<K> = Arc<Lp<Vec<Point<Encrypted<K, Vec<u8>>>>>>;
21
22#[derive(ToOutput, Clone)]
23struct Unkeyed<T>(T);
24
25impl<
26    T: Parse<I::WithExtra<Extra>>,
27    K: 'static + Clone,
28    Extra: 'static + Clone,
29    I: PointInput<Extra = WithKey<K, Extra>>,
30> Parse<I> for Unkeyed<T>
31{
32    fn parse(input: I) -> object_rainbow::Result<Self> {
33        Ok(Self(T::parse(
34            input.map_extra(|WithKey { extra, .. }| extra),
35        )?))
36    }
37}
38
39#[derive(ToOutput, Parse)]
40struct EncryptedInner<K, T> {
41    resolution: Resolution<K>,
42    decrypted: Unkeyed<Arc<T>>,
43}
44
45impl<K, T> Clone for EncryptedInner<K, T> {
46    fn clone(&self) -> Self {
47        Self {
48            resolution: self.resolution.clone(),
49            decrypted: self.decrypted.clone(),
50        }
51    }
52}
53
54type ResolutionIter<'a, K> = std::slice::Iter<'a, Point<Encrypted<K, Vec<u8>>>>;
55
56struct IterateResolution<'a, K, V> {
57    resolution: ResolutionIter<'a, K>,
58    visitor: &'a mut V,
59}
60
61struct Visited<K, T> {
62    decrypted: Point<T>,
63    encrypted: Point<Encrypted<K, Vec<u8>>>,
64}
65
66impl<K, T> FetchBytes for Visited<K, T> {
67    fn fetch_bytes(&'_ self) -> FailFuture<'_, ByteNode> {
68        self.encrypted.fetch_bytes()
69    }
70
71    fn fetch_data(&'_ self) -> FailFuture<'_, Vec<u8>> {
72        self.encrypted.fetch_data()
73    }
74}
75
76impl<K: Key, T: Traversible> Fetch for Visited<K, T> {
77    type T = Encrypted<K, T>;
78
79    fn fetch_full(&'_ self) -> FailFuture<'_, (Self::T, Arc<dyn Resolve>)> {
80        todo!()
81    }
82
83    fn fetch(&'_ self) -> FailFuture<'_, Self::T> {
84        Box::pin(async move {
85            let Encrypted {
86                key,
87                inner:
88                    EncryptedInner {
89                        resolution,
90                        decrypted: _,
91                    },
92            } = self.encrypted.fetch().await?;
93            let decrypted = self.decrypted.fetch().await?;
94            let decrypted = Unkeyed(Arc::new(decrypted));
95            Ok(Encrypted {
96                key,
97                inner: EncryptedInner {
98                    resolution,
99                    decrypted,
100                },
101            })
102        })
103    }
104}
105
106impl<'a, K: Key, V: PointVisitor> PointVisitor for IterateResolution<'a, K, V> {
107    fn visit<T: Traversible>(&mut self, decrypted: &Point<T>) {
108        let decrypted = decrypted.clone();
109        let encrypted = self.resolution.next().expect("length mismatch").clone();
110        let point = Point::from_origin(
111            encrypted.hash(),
112            Arc::new(Visited {
113                decrypted,
114                encrypted,
115            }),
116        );
117        self.visitor.visit(&point);
118    }
119}
120
121impl<K: Key, T: Topological> Topological for EncryptedInner<K, T> {
122    fn accept_points(&self, visitor: &mut impl PointVisitor) {
123        self.decrypted.0.accept_points(&mut IterateResolution {
124            resolution: self.resolution.iter(),
125            visitor,
126        });
127    }
128
129    fn point_count(&self) -> usize {
130        self.resolution.len()
131    }
132
133    fn topology_hash(&self) -> Hash {
134        self.resolution.0.data_hash()
135    }
136}
137
138pub struct Encrypted<K, T> {
139    key: K,
140    inner: EncryptedInner<K, T>,
141}
142
143impl<K, T: Clone> Encrypted<K, T> {
144    pub fn into_inner(self) -> T {
145        Arc::unwrap_or_clone(self.inner.decrypted.0)
146    }
147}
148
149impl<K, T> Deref for Encrypted<K, T> {
150    type Target = T;
151
152    fn deref(&self) -> &Self::Target {
153        self.inner.decrypted.0.as_ref()
154    }
155}
156
157impl<K: Clone, T> Clone for Encrypted<K, T> {
158    fn clone(&self) -> Self {
159        Self {
160            key: self.key.clone(),
161            inner: self.inner.clone(),
162        }
163    }
164}
165
166impl<K: Key, T: Topological> Topological for Encrypted<K, T> {
167    fn accept_points(&self, visitor: &mut impl PointVisitor) {
168        self.inner.accept_points(visitor);
169    }
170
171    fn topology_hash(&self) -> Hash {
172        self.inner.topology_hash()
173    }
174}
175
176impl<K: Key, T: ToOutput> ToOutput for Encrypted<K, T> {
177    fn to_output(&self, output: &mut dyn object_rainbow::Output) {
178        let source = self.inner.vec();
179        output.write(&self.key.encrypt(&source));
180    }
181}
182
183#[derive(Clone)]
184struct Decrypt<K> {
185    resolution: Resolution<K>,
186}
187
188impl<K: Key> Decrypt<K> {
189    async fn resolve_bytes(
190        &self,
191        address: Address,
192    ) -> object_rainbow::Result<(Vec<u8>, Resolution<K>)> {
193        let Encrypted {
194            key: _,
195            inner:
196                EncryptedInner {
197                    resolution,
198                    decrypted,
199                },
200        } = self
201            .resolution
202            .get(address.index)
203            .ok_or(Error::AddressOutOfBounds)?
204            .clone()
205            .fetch()
206            .await?;
207        let data = Arc::unwrap_or_clone(decrypted.0);
208        Ok((data, resolution))
209    }
210}
211
212impl<K: Key> Resolve for Decrypt<K> {
213    fn resolve(&'_ self, address: Address) -> FailFuture<'_, ByteNode> {
214        Box::pin(async move {
215            let (data, resolution) = self.resolve_bytes(address).await?;
216            Ok((data, Arc::new(Decrypt { resolution }) as _))
217        })
218    }
219
220    fn resolve_data(&'_ self, address: Address) -> FailFuture<'_, Vec<u8>> {
221        Box::pin(async move {
222            let (data, _) = self.resolve_bytes(address).await?;
223            Ok(data)
224        })
225    }
226
227    fn name(&self) -> &str {
228        "decrypt"
229    }
230}
231
232impl<
233    K: Key,
234    T: Object<Extra>,
235    Extra: 'static + Send + Sync + Clone,
236    I: PointInput<Extra = WithKey<K, Extra>>,
237> Parse<I> for Encrypted<K, T>
238{
239    fn parse(input: I) -> object_rainbow::Result<Self> {
240        let with_key = input.extra().clone();
241        let resolve = input.resolve().clone();
242        let source = with_key.key.decrypt(input.parse_all()?)?;
243        let EncryptedInner {
244            resolution,
245            decrypted,
246        } = EncryptedInner::<K, Vec<u8>>::parse_slice_extra(&source, &resolve, &with_key)?;
247        let decrypted = T::parse_slice_extra(
248            &decrypted.0,
249            &(Arc::new(Decrypt {
250                resolution: resolution.clone(),
251            }) as _),
252            &with_key.extra,
253        )?;
254        let decrypted = Unkeyed(Arc::new(decrypted));
255        let inner = EncryptedInner {
256            resolution,
257            decrypted,
258        };
259        Ok(Self {
260            key: with_key.key,
261            inner,
262        })
263    }
264}
265
266impl<K, T> Tagged for Encrypted<K, T> {}
267
268impl<K: Key, T: Object<Extra>, Extra: 'static + Send + Sync + Clone> Object<WithKey<K, Extra>>
269    for Encrypted<K, T>
270{
271}
272
273type Extracted<K> = Vec<
274    std::pin::Pin<
275        Box<dyn Future<Output = Result<Point<Encrypted<K, Vec<u8>>>, Error>> + Send + 'static>,
276    >,
277>;
278
279struct ExtractResolution<'a, K> {
280    extracted: &'a mut Extracted<K>,
281    key: &'a K,
282}
283
284struct Untyped<K, T> {
285    key: WithKey<K, ()>,
286    encrypted: Point<Encrypted<K, T>>,
287}
288
289impl<K, T> FetchBytes for Untyped<K, T> {
290    fn fetch_bytes(&'_ self) -> FailFuture<'_, ByteNode> {
291        self.encrypted.fetch_bytes()
292    }
293
294    fn fetch_data(&'_ self) -> FailFuture<'_, Vec<u8>> {
295        self.encrypted.fetch_data()
296    }
297}
298
299impl<K: Key, T> Fetch for Untyped<K, T> {
300    type T = Encrypted<K, Vec<u8>>;
301
302    fn fetch_full(&'_ self) -> FailFuture<'_, (Self::T, Arc<dyn Resolve>)> {
303        Box::pin(async move {
304            let (data, resolve) = self.fetch_bytes().await?;
305            let encrypted = Self::T::parse_slice_extra(&data, &resolve, &self.key)?;
306            Ok((encrypted, resolve))
307        })
308    }
309
310    fn fetch(&'_ self) -> FailFuture<'_, Self::T> {
311        Box::pin(async move {
312            let (data, resolve) = self.fetch_bytes().await?;
313            let encrypted = Self::T::parse_slice_extra(&data, &resolve, &self.key)?;
314            Ok(encrypted)
315        })
316    }
317}
318
319impl<K: Key> PointVisitor for ExtractResolution<'_, K> {
320    fn visit<T: Traversible>(&mut self, decrypted: &Point<T>) {
321        let decrypted = decrypted.clone();
322        let key = self.key.clone();
323        self.extracted.push(Box::pin(async move {
324            let encrypted = encrypt_point(key.clone(), decrypted).await?;
325            let encrypted = Point::from_origin(
326                encrypted.hash(),
327                Arc::new(Untyped {
328                    key: WithKey { key, extra: () },
329                    encrypted,
330                }),
331            );
332            Ok(encrypted)
333        }));
334    }
335}
336
337pub async fn encrypt_point<K: Key, T: Traversible>(
338    key: K,
339    decrypted: Point<T>,
340) -> object_rainbow::Result<Point<Encrypted<K, T>>> {
341    if let Some((address, decrypt)) = decrypted.extract_resolve::<Decrypt<K>>() {
342        let encrypted = decrypt
343            .resolution
344            .get(address.index)
345            .ok_or(Error::AddressOutOfBounds)?
346            .clone();
347        let point = Point::from_origin(
348            encrypted.hash(),
349            Arc::new(Visited {
350                decrypted,
351                encrypted,
352            }),
353        );
354        return Ok(point);
355    }
356    let decrypted = decrypted.fetch().await?;
357    let encrypted = encrypt(key.clone(), decrypted).await?;
358    let point = encrypted.point();
359    Ok(point)
360}
361
362pub async fn encrypt<K: Key, T: Traversible>(
363    key: K,
364    decrypted: T,
365) -> object_rainbow::Result<Encrypted<K, T>> {
366    let mut futures = Vec::with_capacity(decrypted.point_count());
367    decrypted.accept_points(&mut ExtractResolution {
368        extracted: &mut futures,
369        key: &key,
370    });
371    let resolution = futures_util::future::try_join_all(futures).await?;
372    let resolution = Arc::new(Lp(resolution));
373    let decrypted = Unkeyed(Arc::new(decrypted));
374    let inner = EncryptedInner {
375        resolution,
376        decrypted,
377    };
378    Ok(Encrypted { key, inner })
379}