ps_hkey/async_store/
mixed.rs

1use std::{
2    future::Future,
3    mem::{self, replace},
4    ops::Deref,
5    sync::Arc,
6};
7
8use futures::FutureExt;
9use parking_lot::RwLock;
10use ps_datachunk::{DataChunk, OwnedDataChunk, PsDataChunkError};
11use ps_hash::Hash;
12use ps_promise::{Promise, PromiseRejection};
13
14use crate::{store::combined::DynStore, AsyncStore, PsHkeyError, Store};
15
16pub trait DynAsyncStore: Send + Sync {
17    type Error: From<PsDataChunkError> + From<PsHkeyError> + PromiseRejection + Send + 'static;
18
19    fn get(&self, hash: Arc<Hash>) -> Promise<OwnedDataChunk, Self::Error>;
20    fn put_encrypted(&self, chunk: OwnedDataChunk) -> Promise<(), Self::Error>;
21}
22
23impl<T> DynAsyncStore for T
24where
25    T: AsyncStore + Send + Sync + 'static,
26{
27    type Error = T::Error;
28
29    fn get(&self, hash: Arc<Hash>) -> Promise<OwnedDataChunk, Self::Error> {
30        let store = self.clone();
31
32        Promise::new(async move {
33            let hash = hash;
34
35            Ok(AsyncStore::get(&store, &hash).await?.into_owned())
36        })
37    }
38
39    fn put_encrypted(&self, chunk: OwnedDataChunk) -> Promise<(), Self::Error> {
40        AsyncStore::put_encrypted(self, chunk)
41    }
42}
43
44#[derive(Default)]
45pub struct MixedStoreInner<E: MixedStoreError> {
46    pub async_stores: Vec<Box<dyn DynAsyncStore<Error = E>>>,
47    pub stores: Vec<Box<dyn DynStore<Error = E>>>,
48}
49
50#[derive(Clone, Default)]
51pub struct MixedStore<E: MixedStoreError, const WRITE_TO_ALL: bool> {
52    inner: Arc<RwLock<MixedStoreInner<E>>>,
53}
54
55impl<E: MixedStoreError, const WRITE_TO_ALL: bool> Deref for MixedStore<E, WRITE_TO_ALL> {
56    type Target = Arc<RwLock<MixedStoreInner<E>>>;
57
58    fn deref(&self) -> &Self::Target {
59        &self.inner
60    }
61}
62
63impl<E: MixedStoreError, const WRITE_TO_ALL: bool> MixedStore<E, WRITE_TO_ALL> {
64    /// Creates a `MixedStore` from a list of Stores.
65    #[must_use]
66    pub fn new<S, A, IS, IA>(stores: IS, async_stores: IA) -> Self
67    where
68        S: Store<Error = E> + Send + Sync + 'static,
69        A: AsyncStore<Error = E>,
70        IS: IntoIterator<Item = S>,
71        IA: IntoIterator<Item = A>,
72    {
73        Self {
74            inner: Arc::new(RwLock::new(MixedStoreInner {
75                async_stores: async_stores.into_iter().map(|s| Box::new(s) as _).collect(),
76                stores: stores.into_iter().map(|s| Box::new(s) as _).collect(),
77            })),
78        }
79    }
80
81    pub fn push_sync<S>(&mut self, store: S)
82    where
83        S: Store<Error = E> + Send + Sync + 'static,
84    {
85        self.write().stores.push(Box::new(store));
86    }
87
88    pub fn push_async<A>(&mut self, store: A)
89    where
90        A: AsyncStore<Error = E>,
91    {
92        self.write().async_stores.push(Box::new(store));
93    }
94
95    pub fn extend_sync<S, I>(&mut self, iter: I)
96    where
97        S: Store<Error = E> + Send + Sync + 'static,
98        I: IntoIterator<Item = S>,
99    {
100        self.write()
101            .stores
102            .extend(iter.into_iter().map(|s| Box::new(s) as _));
103    }
104
105    pub fn extend_async<A, I>(&mut self, iter: I)
106    where
107        A: AsyncStore<Error = E>,
108        I: IntoIterator<Item = A>,
109    {
110        self.write()
111            .async_stores
112            .extend(iter.into_iter().map(|s| Box::new(s) as _));
113    }
114
115    #[must_use]
116    pub fn write_to_all(self) -> MixedStore<E, true> {
117        MixedStore { inner: self.inner }
118    }
119
120    #[must_use]
121    pub fn write_to_one(self) -> MixedStore<E, false> {
122        MixedStore { inner: self.inner }
123    }
124
125    fn get_sync(&self, hash: &Hash) -> Result<OwnedDataChunk, E> {
126        let mut last_err = None;
127
128        for s in &self.read().stores {
129            match s.get(hash) {
130                Ok(chunk) => return Ok(chunk),
131                Err(err) => last_err = Some(err),
132            }
133        }
134
135        Err(last_err.unwrap_or_else(E::no_stores))
136    }
137
138    fn get_async(&self, hash: &Arc<Hash>) -> Promise<OwnedDataChunk, E> {
139        let mut last_err = E::no_stores();
140        let guard = self.read();
141
142        for s in &guard.stores {
143            match s.get(hash) {
144                Ok(chunk) => return Promise::Resolved(chunk),
145                Err(err) => last_err = err,
146            }
147        }
148
149        let promises: Vec<Promise<OwnedDataChunk, E>> = guard
150            .async_stores
151            .iter()
152            .map(|store| store.get(hash.clone()))
153            .collect();
154
155        drop(guard);
156
157        Promise::new(GetAsync { last_err, promises })
158    }
159}
160
161struct GetAsync<E: MixedStoreError> {
162    last_err: E,
163    promises: Vec<Promise<OwnedDataChunk, E>>,
164}
165
166impl<E: MixedStoreError> Future for GetAsync<E> {
167    type Output = Result<OwnedDataChunk, E>;
168
169    fn poll(
170        mut self: std::pin::Pin<&mut Self>,
171        cx: &mut std::task::Context<'_>,
172    ) -> std::task::Poll<Self::Output> {
173        use std::task::Poll::{Pending, Ready};
174
175        let queue = mem::take(&mut self.promises);
176
177        for promise in queue {
178            match promise {
179                Promise::Consumed => {}
180                Promise::Pending(mut future) => match future.poll_unpin(cx) {
181                    Pending => self.promises.push(Promise::Pending(future)),
182                    Ready(Ok(chunk)) => return Ready(Ok(chunk)),
183                    Ready(Err(err)) => self.last_err = err,
184                },
185                Promise::Rejected(err) => self.last_err = err,
186                Promise::Resolved(chunk) => return Ready(Ok(chunk)),
187            }
188        }
189
190        if self.promises.is_empty() {
191            return Ready(Err(replace(&mut self.last_err, E::already_consumed())));
192        }
193
194        Pending
195    }
196}
197
198impl<E: MixedStoreError> Store for MixedStore<E, true> {
199    type Chunk<'c> = OwnedDataChunk;
200    type Error = E;
201
202    fn get<'a>(&'a self, hash: &Hash) -> Result<Self::Chunk<'a>, Self::Error> {
203        self.get_sync(hash)
204    }
205
206    fn put_encrypted<C: DataChunk>(&self, chunk: C) -> Result<(), Self::Error> {
207        let guard = self.read();
208
209        if guard.stores.is_empty() {
210            return Err(E::no_stores());
211        }
212
213        for s in &guard.stores {
214            s.put_encrypted(chunk.borrow())?;
215        }
216
217        drop(guard);
218
219        Ok(())
220    }
221}
222
223impl<E: MixedStoreError> Store for MixedStore<E, false> {
224    type Chunk<'c> = OwnedDataChunk;
225    type Error = E;
226
227    fn get<'a>(&'a self, hash: &Hash) -> Result<Self::Chunk<'a>, Self::Error> {
228        self.get_sync(hash)
229    }
230
231    fn put_encrypted<C: DataChunk>(&self, chunk: C) -> Result<(), Self::Error> {
232        let mut last_err = E::no_stores();
233
234        for store in &self.read().stores {
235            match store.put_encrypted(chunk.borrow()) {
236                Ok(()) => return Ok(()),
237                Err(err) => last_err = err,
238            }
239        }
240
241        Err(last_err)
242    }
243}
244
245impl<E: MixedStoreError> AsyncStore for MixedStore<E, true> {
246    type Chunk = OwnedDataChunk;
247    type Error = E;
248
249    fn get(&self, hash: &Hash) -> Promise<Self::Chunk, Self::Error> {
250        let store = self.clone();
251        let hash = Arc::from(*hash);
252
253        Promise::new(async move { store.get_async(&hash).await })
254    }
255
256    fn put_encrypted<C: DataChunk>(&self, chunk: C) -> Promise<(), Self::Error> {
257        let this = self.clone();
258        let chunk = chunk.into_owned();
259
260        let guard = this.read();
261
262        if guard.stores.is_empty() && guard.async_stores.is_empty() {
263            return Promise::reject(E::no_stores());
264        }
265
266        let mut promises = Vec::new();
267
268        promises.extend(guard.stores.iter().map(|store| {
269            let chunk = chunk.clone();
270
271            match store.put_encrypted(chunk.borrow()) {
272                Ok(()) => Promise::resolve(()),
273                Err(err) => Promise::reject(err),
274            }
275        }));
276
277        promises.extend(
278            guard
279                .async_stores
280                .iter()
281                .map(|store| store.put_encrypted(chunk.clone())),
282        );
283
284        drop(guard);
285
286        Promise::all(promises).then(async |_| Ok(()))
287    }
288}
289
290impl<E: MixedStoreError> AsyncStore for MixedStore<E, false> {
291    type Chunk = OwnedDataChunk;
292    type Error = E;
293
294    fn get(&self, hash: &Hash) -> Promise<Self::Chunk, Self::Error> {
295        let store = self.clone();
296        let hash = Arc::from(*hash);
297
298        Promise::new(async move { store.get_async(&hash).await })
299    }
300
301    fn put_encrypted<C: DataChunk>(&self, chunk: C) -> Promise<(), Self::Error> {
302        let this = self.clone();
303        let chunk = chunk.into_owned();
304
305        let guard = this.read();
306
307        if guard.stores.is_empty() && guard.async_stores.is_empty() {
308            return Promise::reject(E::no_stores());
309        }
310
311        let mut last_err = None;
312
313        for store in &guard.stores {
314            match store.put_encrypted(chunk.borrow()) {
315                Ok(()) => return Promise::resolve(()),
316                Err(err) => last_err = Some(err),
317            }
318        }
319
320        if guard.async_stores.is_empty() {
321            return Promise::reject(last_err.unwrap_or_else(E::no_stores));
322        }
323
324        let promises: Vec<Promise<(), E>> = guard
325            .async_stores
326            .iter()
327            .map(|store| store.put_encrypted(chunk.clone()))
328            .collect();
329
330        drop(guard);
331
332        Promise::new(async move {
333            match Promise::any(promises).await {
334                Ok(()) => Ok(()),
335                Err(mut errors) => {
336                    let err = errors
337                        .pop()
338                        .or(last_err)
339                        .unwrap_or_else(E::already_consumed);
340
341                    Err(err)
342                }
343            }
344        })
345    }
346}
347
348pub trait MixedStoreError:
349    Clone + From<PsDataChunkError> + From<PsHkeyError> + PromiseRejection + Send + 'static
350{
351    fn no_stores() -> Self;
352}