1use std::sync::Arc;
4
5use diesel::connection::*;
6use diesel::expression::QueryMetadata;
7use diesel::query_builder::*;
8use diesel::result::*;
9use diesel::sql_types::TypeMetadata;
10use diesel::QueryResult;
11
12use crate::backend::LibSql;
13use crate::bind_collector::LibSqlBindCollector;
14use crate::row::LibSqlRow;
15use crate::value::LibSqlValue;
16
17struct TokioRuntime {
20 runtime: Option<tokio::runtime::Runtime>,
21}
22
23impl TokioRuntime {
24 fn new() -> Self {
25 let runtime = if tokio::runtime::Handle::try_current().is_ok() {
26 None
27 } else {
28 Some(
29 tokio::runtime::Runtime::new()
30 .expect("Failed to create tokio runtime for LibSqlConnection"),
31 )
32 };
33 TokioRuntime { runtime }
34 }
35
36 fn block_on<F: std::future::Future>(&self, future: F) -> F::Output {
37 match &self.runtime {
38 Some(rt) => rt.block_on(future),
39 None => {
40 tokio::task::block_in_place(|| tokio::runtime::Handle::current().block_on(future))
41 }
42 }
43 }
44}
45
46#[allow(missing_debug_implementations)]
51pub struct LibSqlConnection {
52 database: libsql::Database,
53 connection: libsql::Connection,
54 runtime: TokioRuntime,
55 transaction_state: AnsiTransactionManager,
56 metadata_lookup: (),
57 instrumentation: DynInstrumentation,
58 is_replica: bool,
60}
61
62#[allow(unsafe_code)]
65unsafe impl Send for LibSqlConnection {}
66
67impl LibSqlConnection {
68 fn establish_inner(database_url: &str) -> ConnectionResult<Self> {
69 let runtime = TokioRuntime::new();
70
71 let is_remote = database_url.starts_with("libsql://")
72 || database_url.starts_with("https://")
73 || database_url.starts_with("http://");
74
75 let database = if is_remote {
76 let (url, auth_token) = parse_remote_url(database_url)?;
78 runtime
79 .block_on(libsql::Builder::new_remote(url, auth_token).build())
80 .map_err(|e| ConnectionError::BadConnection(e.to_string()))?
81 } else {
82 runtime
83 .block_on(libsql::Builder::new_local(database_url).build())
84 .map_err(|e| ConnectionError::BadConnection(e.to_string()))?
85 };
86
87 let connection = database
88 .connect()
89 .map_err(|e| ConnectionError::BadConnection(e.to_string()))?;
90
91 Ok(LibSqlConnection {
92 database,
93 connection,
94 runtime,
95 transaction_state: AnsiTransactionManager::default(),
96 metadata_lookup: (),
97 instrumentation: DynInstrumentation::none(),
98 is_replica: false,
99 })
100 }
101
102 pub fn establish_replica(
110 local_path: &str,
111 remote_url: &str,
112 auth_token: &str,
113 ) -> ConnectionResult<Self> {
114 let runtime = TokioRuntime::new();
115
116 let database = runtime
117 .block_on(
118 libsql::Builder::new_remote_replica(
119 local_path,
120 remote_url.to_string(),
121 auth_token.to_string(),
122 )
123 .build(),
124 )
125 .map_err(|e| ConnectionError::BadConnection(e.to_string()))?;
126
127 let connection = database
128 .connect()
129 .map_err(|e| ConnectionError::BadConnection(e.to_string()))?;
130
131 Ok(LibSqlConnection {
132 database,
133 connection,
134 runtime,
135 transaction_state: AnsiTransactionManager::default(),
136 metadata_lookup: (),
137 instrumentation: DynInstrumentation::none(),
138 is_replica: true,
139 })
140 }
141
142 pub fn sync(&mut self) -> QueryResult<()> {
147 if !self.is_replica {
148 return Ok(());
149 }
150 self.runtime.block_on(self.database.sync()).map_err(|e| {
151 Error::DatabaseError(DatabaseErrorKind::Unknown, Box::new(e.to_string()))
152 })?;
153 Ok(())
154 }
155
156 pub fn alter_column(
165 &mut self,
166 table: &str,
167 column: &str,
168 new_definition: &str,
169 ) -> QueryResult<()> {
170 let sql = format!(
171 "ALTER TABLE {} ALTER COLUMN {} TO {}",
172 table, column, new_definition
173 );
174 self.batch_execute(&sql)
175 }
176
177 pub fn immediate_transaction<T, E, F>(&mut self, f: F) -> Result<T, E>
182 where
183 F: FnOnce(&mut Self) -> Result<T, E>,
184 E: From<diesel::result::Error>,
185 {
186 self.batch_execute("BEGIN IMMEDIATE")?;
187 match f(self) {
188 Ok(value) => {
189 self.batch_execute("COMMIT")?;
190 Ok(value)
191 }
192 Err(e) => {
193 let _ = self.batch_execute("ROLLBACK");
194 Err(e)
195 }
196 }
197 }
198
199 pub fn exclusive_transaction<T, E, F>(&mut self, f: F) -> Result<T, E>
204 where
205 F: FnOnce(&mut Self) -> Result<T, E>,
206 E: From<diesel::result::Error>,
207 {
208 self.batch_execute("BEGIN EXCLUSIVE")?;
209 match f(self) {
210 Ok(value) => {
211 self.batch_execute("COMMIT")?;
212 Ok(value)
213 }
214 Err(e) => {
215 let _ = self.batch_execute("ROLLBACK");
216 Err(e)
217 }
218 }
219 }
220
221 pub fn last_insert_rowid(&self) -> i64 {
225 self.connection.last_insert_rowid()
226 }
227
228 pub fn replica_builder(
230 local_path: impl Into<String>,
231 remote_url: impl Into<String>,
232 auth_token: impl Into<String>,
233 ) -> ReplicaBuilder {
234 ReplicaBuilder::new(local_path, remote_url, auth_token)
235 }
236
237 #[cfg(feature = "encryption")]
241 pub fn establish_encrypted(
242 database_url: &str,
243 encryption_key: Vec<u8>,
244 ) -> ConnectionResult<Self> {
245 let runtime = TokioRuntime::new();
246 let config =
247 libsql::EncryptionConfig::new(libsql::Cipher::Aes256Cbc, encryption_key.into());
248 let database = runtime
249 .block_on(
250 libsql::Builder::new_local(database_url)
251 .encryption_config(config)
252 .build(),
253 )
254 .map_err(|e| ConnectionError::BadConnection(e.to_string()))?;
255
256 let connection = database
257 .connect()
258 .map_err(|e| ConnectionError::BadConnection(e.to_string()))?;
259
260 Ok(LibSqlConnection {
261 database,
262 connection,
263 runtime,
264 transaction_state: AnsiTransactionManager::default(),
265 metadata_lookup: (),
266 instrumentation: DynInstrumentation::none(),
267 is_replica: false,
268 })
269 }
270
271 fn run_query(&mut self, sql: &str, params: Vec<libsql::Value>) -> QueryResult<Vec<LibSqlRow>> {
272 self.runtime.block_on(async {
273 let stmt = self.connection.prepare(sql).await.map_err(|e| {
274 Error::DatabaseError(DatabaseErrorKind::Unknown, Box::new(e.to_string()))
275 })?;
276
277 let rows_result = stmt.query(params).await.map_err(|e| {
278 Error::DatabaseError(DatabaseErrorKind::Unknown, Box::new(e.to_string()))
279 })?;
280
281 Self::collect_rows(rows_result).await
282 })
283 }
284
285 pub(crate) async fn collect_rows(mut rows: libsql::Rows) -> QueryResult<Vec<LibSqlRow>> {
286 let column_count = rows.column_count();
287 let column_names: Arc<[Option<String>]> = (0..column_count)
288 .map(|i| rows.column_name(i).map(|s| s.to_string()))
289 .collect::<Vec<_>>()
290 .into();
291
292 let mut result = Vec::new();
293 while let Some(row) = rows.next().await.map_err(|e| {
294 Error::DatabaseError(DatabaseErrorKind::Unknown, Box::new(e.to_string()))
295 })? {
296 let mut values = Vec::with_capacity(column_count as usize);
297 for i in 0..column_count {
298 let value = row.get_value(i).map_err(|e| {
299 Error::DatabaseError(DatabaseErrorKind::Unknown, Box::new(e.to_string()))
300 })?;
301 values.push(Some(libsql_value_to_owned(value)));
302 }
303 result.push(LibSqlRow {
304 values,
305 column_names: column_names.clone(),
306 });
307 }
308 Ok(result)
309 }
310
311 fn execute_sql(&mut self, sql: &str, params: Vec<libsql::Value>) -> QueryResult<usize> {
312 self.runtime.block_on(async {
313 match self.connection.execute(sql, params.clone()).await {
314 Ok(affected) => Ok(affected as usize),
315 Err(libsql::Error::ExecuteReturnedRows) => {
316 let mut rows = self
320 .connection
321 .query(sql, params)
322 .await
323 .map_err(|e| {
324 Error::DatabaseError(
325 DatabaseErrorKind::Unknown,
326 Box::new(e.to_string()),
327 )
328 })?;
329 let mut count = 0usize;
330 while rows.next().await.map_err(|e| {
331 Error::DatabaseError(
332 DatabaseErrorKind::Unknown,
333 Box::new(e.to_string()),
334 )
335 })?.is_some() {
336 count += 1;
337 }
338 Ok(count)
339 }
340 Err(e) => Err(Error::DatabaseError(
341 DatabaseErrorKind::Unknown,
342 Box::new(e.to_string()),
343 )),
344 }
345 })
346 }
347}
348
349pub(crate) fn build_query<T>(
351 source: &T,
352 metadata_lookup: &mut (),
353) -> QueryResult<(String, Vec<libsql::Value>)>
354where
355 T: QueryFragment<LibSql>,
356{
357 let mut qb = <LibSql as diesel::backend::Backend>::QueryBuilder::default();
358 source.to_sql(&mut qb, &LibSql)?;
359 let sql = qb.finish();
360
361 let mut bind_collector = LibSqlBindCollector::default();
362 source.collect_binds(&mut bind_collector, metadata_lookup, &LibSql)?;
363
364 let params: Vec<libsql::Value> = bind_collector
365 .binds
366 .iter()
367 .map(|(bind, _ty)| bind.to_libsql_value())
368 .collect();
369
370 Ok((sql, params))
371}
372
373impl SimpleConnection for LibSqlConnection {
374 fn batch_execute(&mut self, query: &str) -> QueryResult<()> {
375 self.instrumentation
376 .on_connection_event(InstrumentationEvent::start_query(&StrQueryHelper::new(
377 query,
378 )));
379
380 let result = self.runtime.block_on(async {
381 self.connection.execute_batch(query).await.map_err(|e| {
382 Error::DatabaseError(DatabaseErrorKind::Unknown, Box::new(e.to_string()))
383 })
384 });
385
386 let result = result.map(|_| ());
387
388 self.instrumentation
389 .on_connection_event(InstrumentationEvent::finish_query(
390 &StrQueryHelper::new(query),
391 result.as_ref().err(),
392 ));
393
394 result
395 }
396}
397
398impl ConnectionSealed for LibSqlConnection {}
399
400impl Connection for LibSqlConnection {
401 type Backend = LibSql;
402 type TransactionManager = AnsiTransactionManager;
403
404 fn establish(database_url: &str) -> ConnectionResult<Self> {
405 let mut instrumentation = diesel::connection::get_default_instrumentation();
406 instrumentation.on_connection_event(InstrumentationEvent::start_establish_connection(
407 database_url,
408 ));
409
410 let establish_result = Self::establish_inner(database_url);
411 instrumentation.on_connection_event(InstrumentationEvent::finish_establish_connection(
412 database_url,
413 establish_result.as_ref().err(),
414 ));
415
416 let mut conn = establish_result?;
417 conn.instrumentation = instrumentation.into();
418 Ok(conn)
419 }
420
421 fn execute_returning_count<T>(&mut self, source: &T) -> QueryResult<usize>
422 where
423 T: QueryFragment<Self::Backend> + QueryId,
424 {
425 let (sql, params) = build_query(source, &mut self.metadata_lookup)?;
426
427 self.instrumentation
428 .on_connection_event(InstrumentationEvent::start_query(&StrQueryHelper::new(
429 &sql,
430 )));
431
432 let result = self.execute_sql(&sql, params);
433
434 self.instrumentation
435 .on_connection_event(InstrumentationEvent::finish_query(
436 &StrQueryHelper::new(&sql),
437 result.as_ref().err(),
438 ));
439
440 result
441 }
442
443 fn transaction_state(&mut self) -> &mut AnsiTransactionManager
444 where
445 Self: Sized,
446 {
447 &mut self.transaction_state
448 }
449
450 fn instrumentation(&mut self) -> &mut dyn Instrumentation {
451 &mut *self.instrumentation
452 }
453
454 fn set_instrumentation(&mut self, instrumentation: impl Instrumentation) {
455 self.instrumentation = instrumentation.into();
456 }
457
458 fn set_prepared_statement_cache_size(&mut self, _size: CacheSize) {
459 }
461}
462
463pub struct LibSqlCursor {
465 rows: std::vec::IntoIter<LibSqlRow>,
466}
467
468impl Iterator for LibSqlCursor {
469 type Item = QueryResult<LibSqlRow>;
470
471 fn next(&mut self) -> Option<Self::Item> {
472 self.rows.next().map(Ok)
473 }
474}
475
476impl LoadConnection<DefaultLoadingMode> for LibSqlConnection {
477 type Cursor<'conn, 'query> = LibSqlCursor;
478 type Row<'conn, 'query> = LibSqlRow;
479
480 fn load<'conn, 'query, T>(
481 &'conn mut self,
482 source: T,
483 ) -> QueryResult<Self::Cursor<'conn, 'query>>
484 where
485 T: Query + QueryFragment<Self::Backend> + QueryId + 'query,
486 Self::Backend: QueryMetadata<T::SqlType>,
487 {
488 let (sql, params) = build_query(&source, &mut self.metadata_lookup)?;
489
490 self.instrumentation
491 .on_connection_event(InstrumentationEvent::start_query(&StrQueryHelper::new(
492 &sql,
493 )));
494
495 let result = self.run_query(&sql, params);
496
497 self.instrumentation
498 .on_connection_event(InstrumentationEvent::finish_query(
499 &StrQueryHelper::new(&sql),
500 result.as_ref().err(),
501 ));
502
503 let rows = result?;
504 Ok(LibSqlCursor {
505 rows: rows.into_iter(),
506 })
507 }
508}
509
510impl diesel::migration::MigrationConnection for LibSqlConnection {
511 fn setup(&mut self) -> QueryResult<usize> {
512 use diesel::RunQueryDsl;
513 diesel::sql_query(diesel::migration::CREATE_MIGRATIONS_TABLE).execute(self)
514 }
515}
516
517impl WithMetadataLookup for LibSqlConnection {
518 fn metadata_lookup(&mut self) -> &mut <LibSql as TypeMetadata>::MetadataLookup {
519 &mut self.metadata_lookup
520 }
521}
522
523impl MultiConnectionHelper for LibSqlConnection {
524 fn to_any<'a>(
525 lookup: &mut <Self::Backend as TypeMetadata>::MetadataLookup,
526 ) -> &mut (dyn std::any::Any + 'a) {
527 lookup
528 }
529
530 fn from_any(
531 lookup: &mut dyn std::any::Any,
532 ) -> Option<&mut <Self::Backend as TypeMetadata>::MetadataLookup> {
533 lookup.downcast_mut()
534 }
535}
536
537pub(crate) fn parse_remote_url(database_url: &str) -> ConnectionResult<(String, String)> {
542 if let Some(idx) = database_url.find("?authToken=") {
544 let url = database_url[..idx].to_string();
545 let token_start = idx + "?authToken=".len();
546 let token = if let Some(amp) = database_url[token_start..].find('&') {
548 &database_url[token_start..token_start + amp]
549 } else {
550 &database_url[token_start..]
551 };
552 if token.is_empty() {
553 return Err(ConnectionError::BadConnection(
554 "authToken query parameter is empty".to_string(),
555 ));
556 }
557 return Ok((url, token.to_string()));
558 }
559
560 if let Some(idx) = database_url.find("&authToken=") {
562 let url = database_url[..database_url.find('?').unwrap_or(idx)].to_string();
563 let token_start = idx + "&authToken=".len();
564 let token = if let Some(amp) = database_url[token_start..].find('&') {
565 &database_url[token_start..token_start + amp]
566 } else {
567 &database_url[token_start..]
568 };
569 if token.is_empty() {
570 return Err(ConnectionError::BadConnection(
571 "authToken query parameter is empty".to_string(),
572 ));
573 }
574 return Ok((url, token.to_string()));
575 }
576
577 match std::env::var("LIBSQL_AUTH_TOKEN") {
579 Ok(token) if !token.is_empty() => Ok((database_url.to_string(), token)),
580 _ => Err(ConnectionError::BadConnection(
581 "No auth token provided: use ?authToken=TOKEN in the URL or set LIBSQL_AUTH_TOKEN"
582 .to_string(),
583 )),
584 }
585}
586
587pub struct ReplicaBuilder {
592 local_path: String,
593 remote_url: String,
594 auth_token: String,
595 sync_interval: Option<std::time::Duration>,
596 read_your_writes: bool,
597}
598
599impl ReplicaBuilder {
600 pub fn new(
602 local_path: impl Into<String>,
603 remote_url: impl Into<String>,
604 auth_token: impl Into<String>,
605 ) -> Self {
606 Self {
607 local_path: local_path.into(),
608 remote_url: remote_url.into(),
609 auth_token: auth_token.into(),
610 sync_interval: None,
611 read_your_writes: true,
612 }
613 }
614
615 pub fn sync_interval(mut self, interval: std::time::Duration) -> Self {
618 self.sync_interval = Some(interval);
619 self
620 }
621
622 pub fn read_your_writes(mut self, enabled: bool) -> Self {
627 self.read_your_writes = enabled;
628 self
629 }
630
631 pub fn establish(self) -> ConnectionResult<LibSqlConnection> {
633 let runtime = TokioRuntime::new();
634 let mut builder =
635 libsql::Builder::new_remote_replica(self.local_path, self.remote_url, self.auth_token)
636 .read_your_writes(self.read_your_writes);
637
638 if let Some(interval) = self.sync_interval {
639 builder = builder.sync_interval(interval);
640 }
641
642 let database = runtime
643 .block_on(builder.build())
644 .map_err(|e| ConnectionError::BadConnection(e.to_string()))?;
645
646 let connection = database
647 .connect()
648 .map_err(|e| ConnectionError::BadConnection(e.to_string()))?;
649
650 Ok(LibSqlConnection {
651 database,
652 connection,
653 runtime,
654 transaction_state: AnsiTransactionManager::default(),
655 metadata_lookup: (),
656 instrumentation: DynInstrumentation::none(),
657 is_replica: true,
658 })
659 }
660
661 #[cfg(feature = "async")]
663 pub async fn establish_async(
664 self,
665 ) -> ConnectionResult<crate::async_conn::AsyncLibSqlConnection> {
666 let mut builder =
667 libsql::Builder::new_remote_replica(self.local_path, self.remote_url, self.auth_token)
668 .read_your_writes(self.read_your_writes);
669
670 if let Some(interval) = self.sync_interval {
671 builder = builder.sync_interval(interval);
672 }
673
674 let database = builder
675 .build()
676 .await
677 .map_err(|e| ConnectionError::BadConnection(e.to_string()))?;
678
679 let connection = database
680 .connect()
681 .map_err(|e| ConnectionError::BadConnection(e.to_string()))?;
682
683 Ok(crate::async_conn::AsyncLibSqlConnection::from_parts(
684 database, connection,
685 ))
686 }
687}
688
689pub(crate) fn libsql_value_to_owned(value: libsql::Value) -> LibSqlValue {
691 match value {
692 libsql::Value::Null => LibSqlValue::Null,
693 libsql::Value::Integer(i) => LibSqlValue::Integer(i),
694 libsql::Value::Real(f) => LibSqlValue::Real(f),
695 libsql::Value::Text(s) => LibSqlValue::Text(s),
696 libsql::Value::Blob(b) => LibSqlValue::Blob(b),
697 }
698}