1use crate::error::{QuantRS2Error, QuantRS2Result};
14use crate::platform::PlatformCapabilities;
15use scirs2_core::ndarray::{Array1, Array2};
16use scirs2_core::Complex64;
17use std::collections::HashMap;
18use std::sync::atomic::{AtomicUsize, Ordering};
19use std::sync::{Arc, RwLock};
20use std::time::{Duration, Instant};
21
22#[derive(Debug, Clone)]
24pub struct MemoryBandwidthConfig {
25 pub enable_prefetching: bool,
27 pub prefetch_distance: usize,
29 pub enable_coalescing: bool,
31 pub coalescing_width: usize,
33 pub enable_buffer_pooling: bool,
35 pub max_pool_size: usize,
37 pub enable_cache_aware_layout: bool,
39 pub cache_line_size: usize,
41}
42
43impl Default for MemoryBandwidthConfig {
44 fn default() -> Self {
45 let capabilities = PlatformCapabilities::detect();
46 let cache_line_size = capabilities.cpu.cache.line_size.unwrap_or(64);
47
48 Self {
49 enable_prefetching: true,
50 prefetch_distance: 8,
51 enable_coalescing: true,
52 coalescing_width: 128, enable_buffer_pooling: true,
54 max_pool_size: 1024 * 1024 * 512, enable_cache_aware_layout: true,
56 cache_line_size,
57 }
58 }
59}
60
61#[derive(Debug, Clone, Default)]
63pub struct MemoryBandwidthMetrics {
64 pub bytes_to_device: usize,
66 pub bytes_from_device: usize,
68 pub transfer_count: usize,
70 pub total_transfer_time: Duration,
72 pub average_bandwidth_gbps: f64,
74 pub cache_hit_rate: f64,
76 pub memory_utilization: f64,
78 pub coalescing_efficiency: f64,
80}
81
82pub struct MemoryBufferPool {
84 free_buffers: RwLock<HashMap<usize, Vec<Vec<Complex64>>>>,
86 allocated_bytes: AtomicUsize,
88 config: MemoryBandwidthConfig,
90 pool_hits: AtomicUsize,
92 pool_misses: AtomicUsize,
94}
95
96impl MemoryBufferPool {
97 pub fn new(config: MemoryBandwidthConfig) -> Self {
99 Self {
100 free_buffers: RwLock::new(HashMap::new()),
101 allocated_bytes: AtomicUsize::new(0),
102 config,
103 pool_hits: AtomicUsize::new(0),
104 pool_misses: AtomicUsize::new(0),
105 }
106 }
107
108 pub fn acquire(&self, size: usize) -> Vec<Complex64> {
110 let aligned_size = self.align_to_cache_line(size);
112
113 if let Ok(mut buffers) = self.free_buffers.write() {
115 if let Some(buffer_list) = buffers.get_mut(&aligned_size) {
116 if let Some(buffer) = buffer_list.pop() {
117 self.pool_hits.fetch_add(1, Ordering::Relaxed);
118 return buffer;
119 }
120 }
121 }
122
123 self.pool_misses.fetch_add(1, Ordering::Relaxed);
125 let buffer_bytes = aligned_size * std::mem::size_of::<Complex64>();
126 self.allocated_bytes
127 .fetch_add(buffer_bytes, Ordering::Relaxed);
128
129 vec![Complex64::new(0.0, 0.0); aligned_size]
130 }
131
132 pub fn release(&self, mut buffer: Vec<Complex64>) {
134 let size = buffer.len();
135 let buffer_bytes = size * std::mem::size_of::<Complex64>();
136
137 if self.allocated_bytes.load(Ordering::Relaxed) <= self.config.max_pool_size {
139 for elem in &mut buffer {
141 *elem = Complex64::new(0.0, 0.0);
142 }
143
144 if let Ok(mut buffers) = self.free_buffers.write() {
145 buffers.entry(size).or_default().push(buffer);
146 }
147 } else {
148 self.allocated_bytes
150 .fetch_sub(buffer_bytes, Ordering::Relaxed);
151 }
152 }
153
154 fn align_to_cache_line(&self, size: usize) -> usize {
156 let elem_size = std::mem::size_of::<Complex64>();
157 let elems_per_line = self.config.cache_line_size / elem_size;
158 ((size + elems_per_line - 1) / elems_per_line) * elems_per_line
159 }
160
161 pub fn get_statistics(&self) -> PoolStatistics {
163 let hits = self.pool_hits.load(Ordering::Relaxed);
164 let misses = self.pool_misses.load(Ordering::Relaxed);
165 let total = hits + misses;
166
167 PoolStatistics {
168 allocated_bytes: self.allocated_bytes.load(Ordering::Relaxed),
169 pool_hit_rate: if total > 0 {
170 hits as f64 / total as f64
171 } else {
172 0.0
173 },
174 total_acquisitions: total,
175 }
176 }
177
178 pub fn clear(&self) {
180 if let Ok(mut buffers) = self.free_buffers.write() {
181 for (size, buffer_list) in buffers.drain() {
182 let freed_bytes = size * std::mem::size_of::<Complex64>() * buffer_list.len();
183 self.allocated_bytes
184 .fetch_sub(freed_bytes, Ordering::Relaxed);
185 }
186 }
187 }
188}
189
190#[derive(Debug, Clone)]
192pub struct PoolStatistics {
193 pub allocated_bytes: usize,
195 pub pool_hit_rate: f64,
197 pub total_acquisitions: usize,
199}
200
201pub struct MemoryBandwidthOptimizer {
203 config: MemoryBandwidthConfig,
205 buffer_pool: Arc<MemoryBufferPool>,
207 metrics: RwLock<MemoryBandwidthMetrics>,
209}
210
211impl MemoryBandwidthOptimizer {
212 pub fn new(config: MemoryBandwidthConfig) -> Self {
214 let buffer_pool = Arc::new(MemoryBufferPool::new(config.clone()));
215
216 Self {
217 config,
218 buffer_pool,
219 metrics: RwLock::new(MemoryBandwidthMetrics::default()),
220 }
221 }
222
223 pub fn get_optimal_layout(&self, n_qubits: usize) -> MemoryLayout {
225 let state_size = 1 << n_qubits;
226 let elem_size = std::mem::size_of::<Complex64>();
227 let total_bytes = state_size * elem_size;
228
229 let elems_per_line = self.config.cache_line_size / elem_size;
231
232 MemoryLayout {
233 total_elements: state_size,
234 total_bytes,
235 cache_line_elements: elems_per_line,
236 recommended_alignment: self.config.cache_line_size,
237 use_tiled_layout: n_qubits >= 10, tile_size: if n_qubits >= 10 { 256 } else { 0 },
239 }
240 }
241
242 pub fn optimize_coalesced_access<F>(
244 &self,
245 data: &mut [Complex64],
246 access_pattern: &[usize],
247 operation: F,
248 ) -> QuantRS2Result<()>
249 where
250 F: Fn(&mut Complex64, usize) -> QuantRS2Result<()>,
251 {
252 if !self.config.enable_coalescing {
253 for &idx in access_pattern {
255 if idx >= data.len() {
256 return Err(QuantRS2Error::InvalidInput(
257 "Index out of bounds".to_string(),
258 ));
259 }
260 operation(&mut data[idx], idx)?;
261 }
262 return Ok(());
263 }
264
265 let mut sorted_indices: Vec<_> = access_pattern.to_vec();
267 sorted_indices.sort_unstable();
268
269 let coalescing_elements = self.config.coalescing_width / std::mem::size_of::<Complex64>();
271
272 for chunk in sorted_indices.chunks(coalescing_elements) {
273 for &idx in chunk {
274 if idx >= data.len() {
275 return Err(QuantRS2Error::InvalidInput(
276 "Index out of bounds".to_string(),
277 ));
278 }
279 operation(&mut data[idx], idx)?;
280 }
281 }
282
283 Ok(())
284 }
285
286 pub fn prefetch_for_gate_application(
288 &self,
289 state: &[Complex64],
290 qubit: usize,
291 n_qubits: usize,
292 ) {
293 if !self.config.enable_prefetching {
294 return;
295 }
296
297 let state_size = 1 << n_qubits;
298 let qubit_mask = 1 << qubit;
299
300 for i in 0..(state_size / 2).min(self.config.prefetch_distance * 2) {
302 let idx0 = (i & !(qubit_mask >> 1)) | ((i & (qubit_mask >> 1)) << 1);
303 let idx1 = idx0 | qubit_mask;
304
305 if idx0 < state.len() && idx1 < state.len() {
306 #[cfg(target_arch = "x86_64")]
308 unsafe {
309 let ptr0 = state.as_ptr().add(idx0);
310 let ptr1 = state.as_ptr().add(idx1);
311 std::arch::x86_64::_mm_prefetch(
312 ptr0 as *const i8,
313 std::arch::x86_64::_MM_HINT_T0,
314 );
315 std::arch::x86_64::_mm_prefetch(
316 ptr1 as *const i8,
317 std::arch::x86_64::_MM_HINT_T0,
318 );
319 }
320
321 #[cfg(target_arch = "aarch64")]
322 {
323 let _ = (state[idx0], state[idx1]);
325 }
326 }
327 }
328 }
329
330 pub fn acquire_buffer(&self, size: usize) -> Vec<Complex64> {
332 self.buffer_pool.acquire(size)
333 }
334
335 pub fn release_buffer(&self, buffer: Vec<Complex64>) {
337 self.buffer_pool.release(buffer);
338 }
339
340 pub fn record_transfer(&self, bytes: usize, to_device: bool, duration: Duration) {
342 if let Ok(mut metrics) = self.metrics.write() {
343 if to_device {
344 metrics.bytes_to_device += bytes;
345 } else {
346 metrics.bytes_from_device += bytes;
347 }
348 metrics.transfer_count += 1;
349 metrics.total_transfer_time += duration;
350
351 let total_bytes = metrics.bytes_to_device + metrics.bytes_from_device;
353 let total_secs = metrics.total_transfer_time.as_secs_f64();
354 if total_secs > 0.0 {
355 metrics.average_bandwidth_gbps = (total_bytes as f64) / total_secs / 1e9;
356 }
357 }
358 }
359
360 pub fn get_metrics(&self) -> MemoryBandwidthMetrics {
362 self.metrics.read().unwrap().clone()
363 }
364
365 pub fn get_pool_statistics(&self) -> PoolStatistics {
367 self.buffer_pool.get_statistics()
368 }
369
370 pub fn clear_pool(&self) {
372 self.buffer_pool.clear();
373 }
374
375 pub fn get_optimization_recommendations(&self) -> Vec<String> {
377 let metrics = self.get_metrics();
378 let pool_stats = self.get_pool_statistics();
379 let mut recommendations = Vec::new();
380
381 if metrics.average_bandwidth_gbps < 10.0 && metrics.transfer_count > 100 {
383 recommendations.push(
384 "Consider batching memory transfers to improve bandwidth utilization".to_string(),
385 );
386 }
387
388 if pool_stats.pool_hit_rate < 0.5 && pool_stats.total_acquisitions > 100 {
390 recommendations.push(format!(
391 "Pool hit rate is {:.1}%. Consider increasing pool size for better reuse",
392 pool_stats.pool_hit_rate * 100.0
393 ));
394 }
395
396 if metrics.coalescing_efficiency < 0.7 {
398 recommendations.push(
399 "Memory access pattern has low coalescing efficiency. Consider reordering accesses"
400 .to_string(),
401 );
402 }
403
404 if metrics.cache_hit_rate < 0.8 && metrics.transfer_count > 50 {
406 recommendations.push(
407 "Cache hit rate is low. Consider using cache-aware memory layouts".to_string(),
408 );
409 }
410
411 if recommendations.is_empty() {
412 recommendations.push("Memory bandwidth utilization is optimal".to_string());
413 }
414
415 recommendations
416 }
417}
418
419#[derive(Debug, Clone)]
421pub struct MemoryLayout {
422 pub total_elements: usize,
424 pub total_bytes: usize,
426 pub cache_line_elements: usize,
428 pub recommended_alignment: usize,
430 pub use_tiled_layout: bool,
432 pub tile_size: usize,
434}
435
436pub struct StreamingTransfer {
438 chunk_size: usize,
440 concurrent_transfers: usize,
442 buffer_pool: Arc<MemoryBufferPool>,
444}
445
446impl StreamingTransfer {
447 pub fn new(chunk_size: usize, buffer_pool: Arc<MemoryBufferPool>) -> Self {
449 Self {
450 chunk_size,
451 concurrent_transfers: 2, buffer_pool,
453 }
454 }
455
456 pub fn stream_to_device<F>(
458 &self,
459 data: &[Complex64],
460 transfer_fn: F,
461 ) -> QuantRS2Result<Duration>
462 where
463 F: Fn(&[Complex64], usize) -> QuantRS2Result<()>,
464 {
465 let start = Instant::now();
466 let mut offset = 0;
467
468 while offset < data.len() {
469 let chunk_end = (offset + self.chunk_size).min(data.len());
470 let chunk = &data[offset..chunk_end];
471
472 transfer_fn(chunk, offset)?;
473 offset = chunk_end;
474 }
475
476 Ok(start.elapsed())
477 }
478
479 pub fn stream_from_device<F>(
481 &self,
482 data: &mut [Complex64],
483 transfer_fn: F,
484 ) -> QuantRS2Result<Duration>
485 where
486 F: Fn(&mut [Complex64], usize) -> QuantRS2Result<()>,
487 {
488 let start = Instant::now();
489 let mut offset = 0;
490
491 while offset < data.len() {
492 let chunk_end = (offset + self.chunk_size).min(data.len());
493 let chunk = &mut data[offset..chunk_end];
494
495 transfer_fn(chunk, offset)?;
496 offset = chunk_end;
497 }
498
499 Ok(start.elapsed())
500 }
501}
502
503#[cfg(test)]
504mod tests {
505 use super::*;
506
507 #[test]
508 fn test_memory_bandwidth_config_default() {
509 let config = MemoryBandwidthConfig::default();
510 assert!(config.enable_prefetching);
511 assert!(config.enable_coalescing);
512 assert!(config.enable_buffer_pooling);
513 assert!(config.cache_line_size > 0);
514 }
515
516 #[test]
517 fn test_buffer_pool_acquire_release() {
518 let config = MemoryBandwidthConfig::default();
519 let pool = MemoryBufferPool::new(config);
520
521 let buffer = pool.acquire(100);
523 assert!(buffer.len() >= 100);
524
525 let size = buffer.len();
527 pool.release(buffer);
528
529 let buffer2 = pool.acquire(100);
531 assert_eq!(buffer2.len(), size);
532
533 let stats = pool.get_statistics();
534 assert!(stats.pool_hit_rate > 0.0);
535 }
536
537 #[test]
538 fn test_memory_layout_computation() {
539 let config = MemoryBandwidthConfig::default();
540 let optimizer = MemoryBandwidthOptimizer::new(config);
541
542 let layout = optimizer.get_optimal_layout(4);
543 assert_eq!(layout.total_elements, 16);
544 assert!(!layout.use_tiled_layout);
545
546 let layout_large = optimizer.get_optimal_layout(12);
547 assert_eq!(layout_large.total_elements, 4096);
548 assert!(layout_large.use_tiled_layout);
549 }
550
551 #[test]
552 fn test_coalesced_access_optimization() {
553 let config = MemoryBandwidthConfig::default();
554 let optimizer = MemoryBandwidthOptimizer::new(config);
555
556 let mut data = vec![Complex64::new(0.0, 0.0); 100];
557 let pattern = vec![50, 10, 30, 70, 90];
558
559 let result = optimizer.optimize_coalesced_access(&mut data, &pattern, |elem, idx| {
560 *elem = Complex64::new(idx as f64, 0.0);
561 Ok(())
562 });
563
564 assert!(result.is_ok());
565 assert_eq!(data[10], Complex64::new(10.0, 0.0));
566 assert_eq!(data[50], Complex64::new(50.0, 0.0));
567 }
568
569 #[test]
570 fn test_transfer_metrics_recording() {
571 let config = MemoryBandwidthConfig::default();
572 let optimizer = MemoryBandwidthOptimizer::new(config);
573
574 optimizer.record_transfer(1024, true, Duration::from_micros(100));
575 optimizer.record_transfer(1024, false, Duration::from_micros(100));
576
577 let metrics = optimizer.get_metrics();
578 assert_eq!(metrics.bytes_to_device, 1024);
579 assert_eq!(metrics.bytes_from_device, 1024);
580 assert_eq!(metrics.transfer_count, 2);
581 }
582
583 #[test]
584 fn test_optimization_recommendations() {
585 let config = MemoryBandwidthConfig::default();
586 let optimizer = MemoryBandwidthOptimizer::new(config);
587
588 let recommendations = optimizer.get_optimization_recommendations();
589 assert!(!recommendations.is_empty());
590 }
591
592 #[test]
593 fn test_streaming_transfer() {
594 let config = MemoryBandwidthConfig::default();
595 let pool = Arc::new(MemoryBufferPool::new(config));
596 let streamer = StreamingTransfer::new(32, pool);
597
598 let data = vec![Complex64::new(1.0, 0.0); 100];
599 let result = streamer.stream_to_device(&data, |_chunk, _offset| Ok(()));
600 assert!(result.is_ok());
601 }
602
603 #[test]
604 fn test_pool_clear() {
605 let config = MemoryBandwidthConfig::default();
606 let pool = MemoryBufferPool::new(config);
607
608 for _ in 0..10 {
610 let buffer = pool.acquire(100);
611 pool.release(buffer);
612 }
613
614 pool.clear();
616
617 let stats = pool.get_statistics();
618 assert_eq!(stats.allocated_bytes, 0);
619 }
620}