kronos_compute/implementation/
timeline_batching.rs1use std::collections::HashMap;
9use std::sync::Mutex;
10use crate::sys::*;
11use crate::core::*;
12use crate::ffi::*;
13use super::error::IcdError;
14
15pub struct TimelineState {
17 semaphore: VkSemaphore,
19 current_value: u64,
21 pending_count: u32,
23}
24
25pub struct BatchSubmission {
27 command_buffers: Vec<VkCommandBuffer>,
29 wait_semaphores: Vec<VkSemaphore>,
31 wait_values: Vec<u64>,
32 wait_stages: Vec<VkPipelineStageFlags>,
33 #[allow(dead_code)]
35 signal_value: u64,
36}
37
38impl BatchSubmission {
39 pub fn new() -> Self {
40 Self {
41 command_buffers: Vec::with_capacity(256),
42 wait_semaphores: Vec::new(),
43 wait_values: Vec::new(),
44 wait_stages: Vec::new(),
45 signal_value: 0,
46 }
47 }
48
49 pub fn add_command_buffer(&mut self, cb: VkCommandBuffer) {
51 self.command_buffers.push(cb);
52 }
53
54 pub fn add_wait(&mut self, semaphore: VkSemaphore, value: u64, stage: VkPipelineStageFlags) {
56 self.wait_semaphores.push(semaphore);
57 self.wait_values.push(value);
58 self.wait_stages.push(stage);
59 }
60}
61
62pub struct TimelineManager {
64 timelines: HashMap<u64, TimelineState>,
66 batches: HashMap<u64, BatchSubmission>,
68 batch_size: u32,
70}
71
72lazy_static::lazy_static! {
73 static ref TIMELINE_MANAGER: Mutex<TimelineManager> = Mutex::new(TimelineManager {
74 timelines: HashMap::new(),
75 batches: HashMap::new(),
76 batch_size: 16, });
78}
79
80pub unsafe fn create_timeline_semaphore(
91 device: VkDevice,
92 initial_value: u64,
93) -> Result<VkSemaphore, IcdError> {
94 let timeline_info = VkSemaphoreTypeCreateInfo {
96 sType: VkStructureType::SemaphoreTypeCreateInfo,
97 pNext: std::ptr::null(),
98 semaphoreType: VkSemaphoreType::Timeline,
99 initialValue: initial_value,
100 };
101
102 let create_info = VkSemaphoreCreateInfo {
103 sType: VkStructureType::SemaphoreCreateInfo,
104 pNext: &timeline_info as *const _ as *const std::ffi::c_void,
105 flags: 0,
106 };
107
108 let mut semaphore = VkSemaphore::NULL;
109
110 if let Some(icd) = super::icd_loader::get_icd() {
111 if let Some(create_fn) = icd.create_semaphore {
112 let result = create_fn(device, &create_info, std::ptr::null(), &mut semaphore);
113 if result == VkResult::Success {
114 return Ok(semaphore);
115 }
116 return Err(IcdError::VulkanError(result));
117 }
118 }
119
120 Err(IcdError::MissingFunction("vkCreateSemaphore"))
121}
122
123pub unsafe fn get_queue_timeline(
134 device: VkDevice,
135 queue: VkQueue,
136) -> Result<(VkSemaphore, u64), IcdError> {
137 let mut manager = TIMELINE_MANAGER.lock()?;
138 let queue_key = queue.as_raw();
139
140 if let Some(state) = manager.timelines.get(&queue_key) {
141 Ok((state.semaphore, state.current_value))
142 } else {
143 let semaphore = create_timeline_semaphore(device, 0)?;
145 let state = TimelineState {
146 semaphore,
147 current_value: 0,
148 pending_count: 0,
149 };
150
151 let current_value = state.current_value;
152 manager.timelines.insert(queue_key, state);
153
154 Ok((semaphore, current_value))
155 }
156}
157
158pub fn begin_batch(queue: VkQueue) -> Result<(), IcdError> {
160 let mut manager = TIMELINE_MANAGER.lock()?;
161 let queue_key = queue.as_raw();
162
163 if !manager.batches.contains_key(&queue_key) {
164 manager.batches.insert(queue_key, BatchSubmission::new());
165 }
166
167 Ok(())
168}
169
170pub fn add_to_batch(
172 queue: VkQueue,
173 command_buffer: VkCommandBuffer,
174) -> Result<bool, IcdError> {
175 let mut manager = TIMELINE_MANAGER.lock()?;
176 let queue_key = queue.as_raw();
177
178 let batch = manager.batches.get_mut(&queue_key)
179 .ok_or(IcdError::InvalidOperation("No active batch"))?;
180
181 batch.add_command_buffer(command_buffer);
182
183 let should_submit = batch.command_buffers.len() >= manager.batch_size as usize;
185
186 if let Some(timeline) = manager.timelines.get_mut(&queue_key) {
187 timeline.pending_count += 1;
188 }
189
190 Ok(should_submit)
191}
192
193pub unsafe fn submit_batch(
205 queue: VkQueue,
206 fence: VkFence,
207) -> Result<u64, IcdError> {
208 let mut manager = TIMELINE_MANAGER.lock()?;
209 let queue_key = queue.as_raw();
210
211 let batch = manager.batches.remove(&queue_key)
212 .ok_or(IcdError::InvalidOperation("No active batch"))?;
213
214 if batch.command_buffers.is_empty() {
215 return Ok(0); }
217
218 let timeline = manager.timelines.get_mut(&queue_key)
219 .ok_or(IcdError::InvalidOperation("No timeline for queue"))?;
220
221 timeline.current_value += 1;
223 let signal_value = timeline.current_value;
224
225 let timeline_info = VkTimelineSemaphoreSubmitInfo {
227 sType: VkStructureType::TimelineSemaphoreSubmitInfo,
228 pNext: std::ptr::null(),
229 waitSemaphoreValueCount: batch.wait_values.len() as u32,
230 pWaitSemaphoreValues: if batch.wait_values.is_empty() {
231 std::ptr::null()
232 } else {
233 batch.wait_values.as_ptr()
234 },
235 signalSemaphoreValueCount: 1,
236 pSignalSemaphoreValues: &signal_value,
237 };
238
239 let submit_info = VkSubmitInfo {
241 sType: VkStructureType::SubmitInfo,
242 pNext: &timeline_info as *const _ as *const std::ffi::c_void,
243 waitSemaphoreCount: batch.wait_semaphores.len() as u32,
244 pWaitSemaphores: if batch.wait_semaphores.is_empty() {
245 std::ptr::null()
246 } else {
247 batch.wait_semaphores.as_ptr()
248 },
249 pWaitDstStageMask: if batch.wait_stages.is_empty() {
250 std::ptr::null()
251 } else {
252 batch.wait_stages.as_ptr()
253 },
254 commandBufferCount: batch.command_buffers.len() as u32,
255 pCommandBuffers: batch.command_buffers.as_ptr(),
256 signalSemaphoreCount: 1,
257 pSignalSemaphores: &timeline.semaphore,
258 };
259
260 if let Some(icd) = super::icd_loader::get_icd() {
262 if let Some(submit_fn) = icd.queue_submit {
263 let result = submit_fn(queue, 1, &submit_info, fence);
264 if result != VkResult::Success {
265 return Err(IcdError::VulkanError(result));
266 }
267 } else {
268 return Err(IcdError::MissingFunction("vkQueueSubmit"));
269 }
270 } else {
271 return Err(IcdError::NoIcdLoaded);
272 }
273
274 timeline.pending_count = 0;
276
277 Ok(signal_value)
278}
279
280pub unsafe fn wait_timeline(
292 device: VkDevice,
293 queue: VkQueue,
294 value: u64,
295 timeout: u64,
296) -> Result<(), IcdError> {
297 let manager = TIMELINE_MANAGER.lock()?;
298 let queue_key = queue.as_raw();
299
300 let timeline = manager.timelines.get(&queue_key)
301 .ok_or(IcdError::InvalidOperation("No timeline for queue"))?;
302
303 let wait_info = VkSemaphoreWaitInfo {
304 sType: VkStructureType::SemaphoreWaitInfo,
305 pNext: std::ptr::null(),
306 flags: VkSemaphoreWaitFlags::empty(),
307 semaphoreCount: 1,
308 pSemaphores: &timeline.semaphore,
309 pValues: &value,
310 };
311
312 if let Some(icd) = super::icd_loader::get_icd() {
313 if let Some(wait_fn) = icd.wait_semaphores {
314 let result = wait_fn(device, &wait_info, timeout);
315 if result != VkResult::Success && result != VkResult::Timeout {
316 return Err(IcdError::VulkanError(result));
317 }
318 } else {
319 return Err(IcdError::MissingFunction("vkWaitSemaphores"));
321 }
322 }
323
324 Ok(())
325}
326
327pub struct BatchBuilder {
329 queue: VkQueue,
330 command_buffers: Vec<VkCommandBuffer>,
331}
332
333impl BatchBuilder {
334 pub fn new(queue: VkQueue) -> Self {
335 Self {
336 queue,
337 command_buffers: Vec::new(),
338 }
339 }
340
341 pub fn add_command_buffer(mut self, cb: VkCommandBuffer) -> Self {
343 self.command_buffers.push(cb);
344 self
345 }
346
347 pub fn len(&self) -> usize {
349 self.command_buffers.len()
350 }
351
352 pub fn is_empty(&self) -> bool {
354 self.command_buffers.is_empty()
355 }
356
357 pub unsafe fn submit(self) -> Result<u64, IcdError> {
367 begin_batch(self.queue)?;
368
369 for cb in self.command_buffers {
370 add_to_batch(self.queue, cb)?;
371 }
372
373 submit_batch(self.queue, VkFence::NULL)
374 }
375}
376
377#[derive(Default, Debug)]
379pub struct BatchStats {
380 pub total_submissions: u64,
381 pub total_command_buffers: u64,
382 pub average_batch_size: f64,
383 pub timeline_waits: u64,
384}
385
386impl BatchStats {
387 pub fn record_submission(&mut self, batch_size: usize) {
388 self.total_submissions += 1;
389 self.total_command_buffers += batch_size as u64;
390 self.average_batch_size = self.total_command_buffers as f64 / self.total_submissions as f64;
391 }
392}
393
394pub fn get_batch_stats() -> BatchStats {
396 BatchStats::default()
398}
399
400pub fn set_batch_size(size: u32) -> Result<(), IcdError> {
402 let mut manager = TIMELINE_MANAGER.lock()?;
403 manager.batch_size = size;
404 Ok(())
405}
406
407#[cfg(test)]
408mod tests {
409 use super::*;
410
411 #[test]
412 fn test_batch_builder() {
413 let queue = VkQueue::from_raw(0x1234);
414 let cb1 = VkCommandBuffer::from_raw(0x5678);
415 let cb2 = VkCommandBuffer::from_raw(0x9ABC);
416
417 let builder = BatchBuilder::new(queue)
418 .add_command_buffer(cb1)
419 .add_command_buffer(cb2);
420
421 assert_eq!(builder.len(), 2);
422 assert!(!builder.is_empty());
423 }
424}