diesel_d1/
lib.rs

1use async_trait::async_trait;
2use backend::D1Backend;
3use bind_collector::D1BindCollector;
4use binding::{D1Database, D1PreparedStatement, D1Result};
5use diesel::{
6    connection::{ConnectionSealed, Instrumentation},
7    query_builder::{AsQuery, QueryFragment, QueryId},
8    ConnectionResult, QueryResult,
9};
10use diesel_async::{AsyncConnection, SimpleAsyncConnection};
11use futures_util::{
12    future::BoxFuture,
13    stream::{self, BoxStream},
14    FutureExt, StreamExt,
15};
16use js_sys::{Array, Object, Reflect};
17use query_builder::D1QueryBuilder;
18use row::D1Row;
19use transaction_manager::D1TransactionManager;
20use utils::{D1Error, SendableFuture};
21use wasm_bindgen::JsValue;
22use wasm_bindgen_futures::JsFuture;
23use worker::{console_error, console_log};
24
25pub mod backend;
26mod bind_collector;
27mod binding;
28mod query_builder;
29mod row;
30mod transaction_manager;
31mod types;
32mod utils;
33mod value;
34
35pub struct D1Connection {
36    transaction_queries: Vec<D1PreparedStatement>,
37    transaction_manager: D1TransactionManager,
38    binding: D1Database,
39}
40
41impl D1Connection {
42    pub fn new(env: worker::Env, name: &str) -> Self {
43        let binding: D1Database = Reflect::get(&env, &name.to_owned().into()).unwrap().into();
44        D1Connection {
45            transaction_queries: Vec::default(),
46            transaction_manager: D1TransactionManager::default(),
47            binding,
48        }
49    }
50}
51
52// SAFETY: this is safe under WASM and workers because there's no threads and therefore no race conditions (at least memory ones)
53unsafe impl Send for D1Connection {}
54unsafe impl Sync for D1Connection {}
55
56#[async_trait]
57impl SimpleAsyncConnection for D1Connection {
58    async fn batch_execute(&mut self, query: &str) -> diesel::QueryResult<()> {
59        let statements = [JsValue::from_str(query)].iter().collect::<Array>();
60
61        match SendableFuture(JsFuture::from(self.binding.batch(statements).unwrap())).await {
62            Ok(_) => Ok(()),
63            // FIXME(lduarte): I don't send a proper error becase I don't have time at the moment
64            Err(_) => Err(diesel::result::Error::NotFound),
65        }
66    }
67}
68#[async_trait]
69impl AsyncConnection for D1Connection {
70    type Backend = D1Backend;
71    type TransactionManager = D1TransactionManager;
72
73    #[doc = " The future returned by `AsyncConnection::execute`"]
74    type ExecuteFuture<'conn, 'query> = BoxFuture<'conn, QueryResult<usize>>;
75
76    #[doc = " The future returned by `AsyncConnection::load`"]
77    type LoadFuture<'conn, 'query> = BoxFuture<'conn, QueryResult<Self::Stream<'conn, 'query>>>;
78
79    #[doc = " The inner stream returned by `AsyncConnection::load`"]
80    type Stream<'conn, 'query> = BoxStream<'conn, QueryResult<Self::Row<'conn, 'query>>>;
81
82    #[doc = " The row type used by the stream returned by `AsyncConnection::load`"]
83    type Row<'conn, 'query> = D1Row;
84
85    async fn establish(_unused: &str) -> ConnectionResult<Self> {
86        todo!()
87    }
88
89    fn load<'conn, 'query, T>(&'conn mut self, source: T) -> Self::LoadFuture<'conn, 'query>
90    where
91        T: AsQuery + 'query,
92        T::Query: QueryFragment<Self::Backend> + QueryId + 'query,
93    {
94        let source = source.as_query();
95        let result = prepare_statement_sql(source, &self.binding);
96
97        SendableFuture(async move {
98            let promise = match result.all() {
99                Ok(res) => res,
100                Err(err) => {
101                    console_error!("{:?}", err);
102                    panic!("not supposed to happen .all call");
103                },
104            };
105
106            let result = match SendableFuture(JsFuture::from(promise)).await {
107                Ok(res) => res,
108                Err(err) => {
109                    console_error!("{:?}", err);
110                    panic!("not supposed to happen .all promise");
111                },
112            };
113
114            let result: D1Result = result.into();
115
116            let error = result.error().unwrap();
117
118            if let Some(error_str) = error {
119                return Err(diesel::result::Error::DatabaseError(
120                    diesel::result::DatabaseErrorKind::Unknown,
121                    Box::new(D1Error { message: error_str }),
122                ));
123            }
124
125            let array = result.results().unwrap().unwrap().to_vec();
126
127            if array.is_empty() {
128                return Ok(stream::iter(vec![]).boxed());
129            }
130
131            let field_keys: Vec<String> = js_sys::Object::keys(&Object::from(array[0].clone()))
132                .to_vec()
133                .iter()
134                .map(|val| val.as_string().unwrap())
135                .collect();
136
137            // FIXME: not performant at all, should work well enough
138            let rows: Vec<QueryResult<D1Row>> = array
139                .iter()
140                .map(|val| Ok(D1Row::new(val.clone(), field_keys.clone())))
141                .collect();
142            let iter = stream::iter(rows).boxed();
143            Ok(iter)
144        })
145        .boxed()
146    }
147
148    #[doc(hidden)]
149    fn execute_returning_count<'conn, 'query, T>(
150        &'conn mut self,
151        source: T,
152    ) -> Self::ExecuteFuture<'conn, 'query>
153    where
154        T: QueryFragment<Self::Backend> + QueryId + 'query,
155    {
156        let result = prepare_statement_sql(source, &self.binding);
157        SendableFuture(async move {
158            let promise = match result.all() {
159                Ok(res) => res,
160                Err(err) => {
161                    console_error!("{:?}", err);
162                    panic!("not supposed to happen .all call");
163                },
164            };
165
166            let result = match SendableFuture(JsFuture::from(promise)).await {
167                Ok(res) => res,
168                Err(err) => {
169                    console_error!("{:?}", err);
170                    panic!("not supposed to happen .all promise");
171                },
172            };
173
174            let result: D1Result = result.into();
175
176            let error = result.error().unwrap();
177
178            if let Some(error_str) = error {
179                return Err(diesel::result::Error::DatabaseError(
180                    diesel::result::DatabaseErrorKind::Unknown,
181                    Box::new(D1Error { message: error_str }),
182                ));
183            }
184
185            // if it's successful, meta exists with a `changes` key that is a number
186            let meta = result.meta().unwrap();
187            let value = js_sys::Reflect::get(&meta, &"changes".to_owned().into())
188                .unwrap()
189                .as_f64()
190                .unwrap();
191
192            Ok(value as usize)
193        })
194        .boxed()
195    }
196
197    fn transaction_state(&mut self) -> &mut D1TransactionManager {
198        &mut self.transaction_manager
199    }
200
201    #[doc(hidden)]
202    fn instrumentation(&mut self) -> &mut dyn Instrumentation {
203        todo!()
204    }
205
206    #[doc = " Set a specific [`Instrumentation`] implementation for this connection"]
207    fn set_instrumentation(&mut self, _instrumentation: impl Instrumentation) {
208        todo!()
209    }
210}
211
212impl ConnectionSealed for D1Connection {}
213
214fn construct_bind_data<T>(query: &T) -> Result<Array, diesel::result::Error>
215where
216    T: QueryFragment<D1Backend>,
217{
218    let mut bind_collector = D1BindCollector::default();
219
220    query.collect_binds(&mut bind_collector, &mut (), &D1Backend)?;
221
222    let array = bind_collector
223        .binds
224        .iter()
225        .map(|(bind, _)| bind)
226        .collect::<Array>();
227    Ok(array)
228}
229
230fn prepare_statement_sql<'conn, 'query, T>(source: T, binding: &D1Database) -> D1PreparedStatement
231where
232    T: QueryFragment<D1Backend> + QueryId + 'query,
233{
234    let mut query_builder = D1QueryBuilder::default();
235    source.to_sql(&mut query_builder, &D1Backend).unwrap();
236    let result = match binding.prepare(&query_builder.sql) {
237        Ok(res) => res,
238        Err(err) => {
239            console_error!("{:?}", err);
240            panic!("not supposed to happen d1preparedstatement");
241        },
242    };
243
244    let binds = construct_bind_data(&source).unwrap();
245
246    match result.bind(binds) {
247        Ok(res) => res,
248        Err(err) => {
249            console_error!("{:?}", err);
250            panic!("not supposed to happen bind");
251        },
252    }
253}