1use std::rc::Rc;
2
3use itertools::Itertools;
4use rusqlite::{Connection, Params, Result, Row, limits::Limit, params, types::Value};
5
6pub fn sqlite_parameter_limit(conn: &Connection) -> usize {
8 let limit = conn.limit(Limit::SQLITE_LIMIT_VARIABLE_NUMBER);
9 usize::try_from(limit).expect("SQLite parameter limit should be positive")
10}
11
12pub fn max_rows_per_batch(conn: &Connection, params_per_row: usize) -> usize {
14 let params_per_row = params_per_row.max(1);
15 let max_params = sqlite_parameter_limit(conn);
16 (max_params / params_per_row).max(1)
17}
18
19pub trait Query {
20 type Model;
21 const PRIMARY_KEY: &'static str = "id";
22 const TABLE_NAME: &'static str;
23
24 fn query(conn: &Connection, query: &str, params: impl Params) -> Vec<Self::Model> {
25 let mut stmt = conn.prepare(query).unwrap();
26 let rows = stmt
27 .query_map(params, |row| Ok(Self::process_row(row)))
28 .unwrap();
29 let mut objs = vec![];
30 for row in rows {
31 objs.push(row.unwrap());
32 }
33 objs
34 }
35
36 fn get(conn: &Connection, query: &str, params: impl Params) -> Result<Self::Model> {
37 let mut stmt = conn.prepare(query).unwrap();
38 stmt.query_row(params, |row| Ok(Self::process_row(row)))
39 }
40
41 fn get_by_id<'a, T>(conn: &Connection, id: &'a T) -> Option<Self::Model>
42 where
43 T: Clone + 'a,
44 Value: From<T>,
45 {
46 Self::query(
47 conn,
48 &format!(
49 "select * from {} where {} = ?1",
50 Self::TABLE_NAME,
51 Self::PRIMARY_KEY
52 ),
53 params![Value::from(id.clone())],
54 )
55 .pop()
56 }
57
58 fn query_by_ids<'a, I: ?Sized, T>(conn: &Connection, ids: &'a I) -> Vec<Self::Model>
59 where
60 &'a I: IntoIterator<Item = &'a T>,
61 T: Clone + 'a,
62 Value: From<T>,
63 {
64 let mut results = vec![];
65 let batch_size = max_rows_per_batch(conn, 1);
66 for chunk in &ids.into_iter().chunks(batch_size) {
67 let values: Vec<Value> = chunk
68 .map(|value: &'a T| Value::from(value.clone()))
69 .collect();
70 results.append(&mut Self::query(
71 conn,
72 &format!(
75 "
76 WITH arr AS (
77 SELECT value, rowid AS pos
78 FROM rarray(?1)
79 )
80 SELECT {}.*
81 FROM {}
82 JOIN arr ON {}.{} = arr.value
83 ORDER BY arr.pos;
84 ",
85 Self::TABLE_NAME,
86 Self::TABLE_NAME,
87 Self::TABLE_NAME,
88 Self::PRIMARY_KEY
89 ),
90 params!(Rc::new(values)),
91 ))
92 }
93 results
94 }
95
96 fn delete_by_ids<'a, I: ?Sized, T>(conn: &Connection, ids: &'a I) -> Vec<Self::Model>
97 where
98 &'a I: IntoIterator<Item = &'a T>,
99 T: Clone + 'a,
100 Value: From<T>,
101 {
102 let mut results = vec![];
103 let batch_size = max_rows_per_batch(conn, 1);
104 for chunk in &ids.into_iter().chunks(batch_size) {
105 let values: Vec<Value> = chunk
106 .map(|value: &'a T| Value::from(value.clone()))
107 .collect();
108 results.append(&mut Self::query(
109 conn,
110 &format!(
111 "delete from {} where {} in rarray(?1)",
112 Self::TABLE_NAME,
113 Self::PRIMARY_KEY
114 ),
115 params!(Rc::new(values)),
116 ))
117 }
118 results
119 }
120
121 fn process_row(row: &Row) -> Self::Model;
122
123 fn table_name() -> &'static str {
124 Self::TABLE_NAME
125 }
126
127 fn all(conn: &Connection) -> Vec<Self::Model> {
128 let query = format!("SELECT * FROM {}", Self::TABLE_NAME);
129 Self::query(conn, &query, [])
130 }
131
132 fn all_with_limit(conn: &Connection, limit: usize) -> Vec<Self::Model> {
133 let query = format!("SELECT * FROM {} LIMIT {}", Self::TABLE_NAME, limit);
134 Self::query(conn, &query, [])
135 }
136}
137
138#[cfg(test)]
139mod tests {
140 use super::*;
141 use crate::{collection::Collection, test_helpers::get_connection};
142
143 #[test]
144 fn test_all_method() {
145 let conn = get_connection(None).expect("Failed to get connection");
146
147 let empty_results = Collection::all(&conn);
149 assert!(empty_results.is_empty());
150
151 let collection_names = vec![
153 "test_collection_1",
154 "test_collection_2",
155 "test_collection_3",
156 ];
157 for name in &collection_names {
158 Collection::create(&conn, name);
159 }
160
161 let all_results = Collection::all(&conn);
162 assert_eq!(all_results.len(), collection_names.len());
163
164 let returned_names: Vec<String> = all_results.iter().map(|c| c.name.clone()).collect();
165 for name in &collection_names {
166 assert!(returned_names.contains(&name.to_string()));
167 }
168 }
169
170 #[test]
171 fn test_all_with_limit_method() {
172 let conn = get_connection(None).expect("Failed to get connection");
173
174 for i in 0..10 {
175 Collection::create(&conn, &format!("test_collection_{}", i));
176 }
177
178 let limited_results = Collection::all_with_limit(&conn, 5);
179 assert_eq!(limited_results.len(), 5);
180
181 let all_results = Collection::all(&conn);
183 let large_limit_results = Collection::all_with_limit(&conn, all_results.len() + 20);
184 assert_eq!(large_limit_results.len(), all_results.len());
185
186 let zero_limit_results = Collection::all_with_limit(&conn, 0);
188 assert!(zero_limit_results.is_empty());
189
190 let one_limit_results = Collection::all_with_limit(&conn, 1);
192 assert_eq!(one_limit_results.len(), 1);
193 }
194}