1use std::cmp::Ordering;
4use std::convert::TryFrom;
5use std::path::Path;
6
7use rusqlite::types::ValueRef;
8use rusqlite::{Connection, OpenFlags, Row};
9
10use crate::db::types::{
11 ColumnInfo, DatabaseSummary, SchemaObjectInfo, SortDirection, SqlValue, TableInfo, TablePage,
12 TableQuery, TableSort, ViewInfo,
13};
14use crate::error::{PatchworksError, Result};
15
16const INTERNAL_ROWID_ALIAS: &str = "__patchworks_rowid";
17
18#[derive(Clone, Debug, PartialEq)]
20pub struct InitialInspection {
21 pub summary: DatabaseSummary,
23 pub selected_table: Option<String>,
25 pub table_page: Option<TablePage>,
27}
28
29pub fn open_read_only(path: &Path) -> Result<Connection> {
31 Ok(Connection::open_with_flags(
32 path,
33 OpenFlags::SQLITE_OPEN_READ_ONLY | OpenFlags::SQLITE_OPEN_URI,
34 )?)
35}
36
37pub fn inspect_database(path: &Path) -> Result<DatabaseSummary> {
39 let connection = open_read_only(path)?;
40 let mut tables = Vec::new();
41 let mut views = Vec::new();
42 let mut indexes = Vec::new();
43 let mut triggers = Vec::new();
44
45 let mut statement = connection.prepare(
46 "
47 SELECT type, name, tbl_name, sql
48 FROM sqlite_master
49 WHERE name NOT LIKE 'sqlite_%'
50 ORDER BY type, name
51 ",
52 )?;
53
54 let entries = statement.query_map([], |row| {
55 Ok((
56 row.get::<_, String>(0)?,
57 row.get::<_, String>(1)?,
58 row.get::<_, String>(2)?,
59 row.get::<_, Option<String>>(3)?,
60 ))
61 })?;
62
63 for entry in entries {
64 let (entry_type, name, table_name, create_sql) = entry?;
65 if entry_type == "table" {
66 let normalized_create_sql = normalize_table_create_sql(create_sql, &name);
67 let columns = load_columns(&connection, &name)?;
68 let primary_key = columns
69 .iter()
70 .filter(|column| column.is_primary_key)
71 .map(|column| column.name.clone())
72 .collect::<Vec<_>>();
73 let row_count = count_rows(&connection, &name)?;
74
75 tables.push(TableInfo {
76 name,
77 columns,
78 row_count,
79 primary_key,
80 create_sql: normalized_create_sql,
81 });
82 } else if entry_type == "view" {
83 views.push(ViewInfo { name, create_sql });
84 } else if entry_type == "index" {
85 indexes.push(SchemaObjectInfo {
86 name,
87 table_name,
88 create_sql,
89 });
90 } else if entry_type == "trigger" {
91 triggers.push(SchemaObjectInfo {
92 name,
93 table_name,
94 create_sql,
95 });
96 }
97 }
98
99 Ok(DatabaseSummary {
100 path: path.to_string_lossy().into_owned(),
101 tables,
102 views,
103 indexes,
104 triggers,
105 })
106}
107
108pub fn inspect_database_with_page(path: &Path, query: &TableQuery) -> Result<InitialInspection> {
110 let summary = inspect_database(path)?;
111 let selected_table = summary.tables.first().map(|table| table.name.clone());
112 let table_page = if let Some(table_name) = &selected_table {
113 let table = summary
114 .tables
115 .iter()
116 .find(|table| table.name == *table_name)
117 .ok_or_else(|| PatchworksError::MissingTable {
118 table: table_name.clone(),
119 path: path.to_path_buf(),
120 })?;
121 Some(read_table_page_for_table(path, table, query)?)
122 } else {
123 None
124 };
125
126 Ok(InitialInspection {
127 summary,
128 selected_table,
129 table_page,
130 })
131}
132
133pub fn read_table_page(path: &Path, table_name: &str, query: &TableQuery) -> Result<TablePage> {
135 let summary = inspect_database(path)?;
136 let table = summary
137 .tables
138 .iter()
139 .find(|table| table.name == table_name)
140 .cloned()
141 .ok_or_else(|| PatchworksError::MissingTable {
142 table: table_name.to_owned(),
143 path: path.to_path_buf(),
144 })?;
145
146 read_table_page_for_table(path, &table, query)
147}
148
149pub fn read_table_page_for_table(
151 path: &Path,
152 table: &TableInfo,
153 query: &TableQuery,
154) -> Result<TablePage> {
155 let table = table.clone();
156
157 let connection = open_read_only(path)?;
158 let order_by = build_order_by_clause(&table, query.sort.as_ref())?;
159 let offset = query.page.saturating_mul(query.page_size);
160 let column_count = table.columns.len();
161
162 let sql = format!(
163 "SELECT {} FROM {}{} LIMIT ? OFFSET ?",
164 select_column_list(&table.columns),
165 quote_identifier(&table.name),
166 order_by
167 );
168 let mut statement = connection.prepare(&sql)?;
169 let rows = statement.query_map(
170 rusqlite::params![query.page_size as i64, offset as i64],
171 move |row| read_value_row(row, column_count, 0),
172 )?;
173
174 let mut values = Vec::new();
175 for row in rows {
176 values.push(row?);
177 }
178
179 Ok(TablePage {
180 table_name: table.name,
181 columns: table.columns,
182 rows: values,
183 page: query.page,
184 page_size: query.page_size,
185 total_rows: table.row_count,
186 sort: query.sort.clone(),
187 })
188}
189
190pub fn load_all_rows(path: &Path, table: &TableInfo) -> Result<Vec<Vec<SqlValue>>> {
192 let connection = open_read_only(path)?;
193 let sql = format!(
194 "SELECT {} FROM {}{}",
195 select_column_list(&table.columns),
196 quote_identifier(&table.name),
197 default_order_clause(table)
198 );
199 let mut statement = connection.prepare(&sql)?;
200 let rows = statement.query_map([], move |row| read_value_row(row, table.columns.len(), 0))?;
201
202 let mut values = Vec::new();
203 for row in rows {
204 values.push(row?);
205 }
206
207 Ok(values)
208}
209
210pub fn for_each_row<F>(path: &Path, table: &TableInfo, mut callback: F) -> Result<()>
213where
214 F: FnMut(&[SqlValue]) -> Result<()>,
215{
216 let connection = open_read_only(path)?;
217 let sql = format!(
218 "SELECT {} FROM {}{}",
219 select_column_list(&table.columns),
220 quote_identifier(&table.name),
221 default_order_clause(table)
222 );
223 let mut statement = connection.prepare(&sql)?;
224 let mut rows = statement.query([])?;
225 let column_count = table.columns.len();
226
227 while let Some(row) = rows.next()? {
228 let values = read_value_row(row, column_count, 0)?;
229 callback(&values)?;
230 }
231
232 Ok(())
233}
234
235pub fn identity_columns(table: &TableInfo) -> Vec<String> {
237 if table.primary_key.is_empty() {
238 vec![INTERNAL_ROWID_ALIAS.to_owned()]
239 } else {
240 table.primary_key.clone()
241 }
242}
243
244pub fn quote_identifier(identifier: &str) -> String {
246 format!("\"{}\"", identifier.replace('"', "\"\""))
247}
248
249pub fn read_value_row(
251 row: &Row<'_>,
252 count: usize,
253 offset: usize,
254) -> rusqlite::Result<Vec<SqlValue>> {
255 let mut values = Vec::with_capacity(count);
256 for index in offset..(offset + count) {
257 values.push(sql_value_from_ref(row.get_ref(index)?));
258 }
259
260 Ok(values)
261}
262
263pub fn sql_value_from_ref(value: ValueRef<'_>) -> SqlValue {
265 match value {
266 ValueRef::Null => SqlValue::Null,
267 ValueRef::Integer(value) => SqlValue::Integer(value),
268 ValueRef::Real(value) => SqlValue::Real(value),
269 ValueRef::Text(value) => SqlValue::Text(String::from_utf8_lossy(value).into_owned()),
270 ValueRef::Blob(value) => SqlValue::Blob(value.to_vec()),
271 }
272}
273
274pub fn compare_value_slices(left: &[SqlValue], right: &[SqlValue]) -> Ordering {
276 for (left_value, right_value) in left.iter().zip(right.iter()) {
277 let ordering = compare_sql_values(left_value, right_value);
278 if ordering != Ordering::Equal {
279 return ordering;
280 }
281 }
282
283 left.len().cmp(&right.len())
284}
285
286pub fn compare_sql_values(left: &SqlValue, right: &SqlValue) -> Ordering {
288 use SqlValue::{Blob, Integer, Null, Real, Text};
289
290 let rank = |value: &SqlValue| match value {
291 Null => 0,
292 Integer(_) | Real(_) => 1,
293 Text(_) => 2,
294 Blob(_) => 3,
295 };
296
297 let rank_ordering = rank(left).cmp(&rank(right));
298 if rank_ordering != Ordering::Equal {
299 return rank_ordering;
300 }
301
302 match (left, right) {
303 (Null, Null) => Ordering::Equal,
304 (Integer(left), Integer(right)) => left.cmp(right),
305 (Real(left), Real(right)) => left.partial_cmp(right).unwrap_or(Ordering::Equal),
306 (Integer(left), Real(right)) => {
307 (*left as f64).partial_cmp(right).unwrap_or(Ordering::Equal)
308 }
309 (Real(left), Integer(right)) => left
310 .partial_cmp(&(*right as f64))
311 .unwrap_or(Ordering::Equal),
312 (Text(left), Text(right)) => left.cmp(right),
313 (Blob(left), Blob(right)) => left.cmp(right),
314 _ => Ordering::Equal,
315 }
316}
317
318fn normalize_table_create_sql(create_sql: Option<String>, table_name: &str) -> Option<String> {
319 create_sql.map(|sql| normalize_table_create_sql_text(&sql, table_name))
320}
321
322fn normalize_table_create_sql_text(create_sql: &str, table_name: &str) -> String {
323 if !is_simple_identifier(table_name) {
324 return create_sql.to_owned();
325 }
326
327 let Ok(name_start) = create_table_name_start(create_sql) else {
328 return create_sql.to_owned();
329 };
330 let Ok(name_end) = create_table_name_end(create_sql, name_start) else {
331 return create_sql.to_owned();
332 };
333
334 let suffix = &create_sql[name_end..];
335 let normalized_suffix = if suffix.starts_with('(') {
336 format!(" {suffix}")
337 } else {
338 suffix.to_owned()
339 };
340
341 format!(
342 "{}{}{}",
343 &create_sql[..name_start],
344 table_name,
345 normalized_suffix
346 )
347}
348
349fn is_simple_identifier(identifier: &str) -> bool {
350 let mut chars = identifier.chars();
351 match chars.next() {
352 Some(first) if first == '_' || first.is_ascii_alphabetic() => {}
353 _ => return false,
354 }
355
356 chars.all(|ch| ch == '_' || ch.is_ascii_alphanumeric())
357}
358
359fn create_table_name_start(create_sql: &str) -> Result<usize> {
360 let mut index = skip_ascii_whitespace(create_sql, 0);
361 index = consume_keyword(create_sql, index, "CREATE").ok_or_else(|| {
362 PatchworksError::InvalidState(
363 "CREATE TABLE SQL did not start with CREATE while normalizing inspection output"
364 .to_owned(),
365 )
366 })?;
367 index = skip_ascii_whitespace(create_sql, index);
368
369 if let Some(next) = consume_keyword(create_sql, index, "TEMPORARY") {
370 index = skip_ascii_whitespace(create_sql, next);
371 } else if let Some(next) = consume_keyword(create_sql, index, "TEMP") {
372 index = skip_ascii_whitespace(create_sql, next);
373 }
374
375 index = consume_keyword(create_sql, index, "TABLE").ok_or_else(|| {
376 PatchworksError::InvalidState(
377 "CREATE TABLE SQL did not contain TABLE while normalizing inspection output".to_owned(),
378 )
379 })?;
380 index = skip_ascii_whitespace(create_sql, index);
381
382 if let Some(next) = consume_keyword(create_sql, index, "IF") {
383 index = skip_ascii_whitespace(create_sql, next);
384 index = consume_keyword(create_sql, index, "NOT").ok_or_else(|| {
385 PatchworksError::InvalidState(
386 "CREATE TABLE SQL had IF without NOT while normalizing inspection output"
387 .to_owned(),
388 )
389 })?;
390 index = skip_ascii_whitespace(create_sql, index);
391 index = consume_keyword(create_sql, index, "EXISTS").ok_or_else(|| {
392 PatchworksError::InvalidState(
393 "CREATE TABLE SQL had IF NOT without EXISTS while normalizing inspection output"
394 .to_owned(),
395 )
396 })?;
397 index = skip_ascii_whitespace(create_sql, index);
398 }
399
400 Ok(index)
401}
402
403fn create_table_name_end(create_sql: &str, start: usize) -> Result<usize> {
404 let bytes = create_sql.as_bytes();
405 let mut index = start;
406 let mut quoted_by: Option<u8> = None;
407
408 while let Some(&byte) = bytes.get(index) {
409 if let Some(quote) = quoted_by {
410 if byte == quote {
411 if matches!(quote, b'"' | b'`') && bytes.get(index + 1) == Some("e) {
412 index += 2;
413 continue;
414 }
415 quoted_by = None;
416 }
417 index += 1;
418 continue;
419 }
420
421 match byte {
422 b'"' => quoted_by = Some(b'"'),
423 b'`' => quoted_by = Some(b'`'),
424 b'[' => quoted_by = Some(b']'),
425 b'(' => break,
426 _ if byte.is_ascii_whitespace() => break,
427 _ => {}
428 }
429 index += 1;
430 }
431
432 if index == start {
433 Err(PatchworksError::InvalidState(
434 "CREATE TABLE SQL is missing a table name while normalizing inspection output"
435 .to_owned(),
436 ))
437 } else {
438 Ok(index)
439 }
440}
441
442fn skip_ascii_whitespace(sql: &str, mut index: usize) -> usize {
443 while let Some(byte) = sql.as_bytes().get(index) {
444 if byte.is_ascii_whitespace() {
445 index += 1;
446 } else {
447 break;
448 }
449 }
450 index
451}
452
453fn consume_keyword(sql: &str, index: usize, keyword: &str) -> Option<usize> {
454 let end = index.checked_add(keyword.len())?;
455 let slice = sql.get(index..end)?;
456 if !slice.eq_ignore_ascii_case(keyword) {
457 return None;
458 }
459
460 match sql[end..].chars().next() {
461 Some(ch) if !ch.is_ascii_whitespace() => None,
462 _ => Some(end),
463 }
464}
465
466fn load_columns(connection: &Connection, table_name: &str) -> Result<Vec<ColumnInfo>> {
467 let pragma = format!("PRAGMA table_info({})", quote_identifier(table_name));
468 let mut statement = connection.prepare(&pragma)?;
469 let columns = statement.query_map([], |row| {
470 let declared_type = row
471 .get::<_, Option<String>>(2)?
472 .unwrap_or_else(|| "BLOB".to_owned());
473 let pk_position = row.get::<_, i64>(5)?;
474 Ok((
475 row.get::<_, i64>(0)?,
476 pk_position,
477 ColumnInfo {
478 name: row.get(1)?,
479 col_type: declared_type,
480 nullable: row.get::<_, i64>(3)? == 0,
481 default_value: row.get(4)?,
482 is_primary_key: pk_position > 0,
483 },
484 ))
485 })?;
486
487 let mut values = Vec::new();
488 for column in columns {
489 values.push(column?);
490 }
491
492 let mut ordered_primary = values
493 .iter()
494 .filter(|(_, pk_position, _)| *pk_position > 0)
495 .cloned()
496 .collect::<Vec<_>>();
497 ordered_primary.sort_by_key(|(_, pk_position, _)| *pk_position);
498
499 let primary_names = ordered_primary
500 .into_iter()
501 .map(|(_, _, column)| column.name)
502 .collect::<Vec<_>>();
503
504 values.sort_by_key(|(cid, _, _)| *cid);
505 let mut all_columns = values
506 .into_iter()
507 .map(|(_, _, column)| column)
508 .collect::<Vec<_>>();
509 for column in &mut all_columns {
510 column.is_primary_key = primary_names.iter().any(|name| name == &column.name);
511 }
512
513 Ok(all_columns)
514}
515
516fn count_rows(connection: &Connection, table_name: &str) -> Result<u64> {
517 let sql = format!("SELECT COUNT(*) FROM {}", quote_identifier(table_name));
518 let count = connection.query_row(&sql, [], |row| row.get::<_, i64>(0))?;
519 u64::try_from(count).map_err(|_| {
520 PatchworksError::InvalidState(format!(
521 "received a negative row count while inspecting `{table_name}`"
522 ))
523 })
524}
525
526fn build_order_by_clause(table: &TableInfo, sort: Option<&TableSort>) -> Result<String> {
527 match sort {
528 Some(sort) => {
529 if !table
530 .columns
531 .iter()
532 .any(|column| column.name == sort.column)
533 {
534 return Err(PatchworksError::InvalidState(format!(
535 "column `{}` does not exist on table `{}`",
536 sort.column, table.name
537 )));
538 }
539 let direction = match sort.direction {
540 SortDirection::Asc => "ASC",
541 SortDirection::Desc => "DESC",
542 };
543 let mut order_terms = vec![format!("{} {}", quote_identifier(&sort.column), direction)];
544 order_terms.extend(stable_tie_breaker_terms(table, Some(sort.column.as_str())));
545 Ok(format!(" ORDER BY {}", order_terms.join(", ")))
546 }
547 None => Ok(default_order_clause(table)),
548 }
549}
550
551fn default_order_clause(table: &TableInfo) -> String {
552 format!(
553 " ORDER BY {}",
554 stable_tie_breaker_terms(table, None).join(", ")
555 )
556}
557
558fn stable_tie_breaker_terms(table: &TableInfo, skip_column: Option<&str>) -> Vec<String> {
559 if table.primary_key.is_empty() {
560 return if skip_column == Some("rowid") {
561 Vec::new()
562 } else {
563 vec!["rowid ASC".to_owned()]
564 };
565 }
566
567 table
568 .primary_key
569 .iter()
570 .filter(|column| Some(column.as_str()) != skip_column)
571 .map(|column| format!("{} ASC", quote_identifier(column)))
572 .collect()
573}
574
575fn select_column_list(columns: &[ColumnInfo]) -> String {
576 columns
577 .iter()
578 .map(|column| quote_identifier(&column.name))
579 .collect::<Vec<_>>()
580 .join(", ")
581}
582
583#[cfg(test)]
584mod tests {
585 use super::{build_order_by_clause, default_order_clause};
586 use crate::db::types::{ColumnInfo, SortDirection, TableInfo, TableSort};
587
588 fn sample_table() -> TableInfo {
589 TableInfo {
590 name: "items".to_owned(),
591 columns: vec![
592 ColumnInfo {
593 name: "id".to_owned(),
594 col_type: "INTEGER".to_owned(),
595 nullable: false,
596 default_value: None,
597 is_primary_key: true,
598 },
599 ColumnInfo {
600 name: "name".to_owned(),
601 col_type: "TEXT".to_owned(),
602 nullable: true,
603 default_value: None,
604 is_primary_key: false,
605 },
606 ],
607 row_count: 0,
608 primary_key: vec!["id".to_owned()],
609 create_sql: None,
610 }
611 }
612
613 #[test]
614 fn sorted_pages_include_primary_key_tie_breaker() {
615 let order = build_order_by_clause(
616 &sample_table(),
617 Some(&TableSort {
618 column: "name".to_owned(),
619 direction: SortDirection::Desc,
620 }),
621 )
622 .expect("build order clause");
623
624 assert_eq!(order, " ORDER BY \"name\" DESC, \"id\" ASC");
625 }
626
627 #[test]
628 fn default_order_clause_uses_primary_key_columns() {
629 let order = default_order_clause(&sample_table());
630
631 assert_eq!(order, " ORDER BY \"id\" ASC");
632 }
633}