reefdb/
lib.rs

1mod storage;
2
3use error::ReefDBError;
4use indexes::fts::{
5    default::{DefaultSearchIdx, OnDiskSearchIdx},
6    search::Search,
7};
8use nom::IResult;
9mod indexes;
10mod sql;
11mod transaction;
12use result::ReefDBResult;
13use sql::{
14    clauses::{join_clause::JoinType, wheres::where_type::WhereType},
15    data_type::DataType,
16    data_value::DataValue,
17    statements::{
18        create::CreateStatement, delete::DeleteStatement, insert::InsertStatement,
19        select::SelectStatement, update::UpdateStatement, Statement,
20    },
21};
22
23mod error;
24mod result;
25
26use storage::{disk::OnDiskStorage, memory::InMemoryStorage, Storage};
27
28pub type InMemoryReefDB = ReefDB<InMemoryStorage, DefaultSearchIdx>;
29
30impl InMemoryReefDB {
31    pub fn new() -> Self {
32        ReefDB::init((), ())
33    }
34}
35
36pub type OnDiskReefDB = ReefDB<OnDiskStorage, OnDiskSearchIdx>;
37
38impl OnDiskReefDB {
39    pub fn new(db_path: String, index_path: String) -> Self {
40        ReefDB::init(db_path, index_path)
41    }
42}
43
44//clone
45#[derive(Clone)]
46pub struct ReefDB<S: Storage, FTS: Search> {
47    tables: S,
48    inverted_index: FTS,
49}
50
51impl<S: Storage, FTS: Search> ReefDB<S, FTS> {
52    fn init(args: S::NewArgs, args2: FTS::NewArgs) -> Self {
53        ReefDB {
54            tables: S::new(args),
55            inverted_index: FTS::new(args2),
56        }
57    }
58
59    pub fn query(&mut self, query: &str) -> Result<ReefDBResult, ReefDBError> {
60        match Statement::parse(query) {
61            Ok((_, stmt)) => self.execute_statement(stmt),
62            Err(err) => {
63                eprintln!("Failed to parse statement: {}", err);
64                Err(ReefDBError::Other(err.to_string()))
65            }
66        }
67    }
68
69    fn execute_statement(&mut self, stmt: Statement) -> Result<ReefDBResult, ReefDBError> {
70        match stmt {
71            Statement::Delete(DeleteStatement::FromTable(table_name, where_type)) => {
72                if let Some((schema, table)) = self.tables.get_table(&table_name) {
73                    let mut deleted_rows = 0;
74                    for i in (0..table.len()).rev() {
75                        if let Some(where_type) = &where_type {
76                            match where_type {
77                                WhereType::Regular(where_clause) => {
78                                    if let Some(col_index) = schema.iter().position(|column_def| {
79                                        &column_def.name == &where_clause.col_name
80                                    }) {
81                                        if table[i][col_index] == where_clause.value {
82                                            table.remove(i);
83                                            //also remove from inverted index if it's a fts column
84                                            if schema[col_index].data_type == DataType::FTSText {
85                                                self.inverted_index.remove_document(
86                                                    &table_name,
87                                                    &where_clause.col_name,
88                                                    i,
89                                                );
90                                            }
91                                            deleted_rows += 1;
92                                        }
93                                    }
94                                }
95                                WhereType::FTS(_) => unimplemented!(),
96                            }
97                        } else {
98                            table.remove(i);
99                            deleted_rows += 1;
100                        }
101                    }
102                    Ok(ReefDBResult::Delete(deleted_rows))
103                } else {
104                    Err(ReefDBError::TableNotFound(table_name))
105                }
106            }
107            Statement::Create(CreateStatement::Table(table_name, cols)) => {
108                self.tables
109                    .insert_table(table_name.clone(), cols.clone(), Vec::new());
110
111                // Add columns with DataType::FTSText to the InvertedIndex
112                for column_def in cols.iter() {
113                    if column_def.data_type == DataType::FTSText {
114                        self.inverted_index
115                            .add_column(&table_name, &column_def.name);
116                    }
117                }
118
119                Ok(ReefDBResult::CreateTable)
120            }
121            Statement::Insert(InsertStatement::IntoTable(table_name, values)) => {
122                let row_id = self.tables.push_value(&table_name, values.clone());
123                if let Some((schema, _)) = self.tables.get_table(&table_name) {
124                    for (i, value) in values.iter().enumerate() {
125                        if schema[i].data_type == DataType::FTSText {
126                            if let DataValue::Text(ref text) = value {
127                                self.inverted_index.add_document(
128                                    &table_name,
129                                    &schema[i].name,
130                                    row_id,
131                                    text,
132                                );
133                            }
134                        }
135                    }
136                }
137                Ok(ReefDBResult::Insert(1))
138            }
139            Statement::Select(SelectStatement::FromTable(
140                table_name,
141                columns,
142                where_type,
143                joins,
144            )) => {
145                if let Some((schema, table)) = self.tables.get_table_ref(&table_name) {
146                    let mut result = Vec::<(usize, Vec<DataValue>)>::new();
147
148                    // If there are no join clauses, perform a regular select operation
149                    if joins.is_empty() {
150                        // if there is a where clause, filter the result
151                        if let Some(where_type) = where_type {
152                            let column_indexes: Vec<_> = columns
153                                .iter()
154                                .map(|column_name| {
155                                    schema
156                                        .iter()
157                                        .position(|column_def| {
158                                            &column_def.name == &column_name.name
159                                        })
160                                        .unwrap()
161                                })
162                                .collect();
163                            match where_type {
164                                WhereType::FTS(fts_where) => {
165                                    let row_ids = self.inverted_index.search(
166                                        &table_name,
167                                        &fts_where.col.name,
168                                        &fts_where.query,
169                                    );
170                                    for (rowid, row) in table.iter().enumerate() {
171                                        if row_ids.contains(&rowid) {
172                                            let selected_columns: Vec<_> = row
173                                                .iter()
174                                                .enumerate()
175                                                .filter_map(|(i, value)| {
176                                                    if column_indexes.contains(&i) {
177                                                        Some(value.clone())
178                                                    } else {
179                                                        None
180                                                    }
181                                                })
182                                                .collect();
183                                            result.push((rowid, selected_columns));
184                                        }
185                                    }
186                                }
187                                WhereType::Regular(where_clause) => {
188                                    for (rowid, row) in table.iter().enumerate() {
189                                        let selected_columns: Vec<_> = row
190                                            .iter()
191                                            .enumerate()
192                                            .filter_map(|(i, value)| {
193                                                if column_indexes.contains(&i) {
194                                                    Some(value.clone())
195                                                } else {
196                                                    None
197                                                }
198                                            })
199                                            .collect();
200                                        if let Some(col_index) =
201                                            schema.iter().position(|column_def| {
202                                                &column_def.name == &where_clause.col_name
203                                            })
204                                        {
205                                            if row[col_index] == where_clause.value {
206                                                result.push((rowid, selected_columns));
207                                            }
208                                        } else {
209                                            eprintln!(
210                                                "Column not found: {}",
211                                                where_clause.col_name
212                                            );
213                                        }
214                                    }
215                                }
216                            }
217                        } else {
218                            for (rowid, row) in table.iter().enumerate() {
219                                let column_indexes: Vec<_> = columns
220                                    .iter()
221                                    .map(|column_name| {
222                                        schema
223                                            .iter()
224                                            .position(|column_def| {
225                                                &column_def.name == &column_name.name
226                                            })
227                                            .unwrap()
228                                    })
229                                    .collect();
230
231                                let selected_columns: Vec<_> = row
232                                    .iter()
233                                    .enumerate()
234                                    .filter_map(|(i, value)| {
235                                        if column_indexes.contains(&i) {
236                                            Some(value.clone())
237                                        } else {
238                                            None
239                                        }
240                                    })
241                                    .collect();
242                                result.push((rowid, selected_columns));
243                            }
244                        }
245                    } else {
246                        println!("Joining tables");
247                        // Iterate over join clauses
248                        for join in joins {
249                            let join_type: JoinType = join.join_type;
250                            let join_table_name = join.table_name;
251                            let left_col = join.on.0;
252                            let right_col = join.on.1;
253
254                            if join_type == JoinType::Inner {
255                                println!("Inner join");
256                                if let Some((join_schema, join_table)) =
257                                    self.tables.get_table_ref(&join_table_name)
258                                {
259                                    let join_schema = join_schema.clone();
260                                    println!(
261                                        "Join schema: {:?} joined table name {:?}",
262                                        join_schema, join_table_name
263                                    );
264
265                                    println!("normal schema: {:?}", schema);
266
267                                    let join_table = join_table.clone();
268                                    println!("Join table: {:?}", join_table);
269                                    let left_col_index = schema
270                                        .iter()
271                                        .position(|col| col.name == left_col.column_name)
272                                        .unwrap();
273                                    println!(
274                                        "Left col index: {:?} left_col.column_name {:?}",
275                                        left_col_index, left_col.column_name
276                                    );
277                                    let right_col_index = join_schema
278                                        .iter()
279                                        .position(|col| col.name == right_col.column_name)
280                                        .unwrap();
281                                    println!(
282                                        "Right col index: {:?} right_col.column_name  {:?}",
283                                        right_col_index, right_col.column_name
284                                    );
285                                    for (rowid, row) in table.iter().enumerate() {
286                                        for join_row in join_table.iter() {
287                                            println!("Join row: {:?}", join_row);
288                                            println!("row {:?}", row);
289
290                                            if row[left_col_index] == join_row[right_col_index] {
291                                                let mut selected_columns = vec![];
292
293                                                for column_name in &columns {
294                                                    if let Some(index) = schema
295                                                        .iter()
296                                                        .position(|column_def| &column_def.name == &column_name.name)
297                                                    {
298                                                        println!(
299                                                            "index {:?} schema.len() {:?}",
300                                                            index, schema.len()
301                                                        );
302                                                        selected_columns.push(row[index].clone());
303                                                    } else if let Some(join_col_index) = join_schema
304                                                        .iter()
305                                                        .position(|col| {
306                                                            println!(
307                                                                "col.name {:?} column_name.name {:?}",
308                                                                col.name, column_name.name
309                                                            );
310                                                            &col.name == &column_name.name
311                                                        }) {
312                                                        println!("idx {:?}", join_col_index);
313                                                        selected_columns.push(join_row[join_col_index].clone());
314                                                        println!("selected_columns {:?}", selected_columns);
315                                                    } else {
316                                                        panic!("Invalid column name.");
317                                                    }
318                                                }
319                                                result.push((rowid, selected_columns));
320                                            }
321                                        }
322                                    }
323                                }
324                            }
325                        }
326                    }
327
328                    Ok(ReefDBResult::Select(result))
329                } else {
330                    Err(ReefDBError::TableNotFound(table_name))
331                }
332            }
333            Statement::Update(UpdateStatement::UpdateTable(table_name, updates, where_clause)) => {
334                match where_clause {
335                    Some(WhereType::Regular(where_clause)) => {
336                        let where_col = (where_clause.col_name, where_clause.value);
337                        let affected_rows =
338                            self.tables
339                                .update_table(&table_name, updates.clone(), Some(where_col));
340
341                        // Update FTSText columns in the InvertedIndex
342                        let fts_columns = self.tables.get_fts_columns(&table_name);
343                        for (column_name, _) in &updates {
344                            if fts_columns.contains(&column_name) {
345                                let (_, rows) = self.tables.get_table_ref(&table_name).unwrap();
346                                for (rowid, row) in rows.iter().enumerate() {
347                                    let schema = self.tables.get_schema_ref(&table_name).unwrap();
348                                    let column_index = schema
349                                        .iter()
350                                        .position(|col| col.name == *column_name)
351                                        .unwrap();
352                                    if let DataValue::Text(ref text) = row[column_index] {
353                                        self.inverted_index.update_document(
354                                            &table_name,
355                                            &column_name,
356                                            rowid,
357                                            text,
358                                        );
359                                    }
360                                }
361                            }
362                        }
363
364                        Ok(ReefDBResult::Update(affected_rows))
365                    }
366                    Some(WhereType::FTS(_)) => {
367                        unimplemented!()
368                    }
369                    None => {
370                        let affected_rows = self.tables.update_table(&table_name, updates, None);
371                        Ok(ReefDBResult::Update(affected_rows))
372                    }
373                }
374            }
375        }
376    }
377}
378
379#[cfg(test)]
380mod tests {
381    use std::fs;
382
383    use super::*;
384
385    #[test]
386    fn test_inner_join() {
387        let mut db = InMemoryReefDB::new();
388
389        let queries = vec![
390        "CREATE TABLE users (id INTEGER PRIMARY KEY, name TEXT)",
391        "CREATE TABLE posts (id INTEGER PRIMARY KEY, title TEXT, user_id INTEGER FOREIGN KEY (id) REFERENCES users)",
392        "INSERT INTO users VALUES (1, 'Alice')",
393        "INSERT INTO users VALUES (2, 'Bob')",
394        "INSERT INTO posts VALUES (1, 'Post 1', 1)",
395        "INSERT INTO posts VALUES (2, 'Post 2', 2)",
396        "SELECT users.name, posts.title FROM users INNER JOIN posts ON users.id = posts.user_id",
397    ];
398
399        let mut results = Vec::new();
400        for query in queries {
401            results.push(db.query(query));
402        }
403
404        let expected_results = vec![
405            Ok(ReefDBResult::CreateTable),
406            Ok(ReefDBResult::CreateTable),
407            Ok(ReefDBResult::Insert(1)),
408            Ok(ReefDBResult::Insert(1)),
409            Ok(ReefDBResult::Insert(1)),
410            Ok(ReefDBResult::Insert(1)),
411            Ok(ReefDBResult::Select(vec![
412                (
413                    0,
414                    vec![
415                        DataValue::Text("Alice".to_string()),
416                        DataValue::Text("Post 1".to_string()),
417                    ],
418                ),
419                (
420                    1,
421                    vec![
422                        DataValue::Text("Bob".to_string()),
423                        DataValue::Text("Post 2".to_string()),
424                    ],
425                ),
426            ])),
427        ];
428        assert_eq!(results, expected_results);
429    }
430
431    #[test]
432    fn test_fts_text_search() {
433        let mut db = InMemoryReefDB::new();
434
435        let queries = vec![
436            "CREATE TABLE books (title TEXT, author TEXT, description FTS_TEXT)",
437            "INSERT INTO books VALUES ('Book 1', 'Author 1', 'A book about the history of computer science.')",
438            "INSERT INTO books VALUES ('Book 2', 'Author 2', 'A book about modern programming languages.')",
439            "INSERT INTO books VALUES ('Book 3', 'Author 3', 'A book about the future of artificial intelligence.')",
440            "SELECT title, author FROM books WHERE description MATCH 'computer science'",
441            "SELECT title, author FROM books WHERE description MATCH 'artificial intelligence'",
442        ];
443
444        let mut results = Vec::new();
445        for query in queries {
446            results.push(db.query(query));
447        }
448
449        let expected_results = vec![
450            Ok(ReefDBResult::CreateTable),
451            Ok(ReefDBResult::Insert(1)),
452            Ok(ReefDBResult::Insert(1)),
453            Ok(ReefDBResult::Insert(1)),
454            Ok(ReefDBResult::Select(vec![(
455                0,
456                vec![
457                    DataValue::Text("Book 1".to_string()),
458                    DataValue::Text("Author 1".to_string()),
459                ],
460            )])),
461            Ok(ReefDBResult::Select(vec![(
462                2,
463                vec![
464                    DataValue::Text("Book 3".to_string()),
465                    DataValue::Text("Author 3".to_string()),
466                ],
467            )])),
468        ];
469
470        assert_eq!(results, expected_results);
471    }
472
473    #[test]
474    fn test_database_on_disk() {
475        let kv_path = "kv.db";
476        let index = "index.bin";
477
478        let mut db = OnDiskReefDB::new(kv_path.to_string(), index.to_string());
479
480        let queries = vec![
481            "CREATE TABLE users (name TEXT, age INTEGER)",
482            "INSERT INTO users VALUES ('alice', 30)",
483            "INSERT INTO users VALUES ('bob', 28)",
484            "UPDATE users SET age = 31 WHERE name = 'bob'",
485            "SELECT name, age FROM users",
486            "SELECT name FROM users",
487            "SELECT name FROM users WHERE age = 30",
488        ];
489        let mut results = Vec::new();
490        for query in queries {
491            results.push(db.query(query));
492        }
493
494        let expected_results = vec![
495            Ok(ReefDBResult::CreateTable),
496            Ok(ReefDBResult::Insert(1)),
497            Ok(ReefDBResult::Insert(1)),
498            Ok(ReefDBResult::Update(2)),
499            Ok(ReefDBResult::Select(vec![
500                (
501                    0,
502                    vec![DataValue::Text("alice".to_string()), DataValue::Integer(30)],
503                ),
504                (
505                    1,
506                    vec![DataValue::Text("bob".to_string()), DataValue::Integer(31)],
507                ),
508            ])),
509            Ok(ReefDBResult::Select(vec![
510                (0, vec![DataValue::Text("alice".to_string())]),
511                (1, vec![DataValue::Text("bob".to_string())]),
512            ])),
513            Ok(ReefDBResult::Select(vec![(
514                0,
515                vec![DataValue::Text("alice".to_string())],
516            )])),
517        ];
518        assert_eq!(results, expected_results);
519
520        // Check if the users table has been created
521        assert!(db.tables.table_exists(&"users".to_string()));
522
523        // Get the users table and check the number of rows
524        let (_, users) = db.tables.get_table(&"users".to_string()).unwrap();
525        // println!("{:?}", users);
526        assert_eq!(users.len(), 2);
527
528        // Check the contents of the users table
529        assert_eq!(
530            users[0],
531            vec![DataValue::Text("alice".to_string()), DataValue::Integer(30)]
532        );
533        assert_eq!(
534            users[1],
535            vec![DataValue::Text("bob".to_string()), DataValue::Integer(31)]
536        );
537
538        // Cleanup
539        fs::remove_file(kv_path).unwrap();
540    }
541
542    #[test]
543    fn test_delete() {
544        let mut db = InMemoryReefDB::new();
545        let queries = vec![
546            "CREATE TABLE users (name TEXT, age INTEGER)",
547            "INSERT INTO users VALUES ('alice', 30)",
548            "INSERT INTO users VALUES ('bob', 28)",
549            "DELETE FROM users WHERE name = 'bob'",
550            "SELECT name, age FROM users",
551            "SELECT name FROM users",
552            "SELECT name FROM users WHERE age = 30",
553        ];
554        let mut results = Vec::new();
555        for query in queries {
556            results.push(db.query(query));
557        }
558
559        let expected_results = vec![
560            Ok(ReefDBResult::CreateTable),
561            Ok(ReefDBResult::Insert(1)),
562            Ok(ReefDBResult::Insert(1)),
563            Ok(ReefDBResult::Delete(1)),
564            Ok(ReefDBResult::Select(vec![(
565                0,
566                vec![DataValue::Text("alice".to_string()), DataValue::Integer(30)],
567            )])),
568            Ok(ReefDBResult::Select(vec![(
569                0,
570                vec![DataValue::Text("alice".to_string())],
571            )])),
572            Ok(ReefDBResult::Select(vec![(
573                0,
574                vec![DataValue::Text("alice".to_string())],
575            )])),
576        ];
577        assert_eq!(results, expected_results);
578    }
579
580    #[test]
581    fn test_database() {
582        let mut db = InMemoryReefDB::new();
583
584        let queries = vec![
585            "CREATE TABLE users (name TEXT, age INTEGER)",
586            "INSERT INTO users VALUES ('alice', 30)",
587            "INSERT INTO users VALUES ('bob', 28)",
588            "UPDATE users SET age = 31 WHERE name = 'bob'",
589            "SELECT name, age FROM users",
590            "SELECT name FROM users",
591            "SELECT name FROM users WHERE age = 30",
592        ];
593        let mut results = Vec::new();
594        for query in queries {
595            results.push(db.query(query));
596        }
597
598        let expected_results = vec![
599            Ok(ReefDBResult::CreateTable),
600            Ok(ReefDBResult::Insert(1)),
601            Ok(ReefDBResult::Insert(1)),
602            Ok(ReefDBResult::Update(2)), // Updated 1 row
603            Ok(ReefDBResult::Select(vec![
604                (
605                    0,
606                    vec![DataValue::Text("alice".to_string()), DataValue::Integer(30)],
607                ),
608                (
609                    1,
610                    vec![DataValue::Text("bob".to_string()), DataValue::Integer(31)],
611                ),
612            ])),
613            Ok(ReefDBResult::Select(vec![
614                (0, vec![DataValue::Text("alice".to_string())]),
615                (1, vec![DataValue::Text("bob".to_string())]),
616            ])),
617            Ok(ReefDBResult::Select(vec![(
618                0,
619                vec![DataValue::Text("alice".to_string())],
620            )])),
621        ];
622        assert_eq!(results, expected_results);
623
624        // Check if the users table has been created
625        assert!(db.tables.table_exists(&"users".to_string()));
626
627        // Get the users table and check the number of rows
628        let (_, users) = db.tables.get_table(&"users".to_string()).unwrap();
629        // println!("{:?}", users);
630        assert_eq!(users.len(), 2);
631
632        // Check the contents of the users table
633        assert_eq!(
634            users[0],
635            vec![DataValue::Text("alice".to_string()), DataValue::Integer(30)]
636        );
637        assert_eq!(
638            users[1],
639            vec![DataValue::Text("bob".to_string()), DataValue::Integer(31)]
640        );
641    }
642}