1use std::{ops::Deref, sync::Arc};
2
3use object_rainbow::{
4 Address, ByteNode, Error, FailFuture, Fetch, FetchBytes, Hash, Object, Parse, ParseSliceExtra,
5 Point, PointInput, PointVisitor, Resolve, Singular, Tagged, ToOutput, Topological, Traversible,
6 length_prefixed::Lp,
7};
8
9#[derive(Clone)]
10pub struct WithKey<K, Extra> {
11 pub key: K,
12 pub extra: Extra,
13}
14
15pub trait Key: 'static + Sized + Send + Sync + Clone {
16 fn encrypt(&self, data: &[u8]) -> Vec<u8>;
17 fn decrypt(&self, data: &[u8]) -> object_rainbow::Result<Vec<u8>>;
18}
19
20type Resolution<K> = Arc<Lp<Vec<Point<Encrypted<K, Vec<u8>>>>>>;
21
22#[derive(ToOutput, Clone)]
23struct Unkeyed<T>(T);
24
25impl<
26 T: Parse<I::WithExtra<Extra>>,
27 K: 'static + Clone,
28 Extra: 'static + Clone,
29 I: PointInput<Extra = WithKey<K, Extra>>,
30> Parse<I> for Unkeyed<T>
31{
32 fn parse(input: I) -> object_rainbow::Result<Self> {
33 Ok(Self(T::parse(
34 input.map_extra(|WithKey { extra, .. }| extra),
35 )?))
36 }
37}
38
39#[derive(ToOutput, Parse)]
40struct EncryptedInner<K, T> {
41 resolution: Resolution<K>,
42 decrypted: Unkeyed<Arc<T>>,
43}
44
45impl<K, T> Clone for EncryptedInner<K, T> {
46 fn clone(&self) -> Self {
47 Self {
48 resolution: self.resolution.clone(),
49 decrypted: self.decrypted.clone(),
50 }
51 }
52}
53
54type ResolutionIter<'a, K> = std::slice::Iter<'a, Point<Encrypted<K, Vec<u8>>>>;
55
56struct IterateResolution<'a, K, V> {
57 resolution: ResolutionIter<'a, K>,
58 visitor: &'a mut V,
59}
60
61struct Visited<K, T> {
62 decrypted: Point<T>,
63 encrypted: Point<Encrypted<K, Vec<u8>>>,
64}
65
66impl<K, T> FetchBytes for Visited<K, T> {
67 fn fetch_bytes(&'_ self) -> FailFuture<'_, ByteNode> {
68 self.encrypted.fetch_bytes()
69 }
70
71 fn fetch_data(&'_ self) -> FailFuture<'_, Vec<u8>> {
72 self.encrypted.fetch_data()
73 }
74}
75
76impl<K: Key, T: Traversible> Fetch for Visited<K, T> {
77 type T = Encrypted<K, T>;
78
79 fn fetch_full(&'_ self) -> FailFuture<'_, (Self::T, Arc<dyn Resolve>)> {
80 todo!()
81 }
82
83 fn fetch(&'_ self) -> FailFuture<'_, Self::T> {
84 Box::pin(async move {
85 let Encrypted {
86 key,
87 inner:
88 EncryptedInner {
89 resolution,
90 decrypted: _,
91 },
92 } = self.encrypted.fetch().await?;
93 let decrypted = self.decrypted.fetch().await?;
94 let decrypted = Unkeyed(Arc::new(decrypted));
95 Ok(Encrypted {
96 key,
97 inner: EncryptedInner {
98 resolution,
99 decrypted,
100 },
101 })
102 })
103 }
104}
105
106impl<'a, K: Key, V: PointVisitor> PointVisitor for IterateResolution<'a, K, V> {
107 fn visit<T: Traversible>(&mut self, decrypted: &Point<T>) {
108 let decrypted = decrypted.clone();
109 let encrypted = self.resolution.next().expect("length mismatch").clone();
110 let point = Point::from_origin(
111 encrypted.hash(),
112 Arc::new(Visited {
113 decrypted,
114 encrypted,
115 }),
116 );
117 self.visitor.visit(&point);
118 }
119}
120
121impl<K: Key, T: Topological> Topological for EncryptedInner<K, T> {
122 fn accept_points(&self, visitor: &mut impl PointVisitor) {
123 self.decrypted.0.accept_points(&mut IterateResolution {
124 resolution: self.resolution.iter(),
125 visitor,
126 });
127 }
128
129 fn point_count(&self) -> usize {
130 self.resolution.len()
131 }
132
133 fn topology_hash(&self) -> Hash {
134 self.resolution.0.data_hash()
135 }
136}
137
138pub struct Encrypted<K, T> {
139 key: K,
140 inner: EncryptedInner<K, T>,
141}
142
143impl<K, T: Clone> Encrypted<K, T> {
144 pub fn into_inner(self) -> T {
145 Arc::unwrap_or_clone(self.inner.decrypted.0)
146 }
147}
148
149impl<K, T> Deref for Encrypted<K, T> {
150 type Target = T;
151
152 fn deref(&self) -> &Self::Target {
153 self.inner.decrypted.0.as_ref()
154 }
155}
156
157impl<K: Clone, T> Clone for Encrypted<K, T> {
158 fn clone(&self) -> Self {
159 Self {
160 key: self.key.clone(),
161 inner: self.inner.clone(),
162 }
163 }
164}
165
166impl<K: Key, T: Topological> Topological for Encrypted<K, T> {
167 fn accept_points(&self, visitor: &mut impl PointVisitor) {
168 self.inner.accept_points(visitor);
169 }
170
171 fn topology_hash(&self) -> Hash {
172 self.inner.topology_hash()
173 }
174}
175
176impl<K: Key, T: ToOutput> ToOutput for Encrypted<K, T> {
177 fn to_output(&self, output: &mut dyn object_rainbow::Output) {
178 let source = self.inner.vec();
179 output.write(&self.key.encrypt(&source));
180 }
181}
182
183#[derive(Clone)]
184struct Decrypt<K> {
185 resolution: Resolution<K>,
186}
187
188impl<K: Key> Decrypt<K> {
189 async fn resolve_bytes(
190 &self,
191 address: Address,
192 ) -> object_rainbow::Result<(Vec<u8>, Resolution<K>)> {
193 let Encrypted {
194 key: _,
195 inner:
196 EncryptedInner {
197 resolution,
198 decrypted,
199 },
200 } = self
201 .resolution
202 .get(address.index)
203 .ok_or(Error::AddressOutOfBounds)?
204 .clone()
205 .fetch()
206 .await?;
207 let data = Arc::unwrap_or_clone(decrypted.0);
208 Ok((data, resolution))
209 }
210}
211
212impl<K: Key> Resolve for Decrypt<K> {
213 fn resolve(&'_ self, address: Address) -> FailFuture<'_, ByteNode> {
214 Box::pin(async move {
215 let (data, resolution) = self.resolve_bytes(address).await?;
216 Ok((data, Arc::new(Decrypt { resolution }) as _))
217 })
218 }
219
220 fn resolve_data(&'_ self, address: Address) -> FailFuture<'_, Vec<u8>> {
221 Box::pin(async move {
222 let (data, _) = self.resolve_bytes(address).await?;
223 Ok(data)
224 })
225 }
226
227 fn name(&self) -> &str {
228 "decrypt"
229 }
230}
231
232impl<
233 K: Key,
234 T: Object<Extra>,
235 Extra: 'static + Send + Sync + Clone,
236 I: PointInput<Extra = WithKey<K, Extra>>,
237> Parse<I> for Encrypted<K, T>
238{
239 fn parse(input: I) -> object_rainbow::Result<Self> {
240 let with_key = input.extra().clone();
241 let resolve = input.resolve().clone();
242 let source = with_key.key.decrypt(input.parse_all()?)?;
243 let EncryptedInner {
244 resolution,
245 decrypted,
246 } = EncryptedInner::<K, Vec<u8>>::parse_slice_extra(&source, &resolve, &with_key)?;
247 let decrypted = T::parse_slice_extra(
248 &decrypted.0,
249 &(Arc::new(Decrypt {
250 resolution: resolution.clone(),
251 }) as _),
252 &with_key.extra,
253 )?;
254 let decrypted = Unkeyed(Arc::new(decrypted));
255 let inner = EncryptedInner {
256 resolution,
257 decrypted,
258 };
259 Ok(Self {
260 key: with_key.key,
261 inner,
262 })
263 }
264}
265
266impl<K, T> Tagged for Encrypted<K, T> {}
267
268impl<K: Key, T: Object<Extra>, Extra: 'static + Send + Sync + Clone> Object<WithKey<K, Extra>>
269 for Encrypted<K, T>
270{
271}
272
273type Extracted<K> = Vec<
274 std::pin::Pin<
275 Box<dyn Future<Output = Result<Point<Encrypted<K, Vec<u8>>>, Error>> + Send + 'static>,
276 >,
277>;
278
279struct ExtractResolution<'a, K> {
280 extracted: &'a mut Extracted<K>,
281 key: &'a K,
282}
283
284struct Untyped<K, T> {
285 key: WithKey<K, ()>,
286 encrypted: Point<Encrypted<K, T>>,
287}
288
289impl<K, T> FetchBytes for Untyped<K, T> {
290 fn fetch_bytes(&'_ self) -> FailFuture<'_, ByteNode> {
291 self.encrypted.fetch_bytes()
292 }
293
294 fn fetch_data(&'_ self) -> FailFuture<'_, Vec<u8>> {
295 self.encrypted.fetch_data()
296 }
297}
298
299impl<K: Key, T> Fetch for Untyped<K, T> {
300 type T = Encrypted<K, Vec<u8>>;
301
302 fn fetch_full(&'_ self) -> FailFuture<'_, (Self::T, Arc<dyn Resolve>)> {
303 Box::pin(async move {
304 let (data, resolve) = self.fetch_bytes().await?;
305 let encrypted = Self::T::parse_slice_extra(&data, &resolve, &self.key)?;
306 Ok((encrypted, resolve))
307 })
308 }
309
310 fn fetch(&'_ self) -> FailFuture<'_, Self::T> {
311 Box::pin(async move {
312 let (data, resolve) = self.fetch_bytes().await?;
313 let encrypted = Self::T::parse_slice_extra(&data, &resolve, &self.key)?;
314 Ok(encrypted)
315 })
316 }
317}
318
319impl<K: Key> PointVisitor for ExtractResolution<'_, K> {
320 fn visit<T: Traversible>(&mut self, decrypted: &Point<T>) {
321 let decrypted = decrypted.clone();
322 let key = self.key.clone();
323 self.extracted.push(Box::pin(async move {
324 let encrypted = encrypt_point(key.clone(), decrypted).await?;
325 let encrypted = Point::from_origin(
326 encrypted.hash(),
327 Arc::new(Untyped {
328 key: WithKey { key, extra: () },
329 encrypted,
330 }),
331 );
332 Ok(encrypted)
333 }));
334 }
335}
336
337pub async fn encrypt_point<K: Key, T: Traversible>(
338 key: K,
339 decrypted: Point<T>,
340) -> object_rainbow::Result<Point<Encrypted<K, T>>> {
341 if let Some((address, decrypt)) = decrypted.extract_resolve::<Decrypt<K>>() {
342 let encrypted = decrypt
343 .resolution
344 .get(address.index)
345 .ok_or(Error::AddressOutOfBounds)?
346 .clone();
347 let point = Point::from_origin(
348 encrypted.hash(),
349 Arc::new(Visited {
350 decrypted,
351 encrypted,
352 }),
353 );
354 return Ok(point);
355 }
356 let decrypted = decrypted.fetch().await?;
357 let encrypted = encrypt(key.clone(), decrypted).await?;
358 let point = encrypted.point();
359 Ok(point)
360}
361
362pub async fn encrypt<K: Key, T: Traversible>(
363 key: K,
364 decrypted: T,
365) -> object_rainbow::Result<Encrypted<K, T>> {
366 let mut futures = Vec::with_capacity(decrypted.point_count());
367 decrypted.accept_points(&mut ExtractResolution {
368 extracted: &mut futures,
369 key: &key,
370 });
371 let resolution = futures_util::future::try_join_all(futures).await?;
372 let resolution = Arc::new(Lp(resolution));
373 let decrypted = Unkeyed(Arc::new(decrypted));
374 let inner = EncryptedInner {
375 resolution,
376 decrypted,
377 };
378 Ok(Encrypted { key, inner })
379}