kaccy_db/
generic_repository.rs1use async_trait::async_trait;
9use serde::{de::DeserializeOwned, Serialize};
10use sqlx::PgPool;
11use std::sync::Arc;
12use uuid::Uuid;
13
14use crate::cache::RedisCache;
15use crate::error::Result;
16
17#[derive(Debug, Clone)]
19pub struct Pagination {
20 pub page: u64,
22 pub page_size: u64,
24}
25
26impl Default for Pagination {
27 fn default() -> Self {
28 Self {
29 page: 0,
30 page_size: 20,
31 }
32 }
33}
34
35impl Pagination {
36 pub fn new(page: u64, page_size: u64) -> Self {
38 Self { page, page_size }
39 }
40
41 pub fn limit(&self) -> i64 {
43 self.page_size as i64
44 }
45
46 pub fn offset(&self) -> i64 {
48 (self.page * self.page_size) as i64
49 }
50}
51
52#[derive(Debug, Clone, Serialize)]
54pub struct Page<T> {
55 pub items: Vec<T>,
57 pub total: u64,
59 pub page: u64,
61 pub page_size: u64,
63 pub total_pages: u64,
65}
66
67impl<T> Page<T> {
68 pub fn new(items: Vec<T>, total: u64, pagination: &Pagination) -> Self {
70 let total_pages = if pagination.page_size > 0 {
71 total.div_ceil(pagination.page_size)
72 } else {
73 0
74 };
75
76 Self {
77 items,
78 total,
79 page: pagination.page,
80 page_size: pagination.page_size,
81 total_pages,
82 }
83 }
84
85 pub fn has_next(&self) -> bool {
87 self.page + 1 < self.total_pages
88 }
89
90 pub fn has_prev(&self) -> bool {
92 self.page > 0
93 }
94}
95
96#[async_trait]
98pub trait Repository<T>: Send + Sync
99where
100 T: Send + Sync + Clone + Serialize + DeserializeOwned,
101{
102 fn entity_name(&self) -> &str;
104
105 async fn find_by_id(&self, id: Uuid) -> Result<Option<T>>;
107
108 async fn find_all(&self, pagination: Pagination) -> Result<Page<T>>;
110
111 async fn create(&self, entity: &T) -> Result<T>;
113
114 async fn update(&self, id: Uuid, entity: &T) -> Result<T>;
116
117 async fn delete(&self, id: Uuid) -> Result<bool>;
119
120 async fn count(&self) -> Result<u64>;
122
123 async fn exists(&self, id: Uuid) -> Result<bool> {
125 Ok(self.find_by_id(id).await?.is_some())
126 }
127}
128
129#[allow(dead_code)]
131pub struct CachedGenericRepository<T>
132where
133 T: Send + Sync + Clone + Serialize + DeserializeOwned,
134{
135 pool: Arc<PgPool>,
136 cache: Option<Arc<RedisCache>>,
137 entity_name: String,
138 _phantom: std::marker::PhantomData<T>,
139}
140
141impl<T> CachedGenericRepository<T>
142where
143 T: Send + Sync + Clone + Serialize + DeserializeOwned,
144{
145 pub fn new(pool: Arc<PgPool>, cache: Option<Arc<RedisCache>>, entity_name: String) -> Self {
147 Self {
148 pool,
149 cache,
150 entity_name,
151 _phantom: std::marker::PhantomData,
152 }
153 }
154
155 #[allow(dead_code)]
157 fn cache_key(&self, id: Uuid) -> String {
158 format!("{}:{}", self.entity_name, id)
159 }
160
161 #[allow(dead_code)]
163 async fn get_from_cache(&self, id: Uuid) -> Result<Option<T>> {
164 if let Some(cache) = &self.cache {
165 cache.get(&self.cache_key(id)).await
166 } else {
167 Ok(None)
168 }
169 }
170
171 #[allow(dead_code)]
173 async fn set_in_cache(&self, id: Uuid, entity: &T, ttl_secs: u64) -> Result<()> {
174 if let Some(cache) = &self.cache {
175 cache.set(&self.cache_key(id), entity, ttl_secs).await?;
176 }
177 Ok(())
178 }
179
180 #[allow(dead_code)]
182 async fn invalidate_cache(&self, id: Uuid) -> Result<()> {
183 if let Some(cache) = &self.cache {
184 cache.delete(&self.cache_key(id)).await?;
185 }
186 Ok(())
187 }
188
189 pub fn pool(&self) -> &PgPool {
191 &self.pool
192 }
193}
194
195#[cfg(test)]
196mod tests {
197 use super::*;
198
199 #[test]
200 fn test_pagination_default() {
201 let pagination = Pagination::default();
202 assert_eq!(pagination.page, 0);
203 assert_eq!(pagination.page_size, 20);
204 assert_eq!(pagination.limit(), 20);
205 assert_eq!(pagination.offset(), 0);
206 }
207
208 #[test]
209 fn test_pagination_offset() {
210 let pagination = Pagination::new(2, 10);
211 assert_eq!(pagination.limit(), 10);
212 assert_eq!(pagination.offset(), 20);
213 }
214
215 #[test]
216 fn test_page_creation() {
217 let items = vec![1, 2, 3];
218 let pagination = Pagination::new(0, 10);
219 let page = Page::new(items, 25, &pagination);
220
221 assert_eq!(page.items.len(), 3);
222 assert_eq!(page.total, 25);
223 assert_eq!(page.page, 0);
224 assert_eq!(page.page_size, 10);
225 assert_eq!(page.total_pages, 3);
226 }
227
228 #[test]
229 fn test_page_has_next() {
230 let items = vec![1, 2, 3];
231 let pagination = Pagination::new(0, 10);
232 let page = Page::new(items, 25, &pagination);
233
234 assert!(page.has_next());
235 assert!(!page.has_prev());
236 }
237
238 #[test]
239 fn test_page_has_prev() {
240 let items = vec![1, 2, 3];
241 let pagination = Pagination::new(1, 10);
242 let page = Page::new(items, 25, &pagination);
243
244 assert!(page.has_next());
245 assert!(page.has_prev());
246 }
247
248 #[test]
249 fn test_page_last_page() {
250 let items = vec![1, 2, 3];
251 let pagination = Pagination::new(2, 10);
252 let page = Page::new(items, 25, &pagination);
253
254 assert!(!page.has_next());
255 assert!(page.has_prev());
256 }
257}