1#![allow(missing_docs)]
7
8use crate::{EnumerativeError, EnumerativeResult};
9use num_rational::Rational64;
10use std::collections::HashMap;
11
12#[cfg(feature = "wasm")]
13use wasm_bindgen::prelude::*;
14
15#[cfg(feature = "wasm")]
16use web_sys::console;
17
18#[cfg(feature = "wasm")]
19use num_traits::ToPrimitive;
20
21#[derive(Debug, Clone)]
23pub struct WasmPerformanceConfig {
24 pub enable_simd: bool,
26 pub enable_gpu: bool,
28 pub memory_pool_mb: usize,
30 pub batch_size: usize,
32 pub max_workers: usize,
34 pub enable_workers: bool,
36 pub cache_size: usize,
38}
39
40impl Default for WasmPerformanceConfig {
41 fn default() -> Self {
42 Self {
43 enable_simd: true,
44 enable_gpu: false, memory_pool_mb: 64,
46 batch_size: 1024,
47 max_workers: 4,
48 enable_workers: true,
49 cache_size: 10000,
50 }
51 }
52}
53
54#[derive(Debug)]
56pub struct FastIntersectionComputer {
57 config: WasmPerformanceConfig,
59 cache: HashMap<String, Rational64>,
61 coefficient_buffer: Vec<f64>,
63 #[cfg(feature = "wgpu")]
65 gpu_context: Option<GpuContext>,
66}
67
68impl FastIntersectionComputer {
69 pub fn new(config: WasmPerformanceConfig) -> Self {
71 let cache_capacity = config.cache_size;
72 let buffer_size = config.batch_size * 8; Self {
75 config,
76 cache: HashMap::with_capacity(cache_capacity),
77 coefficient_buffer: vec![0.0; buffer_size],
78 #[cfg(feature = "wgpu")]
79 gpu_context: None,
80 }
81 }
82
83 #[cfg(feature = "wgpu")]
85 pub async fn init_gpu(&mut self) -> EnumerativeResult<()> {
86 self.gpu_context = Some(GpuContext::new().await?);
87 Ok(())
88 }
89
90 pub fn fast_intersection_batch(
92 &mut self,
93 operations: &[(i64, i64, i64)],
94 ) -> EnumerativeResult<Vec<Rational64>> {
95 if operations.is_empty() {
96 return Ok(Vec::new());
97 }
98
99 let mut results = Vec::with_capacity(operations.len());
101 let mut uncached_ops = Vec::new();
102 let mut uncached_indices = Vec::new();
103
104 for (i, &(deg1, deg2, dim)) in operations.iter().enumerate() {
105 let cache_key = format!("{}:{}:{}", deg1, deg2, dim);
106 if let Some(&cached_result) = self.cache.get(&cache_key) {
107 results.push(cached_result);
108 } else {
109 results.push(Rational64::from(0)); uncached_ops.push((deg1, deg2, dim));
111 uncached_indices.push(i);
112 }
113 }
114
115 if uncached_ops.is_empty() {
116 return Ok(results);
117 }
118
119 let computed_results = if self.config.enable_gpu {
121 #[cfg(feature = "wgpu")]
122 {
123 if let Some(ref gpu) = self.gpu_context {
124 self.gpu_compute_batch(gpu, &uncached_ops)?
125 } else {
126 self.simd_compute_batch(&uncached_ops)?
127 }
128 }
129 #[cfg(not(feature = "wgpu"))]
130 {
131 self.simd_compute_batch(&uncached_ops)?
132 }
133 } else {
134 self.simd_compute_batch(&uncached_ops)?
135 };
136
137 for (i, &result) in computed_results.iter().enumerate() {
139 let result_idx = uncached_indices[i];
140 results[result_idx] = result;
141
142 let (deg1, deg2, dim) = uncached_ops[i];
143 let cache_key = format!("{}:{}:{}", deg1, deg2, dim);
144 if self.cache.len() < self.config.cache_size {
145 self.cache.insert(cache_key, result);
146 }
147 }
148
149 Ok(results)
150 }
151
152 fn simd_compute_batch(
154 &mut self,
155 operations: &[(i64, i64, i64)],
156 ) -> EnumerativeResult<Vec<Rational64>> {
157 let batch_size = self.config.batch_size.min(operations.len());
158 let mut results = Vec::with_capacity(operations.len());
159
160 for chunk in operations.chunks(batch_size) {
161 let chunk_results = if self.config.enable_simd {
162 self.simd_intersection_chunk(chunk)?
163 } else {
164 self.scalar_intersection_chunk(chunk)?
165 };
166 results.extend(chunk_results);
167 }
168
169 Ok(results)
170 }
171
172 fn simd_intersection_chunk(
174 &mut self,
175 chunk: &[(i64, i64, i64)],
176 ) -> EnumerativeResult<Vec<Rational64>> {
177 self.coefficient_buffer.clear();
179 self.coefficient_buffer.resize(chunk.len() * 8, 0.0);
180
181 for (i, &(deg1, deg2, dim)) in chunk.iter().enumerate() {
183 let base_idx = i * 8;
184
185 self.coefficient_buffer[base_idx] = deg1 as f64;
187 self.coefficient_buffer[base_idx + 1] = deg2 as f64;
188 self.coefficient_buffer[base_idx + 2] = dim as f64;
189
190 self.coefficient_buffer[base_idx + 3] = (deg1 * deg2) as f64;
192 self.coefficient_buffer[base_idx + 4] = if deg1 + deg2 > dim { 0.0 } else { 1.0 };
193
194 self.coefficient_buffer[base_idx + 5] = ((deg1 + deg2) - dim) as f64;
196 self.coefficient_buffer[base_idx + 6] = (deg1.max(deg2)) as f64;
197 self.coefficient_buffer[base_idx + 7] = (deg1.min(deg2)) as f64;
198 }
199
200 let results = self.vectorized_bezout_computation(chunk.len())?;
202
203 Ok(results)
204 }
205
206 fn vectorized_bezout_computation(&self, count: usize) -> EnumerativeResult<Vec<Rational64>> {
208 let mut results = Vec::with_capacity(count);
209
210 for i in 0..count {
211 let base_idx = i * 8;
212 let deg_product = self.coefficient_buffer[base_idx + 3] as i64;
213 let is_valid = self.coefficient_buffer[base_idx + 4] > 0.5;
214
215 let result = if is_valid {
216 Rational64::from(deg_product)
217 } else {
218 Rational64::from(0)
219 };
220
221 results.push(result);
222 }
223
224 Ok(results)
225 }
226
227 fn scalar_intersection_chunk(
229 &self,
230 chunk: &[(i64, i64, i64)],
231 ) -> EnumerativeResult<Vec<Rational64>> {
232 let mut results = Vec::with_capacity(chunk.len());
233
234 for &(deg1, deg2, dim) in chunk {
235 let result = if deg1 + deg2 > dim {
236 Rational64::from(0) } else {
238 Rational64::from(deg1 * deg2) };
240 results.push(result);
241 }
242
243 Ok(results)
244 }
245
246 pub fn clear_cache(&mut self) {
248 self.cache.clear();
249 }
250
251 pub fn cache_stats(&self) -> (usize, usize) {
253 (self.cache.len(), self.config.cache_size)
254 }
255}
256
257#[cfg(feature = "wgpu")]
259#[derive(Debug)]
260#[allow(dead_code)]
261pub struct GpuContext {
262 device: wgpu::Device,
263 queue: wgpu::Queue,
264 compute_pipeline: wgpu::ComputePipeline,
265}
266
267#[cfg(feature = "wgpu")]
268impl GpuContext {
269 pub async fn new() -> EnumerativeResult<Self> {
271 let instance = wgpu::Instance::new(wgpu::InstanceDescriptor::default());
272
273 let adapter = instance
274 .request_adapter(&wgpu::RequestAdapterOptions::default())
275 .await
276 .ok_or_else(|| {
277 EnumerativeError::ComputationError("No GPU adapter found".to_string())
278 })?;
279
280 let (device, queue) = adapter
281 .request_device(&wgpu::DeviceDescriptor::default(), None)
282 .await
283 .map_err(|e| EnumerativeError::ComputationError(format!("GPU device error: {}", e)))?;
284
285 let shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
286 label: Some("Intersection Compute Shader"),
287 source: wgpu::ShaderSource::Wgsl(include_str!("shaders/intersection.wgsl").into()),
288 });
289
290 let compute_pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
291 label: Some("Intersection Pipeline"),
292 layout: None,
293 module: &shader,
294 entry_point: "main",
295 });
296
297 Ok(Self {
298 device,
299 queue,
300 compute_pipeline,
301 })
302 }
303}
304
305#[cfg(feature = "wgpu")]
306impl FastIntersectionComputer {
307 fn gpu_compute_batch(
309 &self,
310 _gpu: &GpuContext,
311 operations: &[(i64, i64, i64)],
312 ) -> EnumerativeResult<Vec<Rational64>> {
313 let mut input_data = Vec::with_capacity(operations.len() * 4);
315 for &(deg1, deg2, dim) in operations {
316 input_data.extend_from_slice(&[deg1 as f32, deg2 as f32, dim as f32, 0.0]);
317 }
318
319 Err(EnumerativeError::ComputationError(
322 "GPU functionality requires additional dependencies".to_string(),
323 ))
324 }
325}
326
327#[derive(Debug)]
329pub struct SparseSchubertMatrix {
330 entries: Vec<(usize, usize, Rational64)>,
332 rows: usize,
334 cols: usize,
335 row_index: HashMap<usize, Vec<usize>>,
337}
338
339impl SparseSchubertMatrix {
340 pub fn new(rows: usize, cols: usize) -> Self {
342 Self {
343 entries: Vec::new(),
344 rows,
345 cols,
346 row_index: HashMap::new(),
347 }
348 }
349
350 pub fn set(&mut self, row: usize, col: usize, value: Rational64) {
352 if value != Rational64::from(0) {
353 let entry_idx = self.entries.len();
354 self.entries.push((row, col, value));
355 self.row_index.entry(row).or_default().push(entry_idx);
356 }
357 }
358
359 pub fn get(&self, row: usize, col: usize) -> Rational64 {
361 if let Some(indices) = self.row_index.get(&row) {
362 for &idx in indices {
363 let (_, entry_col, value) = self.entries[idx];
364 if entry_col == col {
365 return value;
366 }
367 }
368 }
369 Rational64::from(0)
370 }
371
372 pub fn multiply_vector(&self, vector: &[Rational64]) -> EnumerativeResult<Vec<Rational64>> {
374 if vector.len() != self.cols {
375 return Err(EnumerativeError::InvalidDimension(format!(
376 "Vector length {} != matrix cols {}",
377 vector.len(),
378 self.cols
379 )));
380 }
381
382 let mut result = vec![Rational64::from(0); self.rows];
383
384 for &(row, col, value) in &self.entries {
385 result[row] += value * vector[col];
386 }
387
388 Ok(result)
389 }
390
391 pub fn memory_usage(&self) -> usize {
393 self.entries.len() * std::mem::size_of::<(usize, usize, Rational64)>()
394 + self.row_index.len() * std::mem::size_of::<(usize, Vec<usize>)>()
395 }
396}
397
398#[derive(Debug)]
400pub struct WasmCurveCounting {
401 config: WasmPerformanceConfig,
403 batch_processor: CurveBatchProcessor,
405 memory_pool: MemoryPool,
407}
408
409impl WasmCurveCounting {
410 pub fn new(config: WasmPerformanceConfig) -> Self {
412 let memory_pool = MemoryPool::new(config.memory_pool_mb * 1024 * 1024);
413 let batch_processor = CurveBatchProcessor::new(config.clone());
414
415 Self {
416 config,
417 batch_processor,
418 memory_pool,
419 }
420 }
421
422 pub fn count_curves_batch(
424 &mut self,
425 requests: &[CurveCountRequest],
426 ) -> EnumerativeResult<Vec<i64>> {
427 if requests.is_empty() {
428 return Ok(Vec::new());
429 }
430
431 let _allocation = self.memory_pool.allocate(requests.len() * 64)?;
433
434 let batch_size = self.config.batch_size;
436 let mut results = Vec::with_capacity(requests.len());
437
438 for chunk in requests.chunks(batch_size) {
439 let chunk_results = if self.config.enable_workers {
440 self.batch_processor.process_with_workers(chunk)?
441 } else {
442 self.batch_processor.process_sequential(chunk)?
443 };
444 results.extend(chunk_results);
445 }
446
447 Ok(results)
448 }
449
450 pub fn performance_metrics(&self) -> PerformanceMetrics {
452 PerformanceMetrics {
453 memory_pool_usage: self.memory_pool.usage_percentage(),
454 cache_hit_rate: self.batch_processor.cache_hit_rate(),
455 batch_count: self.batch_processor.batch_count(),
456 worker_utilization: if self.config.enable_workers { 0.8 } else { 1.0 },
457 }
458 }
459}
460
461#[derive(Debug, Clone)]
463pub struct CurveCountRequest {
464 pub target_space: String,
466 pub degree: i64,
468 pub genus: usize,
470 pub constraint_count: usize,
472}
473
474#[derive(Debug)]
476pub struct CurveBatchProcessor {
477 #[allow(dead_code)]
478 config: WasmPerformanceConfig,
479 cache_hits: usize,
480 cache_misses: usize,
481 batch_count: usize,
482}
483
484impl CurveBatchProcessor {
485 pub fn new(config: WasmPerformanceConfig) -> Self {
486 Self {
487 config,
488 cache_hits: 0,
489 cache_misses: 0,
490 batch_count: 0,
491 }
492 }
493
494 pub fn process_with_workers(
495 &mut self,
496 requests: &[CurveCountRequest],
497 ) -> EnumerativeResult<Vec<i64>> {
498 self.batch_count += 1;
499 Ok(requests
501 .iter()
502 .map(|req| req.degree * (req.genus as i64 + 1))
503 .collect())
504 }
505
506 pub fn process_sequential(
507 &mut self,
508 requests: &[CurveCountRequest],
509 ) -> EnumerativeResult<Vec<i64>> {
510 self.batch_count += 1;
511 Ok(requests
513 .iter()
514 .map(|req| req.degree * (req.genus as i64 + 1))
515 .collect())
516 }
517
518 pub fn cache_hit_rate(&self) -> f64 {
519 if self.cache_hits + self.cache_misses == 0 {
520 0.0
521 } else {
522 self.cache_hits as f64 / (self.cache_hits + self.cache_misses) as f64
523 }
524 }
525
526 pub fn batch_count(&self) -> usize {
527 self.batch_count
528 }
529}
530
531#[derive(Debug)]
533pub struct MemoryPool {
534 total_size: usize,
535 allocated: usize,
536}
537
538impl MemoryPool {
539 pub fn new(size: usize) -> Self {
540 Self {
541 total_size: size,
542 allocated: 0,
543 }
544 }
545
546 pub fn allocate(&mut self, size: usize) -> EnumerativeResult<MemoryAllocation> {
547 if self.allocated + size > self.total_size {
548 return Err(EnumerativeError::ComputationError(
549 "Memory pool exhausted".to_string(),
550 ));
551 }
552
553 self.allocated += size;
554 Ok(MemoryAllocation { size })
555 }
556
557 pub fn usage_percentage(&self) -> f64 {
558 self.allocated as f64 / self.total_size as f64 * 100.0
559 }
560}
561
562#[derive(Debug)]
564pub struct MemoryAllocation {
565 #[allow(dead_code)]
566 size: usize,
567}
568
569impl Drop for MemoryAllocation {
570 fn drop(&mut self) {
571 }
573}
574
575#[derive(Debug)]
577pub struct PerformanceMetrics {
578 pub memory_pool_usage: f64,
579 pub cache_hit_rate: f64,
580 pub batch_count: usize,
581 pub worker_utilization: f64,
582}
583
584#[cfg(feature = "wasm")]
586pub fn wasm_log(message: &str) {
587 console::log_1(&message.into());
588}
589
590#[cfg(not(feature = "wasm"))]
591pub fn wasm_log(message: &str) {
592 println!("{}", message);
593}
594
595pub fn benchmark_intersection_computation(
597 config: WasmPerformanceConfig,
598 operation_count: usize,
599) -> EnumerativeResult<f64> {
600 let start = std::time::Instant::now();
601
602 let mut computer = FastIntersectionComputer::new(config);
603
604 let operations: Vec<(i64, i64, i64)> = (0..operation_count)
606 .map(|i| ((i % 10 + 1) as i64, ((i + 1) % 10 + 1) as i64, 3))
607 .collect();
608
609 let _results = computer.fast_intersection_batch(&operations)?;
611
612 let duration = start.elapsed();
613 let operations_per_second = operation_count as f64 / duration.as_secs_f64();
614
615 Ok(operations_per_second)
616}
617
618#[cfg(feature = "wasm")]
620#[wasm_bindgen]
621pub struct WasmEnumerativeAPI {
622 curve_counter: WasmCurveCounting,
623 intersection_computer: FastIntersectionComputer,
624}
625
626#[cfg(feature = "wasm")]
627#[wasm_bindgen]
628impl WasmEnumerativeAPI {
629 #[wasm_bindgen(constructor)]
630 pub fn new() -> Self {
631 let config = WasmPerformanceConfig::default();
632 Self {
633 curve_counter: WasmCurveCounting::new(config.clone()),
634 intersection_computer: FastIntersectionComputer::new(config),
635 }
636 }
637
638 #[wasm_bindgen]
639 pub fn count_curves(&mut self, degree: i64, genus: u32) -> i64 {
640 let request = CurveCountRequest {
641 target_space: "P2".to_string(),
642 degree,
643 genus: genus as usize,
644 constraint_count: 3,
645 };
646
647 self.curve_counter
648 .count_curves_batch(&[request])
649 .unwrap_or_else(|_| vec![0])[0]
650 }
651
652 #[wasm_bindgen]
653 pub fn intersection_number(&mut self, deg1: i64, deg2: i64, dim: i64) -> f64 {
654 let operations = vec![(deg1, deg2, dim)];
655 let results = self
656 .intersection_computer
657 .fast_intersection_batch(&operations)
658 .unwrap_or_else(|_| vec![Rational64::from(0)]);
659
660 results[0].to_f64().unwrap_or(0.0)
661 }
662
663 #[wasm_bindgen]
664 pub fn performance_summary(&self) -> String {
665 let metrics = self.curve_counter.performance_metrics();
666 format!(
667 "Memory: {:.1}%, Cache: {:.1}%, Batches: {}, Workers: {:.1}%",
668 metrics.memory_pool_usage,
669 metrics.cache_hit_rate * 100.0,
670 metrics.batch_count,
671 metrics.worker_utilization * 100.0
672 )
673 }
674}
675
676#[cfg(feature = "wasm")]
677impl Default for WasmEnumerativeAPI {
678 fn default() -> Self {
679 Self::new()
680 }
681}