easy_db/
lib.rs

1use axum::{
2    extract::{Path, Query, State},
3    http::StatusCode,
4    routing::{delete, get, post, put},
5    Json, Router,
6};
7use rusqlite::{types::ValueRef, Connection, ToSql};
8use serde_json::{Map, Value};
9use std::collections::HashMap;
10use std::sync::{Arc, Mutex};
11use tower_http::cors::CorsLayer;
12
13// --- SECURITY CHECK ---
14// SQL Injection protection: Ensures table and column names only contain safe characters.
15fn is_valid_identifier(name: &str) -> bool {
16    name.chars().all(|c| c.is_ascii_alphanumeric() || c == '_')
17}
18
19// =========================================================
20// 1. SERVER PART (EasyDB)
21// =========================================================
22
23/// Main library structure (Server Engine)
24pub struct EasyDB {
25    pub db_name: String,
26    conn: Arc<Mutex<Connection>>,
27    exposed_tables: Vec<String>,
28}
29
30impl EasyDB {
31    /// Initializes the database connection.
32    pub fn init(name: &str) -> anyhow::Result<Self> {
33        let db_path = format!("{}.db", name);
34        let conn = Connection::open(db_path)?;
35
36        Ok(Self {
37            db_name: name.to_string(),
38            conn: Arc::new(Mutex::new(conn)),
39            exposed_tables: Vec::new(),
40        })
41    }
42
43    /// Creates a table and automatically exposes it to the API.
44    pub fn create_table(&mut self, table_name: &str, columns: &str) -> anyhow::Result<()> {
45        // Security check for table name
46        if !is_valid_identifier(table_name) {
47            return Err(anyhow::anyhow!("Invalid table name: {}", table_name));
48        }
49
50        let sql = format!("CREATE TABLE IF NOT EXISTS {} ({})", table_name, columns);
51
52        let conn = self.conn.lock().unwrap();
53        conn.execute(&sql, [])?;
54
55        self.exposed_tables.push(table_name.to_string());
56        println!("✅ Table '{}' created and exposed to API.", table_name);
57        Ok(())
58    }
59
60    /// Starts the server and generates routes.
61    pub async fn run_server(self, port: u16) -> anyhow::Result<()> {
62        let mut app = Router::new();
63        let shared_state = Arc::clone(&self.conn);
64
65        // Dynamically add routes for each table
66        for table in &self.exposed_tables {
67            let t = table.clone();
68            let state = Arc::clone(&shared_state);
69
70            app = app
71                .route(
72                    &format!("/{}", t),
73                    get({
74                        let t = t.clone();
75                        let s = Arc::clone(&state);
76                        move |q| handle_get(State(s), t, q)
77                    }),
78                )
79                .route(
80                    &format!("/{}", t),
81                    post({
82                        let t = t.clone();
83                        let s = Arc::clone(&state);
84                        move |j| handle_post(State(s), t, j)
85                    }),
86                )
87                // FIX: Changed from /:id to /{id} for Axum 0.7 compatibility
88                // Note: We use double braces {{id}} to escape them in format! macro
89                .route(
90                    &format!("/{}/{{id}}", t),
91                    put({
92                        let t = t.clone();
93                        let s = Arc::clone(&state);
94                        move |p, j| handle_put(State(s), t, p, j)
95                    }),
96                )
97                .route(
98                    &format!("/{}/{{id}}", t),
99                    delete({
100                        let t = t.clone();
101                        let s = Arc::clone(&state);
102                        move |p| handle_delete(State(s), t, p)
103                    }),
104                );
105        }
106
107        // CORS: Allow requests from anywhere (Permissive)
108        app = app.layer(CorsLayer::permissive());
109
110        let addr = format!("0.0.0.0:{}", port);
111        let listener = tokio::net::TcpListener::bind(&addr).await?;
112        println!("🚀 Easy-DB Server is running: http://{}", addr);
113
114        axum::serve(listener, app).await?;
115        Ok(())
116    }
117}
118
119// =========================================================
120// 2. CLIENT PART (EasyClient)
121// =========================================================
122
123/// Client Structure: Allows users to easily connect to the server
124pub struct EasyClient {
125    pub base_url: String,
126}
127
128impl EasyClient {
129    /// Creates a new client (e.g., localhost, 9000)
130    pub fn new(host: &str, port: u16) -> Self {
131        Self {
132            base_url: format!("http://{}:{}", host, port),
133        }
134    }
135
136    /// Sends a GET request (Supports Filtering and Sorting)
137    pub async fn get(
138        &self,
139        table: &str,
140        params: Option<HashMap<&str, &str>>,
141    ) -> anyhow::Result<Value> {
142        let mut url = format!("{}/{}", self.base_url, table);
143
144        // If there are filter parameters, add them to the URL
145        if let Some(p) = params {
146            let query_str: Vec<String> = p.iter().map(|(k, v)| format!("{}={}", k, v)).collect();
147            if !query_str.is_empty() {
148                url.push_str(&format!("?{}", query_str.join("&")));
149            }
150        }
151
152        let res = reqwest::get(url).await?.json::<Value>().await?;
153        Ok(res)
154    }
155
156    /// Sends a POST request (Create Data)
157    pub async fn post(&self, table: &str, data: Value) -> anyhow::Result<Value> {
158        let client = reqwest::Client::new();
159        let url = format!("{}/{}", self.base_url, table);
160
161        let res = client
162            .post(url)
163            .json(&data)
164            .send()
165            .await?
166            .json::<Value>()
167            .await?;
168
169        Ok(res)
170    }
171
172    /// Sends a PUT request (Update Data)
173    pub async fn put(&self, table: &str, id: i64, data: Value) -> anyhow::Result<Value> {
174        let client = reqwest::Client::new();
175        let url = format!("{}/{}/{}", self.base_url, table, id);
176        let res = client
177            .put(url)
178            .json(&data)
179            .send()
180            .await?
181            .json::<Value>()
182            .await?;
183        Ok(res)
184    }
185
186    /// Sends a DELETE request (Delete Data)
187    pub async fn delete(&self, table: &str, id: i64) -> anyhow::Result<Value> {
188        let client = reqwest::Client::new();
189        let url = format!("{}/{}/{}", self.base_url, table, id);
190        let res = client.delete(url).send().await?.json::<Value>().await?;
191        Ok(res)
192    }
193}
194
195// =========================================================
196// 3. HANDLERS (API Logic)
197// =========================================================
198
199/// GET: List, filter, and sort data (SECURE VERSION)
200async fn handle_get(
201    State(db): State<Arc<Mutex<Connection>>>,
202    table_name: String,
203    Query(params): Query<HashMap<String, String>>,
204) -> (StatusCode, Json<Value>) {
205    let conn = db.lock().unwrap();
206    let mut sql = format!("SELECT * FROM {}", table_name);
207    let mut filters = Vec::new();
208    let mut sql_params: Vec<Box<dyn ToSql>> = Vec::new();
209
210    // 1. Secure Filtering (Parameterized Query)
211    for (k, v) in &params {
212        if !k.starts_with('_') {
213            if !is_valid_identifier(k) {
214                return (
215                    StatusCode::BAD_REQUEST,
216                    Json(serde_json::json!({"error": "Invalid column name"})),
217                );
218            }
219            filters.push(format!("{} = ?", k));
220            sql_params.push(Box::new(v.clone()));
221        }
222    }
223
224    if !filters.is_empty() {
225        sql.push_str(&format!(" WHERE {}", filters.join(" AND ")));
226    }
227
228    // 2. Sorting
229    if let Some(sort_col) = params.get("_sort") {
230        if !is_valid_identifier(sort_col) {
231            return (
232                StatusCode::BAD_REQUEST,
233                Json(serde_json::json!({"error": "Invalid sort column"})),
234            );
235        }
236        let order = params
237            .get("_order")
238            .map(|s| s.to_uppercase())
239            .unwrap_or("ASC".to_string());
240        let safe_order = if order == "DESC" { "DESC" } else { "ASC" };
241        sql.push_str(&format!(" ORDER BY {} {}", sort_col, safe_order));
242    }
243
244    // 3. Execute Query
245    let mut stmt = match conn.prepare(&sql) {
246        Ok(s) => s,
247        Err(e) => {
248            return (
249                StatusCode::INTERNAL_SERVER_ERROR,
250                Json(serde_json::json!({"error": e.to_string()})),
251            )
252        }
253    };
254
255    let rows = stmt.query_map(
256        rusqlite::params_from_iter(sql_params.iter().map(|p| p.as_ref())),
257        |row| Ok(row_to_json(row)),
258    );
259
260    match rows {
261        Ok(mapped) => {
262            let results: Vec<Value> = mapped.filter_map(|r| r.ok()).collect();
263            (StatusCode::OK, Json(Value::from(results)))
264        }
265        Err(e) => (
266            StatusCode::INTERNAL_SERVER_ERROR,
267            Json(serde_json::json!({"error": e.to_string()})),
268        ),
269    }
270}
271
272/// POST: Create new record (SECURE VERSION)
273async fn handle_post(
274    State(db): State<Arc<Mutex<Connection>>>,
275    table_name: String,
276    Json(payload): Json<Value>,
277) -> (StatusCode, Json<Value>) {
278    let conn = db.lock().unwrap();
279
280    if let Some(obj) = payload.as_object() {
281        if obj.is_empty() {
282            return (
283                StatusCode::BAD_REQUEST,
284                Json(serde_json::json!({"error": "Empty JSON body"})),
285            );
286        }
287
288        let keys: Vec<String> = obj.keys().cloned().collect();
289        for key in &keys {
290            if !is_valid_identifier(key) {
291                return (
292                    StatusCode::BAD_REQUEST,
293                    Json(serde_json::json!({"error": format!("Invalid column: {}", key)})),
294                );
295            }
296        }
297
298        let placeholders: Vec<String> = keys.iter().map(|_| "?".to_string()).collect();
299        let sql = format!(
300            "INSERT INTO {} ({}) VALUES ({})",
301            table_name,
302            keys.join(", "),
303            placeholders.join(", ")
304        );
305
306        let vals: Vec<String> = obj
307            .values()
308            .map(|v| v.as_str().unwrap_or(&v.to_string()).to_string())
309            .collect();
310
311        match conn.execute(&sql, rusqlite::params_from_iter(vals.iter())) {
312            Ok(_) => (
313                StatusCode::CREATED,
314                Json(serde_json::json!({"status": "success", "message": "Record created"})),
315            ),
316            Err(e) => (
317                StatusCode::INTERNAL_SERVER_ERROR,
318                Json(serde_json::json!({"error": e.to_string()})),
319            ),
320        }
321    } else {
322        (
323            StatusCode::BAD_REQUEST,
324            Json(serde_json::json!({"error": "Invalid JSON format"})),
325        )
326    }
327}
328
329/// PUT: Update record (SECURE VERSION)
330async fn handle_put(
331    State(db): State<Arc<Mutex<Connection>>>,
332    table_name: String,
333    Path(id): Path<i32>,
334    Json(payload): Json<Value>,
335) -> (StatusCode, Json<Value>) {
336    let conn = db.lock().unwrap();
337
338    if let Some(obj) = payload.as_object() {
339        for key in obj.keys() {
340            if !is_valid_identifier(key) {
341                return (
342                    StatusCode::BAD_REQUEST,
343                    Json(serde_json::json!({"error": "Invalid column name"})),
344                );
345            }
346        }
347
348        let updates: Vec<String> = obj.keys().map(|k| format!("{} = ?", k)).collect();
349        let sql = format!(
350            "UPDATE {} SET {} WHERE id = ?",
351            table_name,
352            updates.join(", ")
353        );
354
355        let mut params: Vec<String> = obj
356            .values()
357            .map(|v| v.as_str().unwrap_or(&v.to_string()).to_string())
358            .collect();
359        params.push(id.to_string());
360
361        match conn.execute(&sql, rusqlite::params_from_iter(params.iter())) {
362            Ok(affected) => {
363                if affected == 0 {
364                    (
365                        StatusCode::NOT_FOUND,
366                        Json(serde_json::json!({"error": "Record not found"})),
367                    )
368                } else {
369                    (
370                        StatusCode::OK,
371                        Json(serde_json::json!({"status": "success", "message": "Record updated"})),
372                    )
373                }
374            }
375            Err(e) => (
376                StatusCode::INTERNAL_SERVER_ERROR,
377                Json(serde_json::json!({"error": e.to_string()})),
378            ),
379        }
380    } else {
381        (
382            StatusCode::BAD_REQUEST,
383            Json(serde_json::json!({"error": "Invalid JSON format"})),
384        )
385    }
386}
387
388/// DELETE: Delete record (SECURE VERSION)
389async fn handle_delete(
390    State(db): State<Arc<Mutex<Connection>>>,
391    table_name: String,
392    Path(id): Path<i32>,
393) -> (StatusCode, Json<Value>) {
394    let conn = db.lock().unwrap();
395    let sql = format!("DELETE FROM {} WHERE id = ?", table_name);
396
397    match conn.execute(&sql, [id]) {
398        Ok(affected) => {
399            if affected == 0 {
400                (
401                    StatusCode::NOT_FOUND,
402                    Json(serde_json::json!({"error": "Record not found"})),
403                )
404            } else {
405                (
406                    StatusCode::OK,
407                    Json(serde_json::json!({"status": "success", "message": "Record deleted"})),
408                )
409            }
410        }
411        Err(e) => (
412            StatusCode::INTERNAL_SERVER_ERROR,
413            Json(serde_json::json!({"error": e.to_string()})),
414        ),
415    }
416}
417
418/// Helper: Converts SQLite row to JSON
419fn row_to_json(row: &rusqlite::Row) -> Value {
420    let mut map = Map::new();
421    let column_names = row.as_ref().column_names();
422
423    for (i, name) in column_names.iter().enumerate() {
424        let value = match row.get_ref(i).unwrap() {
425            ValueRef::Null => Value::Null,
426            ValueRef::Integer(n) => Value::from(n),
427            ValueRef::Real(f) => Value::from(f),
428            ValueRef::Text(t) => Value::from(std::str::from_utf8(t).unwrap_or("")),
429            ValueRef::Blob(b) => Value::from(format!("{:?}", b)),
430        };
431        map.insert(name.to_string(), value);
432    }
433    Value::Object(map)
434}