1use std::{ops::Deref, sync::Arc};
2
3use object_rainbow::{
4 Address, ByteNode, Error, FailFuture, Fetch, FetchBytes, FullHash, Hash, ListHashes, Node,
5 Object, Parse, ParseSliceExtra, PointInput, PointVisitor, Resolve, Singular, SingularFetch,
6 Tagged, ToOutput, Topological, Traversible, length_prefixed::Lp,
7};
8use object_rainbow_point::{ExtractResolve, IntoPoint, Point};
9
10#[derive(Clone)]
11pub struct WithKey<K, Extra> {
12 pub key: K,
13 pub extra: Extra,
14}
15
16pub trait Key: 'static + Sized + Send + Sync + Clone {
17 fn encrypt(&self, data: &[u8]) -> Vec<u8>;
18 fn decrypt(&self, data: &[u8]) -> object_rainbow::Result<Vec<u8>>;
19}
20
21type Resolution<K> = Arc<Lp<Vec<Point<Encrypted<K, Vec<u8>>>>>>;
22
23#[derive(ToOutput, Clone)]
24struct Unkeyed<T>(T);
25
26impl<
27 T: Parse<I::WithExtra<Extra>>,
28 K: 'static + Clone,
29 Extra: 'static + Clone,
30 I: PointInput<Extra = WithKey<K, Extra>>,
31> Parse<I> for Unkeyed<T>
32{
33 fn parse(input: I) -> object_rainbow::Result<Self> {
34 Ok(Self(T::parse(
35 input.map_extra(|WithKey { extra, .. }| extra),
36 )?))
37 }
38}
39
40#[derive(ToOutput, Parse)]
41struct EncryptedInner<K, T> {
42 resolution: Resolution<K>,
43 decrypted: Unkeyed<Arc<T>>,
44}
45
46impl<K, T> Clone for EncryptedInner<K, T> {
47 fn clone(&self) -> Self {
48 Self {
49 resolution: self.resolution.clone(),
50 decrypted: self.decrypted.clone(),
51 }
52 }
53}
54
55type ResolutionIter<'a, K> = std::slice::Iter<'a, Point<Encrypted<K, Vec<u8>>>>;
56
57struct IterateResolution<'a, 'r, K, V> {
58 resolution: &'r mut ResolutionIter<'a, K>,
59 visitor: &'a mut V,
60}
61
62struct Visited<K, P> {
63 decrypted: P,
64 encrypted: Point<Encrypted<K, Vec<u8>>>,
65}
66
67impl<K, P> FetchBytes for Visited<K, P> {
68 fn fetch_bytes(&'_ self) -> FailFuture<'_, ByteNode> {
69 self.encrypted.fetch_bytes()
70 }
71
72 fn fetch_data(&'_ self) -> FailFuture<'_, Vec<u8>> {
73 self.encrypted.fetch_data()
74 }
75
76 fn fetch_bytes_local(&self) -> object_rainbow::Result<Option<ByteNode>> {
77 self.encrypted.fetch_bytes_local()
78 }
79
80 fn fetch_data_local(&self) -> Option<Vec<u8>> {
81 self.encrypted.fetch_data_local()
82 }
83}
84
85impl<K, P: Send + Sync> Singular for Visited<K, P> {
86 fn hash(&self) -> Hash {
87 self.encrypted.hash()
88 }
89}
90
91impl<K: Key, P: Fetch<T: Traversible>> Fetch for Visited<K, P> {
92 type T = Encrypted<K, P::T>;
93
94 fn fetch_full(&'_ self) -> FailFuture<'_, Node<Self::T>> {
95 Box::pin(async move {
96 let (
97 Encrypted {
98 key,
99 inner:
100 EncryptedInner {
101 resolution,
102 decrypted: _,
103 },
104 },
105 resolve,
106 ) = self.encrypted.fetch_full().await?;
107 let decrypted = self.decrypted.fetch().await?;
108 let decrypted = Unkeyed(Arc::new(decrypted));
109 Ok((
110 Encrypted {
111 key,
112 inner: EncryptedInner {
113 resolution,
114 decrypted,
115 },
116 },
117 resolve,
118 ))
119 })
120 }
121
122 fn fetch(&'_ self) -> FailFuture<'_, Self::T> {
123 Box::pin(async move {
124 let Encrypted {
125 key,
126 inner:
127 EncryptedInner {
128 resolution,
129 decrypted: _,
130 },
131 } = self.encrypted.fetch().await?;
132 let decrypted = self.decrypted.fetch().await?;
133 let decrypted = Unkeyed(Arc::new(decrypted));
134 Ok(Encrypted {
135 key,
136 inner: EncryptedInner {
137 resolution,
138 decrypted,
139 },
140 })
141 })
142 }
143
144 fn try_fetch_local(&self) -> object_rainbow::Result<Option<Node<Self::T>>> {
145 let Some((
146 Encrypted {
147 key,
148 inner:
149 EncryptedInner {
150 resolution,
151 decrypted: _,
152 },
153 },
154 resolve,
155 )) = self.encrypted.try_fetch_local()?
156 else {
157 return Ok(None);
158 };
159 let Some((decrypted, _)) = self.decrypted.try_fetch_local()? else {
160 return Ok(None);
161 };
162 let decrypted = Unkeyed(Arc::new(decrypted));
163 Ok(Some((
164 Encrypted {
165 key,
166 inner: EncryptedInner {
167 resolution,
168 decrypted,
169 },
170 },
171 resolve,
172 )))
173 }
174
175 fn fetch_local(&self) -> Option<Self::T> {
176 let Encrypted {
177 key,
178 inner:
179 EncryptedInner {
180 resolution,
181 decrypted: _,
182 },
183 } = self.encrypted.fetch_local()?;
184 let decrypted = Unkeyed(Arc::new(self.decrypted.fetch_local()?));
185 Some(Encrypted {
186 key,
187 inner: EncryptedInner {
188 resolution,
189 decrypted,
190 },
191 })
192 }
193}
194
195impl<'a, K: Key, V: PointVisitor> PointVisitor for IterateResolution<'a, '_, K, V> {
196 fn visit<T: Traversible>(&mut self, decrypted: &(impl 'static + SingularFetch<T = T> + Clone)) {
197 let decrypted = decrypted.clone();
198 let encrypted = self.resolution.next().expect("length mismatch").clone();
199 let point = Point::from_fetch(
200 encrypted.hash(),
201 Arc::new(Visited {
202 decrypted,
203 encrypted,
204 }),
205 );
206 self.visitor.visit(&point);
207 }
208}
209
210impl<K, T> ListHashes for EncryptedInner<K, T> {
211 fn list_hashes(&self, f: &mut impl FnMut(Hash)) {
212 self.resolution.list_hashes(f);
213 }
214
215 fn topology_hash(&self) -> Hash {
216 self.resolution.0.data_hash()
217 }
218
219 fn point_count(&self) -> usize {
220 self.resolution.len()
221 }
222}
223
224impl<K: Key, T: Topological> Topological for EncryptedInner<K, T> {
225 fn traverse(&self, visitor: &mut impl PointVisitor) {
226 let resolution = &mut self.resolution.iter();
227 self.decrypted.0.traverse(&mut IterateResolution {
228 resolution,
229 visitor,
230 });
231 assert!(resolution.next().is_none());
232 }
233}
234
235pub struct Encrypted<K, T> {
236 key: K,
237 inner: EncryptedInner<K, T>,
238}
239
240impl<K, T: Clone> Encrypted<K, T> {
241 pub fn into_inner(self) -> T {
242 Arc::unwrap_or_clone(self.inner.decrypted.0)
243 }
244}
245
246impl<K, T> Deref for Encrypted<K, T> {
247 type Target = T;
248
249 fn deref(&self) -> &Self::Target {
250 self.inner.decrypted.0.as_ref()
251 }
252}
253
254impl<K: Clone, T> Clone for Encrypted<K, T> {
255 fn clone(&self) -> Self {
256 Self {
257 key: self.key.clone(),
258 inner: self.inner.clone(),
259 }
260 }
261}
262
263impl<K, T> ListHashes for Encrypted<K, T> {
264 fn list_hashes(&self, f: &mut impl FnMut(Hash)) {
265 self.inner.list_hashes(f);
266 }
267
268 fn topology_hash(&self) -> Hash {
269 self.inner.topology_hash()
270 }
271
272 fn point_count(&self) -> usize {
273 self.inner.point_count()
274 }
275}
276
277impl<K: Key, T: Topological> Topological for Encrypted<K, T> {
278 fn traverse(&self, visitor: &mut impl PointVisitor) {
279 self.inner.traverse(visitor);
280 }
281}
282
283impl<K: Key, T: ToOutput> ToOutput for Encrypted<K, T> {
284 fn to_output(&self, output: &mut dyn object_rainbow::Output) {
285 let source = self.inner.vec();
286 output.write(&self.key.encrypt(&source));
287 }
288}
289
290#[derive(Clone)]
291struct Decrypt<K> {
292 resolution: Resolution<K>,
293}
294
295impl<K: Key> Decrypt<K> {
296 async fn resolve_bytes(
297 &self,
298 address: Address,
299 ) -> object_rainbow::Result<(Vec<u8>, Resolution<K>)> {
300 let Encrypted {
301 key: _,
302 inner:
303 EncryptedInner {
304 resolution,
305 decrypted,
306 },
307 } = self
308 .resolution
309 .get(address.index)
310 .ok_or(Error::AddressOutOfBounds)?
311 .clone()
312 .fetch()
313 .await?;
314 let data = Arc::unwrap_or_clone(decrypted.0);
315 Ok((data, resolution))
316 }
317}
318
319impl<K: Key> Resolve for Decrypt<K> {
320 fn resolve(&'_ self, address: Address) -> FailFuture<'_, ByteNode> {
321 Box::pin(async move {
322 let (data, resolution) = self.resolve_bytes(address).await?;
323 Ok((data, Arc::new(Decrypt { resolution }) as _))
324 })
325 }
326
327 fn resolve_data(&'_ self, address: Address) -> FailFuture<'_, Vec<u8>> {
328 Box::pin(async move {
329 let (data, _) = self.resolve_bytes(address).await?;
330 Ok(data)
331 })
332 }
333
334 fn try_resolve_local(&self, address: Address) -> object_rainbow::Result<Option<ByteNode>> {
335 let Some((
336 Encrypted {
337 key: _,
338 inner:
339 EncryptedInner {
340 resolution,
341 decrypted,
342 },
343 },
344 _,
345 )) = self
346 .resolution
347 .get(address.index)
348 .ok_or(Error::AddressOutOfBounds)?
349 .clone()
350 .try_fetch_local()?
351 else {
352 return Ok(None);
353 };
354 let data = Arc::unwrap_or_clone(decrypted.0);
355 Ok(Some((data, Arc::new(Decrypt { resolution }) as _)))
356 }
357}
358
359impl<
360 K: Key,
361 T: Object<Extra>,
362 Extra: 'static + Send + Sync + Clone,
363 I: PointInput<Extra = WithKey<K, Extra>>,
364> Parse<I> for Encrypted<K, T>
365{
366 fn parse(input: I) -> object_rainbow::Result<Self> {
367 let with_key = input.extra().clone();
368 let resolve = input.resolve().clone();
369 let source = with_key.key.decrypt(&input.parse_all()?)?;
370 let EncryptedInner {
371 resolution,
372 decrypted,
373 } = EncryptedInner::<K, Vec<u8>>::parse_slice_extra(&source, &resolve, &with_key)?;
374 let decrypted = T::parse_slice_extra(
375 &decrypted.0,
376 &(Arc::new(Decrypt {
377 resolution: resolution.clone(),
378 }) as _),
379 &with_key.extra,
380 )?;
381 let decrypted = Unkeyed(Arc::new(decrypted));
382 let inner = EncryptedInner {
383 resolution,
384 decrypted,
385 };
386 Ok(Self {
387 key: with_key.key,
388 inner,
389 })
390 }
391}
392
393impl<K, T> Tagged for Encrypted<K, T> {}
394
395type Extracted<K> = Vec<
396 std::pin::Pin<
397 Box<dyn Future<Output = Result<Point<Encrypted<K, Vec<u8>>>, Error>> + Send + 'static>,
398 >,
399>;
400
401struct ExtractResolution<'a, K> {
402 extracted: &'a mut Extracted<K>,
403 key: &'a K,
404}
405
406struct Untyped<K, T> {
407 key: WithKey<K, ()>,
408 encrypted: Point<Encrypted<K, T>>,
409}
410
411impl<K, T> FetchBytes for Untyped<K, T> {
412 fn fetch_bytes(&'_ self) -> FailFuture<'_, ByteNode> {
413 self.encrypted.fetch_bytes()
414 }
415
416 fn fetch_data(&'_ self) -> FailFuture<'_, Vec<u8>> {
417 self.encrypted.fetch_data()
418 }
419
420 fn fetch_bytes_local(&self) -> object_rainbow::Result<Option<ByteNode>> {
421 self.encrypted.fetch_bytes_local()
422 }
423
424 fn fetch_data_local(&self) -> Option<Vec<u8>> {
425 self.encrypted.fetch_data_local()
426 }
427}
428
429impl<K: Send + Sync, T> Singular for Untyped<K, T> {
430 fn hash(&self) -> Hash {
431 self.encrypted.hash()
432 }
433}
434
435impl<K: Key, T: FullHash> Fetch for Untyped<K, T> {
436 type T = Encrypted<K, Vec<u8>>;
437
438 fn fetch_full(&'_ self) -> FailFuture<'_, Node<Self::T>> {
439 Box::pin(async move {
440 let (data, resolve) = self.fetch_bytes().await?;
441 let encrypted = Self::T::parse_slice_extra(&data, &resolve, &self.key)?;
442 Ok((encrypted, resolve))
443 })
444 }
445
446 fn fetch(&'_ self) -> FailFuture<'_, Self::T> {
447 Box::pin(async move {
448 let (data, resolve) = self.fetch_bytes().await?;
449 let encrypted = Self::T::parse_slice_extra(&data, &resolve, &self.key)?;
450 Ok(encrypted)
451 })
452 }
453
454 fn try_fetch_local(&self) -> object_rainbow::Result<Option<Node<Self::T>>> {
455 let Some((data, resolve)) = self.fetch_bytes_local()? else {
456 return Ok(None);
457 };
458 let encrypted = Self::T::parse_slice_extra(&data, &resolve, &self.key)?;
459 Ok(Some((encrypted, resolve)))
460 }
461
462 fn fetch_local(&self) -> Option<Self::T> {
463 let Encrypted {
464 key,
465 inner:
466 EncryptedInner {
467 resolution,
468 decrypted,
469 },
470 } = self.encrypted.fetch_local()?;
471 let decrypted = Unkeyed(Arc::new(decrypted.vec()));
472 Some(Encrypted {
473 key,
474 inner: EncryptedInner {
475 resolution,
476 decrypted,
477 },
478 })
479 }
480}
481
482impl<K: Key> PointVisitor for ExtractResolution<'_, K> {
483 fn visit<T: Traversible>(&mut self, decrypted: &(impl 'static + SingularFetch<T = T> + Clone)) {
484 let decrypted = decrypted.clone();
485 let key = self.key.clone();
486 self.extracted.push(Box::pin(async move {
487 let encrypted = encrypt_point(key.clone(), decrypted).await?;
488 let encrypted = Point::from_fetch(
489 encrypted.hash(),
490 Arc::new(Untyped {
491 key: WithKey { key, extra: () },
492 encrypted,
493 }),
494 );
495 Ok(encrypted)
496 }));
497 }
498}
499
500pub async fn encrypt_point<K: Key, T: Traversible>(
501 key: K,
502 decrypted: impl 'static + SingularFetch<T = T>,
503) -> object_rainbow::Result<Point<Encrypted<K, T>>> {
504 if let Some((address, decrypt)) = decrypted.extract_resolve::<Decrypt<K>>() {
505 let encrypted = decrypt
506 .resolution
507 .get(address.index)
508 .ok_or(Error::AddressOutOfBounds)?
509 .clone();
510 let point = Point::from_fetch(
511 encrypted.hash(),
512 Arc::new(Visited {
513 decrypted,
514 encrypted,
515 }),
516 );
517 return Ok(point);
518 }
519 let decrypted = decrypted.fetch().await?;
520 let encrypted = encrypt(key.clone(), decrypted).await?;
521 let point = encrypted.point();
522 Ok(point)
523}
524
525pub async fn encrypt<K: Key, T: Traversible>(
526 key: K,
527 decrypted: T,
528) -> object_rainbow::Result<Encrypted<K, T>> {
529 let mut futures = Vec::with_capacity(decrypted.point_count());
530 decrypted.traverse(&mut ExtractResolution {
531 extracted: &mut futures,
532 key: &key,
533 });
534 let resolution = futures_util::future::try_join_all(futures).await?;
535 let resolution = Arc::new(Lp(resolution));
536 let decrypted = Unkeyed(Arc::new(decrypted));
537 let inner = EncryptedInner {
538 resolution,
539 decrypted,
540 };
541 Ok(Encrypted { key, inner })
542}