Skip to main content

diesel_libsql/
connection.rs

1//! LibSql connection implementation.
2
3use std::sync::Arc;
4
5use diesel::connection::*;
6use diesel::expression::QueryMetadata;
7use diesel::query_builder::*;
8use diesel::result::*;
9use diesel::sql_types::TypeMetadata;
10use diesel::QueryResult;
11
12use crate::backend::LibSql;
13use crate::bind_collector::LibSqlBindCollector;
14use crate::row::LibSqlRow;
15use crate::value::LibSqlValue;
16
17/// Wrapper around a tokio runtime handle that works whether or not
18/// we're already inside a tokio runtime.
19struct TokioRuntime {
20    runtime: Option<tokio::runtime::Runtime>,
21}
22
23impl TokioRuntime {
24    fn new() -> Self {
25        let runtime = if tokio::runtime::Handle::try_current().is_ok() {
26            None
27        } else {
28            Some(
29                tokio::runtime::Runtime::new()
30                    .expect("Failed to create tokio runtime for LibSqlConnection"),
31            )
32        };
33        TokioRuntime { runtime }
34    }
35
36    fn block_on<F: std::future::Future>(&self, future: F) -> F::Output {
37        match &self.runtime {
38            Some(rt) => rt.block_on(future),
39            None => {
40                tokio::task::block_in_place(|| tokio::runtime::Handle::current().block_on(future))
41            }
42        }
43    }
44}
45
46/// A Diesel connection backed by libsql.
47///
48/// Supports local SQLite databases (`:memory:` and file-based) as well as
49/// remote Turso databases and embedded replicas.
50#[allow(missing_debug_implementations)]
51pub struct LibSqlConnection {
52    database: libsql::Database,
53    connection: libsql::Connection,
54    runtime: TokioRuntime,
55    transaction_state: AnsiTransactionManager,
56    metadata_lookup: (),
57    instrumentation: DynInstrumentation,
58    /// Whether this connection is backed by an embedded replica.
59    is_replica: bool,
60}
61
62// Safety: LibSqlConnection is only used from a single thread at a time.
63// The libsql connection is not shared across threads.
64#[allow(unsafe_code)]
65unsafe impl Send for LibSqlConnection {}
66
67impl LibSqlConnection {
68    fn establish_inner(database_url: &str) -> ConnectionResult<Self> {
69        let runtime = TokioRuntime::new();
70
71        let is_remote = database_url.starts_with("libsql://")
72            || database_url.starts_with("https://")
73            || database_url.starts_with("http://");
74
75        let database = if is_remote {
76            // Parse auth token from ?authToken=TOKEN query param or LIBSQL_AUTH_TOKEN env var
77            let (url, auth_token) = parse_remote_url(database_url)?;
78            runtime
79                .block_on(libsql::Builder::new_remote(url, auth_token).build())
80                .map_err(|e| ConnectionError::BadConnection(e.to_string()))?
81        } else {
82            runtime
83                .block_on(libsql::Builder::new_local(database_url).build())
84                .map_err(|e| ConnectionError::BadConnection(e.to_string()))?
85        };
86
87        let connection = database
88            .connect()
89            .map_err(|e| ConnectionError::BadConnection(e.to_string()))?;
90
91        Ok(LibSqlConnection {
92            database,
93            connection,
94            runtime,
95            transaction_state: AnsiTransactionManager::default(),
96            metadata_lookup: (),
97            instrumentation: DynInstrumentation::none(),
98            is_replica: false,
99        })
100    }
101
102    /// Establish an embedded replica connection.
103    ///
104    /// The replica maintains a local SQLite file at `local_path` that syncs
105    /// from `remote_url` using the provided `auth_token`. Reads are served
106    /// locally; writes are delegated to the remote primary.
107    ///
108    /// Call [`sync`](Self::sync) to pull the latest state from the remote.
109    pub fn establish_replica(
110        local_path: &str,
111        remote_url: &str,
112        auth_token: &str,
113    ) -> ConnectionResult<Self> {
114        let runtime = TokioRuntime::new();
115
116        let database = runtime
117            .block_on(
118                libsql::Builder::new_remote_replica(
119                    local_path,
120                    remote_url.to_string(),
121                    auth_token.to_string(),
122                )
123                .build(),
124            )
125            .map_err(|e| ConnectionError::BadConnection(e.to_string()))?;
126
127        let connection = database
128            .connect()
129            .map_err(|e| ConnectionError::BadConnection(e.to_string()))?;
130
131        Ok(LibSqlConnection {
132            database,
133            connection,
134            runtime,
135            transaction_state: AnsiTransactionManager::default(),
136            metadata_lookup: (),
137            instrumentation: DynInstrumentation::none(),
138            is_replica: true,
139        })
140    }
141
142    /// Sync the embedded replica with the remote primary.
143    ///
144    /// Returns `Ok(())` on success. If this connection is not a replica
145    /// (i.e., it is a local or pure-remote connection), this is a no-op.
146    pub fn sync(&mut self) -> QueryResult<()> {
147        if !self.is_replica {
148            return Ok(());
149        }
150        self.runtime.block_on(self.database.sync()).map_err(|e| {
151            Error::DatabaseError(DatabaseErrorKind::Unknown, Box::new(e.to_string()))
152        })?;
153        Ok(())
154    }
155
156    /// Execute a libSQL-specific `ALTER TABLE ... ALTER COLUMN ... TO ...` statement.
157    ///
158    /// The `new_definition` should include the column name, type, and any constraints.
159    /// For example:
160    /// ```ignore
161    /// conn.alter_column("users", "name", "name TEXT NOT NULL DEFAULT 'unknown'")?;
162    /// ```
163    /// This generates: `ALTER TABLE users ALTER COLUMN name TO name TEXT NOT NULL DEFAULT 'unknown'`
164    pub fn alter_column(
165        &mut self,
166        table: &str,
167        column: &str,
168        new_definition: &str,
169    ) -> QueryResult<()> {
170        let sql = format!(
171            "ALTER TABLE {} ALTER COLUMN {} TO {}",
172            table, column, new_definition
173        );
174        self.batch_execute(&sql)
175    }
176
177    /// Run a transaction with `BEGIN IMMEDIATE`.
178    ///
179    /// Acquires a reserved lock immediately, preventing other writers.
180    /// Useful when you know you will write and want to avoid `SQLITE_BUSY`.
181    pub fn immediate_transaction<T, E, F>(&mut self, f: F) -> Result<T, E>
182    where
183        F: FnOnce(&mut Self) -> Result<T, E>,
184        E: From<diesel::result::Error>,
185    {
186        self.batch_execute("BEGIN IMMEDIATE")?;
187        match f(self) {
188            Ok(value) => {
189                self.batch_execute("COMMIT")?;
190                Ok(value)
191            }
192            Err(e) => {
193                let _ = self.batch_execute("ROLLBACK");
194                Err(e)
195            }
196        }
197    }
198
199    /// Run a transaction with `BEGIN EXCLUSIVE`.
200    ///
201    /// Acquires an exclusive lock immediately, preventing all other connections
202    /// from reading or writing.
203    pub fn exclusive_transaction<T, E, F>(&mut self, f: F) -> Result<T, E>
204    where
205        F: FnOnce(&mut Self) -> Result<T, E>,
206        E: From<diesel::result::Error>,
207    {
208        self.batch_execute("BEGIN EXCLUSIVE")?;
209        match f(self) {
210            Ok(value) => {
211                self.batch_execute("COMMIT")?;
212                Ok(value)
213            }
214            Err(e) => {
215                let _ = self.batch_execute("ROLLBACK");
216                Err(e)
217            }
218        }
219    }
220
221    /// Returns the row ID of the last successful `INSERT`.
222    ///
223    /// Returns `0` if no `INSERT` has been performed on this connection.
224    pub fn last_insert_rowid(&self) -> i64 {
225        self.connection.last_insert_rowid()
226    }
227
228    /// Create a [`ReplicaBuilder`] for configuring an embedded replica connection.
229    pub fn replica_builder(
230        local_path: impl Into<String>,
231        remote_url: impl Into<String>,
232        auth_token: impl Into<String>,
233    ) -> ReplicaBuilder {
234        ReplicaBuilder::new(local_path, remote_url, auth_token)
235    }
236
237    /// Establish a local connection with encryption at rest.
238    ///
239    /// Uses AES-256-CBC encryption. The key must be exactly 32 bytes.
240    #[cfg(feature = "encryption")]
241    pub fn establish_encrypted(
242        database_url: &str,
243        encryption_key: Vec<u8>,
244    ) -> ConnectionResult<Self> {
245        let runtime = TokioRuntime::new();
246        let config =
247            libsql::EncryptionConfig::new(libsql::Cipher::Aes256Cbc, encryption_key.into());
248        let database = runtime
249            .block_on(
250                libsql::Builder::new_local(database_url)
251                    .encryption_config(config)
252                    .build(),
253            )
254            .map_err(|e| ConnectionError::BadConnection(e.to_string()))?;
255
256        let connection = database
257            .connect()
258            .map_err(|e| ConnectionError::BadConnection(e.to_string()))?;
259
260        Ok(LibSqlConnection {
261            database,
262            connection,
263            runtime,
264            transaction_state: AnsiTransactionManager::default(),
265            metadata_lookup: (),
266            instrumentation: DynInstrumentation::none(),
267            is_replica: false,
268        })
269    }
270
271    fn run_query(&mut self, sql: &str, params: Vec<libsql::Value>) -> QueryResult<Vec<LibSqlRow>> {
272        self.runtime.block_on(async {
273            let stmt = self.connection.prepare(sql).await.map_err(|e| {
274                Error::DatabaseError(DatabaseErrorKind::Unknown, Box::new(e.to_string()))
275            })?;
276
277            let rows_result = stmt.query(params).await.map_err(|e| {
278                Error::DatabaseError(DatabaseErrorKind::Unknown, Box::new(e.to_string()))
279            })?;
280
281            Self::collect_rows(rows_result).await
282        })
283    }
284
285    pub(crate) async fn collect_rows(mut rows: libsql::Rows) -> QueryResult<Vec<LibSqlRow>> {
286        let column_count = rows.column_count();
287        let column_names: Arc<[Option<String>]> = (0..column_count)
288            .map(|i| rows.column_name(i).map(|s| s.to_string()))
289            .collect::<Vec<_>>()
290            .into();
291
292        let mut result = Vec::new();
293        while let Some(row) = rows.next().await.map_err(|e| {
294            Error::DatabaseError(DatabaseErrorKind::Unknown, Box::new(e.to_string()))
295        })? {
296            let mut values = Vec::with_capacity(column_count as usize);
297            for i in 0..column_count {
298                let value = row.get_value(i).map_err(|e| {
299                    Error::DatabaseError(DatabaseErrorKind::Unknown, Box::new(e.to_string()))
300                })?;
301                values.push(Some(libsql_value_to_owned(value)));
302            }
303            result.push(LibSqlRow {
304                values,
305                column_names: column_names.clone(),
306            });
307        }
308        Ok(result)
309    }
310
311    fn execute_sql(&mut self, sql: &str, params: Vec<libsql::Value>) -> QueryResult<usize> {
312        self.runtime.block_on(async {
313            match self.connection.execute(sql, params.clone()).await {
314                Ok(affected) => Ok(affected as usize),
315                Err(libsql::Error::ExecuteReturnedRows) => {
316                    // libsql's execute() rejects SELECT statements. Fall back to
317                    // query() and return the row count. This happens when diesel's
318                    // migration harness runs SELECT via execute_returning_count().
319                    let mut rows = self
320                        .connection
321                        .query(sql, params)
322                        .await
323                        .map_err(|e| {
324                            Error::DatabaseError(
325                                DatabaseErrorKind::Unknown,
326                                Box::new(e.to_string()),
327                            )
328                        })?;
329                    let mut count = 0usize;
330                    while rows.next().await.map_err(|e| {
331                        Error::DatabaseError(
332                            DatabaseErrorKind::Unknown,
333                            Box::new(e.to_string()),
334                        )
335                    })?.is_some() {
336                        count += 1;
337                    }
338                    Ok(count)
339                }
340                Err(e) => Err(Error::DatabaseError(
341                    DatabaseErrorKind::Unknown,
342                    Box::new(e.to_string()),
343                )),
344            }
345        })
346    }
347}
348
349/// Extract SQL string and owned params from a query source.
350pub(crate) fn build_query<T>(
351    source: &T,
352    metadata_lookup: &mut (),
353) -> QueryResult<(String, Vec<libsql::Value>)>
354where
355    T: QueryFragment<LibSql>,
356{
357    let mut qb = <LibSql as diesel::backend::Backend>::QueryBuilder::default();
358    source.to_sql(&mut qb, &LibSql)?;
359    let sql = qb.finish();
360
361    let mut bind_collector = LibSqlBindCollector::default();
362    source.collect_binds(&mut bind_collector, metadata_lookup, &LibSql)?;
363
364    let params: Vec<libsql::Value> = bind_collector
365        .binds
366        .iter()
367        .map(|(bind, _ty)| bind.to_libsql_value())
368        .collect();
369
370    Ok((sql, params))
371}
372
373impl SimpleConnection for LibSqlConnection {
374    fn batch_execute(&mut self, query: &str) -> QueryResult<()> {
375        self.instrumentation
376            .on_connection_event(InstrumentationEvent::start_query(&StrQueryHelper::new(
377                query,
378            )));
379
380        let result = self.runtime.block_on(async {
381            self.connection.execute_batch(query).await.map_err(|e| {
382                Error::DatabaseError(DatabaseErrorKind::Unknown, Box::new(e.to_string()))
383            })
384        });
385
386        let result = result.map(|_| ());
387
388        self.instrumentation
389            .on_connection_event(InstrumentationEvent::finish_query(
390                &StrQueryHelper::new(query),
391                result.as_ref().err(),
392            ));
393
394        result
395    }
396}
397
398impl ConnectionSealed for LibSqlConnection {}
399
400impl Connection for LibSqlConnection {
401    type Backend = LibSql;
402    type TransactionManager = AnsiTransactionManager;
403
404    fn establish(database_url: &str) -> ConnectionResult<Self> {
405        let mut instrumentation = diesel::connection::get_default_instrumentation();
406        instrumentation.on_connection_event(InstrumentationEvent::start_establish_connection(
407            database_url,
408        ));
409
410        let establish_result = Self::establish_inner(database_url);
411        instrumentation.on_connection_event(InstrumentationEvent::finish_establish_connection(
412            database_url,
413            establish_result.as_ref().err(),
414        ));
415
416        let mut conn = establish_result?;
417        conn.instrumentation = instrumentation.into();
418        Ok(conn)
419    }
420
421    fn execute_returning_count<T>(&mut self, source: &T) -> QueryResult<usize>
422    where
423        T: QueryFragment<Self::Backend> + QueryId,
424    {
425        let (sql, params) = build_query(source, &mut self.metadata_lookup)?;
426
427        self.instrumentation
428            .on_connection_event(InstrumentationEvent::start_query(&StrQueryHelper::new(
429                &sql,
430            )));
431
432        let result = self.execute_sql(&sql, params);
433
434        self.instrumentation
435            .on_connection_event(InstrumentationEvent::finish_query(
436                &StrQueryHelper::new(&sql),
437                result.as_ref().err(),
438            ));
439
440        result
441    }
442
443    fn transaction_state(&mut self) -> &mut AnsiTransactionManager
444    where
445        Self: Sized,
446    {
447        &mut self.transaction_state
448    }
449
450    fn instrumentation(&mut self) -> &mut dyn Instrumentation {
451        &mut *self.instrumentation
452    }
453
454    fn set_instrumentation(&mut self, instrumentation: impl Instrumentation) {
455        self.instrumentation = instrumentation.into();
456    }
457
458    fn set_prepared_statement_cache_size(&mut self, _size: CacheSize) {
459        // No-op: we don't use a prepared statement cache currently
460    }
461}
462
463/// Iterator over rows returned from a query.
464pub struct LibSqlCursor {
465    rows: std::vec::IntoIter<LibSqlRow>,
466}
467
468impl Iterator for LibSqlCursor {
469    type Item = QueryResult<LibSqlRow>;
470
471    fn next(&mut self) -> Option<Self::Item> {
472        self.rows.next().map(Ok)
473    }
474}
475
476impl LoadConnection<DefaultLoadingMode> for LibSqlConnection {
477    type Cursor<'conn, 'query> = LibSqlCursor;
478    type Row<'conn, 'query> = LibSqlRow;
479
480    fn load<'conn, 'query, T>(
481        &'conn mut self,
482        source: T,
483    ) -> QueryResult<Self::Cursor<'conn, 'query>>
484    where
485        T: Query + QueryFragment<Self::Backend> + QueryId + 'query,
486        Self::Backend: QueryMetadata<T::SqlType>,
487    {
488        let (sql, params) = build_query(&source, &mut self.metadata_lookup)?;
489
490        self.instrumentation
491            .on_connection_event(InstrumentationEvent::start_query(&StrQueryHelper::new(
492                &sql,
493            )));
494
495        let result = self.run_query(&sql, params);
496
497        self.instrumentation
498            .on_connection_event(InstrumentationEvent::finish_query(
499                &StrQueryHelper::new(&sql),
500                result.as_ref().err(),
501            ));
502
503        let rows = result?;
504        Ok(LibSqlCursor {
505            rows: rows.into_iter(),
506        })
507    }
508}
509
510impl diesel::migration::MigrationConnection for LibSqlConnection {
511    fn setup(&mut self) -> QueryResult<usize> {
512        use diesel::RunQueryDsl;
513        diesel::sql_query(diesel::migration::CREATE_MIGRATIONS_TABLE).execute(self)
514    }
515}
516
517impl WithMetadataLookup for LibSqlConnection {
518    fn metadata_lookup(&mut self) -> &mut <LibSql as TypeMetadata>::MetadataLookup {
519        &mut self.metadata_lookup
520    }
521}
522
523impl MultiConnectionHelper for LibSqlConnection {
524    fn to_any<'a>(
525        lookup: &mut <Self::Backend as TypeMetadata>::MetadataLookup,
526    ) -> &mut (dyn std::any::Any + 'a) {
527        lookup
528    }
529
530    fn from_any(
531        lookup: &mut dyn std::any::Any,
532    ) -> Option<&mut <Self::Backend as TypeMetadata>::MetadataLookup> {
533        lookup.downcast_mut()
534    }
535}
536
537/// Parse a remote URL into (url, auth_token).
538///
539/// The auth token is extracted from a `?authToken=TOKEN` query parameter if present,
540/// otherwise from the `LIBSQL_AUTH_TOKEN` environment variable.
541pub(crate) fn parse_remote_url(database_url: &str) -> ConnectionResult<(String, String)> {
542    // Check for ?authToken= query parameter
543    if let Some(idx) = database_url.find("?authToken=") {
544        let url = database_url[..idx].to_string();
545        let token_start = idx + "?authToken=".len();
546        // Token ends at next & or end of string
547        let token = if let Some(amp) = database_url[token_start..].find('&') {
548            &database_url[token_start..token_start + amp]
549        } else {
550            &database_url[token_start..]
551        };
552        if token.is_empty() {
553            return Err(ConnectionError::BadConnection(
554                "authToken query parameter is empty".to_string(),
555            ));
556        }
557        return Ok((url, token.to_string()));
558    }
559
560    // Also check for &authToken= in case it's not the first param
561    if let Some(idx) = database_url.find("&authToken=") {
562        let url = database_url[..database_url.find('?').unwrap_or(idx)].to_string();
563        let token_start = idx + "&authToken=".len();
564        let token = if let Some(amp) = database_url[token_start..].find('&') {
565            &database_url[token_start..token_start + amp]
566        } else {
567            &database_url[token_start..]
568        };
569        if token.is_empty() {
570            return Err(ConnectionError::BadConnection(
571                "authToken query parameter is empty".to_string(),
572            ));
573        }
574        return Ok((url, token.to_string()));
575    }
576
577    // Fall back to env var
578    match std::env::var("LIBSQL_AUTH_TOKEN") {
579        Ok(token) if !token.is_empty() => Ok((database_url.to_string(), token)),
580        _ => Err(ConnectionError::BadConnection(
581            "No auth token provided: use ?authToken=TOKEN in the URL or set LIBSQL_AUTH_TOKEN"
582                .to_string(),
583        )),
584    }
585}
586
587/// Builder for embedded replica connections with advanced configuration.
588///
589/// Created via [`LibSqlConnection::replica_builder`]. Allows setting
590/// `sync_interval` and `read_your_writes` before establishing the connection.
591pub struct ReplicaBuilder {
592    local_path: String,
593    remote_url: String,
594    auth_token: String,
595    sync_interval: Option<std::time::Duration>,
596    read_your_writes: bool,
597}
598
599impl ReplicaBuilder {
600    /// Create a new replica builder.
601    pub fn new(
602        local_path: impl Into<String>,
603        remote_url: impl Into<String>,
604        auth_token: impl Into<String>,
605    ) -> Self {
606        Self {
607            local_path: local_path.into(),
608            remote_url: remote_url.into(),
609            auth_token: auth_token.into(),
610            sync_interval: None,
611            read_your_writes: true,
612        }
613    }
614
615    /// Set automatic sync interval. The replica will periodically pull
616    /// from the remote primary at this interval.
617    pub fn sync_interval(mut self, interval: std::time::Duration) -> Self {
618        self.sync_interval = Some(interval);
619        self
620    }
621
622    /// Enable or disable read-your-writes consistency (default: true).
623    ///
624    /// When enabled, after a successful write the local replica immediately
625    /// reflects the change without waiting for `sync()`.
626    pub fn read_your_writes(mut self, enabled: bool) -> Self {
627        self.read_your_writes = enabled;
628        self
629    }
630
631    /// Build and establish the replica connection.
632    pub fn establish(self) -> ConnectionResult<LibSqlConnection> {
633        let runtime = TokioRuntime::new();
634        let mut builder =
635            libsql::Builder::new_remote_replica(self.local_path, self.remote_url, self.auth_token)
636                .read_your_writes(self.read_your_writes);
637
638        if let Some(interval) = self.sync_interval {
639            builder = builder.sync_interval(interval);
640        }
641
642        let database = runtime
643            .block_on(builder.build())
644            .map_err(|e| ConnectionError::BadConnection(e.to_string()))?;
645
646        let connection = database
647            .connect()
648            .map_err(|e| ConnectionError::BadConnection(e.to_string()))?;
649
650        Ok(LibSqlConnection {
651            database,
652            connection,
653            runtime,
654            transaction_state: AnsiTransactionManager::default(),
655            metadata_lookup: (),
656            instrumentation: DynInstrumentation::none(),
657            is_replica: true,
658        })
659    }
660
661    /// Build and establish the replica connection asynchronously.
662    #[cfg(feature = "async")]
663    pub async fn establish_async(
664        self,
665    ) -> ConnectionResult<crate::async_conn::AsyncLibSqlConnection> {
666        let mut builder =
667            libsql::Builder::new_remote_replica(self.local_path, self.remote_url, self.auth_token)
668                .read_your_writes(self.read_your_writes);
669
670        if let Some(interval) = self.sync_interval {
671            builder = builder.sync_interval(interval);
672        }
673
674        let database = builder
675            .build()
676            .await
677            .map_err(|e| ConnectionError::BadConnection(e.to_string()))?;
678
679        let connection = database
680            .connect()
681            .map_err(|e| ConnectionError::BadConnection(e.to_string()))?;
682
683        Ok(crate::async_conn::AsyncLibSqlConnection::from_parts(
684            database, connection,
685        ))
686    }
687}
688
689/// Convert a `libsql::Value` to our owned `LibSqlValue`.
690pub(crate) fn libsql_value_to_owned(value: libsql::Value) -> LibSqlValue {
691    match value {
692        libsql::Value::Null => LibSqlValue::Null,
693        libsql::Value::Integer(i) => LibSqlValue::Integer(i),
694        libsql::Value::Real(f) => LibSqlValue::Real(f),
695        libsql::Value::Text(s) => LibSqlValue::Text(s),
696        libsql::Value::Blob(b) => LibSqlValue::Blob(b),
697    }
698}