amari_gpu/
relativistic.rs

1//! GPU acceleration for relativistic physics computations
2//!
3//! This module provides GPU-accelerated implementations of relativistic physics
4//! operations from amari-relativistic, including spacetime algebra operations,
5//! geodesic integration, and particle trajectory calculations for spacecraft
6//! orbital mechanics and plasma physics applications.
7
8use crate::GpuError;
9use amari_relativistic::{particle::RelativisticParticle, spacetime::SpacetimeVector};
10use bytemuck::{Pod, Zeroable};
11use wgpu::util::DeviceExt;
12
13/// GPU-accelerated spacetime vector operations using Cl(1,3) signature
14#[repr(C)]
15#[derive(Copy, Clone, Debug, Pod, Zeroable)]
16pub struct GpuSpacetimeVector {
17    /// Temporal component (ct)
18    pub t: f32,
19    /// Spatial x component
20    pub x: f32,
21    /// Spatial y component
22    pub y: f32,
23    /// Spatial z component
24    pub z: f32,
25}
26
27impl GpuSpacetimeVector {
28    /// Create new GPU spacetime vector
29    pub fn new(t: f32, x: f32, y: f32, z: f32) -> Self {
30        Self { t, x, y, z }
31    }
32
33    /// Convert from CPU spacetime vector
34    pub fn from_spacetime_vector(sv: &SpacetimeVector) -> Self {
35        Self::new(
36            sv.time() as f32,
37            sv.x() as f32,
38            sv.y() as f32,
39            sv.z() as f32,
40        )
41    }
42
43    /// Convert to CPU spacetime vector
44    pub fn to_spacetime_vector(&self) -> SpacetimeVector {
45        SpacetimeVector::new(self.t as f64, self.x as f64, self.y as f64, self.z as f64)
46    }
47}
48
49/// GPU-accelerated relativistic particle for trajectory calculations
50#[repr(C)]
51#[derive(Copy, Clone, Debug, Pod, Zeroable)]
52pub struct GpuRelativisticParticle {
53    /// Spacetime position
54    pub position: GpuSpacetimeVector,
55    /// Four-velocity
56    pub velocity: GpuSpacetimeVector,
57    /// Rest mass
58    pub mass: f32,
59    /// Electric charge
60    pub charge: f32,
61    /// Proper time
62    pub proper_time: f32,
63    /// Padding for alignment
64    pub _padding: [f32; 3],
65}
66
67/// GPU-accelerated trajectory calculation parameters
68#[repr(C)]
69#[derive(Copy, Clone, Debug, Pod, Zeroable)]
70pub struct GpuTrajectoryParams {
71    /// Integration time step
72    pub dt: f32,
73    /// Number of integration steps
74    pub steps: u32,
75    /// Normalization tolerance
76    pub tolerance: f32,
77    /// Renormalization frequency
78    pub renorm_freq: u32,
79    /// Schwarzschild radius (for gravitational fields)
80    pub schwarzschild_radius: f32,
81    /// Central mass parameter (GM)
82    pub gm_parameter: f32,
83    /// Padding for alignment
84    pub _padding: [f32; 2],
85}
86
87/// GPU compute context for relativistic physics
88pub struct GpuRelativisticPhysics {
89    device: wgpu::Device,
90    queue: wgpu::Queue,
91    spacetime_pipeline: wgpu::ComputePipeline,
92    geodesic_pipeline: wgpu::ComputePipeline,
93    #[allow(dead_code)]
94    trajectory_pipeline: wgpu::ComputePipeline,
95}
96
97impl GpuRelativisticPhysics {
98    /// Initialize GPU context for relativistic physics computations
99    pub async fn new() -> Result<Self, GpuError> {
100        let instance = wgpu::Instance::default();
101
102        let adapter = instance
103            .request_adapter(&wgpu::RequestAdapterOptions::default())
104            .await
105            .ok_or_else(|| {
106                GpuError::InitializationError("No suitable GPU adapter found".to_string())
107            })?;
108
109        let (device, queue) = adapter
110            .request_device(
111                &wgpu::DeviceDescriptor {
112                    label: Some("Relativistic Physics GPU"),
113                    required_features: wgpu::Features::empty(),
114                    required_limits: wgpu::Limits::default(),
115                },
116                None,
117            )
118            .await
119            .map_err(|e| {
120                GpuError::InitializationError(format!("Failed to create device: {}", e))
121            })?;
122
123        // Compile compute shaders for different operations
124        let spacetime_pipeline = Self::create_spacetime_pipeline(&device)?;
125        let geodesic_pipeline = Self::create_geodesic_pipeline(&device)?;
126        let trajectory_pipeline = Self::create_trajectory_pipeline(&device)?;
127
128        Ok(Self {
129            device,
130            queue,
131            spacetime_pipeline,
132            geodesic_pipeline,
133            trajectory_pipeline,
134        })
135    }
136
137    /// Create compute pipeline for spacetime algebra operations
138    fn create_spacetime_pipeline(device: &wgpu::Device) -> Result<wgpu::ComputePipeline, GpuError> {
139        let shader_source = r#"
140            @group(0) @binding(0) var<storage, read_write> vectors: array<vec4<f32>>;
141            @group(0) @binding(1) var<storage, read_write> results: array<f32>;
142
143            @compute @workgroup_size(64)
144            fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
145                let index = global_id.x;
146                if (index >= arrayLength(&vectors)) {
147                    return;
148                }
149
150                let v = vectors[index];
151
152                // Minkowski inner product: t² - x² - y² - z²
153                let minkowski_norm_sq = v.x * v.x - v.y * v.y - v.z * v.z - v.w * v.w;
154                results[index] = minkowski_norm_sq;
155
156                // Normalize four-velocity if needed (u·u = c²)
157                let c_sq = 299792458.0 * 299792458.0;
158                if (abs(minkowski_norm_sq - c_sq) > 1e-6) {
159                    let norm = sqrt(abs(minkowski_norm_sq));
160                    if (norm > 1e-12) {
161                        let factor = sqrt(c_sq) / norm;
162                        vectors[index] = vec4<f32>(v.x * factor, v.y * factor, v.z * factor, v.w * factor);
163                    }
164                }
165            }
166        "#;
167
168        let shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
169            label: Some("Spacetime Algebra Compute Shader"),
170            source: wgpu::ShaderSource::Wgsl(shader_source.into()),
171        });
172
173        let bind_group_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
174            label: Some("Spacetime Bind Group Layout"),
175            entries: &[
176                wgpu::BindGroupLayoutEntry {
177                    binding: 0,
178                    visibility: wgpu::ShaderStages::COMPUTE,
179                    ty: wgpu::BindingType::Buffer {
180                        ty: wgpu::BufferBindingType::Storage { read_only: false },
181                        has_dynamic_offset: false,
182                        min_binding_size: None,
183                    },
184                    count: None,
185                },
186                wgpu::BindGroupLayoutEntry {
187                    binding: 1,
188                    visibility: wgpu::ShaderStages::COMPUTE,
189                    ty: wgpu::BindingType::Buffer {
190                        ty: wgpu::BufferBindingType::Storage { read_only: false },
191                        has_dynamic_offset: false,
192                        min_binding_size: None,
193                    },
194                    count: None,
195                },
196            ],
197        });
198
199        let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
200            label: Some("Spacetime Pipeline Layout"),
201            bind_group_layouts: &[&bind_group_layout],
202            push_constant_ranges: &[],
203        });
204
205        Ok(
206            device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
207                label: Some("Spacetime Compute Pipeline"),
208                layout: Some(&pipeline_layout),
209                module: &shader,
210                entry_point: "main",
211            }),
212        )
213    }
214
215    /// Create compute pipeline for geodesic integration
216    fn create_geodesic_pipeline(device: &wgpu::Device) -> Result<wgpu::ComputePipeline, GpuError> {
217        let shader_source = r#"
218            struct Particle {
219                position: vec4<f32>,
220                velocity: vec4<f32>,
221                mass: f32,
222                charge: f32,
223                proper_time: f32,
224                padding: f32,
225            };
226
227            struct TrajectoryParams {
228                dt: f32,
229                steps: u32,
230                tolerance: f32,
231                renorm_freq: u32,
232                rs: f32,
233                gm: f32,
234                padding: vec2<f32>,
235            };
236
237            @group(0) @binding(0) var<storage, read_write> particles: array<Particle>;
238            @group(0) @binding(1) var<uniform> params: TrajectoryParams;
239
240            // Schwarzschild metric Christoffel symbols (simplified)
241            fn christoffel_t_rr(r: f32, rs: f32) -> f32 {
242                let factor = rs / (2.0 * r * r);
243                return factor * (1.0 - rs / r);
244            }
245
246            fn christoffel_r_tr(r: f32, rs: f32) -> f32 {
247                return rs / (2.0 * r * r * (1.0 - rs / r));
248            }
249
250            fn christoffel_r_rr(r: f32, rs: f32) -> f32 {
251                return -rs / (2.0 * r * (r - rs));
252            }
253
254            @compute @workgroup_size(64)
255            fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
256                let index = global_id.x;
257                if (index >= arrayLength(&particles)) {
258                    return;
259                }
260
261                var particle = particles[index];
262                let pos = particle.position;
263                let vel = particle.velocity;
264
265                // Compute spatial radius
266                let r = sqrt(pos.y * pos.y + pos.z * pos.z + pos.w * pos.w);
267
268                if (r < params.rs * 1.1) {
269                    // Too close to singularity, skip
270                    return;
271                }
272
273                // Velocity Verlet step for geodesic equation
274                // Simplified for Schwarzschild metric
275
276                // Compute acceleration components
277                let c_trr = christoffel_t_rr(r, params.rs);
278                let c_rtr = christoffel_r_tr(r, params.rs);
279                let c_rrr = christoffel_r_rr(r, params.rs);
280
281                // Geodesic equation: d²x^μ/dτ² = -Γ^μ_αβ v^α v^β
282                var accel = vec4<f32>(0.0, 0.0, 0.0, 0.0);
283
284                // Simplified acceleration calculation
285                accel.x = -c_trr * vel.y * vel.y; // dt component
286                accel.y = -c_rtr * vel.x * vel.y - c_rrr * vel.y * vel.y; // dr component
287
288                // Update position and velocity
289                let dt = params.dt;
290                particle.position = pos + vel * dt + 0.5 * accel * dt * dt;
291                particle.velocity = vel + accel * dt;
292
293                // Renormalize four-velocity periodically
294                let step_mod = u32(particle.proper_time / dt) % params.renorm_freq;
295                if (step_mod == 0u) {
296                    let c_sq = 299792458.0 * 299792458.0;
297                    let norm_sq = particle.velocity.x * particle.velocity.x -
298                                  particle.velocity.y * particle.velocity.y -
299                                  particle.velocity.z * particle.velocity.z -
300                                  particle.velocity.w * particle.velocity.w;
301
302                    if (abs(norm_sq - c_sq) > params.tolerance) {
303                        let norm = sqrt(abs(norm_sq));
304                        if (norm > 1e-12) {
305                            let factor = sqrt(c_sq) / norm;
306                            particle.velocity *= factor;
307                        }
308                    }
309                }
310
311                particle.proper_time += dt;
312                particles[index] = particle;
313            }
314        "#;
315
316        let shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
317            label: Some("Geodesic Integration Compute Shader"),
318            source: wgpu::ShaderSource::Wgsl(shader_source.into()),
319        });
320
321        let bind_group_layout = device.create_bind_group_layout(&wgpu::BindGroupLayoutDescriptor {
322            label: Some("Geodesic Bind Group Layout"),
323            entries: &[
324                wgpu::BindGroupLayoutEntry {
325                    binding: 0,
326                    visibility: wgpu::ShaderStages::COMPUTE,
327                    ty: wgpu::BindingType::Buffer {
328                        ty: wgpu::BufferBindingType::Storage { read_only: false },
329                        has_dynamic_offset: false,
330                        min_binding_size: None,
331                    },
332                    count: None,
333                },
334                wgpu::BindGroupLayoutEntry {
335                    binding: 1,
336                    visibility: wgpu::ShaderStages::COMPUTE,
337                    ty: wgpu::BindingType::Buffer {
338                        ty: wgpu::BufferBindingType::Uniform,
339                        has_dynamic_offset: false,
340                        min_binding_size: None,
341                    },
342                    count: None,
343                },
344            ],
345        });
346
347        let pipeline_layout = device.create_pipeline_layout(&wgpu::PipelineLayoutDescriptor {
348            label: Some("Geodesic Pipeline Layout"),
349            bind_group_layouts: &[&bind_group_layout],
350            push_constant_ranges: &[],
351        });
352
353        Ok(
354            device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
355                label: Some("Geodesic Compute Pipeline"),
356                layout: Some(&pipeline_layout),
357                module: &shader,
358                entry_point: "main",
359            }),
360        )
361    }
362
363    /// Create compute pipeline for trajectory calculations
364    fn create_trajectory_pipeline(
365        device: &wgpu::Device,
366    ) -> Result<wgpu::ComputePipeline, GpuError> {
367        // For now, use the same pipeline as geodesic integration
368        Self::create_geodesic_pipeline(device)
369    }
370
371    /// Compute Minkowski inner products for multiple spacetime vectors
372    pub async fn compute_minkowski_products(
373        &self,
374        vectors: &[GpuSpacetimeVector],
375    ) -> Result<Vec<f32>, GpuError> {
376        let vectors_buffer = self
377            .device
378            .create_buffer_init(&wgpu::util::BufferInitDescriptor {
379                label: Some("Spacetime Vectors Buffer"),
380                contents: bytemuck::cast_slice(vectors),
381                usage: wgpu::BufferUsages::STORAGE
382                    | wgpu::BufferUsages::COPY_DST
383                    | wgpu::BufferUsages::COPY_SRC,
384            });
385
386        let results_buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
387            label: Some("Results Buffer"),
388            size: (vectors.len() * std::mem::size_of::<f32>()) as u64,
389            usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
390            mapped_at_creation: false,
391        });
392
393        let bind_group_layout = self.spacetime_pipeline.get_bind_group_layout(0);
394        let bind_group = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
395            label: Some("Spacetime Bind Group"),
396            layout: &bind_group_layout,
397            entries: &[
398                wgpu::BindGroupEntry {
399                    binding: 0,
400                    resource: vectors_buffer.as_entire_binding(),
401                },
402                wgpu::BindGroupEntry {
403                    binding: 1,
404                    resource: results_buffer.as_entire_binding(),
405                },
406            ],
407        });
408
409        let mut encoder = self
410            .device
411            .create_command_encoder(&wgpu::CommandEncoderDescriptor {
412                label: Some("Spacetime Compute Encoder"),
413            });
414
415        {
416            let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
417                label: Some("Spacetime Compute Pass"),
418                timestamp_writes: None,
419            });
420
421            compute_pass.set_pipeline(&self.spacetime_pipeline);
422            compute_pass.set_bind_group(0, &bind_group, &[]);
423
424            let workgroup_size = 64;
425            let num_workgroups = vectors.len().div_ceil(workgroup_size);
426            compute_pass.dispatch_workgroups(num_workgroups as u32, 1, 1);
427        }
428
429        // Read back results
430        let staging_buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
431            label: Some("Staging Buffer"),
432            size: results_buffer.size(),
433            usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
434            mapped_at_creation: false,
435        });
436
437        encoder.copy_buffer_to_buffer(
438            &results_buffer,
439            0,
440            &staging_buffer,
441            0,
442            results_buffer.size(),
443        );
444
445        self.queue.submit([encoder.finish()]);
446
447        let buffer_slice = staging_buffer.slice(..);
448        let (sender, receiver) = futures::channel::oneshot::channel();
449        buffer_slice.map_async(wgpu::MapMode::Read, move |v| sender.send(v).unwrap());
450
451        self.device.poll(wgpu::Maintain::wait()).panic_on_timeout();
452        receiver
453            .await
454            .unwrap()
455            .map_err(|e| GpuError::BufferError(format!("Buffer mapping failed: {:?}", e)))?;
456
457        let data = buffer_slice.get_mapped_range();
458        let results: Vec<f32> = bytemuck::cast_slice(&data).to_vec();
459
460        drop(data);
461        staging_buffer.unmap();
462
463        Ok(results)
464    }
465
466    /// Propagate multiple particles through spacetime using GPU acceleration
467    pub async fn propagate_particles(
468        &self,
469        particles: &[GpuRelativisticParticle],
470        params: &GpuTrajectoryParams,
471    ) -> Result<Vec<GpuRelativisticParticle>, GpuError> {
472        let particles_buffer = self
473            .device
474            .create_buffer_init(&wgpu::util::BufferInitDescriptor {
475                label: Some("Particles Buffer"),
476                contents: bytemuck::cast_slice(particles),
477                usage: wgpu::BufferUsages::STORAGE
478                    | wgpu::BufferUsages::COPY_DST
479                    | wgpu::BufferUsages::COPY_SRC,
480            });
481
482        let params_buffer = self
483            .device
484            .create_buffer_init(&wgpu::util::BufferInitDescriptor {
485                label: Some("Trajectory Params Buffer"),
486                contents: bytemuck::cast_slice(&[*params]),
487                usage: wgpu::BufferUsages::UNIFORM,
488            });
489
490        let bind_group_layout = self.geodesic_pipeline.get_bind_group_layout(0);
491        let bind_group = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
492            label: Some("Geodesic Bind Group"),
493            layout: &bind_group_layout,
494            entries: &[
495                wgpu::BindGroupEntry {
496                    binding: 0,
497                    resource: particles_buffer.as_entire_binding(),
498                },
499                wgpu::BindGroupEntry {
500                    binding: 1,
501                    resource: params_buffer.as_entire_binding(),
502                },
503            ],
504        });
505
506        // Execute integration steps
507        for _ in 0..params.steps {
508            let mut encoder = self
509                .device
510                .create_command_encoder(&wgpu::CommandEncoderDescriptor {
511                    label: Some("Geodesic Compute Encoder"),
512                });
513
514            {
515                let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
516                    label: Some("Geodesic Compute Pass"),
517                    timestamp_writes: None,
518                });
519
520                compute_pass.set_pipeline(&self.geodesic_pipeline);
521                compute_pass.set_bind_group(0, &bind_group, &[]);
522
523                let workgroup_size = 64;
524                let num_workgroups = particles.len().div_ceil(workgroup_size);
525                compute_pass.dispatch_workgroups(num_workgroups as u32, 1, 1);
526            }
527
528            self.queue.submit([encoder.finish()]);
529            self.device.poll(wgpu::Maintain::wait()).panic_on_timeout();
530        }
531
532        // Read back final particle states
533        let staging_buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
534            label: Some("Particles Staging Buffer"),
535            size: particles_buffer.size(),
536            usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
537            mapped_at_creation: false,
538        });
539
540        let mut encoder = self
541            .device
542            .create_command_encoder(&wgpu::CommandEncoderDescriptor {
543                label: Some("Copy Encoder"),
544            });
545
546        encoder.copy_buffer_to_buffer(
547            &particles_buffer,
548            0,
549            &staging_buffer,
550            0,
551            particles_buffer.size(),
552        );
553        self.queue.submit([encoder.finish()]);
554
555        let buffer_slice = staging_buffer.slice(..);
556        let (sender, receiver) = futures::channel::oneshot::channel();
557        buffer_slice.map_async(wgpu::MapMode::Read, move |v| sender.send(v).unwrap());
558
559        self.device.poll(wgpu::Maintain::wait()).panic_on_timeout();
560        receiver
561            .await
562            .unwrap()
563            .map_err(|e| GpuError::BufferError(format!("Buffer mapping failed: {:?}", e)))?;
564
565        let data = buffer_slice.get_mapped_range();
566        let results: Vec<GpuRelativisticParticle> = bytemuck::cast_slice(&data).to_vec();
567
568        drop(data);
569        staging_buffer.unmap();
570
571        Ok(results)
572    }
573}
574
575/// Convert CPU relativistic particle to GPU format
576impl From<&RelativisticParticle> for GpuRelativisticParticle {
577    fn from(particle: &RelativisticParticle) -> Self {
578        let pos = &particle.position;
579        let vel = particle.four_velocity.as_spacetime_vector();
580
581        Self {
582            position: GpuSpacetimeVector::from_spacetime_vector(pos),
583            velocity: GpuSpacetimeVector::from_spacetime_vector(vel),
584            mass: particle.mass as f32,
585            charge: particle.charge as f32,
586            proper_time: 0.0, // Will be updated during integration
587            _padding: [0.0; 3],
588        }
589    }
590}
591
592#[cfg(test)]
593mod tests {
594    use super::*;
595
596    #[test]
597    fn test_gpu_spacetime_vector_conversion() {
598        let cpu_vector = SpacetimeVector::new(1.0, 2.0, 3.0, 4.0);
599        let gpu_vector = GpuSpacetimeVector::from_spacetime_vector(&cpu_vector);
600        let converted_back = gpu_vector.to_spacetime_vector();
601
602        assert_eq!(converted_back.time(), 1.0);
603        assert_eq!(converted_back.x(), 2.0);
604        assert_eq!(converted_back.y(), 3.0);
605        assert_eq!(converted_back.z(), 4.0);
606    }
607
608    #[tokio::test]
609    #[ignore] // Skip in CI due to GPU hardware requirements
610    async fn test_gpu_minkowski_products() {
611        let gpu_physics = match GpuRelativisticPhysics::new().await {
612            Ok(physics) => physics,
613            Err(_) => {
614                println!("GPU not available, skipping test");
615                return;
616            }
617        };
618
619        let vectors = vec![
620            GpuSpacetimeVector::new(1.0, 0.5, 0.0, 0.0),
621            GpuSpacetimeVector::new(2.0, 1.0, 0.0, 0.0),
622        ];
623
624        let results = gpu_physics
625            .compute_minkowski_products(&vectors)
626            .await
627            .unwrap();
628
629        // Check that we got results for each vector
630        assert_eq!(results.len(), vectors.len());
631
632        // Verify Minkowski inner product calculation (t² - x² - y² - z²)
633        assert!((results[0] - (1.0 - 0.25)).abs() < 1e-6); // 1² - 0.5² = 0.75
634        assert!((results[1] - (4.0 - 1.0)).abs() < 1e-6); // 2² - 1² = 3.0
635    }
636}