1use nalgebra::constraint::{SameNumberOfRows, ShapeConstraint};
4use nalgebra::storage::Storage;
5use nalgebra::{Const, Dim, SVector, Vector};
6use rand::distributions::{uniform::SampleUniform, Distribution};
7use rand::{thread_rng, SeedableRng};
8use serde::{de::DeserializeOwned, Deserialize, Serialize};
9
10use crate::cspace::CSpace;
11use crate::error::{InvalidParamError, Result};
12use crate::params::FromParams;
13use crate::rng::{LinearCoordinates, RNG};
14use crate::scalar::Scalar;
15use crate::trajectories::{FullTraj, LinearTrajectory};
16use crate::util::{bounds::Bounds, norm::NormCost};
17
18#[derive(Copy, Clone, Debug, PartialEq, Serialize, Deserialize)]
19#[serde(bound(
20 serialize = "X: Serialize",
21 deserialize = "X: DeserializeOwned"
22))]
23pub struct LinearSpaceParams<X: Scalar, const N: usize> {
24 pub bounds: Bounds<X, N>,
25 pub norm: NormCost,
26 pub seed: Option<u64>,
27}
28
29pub struct LinearSpace<X, const N: usize>
31where
32 X: SampleUniform,
33{
34 volume: X,
35 norm: NormCost,
36 rng: RNG,
37 distribution: LinearCoordinates<X, N>,
38}
39
40impl<X, const N: usize> LinearSpace<X, N>
41where
42 X: Scalar + SampleUniform,
43{
44 pub fn new(bounds: Bounds<X, N>, norm: NormCost, rng: RNG) -> Result<Self> {
45 if !bounds.is_valid() {
46 Err(InvalidParamError {
47 parameter_name: "bounds",
48 parameter_value: format!("{:?}", bounds),
49 })?;
50 }
51
52 let volume = bounds.volume();
53 let distribution = bounds.into();
54
55 Ok(Self {
56 volume,
57 norm,
58 rng,
59 distribution,
60 })
61 }
62}
63
64impl<X, const N: usize> CSpace<X, N> for LinearSpace<X, N>
65where
66 X: Scalar + SampleUniform,
67{
68 type Traj = LinearTrajectory<X, NormCost, N>;
69
70 fn volume(&self) -> X {
71 self.volume
72 }
73
74 fn cost<R1, R2, S1, S2>(
75 &self,
76 a: &Vector<X, R1, S1>,
77 b: &Vector<X, R2, S2>,
78 ) -> X
79 where
80 X: Scalar,
81 R1: Dim,
82 R2: Dim,
83 S1: Storage<X, R1>,
84 S2: Storage<X, R2>,
85 ShapeConstraint: SameNumberOfRows<R1, R2>
86 + SameNumberOfRows<R1, Const<N>>
87 + SameNumberOfRows<R2, Const<N>>,
88 {
89 self.norm.cost(a, b)
90 }
91
92 fn trajectory<S1, S2>(
93 &self,
94 start: Vector<X, Const<N>, S1>,
95 end: Vector<X, Const<N>, S2>,
96 ) -> Option<FullTraj<X, Self::Traj, S1, S2, N>>
97 where
98 X: Scalar,
99 S1: Storage<X, Const<N>>,
100 S2: Storage<X, Const<N>>,
101 {
102 Some(FullTraj::new(start, end, LinearTrajectory::new(self.norm)))
104 }
105
106 fn is_free<S>(&self, _: &Vector<X, Const<N>, S>) -> bool
107 where
108 S: Storage<X, Const<N>>,
109 {
110 true
112 }
113
114 fn saturate(&self, a: &mut SVector<X, N>, b: &SVector<X, N>, delta: X) {
115 let scale = delta / self.norm.cost(a, b);
116 *a -= *b;
117 *a *= scale;
118 *a += *b;
119 }
120
121 fn sample(&mut self) -> SVector<X, N> {
122 self.distribution.sample(&mut self.rng)
123 }
124}
125
126impl<X, const N: usize> FromParams for LinearSpace<X, N>
127where
128 X: Scalar + SampleUniform,
129{
130 type Params = LinearSpaceParams<X, N>;
131 fn from_params(params: Self::Params) -> Result<Self> {
132 let rng = match params.seed {
133 Some(seed) => RNG::seed_from_u64(seed),
134 None => RNG::from_rng(thread_rng())?,
135 };
136 LinearSpace::new(params.bounds, params.norm, rng)
137 }
138}
139
140#[cfg(test)]
141mod tests {
142
143 use rand::SeedableRng;
144 use rayon::iter::{IntoParallelIterator, ParallelIterator};
145
146 use super::*;
147
148 const SEED: u64 = 0xe580e2e93fd6b040;
149
150 #[test]
151 fn test_parallel_sample() {
152 let mins: [f32; 2] = [-2.0, -2.0];
153 let maxs = [2.0, 2.0];
154 let rng = RNG::seed_from_u64(SEED);
155
156 let bounds = Bounds::new(mins.into(), maxs.into());
157 let space = LinearSpace::new(bounds, NormCost::TwoNorm, rng).unwrap();
158
159 let samples = (0..1000)
160 .into_par_iter()
161 .map(|_| {
162 let point = [-1.0, -1.0].into();
163 space.is_free(&point) })
165 .collect::<Vec<_>>();
166
167 assert_eq!(samples.len(), 1000);
168 let point = [-1.0, -1.0].into();
169 assert_eq!(space.is_free(&point), true);
170 }
171}