axum_sql_viewer/database/
sqlite.rs1use crate::database::traits::{DatabaseError, DatabaseProvider};
4use crate::schema::{
5 ColumnInfo, CountResponse, ForeignKey, IndexInfo, QueryResult, RowQuery, RowsResponse,
6 SortOrder, TableInfo, TableSchema,
7};
8use async_trait::async_trait;
9use serde_json::Value;
10use sqlx::sqlite::SqliteRow;
11use sqlx::{Column, Row, SqlitePool, TypeInfo, ValueRef};
12use std::time::Instant;
13
14pub struct SqliteProvider {
16 pool: SqlitePool,
17}
18
19impl SqliteProvider {
20 pub fn new(pool: SqlitePool) -> Self {
26 Self { pool }
27 }
28
29 fn quote_identifier(identifier: &str) -> String {
34 format!("\"{}\"", identifier.replace('"', "\"\""))
35 }
36
37 fn row_to_json(row: &SqliteRow) -> Result<Value, DatabaseError> {
41 let mut map = serde_json::Map::new();
42
43 for column in row.columns() {
44 let column_name = column.name();
45 let value = Self::extract_column_value(row, column)?;
46 map.insert(column_name.to_string(), value);
47 }
48
49 Ok(Value::Object(map))
50 }
51
52 fn extract_column_value(
54 row: &SqliteRow,
55 column: &sqlx::sqlite::SqliteColumn,
56 ) -> Result<Value, DatabaseError> {
57 let column_name = column.name();
58 let type_info = column.type_info();
59 let type_name = type_info.name();
60
61 if row
63 .try_get_raw(column_name)
64 .map_err(|e| DatabaseError::Query(e.to_string()))?
65 .is_null()
66 {
67 return Ok(Value::Null);
68 }
69
70 match type_name {
73 "INTEGER" | "BIGINT" => {
74 if let Ok(value) = row.try_get::<i64, _>(column_name) {
76 return Ok(Value::Number(value.into()));
77 }
78 }
79 "REAL" | "FLOAT" | "DOUBLE" => {
80 if let Ok(value) = row.try_get::<f64, _>(column_name) {
81 if let Some(number) = serde_json::Number::from_f64(value) {
82 return Ok(Value::Number(number));
83 }
84 }
85 }
86 "TEXT" | "VARCHAR" | "CHAR" | "CLOB" => {
87 if let Ok(value) = row.try_get::<String, _>(column_name) {
88 return Ok(Value::String(value));
89 }
90 }
91 "BLOB" => {
92 if let Ok(value) = row.try_get::<Vec<u8>, _>(column_name) {
93 let base64_string = base64_encode(&value);
95 return Ok(Value::String(format!(
96 "[BLOB: {} bytes, base64: {}]",
97 value.len(),
98 base64_string
99 )));
100 }
101 }
102 "BOOLEAN" | "BOOL" => {
103 if let Ok(value) = row.try_get::<bool, _>(column_name) {
104 return Ok(Value::Bool(value));
105 }
106 }
107 "DATE" | "DATETIME" | "TIMESTAMP" => {
108 if let Ok(value) = row.try_get::<String, _>(column_name) {
110 return Ok(Value::String(value));
111 }
112 }
113 _ => {
114 if let Ok(value) = row.try_get::<String, _>(column_name) {
116 return Ok(Value::String(value));
117 }
118 }
119 }
120
121 if let Ok(value) = row.try_get::<i64, _>(column_name) {
123 return Ok(Value::Number(value.into()));
124 }
125 if let Ok(value) = row.try_get::<f64, _>(column_name) {
126 if let Some(number) = serde_json::Number::from_f64(value) {
127 return Ok(Value::Number(number));
128 }
129 }
130 if let Ok(value) = row.try_get::<String, _>(column_name) {
131 return Ok(Value::String(value));
132 }
133 if let Ok(value) = row.try_get::<bool, _>(column_name) {
134 return Ok(Value::Bool(value));
135 }
136 if let Ok(value) = row.try_get::<Vec<u8>, _>(column_name) {
137 let base64_string = base64_encode(&value);
138 return Ok(Value::String(format!(
139 "[BLOB: {} bytes, base64: {}]",
140 value.len(),
141 base64_string
142 )));
143 }
144
145 Ok(Value::Null)
147 }
148
149 fn build_where_clause(filters: &std::collections::HashMap<String, String>) -> (String, Vec<String>) {
151 if filters.is_empty() {
152 return (String::new(), Vec::new());
153 }
154
155 let mut conditions = Vec::new();
156 let mut values = Vec::new();
157
158 for (column, filter_value) in filters {
159 let quoted_column = Self::quote_identifier(column);
160
161 if filter_value.contains('%') {
163 conditions.push(format!("{} LIKE ?", quoted_column));
164 values.push(filter_value.clone());
165 } else {
166 conditions.push(format!("{} = ?", quoted_column));
167 values.push(filter_value.clone());
168 }
169 }
170
171 (format!(" WHERE {}", conditions.join(" AND ")), values)
172 }
173
174 fn build_order_clause(sort_by: Option<&str>, sort_order: Option<SortOrder>) -> String {
176 match (sort_by, sort_order) {
177 (Some(column), Some(order)) => {
178 let quoted_column = Self::quote_identifier(column);
179 let direction = match order {
180 SortOrder::Ascending => "ASC",
181 SortOrder::Descending => "DESC",
182 };
183 format!(" ORDER BY {} {}", quoted_column, direction)
184 }
185 _ => String::new(),
186 }
187 }
188}
189
190#[async_trait]
191impl DatabaseProvider for SqliteProvider {
192 async fn list_tables(&self) -> Result<Vec<TableInfo>, DatabaseError> {
193 let query = "SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%' ORDER BY name";
194
195 let rows = sqlx::query(query)
196 .fetch_all(&self.pool)
197 .await?;
198
199 let mut tables = Vec::new();
200 for row in rows {
201 let name: String = row.try_get("name")?;
202
203 let count_query = format!("SELECT COUNT(*) as count FROM {}", Self::quote_identifier(&name));
205 let row_count: Option<u64> = sqlx::query_scalar(&count_query)
206 .fetch_one(&self.pool)
207 .await
208 .ok()
209 .map(|count: i64| count as u64);
210
211 tables.push(TableInfo { name, row_count });
212 }
213
214 Ok(tables)
215 }
216
217 async fn get_table_schema(&self, table: &str) -> Result<TableSchema, DatabaseError> {
218 let table_info_query = format!("PRAGMA table_info({})", Self::quote_identifier(table));
220 let column_rows = sqlx::query(&table_info_query)
221 .fetch_all(&self.pool)
222 .await?;
223
224 if column_rows.is_empty() {
225 return Err(DatabaseError::TableNotFound(table.to_string()));
226 }
227
228 let mut columns = Vec::new();
229 let mut primary_key_columns = Vec::new();
230
231 for row in column_rows {
232 let _column_id: i32 = row.try_get("cid")?;
234 let name: String = row.try_get("name")?;
235 let data_type: String = row.try_get("type")?;
236 let not_null: i32 = row.try_get("notnull")?;
237 let default_value: Option<String> = row.try_get("dflt_value").ok();
238 let primary_key: i32 = row.try_get("pk")?;
239
240 let is_primary_key = primary_key > 0;
241 if is_primary_key {
242 primary_key_columns.push((primary_key, name.clone()));
243 }
244
245 columns.push(ColumnInfo {
246 name,
247 data_type,
248 nullable: not_null == 0,
249 default_value,
250 is_primary_key,
251 });
252 }
253
254 primary_key_columns.sort_by_key(|(order, _)| *order);
256 let primary_key = if primary_key_columns.is_empty() {
257 None
258 } else {
259 Some(primary_key_columns.into_iter().map(|(_, name)| name).collect())
260 };
261
262 let foreign_key_query = format!("PRAGMA foreign_key_list({})", Self::quote_identifier(table));
264 let foreign_key_rows = sqlx::query(&foreign_key_query)
265 .fetch_all(&self.pool)
266 .await?;
267
268 let mut foreign_keys = Vec::new();
269 for row in foreign_key_rows {
270 let column: String = row.try_get("from")?;
272 let references_table: String = row.try_get("table")?;
273 let references_column: String = row.try_get("to")?;
274
275 foreign_keys.push(ForeignKey {
276 column,
277 references_table,
278 references_column,
279 });
280 }
281
282 let index_list_query = format!("PRAGMA index_list({})", Self::quote_identifier(table));
284 let index_rows = sqlx::query(&index_list_query)
285 .fetch_all(&self.pool)
286 .await?;
287
288 let mut indexes = Vec::new();
289 for row in index_rows {
290 let index_name: String = row.try_get("name")?;
292 let unique: i32 = row.try_get("unique")?;
293
294 let index_info_query = format!("PRAGMA index_info({})", Self::quote_identifier(&index_name));
296 let index_column_rows = sqlx::query(&index_info_query)
297 .fetch_all(&self.pool)
298 .await?;
299
300 let mut index_columns = Vec::new();
301 for col_row in index_column_rows {
302 let column_name: Option<String> = col_row.try_get("name").ok();
304 if let Some(name) = column_name {
305 index_columns.push(name);
306 }
307 }
308
309 indexes.push(IndexInfo {
310 name: index_name,
311 columns: index_columns,
312 unique: unique != 0,
313 });
314 }
315
316 Ok(TableSchema {
317 name: table.to_string(),
318 columns,
319 primary_key,
320 foreign_keys,
321 indexes,
322 })
323 }
324
325 async fn get_rows(&self, table: &str, query: RowQuery) -> Result<RowsResponse, DatabaseError> {
326 let table_exists: Option<i64> = sqlx::query_scalar(
328 "SELECT 1 FROM sqlite_master WHERE type='table' AND name = ? AND name NOT LIKE 'sqlite_%'"
329 )
330 .bind(table)
331 .fetch_optional(&self.pool)
332 .await?;
333
334 if table_exists.is_none() {
335 return Err(DatabaseError::TableNotFound(table.to_string()));
336 }
337
338 const MAX_LIMIT: u64 = 500;
340 let limit = query.limit.min(MAX_LIMIT);
341
342 let (where_clause, filter_values) = Self::build_where_clause(&query.filters);
344
345 let order_clause = Self::build_order_clause(
347 query.sort_by.as_deref(),
348 query.sort_order,
349 );
350
351 let count_query = format!(
353 "SELECT COUNT(*) FROM {}{}",
354 Self::quote_identifier(table),
355 where_clause
356 );
357
358 let mut count_sql_query = sqlx::query_scalar::<_, i64>(&count_query);
359 for value in &filter_values {
360 count_sql_query = count_sql_query.bind(value);
361 }
362 let total: i64 = count_sql_query.fetch_one(&self.pool).await?;
363 let total = total as u64;
364
365 let select_query = format!(
367 "SELECT * FROM {}{}{} LIMIT ? OFFSET ?",
368 Self::quote_identifier(table),
369 where_clause,
370 order_clause
371 );
372
373 let mut sql_query = sqlx::query(&select_query);
375 for value in &filter_values {
376 sql_query = sql_query.bind(value);
377 }
378 sql_query = sql_query.bind(limit as i64).bind(query.offset as i64);
379
380 let rows = sql_query.fetch_all(&self.pool).await?;
381
382 let columns = if let Some(first_row) = rows.first() {
384 first_row
385 .columns()
386 .iter()
387 .map(|col| col.name().to_string())
388 .collect()
389 } else {
390 let schema = self.get_table_schema(table).await?;
392 schema.columns.into_iter().map(|col| col.name).collect()
393 };
394
395 let mut json_rows = Vec::new();
397 for row in &rows {
398 json_rows.push(Self::row_to_json(row)?);
399 }
400
401 let has_more = query.offset + (json_rows.len() as u64) < total;
402
403 Ok(RowsResponse {
404 rows: json_rows,
405 columns,
406 total,
407 offset: query.offset,
408 limit,
409 has_more,
410 })
411 }
412
413 async fn count_rows(&self, table: &str, query: &RowQuery) -> Result<CountResponse, DatabaseError> {
414 let table_exists: Option<i64> = sqlx::query_scalar(
416 "SELECT 1 FROM sqlite_master WHERE type='table' AND name = ? AND name NOT LIKE 'sqlite_%'"
417 )
418 .bind(table)
419 .fetch_optional(&self.pool)
420 .await?;
421
422 if table_exists.is_none() {
423 return Err(DatabaseError::TableNotFound(table.to_string()));
424 }
425
426 let (where_clause, filter_values) = Self::build_where_clause(&query.filters);
428
429 let count_query = format!(
431 "SELECT COUNT(*) FROM {}{}",
432 Self::quote_identifier(table),
433 where_clause
434 );
435
436 let mut sql_query = sqlx::query_scalar::<_, i64>(&count_query);
437 for value in &filter_values {
438 sql_query = sql_query.bind(value);
439 }
440
441 let count: i64 = sql_query.fetch_one(&self.pool).await?;
442
443 Ok(CountResponse {
444 count: count as u64,
445 })
446 }
447
448 async fn execute_query(&self, sql: &str) -> Result<QueryResult, DatabaseError> {
449 let start_time = Instant::now();
450
451 const QUERY_TIMEOUT_SECONDS: u64 = 30;
453
454 const MAX_RESULT_ROWS: u64 = 10000;
456
457 let trimmed_sql = sql.trim().to_uppercase();
459 let is_select_query = trimmed_sql.starts_with("SELECT")
460 || trimmed_sql.starts_with("PRAGMA")
461 || trimmed_sql.starts_with("EXPLAIN");
462
463 if is_select_query {
464 let result = tokio::time::timeout(
466 std::time::Duration::from_secs(QUERY_TIMEOUT_SECONDS),
467 sqlx::query(sql).fetch_all(&self.pool),
468 )
469 .await;
470
471 let execution_time_milliseconds = start_time.elapsed().as_millis() as u64;
472
473 match result {
474 Ok(Ok(rows)) => {
475 if rows.len() > MAX_RESULT_ROWS as usize {
477 return Err(DatabaseError::TooManyRows(MAX_RESULT_ROWS));
478 }
479
480 let columns = if let Some(first_row) = rows.first() {
482 first_row
483 .columns()
484 .iter()
485 .map(|column| column.name().to_string())
486 .collect()
487 } else {
488 Vec::new()
489 };
490
491 let mut json_rows = Vec::new();
493 for row in &rows {
494 json_rows.push(Self::row_to_json(row)?);
495 }
496
497 Ok(QueryResult {
498 columns,
499 rows: json_rows,
500 affected_rows: rows.len() as u64,
501 execution_time_milliseconds,
502 error: None,
503 })
504 }
505 Ok(Err(error)) => {
506 Ok(QueryResult {
508 columns: Vec::new(),
509 rows: Vec::new(),
510 affected_rows: 0,
511 execution_time_milliseconds,
512 error: Some(error.to_string()),
513 })
514 }
515 Err(_) => {
516 Err(DatabaseError::Timeout)
518 }
519 }
520 } else {
521 let result = tokio::time::timeout(
523 std::time::Duration::from_secs(QUERY_TIMEOUT_SECONDS),
524 sqlx::query(sql).execute(&self.pool),
525 )
526 .await;
527
528 let execution_time_milliseconds = start_time.elapsed().as_millis() as u64;
529
530 match result {
531 Ok(Ok(query_result)) => {
532 Ok(QueryResult {
533 columns: Vec::new(),
534 rows: Vec::new(),
535 affected_rows: query_result.rows_affected(),
536 execution_time_milliseconds,
537 error: None,
538 })
539 }
540 Ok(Err(error)) => {
541 Ok(QueryResult {
542 columns: Vec::new(),
543 rows: Vec::new(),
544 affected_rows: 0,
545 execution_time_milliseconds,
546 error: Some(error.to_string()),
547 })
548 }
549 Err(_) => {
550 Err(DatabaseError::Timeout)
551 }
552 }
553 }
554 }
555}
556
557fn base64_encode(data: &[u8]) -> String {
559 const BASE64_CHARS: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
560
561 let limited_data = if data.len() > 64 {
563 &data[..64]
564 } else {
565 data
566 };
567
568 let mut result = String::new();
569 let mut i = 0;
570
571 while i + 2 < limited_data.len() {
572 let b1 = limited_data[i];
573 let b2 = limited_data[i + 1];
574 let b3 = limited_data[i + 2];
575
576 result.push(BASE64_CHARS[(b1 >> 2) as usize] as char);
577 result.push(BASE64_CHARS[(((b1 & 0x03) << 4) | (b2 >> 4)) as usize] as char);
578 result.push(BASE64_CHARS[(((b2 & 0x0f) << 2) | (b3 >> 6)) as usize] as char);
579 result.push(BASE64_CHARS[(b3 & 0x3f) as usize] as char);
580
581 i += 3;
582 }
583
584 if i < limited_data.len() {
586 let b1 = limited_data[i];
587 result.push(BASE64_CHARS[(b1 >> 2) as usize] as char);
588
589 if i + 1 < limited_data.len() {
590 let b2 = limited_data[i + 1];
591 result.push(BASE64_CHARS[(((b1 & 0x03) << 4) | (b2 >> 4)) as usize] as char);
592 result.push(BASE64_CHARS[((b2 & 0x0f) << 2) as usize] as char);
593 result.push('=');
594 } else {
595 result.push(BASE64_CHARS[((b1 & 0x03) << 4) as usize] as char);
596 result.push_str("==");
597 }
598 }
599
600 if data.len() > 64 {
601 result.push_str("...");
602 }
603
604 result
605}
606
607#[cfg(test)]
608mod tests {
609 use super::*;
610
611 #[test]
612 fn test_quote_identifier() {
613 assert_eq!(SqliteProvider::quote_identifier("users"), "\"users\"");
614 assert_eq!(
615 SqliteProvider::quote_identifier("table\"name"),
616 "\"table\"\"name\""
617 );
618 }
619
620 #[test]
621 fn test_build_where_clause() {
622 let mut filters = std::collections::HashMap::new();
623 filters.insert("name".to_string(), "John".to_string());
624 filters.insert("age".to_string(), "30".to_string());
625
626 let (clause, values) = SqliteProvider::build_where_clause(&filters);
627 assert!(clause.contains("WHERE"));
628 assert!(clause.contains("\"name\""));
629 assert!(clause.contains("\"age\""));
630 assert_eq!(values.len(), 2);
631 }
632
633 #[test]
634 fn test_build_order_clause() {
635 let clause = SqliteProvider::build_order_clause(Some("name"), Some(SortOrder::Ascending));
636 assert!(clause.contains("ORDER BY"));
637 assert!(clause.contains("\"name\""));
638 assert!(clause.contains("ASC"));
639
640 let clause = SqliteProvider::build_order_clause(Some("id"), Some(SortOrder::Descending));
641 assert!(clause.contains("DESC"));
642
643 let clause = SqliteProvider::build_order_clause(None, None);
644 assert!(clause.is_empty());
645 }
646
647 #[test]
648 fn test_base64_encode() {
649 let data = b"Hello, World!";
650 let encoded = base64_encode(data);
651 assert!(!encoded.is_empty());
652 assert!(encoded.chars().all(|c| c.is_alphanumeric() || c == '+' || c == '/' || c == '='));
653 }
654}