use super::requests::{IoRequest, RequestState};
use super::{DEFAULT_URING_BLOCK_SIZE, DEFAULT_URING_IO_PARALLELISM, URING_BLOCK_SIZE};
use crate::local::to_local_path;
use crate::traits::Reader;
use crate::uring::DEFAULT_URING_QUEUE_DEPTH;
use crate::utils::tracking_store::IOTracker;
use bytes::{Bytes, BytesMut};
use deepsize::DeepSizeOf;
use futures::future::BoxFuture;
use futures::{FutureExt, TryFutureExt};
use io_uring::{IoUring, opcode, types};
use lance_core::{Error, Result};
use object_store::path::Path;
use std::cell::{LazyCell, RefCell};
use std::collections::HashMap;
use std::fs::File;
use std::future::Future;
use std::io::{self, ErrorKind};
use std::ops::Range;
use std::pin::Pin;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::{Arc, Mutex};
use tracing::instrument;
use super::reader::{CacheKey, CachedReaderData, HANDLE_CACHE, UringFileHandle};
static USER_DATA_COUNTER: AtomicU64 = AtomicU64::new(1);
struct ThreadLocalUring {
ring: IoUring,
pending: HashMap<u64, Arc<IoRequest>>,
}
thread_local! {
static URING: LazyCell<RefCell<ThreadLocalUring>> = LazyCell::new(|| {
let queue_depth = std::env::var("LANCE_URING_QUEUE_DEPTH")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(DEFAULT_URING_QUEUE_DEPTH);
let ring = IoUring::builder()
.setup_defer_taskrun()
.setup_single_issuer()
.build(queue_depth as u32)
.expect("Failed to create io_uring");
log::debug!(
"Created thread-local io_uring with queue depth {}",
queue_depth
);
RefCell::new(ThreadLocalUring {
ring,
pending: HashMap::new(),
})
});
}
pub(super) fn push_request(request: Arc<IoRequest>) -> io::Result<()> {
URING.with(|cell| {
let mut uring = cell.borrow_mut();
let user_data = USER_DATA_COUNTER.fetch_add(1, Ordering::Relaxed);
let (buffer_ptr, read_offset, read_length) = {
let state = request.state.lock().unwrap();
let br = state.bytes_read;
(
unsafe { state.buffer.as_ptr().add(br) as *mut u8 },
request.offset + br as u64,
(request.length - br) as u32,
)
};
let read_op =
opcode::Read::new(types::Fd(request.fd), buffer_ptr, read_length).offset(read_offset);
let mut sq = uring.ring.submission();
if sq.is_full() {
drop(sq);
return Err(io::Error::new(
io::ErrorKind::WouldBlock,
"io_uring submission queue full",
));
}
unsafe {
sq.push(&read_op.build().user_data(user_data))
.map_err(|_| io::Error::other("Failed to push to SQ"))?;
}
drop(sq);
uring.pending.insert(user_data, request);
Ok(())
})
}
pub(super) fn process_thread_local_completions() -> io::Result<usize> {
URING.with(|cell| {
let mut uring = cell.borrow_mut();
let mut completed = 0;
let mut retries: Vec<Arc<IoRequest>> = Vec::new();
let cqes: Vec<_> = uring
.ring
.completion()
.map(|cqe| (cqe.user_data(), cqe.result()))
.collect();
for (user_data, result) in cqes {
if let Some(request) = uring.pending.remove(&user_data) {
let mut state = request.state.lock().unwrap();
if result < 0 {
state.err = Some(io::Error::from_raw_os_error(-result));
state.completed = true;
} else if result == 0 {
let br = state.bytes_read;
state.err = Some(io::Error::new(
io::ErrorKind::UnexpectedEof,
format!("unexpected EOF: read {} of {} bytes", br, request.length),
));
state.buffer.truncate(br);
state.completed = true;
} else {
let n = result as usize;
state.bytes_read += n;
let br = state.bytes_read;
if br >= request.length {
state.buffer.truncate(br);
state.completed = true;
} else {
drop(state);
retries.push(request);
continue;
}
}
if let Some(waker) = state.waker.take() {
drop(state);
waker.wake();
}
completed += 1;
} else {
log::warn!("Received completion for unknown user_data: {}", user_data);
}
}
for request in retries {
let user_data = USER_DATA_COUNTER.fetch_add(1, Ordering::Relaxed);
let (buffer_ptr, read_offset, read_length) = {
let state = request.state.lock().unwrap();
let br = state.bytes_read;
(
unsafe { state.buffer.as_ptr().add(br) as *mut u8 },
request.offset + br as u64,
(request.length - br) as u32,
)
};
let read_op = opcode::Read::new(types::Fd(request.fd), buffer_ptr, read_length)
.offset(read_offset);
let mut sq = uring.ring.submission();
if sq.is_full() {
drop(sq);
request.fail(io::Error::new(
io::ErrorKind::WouldBlock,
"io_uring submission queue full during retry",
));
continue;
}
unsafe {
if sq.push(&read_op.build().user_data(user_data)).is_err() {
request.fail(io::Error::other("Failed to push short-read retry to SQ"));
continue;
}
}
drop(sq);
uring.pending.insert(user_data, request);
}
if completed > 0 {
log::trace!("Processed {} completions", completed);
}
Ok(completed)
})
}
pub(super) fn submit_and_wait_thread_local() -> io::Result<()> {
URING.with(|cell| {
let uring = cell.borrow_mut();
uring.ring.submit_and_wait(1)?;
Ok(())
})
}
#[derive(Debug)]
pub struct UringCurrentThreadReader {
handle: Arc<UringFileHandle>,
block_size: usize,
size: usize,
io_tracker: Arc<IOTracker>,
}
impl DeepSizeOf for UringCurrentThreadReader {
fn deep_size_of_children(&self, context: &mut deepsize::Context) -> usize {
self.handle.path.as_ref().deep_size_of_children(context)
}
}
impl UringCurrentThreadReader {
#[instrument(level = "debug")]
pub(crate) async fn open(
path: &Path,
block_size: usize,
known_size: Option<usize>,
io_tracker: Arc<IOTracker>,
) -> Result<Box<dyn Reader>> {
let block_size = URING_BLOCK_SIZE.unwrap_or(block_size.max(DEFAULT_URING_BLOCK_SIZE));
let cache_key = CacheKey::new(path, block_size);
if let Some(data) = HANDLE_CACHE.get(&cache_key).await {
let size = known_size.unwrap_or(data.size);
return Ok(Box::new(Self {
handle: data.handle,
block_size,
size,
io_tracker,
}) as Box<dyn Reader>);
}
let path_clone = path.clone();
let local_path = to_local_path(path);
let data = tokio::task::spawn_blocking(move || {
let file = File::open(&local_path).map_err(|e| match e.kind() {
ErrorKind::NotFound => Error::not_found(path_clone.to_string()),
_ => e.into(),
})?;
let size = match known_size {
Some(s) => s,
None => file.metadata()?.len() as usize,
};
Ok::<_, Error>(CachedReaderData {
handle: Arc::new(UringFileHandle::new(file, path_clone)),
size,
})
})
.await??;
HANDLE_CACHE.insert(cache_key, data.clone()).await;
Ok(Box::new(Self {
handle: data.handle.clone(),
block_size,
size: data.size,
io_tracker,
}) as Box<dyn Reader>)
}
fn submit_read(
&self,
offset: u64,
length: usize,
) -> Pin<Box<dyn Future<Output = object_store::Result<Bytes>> + Send>> {
let mut buffer = BytesMut::with_capacity(length);
unsafe {
buffer.set_len(length);
}
let request = Arc::new(IoRequest {
fd: self.handle.fd,
offset,
length,
thread_id: std::thread::current().id(),
state: Mutex::new(RequestState {
completed: false,
waker: None,
err: None,
buffer,
bytes_read: 0,
}),
});
match push_request(request.clone()) {
Ok(()) => Box::pin(super::current_thread_future::UringCurrentThreadFuture::new(
request,
)),
Err(e) => Box::pin(async move {
Err(object_store::Error::Generic {
store: "io_uring_ct",
source: Box::new(e),
})
}),
}
}
}
impl Reader for UringCurrentThreadReader {
fn path(&self) -> &Path {
&self.handle.path
}
fn block_size(&self) -> usize {
self.block_size
}
fn io_parallelism(&self) -> usize {
std::env::var("LANCE_URING_IO_PARALLELISM")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(DEFAULT_URING_IO_PARALLELISM)
}
fn size(&self) -> BoxFuture<'_, object_store::Result<usize>> {
Box::pin(async move { Ok(self.size) })
}
#[instrument(level = "debug", skip(self))]
fn get_range(&self, range: Range<usize>) -> BoxFuture<'static, object_store::Result<Bytes>> {
let io_tracker = self.io_tracker.clone();
let path = self.handle.path.clone();
let num_bytes = range.len() as u64;
let range_u64 = (range.start as u64)..(range.end as u64);
self.submit_read(range.start as u64, range.len())
.map_ok(move |bytes| {
io_tracker.record_read("get_range", path, num_bytes, Some(range_u64));
bytes
})
.boxed()
}
#[instrument(level = "debug", skip(self))]
fn get_all(&self) -> BoxFuture<'static, object_store::Result<Bytes>> {
let size = self.size;
let io_tracker = self.io_tracker.clone();
let path = self.handle.path.clone();
self.submit_read(0, size)
.map_ok(move |bytes| {
io_tracker.record_read("get_all", path, bytes.len() as u64, None);
bytes
})
.boxed()
}
}