1use ghostflow_core::Tensor;
11use crate::linear::Linear;
12use crate::Module;
13
14#[derive(Debug, Clone)]
16pub struct PointNetConfig {
17 pub num_points: usize,
19 pub input_dim: usize,
21 pub num_classes: usize,
23 pub use_stn: bool,
25 pub feature_dim: usize,
27}
28
29impl Default for PointNetConfig {
30 fn default() -> Self {
31 PointNetConfig {
32 num_points: 1024,
33 input_dim: 3,
34 num_classes: 10,
35 use_stn: true,
36 feature_dim: 1024,
37 }
38 }
39}
40
41impl PointNetConfig {
42 pub fn small() -> Self {
44 PointNetConfig {
45 num_points: 512,
46 input_dim: 3,
47 num_classes: 10,
48 use_stn: false,
49 feature_dim: 512,
50 }
51 }
52
53 pub fn large() -> Self {
55 PointNetConfig {
56 num_points: 2048,
57 input_dim: 3,
58 num_classes: 40,
59 use_stn: true,
60 feature_dim: 2048,
61 }
62 }
63}
64
65pub struct STN3d {
67 conv1: Linear,
68 conv2: Linear,
69 conv3: Linear,
70 fc1: Linear,
71 fc2: Linear,
72 fc3: Linear,
73}
74
75impl STN3d {
76 pub fn new() -> Self {
78 STN3d {
79 conv1: Linear::new(3, 64),
80 conv2: Linear::new(64, 128),
81 conv3: Linear::new(128, 1024),
82 fc1: Linear::new(1024, 512),
83 fc2: Linear::new(512, 256),
84 fc3: Linear::new(256, 9), }
86 }
87
88 pub fn forward(&self, x: &Tensor) -> Result<Tensor, String> {
90 let dims = x.dims();
92 let batch_size = dims[0];
93 let num_points = dims[1];
94
95 let x_flat = self.reshape_points(x)?;
97
98 let mut features = self.conv1.forward(&x_flat);
100 features = features.relu();
101 features = self.conv2.forward(&features);
102 features = features.relu();
103 features = self.conv3.forward(&features);
104 features = features.relu();
105
106 features = self.reshape_back(&features, batch_size, num_points, 1024)?;
108
109 let global_features = self.max_pool_points(&features)?;
111
112 let mut x = self.fc1.forward(&global_features);
114 x = x.relu();
115 x = self.fc2.forward(&x);
116 x = x.relu();
117 x = self.fc3.forward(&x);
118
119 self.add_identity_bias(&x, batch_size)
121 }
122
123 fn reshape_points(&self, x: &Tensor) -> Result<Tensor, String> {
124 let data = x.data_f32();
125 let dims = x.dims();
126 let new_dims = vec![dims[0] * dims[1], dims[2]];
127 Tensor::from_slice(&data, &new_dims)
128 .map_err(|e| format!("Failed to reshape: {:?}", e))
129 }
130
131 fn reshape_back(&self, x: &Tensor, batch: usize, points: usize, features: usize) -> Result<Tensor, String> {
132 let data = x.data_f32();
133 Tensor::from_slice(&data, &[batch, points, features])
134 .map_err(|e| format!("Failed to reshape: {:?}", e))
135 }
136
137 fn max_pool_points(&self, x: &Tensor) -> Result<Tensor, String> {
138 let data = x.data_f32();
139 let dims = x.dims();
140 let batch_size = dims[0];
141 let num_points = dims[1];
142 let feature_dim = dims[2];
143
144 let mut result = Vec::with_capacity(batch_size * feature_dim);
145
146 for b in 0..batch_size {
147 for f in 0..feature_dim {
148 let mut max_val = f32::NEG_INFINITY;
149 for p in 0..num_points {
150 let idx = b * num_points * feature_dim + p * feature_dim + f;
151 max_val = max_val.max(data[idx]);
152 }
153 result.push(max_val);
154 }
155 }
156
157 Tensor::from_slice(&result, &[batch_size, feature_dim])
158 .map_err(|e| format!("Failed to pool: {:?}", e))
159 }
160
161 fn add_identity_bias(&self, x: &Tensor, batch_size: usize) -> Result<Tensor, String> {
162 let data = x.data_f32();
163 let mut result = Vec::with_capacity(data.len());
164
165 let identity = vec![1.0, 0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0];
166
167 for b in 0..batch_size {
168 for i in 0..9 {
169 result.push(data[b * 9 + i] + identity[i]);
170 }
171 }
172
173 Tensor::from_slice(&result, &[batch_size, 9])
174 .map_err(|e| format!("Failed to add bias: {:?}", e))
175 }
176}
177
178pub struct PointNetBackbone {
180 conv1: Linear,
181 conv2: Linear,
182 conv3: Linear,
183 conv4: Linear,
184 conv5: Linear,
185}
186
187impl PointNetBackbone {
188 pub fn new() -> Self {
190 PointNetBackbone {
191 conv1: Linear::new(3, 64),
192 conv2: Linear::new(64, 64),
193 conv3: Linear::new(64, 64),
194 conv4: Linear::new(64, 128),
195 conv5: Linear::new(128, 1024),
196 }
197 }
198
199 pub fn forward(&self, x: &Tensor) -> Result<Tensor, String> {
201 let dims = x.dims();
203 let batch_size = dims[0];
204 let num_points = dims[1];
205
206 let data = x.data_f32();
208 let x_flat = Tensor::from_slice(&data, &[batch_size * num_points, 3])
209 .map_err(|e| format!("Failed to reshape: {:?}", e))?;
210
211 let mut features = self.conv1.forward(&x_flat);
213 features = features.relu();
214 features = self.conv2.forward(&features);
215 features = features.relu();
216 features = self.conv3.forward(&features);
217 features = features.relu();
218 features = self.conv4.forward(&features);
219 features = features.relu();
220 features = self.conv5.forward(&features);
221 features = features.relu();
222
223 let feat_data = features.data_f32();
225 let features = Tensor::from_slice(&feat_data, &[batch_size, num_points, 1024])
226 .map_err(|e| format!("Failed to reshape back: {:?}", e))?;
227
228 self.max_pool_points(&features)
230 }
231
232 fn max_pool_points(&self, x: &Tensor) -> Result<Tensor, String> {
233 let data = x.data_f32();
234 let dims = x.dims();
235 let batch_size = dims[0];
236 let num_points = dims[1];
237 let feature_dim = dims[2];
238
239 let mut result = Vec::with_capacity(batch_size * feature_dim);
240
241 for b in 0..batch_size {
242 for f in 0..feature_dim {
243 let mut max_val = f32::NEG_INFINITY;
244 for p in 0..num_points {
245 let idx = b * num_points * feature_dim + p * feature_dim + f;
246 max_val = max_val.max(data[idx]);
247 }
248 result.push(max_val);
249 }
250 }
251
252 Tensor::from_slice(&result, &[batch_size, feature_dim])
253 .map_err(|e| format!("Failed to pool: {:?}", e))
254 }
255}
256
257pub struct PointNet {
259 config: PointNetConfig,
260 stn: Option<STN3d>,
261 backbone: PointNetBackbone,
262 fc1: Linear,
263 fc2: Linear,
264 fc3: Linear,
265}
266
267impl PointNet {
268 pub fn new(config: PointNetConfig) -> Self {
270 let stn = if config.use_stn {
271 Some(STN3d::new())
272 } else {
273 None
274 };
275
276 let backbone = PointNetBackbone::new();
277 let fc1 = Linear::new(1024, 512);
278 let fc2 = Linear::new(512, 256);
279 let fc3 = Linear::new(256, config.num_classes);
280
281 PointNet {
282 config,
283 stn,
284 backbone,
285 fc1,
286 fc2,
287 fc3,
288 }
289 }
290
291 pub fn forward(&self, x: &Tensor) -> Result<Tensor, String> {
293 let x = if let Some(ref stn) = self.stn {
295 let transform = stn.forward(x)?;
296 self.apply_transform(x, &transform)?
297 } else {
298 x.clone()
299 };
300
301 let global_features = self.backbone.forward(&x)?;
303
304 let mut x = self.fc1.forward(&global_features);
306 x = x.relu();
307 x = self.fc2.forward(&x);
308 x = x.relu();
309 let logits = self.fc3.forward(&x);
310
311 Ok(logits)
312 }
313
314 fn apply_transform(&self, points: &Tensor, transform: &Tensor) -> Result<Tensor, String> {
315 Ok(points.clone())
318 }
319}
320
321pub struct FarthestPointSampler;
323
324impl FarthestPointSampler {
325 pub fn sample(points: &Tensor, num_samples: usize) -> Result<Tensor, String> {
327 let data = points.data_f32();
328 let dims = points.dims();
329 let batch_size = dims[0];
330 let num_points = dims[1];
331 let point_dim = dims[2];
332
333 if num_samples > num_points {
334 return Err(format!("Cannot sample {} points from {}", num_samples, num_points));
335 }
336
337 let mut result = Vec::with_capacity(batch_size * num_samples * point_dim);
338
339 for b in 0..batch_size {
340 let batch_offset = b * num_points * point_dim;
341 let mut sampled_indices = Vec::new();
342 let mut distances = vec![f32::INFINITY; num_points];
343
344 sampled_indices.push(0);
346
347 for i in 0..num_points {
349 distances[i] = Self::point_distance(
350 &data[batch_offset..],
351 0,
352 i,
353 point_dim,
354 );
355 }
356
357 for _ in 1..num_samples {
359 let mut max_dist = 0.0;
361 let mut farthest_idx = 0;
362
363 for i in 0..num_points {
364 if distances[i] > max_dist {
365 max_dist = distances[i];
366 farthest_idx = i;
367 }
368 }
369
370 sampled_indices.push(farthest_idx);
371
372 for i in 0..num_points {
374 let dist = Self::point_distance(
375 &data[batch_offset..],
376 farthest_idx,
377 i,
378 point_dim,
379 );
380 distances[i] = distances[i].min(dist);
381 }
382 }
383
384 for &idx in &sampled_indices {
386 let start = batch_offset + idx * point_dim;
387 result.extend_from_slice(&data[start..start + point_dim]);
388 }
389 }
390
391 Tensor::from_slice(&result, &[batch_size, num_samples, point_dim])
392 .map_err(|e| format!("Failed to create sampled tensor: {:?}", e))
393 }
394
395 fn point_distance(data: &[f32], idx1: usize, idx2: usize, dim: usize) -> f32 {
396 let mut dist_sq = 0.0;
397 for d in 0..dim {
398 let diff = data[idx1 * dim + d] - data[idx2 * dim + d];
399 dist_sq += diff * diff;
400 }
401 dist_sq.sqrt()
402 }
403}
404
405pub struct KNNGrouper;
407
408impl KNNGrouper {
409 pub fn group(points: &Tensor, centroids: &Tensor, k: usize) -> Result<Tensor, String> {
411 let points_data = points.data_f32();
412 let centroids_data = centroids.data_f32();
413
414 let points_dims = points.dims();
415 let centroids_dims = centroids.dims();
416
417 let batch_size = points_dims[0];
418 let num_points = points_dims[1];
419 let point_dim = points_dims[2];
420 let num_centroids = centroids_dims[1];
421
422 let mut result = Vec::with_capacity(batch_size * num_centroids * k * point_dim);
423
424 for b in 0..batch_size {
425 let points_offset = b * num_points * point_dim;
426 let centroids_offset = b * num_centroids * point_dim;
427
428 for c in 0..num_centroids {
429 let mut distances: Vec<(f32, usize)> = (0..num_points)
431 .map(|p| {
432 let dist = Self::point_distance(
433 &points_data[points_offset..],
434 ¢roids_data[centroids_offset..],
435 p,
436 c,
437 point_dim,
438 );
439 (dist, p)
440 })
441 .collect();
442
443 distances.sort_by(|a, b| a.0.partial_cmp(&b.0).unwrap());
444
445 for i in 0..k.min(num_points) {
447 let point_idx = distances[i].1;
448 let start = points_offset + point_idx * point_dim;
449 result.extend_from_slice(&points_data[start..start + point_dim]);
450 }
451
452 for _ in num_points..k {
454 for _ in 0..point_dim {
455 result.push(0.0);
456 }
457 }
458 }
459 }
460
461 Tensor::from_slice(&result, &[batch_size, num_centroids, k, point_dim])
462 .map_err(|e| format!("Failed to create grouped tensor: {:?}", e))
463 }
464
465 fn point_distance(points: &[f32], centroids: &[f32], p_idx: usize, c_idx: usize, dim: usize) -> f32 {
466 let mut dist_sq = 0.0;
467 for d in 0..dim {
468 let diff = points[p_idx * dim + d] - centroids[c_idx * dim + d];
469 dist_sq += diff * diff;
470 }
471 dist_sq.sqrt()
472 }
473}
474
475#[cfg(test)]
476mod tests {
477 use super::*;
478
479 #[test]
480 fn test_pointnet_config() {
481 let config = PointNetConfig::default();
482 assert_eq!(config.num_points, 1024);
483 assert_eq!(config.input_dim, 3);
484
485 let small = PointNetConfig::small();
486 assert_eq!(small.num_points, 512);
487 }
488
489 #[test]
490 fn test_stn3d() {
491 let stn = STN3d::new();
492 let points = Tensor::randn(&[2, 64, 3]);
493 let transform = stn.forward(&points).unwrap();
494 assert_eq!(transform.dims(), &[2, 9]); }
496
497 #[test]
498 fn test_pointnet_backbone() {
499 let backbone = PointNetBackbone::new();
500 let points = Tensor::randn(&[2, 128, 3]);
501 let features = backbone.forward(&points).unwrap();
502 assert_eq!(features.dims(), &[2, 1024]);
503 }
504
505 #[test]
506 fn test_pointnet() {
507 let config = PointNetConfig::small();
508 let model = PointNet::new(config);
509
510 let points = Tensor::randn(&[2, 512, 3]);
511 let logits = model.forward(&points).unwrap();
512 assert_eq!(logits.dims(), &[2, 10]); }
514
515 #[test]
516 fn test_farthest_point_sampling() {
517 let points = Tensor::randn(&[1, 100, 3]);
518 let sampled = FarthestPointSampler::sample(&points, 32).unwrap();
519 assert_eq!(sampled.dims(), &[1, 32, 3]);
520 }
521
522 #[test]
523 fn test_knn_grouper() {
524 let points = Tensor::randn(&[1, 100, 3]);
525 let centroids = Tensor::randn(&[1, 10, 3]);
526 let grouped = KNNGrouper::group(&points, ¢roids, 8).unwrap();
527 assert_eq!(grouped.dims(), &[1, 10, 8, 3]); }
529}