1use std::{ops::Deref, sync::Arc};
2
3use object_rainbow::{
4 Address, ByteNode, Error, FailFuture, Fetch, FetchBytes, FullHash, Hash, ListHashes, Node,
5 Object, Parse, ParseSliceExtra, PointInput, PointVisitor, Resolve, Singular, SingularFetch,
6 Tagged, ToOutput, Topological, Traversible, length_prefixed::Lp,
7};
8use object_rainbow_point::{ExtractResolve, IntoPoint, Point};
9
10pub trait Key: 'static + Sized + Send + Sync + Clone {
11 type Error: Into<anyhow::Error>;
12 fn encrypt(&self, data: &[u8]) -> Vec<u8>;
13 fn decrypt(&self, data: &[u8]) -> Result<Vec<u8>, Self::Error>;
14}
15
16type Resolution<K> = Arc<Lp<Vec<Point<Encrypted<K, Vec<u8>>>>>>;
17
18#[derive(ToOutput, Clone)]
19struct Unkeyed<T>(T);
20
21impl<
22 T: Parse<I::WithExtra<Extra>>,
23 K: 'static + Clone,
24 Extra: 'static + Clone,
25 I: PointInput<Extra = (K, Extra)>,
26> Parse<I> for Unkeyed<T>
27{
28 fn parse(input: I) -> object_rainbow::Result<Self> {
29 Ok(Self(T::parse(input.map_extra(|(_, extra)| extra))?))
30 }
31}
32
33#[derive(ToOutput, Parse)]
34struct EncryptedInner<K, T> {
35 resolution: Resolution<K>,
36 decrypted: Unkeyed<Arc<T>>,
37}
38
39impl<K, T> Clone for EncryptedInner<K, T> {
40 fn clone(&self) -> Self {
41 Self {
42 resolution: self.resolution.clone(),
43 decrypted: self.decrypted.clone(),
44 }
45 }
46}
47
48type ResolutionIter<'a, K> = std::slice::Iter<'a, Point<Encrypted<K, Vec<u8>>>>;
49
50struct IterateResolution<'a, 'r, K, V> {
51 resolution: &'r mut ResolutionIter<'a, K>,
52 visitor: &'a mut V,
53}
54
55struct Visited<K, P> {
56 decrypted: P,
57 encrypted: Point<Encrypted<K, Vec<u8>>>,
58}
59
60impl<K, P> FetchBytes for Visited<K, P> {
61 fn fetch_bytes(&'_ self) -> FailFuture<'_, ByteNode> {
62 self.encrypted.fetch_bytes()
63 }
64
65 fn fetch_data(&'_ self) -> FailFuture<'_, Vec<u8>> {
66 self.encrypted.fetch_data()
67 }
68
69 fn fetch_bytes_local(&self) -> object_rainbow::Result<Option<ByteNode>> {
70 self.encrypted.fetch_bytes_local()
71 }
72
73 fn fetch_data_local(&self) -> Option<Vec<u8>> {
74 self.encrypted.fetch_data_local()
75 }
76}
77
78impl<K, P: Send + Sync> Singular for Visited<K, P> {
79 fn hash(&self) -> Hash {
80 self.encrypted.hash()
81 }
82}
83
84impl<K: Key, P: Fetch<T: Traversible>> Fetch for Visited<K, P> {
85 type T = Encrypted<K, P::T>;
86
87 fn fetch_full(&'_ self) -> FailFuture<'_, Node<Self::T>> {
88 Box::pin(async move {
89 let (
90 Encrypted {
91 key,
92 inner:
93 EncryptedInner {
94 resolution,
95 decrypted: _,
96 },
97 },
98 resolve,
99 ) = self.encrypted.fetch_full().await?;
100 let decrypted = self.decrypted.fetch().await?;
101 let decrypted = Unkeyed(Arc::new(decrypted));
102 Ok((
103 Encrypted {
104 key,
105 inner: EncryptedInner {
106 resolution,
107 decrypted,
108 },
109 },
110 resolve,
111 ))
112 })
113 }
114
115 fn fetch(&'_ self) -> FailFuture<'_, Self::T> {
116 Box::pin(async move {
117 let Encrypted {
118 key,
119 inner:
120 EncryptedInner {
121 resolution,
122 decrypted: _,
123 },
124 } = self.encrypted.fetch().await?;
125 let decrypted = self.decrypted.fetch().await?;
126 let decrypted = Unkeyed(Arc::new(decrypted));
127 Ok(Encrypted {
128 key,
129 inner: EncryptedInner {
130 resolution,
131 decrypted,
132 },
133 })
134 })
135 }
136
137 fn try_fetch_local(&self) -> object_rainbow::Result<Option<Node<Self::T>>> {
138 let Some((
139 Encrypted {
140 key,
141 inner:
142 EncryptedInner {
143 resolution,
144 decrypted: _,
145 },
146 },
147 resolve,
148 )) = self.encrypted.try_fetch_local()?
149 else {
150 return Ok(None);
151 };
152 let Some((decrypted, _)) = self.decrypted.try_fetch_local()? else {
153 return Ok(None);
154 };
155 let decrypted = Unkeyed(Arc::new(decrypted));
156 Ok(Some((
157 Encrypted {
158 key,
159 inner: EncryptedInner {
160 resolution,
161 decrypted,
162 },
163 },
164 resolve,
165 )))
166 }
167
168 fn fetch_local(&self) -> Option<Self::T> {
169 let Encrypted {
170 key,
171 inner:
172 EncryptedInner {
173 resolution,
174 decrypted: _,
175 },
176 } = self.encrypted.fetch_local()?;
177 let decrypted = Unkeyed(Arc::new(self.decrypted.fetch_local()?));
178 Some(Encrypted {
179 key,
180 inner: EncryptedInner {
181 resolution,
182 decrypted,
183 },
184 })
185 }
186}
187
188impl<'a, K: Key, V: PointVisitor> PointVisitor for IterateResolution<'a, '_, K, V> {
189 fn visit<T: Traversible>(&mut self, decrypted: &(impl 'static + SingularFetch<T = T> + Clone)) {
190 let decrypted = decrypted.clone();
191 let encrypted = self.resolution.next().expect("length mismatch").clone();
192 let point = Point::from_fetch(
193 encrypted.hash(),
194 Arc::new(Visited {
195 decrypted,
196 encrypted,
197 }),
198 );
199 self.visitor.visit(&point);
200 }
201}
202
203impl<K, T> ListHashes for EncryptedInner<K, T> {
204 fn list_hashes(&self, f: &mut impl FnMut(Hash)) {
205 self.resolution.list_hashes(f);
206 }
207
208 fn topology_hash(&self) -> Hash {
209 self.resolution.0.data_hash()
210 }
211
212 fn point_count(&self) -> usize {
213 self.resolution.len()
214 }
215}
216
217impl<K: Key, T: Topological> Topological for EncryptedInner<K, T> {
218 fn traverse(&self, visitor: &mut impl PointVisitor) {
219 let resolution = &mut self.resolution.iter();
220 self.decrypted.0.traverse(&mut IterateResolution {
221 resolution,
222 visitor,
223 });
224 assert!(resolution.next().is_none());
225 }
226}
227
228pub struct Encrypted<K, T> {
229 key: K,
230 inner: EncryptedInner<K, T>,
231}
232
233impl<K, T: Clone> Encrypted<K, T> {
234 pub fn into_inner(self) -> T {
235 Arc::unwrap_or_clone(self.inner.decrypted.0)
236 }
237}
238
239impl<K, T> Deref for Encrypted<K, T> {
240 type Target = T;
241
242 fn deref(&self) -> &Self::Target {
243 self.inner.decrypted.0.as_ref()
244 }
245}
246
247impl<K: Clone, T> Clone for Encrypted<K, T> {
248 fn clone(&self) -> Self {
249 Self {
250 key: self.key.clone(),
251 inner: self.inner.clone(),
252 }
253 }
254}
255
256impl<K, T> ListHashes for Encrypted<K, T> {
257 fn list_hashes(&self, f: &mut impl FnMut(Hash)) {
258 self.inner.list_hashes(f);
259 }
260
261 fn topology_hash(&self) -> Hash {
262 self.inner.topology_hash()
263 }
264
265 fn point_count(&self) -> usize {
266 self.inner.point_count()
267 }
268}
269
270impl<K: Key, T: Topological> Topological for Encrypted<K, T> {
271 fn traverse(&self, visitor: &mut impl PointVisitor) {
272 self.inner.traverse(visitor);
273 }
274}
275
276impl<K: Key, T: ToOutput> ToOutput for Encrypted<K, T> {
277 fn to_output(&self, output: &mut dyn object_rainbow::Output) {
278 let source = self.inner.vec();
279 output.write(&self.key.encrypt(&source));
280 }
281}
282
283#[derive(Clone)]
284struct Decrypt<K> {
285 resolution: Resolution<K>,
286}
287
288impl<K: Key> Decrypt<K> {
289 async fn resolve_bytes(
290 &self,
291 address: Address,
292 ) -> object_rainbow::Result<(Vec<u8>, Resolution<K>)> {
293 let Encrypted {
294 key: _,
295 inner:
296 EncryptedInner {
297 resolution,
298 decrypted,
299 },
300 } = self
301 .resolution
302 .get(address.index)
303 .ok_or(Error::AddressOutOfBounds)?
304 .clone()
305 .fetch()
306 .await?;
307 let data = Arc::unwrap_or_clone(decrypted.0);
308 Ok((data, resolution))
309 }
310}
311
312impl<K: Key> Resolve for Decrypt<K> {
313 fn resolve(&'_ self, address: Address) -> FailFuture<'_, ByteNode> {
314 Box::pin(async move {
315 let (data, resolution) = self.resolve_bytes(address).await?;
316 Ok((data, Arc::new(Decrypt { resolution }) as _))
317 })
318 }
319
320 fn resolve_data(&'_ self, address: Address) -> FailFuture<'_, Vec<u8>> {
321 Box::pin(async move {
322 let (data, _) = self.resolve_bytes(address).await?;
323 Ok(data)
324 })
325 }
326
327 fn try_resolve_local(&self, address: Address) -> object_rainbow::Result<Option<ByteNode>> {
328 let Some((
329 Encrypted {
330 key: _,
331 inner:
332 EncryptedInner {
333 resolution,
334 decrypted,
335 },
336 },
337 _,
338 )) = self
339 .resolution
340 .get(address.index)
341 .ok_or(Error::AddressOutOfBounds)?
342 .clone()
343 .try_fetch_local()?
344 else {
345 return Ok(None);
346 };
347 let data = Arc::unwrap_or_clone(decrypted.0);
348 Ok(Some((data, Arc::new(Decrypt { resolution }) as _)))
349 }
350}
351
352trait EncryptedExtra<K>: 'static + Send + Sync + Clone {
353 type Extra: 'static + Send + Sync + Clone;
354 fn parts(&self) -> (K, Self::Extra);
355}
356
357impl<K: 'static + Send + Sync + Clone, Extra: 'static + Send + Sync + Clone> EncryptedExtra<K>
358 for (K, Extra)
359{
360 type Extra = Extra;
361
362 fn parts(&self) -> (K, Self::Extra) {
363 self.clone()
364 }
365}
366
367impl<K: 'static + Send + Sync + Clone> EncryptedExtra<K> for K {
368 type Extra = ();
369
370 fn parts(&self) -> (K, Self::Extra) {
371 (self.clone(), ())
372 }
373}
374
375impl<
376 K: Key,
377 T: Object<Extra>,
378 Extra: 'static + Send + Sync + Clone,
379 I: PointInput<Extra: EncryptedExtra<K, Extra = Extra>>,
380> Parse<I> for Encrypted<K, T>
381{
382 fn parse(input: I) -> object_rainbow::Result<Self> {
383 let with_key = input.extra().parts();
384 let resolve = input.resolve().clone();
385 let source = with_key
386 .0
387 .decrypt(&input.parse_all()?)
388 .map_err(object_rainbow::Error::consistency)?;
389 let EncryptedInner {
390 resolution,
391 decrypted,
392 } = EncryptedInner::<K, Vec<u8>>::parse_slice_extra(&source, &resolve, &with_key)?;
393 let decrypted = T::parse_slice_extra(
394 &decrypted.0,
395 &(Arc::new(Decrypt {
396 resolution: resolution.clone(),
397 }) as _),
398 &with_key.1,
399 )?;
400 let decrypted = Unkeyed(Arc::new(decrypted));
401 let inner = EncryptedInner {
402 resolution,
403 decrypted,
404 };
405 Ok(Self {
406 key: with_key.0,
407 inner,
408 })
409 }
410}
411
412impl<K, T> Tagged for Encrypted<K, T> {}
413
414type Extracted<K> = Vec<
415 std::pin::Pin<
416 Box<dyn Future<Output = Result<Point<Encrypted<K, Vec<u8>>>, Error>> + Send + 'static>,
417 >,
418>;
419
420struct ExtractResolution<'a, K> {
421 extracted: &'a mut Extracted<K>,
422 key: &'a K,
423}
424
425struct Untyped<K, T> {
426 key: K,
427 encrypted: Point<Encrypted<K, T>>,
428}
429
430impl<K, T> FetchBytes for Untyped<K, T> {
431 fn fetch_bytes(&'_ self) -> FailFuture<'_, ByteNode> {
432 self.encrypted.fetch_bytes()
433 }
434
435 fn fetch_data(&'_ self) -> FailFuture<'_, Vec<u8>> {
436 self.encrypted.fetch_data()
437 }
438
439 fn fetch_bytes_local(&self) -> object_rainbow::Result<Option<ByteNode>> {
440 self.encrypted.fetch_bytes_local()
441 }
442
443 fn fetch_data_local(&self) -> Option<Vec<u8>> {
444 self.encrypted.fetch_data_local()
445 }
446}
447
448impl<K: Send + Sync, T> Singular for Untyped<K, T> {
449 fn hash(&self) -> Hash {
450 self.encrypted.hash()
451 }
452}
453
454impl<K: Key, T: FullHash> Fetch for Untyped<K, T> {
455 type T = Encrypted<K, Vec<u8>>;
456
457 fn fetch_full(&'_ self) -> FailFuture<'_, Node<Self::T>> {
458 Box::pin(async move {
459 let (data, resolve) = self.fetch_bytes().await?;
460 let encrypted = Self::T::parse_slice_extra(&data, &resolve, &self.key)?;
461 Ok((encrypted, resolve))
462 })
463 }
464
465 fn fetch(&'_ self) -> FailFuture<'_, Self::T> {
466 Box::pin(async move {
467 let (data, resolve) = self.fetch_bytes().await?;
468 let encrypted = Self::T::parse_slice_extra(&data, &resolve, &self.key)?;
469 Ok(encrypted)
470 })
471 }
472
473 fn try_fetch_local(&self) -> object_rainbow::Result<Option<Node<Self::T>>> {
474 let Some((data, resolve)) = self.fetch_bytes_local()? else {
475 return Ok(None);
476 };
477 let encrypted = Self::T::parse_slice_extra(&data, &resolve, &self.key)?;
478 Ok(Some((encrypted, resolve)))
479 }
480
481 fn fetch_local(&self) -> Option<Self::T> {
482 let Encrypted {
483 key,
484 inner:
485 EncryptedInner {
486 resolution,
487 decrypted,
488 },
489 } = self.encrypted.fetch_local()?;
490 let decrypted = Unkeyed(Arc::new(decrypted.vec()));
491 Some(Encrypted {
492 key,
493 inner: EncryptedInner {
494 resolution,
495 decrypted,
496 },
497 })
498 }
499}
500
501impl<K: Key> PointVisitor for ExtractResolution<'_, K> {
502 fn visit<T: Traversible>(&mut self, decrypted: &(impl 'static + SingularFetch<T = T> + Clone)) {
503 let decrypted = decrypted.clone();
504 let key = self.key.clone();
505 self.extracted.push(Box::pin(async move {
506 let encrypted = encrypt_point(key.clone(), decrypted).await?;
507 let encrypted =
508 Point::from_fetch(encrypted.hash(), Arc::new(Untyped { key, encrypted }));
509 Ok(encrypted)
510 }));
511 }
512}
513
514pub async fn encrypt_point<K: Key, T: Traversible>(
515 key: K,
516 decrypted: impl 'static + SingularFetch<T = T>,
517) -> object_rainbow::Result<Point<Encrypted<K, T>>> {
518 if let Some((address, decrypt)) = decrypted.extract_resolve::<Decrypt<K>>() {
519 let encrypted = decrypt
520 .resolution
521 .get(address.index)
522 .ok_or(Error::AddressOutOfBounds)?
523 .clone();
524 let point = Point::from_fetch(
525 encrypted.hash(),
526 Arc::new(Visited {
527 decrypted,
528 encrypted,
529 }),
530 );
531 return Ok(point);
532 }
533 let decrypted = decrypted.fetch().await?;
534 let encrypted = encrypt(key.clone(), decrypted).await?;
535 let point = encrypted.point();
536 Ok(point)
537}
538
539pub async fn encrypt<K: Key, T: Traversible>(
540 key: K,
541 decrypted: T,
542) -> object_rainbow::Result<Encrypted<K, T>> {
543 let mut futures = Vec::with_capacity(decrypted.point_count());
544 decrypted.traverse(&mut ExtractResolution {
545 extracted: &mut futures,
546 key: &key,
547 });
548 let resolution = futures_util::future::try_join_all(futures).await?;
549 let resolution = Arc::new(Lp(resolution));
550 let decrypted = Unkeyed(Arc::new(decrypted));
551 let inner = EncryptedInner {
552 resolution,
553 decrypted,
554 };
555 Ok(Encrypted { key, inner })
556}