1use std::{ops::Deref, sync::Arc};
2
3use object_rainbow::{
4 Address, ByteNode, Error, FailFuture, Fetch, FetchBytes, FullHash, Hash, Node, Object, Parse,
5 ParseSliceExtra, Point, PointInput, PointVisitor, Resolve, Singular, Tagged, ToOutput,
6 Topological, Traversible, length_prefixed::Lp,
7};
8
9#[derive(Clone)]
10pub struct WithKey<K, Extra> {
11 pub key: K,
12 pub extra: Extra,
13}
14
15pub trait Key: 'static + Sized + Send + Sync + Clone {
16 fn encrypt(&self, data: &[u8]) -> Vec<u8>;
17 fn decrypt(&self, data: &[u8]) -> object_rainbow::Result<Vec<u8>>;
18}
19
20type Resolution<K> = Arc<Lp<Vec<Point<Encrypted<K, Vec<u8>>>>>>;
21
22#[derive(ToOutput, Clone)]
23struct Unkeyed<T>(T);
24
25impl<
26 T: Parse<I::WithExtra<Extra>>,
27 K: 'static + Clone,
28 Extra: 'static + Clone,
29 I: PointInput<Extra = WithKey<K, Extra>>,
30> Parse<I> for Unkeyed<T>
31{
32 fn parse(input: I) -> object_rainbow::Result<Self> {
33 Ok(Self(T::parse(
34 input.map_extra(|WithKey { extra, .. }| extra),
35 )?))
36 }
37}
38
39#[derive(ToOutput, Parse)]
40struct EncryptedInner<K, T> {
41 resolution: Resolution<K>,
42 decrypted: Unkeyed<Arc<T>>,
43}
44
45impl<K, T> Clone for EncryptedInner<K, T> {
46 fn clone(&self) -> Self {
47 Self {
48 resolution: self.resolution.clone(),
49 decrypted: self.decrypted.clone(),
50 }
51 }
52}
53
54type ResolutionIter<'a, K> = std::slice::Iter<'a, Point<Encrypted<K, Vec<u8>>>>;
55
56struct IterateResolution<'a, 'r, K, V> {
57 resolution: &'r mut ResolutionIter<'a, K>,
58 visitor: &'a mut V,
59}
60
61struct Visited<K, T> {
62 decrypted: Point<T>,
63 encrypted: Point<Encrypted<K, Vec<u8>>>,
64}
65
66impl<K, T> FetchBytes for Visited<K, T> {
67 fn fetch_bytes(&'_ self) -> FailFuture<'_, ByteNode> {
68 self.encrypted.fetch_bytes()
69 }
70
71 fn fetch_data(&'_ self) -> FailFuture<'_, Vec<u8>> {
72 self.encrypted.fetch_data()
73 }
74
75 fn fetch_bytes_local(&self) -> object_rainbow::Result<Option<ByteNode>> {
76 self.encrypted.fetch_bytes_local()
77 }
78
79 fn fetch_data_local(&self) -> Option<Vec<u8>> {
80 self.encrypted.fetch_data_local()
81 }
82}
83
84impl<K, T> Singular for Visited<K, T> {
85 fn hash(&self) -> Hash {
86 self.encrypted.hash()
87 }
88}
89
90impl<K: Key, T: Traversible> Fetch for Visited<K, T> {
91 type T = Encrypted<K, T>;
92
93 fn fetch_full(&'_ self) -> FailFuture<'_, Node<Self::T>> {
94 Box::pin(async move {
95 let (
96 Encrypted {
97 key,
98 inner:
99 EncryptedInner {
100 resolution,
101 decrypted: _,
102 },
103 },
104 resolve,
105 ) = self.encrypted.fetch_full().await?;
106 let decrypted = self.decrypted.fetch().await?;
107 let decrypted = Unkeyed(Arc::new(decrypted));
108 Ok((
109 Encrypted {
110 key,
111 inner: EncryptedInner {
112 resolution,
113 decrypted,
114 },
115 },
116 resolve,
117 ))
118 })
119 }
120
121 fn fetch(&'_ self) -> FailFuture<'_, Self::T> {
122 Box::pin(async move {
123 let Encrypted {
124 key,
125 inner:
126 EncryptedInner {
127 resolution,
128 decrypted: _,
129 },
130 } = self.encrypted.fetch().await?;
131 let decrypted = self.decrypted.fetch().await?;
132 let decrypted = Unkeyed(Arc::new(decrypted));
133 Ok(Encrypted {
134 key,
135 inner: EncryptedInner {
136 resolution,
137 decrypted,
138 },
139 })
140 })
141 }
142
143 fn try_fetch_local(&self) -> object_rainbow::Result<Option<Node<Self::T>>> {
144 let Some((
145 Encrypted {
146 key,
147 inner:
148 EncryptedInner {
149 resolution,
150 decrypted: _,
151 },
152 },
153 resolve,
154 )) = self.encrypted.try_fetch_local()?
155 else {
156 return Ok(None);
157 };
158 let Some((decrypted, _)) = self.decrypted.try_fetch_local()? else {
159 return Ok(None);
160 };
161 let decrypted = Unkeyed(Arc::new(decrypted));
162 Ok(Some((
163 Encrypted {
164 key,
165 inner: EncryptedInner {
166 resolution,
167 decrypted,
168 },
169 },
170 resolve,
171 )))
172 }
173
174 fn fetch_local(&self) -> Option<Self::T> {
175 let Encrypted {
176 key,
177 inner:
178 EncryptedInner {
179 resolution,
180 decrypted: _,
181 },
182 } = self.encrypted.fetch_local()?;
183 let decrypted = Unkeyed(Arc::new(self.decrypted.fetch_local()?));
184 Some(Encrypted {
185 key,
186 inner: EncryptedInner {
187 resolution,
188 decrypted,
189 },
190 })
191 }
192}
193
194impl<'a, K: Key, V: PointVisitor> PointVisitor for IterateResolution<'a, '_, K, V> {
195 fn visit<T: Traversible>(&mut self, decrypted: &Point<T>) {
196 let decrypted = decrypted.clone();
197 let encrypted = self.resolution.next().expect("length mismatch").clone();
198 let point = Point::from_fetch(
199 encrypted.hash(),
200 Arc::new(Visited {
201 decrypted,
202 encrypted,
203 }),
204 );
205 self.visitor.visit(&point);
206 }
207}
208
209impl<K: Key, T: Topological> Topological for EncryptedInner<K, T> {
210 fn accept_points(&self, visitor: &mut impl PointVisitor) {
211 let resolution = &mut self.resolution.iter();
212 self.decrypted.0.accept_points(&mut IterateResolution {
213 resolution,
214 visitor,
215 });
216 assert!(resolution.next().is_none());
217 }
218
219 fn point_count(&self) -> usize {
220 self.resolution.len()
221 }
222
223 fn topology_hash(&self) -> Hash {
224 self.resolution.0.data_hash()
225 }
226}
227
228pub struct Encrypted<K, T> {
229 key: K,
230 inner: EncryptedInner<K, T>,
231}
232
233impl<K, T: Clone> Encrypted<K, T> {
234 pub fn into_inner(self) -> T {
235 Arc::unwrap_or_clone(self.inner.decrypted.0)
236 }
237}
238
239impl<K, T> Deref for Encrypted<K, T> {
240 type Target = T;
241
242 fn deref(&self) -> &Self::Target {
243 self.inner.decrypted.0.as_ref()
244 }
245}
246
247impl<K: Clone, T> Clone for Encrypted<K, T> {
248 fn clone(&self) -> Self {
249 Self {
250 key: self.key.clone(),
251 inner: self.inner.clone(),
252 }
253 }
254}
255
256impl<K: Key, T: Topological> Topological for Encrypted<K, T> {
257 fn accept_points(&self, visitor: &mut impl PointVisitor) {
258 self.inner.accept_points(visitor);
259 }
260
261 fn topology_hash(&self) -> Hash {
262 self.inner.topology_hash()
263 }
264}
265
266impl<K: Key, T: ToOutput> ToOutput for Encrypted<K, T> {
267 fn to_output(&self, output: &mut dyn object_rainbow::Output) {
268 let source = self.inner.vec();
269 output.write(&self.key.encrypt(&source));
270 }
271}
272
273#[derive(Clone)]
274struct Decrypt<K> {
275 resolution: Resolution<K>,
276}
277
278impl<K: Key> Decrypt<K> {
279 async fn resolve_bytes(
280 &self,
281 address: Address,
282 ) -> object_rainbow::Result<(Vec<u8>, Resolution<K>)> {
283 let Encrypted {
284 key: _,
285 inner:
286 EncryptedInner {
287 resolution,
288 decrypted,
289 },
290 } = self
291 .resolution
292 .get(address.index)
293 .ok_or(Error::AddressOutOfBounds)?
294 .clone()
295 .fetch()
296 .await?;
297 let data = Arc::unwrap_or_clone(decrypted.0);
298 Ok((data, resolution))
299 }
300}
301
302impl<K: Key> Resolve for Decrypt<K> {
303 fn resolve(&'_ self, address: Address) -> FailFuture<'_, ByteNode> {
304 Box::pin(async move {
305 let (data, resolution) = self.resolve_bytes(address).await?;
306 Ok((data, Arc::new(Decrypt { resolution }) as _))
307 })
308 }
309
310 fn resolve_data(&'_ self, address: Address) -> FailFuture<'_, Vec<u8>> {
311 Box::pin(async move {
312 let (data, _) = self.resolve_bytes(address).await?;
313 Ok(data)
314 })
315 }
316
317 fn try_resolve_local(&self, address: Address) -> object_rainbow::Result<Option<ByteNode>> {
318 let Some((
319 Encrypted {
320 key: _,
321 inner:
322 EncryptedInner {
323 resolution,
324 decrypted,
325 },
326 },
327 _,
328 )) = self
329 .resolution
330 .get(address.index)
331 .ok_or(Error::AddressOutOfBounds)?
332 .clone()
333 .try_fetch_local()?
334 else {
335 return Ok(None);
336 };
337 let data = Arc::unwrap_or_clone(decrypted.0);
338 Ok(Some((data, Arc::new(Decrypt { resolution }) as _)))
339 }
340}
341
342impl<
343 K: Key,
344 T: Object<Extra>,
345 Extra: 'static + Send + Sync + Clone,
346 I: PointInput<Extra = WithKey<K, Extra>>,
347> Parse<I> for Encrypted<K, T>
348{
349 fn parse(input: I) -> object_rainbow::Result<Self> {
350 let with_key = input.extra().clone();
351 let resolve = input.resolve().clone();
352 let source = with_key.key.decrypt(&input.parse_all()?)?;
353 let EncryptedInner {
354 resolution,
355 decrypted,
356 } = EncryptedInner::<K, Vec<u8>>::parse_slice_extra(&source, &resolve, &with_key)?;
357 let decrypted = T::parse_slice_extra(
358 &decrypted.0,
359 &(Arc::new(Decrypt {
360 resolution: resolution.clone(),
361 }) as _),
362 &with_key.extra,
363 )?;
364 let decrypted = Unkeyed(Arc::new(decrypted));
365 let inner = EncryptedInner {
366 resolution,
367 decrypted,
368 };
369 Ok(Self {
370 key: with_key.key,
371 inner,
372 })
373 }
374}
375
376impl<K, T> Tagged for Encrypted<K, T> {}
377
378impl<K: Key, T: Object<Extra>, Extra: 'static + Send + Sync + Clone> Object<WithKey<K, Extra>>
379 for Encrypted<K, T>
380{
381}
382
383type Extracted<K> = Vec<
384 std::pin::Pin<
385 Box<dyn Future<Output = Result<Point<Encrypted<K, Vec<u8>>>, Error>> + Send + 'static>,
386 >,
387>;
388
389struct ExtractResolution<'a, K> {
390 extracted: &'a mut Extracted<K>,
391 key: &'a K,
392}
393
394struct Untyped<K, T> {
395 key: WithKey<K, ()>,
396 encrypted: Point<Encrypted<K, T>>,
397}
398
399impl<K, T> FetchBytes for Untyped<K, T> {
400 fn fetch_bytes(&'_ self) -> FailFuture<'_, ByteNode> {
401 self.encrypted.fetch_bytes()
402 }
403
404 fn fetch_data(&'_ self) -> FailFuture<'_, Vec<u8>> {
405 self.encrypted.fetch_data()
406 }
407
408 fn fetch_bytes_local(&self) -> object_rainbow::Result<Option<ByteNode>> {
409 self.encrypted.fetch_bytes_local()
410 }
411
412 fn fetch_data_local(&self) -> Option<Vec<u8>> {
413 self.encrypted.fetch_data_local()
414 }
415}
416
417impl<K: Send + Sync, T> Singular for Untyped<K, T> {
418 fn hash(&self) -> Hash {
419 self.encrypted.hash()
420 }
421}
422
423impl<K: Key, T: FullHash> Fetch for Untyped<K, T> {
424 type T = Encrypted<K, Vec<u8>>;
425
426 fn fetch_full(&'_ self) -> FailFuture<'_, Node<Self::T>> {
427 Box::pin(async move {
428 let (data, resolve) = self.fetch_bytes().await?;
429 let encrypted = Self::T::parse_slice_extra(&data, &resolve, &self.key)?;
430 Ok((encrypted, resolve))
431 })
432 }
433
434 fn fetch(&'_ self) -> FailFuture<'_, Self::T> {
435 Box::pin(async move {
436 let (data, resolve) = self.fetch_bytes().await?;
437 let encrypted = Self::T::parse_slice_extra(&data, &resolve, &self.key)?;
438 Ok(encrypted)
439 })
440 }
441
442 fn try_fetch_local(&self) -> object_rainbow::Result<Option<Node<Self::T>>> {
443 let Some((data, resolve)) = self.fetch_bytes_local()? else {
444 return Ok(None);
445 };
446 let encrypted = Self::T::parse_slice_extra(&data, &resolve, &self.key)?;
447 Ok(Some((encrypted, resolve)))
448 }
449
450 fn fetch_local(&self) -> Option<Self::T> {
451 let Encrypted {
452 key,
453 inner:
454 EncryptedInner {
455 resolution,
456 decrypted,
457 },
458 } = self.encrypted.fetch_local()?;
459 let decrypted = Unkeyed(Arc::new(decrypted.vec()));
460 Some(Encrypted {
461 key,
462 inner: EncryptedInner {
463 resolution,
464 decrypted,
465 },
466 })
467 }
468}
469
470impl<K: Key> PointVisitor for ExtractResolution<'_, K> {
471 fn visit<T: Traversible>(&mut self, decrypted: &Point<T>) {
472 let decrypted = decrypted.clone();
473 let key = self.key.clone();
474 self.extracted.push(Box::pin(async move {
475 let encrypted = encrypt_point(key.clone(), decrypted).await?;
476 let encrypted = Point::from_fetch(
477 encrypted.hash(),
478 Arc::new(Untyped {
479 key: WithKey { key, extra: () },
480 encrypted,
481 }),
482 );
483 Ok(encrypted)
484 }));
485 }
486}
487
488pub async fn encrypt_point<K: Key, T: Traversible>(
489 key: K,
490 decrypted: Point<T>,
491) -> object_rainbow::Result<Point<Encrypted<K, T>>> {
492 if let Some((address, decrypt)) = decrypted.extract_resolve::<Decrypt<K>>() {
493 let encrypted = decrypt
494 .resolution
495 .get(address.index)
496 .ok_or(Error::AddressOutOfBounds)?
497 .clone();
498 let point = Point::from_fetch(
499 encrypted.hash(),
500 Arc::new(Visited {
501 decrypted,
502 encrypted,
503 }),
504 );
505 return Ok(point);
506 }
507 let decrypted = decrypted.fetch().await?;
508 let encrypted = encrypt(key.clone(), decrypted).await?;
509 let point = encrypted.point();
510 Ok(point)
511}
512
513pub async fn encrypt<K: Key, T: Traversible>(
514 key: K,
515 decrypted: T,
516) -> object_rainbow::Result<Encrypted<K, T>> {
517 let mut futures = Vec::with_capacity(decrypted.point_count());
518 decrypted.accept_points(&mut ExtractResolution {
519 extracted: &mut futures,
520 key: &key,
521 });
522 let resolution = futures_util::future::try_join_all(futures).await?;
523 let resolution = Arc::new(Lp(resolution));
524 let decrypted = Unkeyed(Arc::new(decrypted));
525 let inner = EncryptedInner {
526 resolution,
527 decrypted,
528 };
529 Ok(Encrypted { key, inner })
530}