use nalgebra::{Point3, Point2};
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
use crate::core::image::ImageId;
pub type Point3dId = u64;
pub type TrackId = u64;
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
pub struct Color {
pub r: u8,
pub g: u8,
pub b: u8,
}
impl Color {
pub fn new(r: u8, g: u8, b: u8) -> Self {
Self { r, g, b }
}
pub fn black() -> Self {
Self::new(0, 0, 0)
}
pub fn white() -> Self {
Self::new(255, 255, 255)
}
pub fn red() -> Self {
Self::new(255, 0, 0)
}
pub fn green() -> Self {
Self::new(0, 255, 0)
}
pub fn blue() -> Self {
Self::new(0, 0, 255)
}
pub fn to_f32_array(&self) -> [f32; 3] {
[self.r as f32 / 255.0, self.g as f32 / 255.0, self.b as f32 / 255.0]
}
pub fn from_f32_array(rgb: [f32; 3]) -> Self {
Self::new(
(rgb[0].clamp(0.0, 1.0) * 255.0) as u8,
(rgb[1].clamp(0.0, 1.0) * 255.0) as u8,
(rgb[2].clamp(0.0, 1.0) * 255.0) as u8,
)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Observation {
pub image_id: ImageId,
pub feature_idx: usize,
pub point2d: Point2<f64>,
}
impl Observation {
pub fn new(image_id: ImageId, feature_idx: usize, point2d: Point2<f64>) -> Self {
Self {
image_id,
feature_idx,
point2d,
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Point3d {
pub id: Point3dId,
pub position: Point3<f64>,
pub color: Color,
pub error: f64,
pub observations: Vec<Observation>,
pub track_id: Option<TrackId>,
}
impl Point3d {
pub fn new(id: Point3dId, position: Point3<f64>) -> Self {
Self {
id,
position,
color: Color::black(),
error: 0.0,
observations: Vec::new(),
track_id: None,
}
}
pub fn add_observation(&mut self, observation: Observation) {
self.observations.push(observation);
}
pub fn num_observations(&self) -> usize {
self.observations.len()
}
pub fn has_sufficient_observations(&self, min_observations: usize) -> bool {
self.observations.len() >= min_observations
}
pub fn image_ids(&self) -> Vec<ImageId> {
self.observations.iter().map(|obs| obs.image_id).collect()
}
pub fn is_observed_by(&self, image_id: ImageId) -> bool {
self.observations.iter().any(|obs| obs.image_id == image_id)
}
pub fn get_observation_in_image(&self, image_id: ImageId) -> Option<&Observation> {
self.observations.iter().find(|obs| obs.image_id == image_id)
}
pub fn set_color(&mut self, color: Color) {
self.color = color;
}
pub fn set_error(&mut self, error: f64) {
self.error = error;
}
pub fn mean_observation(&self) -> Option<Point2<f64>> {
if self.observations.is_empty() {
return None;
}
let sum = self.observations.iter()
.fold(Point2::new(0.0, 0.0), |acc, obs| acc + obs.point2d.coords);
Some(Point2::from(sum / self.observations.len() as f64))
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Track {
pub id: TrackId,
pub observations: Vec<Observation>,
pub point3d_id: Option<Point3dId>,
pub length: usize,
}
impl Track {
pub fn new(id: TrackId) -> Self {
Self {
id,
observations: Vec::new(),
point3d_id: None,
length: 0,
}
}
pub fn add_observation(&mut self, observation: Observation) {
self.observations.push(observation);
self.length = self.observations.len();
}
pub fn is_valid(&self, min_length: usize) -> bool {
self.length >= min_length
}
pub fn is_triangulated(&self) -> bool {
self.point3d_id.is_some()
}
pub fn set_point3d(&mut self, point3d_id: Point3dId) {
self.point3d_id = Some(point3d_id);
}
pub fn image_ids(&self) -> HashSet<ImageId> {
self.observations.iter().map(|obs| obs.image_id).collect()
}
}
#[derive(Debug, Default, Serialize, Deserialize)]
pub struct Reconstruction {
pub points3d: HashMap<Point3dId, Point3d>,
pub tracks: HashMap<TrackId, Track>,
next_point_id: Point3dId,
next_track_id: TrackId,
}
impl Reconstruction {
pub fn new() -> Self {
Self {
points3d: HashMap::new(),
tracks: HashMap::new(),
next_point_id: 1,
next_track_id: 1,
}
}
pub fn add_point3d(&mut self, mut point: Point3d) -> Point3dId {
if point.id == 0 {
point.id = self.next_point_id;
self.next_point_id += 1;
} else {
self.next_point_id = self.next_point_id.max(point.id + 1);
}
let id = point.id;
self.points3d.insert(id, point);
id
}
pub fn add_track(&mut self, mut track: Track) -> TrackId {
if track.id == 0 {
track.id = self.next_track_id;
self.next_track_id += 1;
} else {
self.next_track_id = self.next_track_id.max(track.id + 1);
}
let id = track.id;
self.tracks.insert(id, track);
id
}
pub fn get_point3d(&self, id: Point3dId) -> Option<&Point3d> {
self.points3d.get(&id)
}
pub fn get_point3d_mut(&mut self, id: Point3dId) -> Option<&mut Point3d> {
self.points3d.get_mut(&id)
}
pub fn get_track(&self, id: TrackId) -> Option<&Track> {
self.tracks.get(&id)
}
pub fn get_track_mut(&mut self, id: TrackId) -> Option<&mut Track> {
self.tracks.get_mut(&id)
}
pub fn remove_point3d(&mut self, id: Point3dId) -> Option<Point3d> {
self.points3d.remove(&id)
}
pub fn remove_track(&mut self, id: TrackId) -> Option<Track> {
self.tracks.remove(&id)
}
pub fn num_points3d(&self) -> usize {
self.points3d.len()
}
pub fn num_tracks(&self) -> usize {
self.tracks.len()
}
pub fn num_triangulated_tracks(&self) -> usize {
self.tracks.values().filter(|track| track.is_triangulated()).count()
}
pub fn bounding_box(&self) -> Option<(Point3<f64>, Point3<f64>)> {
if self.points3d.is_empty() {
return None;
}
let mut min_point = Point3::new(f64::INFINITY, f64::INFINITY, f64::INFINITY);
let mut max_point = Point3::new(f64::NEG_INFINITY, f64::NEG_INFINITY, f64::NEG_INFINITY);
for point in self.points3d.values() {
let pos = &point.position;
min_point.x = min_point.x.min(pos.x);
min_point.y = min_point.y.min(pos.y);
min_point.z = min_point.z.min(pos.z);
max_point.x = max_point.x.max(pos.x);
max_point.y = max_point.y.max(pos.y);
max_point.z = max_point.z.max(pos.z);
}
Some((min_point, max_point))
}
pub fn center(&self) -> Option<Point3<f64>> {
if let Some((min_point, max_point)) = self.bounding_box() {
Some(Point3::from((min_point.coords + max_point.coords) / 2.0))
} else {
None
}
}
pub fn clear(&mut self) {
self.points3d.clear();
self.tracks.clear();
self.next_point_id = 1;
self.next_track_id = 1;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_color() {
let color = Color::new(128, 64, 192);
assert_eq!(color.r, 128);
assert_eq!(color.g, 64);
assert_eq!(color.b, 192);
let float_array = color.to_f32_array();
let back_color = Color::from_f32_array(float_array);
assert_eq!(color, back_color);
}
#[test]
fn test_point3d() {
let mut point = Point3d::new(1, Point3::new(1.0, 2.0, 3.0));
assert_eq!(point.num_observations(), 0);
let obs = Observation::new(1, 0, Point2::new(100.0, 200.0));
point.add_observation(obs);
assert_eq!(point.num_observations(), 1);
assert!(point.is_observed_by(1));
assert!(!point.is_observed_by(2));
}
#[test]
fn test_track() {
let mut track = Track::new(1);
assert!(!track.is_valid(2));
assert!(!track.is_triangulated());
track.add_observation(Observation::new(1, 0, Point2::new(100.0, 200.0)));
track.add_observation(Observation::new(2, 1, Point2::new(150.0, 250.0)));
assert!(track.is_valid(2));
assert_eq!(track.length, 2);
let image_ids = track.image_ids();
assert!(image_ids.contains(&1));
assert!(image_ids.contains(&2));
}
#[test]
fn test_reconstruction() {
let mut recon = Reconstruction::new();
let point = Point3d::new(0, Point3::new(1.0, 2.0, 3.0));
let point_id = recon.add_point3d(point);
assert_eq!(point_id, 1);
assert_eq!(recon.num_points3d(), 1);
let track = Track::new(0);
let track_id = recon.add_track(track);
assert_eq!(track_id, 1);
assert_eq!(recon.num_tracks(), 1);
}
}