1use std::{ops::Deref, sync::Arc};
2
3use object_rainbow::{
4 Address, ByteNode, Error, FailFuture, Fetch, FetchBytes, FullHash, Hash, ListHashes, Node,
5 Object, Parse, ParseSliceExtra, PointInput, PointVisitor, Resolve, Singular, SingularFetch,
6 Tagged, ToOutput, Topological, Traversible, length_prefixed::Lp,
7};
8use object_rainbow_point::{ExtractResolve, IntoPoint, Point};
9
10pub trait Key: 'static + Sized + Send + Sync + Clone {
11 type Error: Into<anyhow::Error>;
12 fn encrypt(&self, data: &[u8]) -> Vec<u8>;
13 fn decrypt(&self, data: &[u8]) -> Result<Vec<u8>, Self::Error>;
14}
15
16type Resolution<K> = Arc<Lp<Vec<Point<Encrypted<K, Vec<u8>>>>>>;
17
18#[derive(ToOutput, Clone)]
19struct Unkeyed<T>(T);
20
21impl<
22 T: Parse<I::WithExtra<Extra>>,
23 K: 'static + Clone,
24 Extra: 'static + Clone,
25 I: PointInput<Extra = (K, Extra)>,
26> Parse<I> for Unkeyed<T>
27{
28 fn parse(input: I) -> object_rainbow::Result<Self> {
29 Ok(Self(T::parse(input.map_extra(|(_, extra)| extra))?))
30 }
31}
32
33#[derive(ToOutput, Parse)]
34struct EncryptedInner<K, T> {
35 resolution: Resolution<K>,
36 decrypted: Unkeyed<Arc<T>>,
37}
38
39impl<K, T> Clone for EncryptedInner<K, T> {
40 fn clone(&self) -> Self {
41 Self {
42 resolution: self.resolution.clone(),
43 decrypted: self.decrypted.clone(),
44 }
45 }
46}
47
48type ResolutionIter<'a, K> = std::slice::Iter<'a, Point<Encrypted<K, Vec<u8>>>>;
49
50struct IterateResolution<'a, 'r, K, V> {
51 resolution: &'r mut ResolutionIter<'a, K>,
52 visitor: &'a mut V,
53}
54
55struct Visited<K, P> {
56 decrypted: P,
57 encrypted: Point<Encrypted<K, Vec<u8>>>,
58}
59
60impl<K, P> FetchBytes for Visited<K, P> {
61 fn fetch_bytes(&'_ self) -> FailFuture<'_, ByteNode> {
62 self.encrypted.fetch_bytes()
63 }
64
65 fn fetch_data(&'_ self) -> FailFuture<'_, Vec<u8>> {
66 self.encrypted.fetch_data()
67 }
68
69 fn fetch_bytes_local(&self) -> object_rainbow::Result<Option<ByteNode>> {
70 self.encrypted.fetch_bytes_local()
71 }
72
73 fn fetch_data_local(&self) -> Option<Vec<u8>> {
74 self.encrypted.fetch_data_local()
75 }
76}
77
78impl<K, P: Send + Sync> Singular for Visited<K, P> {
79 fn hash(&self) -> Hash {
80 self.encrypted.hash()
81 }
82}
83
84impl<K: Key, P: Fetch<T: Traversible>> Fetch for Visited<K, P> {
85 type T = Encrypted<K, P::T>;
86
87 fn fetch_full(&'_ self) -> FailFuture<'_, Node<Self::T>> {
88 Box::pin(async move {
89 let (
90 Encrypted {
91 key,
92 inner:
93 EncryptedInner {
94 resolution,
95 decrypted: _,
96 },
97 },
98 resolve,
99 ) = self.encrypted.fetch_full().await?;
100 let decrypted = self.decrypted.fetch().await?;
101 let decrypted = Unkeyed(Arc::new(decrypted));
102 Ok((
103 Encrypted {
104 key,
105 inner: EncryptedInner {
106 resolution,
107 decrypted,
108 },
109 },
110 resolve,
111 ))
112 })
113 }
114
115 fn fetch(&'_ self) -> FailFuture<'_, Self::T> {
116 Box::pin(async move {
117 let Encrypted {
118 key,
119 inner:
120 EncryptedInner {
121 resolution,
122 decrypted: _,
123 },
124 } = self.encrypted.fetch().await?;
125 let decrypted = self.decrypted.fetch().await?;
126 let decrypted = Unkeyed(Arc::new(decrypted));
127 Ok(Encrypted {
128 key,
129 inner: EncryptedInner {
130 resolution,
131 decrypted,
132 },
133 })
134 })
135 }
136
137 fn try_fetch_local(&self) -> object_rainbow::Result<Option<Node<Self::T>>> {
138 let Some((
139 Encrypted {
140 key,
141 inner:
142 EncryptedInner {
143 resolution,
144 decrypted: _,
145 },
146 },
147 resolve,
148 )) = self.encrypted.try_fetch_local()?
149 else {
150 return Ok(None);
151 };
152 let Some((decrypted, _)) = self.decrypted.try_fetch_local()? else {
153 return Ok(None);
154 };
155 let decrypted = Unkeyed(Arc::new(decrypted));
156 Ok(Some((
157 Encrypted {
158 key,
159 inner: EncryptedInner {
160 resolution,
161 decrypted,
162 },
163 },
164 resolve,
165 )))
166 }
167
168 fn fetch_local(&self) -> Option<Self::T> {
169 let Encrypted {
170 key,
171 inner:
172 EncryptedInner {
173 resolution,
174 decrypted: _,
175 },
176 } = self.encrypted.fetch_local()?;
177 let decrypted = Unkeyed(Arc::new(self.decrypted.fetch_local()?));
178 Some(Encrypted {
179 key,
180 inner: EncryptedInner {
181 resolution,
182 decrypted,
183 },
184 })
185 }
186}
187
188impl<'a, K: Key, V: PointVisitor> PointVisitor for IterateResolution<'a, '_, K, V> {
189 fn visit<T: Traversible>(&mut self, decrypted: &(impl 'static + SingularFetch<T = T> + Clone)) {
190 let decrypted = decrypted.clone();
191 let encrypted = self.resolution.next().expect("length mismatch").clone();
192 let point = Point::from_fetch(
193 encrypted.hash(),
194 Visited {
195 decrypted,
196 encrypted,
197 }
198 .into_dyn_fetch(),
199 );
200 self.visitor.visit(&point);
201 }
202}
203
204impl<K, T> ListHashes for EncryptedInner<K, T> {
205 fn list_hashes(&self, f: &mut impl FnMut(Hash)) {
206 self.resolution.list_hashes(f);
207 }
208
209 fn topology_hash(&self) -> Hash {
210 self.resolution.0.data_hash()
211 }
212
213 fn point_count(&self) -> usize {
214 self.resolution.len()
215 }
216}
217
218impl<K: Key, T: Topological> Topological for EncryptedInner<K, T> {
219 fn traverse(&self, visitor: &mut impl PointVisitor) {
220 let resolution = &mut self.resolution.iter();
221 self.decrypted.0.traverse(&mut IterateResolution {
222 resolution,
223 visitor,
224 });
225 assert!(resolution.next().is_none());
226 }
227}
228
229pub struct Encrypted<K, T> {
230 key: K,
231 inner: EncryptedInner<K, T>,
232}
233
234impl<K, T: Clone> Encrypted<K, T> {
235 pub fn into_inner(self) -> T {
236 Arc::unwrap_or_clone(self.inner.decrypted.0)
237 }
238}
239
240impl<K, T> Deref for Encrypted<K, T> {
241 type Target = T;
242
243 fn deref(&self) -> &Self::Target {
244 self.inner.decrypted.0.as_ref()
245 }
246}
247
248impl<K: Clone, T> Clone for Encrypted<K, T> {
249 fn clone(&self) -> Self {
250 Self {
251 key: self.key.clone(),
252 inner: self.inner.clone(),
253 }
254 }
255}
256
257impl<K, T> ListHashes for Encrypted<K, T> {
258 fn list_hashes(&self, f: &mut impl FnMut(Hash)) {
259 self.inner.list_hashes(f);
260 }
261
262 fn topology_hash(&self) -> Hash {
263 self.inner.topology_hash()
264 }
265
266 fn point_count(&self) -> usize {
267 self.inner.point_count()
268 }
269}
270
271impl<K: Key, T: Topological> Topological for Encrypted<K, T> {
272 fn traverse(&self, visitor: &mut impl PointVisitor) {
273 self.inner.traverse(visitor);
274 }
275}
276
277impl<K: Key, T: ToOutput> ToOutput for Encrypted<K, T> {
278 fn to_output(&self, output: &mut dyn object_rainbow::Output) {
279 let source = self.inner.vec();
280 output.write(&self.key.encrypt(&source));
281 }
282}
283
284#[derive(Clone)]
285struct Decrypt<K> {
286 resolution: Resolution<K>,
287}
288
289impl<K: Key> Decrypt<K> {
290 async fn resolve_bytes(
291 &self,
292 address: Address,
293 ) -> object_rainbow::Result<(Vec<u8>, Resolution<K>)> {
294 let Encrypted {
295 key: _,
296 inner:
297 EncryptedInner {
298 resolution,
299 decrypted,
300 },
301 } = self
302 .resolution
303 .get(address.index)
304 .ok_or(Error::AddressOutOfBounds)?
305 .clone()
306 .fetch()
307 .await?;
308 let data = Arc::unwrap_or_clone(decrypted.0);
309 Ok((data, resolution))
310 }
311}
312
313impl<K: Key> Resolve for Decrypt<K> {
314 fn resolve(&'_ self, address: Address, _: &Arc<dyn Resolve>) -> FailFuture<'_, ByteNode> {
315 Box::pin(async move {
316 let (data, resolution) = self.resolve_bytes(address).await?;
317 Ok((data, Arc::new(Decrypt { resolution }) as _))
318 })
319 }
320
321 fn resolve_data(&'_ self, address: Address) -> FailFuture<'_, Vec<u8>> {
322 Box::pin(async move {
323 let (data, _) = self.resolve_bytes(address).await?;
324 Ok(data)
325 })
326 }
327
328 fn try_resolve_local(
329 &self,
330 address: Address,
331 _: &Arc<dyn Resolve>,
332 ) -> object_rainbow::Result<Option<ByteNode>> {
333 let Some((
334 Encrypted {
335 key: _,
336 inner:
337 EncryptedInner {
338 resolution,
339 decrypted,
340 },
341 },
342 _,
343 )) = self
344 .resolution
345 .get(address.index)
346 .ok_or(Error::AddressOutOfBounds)?
347 .clone()
348 .try_fetch_local()?
349 else {
350 return Ok(None);
351 };
352 let data = Arc::unwrap_or_clone(decrypted.0);
353 Ok(Some((data, Arc::new(Decrypt { resolution }) as _)))
354 }
355}
356
357trait EncryptedExtra<K>: 'static + Send + Sync + Clone {
358 type Extra: 'static + Send + Sync + Clone;
359 fn parts(&self) -> (K, Self::Extra);
360}
361
362impl<K: 'static + Send + Sync + Clone, Extra: 'static + Send + Sync + Clone> EncryptedExtra<K>
363 for (K, Extra)
364{
365 type Extra = Extra;
366
367 fn parts(&self) -> (K, Self::Extra) {
368 self.clone()
369 }
370}
371
372impl<K: 'static + Send + Sync + Clone> EncryptedExtra<K> for K {
373 type Extra = ();
374
375 fn parts(&self) -> (K, Self::Extra) {
376 (self.clone(), ())
377 }
378}
379
380impl<
381 K: Key,
382 T: Object<Extra>,
383 Extra: 'static + Send + Sync + Clone,
384 I: PointInput<Extra: EncryptedExtra<K, Extra = Extra>>,
385> Parse<I> for Encrypted<K, T>
386{
387 fn parse(input: I) -> object_rainbow::Result<Self> {
388 let with_key = input.extra().parts();
389 let resolve = input.resolve().clone();
390 let source = with_key
391 .0
392 .decrypt(&input.parse_all()?)
393 .map_err(object_rainbow::Error::consistency)?;
394 let EncryptedInner {
395 resolution,
396 decrypted,
397 } = EncryptedInner::<K, Vec<u8>>::parse_slice_extra(&source, &resolve, &with_key)?;
398 let decrypted = T::parse_slice_extra(
399 &decrypted.0,
400 &(Arc::new(Decrypt {
401 resolution: resolution.clone(),
402 }) as _),
403 &with_key.1,
404 )?;
405 let decrypted = Unkeyed(Arc::new(decrypted));
406 let inner = EncryptedInner {
407 resolution,
408 decrypted,
409 };
410 Ok(Self {
411 key: with_key.0,
412 inner,
413 })
414 }
415}
416
417impl<K, T> Tagged for Encrypted<K, T> {}
418
419type Extracted<K> = Vec<
420 std::pin::Pin<
421 Box<dyn Future<Output = Result<Point<Encrypted<K, Vec<u8>>>, Error>> + Send + 'static>,
422 >,
423>;
424
425struct ExtractResolution<'a, K> {
426 extracted: &'a mut Extracted<K>,
427 key: &'a K,
428}
429
430struct Untyped<K, T> {
431 key: K,
432 encrypted: Point<Encrypted<K, T>>,
433}
434
435impl<K, T> FetchBytes for Untyped<K, T> {
436 fn fetch_bytes(&'_ self) -> FailFuture<'_, ByteNode> {
437 self.encrypted.fetch_bytes()
438 }
439
440 fn fetch_data(&'_ self) -> FailFuture<'_, Vec<u8>> {
441 self.encrypted.fetch_data()
442 }
443
444 fn fetch_bytes_local(&self) -> object_rainbow::Result<Option<ByteNode>> {
445 self.encrypted.fetch_bytes_local()
446 }
447
448 fn fetch_data_local(&self) -> Option<Vec<u8>> {
449 self.encrypted.fetch_data_local()
450 }
451}
452
453impl<K: Send + Sync, T> Singular for Untyped<K, T> {
454 fn hash(&self) -> Hash {
455 self.encrypted.hash()
456 }
457}
458
459impl<K: Key, T: FullHash> Fetch for Untyped<K, T> {
460 type T = Encrypted<K, Vec<u8>>;
461
462 fn fetch_full(&'_ self) -> FailFuture<'_, Node<Self::T>> {
463 Box::pin(async move {
464 let (data, resolve) = self.fetch_bytes().await?;
465 let encrypted = Self::T::parse_slice_extra(&data, &resolve, &self.key)?;
466 Ok((encrypted, resolve))
467 })
468 }
469
470 fn fetch(&'_ self) -> FailFuture<'_, Self::T> {
471 Box::pin(async move {
472 let (data, resolve) = self.fetch_bytes().await?;
473 let encrypted = Self::T::parse_slice_extra(&data, &resolve, &self.key)?;
474 Ok(encrypted)
475 })
476 }
477
478 fn try_fetch_local(&self) -> object_rainbow::Result<Option<Node<Self::T>>> {
479 let Some((data, resolve)) = self.fetch_bytes_local()? else {
480 return Ok(None);
481 };
482 let encrypted = Self::T::parse_slice_extra(&data, &resolve, &self.key)?;
483 Ok(Some((encrypted, resolve)))
484 }
485
486 fn fetch_local(&self) -> Option<Self::T> {
487 let Encrypted {
488 key,
489 inner:
490 EncryptedInner {
491 resolution,
492 decrypted,
493 },
494 } = self.encrypted.fetch_local()?;
495 let decrypted = Unkeyed(Arc::new(decrypted.vec()));
496 Some(Encrypted {
497 key,
498 inner: EncryptedInner {
499 resolution,
500 decrypted,
501 },
502 })
503 }
504}
505
506impl<K: Key> PointVisitor for ExtractResolution<'_, K> {
507 fn visit<T: Traversible>(&mut self, decrypted: &(impl 'static + SingularFetch<T = T> + Clone)) {
508 let decrypted = decrypted.clone();
509 let key = self.key.clone();
510 self.extracted.push(Box::pin(async move {
511 let encrypted = encrypt_point(key.clone(), decrypted).await?;
512 let encrypted = Point::from_fetch(
513 encrypted.hash(),
514 Untyped { key, encrypted }.into_dyn_fetch(),
515 );
516 Ok(encrypted)
517 }));
518 }
519}
520
521pub async fn encrypt_point<K: Key, T: Traversible>(
522 key: K,
523 decrypted: impl 'static + SingularFetch<T = T>,
524) -> object_rainbow::Result<Point<Encrypted<K, T>>> {
525 if let Some((address, decrypt)) = decrypted.extract_resolve::<Decrypt<K>>() {
526 let encrypted = decrypt
527 .resolution
528 .get(address.index)
529 .ok_or(Error::AddressOutOfBounds)?
530 .clone();
531 let point = Point::from_fetch(
532 encrypted.hash(),
533 Visited {
534 decrypted,
535 encrypted,
536 }
537 .into_dyn_fetch(),
538 );
539 return Ok(point);
540 }
541 let decrypted = decrypted.fetch().await?;
542 let encrypted = encrypt(key.clone(), decrypted).await?;
543 let point = encrypted.point();
544 Ok(point)
545}
546
547pub async fn encrypt<K: Key, T: Traversible>(
548 key: K,
549 decrypted: T,
550) -> object_rainbow::Result<Encrypted<K, T>> {
551 let mut futures = Vec::with_capacity(decrypted.point_count());
552 decrypted.traverse(&mut ExtractResolution {
553 extracted: &mut futures,
554 key: &key,
555 });
556 let resolution = futures_util::future::try_join_all(futures).await?;
557 let resolution = Arc::new(Lp(resolution));
558 let decrypted = Unkeyed(Arc::new(decrypted));
559 let inner = EncryptedInner {
560 resolution,
561 decrypted,
562 };
563 Ok(Encrypted { key, inner })
564}