1use std::{marker::PhantomData, time::Duration};
2
3use sqlx::{postgres::PgRow, PgPool};
4use tokio::{task::JoinHandle, time};
5
6pub struct PgDbIdleAgent<T, F, E>
13where
14 T: for<'r> sqlx::FromRow<'r, PgRow> + Send + Sync + Unpin + 'static,
15 F: Fn(&T) + Send + Sync + 'static,
16 E: Fn(sqlx::Error) + Send + Sync + 'static, {
18 interval_secs: Duration,
19 pool: PgPool,
20 query: String,
21 action: F,
22 error_handler: E,
23 _marker: PhantomData<T>, }
25
26impl<T, F, E> PgDbIdleAgent<T, F, E>
27where
28 T: for<'r> sqlx::FromRow<'r, PgRow> + Send + Sync + Unpin + 'static,
29
30 F: Fn(&T) + Send + Sync + 'static,
31 E: Fn(sqlx::Error) + Send + Sync + 'static, {
33 pub fn new(
34 interval_secs: Duration,
35 pool: PgPool,
36 query: String,
37 action: F,
38 error_handler: E,
39 ) -> Self {
40 Self {
41 interval_secs,
42 pool,
43 query,
44 action,
45 error_handler,
46 _marker: PhantomData, }
48 }
49
50 pub async fn start(self) -> JoinHandle<()> {
51 let mut ticker = time::interval(self.interval_secs);
52 tokio::task::spawn(async move {
53 loop {
54 ticker.tick().await;
55 if let Err(e) = self.check_data().await {
56 (self.error_handler)(e);
57 }
58 }
59 })
60 }
61
62 async fn check_data(&self) -> Result<(), sqlx::Error>
63 where
64 T: for<'r> sqlx::FromRow<'r, PgRow> + Send + Sync + Unpin,
65 {
66 let rows: Vec<T> = sqlx::query_as::<_, T>(self.query.as_str())
67 .fetch_all(&self.pool)
68 .await?;
69 rows.into_iter().for_each(|element| {
70 (self.action)(&element); });
72 Ok(())
73 }
74}
75
76#[cfg(test)]
77mod tests {
78 use super::*;
79 use serial_test::serial;
80 use sqlx::{postgres::PgPoolOptions, FromRow, Pool, Postgres};
81
82 #[derive(FromRow, Debug, PartialEq)]
83 pub struct Example {
84 pub id: i32,
85 pub data: String,
86 pub is_sent: bool,
87 pub version: i32,
88 }
89
90 async fn drop_examples(pool: &Pool<Postgres>) {
91 sqlx::query("DROP TABLE IF EXISTS example CASCADE")
100 .execute(pool)
101 .await
102 .unwrap();
103
104 sqlx::query("DROP SEQUENCE IF EXISTS example_id_seq CASCADE")
107 .execute(pool)
108 .await
109 .unwrap();
110 }
111
112 async fn create_example_table(pool: &Pool<Postgres>) {
113 sqlx::query(
114 "CREATE TABLE IF NOT EXISTS example (
115 id SERIAL PRIMARY KEY,
116 data TEXT NOT NULL,
117 is_sent BOOLEAN NOT NULL,
118 version INT NOT NULL
119 )",
120 )
121 .execute(pool)
122 .await
123 .unwrap();
124 }
125
126 async fn insert_example_data(pool: &PgPool) {
127 let data_list = vec![
128 ("Some random text".to_string(), false, 0),
129 ("Another text".to_string(), true, 1),
130 ("third text".to_string(), true, 0),
131 ];
132
133 for (data, is_sent, version) in data_list {
134 sqlx::query("INSERT INTO example (data, is_sent, version) VALUES ($1, $2, $3)")
135 .bind(&data)
136 .bind(is_sent)
137 .bind(version)
138 .execute(pool)
139 .await
140 .unwrap();
141 }
142 }
143
144 async fn setup_db() -> Pool<Postgres> {
145 let pool = PgPoolOptions::new()
146 .connect("postgres://test:test@localhost:5439/test")
147 .await
148 .unwrap();
149
150 let tx = pool.begin().await.unwrap();
151
152 drop_examples(&pool).await;
153
154 create_example_table(&pool).await;
155
156 insert_example_data(&pool).await;
157
158 tx.commit().await.unwrap();
159
160 pool
161 }
162
163 async fn get_all_examples(pool: &sqlx::PgPool) -> Vec<Example> {
164 sqlx::query_as::<_, Example>("SELECT id, data, is_sent, version FROM example")
165 .fetch_all(pool)
166 .await
167 .unwrap()
168 }
169
170 #[tokio::test]
171 #[serial]
172 async fn test_db_setup() {
173 let expected_data = [
174 Example {
175 id: 1,
176 data: "Some random text".to_string(),
177 is_sent: false,
178 version: 0,
179 },
180 Example {
181 id: 2,
182 data: "Another text".to_string(),
183 is_sent: true,
184 version: 1,
185 },
186 Example {
187 id: 3,
188 data: "third text".to_string(),
189 is_sent: true,
190 version: 0,
191 },
192 ];
193 let pool = setup_db().await;
194 let examples = get_all_examples(&pool).await;
195 examples.into_iter().enumerate().for_each(|(index, e)| {
196 assert_eq!(
197 e, expected_data[index],
198 "The fetched data does not match the expected data."
199 );
200 })
201 }
202
203 #[tokio::test]
204 #[serial]
205 async fn test_pg_db_idle_agent() {
206 let pool = setup_db().await;
207
208 let action = |example: &Example| {
209 println!("Processing example {:?}", example);
210 };
211
212 let error_handler = |err: sqlx::Error| {
213 eprintln!("Error while processing examples: {:?}", err);
214 };
215
216 let interval_secs = Duration::from_secs(1);
217 let query = "SELECT id, data, is_sent, version FROM example".to_string();
218 let agent = PgDbIdleAgent::new(interval_secs, pool.clone(), query, action, error_handler);
219
220 let handle = agent.start().await;
221
222 tokio::time::sleep(Duration::from_secs(4)).await;
223
224 handle.abort();
225 }
226
227 #[tokio::test]
228 #[serial]
229 async fn test_pg_db_idle_agent_error() {
230 let pool = setup_db().await;
231
232 let action = |example: &Example| {
233 println!("Processing example {:?}", example);
234 };
235
236 let error_handler = |err: sqlx::Error| {
237 eprintln!("Error while processing examples: {:?}", err);
238 };
239
240 let interval_secs = Duration::from_secs(1);
241 let query = "INVALID SQL".to_string();
242 let agent = PgDbIdleAgent::new(interval_secs, pool.clone(), query, action, error_handler);
243
244 let handle = agent.start().await;
245
246 tokio::time::sleep(Duration::from_secs(4)).await;
247
248 handle.abort();
249 }
250
251}