remotecache/persistent/
client.rs

1//! A client for interacting with a cache server.
2
3use std::{
4    any::Any,
5    fs::{self, OpenOptions},
6    io::{Read, Write},
7    net::TcpListener,
8    path::{Path, PathBuf},
9    sync::{
10        mpsc::{channel, Receiver, RecvTimeoutError, Sender},
11        Arc, Mutex,
12    },
13    thread,
14    time::Duration,
15};
16
17use backoff::ExponentialBackoff;
18use serde::{de::DeserializeOwned, Deserialize, Serialize};
19use tokio::runtime::{Handle, Runtime};
20use tonic::transport::{Channel, Endpoint};
21
22use crate::{
23    error::{ArcResult, Error, Result},
24    rpc::{
25        local::{self, local_cache_client},
26        remote::{self, remote_cache_client},
27    },
28    run_generator, CacheHandle, Cacheable, CacheableWithState, GenerateFn, GenerateResultFn,
29    GenerateResultWithStateFn, GenerateWithStateFn, Namespace,
30};
31
32use super::server::Server;
33
34/// The timeout for connecting to the cache server.
35pub const CONNECTION_TIMEOUT_MS_DEFAULT: u64 = 1000;
36
37/// The timeout for making a request to the cache server.
38pub const REQUEST_TIMEOUT_MS_DEFAULT: u64 = 1000;
39
40/// An enumeration of client kinds.
41///
42/// Each interacts with a different cache server API, depending on the desired functionality.
43#[derive(Debug, Clone, Copy, Eq, PartialEq, Serialize, Deserialize)]
44pub enum ClientKind {
45    /// A client that shares a filesystem with the server.
46    ///
47    /// Enables storing data in the cache via the filesystem without sending large bytestreams over gRPC.
48    Local,
49    /// A client that does not share a filseystem with the server.
50    ///
51    /// Sends data to the cache server via gRPC.
52    Remote,
53}
54
55#[derive(Debug)]
56struct ClientInner {
57    kind: ClientKind,
58    url: String,
59    poll_backoff: ExponentialBackoff,
60    connection_timeout: Duration,
61    request_timeout: Duration,
62    handle: Handle,
63    // Only used to own a runtime created by the builder.
64    #[allow(dead_code)]
65    runtime: Option<Runtime>,
66}
67
68/// A gRPC cache client.
69///
70/// The semantics of the [`Client`] API are the same as those of the
71/// [`NamespaceCache`](crate::mem::NamespaceCache) API.
72#[derive(Debug, Clone)]
73pub struct Client {
74    inner: Arc<ClientInner>,
75}
76
77/// A builder for a [`Client`].
78#[derive(Default, Clone, Debug)]
79pub struct ClientBuilder {
80    kind: Option<ClientKind>,
81    url: Option<String>,
82    poll_backoff: Option<ExponentialBackoff>,
83    connection_timeout: Option<Duration>,
84    request_timeout: Option<Duration>,
85    handle: Option<Handle>,
86}
87
88struct GenerateState<K, V> {
89    handle: CacheHandle<V>,
90    namespace: Namespace,
91    hash: Vec<u8>,
92    key: K,
93}
94
95/// Sends a heartbeat RPC to the server.
96trait HeartbeatFn: Fn(&Client) -> Result<()> + Send + Any {}
97impl<T: Fn(&Client) -> Result<()> + Send + Any> HeartbeatFn for T {}
98
99/// Writes a generated value to the given `String` path, using the provided assignment ID `u64` to
100/// notify the cache server once completed.
101trait LocalWriteValueFn<V>:
102    FnOnce(&Client, u64, String, &ArcResult<V>) -> Result<()> + Send + Any
103{
104}
105impl<V, T: FnOnce(&Client, u64, String, &ArcResult<V>) -> Result<()> + Send + Any>
106    LocalWriteValueFn<V> for T
107{
108}
109
110/// Writes a generated value to the cache server, using the provided assignment ID `u64` to
111/// tell the cache server which task completed.
112trait RemoteWriteValueFn<V>: FnOnce(&Client, u64, &ArcResult<V>) -> Result<()> + Send + Any {}
113impl<V, T: FnOnce(&Client, u64, &ArcResult<V>) -> Result<()> + Send + Any> RemoteWriteValueFn<V>
114    for T
115{
116}
117
118/// Deserializes desired value from bytes stored in the cache. If `V` is a result, would need to
119/// wrap the bytes from the cache with an `Ok` since `Err` results are not stored in the cache.
120trait DeserializeValueFn<V>: FnOnce(&[u8]) -> Result<V> + Send + Any {}
121impl<V, T: FnOnce(&[u8]) -> Result<V> + Send + Any> DeserializeValueFn<V> for T {}
122
123impl ClientBuilder {
124    /// Creates a new [`ClientBuilder`].
125    pub fn new() -> Self {
126        Self::default()
127    }
128
129    /// Sets the configured server URL.
130    pub fn url(&mut self, url: impl Into<String>) -> &mut Self {
131        self.url = Some(url.into());
132        self
133    }
134
135    /// Sets the configured client type.
136    pub fn kind(&mut self, kind: ClientKind) -> &mut Self {
137        self.kind = Some(kind);
138        self
139    }
140    /// Creates a new [`ClientBuilder`] with configured client type [`ClientKind::Local`] and a
141    /// server URL `url`.
142    pub fn local(url: impl Into<String>) -> Self {
143        let mut builder = Self::new();
144        builder.kind(ClientKind::Local).url(url);
145        builder
146    }
147
148    /// Creates a new [`ClientBuilder`] with configured client type [`ClientKind::Remote`] and a
149    /// server URL `url`.
150    pub fn remote(url: impl Into<String>) -> Self {
151        let mut builder = Self::new();
152        builder.kind(ClientKind::Remote).url(url);
153        builder
154    }
155
156    /// Configures the exponential backoff used when polling the server for cache entry
157    /// statuses.
158    ///
159    /// Defaults to [`ExponentialBackoff::default`].
160    pub fn poll_backoff(&mut self, backoff: ExponentialBackoff) -> &mut Self {
161        self.poll_backoff = Some(backoff);
162        self
163    }
164
165    /// Sets the timeout for connecting to the server.
166    ///
167    /// Defaults to [`CONNECTION_TIMEOUT_MS_DEFAULT`].
168    pub fn connection_timeout(&mut self, timeout: Duration) -> &mut Self {
169        self.connection_timeout = Some(timeout);
170        self
171    }
172
173    /// Sets the timeout for receiving a reply from the server.
174    ///
175    /// Defaults to [`REQUEST_TIMEOUT_MS_DEFAULT`].
176    pub fn request_timeout(&mut self, timeout: Duration) -> &mut Self {
177        self.request_timeout = Some(timeout);
178        self
179    }
180
181    /// Configures a [`Handle`] for making asynchronous gRPC requests.
182    ///
183    /// If no handle is specified, starts a new [`tokio::runtime::Runtime`] upon building the
184    /// [`Client`] object.
185    pub fn runtime_handle(&mut self, handle: Handle) -> &mut Self {
186        self.handle = Some(handle);
187        self
188    }
189
190    /// Builds a [`Client`] object with the configured parameters.
191    pub fn build(&mut self) -> Client {
192        let (handle, runtime) = match self.handle.clone() {
193            Some(handle) => (handle, None),
194            None => {
195                let runtime = tokio::runtime::Builder::new_multi_thread()
196                    .worker_threads(1)
197                    .enable_all()
198                    .build()
199                    .unwrap();
200                (runtime.handle().clone(), Some(runtime))
201            }
202        };
203        Client {
204            inner: Arc::new(ClientInner {
205                kind: self.kind.expect("must specify client kind"),
206                url: self.url.clone().expect("must specify server URL"),
207                poll_backoff: self.poll_backoff.clone().unwrap_or_default(),
208                connection_timeout: self
209                    .connection_timeout
210                    .unwrap_or(Duration::from_millis(CONNECTION_TIMEOUT_MS_DEFAULT)),
211                request_timeout: self
212                    .request_timeout
213                    .unwrap_or(Duration::from_millis(REQUEST_TIMEOUT_MS_DEFAULT)),
214                handle,
215                runtime,
216            }),
217        }
218    }
219}
220
221impl Client {
222    /// Creates a new gRPC cache client for a server at `url` with default configuration values.
223    pub fn with_default_config(kind: ClientKind, url: impl Into<String>) -> Self {
224        Self::builder().kind(kind).url(url).build()
225    }
226
227    /// Creates a new gRPC cache client.
228    pub fn builder() -> ClientBuilder {
229        ClientBuilder::new()
230    }
231
232    /// Creates a new local gRPC cache client.
233    ///
234    /// See [`ClientKind`] for an explanation of the different kinds of clients.
235    pub fn local(url: impl Into<String>) -> ClientBuilder {
236        ClientBuilder::local(url)
237    }
238
239    /// Creates a new remote gRPC cache client.
240    ///
241    /// See [`ClientKind`] for an explanation of the different kinds of clients.
242    pub fn remote(url: impl Into<String>) -> ClientBuilder {
243        ClientBuilder::remote(url)
244    }
245
246    /// Ensures that a value corresponding to `key` is generated, using `generate_fn`
247    /// to generate it if it has not already been generated.
248    ///
249    /// Returns a handle to the value. If the value is not yet generated, it is generated
250    /// in the background.
251    ///
252    /// For more detailed examples, refer to
253    /// [`NamespaceCache::generate`](crate::mem::NamespaceCache::generate).
254    ///
255    /// # Examples
256    ///
257    /// ```
258    /// use cache::{persistent::client::{Client, ClientKind}, error::Error, Cacheable};
259    ///
260    /// let client = Client::with_default_config(ClientKind::Local, "http://127.0.0.1:28055");
261    /// # use cache::persistent::client::{setup_test, create_server_and_clients, ServerKind};
262    /// # let (root, _, runtime) = setup_test("persistent_client_Client_generate").unwrap();
263    /// # let (_, client, _) = create_server_and_clients(root, ServerKind::Local, runtime.handle());
264    ///
265    /// fn generate_fn(tuple: &(u64, u64)) -> u64 {
266    ///     tuple.0 + tuple.1
267    /// }
268    ///
269    /// let handle = client.generate("example.namespace", (5, 6), generate_fn);
270    /// assert_eq!(*handle.get(), 11);
271    /// ```
272    pub fn generate<
273        K: Serialize + Any + Send + Sync,
274        V: Serialize + DeserializeOwned + Send + Sync + Any,
275    >(
276        &self,
277        namespace: impl Into<Namespace>,
278        key: K,
279        generate_fn: impl GenerateFn<K, V>,
280    ) -> CacheHandle<V> {
281        let namespace = namespace.into();
282        let state = Client::setup_generate(namespace, key);
283        let handle = state.handle.clone();
284
285        match self.inner.kind {
286            ClientKind::Local => self.clone().generate_inner_local(state, generate_fn),
287            ClientKind::Remote => self.clone().generate_inner_remote(state, generate_fn),
288        }
289
290        handle
291    }
292
293    /// Ensures that a value corresponding to `key` is generated, using `generate_fn`
294    /// to generate it if it has not already been generated.
295    ///
296    /// Returns a handle to the value. If the value is not yet generated, it is generated
297    /// in the background.
298    ///
299    /// For more detailed examples, refer to
300    /// [`NamespaceCache::generate_with_state`](crate::mem::NamespaceCache::generate_with_state).
301    ///
302    /// # Examples
303    ///
304    /// ```
305    /// use std::sync::{Arc, Mutex};
306    /// use cache::{persistent::client::{Client, ClientKind}, error::Error, Cacheable};
307    ///
308    /// #[derive(Clone)]
309    /// pub struct Log(Arc<Mutex<Vec<(u64, u64)>>>);
310    ///
311    /// let client = Client::with_default_config(ClientKind::Local, "http://127.0.0.1:28055");
312    /// let log = Log(Arc::new(Mutex::new(Vec::new())));
313    /// # use cache::persistent::client::{setup_test, create_server_and_clients, ServerKind};
314    /// # let (root, _, runtime) = setup_test("persistent_client_Client_generate_with_state").unwrap();
315    /// # let (_, client, _) = create_server_and_clients(root, ServerKind::Local, runtime.handle());
316    ///
317    /// fn generate_fn(tuple: &(u64, u64), state: Log) -> u64 {
318    ///     println!("Logging parameters...");
319    ///     state.0.lock().unwrap().push(*tuple);
320    ///     tuple.0 + tuple.1
321    /// }
322    ///
323    /// let handle = client.generate_with_state(
324    ///     "example.namespace", (5, 6), log.clone(), generate_fn
325    /// );
326    /// assert_eq!(*handle.get(), 11);
327    /// assert_eq!(log.0.lock().unwrap().clone(), vec![(5, 6)]);
328    /// ```
329    pub fn generate_with_state<
330        K: Serialize + Send + Sync + Any,
331        V: Serialize + DeserializeOwned + Send + Sync + Any,
332        S: Send + Sync + Any,
333    >(
334        &self,
335        namespace: impl Into<Namespace>,
336        key: K,
337        state: S,
338        generate_fn: impl GenerateWithStateFn<K, S, V>,
339    ) -> CacheHandle<V> {
340        let namespace = namespace.into();
341        self.generate(namespace, key, move |k| generate_fn(k, state))
342    }
343
344    /// Ensures that a result corresponding to `key` is generated, using `generate_fn`
345    /// to generate it if it has not already been generated.
346    ///
347    /// Does not cache on failure as errors are not constrained to be serializable/deserializable.
348    /// As such, failures should happen quickly, or should be serializable and stored as part of
349    /// cached value using [`Client::generate`].
350    ///
351    /// Returns a handle to the value. If the value is not yet generated, it is generated in the
352    /// background.
353    ///
354    /// For more detailed examples, refer to
355    /// [`NamespaceCache::generate_result`](crate::mem::NamespaceCache::generate_result).
356    ///
357    /// # Examples
358    ///
359    /// ```
360    /// use cache::{persistent::client::{Client, ClientKind}, error::Error, Cacheable};
361    ///
362    /// let client = Client::with_default_config(ClientKind::Local, "http://127.0.0.1:28055");
363    /// # use cache::persistent::client::{setup_test, create_server_and_clients, ServerKind};
364    /// # let (root, _, runtime) = setup_test("persistent_client_Client_generate_result").unwrap();
365    /// # let (_, client, _) = create_server_and_clients(root, ServerKind::Local, runtime.handle());
366    ///
367    /// fn generate_fn(tuple: &(u64, u64)) -> anyhow::Result<u64> {
368    ///     if *tuple == (5, 5) {
369    ///         Err(anyhow::anyhow!("invalid tuple"))
370    ///     } else {
371    ///         Ok(tuple.0 + tuple.1)
372    ///     }
373    /// }
374    ///
375    /// let handle = client.generate_result("example.namespace", (5, 5), generate_fn);
376    /// assert_eq!(format!("{}", handle.unwrap_err_inner().root_cause()), "invalid tuple");
377    /// ```
378    pub fn generate_result<
379        K: Serialize + Any + Send + Sync,
380        V: Serialize + DeserializeOwned + Send + Sync + Any,
381        E: Send + Sync + Any,
382    >(
383        &self,
384        namespace: impl Into<Namespace>,
385        key: K,
386        generate_fn: impl GenerateResultFn<K, V, E>,
387    ) -> CacheHandle<std::result::Result<V, E>> {
388        let namespace = namespace.into();
389        let state = Client::setup_generate(namespace, key);
390        let handle = state.handle.clone();
391
392        match self.inner.kind {
393            ClientKind::Local => {
394                self.clone().generate_result_inner_local(state, generate_fn);
395            }
396            ClientKind::Remote => {
397                self.clone()
398                    .generate_result_inner_remote(state, generate_fn);
399            }
400        }
401
402        handle
403    }
404
405    /// Ensures that a value corresponding to `key` is generated, using `generate_fn`
406    /// to generate it if it has not already been generated.
407    ///
408    /// Does not cache on failure as errors are not constrained to be serializable/deserializable.
409    /// As such, failures should happen quickly, or should be serializable and stored as part of
410    /// cached value using [`Client::generate_with_state`].
411    ///
412    /// Returns a handle to the value. If the value is not yet generated, it is generated
413    /// in the background.
414    ///
415    /// For more detailed examples, refer to
416    /// [`NamespaceCache::generate_result_with_state`](crate::mem::NamespaceCache::generate_result_with_state).
417    ///
418    /// # Examples
419    ///
420    /// ```
421    /// use std::sync::{Arc, Mutex};
422    /// use cache::{persistent::client::{Client, ClientKind}, error::Error, Cacheable};
423    ///
424    /// #[derive(Clone)]
425    /// pub struct Log(Arc<Mutex<Vec<(u64, u64)>>>);
426    ///
427    /// let client = Client::with_default_config(ClientKind::Local, "http://127.0.0.1:28055");
428    /// let log = Log(Arc::new(Mutex::new(Vec::new())));
429    /// # use cache::persistent::client::{setup_test, create_server_and_clients, ServerKind};
430    /// # let (root, _, runtime) = setup_test("persistent_client_Client_generate_result_with_state").unwrap();
431    /// # let (_, client, _) = create_server_and_clients(root, ServerKind::Local, runtime.handle());
432    ///
433    /// fn generate_fn(tuple: &(u64, u64), state: Log) -> anyhow::Result<u64> {
434    ///     println!("Logging parameters...");
435    ///     state.0.lock().unwrap().push(*tuple);
436    ///
437    ///     if *tuple == (5, 5) {
438    ///         Err(anyhow::anyhow!("invalid tuple"))
439    ///     } else {
440    ///         Ok(tuple.0 + tuple.1)
441    ///     }
442    /// }
443    ///
444    /// let handle = client.generate_result_with_state(
445    ///     "example.namespace", (5, 5), log.clone(), generate_fn,
446    /// );
447    /// assert_eq!(format!("{}", handle.unwrap_err_inner().root_cause()), "invalid tuple");
448    /// assert_eq!(log.0.lock().unwrap().clone(), vec![(5, 5)]);
449    /// ```
450    pub fn generate_result_with_state<
451        K: Serialize + Send + Sync + Any,
452        V: Serialize + DeserializeOwned + Send + Sync + Any,
453        E: Send + Sync + Any,
454        S: Send + Sync + Any,
455    >(
456        &self,
457        namespace: impl Into<Namespace>,
458        key: K,
459        state: S,
460        generate_fn: impl GenerateResultWithStateFn<K, S, V, E>,
461    ) -> CacheHandle<std::result::Result<V, E>> {
462        let namespace = namespace.into();
463        self.generate_result(namespace, key, move |k| generate_fn(k, state))
464    }
465
466    /// Gets a handle to a cacheable object from the cache, generating the object in the background
467    /// if needed.
468    ///
469    /// Does not cache errors, so any errors thrown should be thrown quickly. Any errors that need
470    /// to be cached should be included in the cached output or should be cached using
471    /// [`Client::get_with_err`].
472    ///
473    /// For more detailed examples, refer to
474    /// [`NamespaceCache::get`](crate::mem::NamespaceCache::get).
475    ///
476    /// # Examples
477    ///
478    /// ```
479    /// use cache::{persistent::client::{Client, ClientKind}, error::Error, Cacheable};
480    /// use serde::{Deserialize, Serialize};
481    ///
482    /// #[derive(Deserialize, Serialize, Hash, Eq, PartialEq)]
483    /// pub struct Params {
484    ///     param1: u64,
485    ///     param2: String,
486    /// };
487    ///
488    /// impl Cacheable for Params {
489    ///     type Output = u64;
490    ///     type Error = anyhow::Error;
491    ///
492    ///     fn generate(&self) -> anyhow::Result<u64> {
493    ///         Ok(2 * self.param1)
494    ///     }
495    /// }
496    ///
497    /// let client = Client::with_default_config(ClientKind::Local, "http://127.0.0.1:28055");
498    /// # use cache::persistent::client::{setup_test, create_server_and_clients, ServerKind};
499    /// # let (root, _, runtime) = setup_test("persistent_client_Client_get").unwrap();
500    /// # let (_, client, _) = create_server_and_clients(root, ServerKind::Local, runtime.handle());
501    ///
502    /// let handle = client.get(
503    ///     "example.namespace", Params { param1: 50, param2: "cache".to_string() }
504    /// );
505    /// assert_eq!(*handle.unwrap_inner(), 100);
506    /// ```
507    pub fn get<K: Cacheable>(
508        &self,
509        namespace: impl Into<Namespace>,
510        key: K,
511    ) -> CacheHandle<std::result::Result<K::Output, K::Error>> {
512        let namespace = namespace.into();
513        self.generate_result(namespace, key, |key| key.generate())
514    }
515
516    /// Gets a handle to a cacheable object from the cache, caching failures as well.
517    ///
518    /// Generates the object in the background if needed.
519    ///
520    /// For more detailed examples, refer to
521    /// [`NamespaceCache::get_with_err`](crate::mem::NamespaceCache::get_with_err).
522    ///
523    /// # Examples
524    ///
525    /// ```
526    /// use cache::{persistent::client::{Client, ClientKind}, error::Error, Cacheable};
527    /// use serde::{Deserialize, Serialize};
528    ///
529    /// #[derive(Deserialize, Serialize, Hash, Eq, PartialEq)]
530    /// pub struct Params {
531    ///     param1: u64,
532    ///     param2: String,
533    /// };
534    ///
535    /// impl Cacheable for Params {
536    ///     type Output = u64;
537    ///     type Error = String;
538    ///
539    ///     fn generate(&self) -> Result<Self::Output, Self::Error> {
540    ///         if self.param1 == 5 {
541    ///             Err("invalid param".to_string())
542    ///         } else {
543    ///             Ok(2 * self.param1)
544    ///         }
545    ///     }
546    /// }
547    ///
548    /// let client = Client::with_default_config(ClientKind::Local, "http://127.0.0.1:28055");
549    /// # use cache::persistent::client::{setup_test, create_server_and_clients, ServerKind};
550    /// # let (root, _, runtime) = setup_test("persistent_client_Client_get_with_err").unwrap();
551    /// # let (_, client, _) = create_server_and_clients(root, ServerKind::Local, runtime.handle());
552    ///
553    /// let handle = client.get_with_err(
554    ///     "example.namespace", Params { param1: 5, param2: "cache".to_string() }
555    /// );
556    /// assert_eq!(handle.unwrap_err_inner(), "invalid param");
557    /// ```
558    pub fn get_with_err<
559        E: Send + Sync + Serialize + DeserializeOwned + Any,
560        K: Cacheable<Error = E>,
561    >(
562        &self,
563        namespace: impl Into<Namespace>,
564        key: K,
565    ) -> CacheHandle<std::result::Result<K::Output, K::Error>> {
566        let namespace = namespace.into();
567        self.generate(namespace, key, |key| key.generate())
568    }
569
570    /// Gets a handle to a cacheable object from the cache, generating the object in the background
571    /// if needed.
572    ///
573    /// Does not cache errors, so any errors thrown should be thrown quickly. Any errors that need
574    /// to be cached should be included in the cached output or should be cached using
575    /// [`Client::get_with_state_and_err`].
576    ///
577    /// For more detailed examples, refer to
578    /// [`NamespaceCache::get_with_state`](crate::mem::NamespaceCache::get_with_state).
579    ///
580    /// # Examples
581    ///
582    /// ```
583    /// use std::sync::{Arc, Mutex};
584    /// use cache::{persistent::client::{Client, ClientKind}, error::Error, CacheableWithState};
585    /// use serde::{Deserialize, Serialize};
586    ///
587    /// #[derive(Debug, Deserialize, Serialize, Clone, Hash, Eq, PartialEq)]
588    /// pub struct Params(u64);
589    ///
590    /// #[derive(Clone)]
591    /// pub struct Log(Arc<Mutex<Vec<Params>>>);
592    ///
593    /// impl CacheableWithState<Log> for Params {
594    ///     type Output = u64;
595    ///     type Error = anyhow::Error;
596    ///
597    ///     fn generate_with_state(&self, state: Log) -> anyhow::Result<u64> {
598    ///         println!("Logging parameters...");
599    ///         state.0.lock().unwrap().push(self.clone());
600    ///         Ok(2 * self.0)
601    ///     }
602    /// }
603    ///
604    /// let client = Client::with_default_config(ClientKind::Local, "http://127.0.0.1:28055");
605    /// let log = Log(Arc::new(Mutex::new(Vec::new())));
606    /// # use cache::persistent::client::{setup_test, create_server_and_clients, ServerKind};
607    /// # let (root, _, runtime) = setup_test("persistent_client_Client_get_with_state").unwrap();
608    /// # let (_, client, _) = create_server_and_clients(root, ServerKind::Local, runtime.handle());
609    ///
610    /// let handle = client.get_with_state(
611    ///     "example.namespace",
612    ///     Params(0),
613    ///     log.clone(),
614    /// );
615    /// assert_eq!(*handle.unwrap_inner(), 0);
616    /// assert_eq!(log.0.lock().unwrap().clone(), vec![Params(0)]);
617    /// ```
618    pub fn get_with_state<S: Send + Sync + Any, K: CacheableWithState<S>>(
619        &self,
620        namespace: impl Into<Namespace>,
621        key: K,
622        state: S,
623    ) -> CacheHandle<std::result::Result<K::Output, K::Error>> {
624        let namespace = namespace.into();
625        self.generate_result_with_state(namespace, key, state, |key, state| {
626            key.generate_with_state(state)
627        })
628    }
629
630    /// Gets a handle to a cacheable object from the cache, caching failures as well.
631    ///
632    /// Generates the object in the background if needed.
633    ///
634    /// See [`Client::get_with_err`] and [`Client::get_with_state`] for related examples.
635    pub fn get_with_state_and_err<
636        S: Send + Sync + Any,
637        E: Send + Sync + Serialize + DeserializeOwned + Any,
638        K: CacheableWithState<S, Error = E>,
639    >(
640        &self,
641        namespace: impl Into<Namespace>,
642        key: K,
643        state: S,
644    ) -> CacheHandle<std::result::Result<K::Output, K::Error>> {
645        let namespace = namespace.into();
646        self.generate_with_state(namespace, key, state, |key, state| {
647            key.generate_with_state(state)
648        })
649    }
650
651    /// Sets up the necessary objects to be passed in to [`Client::spawn_handler`].
652    fn setup_generate<K: Serialize, V>(namespace: Namespace, key: K) -> GenerateState<K, V> {
653        GenerateState {
654            handle: CacheHandle::empty(),
655            namespace,
656            hash: crate::hash(&flexbuffers::to_vec(&key).unwrap()),
657            key,
658        }
659    }
660
661    /// Spawns a new thread to generate the desired value asynchronously.
662    ///
663    /// If the provided handler returns a error, stores an [`Arc`]ed error in the handle.
664    fn spawn_handler<V: Send + Sync + Any>(
665        self,
666        handle: CacheHandle<V>,
667        handler: impl FnOnce() -> Result<()> + Send + Any,
668    ) {
669        thread::spawn(move || {
670            if let Err(e) = handler() {
671                tracing::error!("encountered error while executing handler: {}", e,);
672                handle.set(Err(Arc::new(e)));
673            }
674        });
675    }
676
677    /// Deserializes a cached value into a [`Result`] that can be stored in a [`CacheHandle`].
678    fn deserialize_cache_value<V: DeserializeOwned>(data: &[u8]) -> Result<V> {
679        let data = flexbuffers::from_slice(data)?;
680        Ok(data)
681    }
682
683    /// Deserializes a cached value into a containing result with the appropriate error type.
684    fn deserialize_cache_result<V: DeserializeOwned, E>(
685        data: &[u8],
686    ) -> Result<std::result::Result<V, E>> {
687        let data = flexbuffers::from_slice(data)?;
688        Ok(Ok(data))
689    }
690
691    /// Starts sending heartbeats to the server in a new thread .
692    ///
693    /// Returns a sender for telling the spawned thread to stop sending heartbeats and
694    /// a receiver for waiting for heartbeats to terminate.
695    fn start_heartbeats(
696        &self,
697        heartbeat_interval: Duration,
698        send_heartbeat: impl HeartbeatFn,
699    ) -> (Sender<()>, Receiver<()>) {
700        tracing::debug!("starting heartbeats");
701        let (s_heartbeat_stop, r_heartbeat_stop) = channel();
702        let (s_heartbeat_stopped, r_heartbeat_stopped) = channel();
703        let self_clone = self.clone();
704        thread::spawn(move || {
705            loop {
706                match r_heartbeat_stop.recv_timeout(heartbeat_interval) {
707                    Ok(_) | Err(RecvTimeoutError::Disconnected) => {
708                        break;
709                    }
710                    Err(RecvTimeoutError::Timeout) => {
711                        if send_heartbeat(&self_clone).is_err() {
712                            break;
713                        }
714                    }
715                }
716            }
717            let _ = s_heartbeat_stopped.send(());
718        });
719        (s_heartbeat_stop, r_heartbeat_stopped)
720    }
721
722    /// Converts a [`Result<(S, bool)>`] to a [`std::result::Result<S, backoff::Error<Error>>`].
723    ///
724    /// If the `retry` boolean is `true`, returns a [`backoff::Error::Transient`]. If the provided
725    /// result is [`Err`], returns a [`backoff::Error::Permanent`]. Otherwise, returns the entry
726    /// status of type `S`.
727    fn run_backoff_loop<S>(&self, get_status_fn: impl Fn() -> Result<(S, bool)>) -> Result<S> {
728        Ok(backoff::retry(self.inner.poll_backoff.clone(), move || {
729            tracing::debug!("attempting get request to retrieve entry status");
730            get_status_fn()
731                .map_err(backoff::Error::Permanent)
732                .and_then(|(status, retry)| {
733                    if retry {
734                        tracing::debug!("entry is loading, retrying later");
735                        Err(backoff::Error::transient(Error::EntryLoading))
736                    } else {
737                        tracing::debug!("entry status retrieved");
738                        Ok(status)
739                    }
740                })
741        })
742        .map_err(Box::new)?)
743    }
744
745    /// Handles an unassigned entry by generating it locally.
746    fn handle_unassigned<K: Send + Sync + Any, V: Send + Sync + Any>(
747        handle: CacheHandle<V>,
748        key: K,
749        generate_fn: impl GenerateFn<K, V>,
750    ) {
751        tracing::debug!("entry is unassigned, generating locally");
752        let v = run_generator(move || generate_fn(&key));
753        handle.set(v);
754    }
755
756    /// Handles an assigned entry by generating it locally and sending heartbeats periodically
757    /// while the generator is running.
758    fn handle_assigned<K: Send + Sync + Any, V: Send + Sync + Any>(
759        &self,
760        key: K,
761        generate_fn: impl GenerateFn<K, V>,
762        heartbeat_interval_ms: u64,
763        send_heartbeat: impl HeartbeatFn,
764    ) -> ArcResult<V> {
765        tracing::debug!("entry has been assigned to the client, generating locally");
766        let (s_heartbeat_stop, r_heartbeat_stopped) =
767            self.start_heartbeats(Duration::from_millis(heartbeat_interval_ms), send_heartbeat);
768        let v = run_generator(move || generate_fn(&key));
769        let _ = s_heartbeat_stop.send(());
770        let _ = r_heartbeat_stopped.recv();
771        tracing::debug!("finished generating, writing value to cache");
772        v
773    }
774
775    /// Connects to a local cache gRPC server.
776    async fn connect_local(&self) -> Result<local_cache_client::LocalCacheClient<Channel>> {
777        let endpoint = Endpoint::from_shared(self.inner.url.clone())?
778            .timeout(self.inner.request_timeout)
779            .connect_timeout(self.inner.connection_timeout);
780        let test = local_cache_client::LocalCacheClient::connect(endpoint).await;
781        Ok(test?)
782    }
783
784    /// Issues a `Get` RPC to a local cache gRPC server.
785    fn get_rpc_local(
786        &self,
787        namespace: String,
788        key: Vec<u8>,
789        assign: bool,
790    ) -> Result<local::get_reply::EntryStatus> {
791        let out: Result<local::GetReply> = self.inner.handle.block_on(async {
792            let mut client = self.connect_local().await?;
793            Ok(client
794                .get(local::GetRequest {
795                    namespace,
796                    key,
797                    assign,
798                })
799                .await?
800                .into_inner())
801        });
802        Ok(out?.entry_status.unwrap())
803    }
804
805    /// Issues a `Heartbeat` RPC to a local cache gRPC server.
806    fn heartbeat_rpc_local(&self, id: u64) -> Result<()> {
807        self.inner.handle.block_on(async {
808            let mut client = self.connect_local().await?;
809            client.heartbeat(local::HeartbeatRequest { id }).await?;
810            Ok(())
811        })
812    }
813
814    /// Issues a `Done` RPC to a local cache gRPC server.
815    fn done_rpc_local(&self, id: u64) -> Result<()> {
816        self.inner.handle.block_on(async {
817            let mut client = self.connect_local().await?;
818            client.done(local::DoneRequest { id }).await?;
819            Ok(())
820        })
821    }
822
823    /// Issues a `Drop` RPC to a local cache gRPC server.
824    fn drop_rpc_local(&self, id: u64) -> Result<()> {
825        self.inner.handle.block_on(async {
826            let mut client = self.connect_local().await?;
827            client.drop(local::DropRequest { id }).await?;
828            Ok(())
829        })
830    }
831
832    fn write_generated_data_to_disk<V: Serialize>(
833        &self,
834        id: u64,
835        path: String,
836        data: &V,
837    ) -> Result<()> {
838        let path = PathBuf::from(path);
839        if let Some(parent) = path.parent() {
840            fs::create_dir_all(parent)?;
841        }
842
843        let mut f = OpenOptions::new()
844            .read(true)
845            .write(true)
846            .create(true)
847            .open(&path)?;
848        f.write_all(&flexbuffers::to_vec(data).unwrap())?;
849        self.done_rpc_local(id)?;
850
851        Ok(())
852    }
853
854    /// Writes a generated value to a local cache via the `Set` RPC.
855    fn write_generated_value_local<V: Serialize>(
856        &self,
857        id: u64,
858        path: String,
859        value: &ArcResult<V>,
860    ) -> Result<()> {
861        if let Ok(data) = value {
862            self.write_generated_data_to_disk(id, path, data)?;
863        }
864        Ok(())
865    }
866
867    /// Writes data contained in a generated result to a local cache via the `Set` RPC.
868    ///
869    /// Does not write to the cache if the generated result is an [`Err`].
870    fn write_generated_result_local<V: Serialize, E>(
871        &self,
872        id: u64,
873        path: String,
874        value: &ArcResult<std::result::Result<V, E>>,
875    ) -> Result<()> {
876        if let Ok(Ok(data)) = value {
877            self.write_generated_data_to_disk(id, path, data)?;
878        }
879        Ok(())
880    }
881
882    /// Runs the generate loop for the local cache protocol, checking whether the desired entry is
883    /// loaded and generating it if needed.
884    fn generate_loop_local<K: Send + Sync + Any, V: Send + Sync + Any>(
885        &self,
886        state: GenerateState<K, V>,
887        generate_fn: impl GenerateFn<K, V>,
888        write_generated_value: impl LocalWriteValueFn<V>,
889        deserialize_cache_data: impl DeserializeValueFn<V>,
890    ) -> Result<()> {
891        let GenerateState {
892            handle,
893            namespace,
894            hash,
895            key,
896        } = state;
897
898        let status = self.run_backoff_loop(|| {
899            let status = self.get_rpc_local(namespace.clone().into_inner(), hash.clone(), true)?;
900            let retry = matches!(status, local::get_reply::EntryStatus::Loading(_));
901
902            Ok((status, retry))
903        })?;
904
905        match status {
906            local::get_reply::EntryStatus::Unassigned(_) => {
907                Client::handle_unassigned(handle, key, generate_fn);
908            }
909            local::get_reply::EntryStatus::Assign(local::AssignReply {
910                id,
911                path,
912                heartbeat_interval_ms,
913            }) => {
914                let v = self.handle_assigned(
915                    key,
916                    generate_fn,
917                    heartbeat_interval_ms,
918                    move |client| -> Result<()> { client.heartbeat_rpc_local(id) },
919                );
920                write_generated_value(self, id, path, &v)?;
921                handle.set(v);
922            }
923            local::get_reply::EntryStatus::Loading(_) => unreachable!(),
924            local::get_reply::EntryStatus::Ready(local::ReadyReply { id, path }) => {
925                tracing::debug!("entry is ready, reading from cache");
926                let mut file = std::fs::File::open(path)?;
927                let mut buf = Vec::new();
928                file.read_to_end(&mut buf)?;
929                self.drop_rpc_local(id)?;
930                tracing::debug!("finished reading entry from disk");
931                handle.set(Ok(deserialize_cache_data(&buf)?));
932            }
933        }
934        Ok(())
935    }
936
937    fn generate_inner_local<
938        K: Serialize + Any + Send + Sync,
939        V: Serialize + DeserializeOwned + Send + Sync + Any,
940    >(
941        self,
942        state: GenerateState<K, V>,
943        generate_fn: impl GenerateFn<K, V>,
944    ) {
945        tracing::debug!("generating using local cache API");
946        self.clone().spawn_handler(state.handle.clone(), move || {
947            self.generate_loop_local(
948                state,
949                generate_fn,
950                Client::write_generated_value_local,
951                Client::deserialize_cache_value,
952            )
953        });
954    }
955
956    fn generate_result_inner_local<
957        K: Serialize + Any + Send + Sync,
958        V: Serialize + DeserializeOwned + Send + Sync + Any,
959        E: Send + Sync + Any,
960    >(
961        self,
962        state: GenerateState<K, std::result::Result<V, E>>,
963        generate_fn: impl GenerateResultFn<K, V, E>,
964    ) {
965        self.clone().spawn_handler(state.handle.clone(), move || {
966            self.generate_loop_local(
967                state,
968                generate_fn,
969                Client::write_generated_result_local,
970                Client::deserialize_cache_result,
971            )
972        });
973    }
974
975    /// Connects to a remote cache gRPC server.
976    async fn connect_remote(&self) -> Result<remote_cache_client::RemoteCacheClient<Channel>> {
977        let endpoint = Endpoint::from_shared(self.inner.url.clone())?
978            .timeout(self.inner.request_timeout)
979            .connect_timeout(self.inner.connection_timeout);
980        Ok(remote_cache_client::RemoteCacheClient::connect(endpoint).await?)
981    }
982
983    /// Issues a `Get` RPC to a remote cache gRPC server.
984    fn get_rpc_remote(
985        &self,
986        namespace: String,
987        key: Vec<u8>,
988        assign: bool,
989    ) -> Result<remote::get_reply::EntryStatus> {
990        let out: Result<remote::GetReply> = self.inner.handle.block_on(async {
991            let mut client = self.connect_remote().await?;
992            Ok(client
993                .get(remote::GetRequest {
994                    namespace,
995                    key,
996                    assign,
997                })
998                .await?
999                .into_inner())
1000        });
1001        Ok(out?.entry_status.unwrap())
1002    }
1003
1004    /// Issues a `Heartbeat` RPC to a remote cache gRPC server.
1005    fn heartbeat_rpc_remote(&self, id: u64) -> Result<()> {
1006        self.inner.handle.block_on(async {
1007            let mut client = self.connect_remote().await?;
1008            client.heartbeat(remote::HeartbeatRequest { id }).await?;
1009            Ok(())
1010        })
1011    }
1012
1013    /// Issues a `Set` RPC to a remote cache gRPC server.
1014    fn set_rpc_remote(&self, id: u64, value: Vec<u8>) -> Result<()> {
1015        self.inner.handle.block_on(async {
1016            let mut client = self.connect_remote().await?;
1017            client.set(remote::SetRequest { id, value }).await?;
1018            Ok(())
1019        })
1020    }
1021
1022    /// Writes a generated value to a remote cache via the `Set` RPC.
1023    fn write_generated_value_remote<V: Serialize>(
1024        &self,
1025        id: u64,
1026        value: &ArcResult<V>,
1027    ) -> Result<()> {
1028        if let Ok(data) = value {
1029            self.set_rpc_remote(id, flexbuffers::to_vec(data).unwrap())?;
1030        }
1031        Ok(())
1032    }
1033
1034    /// Writes data contained in a generated result to a remote cache via the `Set` RPC.
1035    ///
1036    /// Does not write to the cache if the generated result is an [`Err`].
1037    fn write_generated_result_remote<V: Serialize, E>(
1038        &self,
1039        id: u64,
1040        value: &ArcResult<std::result::Result<V, E>>,
1041    ) -> Result<()> {
1042        if let Ok(Ok(data)) = value {
1043            self.set_rpc_remote(id, flexbuffers::to_vec(data).unwrap())?;
1044        }
1045        Ok(())
1046    }
1047
1048    /// Runs the generate loop for the remote cache protocol, checking whether the desired entry is
1049    /// loaded and generating it if needed.
1050    fn generate_loop_remote<K: Send + Sync + Any, V: Send + Sync + Any>(
1051        &self,
1052        state: GenerateState<K, V>,
1053        generate_fn: impl GenerateFn<K, V>,
1054        write_generated_value: impl RemoteWriteValueFn<V>,
1055        deserialize_cache_data: impl DeserializeValueFn<V>,
1056    ) -> Result<()> {
1057        let GenerateState {
1058            handle,
1059            namespace,
1060            hash,
1061            key,
1062        } = state;
1063
1064        let status = self.run_backoff_loop(|| {
1065            let status = self.get_rpc_remote(namespace.clone().into_inner(), hash.clone(), true)?;
1066            let retry = matches!(status, remote::get_reply::EntryStatus::Loading(_));
1067
1068            Ok((status, retry))
1069        })?;
1070
1071        match status {
1072            remote::get_reply::EntryStatus::Unassigned(_) => {
1073                Client::handle_unassigned(handle, key, generate_fn);
1074            }
1075            remote::get_reply::EntryStatus::Assign(remote::AssignReply {
1076                id,
1077                heartbeat_interval_ms,
1078            }) => {
1079                let v = self.handle_assigned(
1080                    key,
1081                    generate_fn,
1082                    heartbeat_interval_ms,
1083                    move |client| -> Result<()> { client.heartbeat_rpc_remote(id) },
1084                );
1085                write_generated_value(self, id, &v)?;
1086                handle.set(v);
1087            }
1088            remote::get_reply::EntryStatus::Loading(_) => unreachable!(),
1089            remote::get_reply::EntryStatus::Ready(data) => {
1090                tracing::debug!("entry is ready");
1091                handle.set(Ok(deserialize_cache_data(&data)?));
1092            }
1093        }
1094        Ok(())
1095    }
1096
1097    fn generate_inner_remote<
1098        K: Serialize + Any + Send + Sync,
1099        V: Serialize + DeserializeOwned + Send + Sync + Any,
1100    >(
1101        self,
1102        state: GenerateState<K, V>,
1103        generate_fn: impl GenerateFn<K, V>,
1104    ) {
1105        tracing::debug!("generating using remote cache API");
1106        self.clone().spawn_handler(state.handle.clone(), move || {
1107            self.generate_loop_remote(
1108                state,
1109                generate_fn,
1110                Client::write_generated_value_remote,
1111                Client::deserialize_cache_value,
1112            )
1113        });
1114    }
1115
1116    fn generate_result_inner_remote<
1117        K: Serialize + Any + Send + Sync,
1118        V: Serialize + DeserializeOwned + Send + Sync + Any,
1119        E: Send + Sync + Any,
1120    >(
1121        self,
1122        state: GenerateState<K, std::result::Result<V, E>>,
1123        generate_fn: impl GenerateResultFn<K, V, E>,
1124    ) {
1125        self.clone().spawn_handler(state.handle.clone(), move || {
1126            self.generate_loop_remote(
1127                state,
1128                generate_fn,
1129                Client::write_generated_result_remote,
1130                Client::deserialize_cache_result,
1131            )
1132        });
1133    }
1134}
1135
1136pub(crate) const BUILD_DIR: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/build");
1137pub(crate) const TEST_SERVER_HEARTBEAT_INTERVAL: Duration = Duration::from_millis(200);
1138pub(crate) const TEST_SERVER_HEARTBEAT_TIMEOUT: Duration = Duration::from_millis(500);
1139
1140pub(crate) fn get_listeners(n: usize) -> Vec<(TcpListener, u16)> {
1141    let mut listeners = Vec::new();
1142
1143    for _ in 0..n {
1144        let listener = TcpListener::bind("127.0.0.1:0").unwrap();
1145        let port = listener.local_addr().unwrap().port();
1146        listeners.push((listener, port));
1147    }
1148
1149    listeners
1150}
1151
1152#[doc(hidden)]
1153#[derive(Copy, Clone, Debug, PartialEq, Eq)]
1154pub enum ServerKind {
1155    Local,
1156    Remote,
1157    Both,
1158}
1159
1160impl From<ClientKind> for ServerKind {
1161    fn from(value: ClientKind) -> Self {
1162        match value {
1163            ClientKind::Local => ServerKind::Local,
1164            ClientKind::Remote => ServerKind::Remote,
1165        }
1166    }
1167}
1168
1169pub(crate) fn client_url(port: u16) -> String {
1170    format!("http://127.0.0.1:{port}")
1171}
1172
1173#[doc(hidden)]
1174pub fn create_server_and_clients(
1175    root: PathBuf,
1176    kind: ServerKind,
1177    handle: &Handle,
1178) -> (CacheHandle<Result<()>>, Client, Client) {
1179    let mut listeners = handle.block_on(async {
1180        get_listeners(2)
1181            .into_iter()
1182            .map(|(listener, port)| {
1183                listener.set_nonblocking(true).unwrap();
1184                (tokio::net::TcpListener::from_std(listener).unwrap(), port)
1185            })
1186            .collect::<Vec<_>>()
1187    });
1188    let (local_listener, local_port) = listeners.pop().unwrap();
1189    let (remote_listener, remote_port) = listeners.pop().unwrap();
1190
1191    (
1192        {
1193            let mut builder = Server::builder();
1194
1195            builder = builder
1196                .heartbeat_interval(TEST_SERVER_HEARTBEAT_INTERVAL)
1197                .heartbeat_timeout(TEST_SERVER_HEARTBEAT_TIMEOUT)
1198                .root(root);
1199
1200            let server = match kind {
1201                ServerKind::Local => builder.local_with_incoming(local_listener),
1202                ServerKind::Remote => builder.remote_with_incoming(remote_listener),
1203                ServerKind::Both => builder
1204                    .local_with_incoming(local_listener)
1205                    .remote_with_incoming(remote_listener),
1206            }
1207            .build();
1208
1209            let join_handle = handle.spawn(async move { server.start().await });
1210            let handle_clone = handle.clone();
1211            CacheHandle::new(move || {
1212                let res = handle_clone.block_on(join_handle).unwrap_or_else(|res| {
1213                    if res.is_cancelled() {
1214                        Ok(())
1215                    } else {
1216                        Err(Error::Panic)
1217                    }
1218                });
1219                if let Err(e) = res.as_ref() {
1220                    tracing::error!("server failed to start: {:?}", e);
1221                }
1222                res
1223            })
1224        },
1225        Client::builder()
1226            .kind(ClientKind::Local)
1227            .url(client_url(local_port))
1228            .connection_timeout(Duration::from_secs(3))
1229            .request_timeout(Duration::from_secs(3))
1230            .build(),
1231        Client::builder()
1232            .kind(ClientKind::Remote)
1233            .url(client_url(remote_port))
1234            .connection_timeout(Duration::from_secs(3))
1235            .request_timeout(Duration::from_secs(3))
1236            .build(),
1237    )
1238}
1239
1240pub(crate) fn reset_directory(path: impl AsRef<Path>) -> Result<()> {
1241    let path = path.as_ref();
1242    if path.exists() {
1243        fs::remove_dir_all(path)?;
1244    }
1245    fs::create_dir_all(path)?;
1246    Ok(())
1247}
1248
1249pub(crate) fn create_runtime() -> Runtime {
1250    tokio::runtime::Builder::new_multi_thread()
1251        .worker_threads(1)
1252        .enable_all()
1253        .build()
1254        .unwrap()
1255}
1256
1257#[doc(hidden)]
1258pub fn setup_test(test_name: &str) -> Result<(PathBuf, Arc<Mutex<u64>>, Runtime)> {
1259    let path = PathBuf::from(BUILD_DIR).join(test_name);
1260    reset_directory(&path)?;
1261    Ok((path, Arc::new(Mutex::new(0)), create_runtime()))
1262}