1use async_trait::async_trait;
2use backend::D1Backend;
3use bind_collector::D1BindCollector;
4use binding::{D1Database, D1PreparedStatement, D1Result};
5use diesel::{
6 connection::{ConnectionSealed, Instrumentation},
7 query_builder::{AsQuery, QueryFragment, QueryId},
8 ConnectionResult, QueryResult,
9};
10use diesel_async::{AsyncConnection, SimpleAsyncConnection};
11use futures_util::{
12 future::BoxFuture,
13 stream::{self, BoxStream},
14 FutureExt, StreamExt,
15};
16use js_sys::{Array, Object, Reflect};
17use query_builder::D1QueryBuilder;
18use row::D1Row;
19use transaction_manager::D1TransactionManager;
20use utils::{D1Error, SendableFuture};
21use wasm_bindgen::JsValue;
22use wasm_bindgen_futures::JsFuture;
23use worker::{console_error, console_log};
24
25pub mod backend;
26mod bind_collector;
27mod binding;
28mod query_builder;
29mod row;
30mod transaction_manager;
31mod types;
32mod utils;
33mod value;
34
35pub struct D1Connection {
36 transaction_queries: Vec<D1PreparedStatement>,
37 transaction_manager: D1TransactionManager,
38 binding: D1Database,
39}
40
41impl D1Connection {
42 pub fn new(env: worker::Env, name: &str) -> Self {
43 let binding: D1Database = Reflect::get(&env, &name.to_owned().into()).unwrap().into();
44 D1Connection {
45 transaction_queries: Vec::default(),
46 transaction_manager: D1TransactionManager::default(),
47 binding,
48 }
49 }
50}
51
52unsafe impl Send for D1Connection {}
54unsafe impl Sync for D1Connection {}
55
56#[async_trait]
57impl SimpleAsyncConnection for D1Connection {
58 async fn batch_execute(&mut self, query: &str) -> diesel::QueryResult<()> {
59 let statements = [JsValue::from_str(query)].iter().collect::<Array>();
60
61 match SendableFuture(JsFuture::from(self.binding.batch(statements).unwrap())).await {
62 Ok(_) => Ok(()),
63 Err(_) => Err(diesel::result::Error::NotFound),
65 }
66 }
67}
68#[async_trait]
69impl AsyncConnection for D1Connection {
70 type Backend = D1Backend;
71 type TransactionManager = D1TransactionManager;
72
73 #[doc = " The future returned by `AsyncConnection::execute`"]
74 type ExecuteFuture<'conn, 'query> = BoxFuture<'conn, QueryResult<usize>>;
75
76 #[doc = " The future returned by `AsyncConnection::load`"]
77 type LoadFuture<'conn, 'query> = BoxFuture<'conn, QueryResult<Self::Stream<'conn, 'query>>>;
78
79 #[doc = " The inner stream returned by `AsyncConnection::load`"]
80 type Stream<'conn, 'query> = BoxStream<'conn, QueryResult<Self::Row<'conn, 'query>>>;
81
82 #[doc = " The row type used by the stream returned by `AsyncConnection::load`"]
83 type Row<'conn, 'query> = D1Row;
84
85 async fn establish(_unused: &str) -> ConnectionResult<Self> {
86 todo!()
87 }
88
89 fn load<'conn, 'query, T>(&'conn mut self, source: T) -> Self::LoadFuture<'conn, 'query>
90 where
91 T: AsQuery + 'query,
92 T::Query: QueryFragment<Self::Backend> + QueryId + 'query,
93 {
94 let source = source.as_query();
95 let result = prepare_statement_sql(source, &self.binding);
96
97 SendableFuture(async move {
98 let promise = match result.all() {
99 Ok(res) => res,
100 Err(err) => {
101 console_error!("{:?}", err);
102 panic!("not supposed to happen .all call");
103 },
104 };
105
106 let result = match SendableFuture(JsFuture::from(promise)).await {
107 Ok(res) => res,
108 Err(err) => {
109 console_error!("{:?}", err);
110 panic!("not supposed to happen .all promise");
111 },
112 };
113
114 let result: D1Result = result.into();
115
116 let error = result.error().unwrap();
117
118 if let Some(error_str) = error {
119 return Err(diesel::result::Error::DatabaseError(
120 diesel::result::DatabaseErrorKind::Unknown,
121 Box::new(D1Error { message: error_str }),
122 ));
123 }
124
125 let array = result.results().unwrap().unwrap().to_vec();
126
127 if array.is_empty() {
128 return Ok(stream::iter(vec![]).boxed());
129 }
130
131 let field_keys: Vec<String> = js_sys::Object::keys(&Object::from(array[0].clone()))
132 .to_vec()
133 .iter()
134 .map(|val| val.as_string().unwrap())
135 .collect();
136
137 let rows: Vec<QueryResult<D1Row>> = array
139 .iter()
140 .map(|val| Ok(D1Row::new(val.clone(), field_keys.clone())))
141 .collect();
142 let iter = stream::iter(rows).boxed();
143 Ok(iter)
144 })
145 .boxed()
146 }
147
148 #[doc(hidden)]
149 fn execute_returning_count<'conn, 'query, T>(
150 &'conn mut self,
151 source: T,
152 ) -> Self::ExecuteFuture<'conn, 'query>
153 where
154 T: QueryFragment<Self::Backend> + QueryId + 'query,
155 {
156 let result = prepare_statement_sql(source, &self.binding);
157 SendableFuture(async move {
158 let promise = match result.all() {
159 Ok(res) => res,
160 Err(err) => {
161 console_error!("{:?}", err);
162 panic!("not supposed to happen .all call");
163 },
164 };
165
166 let result = match SendableFuture(JsFuture::from(promise)).await {
167 Ok(res) => res,
168 Err(err) => {
169 console_error!("{:?}", err);
170 panic!("not supposed to happen .all promise");
171 },
172 };
173
174 let result: D1Result = result.into();
175
176 let error = result.error().unwrap();
177
178 if let Some(error_str) = error {
179 return Err(diesel::result::Error::DatabaseError(
180 diesel::result::DatabaseErrorKind::Unknown,
181 Box::new(D1Error { message: error_str }),
182 ));
183 }
184
185 let meta = result.meta().unwrap();
187 let value = js_sys::Reflect::get(&meta, &"changes".to_owned().into())
188 .unwrap()
189 .as_f64()
190 .unwrap();
191
192 Ok(value as usize)
193 })
194 .boxed()
195 }
196
197 fn transaction_state(&mut self) -> &mut D1TransactionManager {
198 &mut self.transaction_manager
199 }
200
201 #[doc(hidden)]
202 fn instrumentation(&mut self) -> &mut dyn Instrumentation {
203 todo!()
204 }
205
206 #[doc = " Set a specific [`Instrumentation`] implementation for this connection"]
207 fn set_instrumentation(&mut self, _instrumentation: impl Instrumentation) {
208 todo!()
209 }
210}
211
212impl ConnectionSealed for D1Connection {}
213
214fn construct_bind_data<T>(query: &T) -> Result<Array, diesel::result::Error>
215where
216 T: QueryFragment<D1Backend>,
217{
218 let mut bind_collector = D1BindCollector::default();
219
220 query.collect_binds(&mut bind_collector, &mut (), &D1Backend)?;
221
222 let array = bind_collector
223 .binds
224 .iter()
225 .map(|(bind, _)| bind)
226 .collect::<Array>();
227 Ok(array)
228}
229
230fn prepare_statement_sql<'conn, 'query, T>(source: T, binding: &D1Database) -> D1PreparedStatement
231where
232 T: QueryFragment<D1Backend> + QueryId + 'query,
233{
234 let mut query_builder = D1QueryBuilder::default();
235 source.to_sql(&mut query_builder, &D1Backend).unwrap();
236 let result = match binding.prepare(&query_builder.sql) {
237 Ok(res) => res,
238 Err(err) => {
239 console_error!("{:?}", err);
240 panic!("not supposed to happen d1preparedstatement");
241 },
242 };
243
244 let binds = construct_bind_data(&source).unwrap();
245
246 match result.bind(binds) {
247 Ok(res) => res,
248 Err(err) => {
249 console_error!("{:?}", err);
250 panic!("not supposed to happen bind");
251 },
252 }
253}