1use std::{ops::Deref, sync::Arc};
2
3use object_rainbow::{
4 Address, ByteNode, Error, FailFuture, Fetch, FetchBytes, FullHash, Hash, ListHashes, Node,
5 Object, Parse, ParseSliceExtra, PointInput, PointVisitor, Resolve, Singular, SingularFetch,
6 Tagged, ToOutput, Topological, Traversible, length_prefixed::Lp,
7};
8use object_rainbow_point::{ExtractResolve, IntoPoint, Point};
9
10pub trait Key: 'static + Sized + Send + Sync + Clone {
11 fn encrypt(&self, data: &[u8]) -> Vec<u8>;
12 fn decrypt(&self, data: &[u8]) -> object_rainbow::Result<Vec<u8>>;
13}
14
15type Resolution<K> = Arc<Lp<Vec<Point<Encrypted<K, Vec<u8>>>>>>;
16
17#[derive(ToOutput, Clone)]
18struct Unkeyed<T>(T);
19
20impl<
21 T: Parse<I::WithExtra<Extra>>,
22 K: 'static + Clone,
23 Extra: 'static + Clone,
24 I: PointInput<Extra = (K, Extra)>,
25> Parse<I> for Unkeyed<T>
26{
27 fn parse(input: I) -> object_rainbow::Result<Self> {
28 Ok(Self(T::parse(input.map_extra(|(_, extra)| extra))?))
29 }
30}
31
32#[derive(ToOutput, Parse)]
33struct EncryptedInner<K, T> {
34 resolution: Resolution<K>,
35 decrypted: Unkeyed<Arc<T>>,
36}
37
38impl<K, T> Clone for EncryptedInner<K, T> {
39 fn clone(&self) -> Self {
40 Self {
41 resolution: self.resolution.clone(),
42 decrypted: self.decrypted.clone(),
43 }
44 }
45}
46
47type ResolutionIter<'a, K> = std::slice::Iter<'a, Point<Encrypted<K, Vec<u8>>>>;
48
49struct IterateResolution<'a, 'r, K, V> {
50 resolution: &'r mut ResolutionIter<'a, K>,
51 visitor: &'a mut V,
52}
53
54struct Visited<K, P> {
55 decrypted: P,
56 encrypted: Point<Encrypted<K, Vec<u8>>>,
57}
58
59impl<K, P> FetchBytes for Visited<K, P> {
60 fn fetch_bytes(&'_ self) -> FailFuture<'_, ByteNode> {
61 self.encrypted.fetch_bytes()
62 }
63
64 fn fetch_data(&'_ self) -> FailFuture<'_, Vec<u8>> {
65 self.encrypted.fetch_data()
66 }
67
68 fn fetch_bytes_local(&self) -> object_rainbow::Result<Option<ByteNode>> {
69 self.encrypted.fetch_bytes_local()
70 }
71
72 fn fetch_data_local(&self) -> Option<Vec<u8>> {
73 self.encrypted.fetch_data_local()
74 }
75}
76
77impl<K, P: Send + Sync> Singular for Visited<K, P> {
78 fn hash(&self) -> Hash {
79 self.encrypted.hash()
80 }
81}
82
83impl<K: Key, P: Fetch<T: Traversible>> Fetch for Visited<K, P> {
84 type T = Encrypted<K, P::T>;
85
86 fn fetch_full(&'_ self) -> FailFuture<'_, Node<Self::T>> {
87 Box::pin(async move {
88 let (
89 Encrypted {
90 key,
91 inner:
92 EncryptedInner {
93 resolution,
94 decrypted: _,
95 },
96 },
97 resolve,
98 ) = self.encrypted.fetch_full().await?;
99 let decrypted = self.decrypted.fetch().await?;
100 let decrypted = Unkeyed(Arc::new(decrypted));
101 Ok((
102 Encrypted {
103 key,
104 inner: EncryptedInner {
105 resolution,
106 decrypted,
107 },
108 },
109 resolve,
110 ))
111 })
112 }
113
114 fn fetch(&'_ self) -> FailFuture<'_, Self::T> {
115 Box::pin(async move {
116 let Encrypted {
117 key,
118 inner:
119 EncryptedInner {
120 resolution,
121 decrypted: _,
122 },
123 } = self.encrypted.fetch().await?;
124 let decrypted = self.decrypted.fetch().await?;
125 let decrypted = Unkeyed(Arc::new(decrypted));
126 Ok(Encrypted {
127 key,
128 inner: EncryptedInner {
129 resolution,
130 decrypted,
131 },
132 })
133 })
134 }
135
136 fn try_fetch_local(&self) -> object_rainbow::Result<Option<Node<Self::T>>> {
137 let Some((
138 Encrypted {
139 key,
140 inner:
141 EncryptedInner {
142 resolution,
143 decrypted: _,
144 },
145 },
146 resolve,
147 )) = self.encrypted.try_fetch_local()?
148 else {
149 return Ok(None);
150 };
151 let Some((decrypted, _)) = self.decrypted.try_fetch_local()? else {
152 return Ok(None);
153 };
154 let decrypted = Unkeyed(Arc::new(decrypted));
155 Ok(Some((
156 Encrypted {
157 key,
158 inner: EncryptedInner {
159 resolution,
160 decrypted,
161 },
162 },
163 resolve,
164 )))
165 }
166
167 fn fetch_local(&self) -> Option<Self::T> {
168 let Encrypted {
169 key,
170 inner:
171 EncryptedInner {
172 resolution,
173 decrypted: _,
174 },
175 } = self.encrypted.fetch_local()?;
176 let decrypted = Unkeyed(Arc::new(self.decrypted.fetch_local()?));
177 Some(Encrypted {
178 key,
179 inner: EncryptedInner {
180 resolution,
181 decrypted,
182 },
183 })
184 }
185}
186
187impl<'a, K: Key, V: PointVisitor> PointVisitor for IterateResolution<'a, '_, K, V> {
188 fn visit<T: Traversible>(&mut self, decrypted: &(impl 'static + SingularFetch<T = T> + Clone)) {
189 let decrypted = decrypted.clone();
190 let encrypted = self.resolution.next().expect("length mismatch").clone();
191 let point = Point::from_fetch(
192 encrypted.hash(),
193 Arc::new(Visited {
194 decrypted,
195 encrypted,
196 }),
197 );
198 self.visitor.visit(&point);
199 }
200}
201
202impl<K, T> ListHashes for EncryptedInner<K, T> {
203 fn list_hashes(&self, f: &mut impl FnMut(Hash)) {
204 self.resolution.list_hashes(f);
205 }
206
207 fn topology_hash(&self) -> Hash {
208 self.resolution.0.data_hash()
209 }
210
211 fn point_count(&self) -> usize {
212 self.resolution.len()
213 }
214}
215
216impl<K: Key, T: Topological> Topological for EncryptedInner<K, T> {
217 fn traverse(&self, visitor: &mut impl PointVisitor) {
218 let resolution = &mut self.resolution.iter();
219 self.decrypted.0.traverse(&mut IterateResolution {
220 resolution,
221 visitor,
222 });
223 assert!(resolution.next().is_none());
224 }
225}
226
227pub struct Encrypted<K, T> {
228 key: K,
229 inner: EncryptedInner<K, T>,
230}
231
232impl<K, T: Clone> Encrypted<K, T> {
233 pub fn into_inner(self) -> T {
234 Arc::unwrap_or_clone(self.inner.decrypted.0)
235 }
236}
237
238impl<K, T> Deref for Encrypted<K, T> {
239 type Target = T;
240
241 fn deref(&self) -> &Self::Target {
242 self.inner.decrypted.0.as_ref()
243 }
244}
245
246impl<K: Clone, T> Clone for Encrypted<K, T> {
247 fn clone(&self) -> Self {
248 Self {
249 key: self.key.clone(),
250 inner: self.inner.clone(),
251 }
252 }
253}
254
255impl<K, T> ListHashes for Encrypted<K, T> {
256 fn list_hashes(&self, f: &mut impl FnMut(Hash)) {
257 self.inner.list_hashes(f);
258 }
259
260 fn topology_hash(&self) -> Hash {
261 self.inner.topology_hash()
262 }
263
264 fn point_count(&self) -> usize {
265 self.inner.point_count()
266 }
267}
268
269impl<K: Key, T: Topological> Topological for Encrypted<K, T> {
270 fn traverse(&self, visitor: &mut impl PointVisitor) {
271 self.inner.traverse(visitor);
272 }
273}
274
275impl<K: Key, T: ToOutput> ToOutput for Encrypted<K, T> {
276 fn to_output(&self, output: &mut dyn object_rainbow::Output) {
277 let source = self.inner.vec();
278 output.write(&self.key.encrypt(&source));
279 }
280}
281
282#[derive(Clone)]
283struct Decrypt<K> {
284 resolution: Resolution<K>,
285}
286
287impl<K: Key> Decrypt<K> {
288 async fn resolve_bytes(
289 &self,
290 address: Address,
291 ) -> object_rainbow::Result<(Vec<u8>, Resolution<K>)> {
292 let Encrypted {
293 key: _,
294 inner:
295 EncryptedInner {
296 resolution,
297 decrypted,
298 },
299 } = self
300 .resolution
301 .get(address.index)
302 .ok_or(Error::AddressOutOfBounds)?
303 .clone()
304 .fetch()
305 .await?;
306 let data = Arc::unwrap_or_clone(decrypted.0);
307 Ok((data, resolution))
308 }
309}
310
311impl<K: Key> Resolve for Decrypt<K> {
312 fn resolve(&'_ self, address: Address) -> FailFuture<'_, ByteNode> {
313 Box::pin(async move {
314 let (data, resolution) = self.resolve_bytes(address).await?;
315 Ok((data, Arc::new(Decrypt { resolution }) as _))
316 })
317 }
318
319 fn resolve_data(&'_ self, address: Address) -> FailFuture<'_, Vec<u8>> {
320 Box::pin(async move {
321 let (data, _) = self.resolve_bytes(address).await?;
322 Ok(data)
323 })
324 }
325
326 fn try_resolve_local(&self, address: Address) -> object_rainbow::Result<Option<ByteNode>> {
327 let Some((
328 Encrypted {
329 key: _,
330 inner:
331 EncryptedInner {
332 resolution,
333 decrypted,
334 },
335 },
336 _,
337 )) = self
338 .resolution
339 .get(address.index)
340 .ok_or(Error::AddressOutOfBounds)?
341 .clone()
342 .try_fetch_local()?
343 else {
344 return Ok(None);
345 };
346 let data = Arc::unwrap_or_clone(decrypted.0);
347 Ok(Some((data, Arc::new(Decrypt { resolution }) as _)))
348 }
349}
350
351trait EncryptedExtra<K>: 'static + Send + Sync + Clone {
352 type Extra: 'static + Send + Sync + Clone;
353 fn parts(&self) -> (K, Self::Extra);
354}
355
356impl<K: 'static + Send + Sync + Clone, Extra: 'static + Send + Sync + Clone> EncryptedExtra<K>
357 for (K, Extra)
358{
359 type Extra = Extra;
360
361 fn parts(&self) -> (K, Self::Extra) {
362 self.clone()
363 }
364}
365
366impl<K: 'static + Send + Sync + Clone> EncryptedExtra<K> for K {
367 type Extra = ();
368
369 fn parts(&self) -> (K, Self::Extra) {
370 (self.clone(), ())
371 }
372}
373
374impl<
375 K: Key,
376 T: Object<Extra>,
377 Extra: 'static + Send + Sync + Clone,
378 I: PointInput<Extra: EncryptedExtra<K, Extra = Extra>>,
379> Parse<I> for Encrypted<K, T>
380{
381 fn parse(input: I) -> object_rainbow::Result<Self> {
382 let with_key = input.extra().parts();
383 let resolve = input.resolve().clone();
384 let source = with_key.0.decrypt(&input.parse_all()?)?;
385 let EncryptedInner {
386 resolution,
387 decrypted,
388 } = EncryptedInner::<K, Vec<u8>>::parse_slice_extra(&source, &resolve, &with_key)?;
389 let decrypted = T::parse_slice_extra(
390 &decrypted.0,
391 &(Arc::new(Decrypt {
392 resolution: resolution.clone(),
393 }) as _),
394 &with_key.1,
395 )?;
396 let decrypted = Unkeyed(Arc::new(decrypted));
397 let inner = EncryptedInner {
398 resolution,
399 decrypted,
400 };
401 Ok(Self {
402 key: with_key.0,
403 inner,
404 })
405 }
406}
407
408impl<K, T> Tagged for Encrypted<K, T> {}
409
410type Extracted<K> = Vec<
411 std::pin::Pin<
412 Box<dyn Future<Output = Result<Point<Encrypted<K, Vec<u8>>>, Error>> + Send + 'static>,
413 >,
414>;
415
416struct ExtractResolution<'a, K> {
417 extracted: &'a mut Extracted<K>,
418 key: &'a K,
419}
420
421struct Untyped<K, T> {
422 key: K,
423 encrypted: Point<Encrypted<K, T>>,
424}
425
426impl<K, T> FetchBytes for Untyped<K, T> {
427 fn fetch_bytes(&'_ self) -> FailFuture<'_, ByteNode> {
428 self.encrypted.fetch_bytes()
429 }
430
431 fn fetch_data(&'_ self) -> FailFuture<'_, Vec<u8>> {
432 self.encrypted.fetch_data()
433 }
434
435 fn fetch_bytes_local(&self) -> object_rainbow::Result<Option<ByteNode>> {
436 self.encrypted.fetch_bytes_local()
437 }
438
439 fn fetch_data_local(&self) -> Option<Vec<u8>> {
440 self.encrypted.fetch_data_local()
441 }
442}
443
444impl<K: Send + Sync, T> Singular for Untyped<K, T> {
445 fn hash(&self) -> Hash {
446 self.encrypted.hash()
447 }
448}
449
450impl<K: Key, T: FullHash> Fetch for Untyped<K, T> {
451 type T = Encrypted<K, Vec<u8>>;
452
453 fn fetch_full(&'_ self) -> FailFuture<'_, Node<Self::T>> {
454 Box::pin(async move {
455 let (data, resolve) = self.fetch_bytes().await?;
456 let encrypted = Self::T::parse_slice_extra(&data, &resolve, &self.key)?;
457 Ok((encrypted, resolve))
458 })
459 }
460
461 fn fetch(&'_ self) -> FailFuture<'_, Self::T> {
462 Box::pin(async move {
463 let (data, resolve) = self.fetch_bytes().await?;
464 let encrypted = Self::T::parse_slice_extra(&data, &resolve, &self.key)?;
465 Ok(encrypted)
466 })
467 }
468
469 fn try_fetch_local(&self) -> object_rainbow::Result<Option<Node<Self::T>>> {
470 let Some((data, resolve)) = self.fetch_bytes_local()? else {
471 return Ok(None);
472 };
473 let encrypted = Self::T::parse_slice_extra(&data, &resolve, &self.key)?;
474 Ok(Some((encrypted, resolve)))
475 }
476
477 fn fetch_local(&self) -> Option<Self::T> {
478 let Encrypted {
479 key,
480 inner:
481 EncryptedInner {
482 resolution,
483 decrypted,
484 },
485 } = self.encrypted.fetch_local()?;
486 let decrypted = Unkeyed(Arc::new(decrypted.vec()));
487 Some(Encrypted {
488 key,
489 inner: EncryptedInner {
490 resolution,
491 decrypted,
492 },
493 })
494 }
495}
496
497impl<K: Key> PointVisitor for ExtractResolution<'_, K> {
498 fn visit<T: Traversible>(&mut self, decrypted: &(impl 'static + SingularFetch<T = T> + Clone)) {
499 let decrypted = decrypted.clone();
500 let key = self.key.clone();
501 self.extracted.push(Box::pin(async move {
502 let encrypted = encrypt_point(key.clone(), decrypted).await?;
503 let encrypted =
504 Point::from_fetch(encrypted.hash(), Arc::new(Untyped { key, encrypted }));
505 Ok(encrypted)
506 }));
507 }
508}
509
510pub async fn encrypt_point<K: Key, T: Traversible>(
511 key: K,
512 decrypted: impl 'static + SingularFetch<T = T>,
513) -> object_rainbow::Result<Point<Encrypted<K, T>>> {
514 if let Some((address, decrypt)) = decrypted.extract_resolve::<Decrypt<K>>() {
515 let encrypted = decrypt
516 .resolution
517 .get(address.index)
518 .ok_or(Error::AddressOutOfBounds)?
519 .clone();
520 let point = Point::from_fetch(
521 encrypted.hash(),
522 Arc::new(Visited {
523 decrypted,
524 encrypted,
525 }),
526 );
527 return Ok(point);
528 }
529 let decrypted = decrypted.fetch().await?;
530 let encrypted = encrypt(key.clone(), decrypted).await?;
531 let point = encrypted.point();
532 Ok(point)
533}
534
535pub async fn encrypt<K: Key, T: Traversible>(
536 key: K,
537 decrypted: T,
538) -> object_rainbow::Result<Encrypted<K, T>> {
539 let mut futures = Vec::with_capacity(decrypted.point_count());
540 decrypted.traverse(&mut ExtractResolution {
541 extracted: &mut futures,
542 key: &key,
543 });
544 let resolution = futures_util::future::try_join_all(futures).await?;
545 let resolution = Arc::new(Lp(resolution));
546 let decrypted = Unkeyed(Arc::new(decrypted));
547 let inner = EncryptedInner {
548 resolution,
549 decrypted,
550 };
551 Ok(Encrypted { key, inner })
552}