use super::DEFAULT_URING_QUEUE_DEPTH;
use super::requests::IoRequest;
use io_uring::{IoUring, opcode, types};
use std::collections::HashMap;
use std::io;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::mpsc::{Receiver, RecvTimeoutError, SyncSender, sync_channel};
use std::sync::{Arc, LazyLock};
use std::time::{Duration, Instant};
pub(super) struct UringThreadHandle {
pub request_tx: SyncSender<Arc<IoRequest>>,
}
pub(super) static URING_THREADS: LazyLock<Vec<UringThreadHandle>> = LazyLock::new(|| {
let queue_depth = get_queue_depth();
let thread_count = get_thread_count();
let mut threads = Vec::with_capacity(thread_count);
for i in 0..thread_count {
let (tx, rx) = sync_channel(queue_depth);
std::thread::Builder::new()
.name(format!("lance-uring-{}", i))
.spawn(move || run_uring_thread(rx, queue_depth, i))
.expect("Failed to spawn io_uring thread");
threads.push(UringThreadHandle { request_tx: tx });
}
log::info!(
"io_uring thread pool spawned ({} threads, queue_depth={})",
thread_count,
queue_depth
);
threads
});
pub(super) static THREAD_SELECTOR: AtomicU64 = AtomicU64::new(0);
static USER_DATA_COUNTER: AtomicU64 = AtomicU64::new(1);
pub(super) static SUBMITTED_COUNTER: AtomicU64 = AtomicU64::new(0);
const DEFAULT_SUBMIT_BATCH_SIZE: usize = 128;
const DEFAULT_URING_THREAD_COUNT: usize = 2;
fn get_queue_depth() -> usize {
std::env::var("LANCE_URING_QUEUE_DEPTH")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(DEFAULT_URING_QUEUE_DEPTH)
}
fn get_poll_timeout() -> Duration {
let timeout_ms = std::env::var("LANCE_URING_POLL_TIMEOUT_MS")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(10);
Duration::from_millis(timeout_ms)
}
fn get_submit_batch_size() -> usize {
std::env::var("LANCE_URING_SUBMIT_BATCH_SIZE")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(DEFAULT_SUBMIT_BATCH_SIZE)
}
fn get_thread_count() -> usize {
std::env::var("LANCE_URING_THREAD_COUNT")
.ok()
.and_then(|s| s.parse().ok())
.unwrap_or(DEFAULT_URING_THREAD_COUNT)
}
fn run_uring_thread(request_rx: Receiver<Arc<IoRequest>>, queue_depth: usize, thread_id: usize) {
let mut ring = IoUring::builder()
.build(queue_depth as u32)
.expect("Failed to create io_uring");
let mut pending: HashMap<u64, Arc<IoRequest>> = HashMap::with_capacity(queue_depth);
let poll_timeout = get_poll_timeout();
let submit_batch_size = get_submit_batch_size();
let mut last_log = Instant::now();
let log_interval = Duration::from_millis(100);
let mut completed_iops = 0usize;
let mut completed_sectors = 0usize;
let mut min_in_flight = usize::MAX;
loop {
let in_flight = pending.len();
min_in_flight = min_in_flight.min(in_flight);
let now = Instant::now();
if now.duration_since(last_log) >= log_interval {
let submitted = SUBMITTED_COUNTER.load(Ordering::Relaxed);
log::info!(
"io_uring[{}]: {} submitted, {} in flight (min {}), {} iops completed, {} sectors completed",
thread_id,
submitted,
in_flight,
min_in_flight,
completed_iops,
completed_sectors
);
last_log = now;
completed_iops = 0; completed_sectors = 0; min_in_flight = usize::MAX; }
let mut needs_submit = false;
let completions = process_completions(&mut ring, &mut pending);
match completions {
Ok(result) => {
completed_iops += result.iops;
completed_sectors += result.sectors;
for request in result.retries {
if let Err(e) = push_to_sq(&mut ring, &mut pending, request) {
log::error!("Failed to resubmit short read: {}", e);
} else {
needs_submit = true;
}
}
}
Err(e) => {
log::error!("Error processing io_uring completions: {}", e);
}
}
min_in_flight = min_in_flight.min(pending.len());
let mut batch_count = 0;
loop {
let recv_result = if pending.is_empty() && batch_count == 0 {
request_rx.recv_timeout(poll_timeout).map_err(|e| match e {
RecvTimeoutError::Timeout => std::sync::mpsc::TryRecvError::Empty,
RecvTimeoutError::Disconnected => std::sync::mpsc::TryRecvError::Disconnected,
})
} else {
request_rx.try_recv()
};
match recv_result {
Ok(request) => {
SUBMITTED_COUNTER.fetch_sub(1, Ordering::Relaxed);
if let Err(e) = push_to_sq(&mut ring, &mut pending, request) {
log::error!("Failed to push to io_uring SQ: {}", e);
} else {
batch_count += 1;
}
if batch_count >= submit_batch_size {
break;
}
}
Err(std::sync::mpsc::TryRecvError::Empty) => {
break;
}
Err(std::sync::mpsc::TryRecvError::Disconnected) => {
if batch_count > 0
&& let Err(e) = ring.submit()
{
log::error!(
"io_uring[{}]: Failed to submit io_uring batch: {}",
thread_id,
e
);
}
log::info!(
"io_uring thread {} shutting down (channel disconnected)",
thread_id
);
return;
}
}
}
if (batch_count > 0 || needs_submit)
&& let Err(e) = ring.submit()
{
log::error!(
"Failed to submit io_uring batch of {} requests: {}",
batch_count,
e
);
}
}
}
pub(super) fn push_to_sq(
ring: &mut IoUring,
pending: &mut HashMap<u64, Arc<IoRequest>>,
request: Arc<IoRequest>,
) -> io::Result<()> {
let user_data = USER_DATA_COUNTER.fetch_add(1, Ordering::Relaxed);
let (buffer_ptr, read_offset, read_length) = {
let state = request.state.lock().unwrap();
let br = state.bytes_read;
(
unsafe { state.buffer.as_ptr().add(br) as *mut u8 },
request.offset + br as u64,
(request.length - br) as u32,
)
};
let read_op =
opcode::Read::new(types::Fd(request.fd), buffer_ptr, read_length).offset(read_offset);
let mut sq = ring.submission();
if sq.is_full() {
drop(sq);
request.fail(io::Error::new(
io::ErrorKind::WouldBlock,
"io_uring submission queue full",
));
return Err(io::Error::new(
io::ErrorKind::WouldBlock,
"io_uring submission queue full",
));
}
unsafe {
if sq.push(&read_op.build().user_data(user_data)).is_err() {
drop(sq);
request.fail(io::Error::other("Failed to push to SQ"));
return Err(io::Error::other("Failed to push to SQ"));
}
}
drop(sq);
pending.insert(user_data, request);
Ok(())
}
struct CompletionResult {
iops: usize,
sectors: usize,
retries: Vec<Arc<IoRequest>>,
}
fn process_completions(
ring: &mut IoUring,
pending: &mut HashMap<u64, Arc<IoRequest>>,
) -> io::Result<CompletionResult> {
let mut iops = 0;
let mut sectors = 0;
let mut retries = Vec::new();
for cqe in ring.completion() {
let user_data = cqe.user_data();
let result = cqe.result();
if let Some(request) = pending.remove(&user_data) {
let mut state = request.state.lock().unwrap();
if result < 0 {
state.err = Some(io::Error::from_raw_os_error(-result));
state.completed = true;
} else if result == 0 {
let br = state.bytes_read;
state.err = Some(io::Error::new(
io::ErrorKind::UnexpectedEof,
format!("unexpected EOF: read {} of {} bytes", br, request.length),
));
state.buffer.truncate(br);
state.completed = true;
} else {
let n = result as usize;
state.bytes_read += n;
let br = state.bytes_read;
if br >= request.length {
state.buffer.truncate(br);
state.completed = true;
if request.length > 0 {
let first_sector = request.offset / 4096;
let last_sector = (request.offset + request.length as u64 - 1) / 4096;
let num_sectors = (last_sector - first_sector + 1) as usize;
sectors += num_sectors;
}
} else {
drop(state);
retries.push(request);
continue;
}
}
if let Some(waker) = state.waker.take() {
drop(state); waker.wake();
}
iops += 1;
} else {
log::warn!("Received completion for unknown user_data: {}", user_data);
}
}
Ok(CompletionResult {
iops,
sectors,
retries,
})
}