object_rainbow_encrypted/
lib.rs

1use std::{ops::Deref, sync::Arc};
2
3use object_rainbow::{
4    Address, ByteNode, Error, FailFuture, Fetch, Hash, Object, Parse, ParseSliceExtra, Point,
5    PointInput, PointVisitor, RawPoint, Resolve, Tagged, ToOutput, Topological,
6    length_prefixed::Lp,
7};
8
9#[derive(Clone)]
10pub struct WithKey<K, Extra> {
11    pub key: K,
12    pub extra: Extra,
13}
14
15pub trait Key: 'static + Sized + Send + Sync + Clone {
16    fn encrypt(&self, data: &[u8]) -> Vec<u8>;
17    fn decrypt(&self, data: &[u8]) -> object_rainbow::Result<Vec<u8>>;
18}
19
20type Resolution<K, Extra> = Arc<Lp<Vec<RawPoint<Encrypted<K, Vec<u8>, Extra>, WithKey<K, Extra>>>>>;
21
22#[derive(ToOutput, Clone)]
23struct Unkeyed<T>(T);
24
25impl<
26    T: Parse<I::WithExtra<Extra>>,
27    K: 'static + Clone,
28    Extra: 'static + Clone,
29    I: PointInput<Extra = WithKey<K, Extra>>,
30> Parse<I> for Unkeyed<T>
31{
32    fn parse(input: I) -> object_rainbow::Result<Self> {
33        Ok(Self(T::parse(
34            input.map_extra(|WithKey { extra, .. }| extra),
35        )?))
36    }
37}
38
39#[derive(ToOutput, Parse)]
40struct EncryptedInner<K, T, Extra> {
41    resolution: Resolution<K, Extra>,
42    decrypted: Unkeyed<Arc<T>>,
43}
44
45impl<K, T, Extra> Clone for EncryptedInner<K, T, Extra> {
46    fn clone(&self) -> Self {
47        Self {
48            resolution: self.resolution.clone(),
49            decrypted: self.decrypted.clone(),
50        }
51    }
52}
53
54type ResolutionIter<'a, K, Extra> =
55    std::slice::Iter<'a, RawPoint<Encrypted<K, Vec<u8>, Extra>, WithKey<K, Extra>>>;
56
57struct IterateResolution<'a, K, V, Extra> {
58    resolution: ResolutionIter<'a, K, Extra>,
59    visitor: &'a mut V,
60}
61
62impl<'a, K: Key, V: PointVisitor<WithKey<K, Extra>>, Extra: 'static + Send + Sync + Clone>
63    PointVisitor<Extra> for IterateResolution<'a, K, V, Extra>
64{
65    fn visit<T: Object<Extra>>(&mut self, _: &Point<T, Extra>) {
66        let point = self
67            .resolution
68            .next()
69            .expect("length mismatch")
70            .clone()
71            .cast::<Encrypted<K, T, Extra>>()
72            .point();
73        self.visitor.visit(&point);
74    }
75}
76
77impl<K: Key, T: Topological<Extra>, Extra: 'static + Send + Sync + Clone>
78    Topological<WithKey<K, Extra>> for EncryptedInner<K, T, Extra>
79{
80    fn accept_points(&self, visitor: &mut impl PointVisitor<WithKey<K, Extra>>) {
81        self.decrypted.0.accept_points(&mut IterateResolution {
82            resolution: self.resolution.iter(),
83            visitor,
84        });
85    }
86
87    fn point_count(&self) -> usize {
88        self.resolution.len()
89    }
90
91    fn topology_hash(&self) -> Hash {
92        self.resolution.0.data_hash()
93    }
94}
95
96pub struct Encrypted<K, T, Extra> {
97    key: K,
98    inner: EncryptedInner<K, T, Extra>,
99}
100
101impl<K, T: Clone, Extra> Encrypted<K, T, Extra> {
102    pub fn into_inner(self) -> T {
103        Arc::unwrap_or_clone(self.inner.decrypted.0)
104    }
105}
106
107impl<K, T, Extra> Deref for Encrypted<K, T, Extra> {
108    type Target = T;
109
110    fn deref(&self) -> &Self::Target {
111        self.inner.decrypted.0.as_ref()
112    }
113}
114
115impl<K: Clone, T, Extra> Clone for Encrypted<K, T, Extra> {
116    fn clone(&self) -> Self {
117        Self {
118            key: self.key.clone(),
119            inner: self.inner.clone(),
120        }
121    }
122}
123
124impl<K: Key, T: Topological<Extra>, Extra: 'static + Send + Sync + Clone>
125    Topological<WithKey<K, Extra>> for Encrypted<K, T, Extra>
126{
127    fn accept_points(&self, visitor: &mut impl PointVisitor<WithKey<K, Extra>>) {
128        self.inner.accept_points(visitor);
129    }
130
131    fn topology_hash(&self) -> Hash {
132        self.inner.topology_hash()
133    }
134}
135
136impl<K: Key, T: ToOutput, Extra> ToOutput for Encrypted<K, T, Extra> {
137    fn to_output(&self, output: &mut dyn object_rainbow::Output) {
138        let source = self.inner.vec();
139        output.write(&self.key.encrypt(&source));
140    }
141}
142
143#[derive(Clone)]
144struct Decrypt<K, Extra> {
145    resolution: Resolution<K, Extra>,
146}
147
148impl<K: Key, Extra: 'static + Send + Sync + Clone> Resolve for Decrypt<K, Extra> {
149    fn resolve(&'_ self, address: Address) -> FailFuture<'_, ByteNode> {
150        Box::pin(async move {
151            let Encrypted {
152                key: _,
153                inner:
154                    EncryptedInner {
155                        resolution,
156                        decrypted,
157                    },
158            } = self
159                .resolution
160                .get(address.index)
161                .ok_or(Error::AddressOutOfBounds)?
162                .clone()
163                .fetch()
164                .await?;
165            Ok((
166                Arc::into_inner(decrypted.0).expect("not shared because reconstructed"),
167                Arc::new(Decrypt { resolution }) as _,
168            ))
169        })
170    }
171
172    fn name(&self) -> &str {
173        "decrypt"
174    }
175}
176
177impl<
178    K: Key,
179    T: Object<Extra>,
180    Extra: 'static + Send + Sync + Clone,
181    I: PointInput<Extra = WithKey<K, Extra>>,
182> Parse<I> for Encrypted<K, T, Extra>
183{
184    fn parse(input: I) -> object_rainbow::Result<Self> {
185        let with_key = input.extra().clone();
186        let resolve = input.resolve().clone();
187        let source = with_key.key.decrypt(input.parse_all()?)?;
188        let EncryptedInner {
189            resolution,
190            decrypted,
191        } = EncryptedInner::<K, Vec<u8>, Extra>::parse_slice_extra(&source, &resolve, &with_key)?;
192        let decrypted = T::parse_slice_extra(
193            &decrypted.0,
194            &(Arc::new(Decrypt {
195                resolution: resolution.clone(),
196            }) as _),
197            &with_key.extra,
198        )?;
199        let decrypted = Unkeyed(Arc::new(decrypted));
200        let inner = EncryptedInner {
201            resolution,
202            decrypted,
203        };
204        Ok(Self {
205            key: with_key.key,
206            inner,
207        })
208    }
209}
210
211impl<K, T, Extra> Tagged for Encrypted<K, T, Extra> {}
212
213impl<K: Key, T: Object<Extra>, Extra: 'static + Send + Sync + Clone> Object<WithKey<K, Extra>>
214    for Encrypted<K, T, Extra>
215{
216}
217
218type Extracted<K, Extra> = Vec<
219    std::pin::Pin<
220        Box<
221            dyn Future<
222                    Output = Result<
223                        RawPoint<Encrypted<K, Vec<u8>, Extra>, WithKey<K, Extra>>,
224                        Error,
225                    >,
226                > + Send
227                + 'static,
228        >,
229    >,
230>;
231
232struct ExtractResolution<'a, K, Extra>(&'a mut Extracted<K, Extra>, &'a K);
233
234impl<K: Key, Extra: 'static + Send + Sync + Clone> PointVisitor<Extra>
235    for ExtractResolution<'_, K, Extra>
236{
237    fn visit<T: Object<Extra>>(&mut self, point: &Point<T, Extra>) {
238        let point = point.clone();
239        let key = self.1.clone();
240        self.0.push(Box::pin(async move {
241            let point = encrypt_point(key, point).await?.raw().cast();
242            Ok(point)
243        }));
244    }
245}
246
247pub async fn encrypt_point<K: Key, T: Object<Extra>, Extra: 'static + Send + Sync + Clone>(
248    key: K,
249    point: Point<T, Extra>,
250) -> object_rainbow::Result<Point<Encrypted<K, T, Extra>, WithKey<K, Extra>>> {
251    if let Some((address, decrypt)) = point.extract_resolve::<Decrypt<K, Extra>>() {
252        let point = decrypt
253            .resolution
254            .get(address.index)
255            .ok_or(Error::AddressOutOfBounds)?;
256        return Ok(point.clone().cast().point());
257    }
258    let decrypted = point.fetch().await?;
259    let encrypted = encrypt(key.clone(), decrypted).await?;
260    let point = encrypted.point_extra(WithKey {
261        key,
262        extra: point.extra().clone(),
263    });
264    Ok(point)
265}
266
267pub async fn encrypt<K: Key, T: Object<Extra>, Extra: 'static + Send + Sync + Clone>(
268    key: K,
269    decrypted: T,
270) -> object_rainbow::Result<Encrypted<K, T, Extra>> {
271    let mut futures = Vec::with_capacity(decrypted.point_count());
272    decrypted.accept_points(&mut ExtractResolution(&mut futures, &key));
273    let resolution = futures_util::future::try_join_all(futures).await?;
274    let resolution = Arc::new(Lp(resolution));
275    let decrypted = Unkeyed(Arc::new(decrypted));
276    let inner = EncryptedInner {
277        resolution,
278        decrypted,
279    };
280    Ok(Encrypted { key, inner })
281}