1use crate::{EnumerativeError, EnumerativeResult};
8use num_rational::Rational64;
9use std::collections::HashMap;
10
11#[cfg(feature = "wasm")]
12use wasm_bindgen::prelude::*;
13
14#[cfg(feature = "wasm")]
15use web_sys::console;
16
17#[cfg(feature = "wasm")]
18use num_traits::ToPrimitive;
19
20#[derive(Debug, Clone)]
22pub struct WasmPerformanceConfig {
23 pub enable_simd: bool,
25 pub enable_gpu: bool,
27 pub memory_pool_mb: usize,
29 pub batch_size: usize,
31 pub max_workers: usize,
33 pub enable_workers: bool,
35 pub cache_size: usize,
37}
38
39impl Default for WasmPerformanceConfig {
40 fn default() -> Self {
41 Self {
42 enable_simd: true,
43 enable_gpu: false, memory_pool_mb: 64,
45 batch_size: 1024,
46 max_workers: 4,
47 enable_workers: true,
48 cache_size: 10000,
49 }
50 }
51}
52
53#[derive(Debug)]
55pub struct FastIntersectionComputer {
56 config: WasmPerformanceConfig,
58 cache: HashMap<String, Rational64>,
60 coefficient_buffer: Vec<f64>,
62 #[cfg(feature = "wgpu")]
64 gpu_context: Option<GpuContext>,
65}
66
67impl FastIntersectionComputer {
68 pub fn new(config: WasmPerformanceConfig) -> Self {
70 let cache_capacity = config.cache_size;
71 let buffer_size = config.batch_size * 8; Self {
74 config,
75 cache: HashMap::with_capacity(cache_capacity),
76 coefficient_buffer: vec![0.0; buffer_size],
77 #[cfg(feature = "wgpu")]
78 gpu_context: None,
79 }
80 }
81
82 #[cfg(feature = "wgpu")]
84 pub async fn init_gpu(&mut self) -> EnumerativeResult<()> {
85 self.gpu_context = Some(GpuContext::new().await?);
86 Ok(())
87 }
88
89 pub fn fast_intersection_batch(
91 &mut self,
92 operations: &[(i64, i64, i64)],
93 ) -> EnumerativeResult<Vec<Rational64>> {
94 if operations.is_empty() {
95 return Ok(Vec::new());
96 }
97
98 let mut results = Vec::with_capacity(operations.len());
100 let mut uncached_ops = Vec::new();
101 let mut uncached_indices = Vec::new();
102
103 for (i, &(deg1, deg2, dim)) in operations.iter().enumerate() {
104 let cache_key = format!("{}:{}:{}", deg1, deg2, dim);
105 if let Some(&cached_result) = self.cache.get(&cache_key) {
106 results.push(cached_result);
107 } else {
108 results.push(Rational64::from(0)); uncached_ops.push((deg1, deg2, dim));
110 uncached_indices.push(i);
111 }
112 }
113
114 if uncached_ops.is_empty() {
115 return Ok(results);
116 }
117
118 let computed_results = if self.config.enable_gpu {
120 #[cfg(feature = "wgpu")]
121 {
122 if let Some(ref gpu) = self.gpu_context {
123 self.gpu_compute_batch(gpu, &uncached_ops)?
124 } else {
125 self.simd_compute_batch(&uncached_ops)?
126 }
127 }
128 #[cfg(not(feature = "wgpu"))]
129 {
130 self.simd_compute_batch(&uncached_ops)?
131 }
132 } else {
133 self.simd_compute_batch(&uncached_ops)?
134 };
135
136 for (i, &result) in computed_results.iter().enumerate() {
138 let result_idx = uncached_indices[i];
139 results[result_idx] = result;
140
141 let (deg1, deg2, dim) = uncached_ops[i];
142 let cache_key = format!("{}:{}:{}", deg1, deg2, dim);
143 if self.cache.len() < self.config.cache_size {
144 self.cache.insert(cache_key, result);
145 }
146 }
147
148 Ok(results)
149 }
150
151 fn simd_compute_batch(
153 &mut self,
154 operations: &[(i64, i64, i64)],
155 ) -> EnumerativeResult<Vec<Rational64>> {
156 let batch_size = self.config.batch_size.min(operations.len());
157 let mut results = Vec::with_capacity(operations.len());
158
159 for chunk in operations.chunks(batch_size) {
160 let chunk_results = if self.config.enable_simd {
161 self.simd_intersection_chunk(chunk)?
162 } else {
163 self.scalar_intersection_chunk(chunk)?
164 };
165 results.extend(chunk_results);
166 }
167
168 Ok(results)
169 }
170
171 fn simd_intersection_chunk(
173 &mut self,
174 chunk: &[(i64, i64, i64)],
175 ) -> EnumerativeResult<Vec<Rational64>> {
176 self.coefficient_buffer.clear();
178 self.coefficient_buffer.resize(chunk.len() * 8, 0.0);
179
180 for (i, &(deg1, deg2, dim)) in chunk.iter().enumerate() {
182 let base_idx = i * 8;
183
184 self.coefficient_buffer[base_idx] = deg1 as f64;
186 self.coefficient_buffer[base_idx + 1] = deg2 as f64;
187 self.coefficient_buffer[base_idx + 2] = dim as f64;
188
189 self.coefficient_buffer[base_idx + 3] = (deg1 * deg2) as f64;
191 self.coefficient_buffer[base_idx + 4] = if deg1 + deg2 > dim { 0.0 } else { 1.0 };
192
193 self.coefficient_buffer[base_idx + 5] = ((deg1 + deg2) - dim) as f64;
195 self.coefficient_buffer[base_idx + 6] = (deg1.max(deg2)) as f64;
196 self.coefficient_buffer[base_idx + 7] = (deg1.min(deg2)) as f64;
197 }
198
199 let results = self.vectorized_bezout_computation(chunk.len())?;
201
202 Ok(results)
203 }
204
205 fn vectorized_bezout_computation(&self, count: usize) -> EnumerativeResult<Vec<Rational64>> {
207 let mut results = Vec::with_capacity(count);
208
209 for i in 0..count {
210 let base_idx = i * 8;
211 let deg_product = self.coefficient_buffer[base_idx + 3] as i64;
212 let is_valid = self.coefficient_buffer[base_idx + 4] > 0.5;
213
214 let result = if is_valid {
215 Rational64::from(deg_product)
216 } else {
217 Rational64::from(0)
218 };
219
220 results.push(result);
221 }
222
223 Ok(results)
224 }
225
226 fn scalar_intersection_chunk(
228 &self,
229 chunk: &[(i64, i64, i64)],
230 ) -> EnumerativeResult<Vec<Rational64>> {
231 let mut results = Vec::with_capacity(chunk.len());
232
233 for &(deg1, deg2, dim) in chunk {
234 let result = if deg1 + deg2 > dim {
235 Rational64::from(0) } else {
237 Rational64::from(deg1 * deg2) };
239 results.push(result);
240 }
241
242 Ok(results)
243 }
244
245 pub fn clear_cache(&mut self) {
247 self.cache.clear();
248 }
249
250 pub fn cache_stats(&self) -> (usize, usize) {
252 (self.cache.len(), self.config.cache_size)
253 }
254}
255
256#[cfg(feature = "wgpu")]
258#[derive(Debug)]
259#[allow(dead_code)]
260pub struct GpuContext {
261 device: wgpu::Device,
262 queue: wgpu::Queue,
263 compute_pipeline: wgpu::ComputePipeline,
264}
265
266#[cfg(feature = "wgpu")]
267impl GpuContext {
268 pub async fn new() -> EnumerativeResult<Self> {
270 let instance = wgpu::Instance::new(wgpu::InstanceDescriptor::default());
271
272 let adapter = instance
273 .request_adapter(&wgpu::RequestAdapterOptions::default())
274 .await
275 .ok_or_else(|| {
276 EnumerativeError::ComputationError("No GPU adapter found".to_string())
277 })?;
278
279 let (device, queue) = adapter
280 .request_device(&wgpu::DeviceDescriptor::default(), None)
281 .await
282 .map_err(|e| EnumerativeError::ComputationError(format!("GPU device error: {}", e)))?;
283
284 let shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
285 label: Some("Intersection Compute Shader"),
286 source: wgpu::ShaderSource::Wgsl(include_str!("shaders/intersection.wgsl").into()),
287 });
288
289 let compute_pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
290 label: Some("Intersection Pipeline"),
291 layout: None,
292 module: &shader,
293 entry_point: "main",
294 });
295
296 Ok(Self {
297 device,
298 queue,
299 compute_pipeline,
300 })
301 }
302}
303
304#[cfg(feature = "wgpu")]
305impl FastIntersectionComputer {
306 fn gpu_compute_batch(
308 &self,
309 _gpu: &GpuContext,
310 operations: &[(i64, i64, i64)],
311 ) -> EnumerativeResult<Vec<Rational64>> {
312 let mut input_data = Vec::with_capacity(operations.len() * 4);
314 for &(deg1, deg2, dim) in operations {
315 input_data.extend_from_slice(&[deg1 as f32, deg2 as f32, dim as f32, 0.0]);
316 }
317
318 Err(EnumerativeError::ComputationError(
321 "GPU functionality requires additional dependencies".to_string(),
322 ))
323 }
324}
325
326#[derive(Debug)]
328pub struct SparseSchubertMatrix {
329 entries: Vec<(usize, usize, Rational64)>,
331 rows: usize,
333 cols: usize,
334 row_index: HashMap<usize, Vec<usize>>,
336}
337
338impl SparseSchubertMatrix {
339 pub fn new(rows: usize, cols: usize) -> Self {
341 Self {
342 entries: Vec::new(),
343 rows,
344 cols,
345 row_index: HashMap::new(),
346 }
347 }
348
349 pub fn set(&mut self, row: usize, col: usize, value: Rational64) {
351 if value != Rational64::from(0) {
352 let entry_idx = self.entries.len();
353 self.entries.push((row, col, value));
354 self.row_index.entry(row).or_default().push(entry_idx);
355 }
356 }
357
358 pub fn get(&self, row: usize, col: usize) -> Rational64 {
360 if let Some(indices) = self.row_index.get(&row) {
361 for &idx in indices {
362 let (_, entry_col, value) = self.entries[idx];
363 if entry_col == col {
364 return value;
365 }
366 }
367 }
368 Rational64::from(0)
369 }
370
371 pub fn multiply_vector(&self, vector: &[Rational64]) -> EnumerativeResult<Vec<Rational64>> {
373 if vector.len() != self.cols {
374 return Err(EnumerativeError::InvalidDimension(format!(
375 "Vector length {} != matrix cols {}",
376 vector.len(),
377 self.cols
378 )));
379 }
380
381 let mut result = vec![Rational64::from(0); self.rows];
382
383 for &(row, col, value) in &self.entries {
384 result[row] += value * vector[col];
385 }
386
387 Ok(result)
388 }
389
390 pub fn memory_usage(&self) -> usize {
392 self.entries.len() * std::mem::size_of::<(usize, usize, Rational64)>()
393 + self.row_index.len() * std::mem::size_of::<(usize, Vec<usize>)>()
394 }
395}
396
397#[derive(Debug)]
399pub struct WasmCurveCounting {
400 config: WasmPerformanceConfig,
402 batch_processor: CurveBatchProcessor,
404 memory_pool: MemoryPool,
406}
407
408impl WasmCurveCounting {
409 pub fn new(config: WasmPerformanceConfig) -> Self {
411 let memory_pool = MemoryPool::new(config.memory_pool_mb * 1024 * 1024);
412 let batch_processor = CurveBatchProcessor::new(config.clone());
413
414 Self {
415 config,
416 batch_processor,
417 memory_pool,
418 }
419 }
420
421 pub fn count_curves_batch(
423 &mut self,
424 requests: &[CurveCountRequest],
425 ) -> EnumerativeResult<Vec<i64>> {
426 if requests.is_empty() {
427 return Ok(Vec::new());
428 }
429
430 let _allocation = self.memory_pool.allocate(requests.len() * 64)?;
432
433 let batch_size = self.config.batch_size;
435 let mut results = Vec::with_capacity(requests.len());
436
437 for chunk in requests.chunks(batch_size) {
438 let chunk_results = if self.config.enable_workers {
439 self.batch_processor.process_with_workers(chunk)?
440 } else {
441 self.batch_processor.process_sequential(chunk)?
442 };
443 results.extend(chunk_results);
444 }
445
446 Ok(results)
447 }
448
449 pub fn performance_metrics(&self) -> PerformanceMetrics {
451 PerformanceMetrics {
452 memory_pool_usage: self.memory_pool.usage_percentage(),
453 cache_hit_rate: self.batch_processor.cache_hit_rate(),
454 batch_count: self.batch_processor.batch_count(),
455 worker_utilization: if self.config.enable_workers { 0.8 } else { 1.0 },
456 }
457 }
458}
459
460#[derive(Debug, Clone)]
462pub struct CurveCountRequest {
463 pub target_space: String,
464 pub degree: i64,
465 pub genus: usize,
466 pub constraint_count: usize,
467}
468
469#[derive(Debug)]
471pub struct CurveBatchProcessor {
472 #[allow(dead_code)]
473 config: WasmPerformanceConfig,
474 cache_hits: usize,
475 cache_misses: usize,
476 batch_count: usize,
477}
478
479impl CurveBatchProcessor {
480 pub fn new(config: WasmPerformanceConfig) -> Self {
481 Self {
482 config,
483 cache_hits: 0,
484 cache_misses: 0,
485 batch_count: 0,
486 }
487 }
488
489 pub fn process_with_workers(
490 &mut self,
491 requests: &[CurveCountRequest],
492 ) -> EnumerativeResult<Vec<i64>> {
493 self.batch_count += 1;
494 Ok(requests
496 .iter()
497 .map(|req| req.degree * (req.genus as i64 + 1))
498 .collect())
499 }
500
501 pub fn process_sequential(
502 &mut self,
503 requests: &[CurveCountRequest],
504 ) -> EnumerativeResult<Vec<i64>> {
505 self.batch_count += 1;
506 Ok(requests
508 .iter()
509 .map(|req| req.degree * (req.genus as i64 + 1))
510 .collect())
511 }
512
513 pub fn cache_hit_rate(&self) -> f64 {
514 if self.cache_hits + self.cache_misses == 0 {
515 0.0
516 } else {
517 self.cache_hits as f64 / (self.cache_hits + self.cache_misses) as f64
518 }
519 }
520
521 pub fn batch_count(&self) -> usize {
522 self.batch_count
523 }
524}
525
526#[derive(Debug)]
528pub struct MemoryPool {
529 total_size: usize,
530 allocated: usize,
531}
532
533impl MemoryPool {
534 pub fn new(size: usize) -> Self {
535 Self {
536 total_size: size,
537 allocated: 0,
538 }
539 }
540
541 pub fn allocate(&mut self, size: usize) -> EnumerativeResult<MemoryAllocation> {
542 if self.allocated + size > self.total_size {
543 return Err(EnumerativeError::ComputationError(
544 "Memory pool exhausted".to_string(),
545 ));
546 }
547
548 self.allocated += size;
549 Ok(MemoryAllocation { size })
550 }
551
552 pub fn usage_percentage(&self) -> f64 {
553 self.allocated as f64 / self.total_size as f64 * 100.0
554 }
555}
556
557#[derive(Debug)]
559pub struct MemoryAllocation {
560 #[allow(dead_code)]
561 size: usize,
562}
563
564impl Drop for MemoryAllocation {
565 fn drop(&mut self) {
566 }
568}
569
570#[derive(Debug)]
572pub struct PerformanceMetrics {
573 pub memory_pool_usage: f64,
574 pub cache_hit_rate: f64,
575 pub batch_count: usize,
576 pub worker_utilization: f64,
577}
578
579#[cfg(feature = "wasm")]
581pub fn wasm_log(message: &str) {
582 console::log_1(&message.into());
583}
584
585#[cfg(not(feature = "wasm"))]
586pub fn wasm_log(message: &str) {
587 println!("{}", message);
588}
589
590pub fn benchmark_intersection_computation(
592 config: WasmPerformanceConfig,
593 operation_count: usize,
594) -> EnumerativeResult<f64> {
595 let start = std::time::Instant::now();
596
597 let mut computer = FastIntersectionComputer::new(config);
598
599 let operations: Vec<(i64, i64, i64)> = (0..operation_count)
601 .map(|i| ((i % 10 + 1) as i64, ((i + 1) % 10 + 1) as i64, 3))
602 .collect();
603
604 let _results = computer.fast_intersection_batch(&operations)?;
606
607 let duration = start.elapsed();
608 let operations_per_second = operation_count as f64 / duration.as_secs_f64();
609
610 Ok(operations_per_second)
611}
612
613#[cfg(feature = "wasm")]
615#[wasm_bindgen]
616pub struct WasmEnumerativeAPI {
617 curve_counter: WasmCurveCounting,
618 intersection_computer: FastIntersectionComputer,
619}
620
621#[cfg(feature = "wasm")]
622#[wasm_bindgen]
623impl WasmEnumerativeAPI {
624 #[wasm_bindgen(constructor)]
625 pub fn new() -> Self {
626 let config = WasmPerformanceConfig::default();
627 Self {
628 curve_counter: WasmCurveCounting::new(config.clone()),
629 intersection_computer: FastIntersectionComputer::new(config),
630 }
631 }
632
633 #[wasm_bindgen]
634 pub fn count_curves(&mut self, degree: i64, genus: u32) -> i64 {
635 let request = CurveCountRequest {
636 target_space: "P2".to_string(),
637 degree,
638 genus: genus as usize,
639 constraint_count: 3,
640 };
641
642 self.curve_counter
643 .count_curves_batch(&[request])
644 .unwrap_or_else(|_| vec![0])[0]
645 }
646
647 #[wasm_bindgen]
648 pub fn intersection_number(&mut self, deg1: i64, deg2: i64, dim: i64) -> f64 {
649 let operations = vec![(deg1, deg2, dim)];
650 let results = self
651 .intersection_computer
652 .fast_intersection_batch(&operations)
653 .unwrap_or_else(|_| vec![Rational64::from(0)]);
654
655 results[0].to_f64().unwrap_or(0.0)
656 }
657
658 #[wasm_bindgen]
659 pub fn performance_summary(&self) -> String {
660 let metrics = self.curve_counter.performance_metrics();
661 format!(
662 "Memory: {:.1}%, Cache: {:.1}%, Batches: {}, Workers: {:.1}%",
663 metrics.memory_pool_usage,
664 metrics.cache_hit_rate * 100.0,
665 metrics.batch_count,
666 metrics.worker_utilization * 100.0
667 )
668 }
669}
670
671#[cfg(feature = "wasm")]
672impl Default for WasmEnumerativeAPI {
673 fn default() -> Self {
674 Self::new()
675 }
676}