kronos_compute/implementation/
timeline_batching.rs

1//! Timeline semaphore batching for efficient submission
2//! 
3//! Implements:
4//! - One timeline semaphore per queue
5//! - Batch submissions with single fence
6//! - Target: 30-50% reduction in CPU submit time
7
8use std::collections::HashMap;
9use std::sync::Mutex;
10use crate::sys::*;
11use crate::core::*;
12use crate::ffi::*;
13use super::error::IcdError;
14
15/// Timeline semaphore state per queue
16pub struct TimelineState {
17    /// The timeline semaphore for this queue
18    semaphore: VkSemaphore,
19    /// Current timeline value
20    current_value: u64,
21    /// Pending submissions in current batch
22    pending_count: u32,
23}
24
25/// Batch submission context
26pub struct BatchSubmission {
27    /// Command buffers in this batch
28    command_buffers: Vec<VkCommandBuffer>,
29    /// Wait semaphores (from other queues)
30    wait_semaphores: Vec<VkSemaphore>,
31    wait_values: Vec<u64>,
32    wait_stages: Vec<VkPipelineStageFlags>,
33    /// Signal value for this batch
34    #[allow(dead_code)]
35    signal_value: u64,
36}
37
38impl BatchSubmission {
39    pub fn new() -> Self {
40        Self {
41            command_buffers: Vec::with_capacity(256),
42            wait_semaphores: Vec::new(),
43            wait_values: Vec::new(),
44            wait_stages: Vec::new(),
45            signal_value: 0,
46        }
47    }
48    
49    /// Add a command buffer to the batch
50    pub fn add_command_buffer(&mut self, cb: VkCommandBuffer) {
51        self.command_buffers.push(cb);
52    }
53    
54    /// Add a wait dependency from another queue
55    pub fn add_wait(&mut self, semaphore: VkSemaphore, value: u64, stage: VkPipelineStageFlags) {
56        self.wait_semaphores.push(semaphore);
57        self.wait_values.push(value);
58        self.wait_stages.push(stage);
59    }
60}
61
62/// Global timeline manager
63pub struct TimelineManager {
64    /// Queue -> Timeline state mapping
65    timelines: HashMap<u64, TimelineState>,
66    /// Active batch per queue
67    batches: HashMap<u64, BatchSubmission>,
68    /// Batch size threshold
69    batch_size: u32,
70}
71
72lazy_static::lazy_static! {
73    static ref TIMELINE_MANAGER: Mutex<TimelineManager> = Mutex::new(TimelineManager {
74        timelines: HashMap::new(),
75        batches: HashMap::new(),
76        batch_size: 16, // Default batch size
77    });
78}
79
80/// Create a timeline semaphore
81///
82/// # Safety
83///
84/// This function is unsafe because:
85/// - The device must be a valid VkDevice handle
86/// - Calls vkCreateSemaphore through ICD function pointer
87/// - The returned semaphore must be destroyed with vkDestroySemaphore
88/// - Timeline semaphores require Vulkan 1.2 or VK_KHR_timeline_semaphore
89/// - Invalid device handle will cause undefined behavior
90pub unsafe fn create_timeline_semaphore(
91    device: VkDevice,
92    initial_value: u64,
93) -> Result<VkSemaphore, IcdError> {
94    // Timeline semaphore create info
95    let timeline_info = VkSemaphoreTypeCreateInfo {
96        sType: VkStructureType::SemaphoreTypeCreateInfo,
97        pNext: std::ptr::null(),
98        semaphoreType: VkSemaphoreType::Timeline,
99        initialValue: initial_value,
100    };
101    
102    let create_info = VkSemaphoreCreateInfo {
103        sType: VkStructureType::SemaphoreCreateInfo,
104        pNext: &timeline_info as *const _ as *const std::ffi::c_void,
105        flags: 0,
106    };
107    
108    let mut semaphore = VkSemaphore::NULL;
109    
110    if let Some(icd) = super::icd_loader::get_icd() {
111        if let Some(create_fn) = icd.create_semaphore {
112            let result = create_fn(device, &create_info, std::ptr::null(), &mut semaphore);
113            if result == VkResult::Success {
114                return Ok(semaphore);
115            }
116            return Err(IcdError::VulkanError(result));
117        }
118    }
119    
120    Err(IcdError::MissingFunction("vkCreateSemaphore"))
121}
122
123/// Get or create timeline semaphore for a queue
124///
125/// # Safety
126///
127/// This function is unsafe because:
128/// - Both device and queue must be valid Vulkan handles
129/// - The queue must have been created from the device
130/// - Creates timeline semaphore if not exists (calls unsafe create_timeline_semaphore)
131/// - Thread safety is provided by the global TIMELINE_MANAGER mutex
132/// - The returned semaphore is managed by the timeline manager
133pub unsafe fn get_queue_timeline(
134    device: VkDevice,
135    queue: VkQueue,
136) -> Result<(VkSemaphore, u64), IcdError> {
137    let mut manager = TIMELINE_MANAGER.lock()?;
138    let queue_key = queue.as_raw();
139    
140    if let Some(state) = manager.timelines.get(&queue_key) {
141        Ok((state.semaphore, state.current_value))
142    } else {
143        // Create new timeline semaphore
144        let semaphore = create_timeline_semaphore(device, 0)?;
145        let state = TimelineState {
146            semaphore,
147            current_value: 0,
148            pending_count: 0,
149        };
150        
151        let current_value = state.current_value;
152        manager.timelines.insert(queue_key, state);
153        
154        Ok((semaphore, current_value))
155    }
156}
157
158/// Begin a batch submission
159pub fn begin_batch(queue: VkQueue) -> Result<(), IcdError> {
160    let mut manager = TIMELINE_MANAGER.lock()?;
161    let queue_key = queue.as_raw();
162    
163    if !manager.batches.contains_key(&queue_key) {
164        manager.batches.insert(queue_key, BatchSubmission::new());
165    }
166    
167    Ok(())
168}
169
170/// Add command buffer to current batch
171pub fn add_to_batch(
172    queue: VkQueue,
173    command_buffer: VkCommandBuffer,
174) -> Result<bool, IcdError> {
175    let mut manager = TIMELINE_MANAGER.lock()?;
176    let queue_key = queue.as_raw();
177    
178    let batch = manager.batches.get_mut(&queue_key)
179        .ok_or(IcdError::InvalidOperation("No active batch"))?;
180    
181    batch.add_command_buffer(command_buffer);
182    
183    // Check if batch is full
184    let should_submit = batch.command_buffers.len() >= manager.batch_size as usize;
185    
186    if let Some(timeline) = manager.timelines.get_mut(&queue_key) {
187        timeline.pending_count += 1;
188    }
189    
190    Ok(should_submit)
191}
192
193/// Submit the current batch
194///
195/// # Safety
196///
197/// This function is unsafe because:
198/// - The queue must be a valid VkQueue handle
199/// - The fence (if not NULL) must be a valid VkFence handle
200/// - All command buffers in the batch must be in executable state
201/// - Calls vkQueueSubmit with timeline semaphore info
202/// - The queue must not be in use by another thread
203/// - Timeline semaphore operations require proper synchronization
204pub unsafe fn submit_batch(
205    queue: VkQueue,
206    fence: VkFence,
207) -> Result<u64, IcdError> {
208    let mut manager = TIMELINE_MANAGER.lock()?;
209    let queue_key = queue.as_raw();
210    
211    let batch = manager.batches.remove(&queue_key)
212        .ok_or(IcdError::InvalidOperation("No active batch"))?;
213    
214    if batch.command_buffers.is_empty() {
215        return Ok(0); // Nothing to submit
216    }
217    
218    let timeline = manager.timelines.get_mut(&queue_key)
219        .ok_or(IcdError::InvalidOperation("No timeline for queue"))?;
220    
221    // Increment timeline value for this batch
222    timeline.current_value += 1;
223    let signal_value = timeline.current_value;
224    
225    // Build timeline submit info
226    let timeline_info = VkTimelineSemaphoreSubmitInfo {
227        sType: VkStructureType::TimelineSemaphoreSubmitInfo,
228        pNext: std::ptr::null(),
229        waitSemaphoreValueCount: batch.wait_values.len() as u32,
230        pWaitSemaphoreValues: if batch.wait_values.is_empty() {
231            std::ptr::null()
232        } else {
233            batch.wait_values.as_ptr()
234        },
235        signalSemaphoreValueCount: 1,
236        pSignalSemaphoreValues: &signal_value,
237    };
238    
239    // Build submit info
240    let submit_info = VkSubmitInfo {
241        sType: VkStructureType::SubmitInfo,
242        pNext: &timeline_info as *const _ as *const std::ffi::c_void,
243        waitSemaphoreCount: batch.wait_semaphores.len() as u32,
244        pWaitSemaphores: if batch.wait_semaphores.is_empty() {
245            std::ptr::null()
246        } else {
247            batch.wait_semaphores.as_ptr()
248        },
249        pWaitDstStageMask: if batch.wait_stages.is_empty() {
250            std::ptr::null()
251        } else {
252            batch.wait_stages.as_ptr()
253        },
254        commandBufferCount: batch.command_buffers.len() as u32,
255        pCommandBuffers: batch.command_buffers.as_ptr(),
256        signalSemaphoreCount: 1,
257        pSignalSemaphores: &timeline.semaphore,
258    };
259    
260    // Submit to queue
261    if let Some(icd) = super::icd_loader::get_icd() {
262        if let Some(submit_fn) = icd.queue_submit {
263            let result = submit_fn(queue, 1, &submit_info, fence);
264            if result != VkResult::Success {
265                return Err(IcdError::VulkanError(result));
266            }
267        } else {
268            return Err(IcdError::MissingFunction("vkQueueSubmit"));
269        }
270    } else {
271        return Err(IcdError::NoIcdLoaded);
272    }
273    
274    // Reset pending count
275    timeline.pending_count = 0;
276    
277    Ok(signal_value)
278}
279
280/// Wait for timeline value
281///
282/// # Safety
283///
284/// This function is unsafe because:
285/// - The device must be a valid VkDevice handle
286/// - The queue must be a valid VkQueue with associated timeline
287/// - Calls vkWaitSemaphores through ICD function pointer
288/// - The value must be reachable (not waiting for future unsubmitted work)
289/// - Timeout is in nanoseconds, UINT64_MAX means wait forever
290/// - May block the calling thread
291pub unsafe fn wait_timeline(
292    device: VkDevice,
293    queue: VkQueue,
294    value: u64,
295    timeout: u64,
296) -> Result<(), IcdError> {
297    let manager = TIMELINE_MANAGER.lock()?;
298    let queue_key = queue.as_raw();
299    
300    let timeline = manager.timelines.get(&queue_key)
301        .ok_or(IcdError::InvalidOperation("No timeline for queue"))?;
302    
303    let wait_info = VkSemaphoreWaitInfo {
304        sType: VkStructureType::SemaphoreWaitInfo,
305        pNext: std::ptr::null(),
306        flags: VkSemaphoreWaitFlags::empty(),
307        semaphoreCount: 1,
308        pSemaphores: &timeline.semaphore,
309        pValues: &value,
310    };
311    
312    if let Some(icd) = super::icd_loader::get_icd() {
313        if let Some(wait_fn) = icd.wait_semaphores {
314            let result = wait_fn(device, &wait_info, timeout);
315            if result != VkResult::Success && result != VkResult::Timeout {
316                return Err(IcdError::VulkanError(result));
317            }
318        } else {
319            // Fallback to fence if timeline semaphores not supported
320            return Err(IcdError::MissingFunction("vkWaitSemaphores"));
321        }
322    }
323    
324    Ok(())
325}
326
327/// Batch submission builder for convenient API
328pub struct BatchBuilder {
329    queue: VkQueue,
330    command_buffers: Vec<VkCommandBuffer>,
331}
332
333impl BatchBuilder {
334    pub fn new(queue: VkQueue) -> Self {
335        Self {
336            queue,
337            command_buffers: Vec::new(),
338        }
339    }
340    
341    /// Add command buffer to batch
342    pub fn add_command_buffer(mut self, cb: VkCommandBuffer) -> Self {
343        self.command_buffers.push(cb);
344        self
345    }
346    
347    /// Get the number of command buffers in the batch
348    pub fn len(&self) -> usize {
349        self.command_buffers.len()
350    }
351    
352    /// Check if the batch is empty
353    pub fn is_empty(&self) -> bool {
354        self.command_buffers.is_empty()
355    }
356    
357    /// Submit the batch
358    ///
359    /// # Safety
360    ///
361    /// This function is unsafe because:
362    /// - All command buffers must be valid and in executable state
363    /// - The queue must be valid and not in use by another thread
364    /// - Calls unsafe functions: begin_batch, add_to_batch, submit_batch
365    /// - Command buffers must not be reset or freed until submission completes
366    pub unsafe fn submit(self) -> Result<u64, IcdError> {
367        begin_batch(self.queue)?;
368        
369        for cb in self.command_buffers {
370            add_to_batch(self.queue, cb)?;
371        }
372        
373        submit_batch(self.queue, VkFence::NULL)
374    }
375}
376
377/// Batch statistics
378#[derive(Default, Debug)]
379pub struct BatchStats {
380    pub total_submissions: u64,
381    pub total_command_buffers: u64,
382    pub average_batch_size: f64,
383    pub timeline_waits: u64,
384}
385
386impl BatchStats {
387    pub fn record_submission(&mut self, batch_size: usize) {
388        self.total_submissions += 1;
389        self.total_command_buffers += batch_size as u64;
390        self.average_batch_size = self.total_command_buffers as f64 / self.total_submissions as f64;
391    }
392}
393
394/// Get batch statistics
395pub fn get_batch_stats() -> BatchStats {
396    // In a real implementation, we'd track these
397    BatchStats::default()
398}
399
400/// Set batch size threshold
401pub fn set_batch_size(size: u32) -> Result<(), IcdError> {
402    let mut manager = TIMELINE_MANAGER.lock()?;
403    manager.batch_size = size;
404    Ok(())
405}
406
407#[cfg(test)]
408mod tests {
409    use super::*;
410    
411    #[test]
412    fn test_batch_builder() {
413        let queue = VkQueue::from_raw(0x1234);
414        let cb1 = VkCommandBuffer::from_raw(0x5678);
415        let cb2 = VkCommandBuffer::from_raw(0x9ABC);
416        
417        let builder = BatchBuilder::new(queue)
418            .add_command_buffer(cb1)
419            .add_command_buffer(cb2);
420        
421        assert_eq!(builder.len(), 2);
422        assert!(!builder.is_empty());
423    }
424}