use super::ConnectionState;
use crate::{error::Error, SqliteConnection, SqliteError};
use libsqlite3_sys::{
sqlite3_deserialize, sqlite3_free, sqlite3_malloc64, sqlite3_serialize,
SQLITE_DESERIALIZE_FREEONCLOSE, SQLITE_DESERIALIZE_READONLY, SQLITE_DESERIALIZE_RESIZEABLE,
SQLITE_NOMEM, SQLITE_OK,
};
use std::ffi::c_char;
use std::fmt::Debug;
use std::{
ops::{Deref, DerefMut},
ptr,
ptr::NonNull,
};
impl SqliteConnection {
pub async fn serialize(&mut self, schema: Option<&str>) -> Result<SqliteOwnedBuf, Error> {
let schema = schema.map(SchemaName::try_from).transpose()?;
self.worker.serialize(schema).await
}
pub async fn deserialize(
&mut self,
schema: Option<&str>,
data: SqliteOwnedBuf,
read_only: bool,
) -> Result<(), Error> {
let schema = schema.map(SchemaName::try_from).transpose()?;
self.worker.deserialize(schema, data, read_only).await
}
}
pub(crate) fn serialize(
conn: &mut ConnectionState,
schema: Option<SchemaName>,
) -> Result<SqliteOwnedBuf, Error> {
let mut size = 0;
let buf = unsafe {
let ptr = sqlite3_serialize(
conn.handle.as_ptr(),
schema.as_ref().map_or(ptr::null(), SchemaName::as_ptr),
&mut size,
0,
);
usize::try_from(size)
.ok()
.and_then(|size| SqliteOwnedBuf::from_raw(ptr, size))
};
if let Some(buf) = buf {
return Ok(buf);
}
if let Some(error) = conn.handle.last_error() {
return Err(error.into());
}
if size > 0 {
return Err(SqliteError::from_code(SQLITE_NOMEM).into());
}
Err(SqliteError::generic(format!(
"database {} does not exist",
schema.as_ref().map_or("main", SchemaName::as_str)
))
.into())
}
pub(crate) fn deserialize(
conn: &mut ConnectionState,
schema: Option<SchemaName>,
data: SqliteOwnedBuf,
read_only: bool,
) -> Result<(), Error> {
let mut flags = SQLITE_DESERIALIZE_FREEONCLOSE;
if read_only {
flags |= SQLITE_DESERIALIZE_READONLY;
} else {
flags |= SQLITE_DESERIALIZE_RESIZEABLE;
}
let (buf, size) = data.into_raw();
let rc = unsafe {
sqlite3_deserialize(
conn.handle.as_ptr(),
schema.as_ref().map_or(ptr::null(), SchemaName::as_ptr),
buf,
i64::try_from(size).unwrap(),
i64::try_from(size).unwrap(),
flags,
)
};
match rc {
SQLITE_OK => Ok(()),
SQLITE_NOMEM => Err(SqliteError::from_code(SQLITE_NOMEM).into()),
_ => Err(SqliteError::generic("an error occurred during deserialization").into()),
}
}
#[derive(Debug)]
pub struct SqliteOwnedBuf {
ptr: NonNull<u8>,
size: usize,
}
unsafe impl Send for SqliteOwnedBuf {}
unsafe impl Sync for SqliteOwnedBuf {}
impl Drop for SqliteOwnedBuf {
fn drop(&mut self) {
unsafe {
sqlite3_free(self.ptr.as_ptr().cast());
}
}
}
impl SqliteOwnedBuf {
unsafe fn with_capacity(size: usize) -> Option<SqliteOwnedBuf> {
let ptr = sqlite3_malloc64(u64::try_from(size).unwrap()).cast::<u8>();
Self::from_raw(ptr, size)
}
unsafe fn from_raw(ptr: *mut u8, size: usize) -> Option<Self> {
Some(Self {
ptr: NonNull::new(ptr)?,
size,
})
}
fn into_raw(self) -> (*mut u8, usize) {
let raw = (self.ptr.as_ptr(), self.size);
std::mem::forget(self);
raw
}
}
impl TryFrom<&[u8]> for SqliteOwnedBuf {
type Error = Error;
fn try_from(bytes: &[u8]) -> Result<Self, Self::Error> {
unsafe {
let mut buf = Self::with_capacity(bytes.len()).ok_or_else(|| {
Error::InvalidArgument("SQLite owned buffer cannot be empty".to_string())
})?;
ptr::copy_nonoverlapping(bytes.as_ptr(), buf.ptr.as_mut(), buf.size);
Ok(buf)
}
}
}
impl Deref for SqliteOwnedBuf {
type Target = [u8];
fn deref(&self) -> &Self::Target {
unsafe { std::slice::from_raw_parts(self.ptr.as_ptr(), self.size) }
}
}
impl DerefMut for SqliteOwnedBuf {
fn deref_mut(&mut self) -> &mut Self::Target {
unsafe { std::slice::from_raw_parts_mut(self.ptr.as_mut(), self.size) }
}
}
impl AsRef<[u8]> for SqliteOwnedBuf {
fn as_ref(&self) -> &[u8] {
self.deref()
}
}
impl AsMut<[u8]> for SqliteOwnedBuf {
fn as_mut(&mut self) -> &mut [u8] {
self.deref_mut()
}
}
#[derive(Debug)]
pub(crate) struct SchemaName(Box<str>);
impl SchemaName {
pub fn as_str(&self) -> &str {
&self.0[..self.0.len() - 1]
}
pub fn as_ptr(&self) -> *const c_char {
self.0.as_ptr() as *const c_char
}
}
impl<'a> TryFrom<&'a str> for SchemaName {
type Error = Error;
fn try_from(name: &'a str) -> Result<Self, Self::Error> {
if let Some(pos) = name.as_bytes().iter().position(|&b| b == 0) {
return Err(Error::InvalidArgument(format!(
"schema name {name:?} contains a zero byte at index {pos}"
)));
}
let capacity = name.len().checked_add(1).unwrap();
let mut s = String::new();
s.reserve_exact(capacity);
s.push_str(name);
s.push('\0');
Ok(SchemaName(s.into()))
}
}