1use crate::GpuError;
7use amari_network::{Community, GeometricNetwork};
8use bytemuck::{Pod, Zeroable};
9use futures::channel::oneshot;
10use thiserror::Error;
11use wgpu::util::DeviceExt;
12
13#[derive(Error, Debug)]
14pub enum GpuNetworkError {
15 #[error("GPU error: {0}")]
16 Gpu(#[from] GpuError),
17
18 #[error("Network error: {0}")]
19 Network(#[from] amari_network::NetworkError),
20
21 #[error("Invalid network size: {0}")]
22 InvalidSize(usize),
23
24 #[error("Buffer error: {0}")]
25 BufferError(String),
26}
27
28pub type GpuNetworkResult<T> = Result<T, GpuNetworkError>;
29
30pub struct GpuGeometricNetwork {
32 device: wgpu::Device,
33 queue: wgpu::Queue,
34 distance_pipeline: wgpu::ComputePipeline,
35 #[allow(dead_code)]
36 centrality_pipeline: wgpu::ComputePipeline,
37 #[allow(dead_code)]
38 clustering_pipeline: wgpu::ComputePipeline,
39}
40
41#[repr(C)]
43#[derive(Copy, Clone, Pod, Zeroable)]
44struct GpuNodePosition {
45 x: f32,
46 y: f32,
47 z: f32,
48 padding: f32, }
50
51#[repr(C)]
53#[derive(Copy, Clone, Pod, Zeroable)]
54struct GpuEdgeData {
55 source: u32,
56 target: u32,
57 weight: f32,
58 padding: f32,
59}
60
61impl GpuGeometricNetwork {
62 pub async fn new() -> GpuNetworkResult<Self> {
64 let instance = wgpu::Instance::default();
65
66 let adapter = instance
67 .request_adapter(&wgpu::RequestAdapterOptions {
68 power_preference: wgpu::PowerPreference::HighPerformance,
69 compatible_surface: None,
70 force_fallback_adapter: false,
71 })
72 .await
73 .ok_or_else(|| GpuError::InitializationError("No GPU adapter found".to_string()))?;
74
75 let (device, queue) = adapter
76 .request_device(
77 &wgpu::DeviceDescriptor {
78 label: Some("Amari Network GPU Device"),
79 required_features: wgpu::Features::empty(),
80 required_limits: wgpu::Limits::default(),
81 },
82 None,
83 )
84 .await
85 .map_err(|e| GpuError::InitializationError(e.to_string()))?;
86
87 let distance_pipeline = Self::create_distance_pipeline(&device)?;
88 let centrality_pipeline = Self::create_centrality_pipeline(&device)?;
89 let clustering_pipeline = Self::create_clustering_pipeline(&device)?;
90
91 Ok(Self {
92 device,
93 queue,
94 distance_pipeline,
95 centrality_pipeline,
96 clustering_pipeline,
97 })
98 }
99
100 pub async fn compute_all_pairwise_distances<const P: usize, const Q: usize, const R: usize>(
102 &self,
103 network: &GeometricNetwork<P, Q, R>,
104 ) -> GpuNetworkResult<Vec<Vec<f64>>> {
105 let num_nodes = network.num_nodes();
106 if num_nodes == 0 {
107 return Ok(Vec::new());
108 }
109
110 let gpu_positions: Vec<GpuNodePosition> = (0..num_nodes)
112 .map(|i| {
113 let pos = network.get_node(i).unwrap();
114 GpuNodePosition {
115 x: pos.vector_component(0) as f32,
116 y: pos.vector_component(1) as f32,
117 z: pos.vector_component(2) as f32,
118 padding: 0.0,
119 }
120 })
121 .collect();
122
123 let positions_buffer = self
125 .device
126 .create_buffer_init(&wgpu::util::BufferInitDescriptor {
127 label: Some("Node Positions"),
128 contents: bytemuck::cast_slice(&gpu_positions),
129 usage: wgpu::BufferUsages::STORAGE,
130 });
131
132 let output_size = num_nodes * num_nodes * 4; let output_buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
134 label: Some("Distance Output"),
135 size: output_size as u64,
136 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC,
137 mapped_at_creation: false,
138 });
139
140 let staging_buffer = self.device.create_buffer(&wgpu::BufferDescriptor {
141 label: Some("Distance Staging"),
142 size: output_size as u64,
143 usage: wgpu::BufferUsages::MAP_READ | wgpu::BufferUsages::COPY_DST,
144 mapped_at_creation: false,
145 });
146
147 let bind_group_layout = self.distance_pipeline.get_bind_group_layout(0);
149 let bind_group = self.device.create_bind_group(&wgpu::BindGroupDescriptor {
150 label: Some("Distance Compute Bind Group"),
151 layout: &bind_group_layout,
152 entries: &[
153 wgpu::BindGroupEntry {
154 binding: 0,
155 resource: positions_buffer.as_entire_binding(),
156 },
157 wgpu::BindGroupEntry {
158 binding: 1,
159 resource: output_buffer.as_entire_binding(),
160 },
161 ],
162 });
163
164 let mut encoder = self
166 .device
167 .create_command_encoder(&wgpu::CommandEncoderDescriptor {
168 label: Some("Distance Compute Encoder"),
169 });
170
171 {
172 let mut compute_pass = encoder.begin_compute_pass(&wgpu::ComputePassDescriptor {
173 label: Some("Distance Compute Pass"),
174 timestamp_writes: None,
175 });
176 compute_pass.set_pipeline(&self.distance_pipeline);
177 compute_pass.set_bind_group(0, &bind_group, &[]);
178 let workgroup_count = num_nodes.div_ceil(64);
179 compute_pass.dispatch_workgroups(workgroup_count as u32, workgroup_count as u32, 1);
180 }
181
182 encoder.copy_buffer_to_buffer(&output_buffer, 0, &staging_buffer, 0, output_size as u64);
183
184 self.queue.submit(std::iter::once(encoder.finish()));
185
186 let buffer_slice = staging_buffer.slice(..);
188 let (sender, receiver) = oneshot::channel();
189 buffer_slice.map_async(wgpu::MapMode::Read, move |result| {
190 let _ = sender.send(result);
191 });
192
193 self.device.poll(wgpu::Maintain::Wait);
194
195 receiver
196 .await
197 .map_err(|_| {
198 GpuNetworkError::BufferError("Failed to receive buffer mapping".to_string())
199 })?
200 .map_err(|e| GpuNetworkError::BufferError(format!("Buffer mapping failed: {:?}", e)))?;
201
202 let data = buffer_slice.get_mapped_range();
203 let result_f32: &[f32] = bytemuck::cast_slice(&data);
204
205 let mut distances = vec![vec![0.0; num_nodes]; num_nodes];
207 for i in 0..num_nodes {
208 for j in 0..num_nodes {
209 distances[i][j] = result_f32[i * num_nodes + j] as f64;
210 }
211 }
212
213 drop(data);
214 staging_buffer.unmap();
215
216 Ok(distances)
217 }
218
219 pub async fn compute_geometric_centrality<const P: usize, const Q: usize, const R: usize>(
221 &self,
222 network: &GeometricNetwork<P, Q, R>,
223 ) -> GpuNetworkResult<Vec<f64>> {
224 let distances = self.compute_all_pairwise_distances(network).await?;
226
227 let num_nodes = network.num_nodes();
230 let mut centrality = vec![0.0; num_nodes];
231
232 for i in 0..num_nodes {
233 let total_distance: f64 = distances[i].iter().sum();
234 centrality[i] = if total_distance > 0.0 {
235 (num_nodes as f64 - 1.0) / total_distance
236 } else {
237 0.0
238 };
239 }
240
241 Ok(centrality)
242 }
243
244 pub async fn geometric_clustering<const P: usize, const Q: usize, const R: usize>(
246 &self,
247 network: &GeometricNetwork<P, Q, R>,
248 k: usize,
249 max_iterations: usize,
250 ) -> GpuNetworkResult<Vec<Community<P, Q, R>>> {
251 let num_nodes = network.num_nodes();
252 if k > num_nodes || k == 0 {
253 return Err(GpuNetworkError::InvalidSize(k));
254 }
255
256 let distances = self.compute_all_pairwise_distances(network).await?;
258
259 let mut centroids = Vec::with_capacity(k);
261 for i in 0..k {
262 let centroid_idx = (i * num_nodes) / k;
263 centroids.push(centroid_idx);
264 }
265
266 let mut assignments = vec![0; num_nodes];
267
268 for _iteration in 0..max_iterations {
269 let mut changed = false;
270
271 for node in 0..num_nodes {
273 let mut best_cluster = 0;
274 let mut best_distance = f64::INFINITY;
275
276 for (cluster, ¢roid) in centroids.iter().enumerate().take(k) {
277 let distance = distances[node][centroid];
278
279 if distance < best_distance {
280 best_distance = distance;
281 best_cluster = cluster;
282 }
283 }
284
285 if assignments[node] != best_cluster {
286 assignments[node] = best_cluster;
287 changed = true;
288 }
289 }
290
291 if !changed {
292 break;
293 }
294
295 for (cluster, centroid) in centroids.iter_mut().enumerate().take(k) {
297 let cluster_nodes: Vec<usize> = assignments
298 .iter()
299 .enumerate()
300 .filter(|(_, &c)| c == cluster)
301 .map(|(node, _)| node)
302 .collect();
303
304 if !cluster_nodes.is_empty() {
305 let mut best_medoid = cluster_nodes[0];
306 let mut best_total_distance = f64::INFINITY;
307
308 for &candidate in &cluster_nodes {
309 let total_distance: f64 = cluster_nodes
310 .iter()
311 .map(|&other| distances[candidate][other])
312 .sum();
313
314 if total_distance < best_total_distance {
315 best_total_distance = total_distance;
316 best_medoid = candidate;
317 }
318 }
319
320 *centroid = best_medoid;
321 }
322 }
323 }
324
325 let mut communities = Vec::with_capacity(k);
327 for (cluster, ¢roid) in centroids.iter().enumerate().take(k) {
328 let nodes: Vec<usize> = assignments
329 .iter()
330 .enumerate()
331 .filter(|(_, &c)| c == cluster)
332 .map(|(node, _)| node)
333 .collect();
334
335 if !nodes.is_empty() {
336 let centroid_pos = network.get_node(centroid).unwrap().clone();
337 communities.push(Community {
338 nodes,
339 geometric_centroid: centroid_pos,
340 cohesion_score: 1.0, });
342 }
343 }
344
345 Ok(communities)
346 }
347
348 pub fn should_use_gpu(num_nodes: usize) -> bool {
350 num_nodes >= 100
352 }
353
354 fn create_distance_pipeline(device: &wgpu::Device) -> Result<wgpu::ComputePipeline, GpuError> {
357 let shader = device.create_shader_module(wgpu::ShaderModuleDescriptor {
358 label: Some("Distance Compute Shader"),
359 source: wgpu::ShaderSource::Wgsl(DISTANCE_COMPUTE_SHADER.into()),
360 });
361
362 let pipeline = device.create_compute_pipeline(&wgpu::ComputePipelineDescriptor {
363 label: Some("Distance Compute Pipeline"),
364 layout: None,
365 module: &shader,
366 entry_point: "main",
367 });
368
369 Ok(pipeline)
370 }
371
372 fn create_centrality_pipeline(
373 device: &wgpu::Device,
374 ) -> Result<wgpu::ComputePipeline, GpuError> {
375 Self::create_distance_pipeline(device)
377 }
378
379 fn create_clustering_pipeline(
380 device: &wgpu::Device,
381 ) -> Result<wgpu::ComputePipeline, GpuError> {
382 Self::create_distance_pipeline(device)
384 }
385}
386
387const DISTANCE_COMPUTE_SHADER: &str = r#"
389struct NodePosition {
390 x: f32,
391 y: f32,
392 z: f32,
393 padding: f32,
394}
395
396@group(0) @binding(0)
397var<storage, read> positions: array<NodePosition>;
398
399@group(0) @binding(1)
400var<storage, read_write> distances: array<f32>;
401
402@compute @workgroup_size(8, 8)
403fn main(@builtin(global_invocation_id) global_id: vec3<u32>) {
404 let i = global_id.x;
405 let j = global_id.y;
406 let num_nodes = arrayLength(&positions);
407
408 if (i >= num_nodes || j >= num_nodes) {
409 return;
410 }
411
412 let idx = i * num_nodes + j;
413
414 if (i == j) {
415 distances[idx] = 0.0;
416 return;
417 }
418
419 let pos_i = positions[i];
420 let pos_j = positions[j];
421
422 let dx = pos_i.x - pos_j.x;
423 let dy = pos_i.y - pos_j.y;
424 let dz = pos_i.z - pos_j.z;
425
426 let distance = sqrt(dx * dx + dy * dy + dz * dz);
427 distances[idx] = distance;
428}
429"#;
430
431pub struct AdaptiveNetworkCompute {
433 gpu: Option<GpuGeometricNetwork>,
434}
435
436impl AdaptiveNetworkCompute {
437 pub async fn new() -> Self {
439 let gpu = {
441 let panic_result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
442 pollster::block_on(async { GpuGeometricNetwork::new().await.ok() })
443 }));
444
445 panic_result.unwrap_or_default()
447 };
448
449 Self { gpu }
450 }
451
452 pub async fn compute_all_pairwise_distances<const P: usize, const Q: usize, const R: usize>(
454 &self,
455 network: &GeometricNetwork<P, Q, R>,
456 ) -> GpuNetworkResult<Vec<Vec<f64>>> {
457 let num_nodes = network.num_nodes();
458
459 if let Some(gpu) = &self.gpu {
460 if GpuGeometricNetwork::should_use_gpu(num_nodes) {
461 return gpu.compute_all_pairwise_distances(network).await;
462 }
463 }
464
465 network
467 .compute_all_pairs_shortest_paths()
468 .map_err(GpuNetworkError::Network)
469 }
470
471 pub async fn compute_geometric_centrality<const P: usize, const Q: usize, const R: usize>(
473 &self,
474 network: &GeometricNetwork<P, Q, R>,
475 ) -> GpuNetworkResult<Vec<f64>> {
476 let num_nodes = network.num_nodes();
477
478 if let Some(gpu) = &self.gpu {
479 if GpuGeometricNetwork::should_use_gpu(num_nodes) {
480 return gpu.compute_geometric_centrality(network).await;
481 }
482 }
483
484 network
486 .compute_geometric_centrality()
487 .map_err(GpuNetworkError::Network)
488 }
489}
490
491#[cfg(test)]
492mod tests {
493 use super::*;
494
495 #[test]
496 fn test_should_use_gpu() {
497 assert!(!GpuGeometricNetwork::should_use_gpu(10));
498 assert!(GpuGeometricNetwork::should_use_gpu(1000));
499 }
500
501 #[tokio::test]
502 async fn test_adaptive_network_creation() {
503 let adaptive = AdaptiveNetworkCompute::new().await;
505
506 match &adaptive.gpu {
508 Some(_) => {
509 println!("✅ GPU network acceleration available");
510 }
511 None => {
512 println!("✅ GPU not available, using CPU fallback for network operations");
513 }
514 }
515
516 }
519}