use std::collections::HashMap;
use std::sync::Mutex;
use crate::sys::*;
use crate::core::*;
use crate::ffi::*;
use super::error::IcdError;
pub struct TimelineState {
semaphore: VkSemaphore,
current_value: u64,
pending_count: u32,
}
pub struct BatchSubmission {
command_buffers: Vec<VkCommandBuffer>,
wait_semaphores: Vec<VkSemaphore>,
wait_values: Vec<u64>,
wait_stages: Vec<VkPipelineStageFlags>,
#[allow(dead_code)]
signal_value: u64,
}
impl BatchSubmission {
pub fn new() -> Self {
Self {
command_buffers: Vec::with_capacity(256),
wait_semaphores: Vec::new(),
wait_values: Vec::new(),
wait_stages: Vec::new(),
signal_value: 0,
}
}
pub fn add_command_buffer(&mut self, cb: VkCommandBuffer) {
self.command_buffers.push(cb);
}
pub fn add_wait(&mut self, semaphore: VkSemaphore, value: u64, stage: VkPipelineStageFlags) {
self.wait_semaphores.push(semaphore);
self.wait_values.push(value);
self.wait_stages.push(stage);
}
}
pub struct TimelineManager {
timelines: HashMap<u64, TimelineState>,
batches: HashMap<u64, BatchSubmission>,
batch_size: u32,
}
lazy_static::lazy_static! {
static ref TIMELINE_MANAGER: Mutex<TimelineManager> = Mutex::new(TimelineManager {
timelines: HashMap::new(),
batches: HashMap::new(),
batch_size: 16, });
}
pub unsafe fn create_timeline_semaphore(
device: VkDevice,
initial_value: u64,
) -> Result<VkSemaphore, IcdError> {
let timeline_info = VkSemaphoreTypeCreateInfo {
sType: VkStructureType::SemaphoreTypeCreateInfo,
pNext: std::ptr::null(),
semaphoreType: VkSemaphoreType::Timeline,
initialValue: initial_value,
};
let create_info = VkSemaphoreCreateInfo {
sType: VkStructureType::SemaphoreCreateInfo,
pNext: &timeline_info as *const _ as *const std::ffi::c_void,
flags: 0,
};
let mut semaphore = VkSemaphore::NULL;
if let Some(icd) = super::icd_loader::get_icd() {
if let Some(create_fn) = icd.create_semaphore {
let result = create_fn(device, &create_info, std::ptr::null(), &mut semaphore);
if result == VkResult::Success {
return Ok(semaphore);
}
return Err(IcdError::VulkanError(result));
}
}
Err(IcdError::MissingFunction("vkCreateSemaphore"))
}
pub unsafe fn get_queue_timeline(
device: VkDevice,
queue: VkQueue,
) -> Result<(VkSemaphore, u64), IcdError> {
let mut manager = TIMELINE_MANAGER.lock()?;
let queue_key = queue.as_raw();
if let Some(state) = manager.timelines.get(&queue_key) {
Ok((state.semaphore, state.current_value))
} else {
let semaphore = create_timeline_semaphore(device, 0)?;
let state = TimelineState {
semaphore,
current_value: 0,
pending_count: 0,
};
let current_value = state.current_value;
manager.timelines.insert(queue_key, state);
Ok((semaphore, current_value))
}
}
pub fn begin_batch(queue: VkQueue) -> Result<(), IcdError> {
let mut manager = TIMELINE_MANAGER.lock()?;
let queue_key = queue.as_raw();
if !manager.batches.contains_key(&queue_key) {
manager.batches.insert(queue_key, BatchSubmission::new());
}
Ok(())
}
pub fn add_to_batch(
queue: VkQueue,
command_buffer: VkCommandBuffer,
) -> Result<bool, IcdError> {
let mut manager = TIMELINE_MANAGER.lock()?;
let queue_key = queue.as_raw();
let batch = manager.batches.get_mut(&queue_key)
.ok_or(IcdError::InvalidOperation("No active batch"))?;
batch.add_command_buffer(command_buffer);
let should_submit = batch.command_buffers.len() >= manager.batch_size as usize;
if let Some(timeline) = manager.timelines.get_mut(&queue_key) {
timeline.pending_count += 1;
}
Ok(should_submit)
}
pub unsafe fn submit_batch(
queue: VkQueue,
fence: VkFence,
) -> Result<u64, IcdError> {
let mut manager = TIMELINE_MANAGER.lock()?;
let queue_key = queue.as_raw();
let batch = manager.batches.remove(&queue_key)
.ok_or(IcdError::InvalidOperation("No active batch"))?;
if batch.command_buffers.is_empty() {
return Ok(0); }
let timeline = manager.timelines.get_mut(&queue_key)
.ok_or(IcdError::InvalidOperation("No timeline for queue"))?;
timeline.current_value += 1;
let signal_value = timeline.current_value;
let timeline_info = VkTimelineSemaphoreSubmitInfo {
sType: VkStructureType::TimelineSemaphoreSubmitInfo,
pNext: std::ptr::null(),
waitSemaphoreValueCount: batch.wait_values.len() as u32,
pWaitSemaphoreValues: if batch.wait_values.is_empty() {
std::ptr::null()
} else {
batch.wait_values.as_ptr()
},
signalSemaphoreValueCount: 1,
pSignalSemaphoreValues: &signal_value,
};
let submit_info = VkSubmitInfo {
sType: VkStructureType::SubmitInfo,
pNext: &timeline_info as *const _ as *const std::ffi::c_void,
waitSemaphoreCount: batch.wait_semaphores.len() as u32,
pWaitSemaphores: if batch.wait_semaphores.is_empty() {
std::ptr::null()
} else {
batch.wait_semaphores.as_ptr()
},
pWaitDstStageMask: if batch.wait_stages.is_empty() {
std::ptr::null()
} else {
batch.wait_stages.as_ptr()
},
commandBufferCount: batch.command_buffers.len() as u32,
pCommandBuffers: batch.command_buffers.as_ptr(),
signalSemaphoreCount: 1,
pSignalSemaphores: &timeline.semaphore,
};
if let Some(icd) = super::icd_loader::get_icd() {
if let Some(submit_fn) = icd.queue_submit {
let result = submit_fn(queue, 1, &submit_info, fence);
if result != VkResult::Success {
return Err(IcdError::VulkanError(result));
}
} else {
return Err(IcdError::MissingFunction("vkQueueSubmit"));
}
} else {
return Err(IcdError::NoIcdLoaded);
}
timeline.pending_count = 0;
Ok(signal_value)
}
pub unsafe fn wait_timeline(
device: VkDevice,
queue: VkQueue,
value: u64,
timeout: u64,
) -> Result<(), IcdError> {
let manager = TIMELINE_MANAGER.lock()?;
let queue_key = queue.as_raw();
let timeline = manager.timelines.get(&queue_key)
.ok_or(IcdError::InvalidOperation("No timeline for queue"))?;
let wait_info = VkSemaphoreWaitInfo {
sType: VkStructureType::SemaphoreWaitInfo,
pNext: std::ptr::null(),
flags: VkSemaphoreWaitFlags::empty(),
semaphoreCount: 1,
pSemaphores: &timeline.semaphore,
pValues: &value,
};
if let Some(icd) = super::icd_loader::get_icd() {
if let Some(wait_fn) = icd.wait_semaphores {
let result = wait_fn(device, &wait_info, timeout);
if result != VkResult::Success && result != VkResult::Timeout {
return Err(IcdError::VulkanError(result));
}
} else {
return Err(IcdError::MissingFunction("vkWaitSemaphores"));
}
}
Ok(())
}
pub struct BatchBuilder {
queue: VkQueue,
command_buffers: Vec<VkCommandBuffer>,
}
impl BatchBuilder {
pub fn new(queue: VkQueue) -> Self {
Self {
queue,
command_buffers: Vec::new(),
}
}
pub fn add_command_buffer(mut self, cb: VkCommandBuffer) -> Self {
self.command_buffers.push(cb);
self
}
pub unsafe fn submit(self) -> Result<u64, IcdError> {
begin_batch(self.queue)?;
for cb in self.command_buffers {
add_to_batch(self.queue, cb)?;
}
submit_batch(self.queue, VkFence::NULL)
}
}
#[derive(Default, Debug)]
pub struct BatchStats {
pub total_submissions: u64,
pub total_command_buffers: u64,
pub average_batch_size: f64,
pub timeline_waits: u64,
}
impl BatchStats {
pub fn record_submission(&mut self, batch_size: usize) {
self.total_submissions += 1;
self.total_command_buffers += batch_size as u64;
self.average_batch_size = self.total_command_buffers as f64 / self.total_submissions as f64;
}
}
pub fn get_batch_stats() -> BatchStats {
BatchStats::default()
}
pub fn set_batch_size(size: u32) -> Result<(), IcdError> {
let mut manager = TIMELINE_MANAGER.lock()?;
manager.batch_size = size;
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_batch_builder() {
let queue = VkQueue::from_raw(0x1234);
let cb1 = VkCommandBuffer::from_raw(0x5678);
let cb2 = VkCommandBuffer::from_raw(0x9ABC);
let builder = BatchBuilder::new(queue)
.add_command_buffer(cb1)
.add_command_buffer(cb2);
assert_eq!(builder.command_buffers.len(), 2);
}
}