easy_db/
lib.rs

1use axum::{
2    extract::{Path, Query, State},
3    http::StatusCode,
4    routing::{delete, get, post, put},
5    Json, Router,
6};
7use rusqlite::{types::ValueRef, Connection, ToSql};
8use serde_json::{Map, Value};
9use std::collections::HashMap;
10use std::sync::{Arc, Mutex};
11use tower_http::cors::CorsLayer;
12
13// --- SECURITY CHECK ---
14// SQL Injection protection: Ensures table and column names only contain safe characters.
15fn is_valid_identifier(name: &str) -> bool {
16    name.chars().all(|c| c.is_ascii_alphanumeric() || c == '_')
17}
18
19// =========================================================
20// 1. SERVER PART (EasyDB)
21// =========================================================
22
23/// Main library structure (Server Engine)
24pub struct EasyDB {
25    pub db_name: String,
26    conn: Arc<Mutex<Connection>>,
27    exposed_tables: Vec<String>,
28}
29
30impl EasyDB {
31    /// Initializes the database connection.
32    pub fn init(name: &str) -> anyhow::Result<Self> {
33        let db_path = format!("{}.db", name);
34        let conn = Connection::open(db_path)?;
35
36        Ok(Self {
37            db_name: name.to_string(),
38            conn: Arc::new(Mutex::new(conn)),
39            exposed_tables: Vec::new(),
40        })
41    }
42
43    /// Creates a table and automatically exposes it to the API.
44    pub fn create_table(&mut self, table_name: &str, columns: &str) -> anyhow::Result<()> {
45        // Security check for table name
46        if !is_valid_identifier(table_name) {
47            return Err(anyhow::anyhow!("Invalid table name: {}", table_name));
48        }
49
50        let sql = format!("CREATE TABLE IF NOT EXISTS {} ({})", table_name, columns);
51
52        let conn = self.conn.lock().unwrap();
53        conn.execute(&sql, [])?;
54
55        self.exposed_tables.push(table_name.to_string());
56        println!("✅ Table '{}' created and exposed to API.", table_name);
57        Ok(())
58    }
59
60    /// Starts the server and generates routes.
61    pub async fn run_server(self, port: u16) -> anyhow::Result<()> {
62        let mut app = Router::new();
63        let shared_state = Arc::clone(&self.conn);
64
65        // Dynamically add routes for each table
66        for table in &self.exposed_tables {
67            let t = table.clone();
68            let state = Arc::clone(&shared_state);
69
70            app = app
71                .route(
72                    &format!("/{}", t),
73                    get({
74                        let t = t.clone();
75                        let s = Arc::clone(&state);
76                        move |q| handle_get(State(s), t, q)
77                    }),
78                )
79                .route(
80                    &format!("/{}", t),
81                    post({
82                        let t = t.clone();
83                        let s = Arc::clone(&state);
84                        move |j| handle_post(State(s), t, j)
85                    }),
86                )
87                .route(
88                    &format!("/{}/:id", t),
89                    put({
90                        let t = t.clone();
91                        let s = Arc::clone(&state);
92                        move |p, j| handle_put(State(s), t, p, j)
93                    }),
94                )
95                .route(
96                    &format!("/{}/:id", t),
97                    delete({
98                        let t = t.clone();
99                        let s = Arc::clone(&state);
100                        move |p| handle_delete(State(s), t, p)
101                    }),
102                );
103        }
104
105        // CORS: Allow requests from anywhere (Permissive)
106        app = app.layer(CorsLayer::permissive());
107
108        let addr = format!("0.0.0.0:{}", port);
109        let listener = tokio::net::TcpListener::bind(&addr).await?;
110        println!("🚀 Easy-DB Server is running: http://{}", addr);
111
112        axum::serve(listener, app).await?;
113        Ok(())
114    }
115}
116
117// =========================================================
118// 2. CLIENT PART (EasyClient)
119// =========================================================
120
121/// Client Structure: Allows users to easily connect to the server
122pub struct EasyClient {
123    pub base_url: String,
124}
125
126impl EasyClient {
127    /// Creates a new client (e.g., localhost, 9000)
128    pub fn new(host: &str, port: u16) -> Self {
129        Self {
130            base_url: format!("http://{}:{}", host, port),
131        }
132    }
133
134    /// Sends a GET request (Supports Filtering and Sorting)
135    pub async fn get(
136        &self,
137        table: &str,
138        params: Option<HashMap<&str, &str>>,
139    ) -> anyhow::Result<Value> {
140        let mut url = format!("{}/{}", self.base_url, table);
141
142        // If there are filter parameters, add them to the URL
143        if let Some(p) = params {
144            let query_str: Vec<String> = p.iter().map(|(k, v)| format!("{}={}", k, v)).collect();
145            if !query_str.is_empty() {
146                url.push_str(&format!("?{}", query_str.join("&")));
147            }
148        }
149
150        let res = reqwest::get(url).await?.json::<Value>().await?;
151        Ok(res)
152    }
153
154    /// Sends a POST request (Create Data)
155    pub async fn post(&self, table: &str, data: Value) -> anyhow::Result<Value> {
156        let client = reqwest::Client::new();
157        let url = format!("{}/{}", self.base_url, table);
158
159        let res = client
160            .post(url)
161            .json(&data)
162            .send()
163            .await?
164            .json::<Value>()
165            .await?;
166
167        Ok(res)
168    }
169
170    /// Sends a PUT request (Update Data)
171    pub async fn put(&self, table: &str, id: i64, data: Value) -> anyhow::Result<Value> {
172        let client = reqwest::Client::new();
173        let url = format!("{}/{}/{}", self.base_url, table, id);
174        let res = client
175            .put(url)
176            .json(&data)
177            .send()
178            .await?
179            .json::<Value>()
180            .await?;
181        Ok(res)
182    }
183
184    /// Sends a DELETE request (Delete Data)
185    pub async fn delete(&self, table: &str, id: i64) -> anyhow::Result<Value> {
186        let client = reqwest::Client::new();
187        let url = format!("{}/{}/{}", self.base_url, table, id);
188        let res = client.delete(url).send().await?.json::<Value>().await?;
189        Ok(res)
190    }
191}
192
193// =========================================================
194// 3. HANDLERS (API Logic)
195// =========================================================
196
197/// GET: List, filter, and sort data (SECURE VERSION)
198async fn handle_get(
199    State(db): State<Arc<Mutex<Connection>>>,
200    table_name: String,
201    Query(params): Query<HashMap<String, String>>,
202) -> (StatusCode, Json<Value>) {
203    let conn = db.lock().unwrap();
204    let mut sql = format!("SELECT * FROM {}", table_name);
205    let mut filters = Vec::new();
206    let mut sql_params: Vec<Box<dyn ToSql>> = Vec::new();
207
208    // 1. Secure Filtering (Parameterized Query)
209    for (k, v) in &params {
210        if !k.starts_with('_') {
211            if !is_valid_identifier(k) {
212                return (
213                    StatusCode::BAD_REQUEST,
214                    Json(serde_json::json!({"error": "Invalid column name"})),
215                );
216            }
217            filters.push(format!("{} = ?", k));
218            sql_params.push(Box::new(v.clone()));
219        }
220    }
221
222    if !filters.is_empty() {
223        sql.push_str(&format!(" WHERE {}", filters.join(" AND ")));
224    }
225
226    // 2. Sorting
227    if let Some(sort_col) = params.get("_sort") {
228        if !is_valid_identifier(sort_col) {
229            return (
230                StatusCode::BAD_REQUEST,
231                Json(serde_json::json!({"error": "Invalid sort column"})),
232            );
233        }
234        let order = params
235            .get("_order")
236            .map(|s| s.to_uppercase())
237            .unwrap_or("ASC".to_string());
238        let safe_order = if order == "DESC" { "DESC" } else { "ASC" };
239        sql.push_str(&format!(" ORDER BY {} {}", sort_col, safe_order));
240    }
241
242    // 3. Execute Query
243    let mut stmt = match conn.prepare(&sql) {
244        Ok(s) => s,
245        Err(e) => {
246            return (
247                StatusCode::INTERNAL_SERVER_ERROR,
248                Json(serde_json::json!({"error": e.to_string()})),
249            )
250        }
251    };
252
253    let rows = stmt.query_map(
254        rusqlite::params_from_iter(sql_params.iter().map(|p| p.as_ref())),
255        |row| Ok(row_to_json(row)),
256    );
257
258    match rows {
259        Ok(mapped) => {
260            let results: Vec<Value> = mapped.filter_map(|r| r.ok()).collect();
261            (StatusCode::OK, Json(Value::from(results)))
262        }
263        Err(e) => (
264            StatusCode::INTERNAL_SERVER_ERROR,
265            Json(serde_json::json!({"error": e.to_string()})),
266        ),
267    }
268}
269
270/// POST: Create new record (SECURE VERSION)
271async fn handle_post(
272    State(db): State<Arc<Mutex<Connection>>>,
273    table_name: String,
274    Json(payload): Json<Value>,
275) -> (StatusCode, Json<Value>) {
276    let conn = db.lock().unwrap();
277
278    if let Some(obj) = payload.as_object() {
279        if obj.is_empty() {
280            return (
281                StatusCode::BAD_REQUEST,
282                Json(serde_json::json!({"error": "Empty JSON body"})),
283            );
284        }
285
286        let keys: Vec<String> = obj.keys().cloned().collect();
287        for key in &keys {
288            if !is_valid_identifier(key) {
289                return (
290                    StatusCode::BAD_REQUEST,
291                    Json(serde_json::json!({"error": format!("Invalid column: {}", key)})),
292                );
293            }
294        }
295
296        let placeholders: Vec<String> = keys.iter().map(|_| "?".to_string()).collect();
297        let sql = format!(
298            "INSERT INTO {} ({}) VALUES ({})",
299            table_name,
300            keys.join(", "),
301            placeholders.join(", ")
302        );
303
304        let vals: Vec<String> = obj
305            .values()
306            .map(|v| v.as_str().unwrap_or(&v.to_string()).to_string())
307            .collect();
308
309        match conn.execute(&sql, rusqlite::params_from_iter(vals.iter())) {
310            Ok(_) => (
311                StatusCode::CREATED,
312                Json(serde_json::json!({"status": "success", "message": "Record created"})),
313            ),
314            Err(e) => (
315                StatusCode::INTERNAL_SERVER_ERROR,
316                Json(serde_json::json!({"error": e.to_string()})),
317            ),
318        }
319    } else {
320        (
321            StatusCode::BAD_REQUEST,
322            Json(serde_json::json!({"error": "Invalid JSON format"})),
323        )
324    }
325}
326
327/// PUT: Update record (SECURE VERSION)
328async fn handle_put(
329    State(db): State<Arc<Mutex<Connection>>>,
330    table_name: String,
331    Path(id): Path<i32>,
332    Json(payload): Json<Value>,
333) -> (StatusCode, Json<Value>) {
334    let conn = db.lock().unwrap();
335
336    if let Some(obj) = payload.as_object() {
337        for key in obj.keys() {
338            if !is_valid_identifier(key) {
339                return (
340                    StatusCode::BAD_REQUEST,
341                    Json(serde_json::json!({"error": "Invalid column name"})),
342                );
343            }
344        }
345
346        let updates: Vec<String> = obj.keys().map(|k| format!("{} = ?", k)).collect();
347        let sql = format!(
348            "UPDATE {} SET {} WHERE id = ?",
349            table_name,
350            updates.join(", ")
351        );
352
353        let mut params: Vec<String> = obj
354            .values()
355            .map(|v| v.as_str().unwrap_or(&v.to_string()).to_string())
356            .collect();
357        params.push(id.to_string());
358
359        match conn.execute(&sql, rusqlite::params_from_iter(params.iter())) {
360            Ok(affected) => {
361                if affected == 0 {
362                    (
363                        StatusCode::NOT_FOUND,
364                        Json(serde_json::json!({"error": "Record not found"})),
365                    )
366                } else {
367                    (
368                        StatusCode::OK,
369                        Json(serde_json::json!({"status": "success", "message": "Record updated"})),
370                    )
371                }
372            }
373            Err(e) => (
374                StatusCode::INTERNAL_SERVER_ERROR,
375                Json(serde_json::json!({"error": e.to_string()})),
376            ),
377        }
378    } else {
379        (
380            StatusCode::BAD_REQUEST,
381            Json(serde_json::json!({"error": "Invalid JSON format"})),
382        )
383    }
384}
385
386/// DELETE: Delete record (SECURE VERSION)
387async fn handle_delete(
388    State(db): State<Arc<Mutex<Connection>>>,
389    table_name: String,
390    Path(id): Path<i32>,
391) -> (StatusCode, Json<Value>) {
392    let conn = db.lock().unwrap();
393    let sql = format!("DELETE FROM {} WHERE id = ?", table_name);
394
395    match conn.execute(&sql, [id]) {
396        Ok(affected) => {
397            if affected == 0 {
398                (
399                    StatusCode::NOT_FOUND,
400                    Json(serde_json::json!({"error": "Record not found"})),
401                )
402            } else {
403                (
404                    StatusCode::OK,
405                    Json(serde_json::json!({"status": "success", "message": "Record deleted"})),
406                )
407            }
408        }
409        Err(e) => (
410            StatusCode::INTERNAL_SERVER_ERROR,
411            Json(serde_json::json!({"error": e.to_string()})),
412        ),
413    }
414}
415
416/// Helper: Converts SQLite row to JSON
417fn row_to_json(row: &rusqlite::Row) -> Value {
418    let mut map = Map::new();
419    let column_names = row.as_ref().column_names();
420
421    for (i, name) in column_names.iter().enumerate() {
422        let value = match row.get_ref(i).unwrap() {
423            ValueRef::Null => Value::Null,
424            ValueRef::Integer(n) => Value::from(n),
425            ValueRef::Real(f) => Value::from(f),
426            ValueRef::Text(t) => Value::from(std::str::from_utf8(t).unwrap_or("")),
427            ValueRef::Blob(b) => Value::from(format!("{:?}", b)),
428        };
429        map.insert(name.to_string(), value);
430    }
431    Value::Object(map)
432}