1use std::{
4 any::Any,
5 fs::{self, OpenOptions},
6 io::{Read, Write},
7 net::TcpListener,
8 path::{Path, PathBuf},
9 sync::{
10 mpsc::{channel, Receiver, RecvTimeoutError, Sender},
11 Arc, Mutex,
12 },
13 thread,
14 time::Duration,
15};
16
17use backoff::ExponentialBackoff;
18use serde::{de::DeserializeOwned, Deserialize, Serialize};
19use tokio::runtime::{Handle, Runtime};
20use tonic::transport::{Channel, Endpoint};
21
22use crate::{
23 error::{ArcResult, Error, Result},
24 rpc::{
25 local::{self, local_cache_client},
26 remote::{self, remote_cache_client},
27 },
28 run_generator, CacheHandle, Cacheable, CacheableWithState, GenerateFn, GenerateResultFn,
29 GenerateResultWithStateFn, GenerateWithStateFn, Namespace,
30};
31
32use super::server::Server;
33
34pub const CONNECTION_TIMEOUT_MS_DEFAULT: u64 = 1000;
36
37pub const REQUEST_TIMEOUT_MS_DEFAULT: u64 = 1000;
39
40#[derive(Debug, Clone, Copy, Eq, PartialEq, Serialize, Deserialize)]
44pub enum ClientKind {
45 Local,
49 Remote,
53}
54
55#[derive(Debug)]
56struct ClientInner {
57 kind: ClientKind,
58 url: String,
59 poll_backoff: ExponentialBackoff,
60 connection_timeout: Duration,
61 request_timeout: Duration,
62 handle: Handle,
63 #[allow(dead_code)]
65 runtime: Option<Runtime>,
66}
67
68#[derive(Debug, Clone)]
73pub struct Client {
74 inner: Arc<ClientInner>,
75}
76
77#[derive(Default, Clone, Debug)]
79pub struct ClientBuilder {
80 kind: Option<ClientKind>,
81 url: Option<String>,
82 poll_backoff: Option<ExponentialBackoff>,
83 connection_timeout: Option<Duration>,
84 request_timeout: Option<Duration>,
85 handle: Option<Handle>,
86}
87
88struct GenerateState<K, V> {
89 handle: CacheHandle<V>,
90 namespace: Namespace,
91 hash: Vec<u8>,
92 key: K,
93}
94
95trait HeartbeatFn: Fn(&Client) -> Result<()> + Send + Any {}
97impl<T: Fn(&Client) -> Result<()> + Send + Any> HeartbeatFn for T {}
98
99trait LocalWriteValueFn<V>:
102 FnOnce(&Client, u64, String, &ArcResult<V>) -> Result<()> + Send + Any
103{
104}
105impl<V, T: FnOnce(&Client, u64, String, &ArcResult<V>) -> Result<()> + Send + Any>
106 LocalWriteValueFn<V> for T
107{
108}
109
110trait RemoteWriteValueFn<V>: FnOnce(&Client, u64, &ArcResult<V>) -> Result<()> + Send + Any {}
113impl<V, T: FnOnce(&Client, u64, &ArcResult<V>) -> Result<()> + Send + Any> RemoteWriteValueFn<V>
114 for T
115{
116}
117
118trait DeserializeValueFn<V>: FnOnce(&[u8]) -> Result<V> + Send + Any {}
121impl<V, T: FnOnce(&[u8]) -> Result<V> + Send + Any> DeserializeValueFn<V> for T {}
122
123impl ClientBuilder {
124 pub fn new() -> Self {
126 Self::default()
127 }
128
129 pub fn url(&mut self, url: impl Into<String>) -> &mut Self {
131 self.url = Some(url.into());
132 self
133 }
134
135 pub fn kind(&mut self, kind: ClientKind) -> &mut Self {
137 self.kind = Some(kind);
138 self
139 }
140 pub fn local(url: impl Into<String>) -> Self {
143 let mut builder = Self::new();
144 builder.kind(ClientKind::Local).url(url);
145 builder
146 }
147
148 pub fn remote(url: impl Into<String>) -> Self {
151 let mut builder = Self::new();
152 builder.kind(ClientKind::Remote).url(url);
153 builder
154 }
155
156 pub fn poll_backoff(&mut self, backoff: ExponentialBackoff) -> &mut Self {
161 self.poll_backoff = Some(backoff);
162 self
163 }
164
165 pub fn connection_timeout(&mut self, timeout: Duration) -> &mut Self {
169 self.connection_timeout = Some(timeout);
170 self
171 }
172
173 pub fn request_timeout(&mut self, timeout: Duration) -> &mut Self {
177 self.request_timeout = Some(timeout);
178 self
179 }
180
181 pub fn runtime_handle(&mut self, handle: Handle) -> &mut Self {
186 self.handle = Some(handle);
187 self
188 }
189
190 pub fn build(&mut self) -> Client {
192 let (handle, runtime) = match self.handle.clone() {
193 Some(handle) => (handle, None),
194 None => {
195 let runtime = tokio::runtime::Builder::new_multi_thread()
196 .worker_threads(1)
197 .enable_all()
198 .build()
199 .unwrap();
200 (runtime.handle().clone(), Some(runtime))
201 }
202 };
203 Client {
204 inner: Arc::new(ClientInner {
205 kind: self.kind.expect("must specify client kind"),
206 url: self.url.clone().expect("must specify server URL"),
207 poll_backoff: self.poll_backoff.clone().unwrap_or_default(),
208 connection_timeout: self
209 .connection_timeout
210 .unwrap_or(Duration::from_millis(CONNECTION_TIMEOUT_MS_DEFAULT)),
211 request_timeout: self
212 .request_timeout
213 .unwrap_or(Duration::from_millis(REQUEST_TIMEOUT_MS_DEFAULT)),
214 handle,
215 runtime,
216 }),
217 }
218 }
219}
220
221impl Client {
222 pub fn with_default_config(kind: ClientKind, url: impl Into<String>) -> Self {
224 Self::builder().kind(kind).url(url).build()
225 }
226
227 pub fn builder() -> ClientBuilder {
229 ClientBuilder::new()
230 }
231
232 pub fn local(url: impl Into<String>) -> ClientBuilder {
236 ClientBuilder::local(url)
237 }
238
239 pub fn remote(url: impl Into<String>) -> ClientBuilder {
243 ClientBuilder::remote(url)
244 }
245
246 pub fn generate<
273 K: Serialize + Any + Send + Sync,
274 V: Serialize + DeserializeOwned + Send + Sync + Any,
275 >(
276 &self,
277 namespace: impl Into<Namespace>,
278 key: K,
279 generate_fn: impl GenerateFn<K, V>,
280 ) -> CacheHandle<V> {
281 let namespace = namespace.into();
282 let state = Client::setup_generate(namespace, key);
283 let handle = state.handle.clone();
284
285 match self.inner.kind {
286 ClientKind::Local => self.clone().generate_inner_local(state, generate_fn),
287 ClientKind::Remote => self.clone().generate_inner_remote(state, generate_fn),
288 }
289
290 handle
291 }
292
293 pub fn generate_with_state<
330 K: Serialize + Send + Sync + Any,
331 V: Serialize + DeserializeOwned + Send + Sync + Any,
332 S: Send + Sync + Any,
333 >(
334 &self,
335 namespace: impl Into<Namespace>,
336 key: K,
337 state: S,
338 generate_fn: impl GenerateWithStateFn<K, S, V>,
339 ) -> CacheHandle<V> {
340 let namespace = namespace.into();
341 self.generate(namespace, key, move |k| generate_fn(k, state))
342 }
343
344 pub fn generate_result<
379 K: Serialize + Any + Send + Sync,
380 V: Serialize + DeserializeOwned + Send + Sync + Any,
381 E: Send + Sync + Any,
382 >(
383 &self,
384 namespace: impl Into<Namespace>,
385 key: K,
386 generate_fn: impl GenerateResultFn<K, V, E>,
387 ) -> CacheHandle<std::result::Result<V, E>> {
388 let namespace = namespace.into();
389 let state = Client::setup_generate(namespace, key);
390 let handle = state.handle.clone();
391
392 match self.inner.kind {
393 ClientKind::Local => {
394 self.clone().generate_result_inner_local(state, generate_fn);
395 }
396 ClientKind::Remote => {
397 self.clone()
398 .generate_result_inner_remote(state, generate_fn);
399 }
400 }
401
402 handle
403 }
404
405 pub fn generate_result_with_state<
451 K: Serialize + Send + Sync + Any,
452 V: Serialize + DeserializeOwned + Send + Sync + Any,
453 E: Send + Sync + Any,
454 S: Send + Sync + Any,
455 >(
456 &self,
457 namespace: impl Into<Namespace>,
458 key: K,
459 state: S,
460 generate_fn: impl GenerateResultWithStateFn<K, S, V, E>,
461 ) -> CacheHandle<std::result::Result<V, E>> {
462 let namespace = namespace.into();
463 self.generate_result(namespace, key, move |k| generate_fn(k, state))
464 }
465
466 pub fn get<K: Cacheable>(
508 &self,
509 namespace: impl Into<Namespace>,
510 key: K,
511 ) -> CacheHandle<std::result::Result<K::Output, K::Error>> {
512 let namespace = namespace.into();
513 self.generate_result(namespace, key, |key| key.generate())
514 }
515
516 pub fn get_with_err<
559 E: Send + Sync + Serialize + DeserializeOwned + Any,
560 K: Cacheable<Error = E>,
561 >(
562 &self,
563 namespace: impl Into<Namespace>,
564 key: K,
565 ) -> CacheHandle<std::result::Result<K::Output, K::Error>> {
566 let namespace = namespace.into();
567 self.generate(namespace, key, |key| key.generate())
568 }
569
570 pub fn get_with_state<S: Send + Sync + Any, K: CacheableWithState<S>>(
619 &self,
620 namespace: impl Into<Namespace>,
621 key: K,
622 state: S,
623 ) -> CacheHandle<std::result::Result<K::Output, K::Error>> {
624 let namespace = namespace.into();
625 self.generate_result_with_state(namespace, key, state, |key, state| {
626 key.generate_with_state(state)
627 })
628 }
629
630 pub fn get_with_state_and_err<
636 S: Send + Sync + Any,
637 E: Send + Sync + Serialize + DeserializeOwned + Any,
638 K: CacheableWithState<S, Error = E>,
639 >(
640 &self,
641 namespace: impl Into<Namespace>,
642 key: K,
643 state: S,
644 ) -> CacheHandle<std::result::Result<K::Output, K::Error>> {
645 let namespace = namespace.into();
646 self.generate_with_state(namespace, key, state, |key, state| {
647 key.generate_with_state(state)
648 })
649 }
650
651 fn setup_generate<K: Serialize, V>(namespace: Namespace, key: K) -> GenerateState<K, V> {
653 GenerateState {
654 handle: CacheHandle::empty(),
655 namespace,
656 hash: crate::hash(&flexbuffers::to_vec(&key).unwrap()),
657 key,
658 }
659 }
660
661 fn spawn_handler<V: Send + Sync + Any>(
665 self,
666 handle: CacheHandle<V>,
667 handler: impl FnOnce() -> Result<()> + Send + Any,
668 ) {
669 thread::spawn(move || {
670 if let Err(e) = handler() {
671 tracing::error!("encountered error while executing handler: {}", e,);
672 handle.set(Err(Arc::new(e)));
673 }
674 });
675 }
676
677 fn deserialize_cache_value<V: DeserializeOwned>(data: &[u8]) -> Result<V> {
679 let data = flexbuffers::from_slice(data)?;
680 Ok(data)
681 }
682
683 fn deserialize_cache_result<V: DeserializeOwned, E>(
685 data: &[u8],
686 ) -> Result<std::result::Result<V, E>> {
687 let data = flexbuffers::from_slice(data)?;
688 Ok(Ok(data))
689 }
690
691 fn start_heartbeats(
696 &self,
697 heartbeat_interval: Duration,
698 send_heartbeat: impl HeartbeatFn,
699 ) -> (Sender<()>, Receiver<()>) {
700 tracing::debug!("starting heartbeats");
701 let (s_heartbeat_stop, r_heartbeat_stop) = channel();
702 let (s_heartbeat_stopped, r_heartbeat_stopped) = channel();
703 let self_clone = self.clone();
704 thread::spawn(move || {
705 loop {
706 match r_heartbeat_stop.recv_timeout(heartbeat_interval) {
707 Ok(_) | Err(RecvTimeoutError::Disconnected) => {
708 break;
709 }
710 Err(RecvTimeoutError::Timeout) => {
711 if send_heartbeat(&self_clone).is_err() {
712 break;
713 }
714 }
715 }
716 }
717 let _ = s_heartbeat_stopped.send(());
718 });
719 (s_heartbeat_stop, r_heartbeat_stopped)
720 }
721
722 fn run_backoff_loop<S>(&self, get_status_fn: impl Fn() -> Result<(S, bool)>) -> Result<S> {
728 Ok(backoff::retry(self.inner.poll_backoff.clone(), move || {
729 tracing::debug!("attempting get request to retrieve entry status");
730 get_status_fn()
731 .map_err(backoff::Error::Permanent)
732 .and_then(|(status, retry)| {
733 if retry {
734 tracing::debug!("entry is loading, retrying later");
735 Err(backoff::Error::transient(Error::EntryLoading))
736 } else {
737 tracing::debug!("entry status retrieved");
738 Ok(status)
739 }
740 })
741 })
742 .map_err(Box::new)?)
743 }
744
745 fn handle_unassigned<K: Send + Sync + Any, V: Send + Sync + Any>(
747 handle: CacheHandle<V>,
748 key: K,
749 generate_fn: impl GenerateFn<K, V>,
750 ) {
751 tracing::debug!("entry is unassigned, generating locally");
752 let v = run_generator(move || generate_fn(&key));
753 handle.set(v);
754 }
755
756 fn handle_assigned<K: Send + Sync + Any, V: Send + Sync + Any>(
759 &self,
760 key: K,
761 generate_fn: impl GenerateFn<K, V>,
762 heartbeat_interval_ms: u64,
763 send_heartbeat: impl HeartbeatFn,
764 ) -> ArcResult<V> {
765 tracing::debug!("entry has been assigned to the client, generating locally");
766 let (s_heartbeat_stop, r_heartbeat_stopped) =
767 self.start_heartbeats(Duration::from_millis(heartbeat_interval_ms), send_heartbeat);
768 let v = run_generator(move || generate_fn(&key));
769 let _ = s_heartbeat_stop.send(());
770 let _ = r_heartbeat_stopped.recv();
771 tracing::debug!("finished generating, writing value to cache");
772 v
773 }
774
775 async fn connect_local(&self) -> Result<local_cache_client::LocalCacheClient<Channel>> {
777 let endpoint = Endpoint::from_shared(self.inner.url.clone())?
778 .timeout(self.inner.request_timeout)
779 .connect_timeout(self.inner.connection_timeout);
780 let test = local_cache_client::LocalCacheClient::connect(endpoint).await;
781 Ok(test?)
782 }
783
784 fn get_rpc_local(
786 &self,
787 namespace: String,
788 key: Vec<u8>,
789 assign: bool,
790 ) -> Result<local::get_reply::EntryStatus> {
791 let out: Result<local::GetReply> = self.inner.handle.block_on(async {
792 let mut client = self.connect_local().await?;
793 Ok(client
794 .get(local::GetRequest {
795 namespace,
796 key,
797 assign,
798 })
799 .await?
800 .into_inner())
801 });
802 Ok(out?.entry_status.unwrap())
803 }
804
805 fn heartbeat_rpc_local(&self, id: u64) -> Result<()> {
807 self.inner.handle.block_on(async {
808 let mut client = self.connect_local().await?;
809 client.heartbeat(local::HeartbeatRequest { id }).await?;
810 Ok(())
811 })
812 }
813
814 fn done_rpc_local(&self, id: u64) -> Result<()> {
816 self.inner.handle.block_on(async {
817 let mut client = self.connect_local().await?;
818 client.done(local::DoneRequest { id }).await?;
819 Ok(())
820 })
821 }
822
823 fn drop_rpc_local(&self, id: u64) -> Result<()> {
825 self.inner.handle.block_on(async {
826 let mut client = self.connect_local().await?;
827 client.drop(local::DropRequest { id }).await?;
828 Ok(())
829 })
830 }
831
832 fn write_generated_data_to_disk<V: Serialize>(
833 &self,
834 id: u64,
835 path: String,
836 data: &V,
837 ) -> Result<()> {
838 let path = PathBuf::from(path);
839 if let Some(parent) = path.parent() {
840 fs::create_dir_all(parent)?;
841 }
842
843 let mut f = OpenOptions::new()
844 .read(true)
845 .write(true)
846 .create(true)
847 .open(&path)?;
848 f.write_all(&flexbuffers::to_vec(data).unwrap())?;
849 self.done_rpc_local(id)?;
850
851 Ok(())
852 }
853
854 fn write_generated_value_local<V: Serialize>(
856 &self,
857 id: u64,
858 path: String,
859 value: &ArcResult<V>,
860 ) -> Result<()> {
861 if let Ok(data) = value {
862 self.write_generated_data_to_disk(id, path, data)?;
863 }
864 Ok(())
865 }
866
867 fn write_generated_result_local<V: Serialize, E>(
871 &self,
872 id: u64,
873 path: String,
874 value: &ArcResult<std::result::Result<V, E>>,
875 ) -> Result<()> {
876 if let Ok(Ok(data)) = value {
877 self.write_generated_data_to_disk(id, path, data)?;
878 }
879 Ok(())
880 }
881
882 fn generate_loop_local<K: Send + Sync + Any, V: Send + Sync + Any>(
885 &self,
886 state: GenerateState<K, V>,
887 generate_fn: impl GenerateFn<K, V>,
888 write_generated_value: impl LocalWriteValueFn<V>,
889 deserialize_cache_data: impl DeserializeValueFn<V>,
890 ) -> Result<()> {
891 let GenerateState {
892 handle,
893 namespace,
894 hash,
895 key,
896 } = state;
897
898 let status = self.run_backoff_loop(|| {
899 let status = self.get_rpc_local(namespace.clone().into_inner(), hash.clone(), true)?;
900 let retry = matches!(status, local::get_reply::EntryStatus::Loading(_));
901
902 Ok((status, retry))
903 })?;
904
905 match status {
906 local::get_reply::EntryStatus::Unassigned(_) => {
907 Client::handle_unassigned(handle, key, generate_fn);
908 }
909 local::get_reply::EntryStatus::Assign(local::AssignReply {
910 id,
911 path,
912 heartbeat_interval_ms,
913 }) => {
914 let v = self.handle_assigned(
915 key,
916 generate_fn,
917 heartbeat_interval_ms,
918 move |client| -> Result<()> { client.heartbeat_rpc_local(id) },
919 );
920 write_generated_value(self, id, path, &v)?;
921 handle.set(v);
922 }
923 local::get_reply::EntryStatus::Loading(_) => unreachable!(),
924 local::get_reply::EntryStatus::Ready(local::ReadyReply { id, path }) => {
925 tracing::debug!("entry is ready, reading from cache");
926 let mut file = std::fs::File::open(path)?;
927 let mut buf = Vec::new();
928 file.read_to_end(&mut buf)?;
929 self.drop_rpc_local(id)?;
930 tracing::debug!("finished reading entry from disk");
931 handle.set(Ok(deserialize_cache_data(&buf)?));
932 }
933 }
934 Ok(())
935 }
936
937 fn generate_inner_local<
938 K: Serialize + Any + Send + Sync,
939 V: Serialize + DeserializeOwned + Send + Sync + Any,
940 >(
941 self,
942 state: GenerateState<K, V>,
943 generate_fn: impl GenerateFn<K, V>,
944 ) {
945 tracing::debug!("generating using local cache API");
946 self.clone().spawn_handler(state.handle.clone(), move || {
947 self.generate_loop_local(
948 state,
949 generate_fn,
950 Client::write_generated_value_local,
951 Client::deserialize_cache_value,
952 )
953 });
954 }
955
956 fn generate_result_inner_local<
957 K: Serialize + Any + Send + Sync,
958 V: Serialize + DeserializeOwned + Send + Sync + Any,
959 E: Send + Sync + Any,
960 >(
961 self,
962 state: GenerateState<K, std::result::Result<V, E>>,
963 generate_fn: impl GenerateResultFn<K, V, E>,
964 ) {
965 self.clone().spawn_handler(state.handle.clone(), move || {
966 self.generate_loop_local(
967 state,
968 generate_fn,
969 Client::write_generated_result_local,
970 Client::deserialize_cache_result,
971 )
972 });
973 }
974
975 async fn connect_remote(&self) -> Result<remote_cache_client::RemoteCacheClient<Channel>> {
977 let endpoint = Endpoint::from_shared(self.inner.url.clone())?
978 .timeout(self.inner.request_timeout)
979 .connect_timeout(self.inner.connection_timeout);
980 Ok(remote_cache_client::RemoteCacheClient::connect(endpoint).await?)
981 }
982
983 fn get_rpc_remote(
985 &self,
986 namespace: String,
987 key: Vec<u8>,
988 assign: bool,
989 ) -> Result<remote::get_reply::EntryStatus> {
990 let out: Result<remote::GetReply> = self.inner.handle.block_on(async {
991 let mut client = self.connect_remote().await?;
992 Ok(client
993 .get(remote::GetRequest {
994 namespace,
995 key,
996 assign,
997 })
998 .await?
999 .into_inner())
1000 });
1001 Ok(out?.entry_status.unwrap())
1002 }
1003
1004 fn heartbeat_rpc_remote(&self, id: u64) -> Result<()> {
1006 self.inner.handle.block_on(async {
1007 let mut client = self.connect_remote().await?;
1008 client.heartbeat(remote::HeartbeatRequest { id }).await?;
1009 Ok(())
1010 })
1011 }
1012
1013 fn set_rpc_remote(&self, id: u64, value: Vec<u8>) -> Result<()> {
1015 self.inner.handle.block_on(async {
1016 let mut client = self.connect_remote().await?;
1017 client.set(remote::SetRequest { id, value }).await?;
1018 Ok(())
1019 })
1020 }
1021
1022 fn write_generated_value_remote<V: Serialize>(
1024 &self,
1025 id: u64,
1026 value: &ArcResult<V>,
1027 ) -> Result<()> {
1028 if let Ok(data) = value {
1029 self.set_rpc_remote(id, flexbuffers::to_vec(data).unwrap())?;
1030 }
1031 Ok(())
1032 }
1033
1034 fn write_generated_result_remote<V: Serialize, E>(
1038 &self,
1039 id: u64,
1040 value: &ArcResult<std::result::Result<V, E>>,
1041 ) -> Result<()> {
1042 if let Ok(Ok(data)) = value {
1043 self.set_rpc_remote(id, flexbuffers::to_vec(data).unwrap())?;
1044 }
1045 Ok(())
1046 }
1047
1048 fn generate_loop_remote<K: Send + Sync + Any, V: Send + Sync + Any>(
1051 &self,
1052 state: GenerateState<K, V>,
1053 generate_fn: impl GenerateFn<K, V>,
1054 write_generated_value: impl RemoteWriteValueFn<V>,
1055 deserialize_cache_data: impl DeserializeValueFn<V>,
1056 ) -> Result<()> {
1057 let GenerateState {
1058 handle,
1059 namespace,
1060 hash,
1061 key,
1062 } = state;
1063
1064 let status = self.run_backoff_loop(|| {
1065 let status = self.get_rpc_remote(namespace.clone().into_inner(), hash.clone(), true)?;
1066 let retry = matches!(status, remote::get_reply::EntryStatus::Loading(_));
1067
1068 Ok((status, retry))
1069 })?;
1070
1071 match status {
1072 remote::get_reply::EntryStatus::Unassigned(_) => {
1073 Client::handle_unassigned(handle, key, generate_fn);
1074 }
1075 remote::get_reply::EntryStatus::Assign(remote::AssignReply {
1076 id,
1077 heartbeat_interval_ms,
1078 }) => {
1079 let v = self.handle_assigned(
1080 key,
1081 generate_fn,
1082 heartbeat_interval_ms,
1083 move |client| -> Result<()> { client.heartbeat_rpc_remote(id) },
1084 );
1085 write_generated_value(self, id, &v)?;
1086 handle.set(v);
1087 }
1088 remote::get_reply::EntryStatus::Loading(_) => unreachable!(),
1089 remote::get_reply::EntryStatus::Ready(data) => {
1090 tracing::debug!("entry is ready");
1091 handle.set(Ok(deserialize_cache_data(&data)?));
1092 }
1093 }
1094 Ok(())
1095 }
1096
1097 fn generate_inner_remote<
1098 K: Serialize + Any + Send + Sync,
1099 V: Serialize + DeserializeOwned + Send + Sync + Any,
1100 >(
1101 self,
1102 state: GenerateState<K, V>,
1103 generate_fn: impl GenerateFn<K, V>,
1104 ) {
1105 tracing::debug!("generating using remote cache API");
1106 self.clone().spawn_handler(state.handle.clone(), move || {
1107 self.generate_loop_remote(
1108 state,
1109 generate_fn,
1110 Client::write_generated_value_remote,
1111 Client::deserialize_cache_value,
1112 )
1113 });
1114 }
1115
1116 fn generate_result_inner_remote<
1117 K: Serialize + Any + Send + Sync,
1118 V: Serialize + DeserializeOwned + Send + Sync + Any,
1119 E: Send + Sync + Any,
1120 >(
1121 self,
1122 state: GenerateState<K, std::result::Result<V, E>>,
1123 generate_fn: impl GenerateResultFn<K, V, E>,
1124 ) {
1125 self.clone().spawn_handler(state.handle.clone(), move || {
1126 self.generate_loop_remote(
1127 state,
1128 generate_fn,
1129 Client::write_generated_result_remote,
1130 Client::deserialize_cache_result,
1131 )
1132 });
1133 }
1134}
1135
1136pub(crate) const BUILD_DIR: &str = concat!(env!("CARGO_MANIFEST_DIR"), "/build");
1137pub(crate) const TEST_SERVER_HEARTBEAT_INTERVAL: Duration = Duration::from_millis(200);
1138pub(crate) const TEST_SERVER_HEARTBEAT_TIMEOUT: Duration = Duration::from_millis(500);
1139
1140pub(crate) fn get_listeners(n: usize) -> Vec<(TcpListener, u16)> {
1141 let mut listeners = Vec::new();
1142
1143 for _ in 0..n {
1144 let listener = TcpListener::bind("127.0.0.1:0").unwrap();
1145 let port = listener.local_addr().unwrap().port();
1146 listeners.push((listener, port));
1147 }
1148
1149 listeners
1150}
1151
1152#[doc(hidden)]
1153#[derive(Copy, Clone, Debug, PartialEq, Eq)]
1154pub enum ServerKind {
1155 Local,
1156 Remote,
1157 Both,
1158}
1159
1160impl From<ClientKind> for ServerKind {
1161 fn from(value: ClientKind) -> Self {
1162 match value {
1163 ClientKind::Local => ServerKind::Local,
1164 ClientKind::Remote => ServerKind::Remote,
1165 }
1166 }
1167}
1168
1169pub(crate) fn client_url(port: u16) -> String {
1170 format!("http://127.0.0.1:{port}")
1171}
1172
1173#[doc(hidden)]
1174pub fn create_server_and_clients(
1175 root: PathBuf,
1176 kind: ServerKind,
1177 handle: &Handle,
1178) -> (CacheHandle<Result<()>>, Client, Client) {
1179 let mut listeners = handle.block_on(async {
1180 get_listeners(2)
1181 .into_iter()
1182 .map(|(listener, port)| {
1183 listener.set_nonblocking(true).unwrap();
1184 (tokio::net::TcpListener::from_std(listener).unwrap(), port)
1185 })
1186 .collect::<Vec<_>>()
1187 });
1188 let (local_listener, local_port) = listeners.pop().unwrap();
1189 let (remote_listener, remote_port) = listeners.pop().unwrap();
1190
1191 (
1192 {
1193 let mut builder = Server::builder();
1194
1195 builder = builder
1196 .heartbeat_interval(TEST_SERVER_HEARTBEAT_INTERVAL)
1197 .heartbeat_timeout(TEST_SERVER_HEARTBEAT_TIMEOUT)
1198 .root(root);
1199
1200 let server = match kind {
1201 ServerKind::Local => builder.local_with_incoming(local_listener),
1202 ServerKind::Remote => builder.remote_with_incoming(remote_listener),
1203 ServerKind::Both => builder
1204 .local_with_incoming(local_listener)
1205 .remote_with_incoming(remote_listener),
1206 }
1207 .build();
1208
1209 let join_handle = handle.spawn(async move { server.start().await });
1210 let handle_clone = handle.clone();
1211 CacheHandle::new(move || {
1212 let res = handle_clone.block_on(join_handle).unwrap_or_else(|res| {
1213 if res.is_cancelled() {
1214 Ok(())
1215 } else {
1216 Err(Error::Panic)
1217 }
1218 });
1219 if let Err(e) = res.as_ref() {
1220 tracing::error!("server failed to start: {:?}", e);
1221 }
1222 res
1223 })
1224 },
1225 Client::builder()
1226 .kind(ClientKind::Local)
1227 .url(client_url(local_port))
1228 .connection_timeout(Duration::from_secs(3))
1229 .request_timeout(Duration::from_secs(3))
1230 .build(),
1231 Client::builder()
1232 .kind(ClientKind::Remote)
1233 .url(client_url(remote_port))
1234 .connection_timeout(Duration::from_secs(3))
1235 .request_timeout(Duration::from_secs(3))
1236 .build(),
1237 )
1238}
1239
1240pub(crate) fn reset_directory(path: impl AsRef<Path>) -> Result<()> {
1241 let path = path.as_ref();
1242 if path.exists() {
1243 fs::remove_dir_all(path)?;
1244 }
1245 fs::create_dir_all(path)?;
1246 Ok(())
1247}
1248
1249pub(crate) fn create_runtime() -> Runtime {
1250 tokio::runtime::Builder::new_multi_thread()
1251 .worker_threads(1)
1252 .enable_all()
1253 .build()
1254 .unwrap()
1255}
1256
1257#[doc(hidden)]
1258pub fn setup_test(test_name: &str) -> Result<(PathBuf, Arc<Mutex<u64>>, Runtime)> {
1259 let path = PathBuf::from(BUILD_DIR).join(test_name);
1260 reset_directory(&path)?;
1261 Ok((path, Arc::new(Mutex::new(0)), create_runtime()))
1262}