noosphere_gateway/
try_or_reset.rs1use anyhow::Result;
2use std::future::Future;
3use std::sync::Arc;
4use tokio::sync::OnceCell;
5
6pub struct TryOrReset<I, O, F>
31where
32 F: Future<Output = Result<O, anyhow::Error>>,
33 I: Fn() -> F,
34{
35 init: I,
36 initialized: OnceCell<Arc<O>>,
37}
38
39impl<I, O, F> TryOrReset<I, O, F>
40where
41 F: Future<Output = Result<O, anyhow::Error>>,
42 I: Fn() -> F,
43{
44 pub fn new(init: I) -> Self {
45 TryOrReset {
46 init,
47 initialized: OnceCell::new(),
48 }
49 }
50
51 pub async fn invoke<Ii, Oo, Ff>(&mut self, invoke: Ii) -> Result<Oo>
56 where
57 Ii: FnOnce(Arc<O>) -> Ff,
58 Ff: Future<Output = Result<Oo, anyhow::Error>>,
59 {
60 match self
61 .initialized
62 .get_or_try_init(|| async { Ok(Arc::new((self.init)().await?)) })
63 .await
64 {
65 Ok(initialized) => match invoke(initialized.clone()).await {
66 Ok(output) => Ok(output),
67 Err(error) => {
68 self.initialized.take();
69 Err(error)
70 }
71 },
72 Err(error) => Err(error),
73 }
74 }
75}
76
77#[cfg(test)]
78mod tests {
79 use anyhow::{anyhow, Result};
80 use std::{ops::AddAssign, sync::Arc};
81
82 use tokio::sync::Mutex;
83
84 use super::TryOrReset;
85
86 #[tokio::test]
87 async fn it_initializes_context_before_invocation_and_recovers_from_failure() {
88 let count = Arc::new(Mutex::new(0u32));
89
90 let mut again = TryOrReset::new(|| async {
91 let mut count = count.lock().await;
92 count.add_assign(1);
93 Ok(format!("Hello {}", count))
94 });
95
96 again
97 .invoke(|context| async move {
98 assert_eq!("Hello 1", context.as_str());
99 Ok(())
100 })
101 .await
102 .unwrap();
103
104 let _: Result<()> = again
105 .invoke(|_| async move { Err(anyhow!("Arbitrary error")) })
106 .await;
107
108 again
109 .invoke(|context| async move {
110 assert_eq!("Hello 2", context.as_str());
111 Ok(())
112 })
113 .await
114 .unwrap();
115 }
116
117 #[tokio::test]
118 async fn it_only_initializes_context_once_as_long_as_results_are_ok() {
119 let count = Arc::new(Mutex::new(0u32));
120
121 let mut again = TryOrReset::new(|| async {
122 let mut count = count.lock().await;
123 count.add_assign(1);
124 Ok(format!("Hello {}", count))
125 });
126
127 for _ in 0..10 {
128 again
129 .invoke(|context| async move {
130 assert_eq!("Hello 1", context.as_str());
131 Ok(())
132 })
133 .await
134 .unwrap();
135 }
136 }
137
138 #[tokio::test]
139 async fn it_will_try_again_next_time_if_initialization_fails() {
140 let count = Arc::new(Mutex::new(0u32));
141 let mut again = TryOrReset::new(|| async {
142 let mut count = count.lock().await;
143 count.add_assign(1);
144 if count.to_owned() == 1 {
145 Err(anyhow!("Arbitrary failure"))
146 } else {
147 Ok(format!("Hello {}", count))
148 }
149 });
150
151 let _ = again
152 .invoke::<_, (), _>(|_| async move {
153 unreachable!("First initialization should not have succeeded");
154 })
155 .await;
156
157 again
158 .invoke(|context| async move {
159 assert_eq!("Hello 2", context.as_str());
160 Ok(())
161 })
162 .await
163 .unwrap()
164 }
165}