1use std::{ops::Deref, sync::Arc};
2
3use object_rainbow::{
4 Address, ByteNode, Error, FailFuture, Fetch, FetchBytes, Hash, ListHashes, Node, Object, Parse,
5 ParseInline, ParseSliceExtra, PointInput, PointVisitor, Resolve, Singular, SingularFetch,
6 Tagged, ToOutput, TopoVec, Topological, Traversible, derive_for_wrapped, length_prefixed::Lp,
7 map_extra::MappedExtra, tuple_extra::Extra0,
8};
9use object_rainbow_point::{ExtractResolve, Extras, IntoPoint, Point};
10
11#[derive_for_wrapped]
12pub trait Key: 'static + Sized + Send + Sync + Clone + PartialEq + Eq {
13 type Error: Into<anyhow::Error>;
14 fn encrypt(&self, data: &[u8]) -> Vec<u8>;
15 fn decrypt(&self, data: &[u8]) -> Result<Vec<u8>, Self::Error>;
16}
17
18#[derive(ToOutput, Parse, ParseInline, Clone)]
19struct RawResolve<K> {
20 key: Extras<K>,
21 resolve: Arc<dyn Resolve>,
22 addresses: Arc<Lp<Vec<Address>>>,
23}
24
25impl<K> RawResolve<K> {
26 fn translate(&self, address: Address) -> object_rainbow::Result<Address> {
27 self.addresses
28 .get(address.index)
29 .copied()
30 .ok_or(Error::AddressOutOfBounds)
31 }
32}
33
34fn side_parse<K: Key>(
35 key: &K,
36 data: &[u8],
37 resolve: &Arc<dyn Resolve>,
38) -> object_rainbow::Result<(InnerHeader<K>, Vec<u8>)> {
39 let data = key
40 .decrypt(data)
41 .map_err(object_rainbow::Error::consistency)?;
42 <(InnerHeader<K>, Vec<u8>) as ParseSliceExtra<K>>::parse_slice_extra(&data, resolve, key)
43}
44
45impl<K: Key> Resolve for RawResolve<K> {
46 fn resolve<'a>(
47 &'a self,
48 address: Address,
49 _: &'a Arc<dyn Resolve>,
50 ) -> FailFuture<'a, ByteNode> {
51 Box::pin(async move {
52 let address = self.translate(address)?;
53 let (data, resolve) = self.resolve.resolve(address, &self.resolve).await?;
54 let (InnerHeader { resolve, .. }, data) = side_parse(&self.key.0, &data, &resolve)?;
55 let resolve = Arc::new(resolve) as _;
56 Ok((data, resolve))
57 })
58 }
59
60 fn resolve_data(&'_ self, address: Address) -> FailFuture<'_, Vec<u8>> {
61 Box::pin(async move {
62 let address = self.translate(address)?;
63 let (data, resolve) = self.resolve.resolve(address, &self.resolve).await?;
64 let (_, data) = side_parse(&self.key.0, &data, &resolve)?;
65 Ok(data)
66 })
67 }
68
69 fn try_resolve_local(
70 &self,
71 address: Address,
72 _: &Arc<dyn Resolve>,
73 ) -> object_rainbow::Result<Option<ByteNode>> {
74 let address = self.translate(address)?;
75 let Some((data, resolve)) = self.resolve.try_resolve_local(address, &self.resolve)? else {
76 return Ok(None);
77 };
78 let (InnerHeader { resolve, .. }, data) = side_parse(&self.key.0, &data, &resolve)?;
79 let resolve = Arc::new(resolve) as _;
80 Ok(Some((data, resolve)))
81 }
82}
83
84#[derive(Parse, ParseInline)]
85struct InnerHeader<K> {
86 tags: Hash,
87 resolve: RawResolve<K>,
88}
89
90impl<K: Key> InnerHeader<K> {
91 fn with<T: Topological + Tagged>(self, decrypted: T) -> object_rainbow::Result<Inner<K, T>> {
92 if self.tags != T::HASH {
93 return Err(object_rainbow::error_consistency!("tags mismatch"));
94 }
95 let mut topology = TopoVec::new();
96 let mut v = RawVisit {
97 at: 0,
98 resolve: &self.resolve,
99 visitor: &mut topology,
100 };
101 decrypted.traverse(&mut v);
102 v.done()?;
103 let topology = Arc::new(Lp(topology));
104 let decrypted = Arc::new(decrypted);
105 Ok(Inner {
106 tags: self.tags,
107 key: self.resolve.key,
108 topology,
109 decrypted,
110 })
111 }
112}
113
114#[derive(ToOutput)]
115struct Inner<K, T> {
116 tags: Hash,
117 key: Extras<K>,
118 topology: Arc<Lp<TopoVec>>,
119 decrypted: Arc<T>,
120}
121
122struct RawVisit<'a, K, V> {
123 at: usize,
124 resolve: &'a RawResolve<K>,
125 visitor: &'a mut V,
126}
127
128impl<'a, K, V> RawVisit<'a, K, V> {
129 fn done(self) -> object_rainbow::Result<()> {
130 if self.at != self.resolve.addresses.len() {
131 Err(Error::AddressOutOfBounds)
132 } else {
133 Ok(())
134 }
135 }
136}
137
138#[derive(Clone)]
139struct RawFetch<K, D> {
140 key: K,
141 resolve: Arc<dyn Resolve>,
142 address: Address,
143 decrypted: D,
144}
145
146impl<K, D> FetchBytes for RawFetch<K, D> {
147 fn fetch_bytes(&'_ self) -> FailFuture<'_, ByteNode> {
148 self.resolve.resolve(self.address, &self.resolve)
149 }
150
151 fn fetch_data(&'_ self) -> FailFuture<'_, Vec<u8>> {
152 self.resolve.resolve_data(self.address)
153 }
154
155 fn fetch_bytes_local(&self) -> object_rainbow::Result<Option<ByteNode>> {
156 self.resolve.try_resolve_local(self.address, &self.resolve)
157 }
158}
159
160impl<K: Key, D: Fetch<T: Topological + Tagged>> Fetch for RawFetch<K, D> {
161 type T = Encrypted<K, D::T>;
162
163 fn fetch_full(&'_ self) -> FailFuture<'_, Node<Self::T>> {
164 Box::pin(async move {
165 let (encrypted, resolve) = self.resolve.resolve(self.address, &self.resolve).await?;
166 let (header, _) = side_parse(&self.key, &encrypted, &resolve)?;
167 let decrypted = self.decrypted.fetch().await?;
168 let inner = header.with(decrypted)?;
169 Ok((Encrypted { inner }, resolve))
170 })
171 }
172
173 fn fetch(&'_ self) -> FailFuture<'_, Self::T> {
174 Box::pin(async move {
175 let (encrypted, resolve) = self.resolve.resolve(self.address, &self.resolve).await?;
176 let (header, _) = side_parse(&self.key, &encrypted, &resolve)?;
177 let decrypted = self.decrypted.fetch().await?;
178 let inner = header.with(decrypted)?;
179 Ok(Encrypted { inner })
180 })
181 }
182
183 fn try_fetch_local(&self) -> object_rainbow::Result<Option<Node<Self::T>>> {
184 let Some((encrypted, resolve)) = self
185 .resolve
186 .try_resolve_local(self.address, &self.resolve)?
187 else {
188 return Ok(None);
189 };
190 let (header, _) = side_parse(&self.key, &encrypted, &resolve)?;
191 let Some((decrypted, _)) = self.decrypted.try_fetch_local()? else {
192 return Ok(None);
193 };
194 let inner = header.with(decrypted)?;
195 Ok(Some((Encrypted { inner }, resolve)))
196 }
197}
198
199impl<K: Send + Sync, D: Send + Sync> Singular for RawFetch<K, D> {
200 fn hash(&self) -> Hash {
201 self.address.hash
202 }
203}
204
205impl<'a, K: Key, V: PointVisitor> PointVisitor for RawVisit<'a, K, V> {
206 fn visit<T: Traversible>(&mut self, point: &(impl 'static + SingularFetch<T = T> + Clone)) {
207 let at = self.at;
208 self.at += 1;
209 if let Some(address) = self.resolve.addresses.get(at).copied() {
210 let key = self.resolve.key.0.clone();
211 let resolve = self.resolve.resolve.clone();
212 let decrypted = point.clone();
213 self.visitor.visit(&RawFetch {
214 key,
215 resolve,
216 address,
217 decrypted,
218 });
219 }
220 }
221}
222
223impl<
224 K: Key,
225 T: Object<Extra>,
226 Extra: 'static + Send + Sync + Clone,
227 I: PointInput<Extra = (K, Extra)>,
228> Parse<I> for Inner<K, T>
229{
230 fn parse(mut input: I) -> object_rainbow::Result<Self> {
231 let header = input
232 .parse_inline::<MappedExtra<InnerHeader<K>, Extra0>>()?
233 .1;
234 let extra = input.extra().1.clone();
235 let decrypted = T::parse_slice_extra(
236 &input.parse_all()?,
237 &(Arc::new(header.resolve.clone()) as _),
238 &extra,
239 )?;
240 header.with(decrypted)
241 }
242}
243
244impl<K: Clone, T> Clone for Inner<K, T> {
245 fn clone(&self) -> Self {
246 Self {
247 tags: self.tags,
248 key: self.key.clone(),
249 topology: self.topology.clone(),
250 decrypted: self.decrypted.clone(),
251 }
252 }
253}
254
255#[derive(Clone)]
256struct InnerFetch<K, D> {
257 key: K,
258 encrypted: Arc<dyn Singular>,
259 decrypted: D,
260}
261
262impl<K, D> FetchBytes for InnerFetch<K, D> {
263 fn fetch_bytes(&'_ self) -> FailFuture<'_, ByteNode> {
264 self.encrypted.fetch_bytes()
265 }
266
267 fn fetch_data(&'_ self) -> FailFuture<'_, Vec<u8>> {
268 self.encrypted.fetch_data()
269 }
270
271 fn fetch_bytes_local(&self) -> object_rainbow::Result<Option<ByteNode>> {
272 self.encrypted.fetch_bytes_local()
273 }
274
275 fn fetch_data_local(&self) -> Option<Vec<u8>> {
276 self.encrypted.fetch_data_local()
277 }
278}
279
280impl<K: Key, D: Fetch<T: Topological + Tagged>> Fetch for InnerFetch<K, D> {
281 type T = Encrypted<K, D::T>;
282
283 fn fetch_full(&'_ self) -> FailFuture<'_, Node<Self::T>> {
284 Box::pin(async move {
285 let (encrypted, resolve) = self.encrypted.fetch_bytes().await?;
286 let (header, _) = side_parse(&self.key, &encrypted, &resolve)?;
287 let decrypted = self.decrypted.fetch().await?;
288 let inner = header.with(decrypted)?;
289 Ok((Encrypted { inner }, resolve))
290 })
291 }
292
293 fn fetch(&'_ self) -> FailFuture<'_, Self::T> {
294 Box::pin(async move {
295 let (encrypted, resolve) = self.encrypted.fetch_bytes().await?;
296 let (header, _) = side_parse(&self.key, &encrypted, &resolve)?;
297 let decrypted = self.decrypted.fetch().await?;
298 let inner = header.with(decrypted)?;
299 Ok(Encrypted { inner })
300 })
301 }
302
303 fn try_fetch_local(&self) -> object_rainbow::Result<Option<Node<Self::T>>> {
304 let Some((encrypted, resolve)) = self.encrypted.fetch_bytes_local()? else {
305 return Ok(None);
306 };
307 let (header, _) = side_parse(&self.key, &encrypted, &resolve)?;
308 let Some((decrypted, _)) = self.decrypted.try_fetch_local()? else {
309 return Ok(None);
310 };
311 let inner = header.with(decrypted)?;
312 Ok(Some((Encrypted { inner }, resolve)))
313 }
314}
315
316impl<K: Send + Sync, D: Send + Sync> Singular for InnerFetch<K, D> {
317 fn hash(&self) -> Hash {
318 self.encrypted.hash()
319 }
320}
321
322struct IterateResolution<'a, 'r, K, V> {
323 key: &'a K,
324 topology: &'r mut std::slice::Iter<'a, Arc<dyn Singular>>,
325 visitor: &'a mut V,
326}
327
328impl<'a, K: Key, V: PointVisitor> PointVisitor for IterateResolution<'a, '_, K, V> {
329 fn visit<T: Traversible>(&mut self, decrypted: &(impl 'static + SingularFetch<T = T> + Clone)) {
330 let decrypted = decrypted.clone();
331 let encrypted = self.topology.next().expect("length mismatch").clone();
332 let point = Point::from_fetch(
333 encrypted.hash(),
334 InnerFetch {
335 key: self.key.clone(),
336 decrypted,
337 encrypted,
338 }
339 .into_dyn_fetch(),
340 );
341 self.visitor.visit(&point);
342 }
343}
344
345impl<K, T> ListHashes for Inner<K, T> {
346 fn list_hashes(&self, f: &mut impl FnMut(Hash)) {
347 self.topology.list_hashes(f);
348 }
349
350 fn topology_hash(&self) -> Hash {
351 self.topology.0.data_hash()
352 }
353
354 fn point_count(&self) -> usize {
355 self.topology.len()
356 }
357}
358
359impl<K: Key, T: Topological> Topological for Inner<K, T> {
360 fn traverse(&self, visitor: &mut impl PointVisitor) {
361 let topology = &mut self.topology.iter();
362 self.decrypted.traverse(&mut IterateResolution {
363 key: &self.key.0,
364 topology,
365 visitor,
366 });
367 assert!(topology.next().is_none());
368 }
369
370 fn topology(&self) -> TopoVec {
371 self.topology.0.clone()
372 }
373}
374
375pub struct Encrypted<K, T> {
376 inner: Inner<K, T>,
377}
378
379impl<K, T: Clone> Encrypted<K, T> {
380 pub fn into_inner(self) -> T {
381 Arc::unwrap_or_clone(self.inner.decrypted)
382 }
383}
384
385impl<K, T> Deref for Encrypted<K, T> {
386 type Target = T;
387
388 fn deref(&self) -> &Self::Target {
389 self.inner.decrypted.as_ref()
390 }
391}
392
393impl<K: Clone, T> Clone for Encrypted<K, T> {
394 fn clone(&self) -> Self {
395 Self {
396 inner: self.inner.clone(),
397 }
398 }
399}
400
401impl<K, T> ListHashes for Encrypted<K, T> {
402 fn list_hashes(&self, f: &mut impl FnMut(Hash)) {
403 self.inner.list_hashes(f);
404 }
405
406 fn topology_hash(&self) -> Hash {
407 self.inner.topology_hash()
408 }
409
410 fn point_count(&self) -> usize {
411 self.inner.point_count()
412 }
413}
414
415impl<K: Key, T: Topological> Topological for Encrypted<K, T> {
416 fn traverse(&self, visitor: &mut impl PointVisitor) {
417 self.inner.traverse(visitor);
418 }
419}
420
421impl<K: Key, T: ToOutput> ToOutput for Encrypted<K, T> {
422 fn to_output(&self, output: &mut impl object_rainbow::Output) {
423 if output.is_mangling() {
424 output.write(
425 &self
426 .inner
427 .key
428 .encrypt(b"this encrypted constant is followed by an unencrypted inner hash"),
429 );
430 self.inner.decrypted.data_hash();
431 }
432 if output.is_real() {
433 let source = self.inner.vec();
434 output.write(&self.inner.key.encrypt(&source));
435 }
436 }
437}
438
439trait EncryptedExtra<K>: 'static + Send + Sync + Clone {
440 type Extra: 'static + Send + Sync + Clone;
441 fn parts(&self) -> (K, Self::Extra);
442}
443
444impl<K: 'static + Send + Sync + Clone, Extra: 'static + Send + Sync + Clone> EncryptedExtra<K>
445 for (K, Extra)
446{
447 type Extra = Extra;
448
449 fn parts(&self) -> (K, Self::Extra) {
450 self.clone()
451 }
452}
453
454impl<K: 'static + Send + Sync + Clone> EncryptedExtra<K> for K {
455 type Extra = ();
456
457 fn parts(&self) -> (K, Self::Extra) {
458 (self.clone(), ())
459 }
460}
461
462impl<
463 K: Key,
464 T: Object<Extra>,
465 Extra: 'static + Send + Sync + Clone,
466 I: PointInput<Extra: EncryptedExtra<K, Extra = Extra>>,
467> Parse<I> for Encrypted<K, T>
468{
469 fn parse(input: I) -> object_rainbow::Result<Self> {
470 let with_key = input.extra().parts();
471 let resolve = input.resolve().clone();
472 let source = with_key
473 .0
474 .decrypt(&input.parse_all()?)
475 .map_err(object_rainbow::Error::consistency)?;
476 let inner = Inner::<K, T>::parse_slice_extra(&source, &resolve, &with_key)?;
477 Ok(Self { inner })
478 }
479}
480
481impl<K, T> Tagged for Encrypted<K, T> {}
482
483type Extracted =
484 Vec<std::pin::Pin<Box<dyn Future<Output = Result<Arc<dyn Singular>, Error>> + Send + 'static>>>;
485
486struct ExtractResolution<'a, K> {
487 extracted: &'a mut Extracted,
488 key: &'a K,
489}
490
491impl<K: Key> PointVisitor for ExtractResolution<'_, K> {
492 fn visit<T: Traversible>(&mut self, decrypted: &(impl 'static + SingularFetch<T = T> + Clone)) {
493 let decrypted = decrypted.clone();
494 let key = self.key.clone();
495 self.extracted.push(Box::pin(async move {
496 Ok(Arc::new(encrypt_point(key, decrypted).await?) as _)
497 }));
498 }
499}
500
501pub async fn encrypt_point<K: Key, T: Traversible>(
502 key: K,
503 decrypted: impl 'static + SingularFetch<T = T>,
504) -> object_rainbow::Result<Point<Encrypted<K, T>>> {
505 if let Some((address, resolve)) = decrypted.extract_resolve::<RawResolve<K>>()
506 && resolve.key.0 == key
507 {
508 let address = resolve.translate(*address)?;
509 let point = Point::from_fetch(
510 address.hash,
511 RawFetch {
512 key,
513 resolve: resolve.resolve.clone(),
514 address,
515 decrypted,
516 }
517 .into_dyn_fetch(),
518 );
519 return Ok(point);
520 }
521 let decrypted = decrypted.fetch().await?;
522 let encrypted = encrypt(key, decrypted).await?;
523 let point = encrypted.point();
524 Ok(point)
525}
526
527pub async fn encrypt<K: Key, T: Traversible>(
528 key: K,
529 decrypted: T,
530) -> object_rainbow::Result<Encrypted<K, T>> {
531 let mut futures = Vec::with_capacity(decrypted.point_count());
532 decrypted.traverse(&mut ExtractResolution {
533 extracted: &mut futures,
534 key: &key,
535 });
536 let topology = futures_util::future::try_join_all(futures).await?;
537 let topology = Arc::new(Lp(topology));
538 let decrypted = Arc::new(decrypted);
539 let inner = Inner {
540 tags: T::HASH,
541 key: Extras(key),
542 topology,
543 decrypted,
544 };
545 Ok(Encrypted { inner })
546}