1use std::{ops::Deref, sync::Arc};
2
3use object_rainbow::{
4 Address, ByteNode, Error, ExtractResolve, FailFuture, Fetch, FetchBytes, FullHash, Hash,
5 ListHashes, Node, Object, Parse, ParseSliceExtra, Point, PointInput, PointVisitor, Resolve,
6 Singular, SingularFetch, Tagged, ToOutput, Topological, Traversible, length_prefixed::Lp,
7};
8
9#[derive(Clone)]
10pub struct WithKey<K, Extra> {
11 pub key: K,
12 pub extra: Extra,
13}
14
15pub trait Key: 'static + Sized + Send + Sync + Clone {
16 fn encrypt(&self, data: &[u8]) -> Vec<u8>;
17 fn decrypt(&self, data: &[u8]) -> object_rainbow::Result<Vec<u8>>;
18}
19
20type Resolution<K> = Arc<Lp<Vec<Point<Encrypted<K, Vec<u8>>>>>>;
21
22#[derive(ToOutput, Clone)]
23struct Unkeyed<T>(T);
24
25impl<
26 T: Parse<I::WithExtra<Extra>>,
27 K: 'static + Clone,
28 Extra: 'static + Clone,
29 I: PointInput<Extra = WithKey<K, Extra>>,
30> Parse<I> for Unkeyed<T>
31{
32 fn parse(input: I) -> object_rainbow::Result<Self> {
33 Ok(Self(T::parse(
34 input.map_extra(|WithKey { extra, .. }| extra),
35 )?))
36 }
37}
38
39#[derive(ToOutput, Parse)]
40struct EncryptedInner<K, T> {
41 resolution: Resolution<K>,
42 decrypted: Unkeyed<Arc<T>>,
43}
44
45impl<K, T> Clone for EncryptedInner<K, T> {
46 fn clone(&self) -> Self {
47 Self {
48 resolution: self.resolution.clone(),
49 decrypted: self.decrypted.clone(),
50 }
51 }
52}
53
54type ResolutionIter<'a, K> = std::slice::Iter<'a, Point<Encrypted<K, Vec<u8>>>>;
55
56struct IterateResolution<'a, 'r, K, V> {
57 resolution: &'r mut ResolutionIter<'a, K>,
58 visitor: &'a mut V,
59}
60
61struct Visited<K, P> {
62 decrypted: P,
63 encrypted: Point<Encrypted<K, Vec<u8>>>,
64}
65
66impl<K, P> FetchBytes for Visited<K, P> {
67 fn fetch_bytes(&'_ self) -> FailFuture<'_, ByteNode> {
68 self.encrypted.fetch_bytes()
69 }
70
71 fn fetch_data(&'_ self) -> FailFuture<'_, Vec<u8>> {
72 self.encrypted.fetch_data()
73 }
74
75 fn fetch_bytes_local(&self) -> object_rainbow::Result<Option<ByteNode>> {
76 self.encrypted.fetch_bytes_local()
77 }
78
79 fn fetch_data_local(&self) -> Option<Vec<u8>> {
80 self.encrypted.fetch_data_local()
81 }
82}
83
84impl<K, P: Send + Sync> Singular for Visited<K, P> {
85 fn hash(&self) -> Hash {
86 self.encrypted.hash()
87 }
88}
89
90impl<K: Key, P: Fetch<T: Traversible>> Fetch for Visited<K, P> {
91 type T = Encrypted<K, P::T>;
92
93 fn fetch_full(&'_ self) -> FailFuture<'_, Node<Self::T>> {
94 Box::pin(async move {
95 let (
96 Encrypted {
97 key,
98 inner:
99 EncryptedInner {
100 resolution,
101 decrypted: _,
102 },
103 },
104 resolve,
105 ) = self.encrypted.fetch_full().await?;
106 let decrypted = self.decrypted.fetch().await?;
107 let decrypted = Unkeyed(Arc::new(decrypted));
108 Ok((
109 Encrypted {
110 key,
111 inner: EncryptedInner {
112 resolution,
113 decrypted,
114 },
115 },
116 resolve,
117 ))
118 })
119 }
120
121 fn fetch(&'_ self) -> FailFuture<'_, Self::T> {
122 Box::pin(async move {
123 let Encrypted {
124 key,
125 inner:
126 EncryptedInner {
127 resolution,
128 decrypted: _,
129 },
130 } = self.encrypted.fetch().await?;
131 let decrypted = self.decrypted.fetch().await?;
132 let decrypted = Unkeyed(Arc::new(decrypted));
133 Ok(Encrypted {
134 key,
135 inner: EncryptedInner {
136 resolution,
137 decrypted,
138 },
139 })
140 })
141 }
142
143 fn try_fetch_local(&self) -> object_rainbow::Result<Option<Node<Self::T>>> {
144 let Some((
145 Encrypted {
146 key,
147 inner:
148 EncryptedInner {
149 resolution,
150 decrypted: _,
151 },
152 },
153 resolve,
154 )) = self.encrypted.try_fetch_local()?
155 else {
156 return Ok(None);
157 };
158 let Some((decrypted, _)) = self.decrypted.try_fetch_local()? else {
159 return Ok(None);
160 };
161 let decrypted = Unkeyed(Arc::new(decrypted));
162 Ok(Some((
163 Encrypted {
164 key,
165 inner: EncryptedInner {
166 resolution,
167 decrypted,
168 },
169 },
170 resolve,
171 )))
172 }
173
174 fn fetch_local(&self) -> Option<Self::T> {
175 let Encrypted {
176 key,
177 inner:
178 EncryptedInner {
179 resolution,
180 decrypted: _,
181 },
182 } = self.encrypted.fetch_local()?;
183 let decrypted = Unkeyed(Arc::new(self.decrypted.fetch_local()?));
184 Some(Encrypted {
185 key,
186 inner: EncryptedInner {
187 resolution,
188 decrypted,
189 },
190 })
191 }
192}
193
194impl<'a, K: Key, V: PointVisitor> PointVisitor for IterateResolution<'a, '_, K, V> {
195 fn visit<T: Traversible>(&mut self, decrypted: &(impl 'static + SingularFetch<T = T> + Clone)) {
196 let decrypted = decrypted.clone();
197 let encrypted = self.resolution.next().expect("length mismatch").clone();
198 let point = Point::from_fetch(
199 encrypted.hash(),
200 Arc::new(Visited {
201 decrypted,
202 encrypted,
203 }),
204 );
205 self.visitor.visit(&point);
206 }
207}
208
209impl<K, T> ListHashes for EncryptedInner<K, T> {
210 fn list_hashes(&self, f: &mut impl FnMut(Hash)) {
211 self.resolution.list_hashes(f);
212 }
213
214 fn topology_hash(&self) -> Hash {
215 self.resolution.0.data_hash()
216 }
217
218 fn point_count(&self) -> usize {
219 self.resolution.len()
220 }
221}
222
223impl<K: Key, T: Topological> Topological for EncryptedInner<K, T> {
224 fn traverse(&self, visitor: &mut impl PointVisitor) {
225 let resolution = &mut self.resolution.iter();
226 self.decrypted.0.traverse(&mut IterateResolution {
227 resolution,
228 visitor,
229 });
230 assert!(resolution.next().is_none());
231 }
232}
233
234pub struct Encrypted<K, T> {
235 key: K,
236 inner: EncryptedInner<K, T>,
237}
238
239impl<K, T: Clone> Encrypted<K, T> {
240 pub fn into_inner(self) -> T {
241 Arc::unwrap_or_clone(self.inner.decrypted.0)
242 }
243}
244
245impl<K, T> Deref for Encrypted<K, T> {
246 type Target = T;
247
248 fn deref(&self) -> &Self::Target {
249 self.inner.decrypted.0.as_ref()
250 }
251}
252
253impl<K: Clone, T> Clone for Encrypted<K, T> {
254 fn clone(&self) -> Self {
255 Self {
256 key: self.key.clone(),
257 inner: self.inner.clone(),
258 }
259 }
260}
261
262impl<K, T> ListHashes for Encrypted<K, T> {
263 fn list_hashes(&self, f: &mut impl FnMut(Hash)) {
264 self.inner.list_hashes(f);
265 }
266
267 fn topology_hash(&self) -> Hash {
268 self.inner.topology_hash()
269 }
270
271 fn point_count(&self) -> usize {
272 self.inner.point_count()
273 }
274}
275
276impl<K: Key, T: Topological> Topological for Encrypted<K, T> {
277 fn traverse(&self, visitor: &mut impl PointVisitor) {
278 self.inner.traverse(visitor);
279 }
280}
281
282impl<K: Key, T: ToOutput> ToOutput for Encrypted<K, T> {
283 fn to_output(&self, output: &mut dyn object_rainbow::Output) {
284 let source = self.inner.vec();
285 output.write(&self.key.encrypt(&source));
286 }
287}
288
289#[derive(Clone)]
290struct Decrypt<K> {
291 resolution: Resolution<K>,
292}
293
294impl<K: Key> Decrypt<K> {
295 async fn resolve_bytes(
296 &self,
297 address: Address,
298 ) -> object_rainbow::Result<(Vec<u8>, Resolution<K>)> {
299 let Encrypted {
300 key: _,
301 inner:
302 EncryptedInner {
303 resolution,
304 decrypted,
305 },
306 } = self
307 .resolution
308 .get(address.index)
309 .ok_or(Error::AddressOutOfBounds)?
310 .clone()
311 .fetch()
312 .await?;
313 let data = Arc::unwrap_or_clone(decrypted.0);
314 Ok((data, resolution))
315 }
316}
317
318impl<K: Key> Resolve for Decrypt<K> {
319 fn resolve(&'_ self, address: Address) -> FailFuture<'_, ByteNode> {
320 Box::pin(async move {
321 let (data, resolution) = self.resolve_bytes(address).await?;
322 Ok((data, Arc::new(Decrypt { resolution }) as _))
323 })
324 }
325
326 fn resolve_data(&'_ self, address: Address) -> FailFuture<'_, Vec<u8>> {
327 Box::pin(async move {
328 let (data, _) = self.resolve_bytes(address).await?;
329 Ok(data)
330 })
331 }
332
333 fn try_resolve_local(&self, address: Address) -> object_rainbow::Result<Option<ByteNode>> {
334 let Some((
335 Encrypted {
336 key: _,
337 inner:
338 EncryptedInner {
339 resolution,
340 decrypted,
341 },
342 },
343 _,
344 )) = self
345 .resolution
346 .get(address.index)
347 .ok_or(Error::AddressOutOfBounds)?
348 .clone()
349 .try_fetch_local()?
350 else {
351 return Ok(None);
352 };
353 let data = Arc::unwrap_or_clone(decrypted.0);
354 Ok(Some((data, Arc::new(Decrypt { resolution }) as _)))
355 }
356}
357
358impl<
359 K: Key,
360 T: Object<Extra>,
361 Extra: 'static + Send + Sync + Clone,
362 I: PointInput<Extra = WithKey<K, Extra>>,
363> Parse<I> for Encrypted<K, T>
364{
365 fn parse(input: I) -> object_rainbow::Result<Self> {
366 let with_key = input.extra().clone();
367 let resolve = input.resolve().clone();
368 let source = with_key.key.decrypt(&input.parse_all()?)?;
369 let EncryptedInner {
370 resolution,
371 decrypted,
372 } = EncryptedInner::<K, Vec<u8>>::parse_slice_extra(&source, &resolve, &with_key)?;
373 let decrypted = T::parse_slice_extra(
374 &decrypted.0,
375 &(Arc::new(Decrypt {
376 resolution: resolution.clone(),
377 }) as _),
378 &with_key.extra,
379 )?;
380 let decrypted = Unkeyed(Arc::new(decrypted));
381 let inner = EncryptedInner {
382 resolution,
383 decrypted,
384 };
385 Ok(Self {
386 key: with_key.key,
387 inner,
388 })
389 }
390}
391
392impl<K, T> Tagged for Encrypted<K, T> {}
393
394type Extracted<K> = Vec<
395 std::pin::Pin<
396 Box<dyn Future<Output = Result<Point<Encrypted<K, Vec<u8>>>, Error>> + Send + 'static>,
397 >,
398>;
399
400struct ExtractResolution<'a, K> {
401 extracted: &'a mut Extracted<K>,
402 key: &'a K,
403}
404
405struct Untyped<K, T> {
406 key: WithKey<K, ()>,
407 encrypted: Point<Encrypted<K, T>>,
408}
409
410impl<K, T> FetchBytes for Untyped<K, T> {
411 fn fetch_bytes(&'_ self) -> FailFuture<'_, ByteNode> {
412 self.encrypted.fetch_bytes()
413 }
414
415 fn fetch_data(&'_ self) -> FailFuture<'_, Vec<u8>> {
416 self.encrypted.fetch_data()
417 }
418
419 fn fetch_bytes_local(&self) -> object_rainbow::Result<Option<ByteNode>> {
420 self.encrypted.fetch_bytes_local()
421 }
422
423 fn fetch_data_local(&self) -> Option<Vec<u8>> {
424 self.encrypted.fetch_data_local()
425 }
426}
427
428impl<K: Send + Sync, T> Singular for Untyped<K, T> {
429 fn hash(&self) -> Hash {
430 self.encrypted.hash()
431 }
432}
433
434impl<K: Key, T: FullHash> Fetch for Untyped<K, T> {
435 type T = Encrypted<K, Vec<u8>>;
436
437 fn fetch_full(&'_ self) -> FailFuture<'_, Node<Self::T>> {
438 Box::pin(async move {
439 let (data, resolve) = self.fetch_bytes().await?;
440 let encrypted = Self::T::parse_slice_extra(&data, &resolve, &self.key)?;
441 Ok((encrypted, resolve))
442 })
443 }
444
445 fn fetch(&'_ self) -> FailFuture<'_, Self::T> {
446 Box::pin(async move {
447 let (data, resolve) = self.fetch_bytes().await?;
448 let encrypted = Self::T::parse_slice_extra(&data, &resolve, &self.key)?;
449 Ok(encrypted)
450 })
451 }
452
453 fn try_fetch_local(&self) -> object_rainbow::Result<Option<Node<Self::T>>> {
454 let Some((data, resolve)) = self.fetch_bytes_local()? else {
455 return Ok(None);
456 };
457 let encrypted = Self::T::parse_slice_extra(&data, &resolve, &self.key)?;
458 Ok(Some((encrypted, resolve)))
459 }
460
461 fn fetch_local(&self) -> Option<Self::T> {
462 let Encrypted {
463 key,
464 inner:
465 EncryptedInner {
466 resolution,
467 decrypted,
468 },
469 } = self.encrypted.fetch_local()?;
470 let decrypted = Unkeyed(Arc::new(decrypted.vec()));
471 Some(Encrypted {
472 key,
473 inner: EncryptedInner {
474 resolution,
475 decrypted,
476 },
477 })
478 }
479}
480
481impl<K: Key> PointVisitor for ExtractResolution<'_, K> {
482 fn visit<T: Traversible>(&mut self, decrypted: &(impl 'static + SingularFetch<T = T> + Clone)) {
483 let decrypted = decrypted.clone();
484 let key = self.key.clone();
485 self.extracted.push(Box::pin(async move {
486 let encrypted = encrypt_point(key.clone(), decrypted).await?;
487 let encrypted = Point::from_fetch(
488 encrypted.hash(),
489 Arc::new(Untyped {
490 key: WithKey { key, extra: () },
491 encrypted,
492 }),
493 );
494 Ok(encrypted)
495 }));
496 }
497}
498
499pub async fn encrypt_point<K: Key, T: Traversible>(
500 key: K,
501 decrypted: impl 'static + SingularFetch<T = T>,
502) -> object_rainbow::Result<Point<Encrypted<K, T>>> {
503 if let Some((address, decrypt)) = decrypted.extract_resolve::<Decrypt<K>>() {
504 let encrypted = decrypt
505 .resolution
506 .get(address.index)
507 .ok_or(Error::AddressOutOfBounds)?
508 .clone();
509 let point = Point::from_fetch(
510 encrypted.hash(),
511 Arc::new(Visited {
512 decrypted,
513 encrypted,
514 }),
515 );
516 return Ok(point);
517 }
518 let decrypted = decrypted.fetch().await?;
519 let encrypted = encrypt(key.clone(), decrypted).await?;
520 let point = encrypted.point();
521 Ok(point)
522}
523
524pub async fn encrypt<K: Key, T: Traversible>(
525 key: K,
526 decrypted: T,
527) -> object_rainbow::Result<Encrypted<K, T>> {
528 let mut futures = Vec::with_capacity(decrypted.point_count());
529 decrypted.traverse(&mut ExtractResolution {
530 extracted: &mut futures,
531 key: &key,
532 });
533 let resolution = futures_util::future::try_join_all(futures).await?;
534 let resolution = Arc::new(Lp(resolution));
535 let decrypted = Unkeyed(Arc::new(decrypted));
536 let inner = EncryptedInner {
537 resolution,
538 decrypted,
539 };
540 Ok(Encrypted { key, inner })
541}