axum_sql_viewer/database/
sqlite.rs

1//! SQLite database provider implementation
2
3use crate::database::traits::{DatabaseError, DatabaseProvider};
4use crate::schema::{
5    ColumnInfo, CountResponse, ForeignKey, IndexInfo, QueryResult, RowQuery, RowsResponse,
6    SortOrder, TableInfo, TableSchema,
7};
8use async_trait::async_trait;
9use serde_json::Value;
10use sqlx::sqlite::SqliteRow;
11use sqlx::{Column, Row, SqlitePool, TypeInfo, ValueRef};
12use std::time::Instant;
13
14/// SQLite database provider
15pub struct SqliteProvider {
16    pool: SqlitePool,
17}
18
19impl SqliteProvider {
20    /// Create a new SQLite provider
21    ///
22    /// # Arguments
23    ///
24    /// * `pool` - SQLite connection pool
25    pub fn new(pool: SqlitePool) -> Self {
26        Self { pool }
27    }
28
29    /// Quote an identifier (table or column name) to prevent SQL injection
30    ///
31    /// SQLite uses double quotes for identifiers. This function escapes any
32    /// double quotes in the identifier by doubling them.
33    fn quote_identifier(identifier: &str) -> String {
34        format!("\"{}\"", identifier.replace('"', "\"\""))
35    }
36
37    /// Convert a SQLite row to a JSON object
38    ///
39    /// This handles all SQLite data types and converts them to appropriate JSON values.
40    fn row_to_json(row: &SqliteRow) -> Result<Value, DatabaseError> {
41        let mut map = serde_json::Map::new();
42
43        for column in row.columns() {
44            let column_name = column.name();
45            let value = Self::extract_column_value(row, column)?;
46            map.insert(column_name.to_string(), value);
47        }
48
49        Ok(Value::Object(map))
50    }
51
52    /// Extract a column value from a SQLite row and convert to JSON
53    fn extract_column_value(
54        row: &SqliteRow,
55        column: &sqlx::sqlite::SqliteColumn,
56    ) -> Result<Value, DatabaseError> {
57        let column_name = column.name();
58        let type_info = column.type_info();
59        let type_name = type_info.name();
60
61        // Check if the value is NULL first
62        if row
63            .try_get_raw(column_name)
64            .map_err(|e| DatabaseError::Query(e.to_string()))?
65            .is_null()
66        {
67            return Ok(Value::Null);
68        }
69
70        // SQLite has dynamic typing but reports affinities: INTEGER, REAL, TEXT, BLOB, NULL
71        // We'll try to extract the value based on the type affinity
72        match type_name {
73            "INTEGER" | "BIGINT" => {
74                // Try i64 first, which covers most integer cases
75                if let Ok(value) = row.try_get::<i64, _>(column_name) {
76                    return Ok(Value::Number(value.into()));
77                }
78            }
79            "REAL" | "FLOAT" | "DOUBLE" => {
80                if let Ok(value) = row.try_get::<f64, _>(column_name) {
81                    if let Some(number) = serde_json::Number::from_f64(value) {
82                        return Ok(Value::Number(number));
83                    }
84                }
85            }
86            "TEXT" | "VARCHAR" | "CHAR" | "CLOB" => {
87                if let Ok(value) = row.try_get::<String, _>(column_name) {
88                    return Ok(Value::String(value));
89                }
90            }
91            "BLOB" => {
92                if let Ok(value) = row.try_get::<Vec<u8>, _>(column_name) {
93                    // Convert BLOB to base64 string for JSON serialization
94                    let base64_string = base64_encode(&value);
95                    return Ok(Value::String(format!(
96                        "[BLOB: {} bytes, base64: {}]",
97                        value.len(),
98                        base64_string
99                    )));
100                }
101            }
102            "BOOLEAN" | "BOOL" => {
103                if let Ok(value) = row.try_get::<bool, _>(column_name) {
104                    return Ok(Value::Bool(value));
105                }
106            }
107            "DATE" | "DATETIME" | "TIMESTAMP" => {
108                // Try to get as string (ISO format is common in SQLite)
109                if let Ok(value) = row.try_get::<String, _>(column_name) {
110                    return Ok(Value::String(value));
111                }
112            }
113            _ => {
114                // For unknown types, try string first, then other types
115                if let Ok(value) = row.try_get::<String, _>(column_name) {
116                    return Ok(Value::String(value));
117                }
118            }
119        }
120
121        // Fallback: try common types in order
122        if let Ok(value) = row.try_get::<i64, _>(column_name) {
123            return Ok(Value::Number(value.into()));
124        }
125        if let Ok(value) = row.try_get::<f64, _>(column_name) {
126            if let Some(number) = serde_json::Number::from_f64(value) {
127                return Ok(Value::Number(number));
128            }
129        }
130        if let Ok(value) = row.try_get::<String, _>(column_name) {
131            return Ok(Value::String(value));
132        }
133        if let Ok(value) = row.try_get::<bool, _>(column_name) {
134            return Ok(Value::Bool(value));
135        }
136        if let Ok(value) = row.try_get::<Vec<u8>, _>(column_name) {
137            let base64_string = base64_encode(&value);
138            return Ok(Value::String(format!(
139                "[BLOB: {} bytes, base64: {}]",
140                value.len(),
141                base64_string
142            )));
143        }
144
145        // If all else fails, return null
146        Ok(Value::Null)
147    }
148
149    /// Build a WHERE clause from filters
150    fn build_where_clause(filters: &std::collections::HashMap<String, String>) -> (String, Vec<String>) {
151        if filters.is_empty() {
152            return (String::new(), Vec::new());
153        }
154
155        let mut conditions = Vec::new();
156        let mut values = Vec::new();
157
158        for (column, filter_value) in filters {
159            let quoted_column = Self::quote_identifier(column);
160
161            // Support LIKE patterns with % wildcard
162            if filter_value.contains('%') {
163                conditions.push(format!("{} LIKE ?", quoted_column));
164                values.push(filter_value.clone());
165            } else {
166                conditions.push(format!("{} = ?", quoted_column));
167                values.push(filter_value.clone());
168            }
169        }
170
171        (format!(" WHERE {}", conditions.join(" AND ")), values)
172    }
173
174    /// Build an ORDER BY clause from sort parameters
175    fn build_order_clause(sort_by: Option<&str>, sort_order: Option<SortOrder>) -> String {
176        match (sort_by, sort_order) {
177            (Some(column), Some(order)) => {
178                let quoted_column = Self::quote_identifier(column);
179                let direction = match order {
180                    SortOrder::Ascending => "ASC",
181                    SortOrder::Descending => "DESC",
182                };
183                format!(" ORDER BY {} {}", quoted_column, direction)
184            }
185            _ => String::new(),
186        }
187    }
188}
189
190#[async_trait]
191impl DatabaseProvider for SqliteProvider {
192    async fn list_tables(&self) -> Result<Vec<TableInfo>, DatabaseError> {
193        let query = "SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%' ORDER BY name";
194
195        let rows = sqlx::query(query)
196            .fetch_all(&self.pool)
197            .await?;
198
199        let mut tables = Vec::new();
200        for row in rows {
201            let name: String = row.try_get("name")?;
202
203            // Optionally get row count for each table
204            let count_query = format!("SELECT COUNT(*) as count FROM {}", Self::quote_identifier(&name));
205            let row_count: Option<u64> = sqlx::query_scalar(&count_query)
206                .fetch_one(&self.pool)
207                .await
208                .ok()
209                .map(|count: i64| count as u64);
210
211            tables.push(TableInfo { name, row_count });
212        }
213
214        Ok(tables)
215    }
216
217    async fn get_table_schema(&self, table: &str) -> Result<TableSchema, DatabaseError> {
218        // Get column information using PRAGMA table_info
219        let table_info_query = format!("PRAGMA table_info({})", Self::quote_identifier(table));
220        let column_rows = sqlx::query(&table_info_query)
221            .fetch_all(&self.pool)
222            .await?;
223
224        if column_rows.is_empty() {
225            return Err(DatabaseError::TableNotFound(table.to_string()));
226        }
227
228        let mut columns = Vec::new();
229        let mut primary_key_columns = Vec::new();
230
231        for row in column_rows {
232            // PRAGMA table_info returns: cid, name, type, notnull, dflt_value, pk
233            let _column_id: i32 = row.try_get("cid")?;
234            let name: String = row.try_get("name")?;
235            let data_type: String = row.try_get("type")?;
236            let not_null: i32 = row.try_get("notnull")?;
237            let default_value: Option<String> = row.try_get("dflt_value").ok();
238            let primary_key: i32 = row.try_get("pk")?;
239
240            let is_primary_key = primary_key > 0;
241            if is_primary_key {
242                primary_key_columns.push((primary_key, name.clone()));
243            }
244
245            columns.push(ColumnInfo {
246                name,
247                data_type,
248                nullable: not_null == 0,
249                default_value,
250                is_primary_key,
251            });
252        }
253
254        // Sort primary key columns by their pk order and extract names
255        primary_key_columns.sort_by_key(|(order, _)| *order);
256        let primary_key = if primary_key_columns.is_empty() {
257            None
258        } else {
259            Some(primary_key_columns.into_iter().map(|(_, name)| name).collect())
260        };
261
262        // Get foreign key information using PRAGMA foreign_key_list
263        let foreign_key_query = format!("PRAGMA foreign_key_list({})", Self::quote_identifier(table));
264        let foreign_key_rows = sqlx::query(&foreign_key_query)
265            .fetch_all(&self.pool)
266            .await?;
267
268        let mut foreign_keys = Vec::new();
269        for row in foreign_key_rows {
270            // PRAGMA foreign_key_list returns: id, seq, table, from, to, on_update, on_delete, match
271            let column: String = row.try_get("from")?;
272            let references_table: String = row.try_get("table")?;
273            let references_column: String = row.try_get("to")?;
274
275            foreign_keys.push(ForeignKey {
276                column,
277                references_table,
278                references_column,
279            });
280        }
281
282        // Get index information using PRAGMA index_list
283        let index_list_query = format!("PRAGMA index_list({})", Self::quote_identifier(table));
284        let index_rows = sqlx::query(&index_list_query)
285            .fetch_all(&self.pool)
286            .await?;
287
288        let mut indexes = Vec::new();
289        for row in index_rows {
290            // PRAGMA index_list returns: seq, name, unique, origin, partial
291            let index_name: String = row.try_get("name")?;
292            let unique: i32 = row.try_get("unique")?;
293
294            // Get columns in this index using PRAGMA index_info
295            let index_info_query = format!("PRAGMA index_info({})", Self::quote_identifier(&index_name));
296            let index_column_rows = sqlx::query(&index_info_query)
297                .fetch_all(&self.pool)
298                .await?;
299
300            let mut index_columns = Vec::new();
301            for col_row in index_column_rows {
302                // PRAGMA index_info returns: seqno, cid, name
303                let column_name: Option<String> = col_row.try_get("name").ok();
304                if let Some(name) = column_name {
305                    index_columns.push(name);
306                }
307            }
308
309            indexes.push(IndexInfo {
310                name: index_name,
311                columns: index_columns,
312                unique: unique != 0,
313            });
314        }
315
316        Ok(TableSchema {
317            name: table.to_string(),
318            columns,
319            primary_key,
320            foreign_keys,
321            indexes,
322        })
323    }
324
325    async fn get_rows(&self, table: &str, query: RowQuery) -> Result<RowsResponse, DatabaseError> {
326        // Verify the table exists first
327        let table_exists: Option<i64> = sqlx::query_scalar(
328            "SELECT 1 FROM sqlite_master WHERE type='table' AND name = ? AND name NOT LIKE 'sqlite_%'"
329        )
330        .bind(table)
331        .fetch_optional(&self.pool)
332        .await?;
333
334        if table_exists.is_none() {
335            return Err(DatabaseError::TableNotFound(table.to_string()));
336        }
337
338        // Enforce maximum limit
339        const MAX_LIMIT: u64 = 500;
340        let limit = query.limit.min(MAX_LIMIT);
341
342        // Build WHERE clause from filters
343        let (where_clause, filter_values) = Self::build_where_clause(&query.filters);
344
345        // Build ORDER BY clause
346        let order_clause = Self::build_order_clause(
347            query.sort_by.as_deref(),
348            query.sort_order,
349        );
350
351        // Get total count with filters applied
352        let count_query = format!(
353            "SELECT COUNT(*) FROM {}{}",
354            Self::quote_identifier(table),
355            where_clause
356        );
357
358        let mut count_sql_query = sqlx::query_scalar::<_, i64>(&count_query);
359        for value in &filter_values {
360            count_sql_query = count_sql_query.bind(value);
361        }
362        let total: i64 = count_sql_query.fetch_one(&self.pool).await?;
363        let total = total as u64;
364
365        // Build the main query
366        let select_query = format!(
367            "SELECT * FROM {}{}{} LIMIT ? OFFSET ?",
368            Self::quote_identifier(table),
369            where_clause,
370            order_clause
371        );
372
373        // Build and execute query with bindings
374        let mut sql_query = sqlx::query(&select_query);
375        for value in &filter_values {
376            sql_query = sql_query.bind(value);
377        }
378        sql_query = sql_query.bind(limit as i64).bind(query.offset as i64);
379
380        let rows = sql_query.fetch_all(&self.pool).await?;
381
382        // Extract column names from the first row (if any) or from schema
383        let columns = if let Some(first_row) = rows.first() {
384            first_row
385                .columns()
386                .iter()
387                .map(|col| col.name().to_string())
388                .collect()
389        } else {
390            // If no rows, get columns from schema
391            let schema = self.get_table_schema(table).await?;
392            schema.columns.into_iter().map(|col| col.name).collect()
393        };
394
395        // Convert rows to JSON
396        let mut json_rows = Vec::new();
397        for row in &rows {
398            json_rows.push(Self::row_to_json(row)?);
399        }
400
401        let has_more = query.offset + (json_rows.len() as u64) < total;
402
403        Ok(RowsResponse {
404            rows: json_rows,
405            columns,
406            total,
407            offset: query.offset,
408            limit,
409            has_more,
410        })
411    }
412
413    async fn count_rows(&self, table: &str, query: &RowQuery) -> Result<CountResponse, DatabaseError> {
414        // Verify the table exists first
415        let table_exists: Option<i64> = sqlx::query_scalar(
416            "SELECT 1 FROM sqlite_master WHERE type='table' AND name = ? AND name NOT LIKE 'sqlite_%'"
417        )
418        .bind(table)
419        .fetch_optional(&self.pool)
420        .await?;
421
422        if table_exists.is_none() {
423            return Err(DatabaseError::TableNotFound(table.to_string()));
424        }
425
426        // Build WHERE clause from filters
427        let (where_clause, filter_values) = Self::build_where_clause(&query.filters);
428
429        // Build count query
430        let count_query = format!(
431            "SELECT COUNT(*) FROM {}{}",
432            Self::quote_identifier(table),
433            where_clause
434        );
435
436        let mut sql_query = sqlx::query_scalar::<_, i64>(&count_query);
437        for value in &filter_values {
438            sql_query = sql_query.bind(value);
439        }
440
441        let count: i64 = sql_query.fetch_one(&self.pool).await?;
442
443        Ok(CountResponse {
444            count: count as u64,
445        })
446    }
447
448    async fn execute_query(&self, sql: &str) -> Result<QueryResult, DatabaseError> {
449        let start_time = Instant::now();
450
451        // Enforce query timeout (30 seconds)
452        const QUERY_TIMEOUT_SECONDS: u64 = 30;
453
454        // Enforce maximum result row limit
455        const MAX_RESULT_ROWS: u64 = 10000;
456
457        // Check if this is a SELECT query or a write operation
458        let trimmed_sql = sql.trim().to_uppercase();
459        let is_select_query = trimmed_sql.starts_with("SELECT")
460            || trimmed_sql.starts_with("PRAGMA")
461            || trimmed_sql.starts_with("EXPLAIN");
462
463        if is_select_query {
464            // For SELECT queries, fetch all rows
465            let result = tokio::time::timeout(
466                std::time::Duration::from_secs(QUERY_TIMEOUT_SECONDS),
467                sqlx::query(sql).fetch_all(&self.pool),
468            )
469            .await;
470
471            let execution_time_milliseconds = start_time.elapsed().as_millis() as u64;
472
473            match result {
474                Ok(Ok(rows)) => {
475                    // Check row limit
476                    if rows.len() > MAX_RESULT_ROWS as usize {
477                        return Err(DatabaseError::TooManyRows(MAX_RESULT_ROWS));
478                    }
479
480                    // Extract columns from first row or return empty result
481                    let columns = if let Some(first_row) = rows.first() {
482                        first_row
483                            .columns()
484                            .iter()
485                            .map(|column| column.name().to_string())
486                            .collect()
487                    } else {
488                        Vec::new()
489                    };
490
491                    // Convert rows to JSON
492                    let mut json_rows = Vec::new();
493                    for row in &rows {
494                        json_rows.push(Self::row_to_json(row)?);
495                    }
496
497                    Ok(QueryResult {
498                        columns,
499                        rows: json_rows,
500                        affected_rows: rows.len() as u64,
501                        execution_time_milliseconds,
502                        error: None,
503                    })
504                }
505                Ok(Err(error)) => {
506                    // SQL execution error
507                    Ok(QueryResult {
508                        columns: Vec::new(),
509                        rows: Vec::new(),
510                        affected_rows: 0,
511                        execution_time_milliseconds,
512                        error: Some(error.to_string()),
513                    })
514                }
515                Err(_) => {
516                    // Timeout error
517                    Err(DatabaseError::Timeout)
518                }
519            }
520        } else {
521            // For INSERT/UPDATE/DELETE, use execute() to get affected rows
522            let result = tokio::time::timeout(
523                std::time::Duration::from_secs(QUERY_TIMEOUT_SECONDS),
524                sqlx::query(sql).execute(&self.pool),
525            )
526            .await;
527
528            let execution_time_milliseconds = start_time.elapsed().as_millis() as u64;
529
530            match result {
531                Ok(Ok(query_result)) => {
532                    Ok(QueryResult {
533                        columns: Vec::new(),
534                        rows: Vec::new(),
535                        affected_rows: query_result.rows_affected(),
536                        execution_time_milliseconds,
537                        error: None,
538                    })
539                }
540                Ok(Err(error)) => {
541                    Ok(QueryResult {
542                        columns: Vec::new(),
543                        rows: Vec::new(),
544                        affected_rows: 0,
545                        execution_time_milliseconds,
546                        error: Some(error.to_string()),
547                    })
548                }
549                Err(_) => {
550                    Err(DatabaseError::Timeout)
551                }
552            }
553        }
554    }
555}
556
557/// Simple base64 encoding for BLOB data
558fn base64_encode(data: &[u8]) -> String {
559    const BASE64_CHARS: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
560
561    // Limit to first 64 bytes for display purposes
562    let limited_data = if data.len() > 64 {
563        &data[..64]
564    } else {
565        data
566    };
567
568    let mut result = String::new();
569    let mut i = 0;
570
571    while i + 2 < limited_data.len() {
572        let b1 = limited_data[i];
573        let b2 = limited_data[i + 1];
574        let b3 = limited_data[i + 2];
575
576        result.push(BASE64_CHARS[(b1 >> 2) as usize] as char);
577        result.push(BASE64_CHARS[(((b1 & 0x03) << 4) | (b2 >> 4)) as usize] as char);
578        result.push(BASE64_CHARS[(((b2 & 0x0f) << 2) | (b3 >> 6)) as usize] as char);
579        result.push(BASE64_CHARS[(b3 & 0x3f) as usize] as char);
580
581        i += 3;
582    }
583
584    // Handle remaining bytes
585    if i < limited_data.len() {
586        let b1 = limited_data[i];
587        result.push(BASE64_CHARS[(b1 >> 2) as usize] as char);
588
589        if i + 1 < limited_data.len() {
590            let b2 = limited_data[i + 1];
591            result.push(BASE64_CHARS[(((b1 & 0x03) << 4) | (b2 >> 4)) as usize] as char);
592            result.push(BASE64_CHARS[((b2 & 0x0f) << 2) as usize] as char);
593            result.push('=');
594        } else {
595            result.push(BASE64_CHARS[((b1 & 0x03) << 4) as usize] as char);
596            result.push_str("==");
597        }
598    }
599
600    if data.len() > 64 {
601        result.push_str("...");
602    }
603
604    result
605}
606
607#[cfg(test)]
608mod tests {
609    use super::*;
610
611    #[test]
612    fn test_quote_identifier() {
613        assert_eq!(SqliteProvider::quote_identifier("users"), "\"users\"");
614        assert_eq!(
615            SqliteProvider::quote_identifier("table\"name"),
616            "\"table\"\"name\""
617        );
618    }
619
620    #[test]
621    fn test_build_where_clause() {
622        let mut filters = std::collections::HashMap::new();
623        filters.insert("name".to_string(), "John".to_string());
624        filters.insert("age".to_string(), "30".to_string());
625
626        let (clause, values) = SqliteProvider::build_where_clause(&filters);
627        assert!(clause.contains("WHERE"));
628        assert!(clause.contains("\"name\""));
629        assert!(clause.contains("\"age\""));
630        assert_eq!(values.len(), 2);
631    }
632
633    #[test]
634    fn test_build_order_clause() {
635        let clause = SqliteProvider::build_order_clause(Some("name"), Some(SortOrder::Ascending));
636        assert!(clause.contains("ORDER BY"));
637        assert!(clause.contains("\"name\""));
638        assert!(clause.contains("ASC"));
639
640        let clause = SqliteProvider::build_order_clause(Some("id"), Some(SortOrder::Descending));
641        assert!(clause.contains("DESC"));
642
643        let clause = SqliteProvider::build_order_clause(None, None);
644        assert!(clause.is_empty());
645    }
646
647    #[test]
648    fn test_base64_encode() {
649        let data = b"Hello, World!";
650        let encoded = base64_encode(data);
651        assert!(!encoded.is_empty());
652        assert!(encoded.chars().all(|c| c.is_alphanumeric() || c == '+' || c == '/' || c == '='));
653    }
654}