1use std::{ops::Deref, sync::Arc};
2
3use object_rainbow::{
4 Address, ByteNode, Error, FailFuture, Fetch, FetchBytes, Hash, Object, Parse, ParseSliceExtra,
5 Point, PointInput, PointVisitor, Resolve, Singular, Tagged, ToOutput, Topological, Traversible,
6 length_prefixed::Lp,
7};
8
9#[derive(Clone)]
10pub struct WithKey<K, Extra> {
11 pub key: K,
12 pub extra: Extra,
13}
14
15pub trait Key: 'static + Sized + Send + Sync + Clone {
16 fn encrypt(&self, data: &[u8]) -> Vec<u8>;
17 fn decrypt(&self, data: &[u8]) -> object_rainbow::Result<Vec<u8>>;
18}
19
20type Resolution<K> = Arc<Lp<Vec<Point<Encrypted<K, Vec<u8>>>>>>;
21
22#[derive(ToOutput, Clone)]
23struct Unkeyed<T>(T);
24
25impl<
26 T: Parse<I::WithExtra<Extra>>,
27 K: 'static + Clone,
28 Extra: 'static + Clone,
29 I: PointInput<Extra = WithKey<K, Extra>>,
30> Parse<I> for Unkeyed<T>
31{
32 fn parse(input: I) -> object_rainbow::Result<Self> {
33 Ok(Self(T::parse(
34 input.map_extra(|WithKey { extra, .. }| extra),
35 )?))
36 }
37}
38
39#[derive(ToOutput, Parse)]
40struct EncryptedInner<K, T> {
41 resolution: Resolution<K>,
42 decrypted: Unkeyed<Arc<T>>,
43}
44
45impl<K, T> Clone for EncryptedInner<K, T> {
46 fn clone(&self) -> Self {
47 Self {
48 resolution: self.resolution.clone(),
49 decrypted: self.decrypted.clone(),
50 }
51 }
52}
53
54type ResolutionIter<'a, K> = std::slice::Iter<'a, Point<Encrypted<K, Vec<u8>>>>;
55
56struct IterateResolution<'a, K, V> {
57 resolution: ResolutionIter<'a, K>,
58 visitor: &'a mut V,
59}
60
61struct Visited<K, T> {
62 decrypted: Point<T>,
63 encrypted: Point<Encrypted<K, Vec<u8>>>,
64}
65
66impl<K, T> FetchBytes for Visited<K, T> {
67 fn fetch_bytes(&'_ self) -> FailFuture<'_, ByteNode> {
68 self.encrypted.fetch_bytes()
69 }
70
71 fn fetch_data(&'_ self) -> FailFuture<'_, Vec<u8>> {
72 self.encrypted.fetch_data()
73 }
74}
75
76impl<K: Key, T: Traversible> Fetch for Visited<K, T> {
77 type T = Encrypted<K, T>;
78
79 fn fetch_full(&'_ self) -> FailFuture<'_, (Self::T, Arc<dyn Resolve>)> {
80 Box::pin(async move {
81 let (
82 Encrypted {
83 key,
84 inner:
85 EncryptedInner {
86 resolution,
87 decrypted: _,
88 },
89 },
90 resolve,
91 ) = self.encrypted.fetch_full().await?;
92 let decrypted = self.decrypted.fetch().await?;
93 let decrypted = Unkeyed(Arc::new(decrypted));
94 Ok((
95 Encrypted {
96 key,
97 inner: EncryptedInner {
98 resolution,
99 decrypted,
100 },
101 },
102 resolve,
103 ))
104 })
105 }
106
107 fn fetch(&'_ self) -> FailFuture<'_, Self::T> {
108 Box::pin(async move {
109 let Encrypted {
110 key,
111 inner:
112 EncryptedInner {
113 resolution,
114 decrypted: _,
115 },
116 } = self.encrypted.fetch().await?;
117 let decrypted = self.decrypted.fetch().await?;
118 let decrypted = Unkeyed(Arc::new(decrypted));
119 Ok(Encrypted {
120 key,
121 inner: EncryptedInner {
122 resolution,
123 decrypted,
124 },
125 })
126 })
127 }
128}
129
130impl<'a, K: Key, V: PointVisitor> PointVisitor for IterateResolution<'a, K, V> {
131 fn visit<T: Traversible>(&mut self, decrypted: &Point<T>) {
132 let decrypted = decrypted.clone();
133 let encrypted = self.resolution.next().expect("length mismatch").clone();
134 let point = Point::from_origin(
135 encrypted.hash(),
136 Arc::new(Visited {
137 decrypted,
138 encrypted,
139 }),
140 );
141 self.visitor.visit(&point);
142 }
143}
144
145impl<K: Key, T: Topological> Topological for EncryptedInner<K, T> {
146 fn accept_points(&self, visitor: &mut impl PointVisitor) {
147 self.decrypted.0.accept_points(&mut IterateResolution {
148 resolution: self.resolution.iter(),
149 visitor,
150 });
151 }
152
153 fn point_count(&self) -> usize {
154 self.resolution.len()
155 }
156
157 fn topology_hash(&self) -> Hash {
158 self.resolution.0.data_hash()
159 }
160}
161
162pub struct Encrypted<K, T> {
163 key: K,
164 inner: EncryptedInner<K, T>,
165}
166
167impl<K, T: Clone> Encrypted<K, T> {
168 pub fn into_inner(self) -> T {
169 Arc::unwrap_or_clone(self.inner.decrypted.0)
170 }
171}
172
173impl<K, T> Deref for Encrypted<K, T> {
174 type Target = T;
175
176 fn deref(&self) -> &Self::Target {
177 self.inner.decrypted.0.as_ref()
178 }
179}
180
181impl<K: Clone, T> Clone for Encrypted<K, T> {
182 fn clone(&self) -> Self {
183 Self {
184 key: self.key.clone(),
185 inner: self.inner.clone(),
186 }
187 }
188}
189
190impl<K: Key, T: Topological> Topological for Encrypted<K, T> {
191 fn accept_points(&self, visitor: &mut impl PointVisitor) {
192 self.inner.accept_points(visitor);
193 }
194
195 fn topology_hash(&self) -> Hash {
196 self.inner.topology_hash()
197 }
198}
199
200impl<K: Key, T: ToOutput> ToOutput for Encrypted<K, T> {
201 fn to_output(&self, output: &mut dyn object_rainbow::Output) {
202 let source = self.inner.vec();
203 output.write(&self.key.encrypt(&source));
204 }
205}
206
207#[derive(Clone)]
208struct Decrypt<K> {
209 resolution: Resolution<K>,
210}
211
212impl<K: Key> Decrypt<K> {
213 async fn resolve_bytes(
214 &self,
215 address: Address,
216 ) -> object_rainbow::Result<(Vec<u8>, Resolution<K>)> {
217 let Encrypted {
218 key: _,
219 inner:
220 EncryptedInner {
221 resolution,
222 decrypted,
223 },
224 } = self
225 .resolution
226 .get(address.index)
227 .ok_or(Error::AddressOutOfBounds)?
228 .clone()
229 .fetch()
230 .await?;
231 let data = Arc::unwrap_or_clone(decrypted.0);
232 Ok((data, resolution))
233 }
234}
235
236impl<K: Key> Resolve for Decrypt<K> {
237 fn resolve(&'_ self, address: Address) -> FailFuture<'_, ByteNode> {
238 Box::pin(async move {
239 let (data, resolution) = self.resolve_bytes(address).await?;
240 Ok((data, Arc::new(Decrypt { resolution }) as _))
241 })
242 }
243
244 fn resolve_data(&'_ self, address: Address) -> FailFuture<'_, Vec<u8>> {
245 Box::pin(async move {
246 let (data, _) = self.resolve_bytes(address).await?;
247 Ok(data)
248 })
249 }
250
251 fn name(&self) -> &str {
252 "decrypt"
253 }
254}
255
256impl<
257 K: Key,
258 T: Object<Extra>,
259 Extra: 'static + Send + Sync + Clone,
260 I: PointInput<Extra = WithKey<K, Extra>>,
261> Parse<I> for Encrypted<K, T>
262{
263 fn parse(input: I) -> object_rainbow::Result<Self> {
264 let with_key = input.extra().clone();
265 let resolve = input.resolve().clone();
266 let source = with_key.key.decrypt(&input.parse_all()?)?;
267 let EncryptedInner {
268 resolution,
269 decrypted,
270 } = EncryptedInner::<K, Vec<u8>>::parse_slice_extra(&source, &resolve, &with_key)?;
271 let decrypted = T::parse_slice_extra(
272 &decrypted.0,
273 &(Arc::new(Decrypt {
274 resolution: resolution.clone(),
275 }) as _),
276 &with_key.extra,
277 )?;
278 let decrypted = Unkeyed(Arc::new(decrypted));
279 let inner = EncryptedInner {
280 resolution,
281 decrypted,
282 };
283 Ok(Self {
284 key: with_key.key,
285 inner,
286 })
287 }
288}
289
290impl<K, T> Tagged for Encrypted<K, T> {}
291
292impl<K: Key, T: Object<Extra>, Extra: 'static + Send + Sync + Clone> Object<WithKey<K, Extra>>
293 for Encrypted<K, T>
294{
295}
296
297type Extracted<K> = Vec<
298 std::pin::Pin<
299 Box<dyn Future<Output = Result<Point<Encrypted<K, Vec<u8>>>, Error>> + Send + 'static>,
300 >,
301>;
302
303struct ExtractResolution<'a, K> {
304 extracted: &'a mut Extracted<K>,
305 key: &'a K,
306}
307
308struct Untyped<K, T> {
309 key: WithKey<K, ()>,
310 encrypted: Point<Encrypted<K, T>>,
311}
312
313impl<K, T> FetchBytes for Untyped<K, T> {
314 fn fetch_bytes(&'_ self) -> FailFuture<'_, ByteNode> {
315 self.encrypted.fetch_bytes()
316 }
317
318 fn fetch_data(&'_ self) -> FailFuture<'_, Vec<u8>> {
319 self.encrypted.fetch_data()
320 }
321}
322
323impl<K: Key, T> Fetch for Untyped<K, T> {
324 type T = Encrypted<K, Vec<u8>>;
325
326 fn fetch_full(&'_ self) -> FailFuture<'_, (Self::T, Arc<dyn Resolve>)> {
327 Box::pin(async move {
328 let (data, resolve) = self.fetch_bytes().await?;
329 let encrypted = Self::T::parse_slice_extra(&data, &resolve, &self.key)?;
330 Ok((encrypted, resolve))
331 })
332 }
333
334 fn fetch(&'_ self) -> FailFuture<'_, Self::T> {
335 Box::pin(async move {
336 let (data, resolve) = self.fetch_bytes().await?;
337 let encrypted = Self::T::parse_slice_extra(&data, &resolve, &self.key)?;
338 Ok(encrypted)
339 })
340 }
341}
342
343impl<K: Key> PointVisitor for ExtractResolution<'_, K> {
344 fn visit<T: Traversible>(&mut self, decrypted: &Point<T>) {
345 let decrypted = decrypted.clone();
346 let key = self.key.clone();
347 self.extracted.push(Box::pin(async move {
348 let encrypted = encrypt_point(key.clone(), decrypted).await?;
349 let encrypted = Point::from_origin(
350 encrypted.hash(),
351 Arc::new(Untyped {
352 key: WithKey { key, extra: () },
353 encrypted,
354 }),
355 );
356 Ok(encrypted)
357 }));
358 }
359}
360
361pub async fn encrypt_point<K: Key, T: Traversible>(
362 key: K,
363 decrypted: Point<T>,
364) -> object_rainbow::Result<Point<Encrypted<K, T>>> {
365 if let Some((address, decrypt)) = decrypted.extract_resolve::<Decrypt<K>>() {
366 let encrypted = decrypt
367 .resolution
368 .get(address.index)
369 .ok_or(Error::AddressOutOfBounds)?
370 .clone();
371 let point = Point::from_origin(
372 encrypted.hash(),
373 Arc::new(Visited {
374 decrypted,
375 encrypted,
376 }),
377 );
378 return Ok(point);
379 }
380 let decrypted = decrypted.fetch().await?;
381 let encrypted = encrypt(key.clone(), decrypted).await?;
382 let point = encrypted.point();
383 Ok(point)
384}
385
386pub async fn encrypt<K: Key, T: Traversible>(
387 key: K,
388 decrypted: T,
389) -> object_rainbow::Result<Encrypted<K, T>> {
390 let mut futures = Vec::with_capacity(decrypted.point_count());
391 decrypted.accept_points(&mut ExtractResolution {
392 extracted: &mut futures,
393 key: &key,
394 });
395 let resolution = futures_util::future::try_join_all(futures).await?;
396 let resolution = Arc::new(Lp(resolution));
397 let decrypted = Unkeyed(Arc::new(decrypted));
398 let inner = EncryptedInner {
399 resolution,
400 decrypted,
401 };
402 Ok(Encrypted { key, inner })
403}