Skip to main content

gen_models/
traits.rs

1use std::rc::Rc;
2
3use itertools::Itertools;
4use rusqlite::{Connection, Params, Result, Row, limits::Limit, params, types::Value};
5
6/// Returns the SQLite variable parameter limit for the provided connection.
7pub fn sqlite_parameter_limit(conn: &Connection) -> usize {
8    let limit = conn.limit(Limit::SQLITE_LIMIT_VARIABLE_NUMBER);
9    usize::try_from(limit).expect("SQLite parameter limit should be positive")
10}
11
12/// Computes how many rows can be inserted per batch given a parameter count.
13pub fn max_rows_per_batch(conn: &Connection, params_per_row: usize) -> usize {
14    let params_per_row = params_per_row.max(1);
15    let max_params = sqlite_parameter_limit(conn);
16    (max_params / params_per_row).max(1)
17}
18
19pub trait Query {
20    type Model;
21    const PRIMARY_KEY: &'static str = "id";
22    const TABLE_NAME: &'static str;
23
24    fn query(conn: &Connection, query: &str, params: impl Params) -> Vec<Self::Model> {
25        let mut stmt = conn.prepare(query).unwrap();
26        let rows = stmt
27            .query_map(params, |row| Ok(Self::process_row(row)))
28            .unwrap();
29        let mut objs = vec![];
30        for row in rows {
31            objs.push(row.unwrap());
32        }
33        objs
34    }
35
36    fn get(conn: &Connection, query: &str, params: impl Params) -> Result<Self::Model> {
37        let mut stmt = conn.prepare(query).unwrap();
38        stmt.query_row(params, |row| Ok(Self::process_row(row)))
39    }
40
41    fn get_by_id<'a, T>(conn: &Connection, id: &'a T) -> Option<Self::Model>
42    where
43        T: Clone + 'a,
44        Value: From<T>,
45    {
46        Self::query(
47            conn,
48            &format!(
49                "select * from {} where {} = ?1",
50                Self::TABLE_NAME,
51                Self::PRIMARY_KEY
52            ),
53            params![Value::from(id.clone())],
54        )
55        .pop()
56    }
57
58    fn query_by_ids<'a, I: ?Sized, T>(conn: &Connection, ids: &'a I) -> Vec<Self::Model>
59    where
60        &'a I: IntoIterator<Item = &'a T>,
61        T: Clone + 'a,
62        Value: From<T>,
63    {
64        let mut results = vec![];
65        let batch_size = max_rows_per_batch(conn, 1);
66        for chunk in &ids.into_iter().chunks(batch_size) {
67            let values: Vec<Value> = chunk
68                .map(|value: &'a T| Value::from(value.clone()))
69                .collect();
70            results.append(&mut Self::query(
71                conn,
72                // The use of rarray/rowid is to preserve the order of the input IDs. If it becomes a performance hit,
73                // we can consider seeing if the input is an ordered preserving structure or not.
74                &format!(
75                    "
76                    WITH arr AS (
77                    SELECT value, rowid AS pos
78                    FROM rarray(?1)
79                    )
80                    SELECT {}.*
81                    FROM {}
82                    JOIN arr ON {}.{} = arr.value
83                    ORDER BY arr.pos;
84                    ",
85                    Self::TABLE_NAME,
86                    Self::TABLE_NAME,
87                    Self::TABLE_NAME,
88                    Self::PRIMARY_KEY
89                ),
90                params!(Rc::new(values)),
91            ))
92        }
93        results
94    }
95
96    fn delete_by_ids<'a, I: ?Sized, T>(conn: &Connection, ids: &'a I) -> Vec<Self::Model>
97    where
98        &'a I: IntoIterator<Item = &'a T>,
99        T: Clone + 'a,
100        Value: From<T>,
101    {
102        let mut results = vec![];
103        let batch_size = max_rows_per_batch(conn, 1);
104        for chunk in &ids.into_iter().chunks(batch_size) {
105            let values: Vec<Value> = chunk
106                .map(|value: &'a T| Value::from(value.clone()))
107                .collect();
108            results.append(&mut Self::query(
109                conn,
110                &format!(
111                    "delete from {} where {} in rarray(?1)",
112                    Self::TABLE_NAME,
113                    Self::PRIMARY_KEY
114                ),
115                params!(Rc::new(values)),
116            ))
117        }
118        results
119    }
120
121    fn process_row(row: &Row) -> Self::Model;
122
123    fn table_name() -> &'static str {
124        Self::TABLE_NAME
125    }
126
127    fn all(conn: &Connection) -> Vec<Self::Model> {
128        let query = format!("SELECT * FROM {}", Self::TABLE_NAME);
129        Self::query(conn, &query, [])
130    }
131
132    fn all_with_limit(conn: &Connection, limit: usize) -> Vec<Self::Model> {
133        let query = format!("SELECT * FROM {} LIMIT {}", Self::TABLE_NAME, limit);
134        Self::query(conn, &query, [])
135    }
136}
137
138#[cfg(test)]
139mod tests {
140    use super::*;
141    use crate::{collection::Collection, test_helpers::get_connection};
142
143    #[test]
144    fn test_all_method() {
145        let conn = get_connection(None).expect("Failed to get connection");
146
147        // Test with empty table
148        let empty_results = Collection::all(&conn);
149        assert!(empty_results.is_empty());
150
151        // Create test collections
152        let collection_names = vec![
153            "test_collection_1",
154            "test_collection_2",
155            "test_collection_3",
156        ];
157        for name in &collection_names {
158            Collection::create(&conn, name);
159        }
160
161        let all_results = Collection::all(&conn);
162        assert_eq!(all_results.len(), collection_names.len());
163
164        let returned_names: Vec<String> = all_results.iter().map(|c| c.name.clone()).collect();
165        for name in &collection_names {
166            assert!(returned_names.contains(&name.to_string()));
167        }
168    }
169
170    #[test]
171    fn test_all_with_limit_method() {
172        let conn = get_connection(None).expect("Failed to get connection");
173
174        for i in 0..10 {
175            Collection::create(&conn, &format!("test_collection_{}", i));
176        }
177
178        let limited_results = Collection::all_with_limit(&conn, 5);
179        assert_eq!(limited_results.len(), 5);
180
181        // Test limit larger than available records
182        let all_results = Collection::all(&conn);
183        let large_limit_results = Collection::all_with_limit(&conn, all_results.len() + 20);
184        assert_eq!(large_limit_results.len(), all_results.len());
185
186        // Test limit of 0
187        let zero_limit_results = Collection::all_with_limit(&conn, 0);
188        assert!(zero_limit_results.is_empty());
189
190        // Test limit of 1
191        let one_limit_results = Collection::all_with_limit(&conn, 1);
192        assert_eq!(one_limit_results.len(), 1);
193    }
194}