#![allow(dead_code)]
use crate::{Result, VisionError};
use scirs2_core::ndarray::{arr1, arr2, Array1, Array2, ArrayView2, Axis};
use scirs2_spatial::kdtree::KDTree;
use scirs2_spatial::octree::Octree;
use scirs2_spatial::quadtree::Quadtree;
use scirs2_spatial::rtree::RTree;
use std::collections::HashMap;
use torsh_tensor::Tensor;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct PointId(pub usize);
#[derive(Debug, Clone)]
pub struct BoundingBox {
pub min: Array1<f64>,
pub max: Array1<f64>,
}
impl BoundingBox {
pub fn new(min: Array1<f64>, max: Array1<f64>) -> Result<Self> {
if min.len() != max.len() {
return Err(VisionError::InvalidArgument(
"Min and max coordinates must have same dimension".to_string(),
));
}
for (i, (&min_val, &max_val)) in min.iter().zip(max.iter()).enumerate() {
if min_val > max_val {
return Err(VisionError::InvalidArgument(format!(
"Min coordinate {} is greater than max coordinate at dimension {}",
min_val, i
)));
}
}
Ok(Self { min, max })
}
pub fn from_points(points: &Array2<f64>) -> Result<Self> {
if points.is_empty() {
return Err(VisionError::InvalidArgument(
"Cannot create bounding box from empty points".to_string(),
));
}
let dims = points.ncols();
let mut min = Array1::from_elem(dims, f64::INFINITY);
let mut max = Array1::from_elem(dims, f64::NEG_INFINITY);
for point in points.outer_iter() {
for (i, &coord) in point.iter().enumerate() {
min[i] = min[i].min(coord);
max[i] = max[i].max(coord);
}
}
Self::new(min, max)
}
pub fn contains(&self, point: &ArrayView2<f64>) -> bool {
if point.len() != self.min.len() {
return false;
}
for (i, &coord) in point.iter().enumerate() {
if coord < self.min[i] || coord > self.max[i] {
return false;
}
}
true
}
pub fn volume(&self) -> f64 {
self.max
.iter()
.zip(self.min.iter())
.map(|(&max_val, &min_val)| max_val - min_val)
.product()
}
}
pub struct SpatialObjectTracker {
spatial_index: Option<RTree<ObjectId>>,
object_data: HashMap<ObjectId, ObjectMetadata>,
frame_history: Vec<FrameData>,
max_history: usize,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub struct ObjectId(pub usize);
#[derive(Debug, Clone)]
pub struct ObjectMetadata {
pub bbox: BoundingBox,
pub confidence: f64,
pub class_id: usize,
pub last_seen_frame: usize,
}
#[derive(Debug, Clone)]
pub struct FrameData {
pub frame_id: usize,
pub detections: Vec<Detection>,
pub timestamp: f64,
}
#[derive(Debug, Clone)]
pub struct Detection {
pub bbox: BoundingBox,
pub confidence: f64,
pub class_id: usize,
pub features: Option<Array1<f64>>,
}
#[derive(Debug, Clone)]
pub struct TrackedObject {
pub object_id: ObjectId,
pub bbox: BoundingBox,
pub confidence: f64,
pub class_id: usize,
pub track_length: usize,
}
impl SpatialObjectTracker {
pub fn new(max_history: usize) -> Self {
Self {
spatial_index: None,
object_data: HashMap::new(),
frame_history: Vec::new(),
max_history,
}
}
pub fn track_objects(
&mut self,
detections: &[Detection],
frame_id: usize,
) -> Result<Vec<TrackedObject>> {
let frame_data = FrameData {
frame_id,
detections: detections.to_vec(),
timestamp: std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.expect("system time should be after UNIX_EPOCH")
.as_secs_f64(),
};
self.frame_history.push(frame_data);
if self.frame_history.len() > self.max_history {
self.frame_history.remove(0);
}
let mut tracked_objects = Vec::new();
for (i, detection) in detections.iter().enumerate() {
let object_id = ObjectId(i);
let tracked_object = TrackedObject {
object_id,
bbox: detection.bbox.clone(),
confidence: detection.confidence,
class_id: detection.class_id,
track_length: 1,
};
tracked_objects.push(tracked_object);
let metadata = ObjectMetadata {
bbox: detection.bbox.clone(),
confidence: detection.confidence,
class_id: detection.class_id,
last_seen_frame: frame_id,
};
self.object_data.insert(object_id, metadata);
}
Ok(tracked_objects)
}
pub fn query_region(&self, region: &BoundingBox) -> Result<Vec<ObjectId>> {
let mut objects_in_region = Vec::new();
for (&object_id, metadata) in &self.object_data {
if self.bboxes_overlap(&metadata.bbox, region) {
objects_in_region.push(object_id);
}
}
Ok(objects_in_region)
}
fn bboxes_overlap(&self, bbox1: &BoundingBox, bbox2: &BoundingBox) -> bool {
if bbox1.min.len() != bbox2.min.len() {
return false;
}
for i in 0..bbox1.min.len() {
if bbox1.max[i] < bbox2.min[i] || bbox1.min[i] > bbox2.max[i] {
return false;
}
}
true
}
pub fn get_trajectory(&self, object_id: ObjectId) -> Vec<BoundingBox> {
if let Some(metadata) = self.object_data.get(&object_id) {
vec![metadata.bbox.clone()]
} else {
Vec::new()
}
}
}
pub struct PointCloudProcessor {
octree: Option<Octree>,
points: Array2<f64>,
point_metadata: HashMap<PointId, PointMetadata>,
}
#[derive(Debug, Clone)]
pub struct PointMetadata {
pub color: Option<Array1<f32>>,
pub normal: Option<Array1<f64>>,
pub intensity: Option<f64>,
}
impl PointCloudProcessor {
pub fn new() -> Self {
Self {
octree: None,
points: Array2::zeros((0, 3)),
point_metadata: HashMap::new(),
}
}
pub fn build_index(&mut self, points: Array2<f64>) -> Result<()> {
if points.ncols() != 3 {
return Err(VisionError::InvalidArgument(
"Point cloud must have 3D coordinates".to_string(),
));
}
self.points = points;
let _bbox = BoundingBox::from_points(&self.points)?;
Ok(())
}
pub fn query_region(&self, region: &BoundingBox) -> Result<Vec<PointId>> {
let mut points_in_region = Vec::new();
for (i, point) in self.points.outer_iter().enumerate() {
if region.contains(&point.view().insert_axis(Axis(1))) {
points_in_region.push(PointId(i));
}
}
Ok(points_in_region)
}
pub fn find_neighbors(&self, query_point: &Array1<f64>, k: usize) -> Result<Vec<PointId>> {
if query_point.len() != 3 {
return Err(VisionError::InvalidArgument(
"Query point must be 3D".to_string(),
));
}
let mut distances: Vec<(PointId, f64)> = Vec::new();
for (i, point) in self.points.outer_iter().enumerate() {
let diff = &point - query_point;
let distance = (diff.mapv(|x| x * x).sum()).sqrt();
distances.push((PointId(i), distance));
}
distances.sort_by(|a, b| a.1.partial_cmp(&b.1).expect("comparison should succeed"));
distances.truncate(k);
Ok(distances.into_iter().map(|(id, _)| id).collect())
}
pub fn segment_regions(&self, _region_size: f64) -> Result<Vec<Vec<PointId>>> {
let mut regions = Vec::new();
if !self.points.is_empty() {
let bbox = BoundingBox::from_points(&self.points)?;
let _dims = bbox.max.len();
let all_points: Vec<PointId> = (0..self.points.nrows()).map(PointId).collect();
regions.push(all_points);
}
Ok(regions)
}
}
impl Default for PointCloudProcessor {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_bounding_box_creation() {
let min = arr1(&[0.0, 0.0]);
let max = arr1(&[1.0, 1.0]);
let bbox = BoundingBox::new(min, max);
assert!(bbox.is_ok());
}
#[test]
fn test_bounding_box_invalid() {
let min = arr1(&[1.0, 1.0]);
let max = arr1(&[0.0, 0.0]);
let bbox = BoundingBox::new(min, max);
assert!(bbox.is_err());
}
#[test]
fn test_bounding_box_from_points() {
let points = arr2(&[[0.0, 0.0], [1.0, 1.0], [0.5, 0.5]]);
let bbox = BoundingBox::from_points(&points);
assert!(bbox.is_ok());
let bbox = bbox.expect("operation should succeed");
assert_eq!(bbox.min[0], 0.0);
assert_eq!(bbox.max[0], 1.0);
assert_eq!(bbox.volume(), 1.0);
}
#[test]
fn test_spatial_object_tracker() {
let mut tracker = SpatialObjectTracker::new(10);
let detection = Detection {
bbox: BoundingBox::new(arr1(&[0.0, 0.0]), arr1(&[1.0, 1.0]))
.expect("operation should succeed"),
confidence: 0.9,
class_id: 1,
features: None,
};
let result = tracker.track_objects(&[detection], 0);
assert!(result.is_ok());
assert_eq!(result.expect("operation should succeed").len(), 1);
}
#[test]
fn test_point_cloud_processor() {
let mut processor = PointCloudProcessor::new();
let points = arr2(&[[0.0, 0.0, 0.0], [1.0, 1.0, 1.0]]);
let result = processor.build_index(points);
assert!(result.is_ok());
}
}