Skip to main content

d1_orm_engine/
table.rs

1use d1_orm_query::{build_conditions, build_tail, Query, Set};
2use serde::Deserialize;
3use std::marker::PhantomData;
4use worker::D1Database;
5use crate::{convert::to_js_vec, error::OrmError, model::D1Model};
6
7pub struct Table<'db, M: D1Model> {
8    pub(crate) db: &'db D1Database,
9    _m: PhantomData<M>,
10}
11
12impl<'db, M: D1Model> Table<'db, M> {
13    pub fn new(db: &'db D1Database) -> Self {
14        Self { db, _m: PhantomData }
15    }
16
17    pub async fn insert(&self, model: &M) -> Result<M, OrmError> {
18        let cols = M::COLUMNS.join(", ");
19        let phs = (1..=M::COLUMNS.len()).map(|i| format!("?{}", i)).collect::<Vec<_>>().join(", ");
20        let sql = format!("INSERT INTO {} ({}) VALUES ({}) RETURNING *", M::TABLE, cols, phs);
21        let js = to_js_vec(&model.values());
22        self.db.prepare(&sql)
23            .bind(&js).map_err(|_| OrmError::Bind)?
24            .first::<M>(None).await.map_err(|_| OrmError::Execute)?
25            .ok_or(OrmError::Execute)
26    }
27
28    pub async fn insert_batch(&self, models: &[M]) -> Result<Vec<M>, OrmError> {
29        if models.is_empty() { return Ok(vec![]); }
30        let cols = M::COLUMNS.join(", ");
31        let phs = (1..=M::COLUMNS.len()).map(|i| format!("?{}", i)).collect::<Vec<_>>().join(", ");
32        let sql = format!("INSERT INTO {} ({}) VALUES ({}) RETURNING *", M::TABLE, cols, phs);
33        let stmts = models.iter().map(|m| {
34            let js = to_js_vec(&m.values());
35            self.db.prepare(&sql).bind(&js).map_err(|_| OrmError::Bind)
36        }).collect::<Result<Vec<_>, _>>()?;
37        let results = self.db.batch(stmts).await.map_err(|_| OrmError::Execute)?;
38        results.into_iter().map(|r| {
39            r.results::<M>().map_err(|_| OrmError::Deserialize)?
40                .into_iter().next().ok_or(OrmError::Execute)
41        }).collect()
42    }
43
44    pub async fn find_one(&self, query: Query) -> Result<Option<M>, OrmError> {
45        let cols = M::COLUMNS.join(", ");
46        let (where_parts, values) = build_conditions(&query, 1);
47        let where_sql = if where_parts.is_empty() { String::new() } else { format!("WHERE {}", where_parts) };
48        let sql = format!("SELECT {} FROM {} {}", cols, M::TABLE, where_sql);
49        let js = to_js_vec(&values);
50        self.db.prepare(&sql)
51            .bind(&js).map_err(|_| OrmError::Bind)?
52            .first::<M>(None).await.map_err(|_| OrmError::Execute)
53    }
54
55    pub async fn find_all(&self, query: Query) -> Result<Vec<M>, OrmError> {
56        let cols = M::COLUMNS.join(", ");
57        let (where_parts, values) = build_conditions(&query, 1);
58        let where_sql = if where_parts.is_empty() { String::new() } else { format!("WHERE {}", where_parts) };
59        let tail = build_tail(&query);
60        let sql = format!("SELECT {} FROM {} {}{}", cols, M::TABLE, where_sql, tail);
61        let js = to_js_vec(&values);
62        let result = self.db.prepare(&sql)
63            .bind(&js).map_err(|_| OrmError::Bind)?
64            .all().await.map_err(|_| OrmError::Execute)?;
65        result.results::<M>().map_err(|_| OrmError::Deserialize)
66    }
67
68    pub async fn update(&self, set: Set, query: Query) -> Result<Option<M>, OrmError> {
69        if set.is_empty() { return Ok(None); }
70        let (set_sql, set_vals, next_n) = set.build(1);
71        let (where_parts, where_vals) = build_conditions(&query, next_n);
72        let where_sql = if where_parts.is_empty() { String::new() } else { format!("WHERE {}", where_parts) };
73        let sql = format!("UPDATE {} SET {} {} RETURNING *", M::TABLE, set_sql, where_sql);
74        let all_vals: Vec<_> = set_vals.into_iter().chain(where_vals).collect();
75        let js = to_js_vec(&all_vals);
76        self.db.prepare(&sql)
77            .bind(&js).map_err(|_| OrmError::Bind)?
78            .first::<M>(None).await.map_err(|_| OrmError::Execute)
79    }
80
81    pub async fn delete(&self, query: Query) -> Result<(), OrmError> {
82        let (where_parts, values) = build_conditions(&query, 1);
83        let where_sql = if where_parts.is_empty() { String::new() } else { format!("WHERE {}", where_parts) };
84        let sql = format!("DELETE FROM {} {}", M::TABLE, where_sql);
85        let js = to_js_vec(&values);
86        self.db.prepare(&sql)
87            .bind(&js).map_err(|_| OrmError::Bind)?
88            .run().await.map_err(|_| OrmError::Execute)?;
89        Ok(())
90    }
91
92    pub async fn count(&self, query: Query) -> Result<u64, OrmError> {
93        #[derive(Deserialize)]
94        struct CountRow { count: u64 }
95        let (where_parts, values) = build_conditions(&query, 1);
96        let where_sql = if where_parts.is_empty() { String::new() } else { format!("WHERE {}", where_parts) };
97        let sql = format!("SELECT COUNT(*) as count FROM {} {}", M::TABLE, where_sql);
98        let js = to_js_vec(&values);
99        let row = self.db.prepare(&sql)
100            .bind(&js).map_err(|_| OrmError::Bind)?
101            .first::<CountRow>(None).await.map_err(|_| OrmError::Execute)?;
102        Ok(row.map(|r| r.count).unwrap_or(0))
103    }
104}