use std::ffi::c_void;
use std::io;
use std::pin::Pin;
use std::sync::OnceLock;
use std::task::{Context, Poll};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
use crate::error::Error;
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
#[non_exhaustive]
pub enum FileStreamAccess {
Read,
Write,
ReadWrite,
}
impl FileStreamAccess {
fn to_raw(self) -> u32 {
match self {
Self::Read => SQL_FILESTREAM_READ,
Self::Write => SQL_FILESTREAM_WRITE,
Self::ReadWrite => SQL_FILESTREAM_READWRITE,
}
}
}
pub mod open_options {
pub const NONE: u32 = 0x0000_0000;
pub const SEQUENTIAL_SCAN: u32 = 0x0000_0008;
pub const ASYNC: u32 = 0x0000_0001;
}
pub struct FileStream {
inner: tokio::fs::File,
}
impl FileStream {
pub fn open(path: &str, access: FileStreamAccess, txn_context: &[u8]) -> Result<Self, Error> {
Self::open_with_options(path, access, txn_context, open_options::NONE)
}
pub fn open_with_options(
path: &str,
access: FileStreamAccess,
txn_context: &[u8],
options: u32,
) -> Result<Self, Error> {
let open_fn = load_open_sql_filestream()?;
let path_wide: Vec<u16> = path.encode_utf16().chain(std::iter::once(0)).collect();
let handle = unsafe {
open_fn(
path_wide.as_ptr(),
access.to_raw(),
options,
txn_context.as_ptr(),
txn_context.len(),
std::ptr::null(), )
};
if handle == INVALID_HANDLE_VALUE || handle.is_null() {
let err_code = unsafe { GetLastError() };
let message = format_win32_error(err_code);
return Err(Error::FileStream(format!(
"OpenSqlFilestream failed: {message} (Win32 error {err_code})"
)));
}
let file = unsafe {
use std::os::windows::io::{FromRawHandle, OwnedHandle, RawHandle};
let owned = OwnedHandle::from_raw_handle(handle as RawHandle);
let std_file = std::fs::File::from(owned);
tokio::fs::File::from_std(std_file)
};
Ok(Self { inner: file })
}
#[must_use]
pub fn as_tokio_file(&self) -> &tokio::fs::File {
&self.inner
}
pub fn as_tokio_file_mut(&mut self) -> &mut tokio::fs::File {
&mut self.inner
}
#[must_use]
pub fn into_tokio_file(self) -> tokio::fs::File {
self.inner
}
}
impl AsyncRead for FileStream {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &mut ReadBuf<'_>,
) -> Poll<io::Result<()>> {
Pin::new(&mut self.inner).poll_read(cx, buf)
}
}
impl AsyncWrite for FileStream {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.inner).poll_write(cx, buf)
}
fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.inner).poll_flush(cx)
}
fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.inner).poll_shutdown(cx)
}
}
impl std::fmt::Debug for FileStream {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("FileStream")
.field("inner", &self.inner)
.finish()
}
}
const SQL_FILESTREAM_READ: u32 = 0;
const SQL_FILESTREAM_WRITE: u32 = 1;
const SQL_FILESTREAM_READWRITE: u32 = 2;
const INVALID_HANDLE_VALUE: *mut c_void = -1_isize as *mut c_void;
type OpenSqlFilestreamFn = unsafe extern "system" fn(
filestream_path: *const u16, desired_access: u32, open_options: u32, filestream_txn_context: *const u8, filestream_txn_context_len: usize, allocation_size: *const i64, ) -> *mut c_void;
unsafe extern "system" {
fn LoadLibraryW(name: *const u16) -> *mut c_void;
fn GetProcAddress(module: *mut c_void, name: *const u8) -> *mut c_void;
fn GetLastError() -> u32;
fn FormatMessageW(
flags: u32,
source: *const c_void,
message_id: u32,
language_id: u32,
buffer: *mut u16,
size: u32,
arguments: *const c_void,
) -> u32;
}
fn format_win32_error(error_code: u32) -> String {
const FORMAT_MESSAGE_FROM_SYSTEM: u32 = 0x0000_1000;
const FORMAT_MESSAGE_IGNORE_INSERTS: u32 = 0x0000_0200;
let mut buf = [0u16; 512];
let len = unsafe {
FormatMessageW(
FORMAT_MESSAGE_FROM_SYSTEM | FORMAT_MESSAGE_IGNORE_INSERTS,
std::ptr::null(),
error_code,
0, buf.as_mut_ptr(),
buf.len() as u32,
std::ptr::null(),
)
};
if len == 0 {
return format!("Unknown error (0x{error_code:08X})");
}
let s = String::from_utf16_lossy(&buf[..len as usize]);
s.trim_end().to_string()
}
static OPEN_SQL_FILESTREAM: OnceLock<Result<OpenSqlFilestreamFn, String>> = OnceLock::new();
const DLL_SEARCH_ORDER: &[&str] = &["msoledbsql19.dll", "msoledbsql.dll", "sqlncli11.dll"];
fn load_open_sql_filestream() -> Result<OpenSqlFilestreamFn, Error> {
OPEN_SQL_FILESTREAM
.get_or_init(|| {
for dll_name in DLL_SEARCH_ORDER {
let dll_wide: Vec<u16> =
dll_name.encode_utf16().chain(std::iter::once(0)).collect();
let module = unsafe { LoadLibraryW(dll_wide.as_ptr()) };
if module.is_null() {
continue;
}
let proc = unsafe { GetProcAddress(module, c"OpenSqlFilestream".as_ptr().cast()) };
if proc.is_null() {
continue;
}
tracing::debug!(dll = dll_name, "Loaded OpenSqlFilestream");
let func: OpenSqlFilestreamFn = unsafe { std::mem::transmute(proc) };
return Ok(func);
}
Err(format!(
"FILESTREAM driver not found. Install the Microsoft OLE DB Driver for SQL Server. \
Searched: {}",
DLL_SEARCH_ORDER.join(", ")
))
})
.clone()
.map_err(Error::FileStream)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_filestream_access_raw_values() {
assert_eq!(FileStreamAccess::Read.to_raw(), 0);
assert_eq!(FileStreamAccess::Write.to_raw(), 1);
assert_eq!(FileStreamAccess::ReadWrite.to_raw(), 2);
}
#[test]
fn test_dll_search_order() {
assert_eq!(DLL_SEARCH_ORDER[0], "msoledbsql19.dll");
assert_eq!(DLL_SEARCH_ORDER[1], "msoledbsql.dll");
assert_eq!(DLL_SEARCH_ORDER[2], "sqlncli11.dll");
}
#[test]
fn test_load_open_sql_filestream() {
let result = load_open_sql_filestream();
match result {
Ok(_) => {
let result2 = load_open_sql_filestream();
assert!(result2.is_ok(), "Second call should also succeed (cached)");
}
Err(e) => {
let msg = format!("{e}");
assert!(
msg.contains("FILESTREAM driver not found"),
"Error should indicate missing driver: {msg}"
);
}
}
}
}