1#![allow(dead_code)]
2use std::collections::VecDeque;
9
10#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
12pub struct BufferRegion {
13 pub offset: u64,
15 pub size: u64,
17}
18
19impl BufferRegion {
20 #[must_use]
22 pub fn new(offset: u64, size: u64) -> Self {
23 Self { offset, size }
24 }
25
26 #[must_use]
28 pub fn from_start(size: u64) -> Self {
29 Self { offset: 0, size }
30 }
31
32 #[must_use]
34 pub fn end(&self) -> u64 {
35 self.offset + self.size
36 }
37
38 #[must_use]
40 pub fn overlaps(&self, other: &BufferRegion) -> bool {
41 self.offset < other.end() && other.offset < self.end()
42 }
43
44 #[must_use]
46 pub fn contained_in(&self, other: &BufferRegion) -> bool {
47 self.offset >= other.offset && self.end() <= other.end()
48 }
49
50 #[must_use]
52 pub fn intersection(&self, other: &BufferRegion) -> Option<BufferRegion> {
53 let start = self.offset.max(other.offset);
54 let end = self.end().min(other.end());
55 if start < end {
56 Some(BufferRegion::new(start, end - start))
57 } else {
58 None
59 }
60 }
61}
62
63#[derive(Debug, Clone, Copy, PartialEq, Eq)]
65pub struct ImageRegion {
66 pub x: u32,
68 pub y: u32,
70 pub width: u32,
72 pub height: u32,
74}
75
76impl ImageRegion {
77 #[must_use]
79 pub fn new(x: u32, y: u32, width: u32, height: u32) -> Self {
80 Self {
81 x,
82 y,
83 width,
84 height,
85 }
86 }
87
88 #[must_use]
90 pub fn from_size(width: u32, height: u32) -> Self {
91 Self {
92 x: 0,
93 y: 0,
94 width,
95 height,
96 }
97 }
98
99 #[must_use]
101 pub fn pixel_count(&self) -> u64 {
102 u64::from(self.width) * u64::from(self.height)
103 }
104
105 #[must_use]
107 pub fn contains_point(&self, px: u32, py: u32) -> bool {
108 px >= self.x && px < self.x + self.width && py >= self.y && py < self.y + self.height
109 }
110}
111
112#[derive(Debug, Clone, Copy, PartialEq, Eq)]
114pub enum CopyDirection {
115 HostToDevice,
117 DeviceToHost,
119 DeviceToDevice,
121 PeerToPeer,
123}
124
125impl std::fmt::Display for CopyDirection {
126 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
127 match self {
128 Self::HostToDevice => write!(f, "Host -> Device"),
129 Self::DeviceToHost => write!(f, "Device -> Host"),
130 Self::DeviceToDevice => write!(f, "Device -> Device"),
131 Self::PeerToPeer => write!(f, "Peer -> Peer"),
132 }
133 }
134}
135
136#[derive(Debug, Clone, PartialEq, Eq)]
138pub struct CopyCommand {
139 pub src_id: u64,
141 pub dst_id: u64,
143 pub src_region: BufferRegion,
145 pub dst_offset: u64,
147 pub direction: CopyDirection,
149}
150
151impl CopyCommand {
152 #[must_use]
154 pub fn new(
155 src_id: u64,
156 dst_id: u64,
157 src_region: BufferRegion,
158 dst_offset: u64,
159 direction: CopyDirection,
160 ) -> Self {
161 Self {
162 src_id,
163 dst_id,
164 src_region,
165 dst_offset,
166 direction,
167 }
168 }
169
170 #[must_use]
172 pub fn dst_region(&self) -> BufferRegion {
173 BufferRegion::new(self.dst_offset, self.src_region.size)
174 }
175
176 #[must_use]
178 pub fn aliases_with(&self, other: &CopyCommand) -> bool {
179 if self.dst_id == other.dst_id {
181 let my_dst = self.dst_region();
182 let other_dst = other.dst_region();
183 if my_dst.overlaps(&other_dst) {
184 return true;
185 }
186 }
187 if self.src_id == other.dst_id {
189 let other_dst = other.dst_region();
190 if self.src_region.overlaps(&other_dst) {
191 return true;
192 }
193 }
194 if self.dst_id == other.src_id {
195 let my_dst = self.dst_region();
196 if my_dst.overlaps(&other.src_region) {
197 return true;
198 }
199 }
200 false
201 }
202}
203
204#[derive(Debug, Default)]
206pub struct CopyBatch {
207 commands: VecDeque<CopyCommand>,
209 total_bytes: u64,
211}
212
213impl CopyBatch {
214 #[must_use]
216 pub fn new() -> Self {
217 Self::default()
218 }
219
220 pub fn push(&mut self, cmd: CopyCommand) {
222 self.total_bytes += cmd.src_region.size;
223 self.commands.push_back(cmd);
224 }
225
226 #[must_use]
228 pub fn len(&self) -> usize {
229 self.commands.len()
230 }
231
232 #[must_use]
234 pub fn is_empty(&self) -> bool {
235 self.commands.is_empty()
236 }
237
238 #[must_use]
240 pub fn total_bytes(&self) -> u64 {
241 self.total_bytes
242 }
243
244 pub fn drain(&mut self) -> Vec<CopyCommand> {
246 self.total_bytes = 0;
247 self.commands.drain(..).collect()
248 }
249
250 #[must_use]
252 pub fn has_hazards(&self) -> bool {
253 let cmds: Vec<_> = self.commands.iter().collect();
254 for i in 0..cmds.len() {
255 for j in (i + 1)..cmds.len() {
256 if cmds[i].aliases_with(cmds[j]) {
257 return true;
258 }
259 }
260 }
261 false
262 }
263
264 #[must_use]
266 pub fn split_independent(mut self) -> Vec<Vec<CopyCommand>> {
267 let all = self.drain();
268 if all.is_empty() {
269 return Vec::new();
270 }
271
272 let mut batches: Vec<Vec<CopyCommand>> = Vec::new();
273
274 for cmd in all {
275 let mut placed = false;
276 for batch in &mut batches {
277 let conflicts = batch.iter().any(|existing| existing.aliases_with(&cmd));
278 if !conflicts {
279 batch.push(cmd.clone());
280 placed = true;
281 break;
282 }
283 }
284 if !placed {
285 batches.push(vec![cmd]);
286 }
287 }
288
289 batches
290 }
291}
292
293#[derive(Debug, Clone, Default, PartialEq)]
295pub struct CopyStats {
296 pub copy_count: u64,
298 pub total_bytes: u64,
300 pub h2d_count: u64,
302 pub d2h_count: u64,
304 pub d2d_count: u64,
306}
307
308impl CopyStats {
309 #[must_use]
311 pub fn new() -> Self {
312 Self::default()
313 }
314
315 pub fn record(&mut self, cmd: &CopyCommand) {
317 self.copy_count += 1;
318 self.total_bytes += cmd.src_region.size;
319 match cmd.direction {
320 CopyDirection::HostToDevice => self.h2d_count += 1,
321 CopyDirection::DeviceToHost => self.d2h_count += 1,
322 CopyDirection::DeviceToDevice | CopyDirection::PeerToPeer => self.d2d_count += 1,
323 }
324 }
325
326 pub fn reset(&mut self) {
328 *self = Self::default();
329 }
330}
331
332#[cfg(test)]
333mod tests {
334 use super::*;
335
336 #[test]
337 fn test_buffer_region_basic() {
338 let r = BufferRegion::new(10, 20);
339 assert_eq!(r.offset, 10);
340 assert_eq!(r.size, 20);
341 assert_eq!(r.end(), 30);
342 }
343
344 #[test]
345 fn test_buffer_region_from_start() {
346 let r = BufferRegion::from_start(100);
347 assert_eq!(r.offset, 0);
348 assert_eq!(r.size, 100);
349 }
350
351 #[test]
352 fn test_buffer_region_overlaps() {
353 let a = BufferRegion::new(0, 10);
354 let b = BufferRegion::new(5, 10);
355 assert!(a.overlaps(&b));
356 assert!(b.overlaps(&a));
357 }
358
359 #[test]
360 fn test_buffer_region_no_overlap() {
361 let a = BufferRegion::new(0, 10);
362 let b = BufferRegion::new(10, 10);
363 assert!(!a.overlaps(&b));
364 }
365
366 #[test]
367 fn test_buffer_region_contained() {
368 let inner = BufferRegion::new(5, 5);
369 let outer = BufferRegion::new(0, 20);
370 assert!(inner.contained_in(&outer));
371 assert!(!outer.contained_in(&inner));
372 }
373
374 #[test]
375 fn test_buffer_region_intersection() {
376 let a = BufferRegion::new(0, 10);
377 let b = BufferRegion::new(5, 10);
378 let i = a.intersection(&b).expect("intersection should succeed");
379 assert_eq!(i.offset, 5);
380 assert_eq!(i.size, 5);
381 }
382
383 #[test]
384 fn test_buffer_region_no_intersection() {
385 let a = BufferRegion::new(0, 5);
386 let b = BufferRegion::new(10, 5);
387 assert!(a.intersection(&b).is_none());
388 }
389
390 #[test]
391 fn test_image_region_basic() {
392 let r = ImageRegion::new(10, 20, 100, 50);
393 assert_eq!(r.pixel_count(), 5000);
394 }
395
396 #[test]
397 fn test_image_region_contains_point() {
398 let r = ImageRegion::from_size(100, 100);
399 assert!(r.contains_point(50, 50));
400 assert!(!r.contains_point(100, 100));
401 assert!(r.contains_point(0, 0));
402 }
403
404 #[test]
405 fn test_copy_direction_display() {
406 assert_eq!(format!("{}", CopyDirection::HostToDevice), "Host -> Device");
407 assert_eq!(format!("{}", CopyDirection::DeviceToHost), "Device -> Host");
408 }
409
410 #[test]
411 fn test_copy_command_dst_region() {
412 let cmd = CopyCommand::new(
413 1,
414 2,
415 BufferRegion::new(0, 1024),
416 512,
417 CopyDirection::DeviceToDevice,
418 );
419 let dst = cmd.dst_region();
420 assert_eq!(dst.offset, 512);
421 assert_eq!(dst.size, 1024);
422 }
423
424 #[test]
425 fn test_copy_command_aliases() {
426 let a = CopyCommand::new(
427 1,
428 2,
429 BufferRegion::new(0, 100),
430 0,
431 CopyDirection::DeviceToDevice,
432 );
433 let b = CopyCommand::new(
434 3,
435 2,
436 BufferRegion::new(0, 100),
437 50,
438 CopyDirection::DeviceToDevice,
439 );
440 assert!(a.aliases_with(&b));
441 }
442
443 #[test]
444 fn test_copy_command_no_alias() {
445 let a = CopyCommand::new(
446 1,
447 2,
448 BufferRegion::new(0, 100),
449 0,
450 CopyDirection::DeviceToDevice,
451 );
452 let b = CopyCommand::new(
453 3,
454 4,
455 BufferRegion::new(0, 100),
456 0,
457 CopyDirection::DeviceToDevice,
458 );
459 assert!(!a.aliases_with(&b));
460 }
461
462 #[test]
463 fn test_copy_batch_push_and_drain() {
464 let mut batch = CopyBatch::new();
465 assert!(batch.is_empty());
466
467 batch.push(CopyCommand::new(
468 1,
469 2,
470 BufferRegion::from_start(256),
471 0,
472 CopyDirection::HostToDevice,
473 ));
474 batch.push(CopyCommand::new(
475 3,
476 4,
477 BufferRegion::from_start(512),
478 0,
479 CopyDirection::DeviceToHost,
480 ));
481
482 assert_eq!(batch.len(), 2);
483 assert_eq!(batch.total_bytes(), 768);
484
485 let cmds = batch.drain();
486 assert_eq!(cmds.len(), 2);
487 assert!(batch.is_empty());
488 assert_eq!(batch.total_bytes(), 0);
489 }
490
491 #[test]
492 fn test_copy_batch_no_hazards() {
493 let mut batch = CopyBatch::new();
494 batch.push(CopyCommand::new(
495 1,
496 2,
497 BufferRegion::from_start(100),
498 0,
499 CopyDirection::DeviceToDevice,
500 ));
501 batch.push(CopyCommand::new(
502 3,
503 4,
504 BufferRegion::from_start(100),
505 0,
506 CopyDirection::DeviceToDevice,
507 ));
508 assert!(!batch.has_hazards());
509 }
510
511 #[test]
512 fn test_copy_batch_with_hazards() {
513 let mut batch = CopyBatch::new();
514 batch.push(CopyCommand::new(
515 1,
516 2,
517 BufferRegion::from_start(100),
518 0,
519 CopyDirection::DeviceToDevice,
520 ));
521 batch.push(CopyCommand::new(
522 3,
523 2,
524 BufferRegion::from_start(100),
525 50,
526 CopyDirection::DeviceToDevice,
527 ));
528 assert!(batch.has_hazards());
529 }
530
531 #[test]
532 fn test_copy_batch_split_independent() {
533 let mut batch = CopyBatch::new();
534 batch.push(CopyCommand::new(
535 1,
536 2,
537 BufferRegion::from_start(100),
538 0,
539 CopyDirection::DeviceToDevice,
540 ));
541 batch.push(CopyCommand::new(
542 3,
543 2,
544 BufferRegion::from_start(100),
545 50,
546 CopyDirection::DeviceToDevice,
547 ));
548 batch.push(CopyCommand::new(
549 5,
550 6,
551 BufferRegion::from_start(100),
552 0,
553 CopyDirection::DeviceToDevice,
554 ));
555
556 let batches = batch.split_independent();
557 assert!(batches.len() >= 2);
559 }
560
561 #[test]
562 fn test_copy_stats() {
563 let mut stats = CopyStats::new();
564 let cmd = CopyCommand::new(
565 1,
566 2,
567 BufferRegion::from_start(1024),
568 0,
569 CopyDirection::HostToDevice,
570 );
571 stats.record(&cmd);
572 assert_eq!(stats.copy_count, 1);
573 assert_eq!(stats.total_bytes, 1024);
574 assert_eq!(stats.h2d_count, 1);
575 assert_eq!(stats.d2h_count, 0);
576
577 stats.reset();
578 assert_eq!(stats.copy_count, 0);
579 }
580
581 #[test]
582 fn test_copy_stats_multiple_directions() {
583 let mut stats = CopyStats::new();
584 stats.record(&CopyCommand::new(
585 1,
586 2,
587 BufferRegion::from_start(100),
588 0,
589 CopyDirection::HostToDevice,
590 ));
591 stats.record(&CopyCommand::new(
592 2,
593 1,
594 BufferRegion::from_start(200),
595 0,
596 CopyDirection::DeviceToHost,
597 ));
598 stats.record(&CopyCommand::new(
599 2,
600 3,
601 BufferRegion::from_start(300),
602 0,
603 CopyDirection::DeviceToDevice,
604 ));
605 assert_eq!(stats.copy_count, 3);
606 assert_eq!(stats.total_bytes, 600);
607 assert_eq!(stats.h2d_count, 1);
608 assert_eq!(stats.d2h_count, 1);
609 assert_eq!(stats.d2d_count, 1);
610 }
611}