1use crate::types::Point3D;
4use ndarray::Array2;
5use serde::{Deserialize, Serialize};
6
7#[derive(Debug, Clone, Serialize, Deserialize)]
9pub struct DirectivityPattern {
10 pub horizontal_angles: Vec<f64>,
12 pub vertical_angles: Vec<f64>,
14 pub magnitude: Array2<f64>,
17}
18
19impl DirectivityPattern {
20 pub fn omnidirectional() -> Self {
22 let horizontal_angles: Vec<f64> = (0..36).map(|i| i as f64 * 10.0).collect();
23 let vertical_angles: Vec<f64> = (0..19).map(|i| i as f64 * 10.0).collect();
24
25 let magnitude = Array2::ones((vertical_angles.len(), horizontal_angles.len()));
26
27 Self {
28 horizontal_angles,
29 vertical_angles,
30 magnitude,
31 }
32 }
33
34 pub fn cardioid() -> Self {
36 let horizontal_angles: Vec<f64> = (0..36).map(|i| i as f64 * 10.0).collect();
37 let vertical_angles: Vec<f64> = (0..19).map(|i| i as f64 * 10.0).collect();
38
39 let mut magnitude = Array2::zeros((vertical_angles.len(), horizontal_angles.len()));
40
41 for (v_idx, &v_angle) in vertical_angles.iter().enumerate() {
42 for (h_idx, &h_angle) in horizontal_angles.iter().enumerate() {
43 let theta_rad = v_angle.to_radians();
45 let phi_rad = h_angle.to_radians();
46 let forward_dot = theta_rad.sin() * phi_rad.sin();
48 magnitude[[v_idx, h_idx]] = 0.5 * (1.0 + forward_dot).max(0.0);
49 }
50 }
51
52 Self {
53 horizontal_angles,
54 vertical_angles,
55 magnitude,
56 }
57 }
58
59 pub fn interpolate(&self, theta: f64, phi: f64) -> f64 {
61 let theta_deg = theta.to_degrees();
63 let mut phi_deg = phi.to_degrees();
64
65 while phi_deg < 0.0 {
67 phi_deg += 360.0;
68 }
69 while phi_deg >= 360.0 {
70 phi_deg -= 360.0;
71 }
72
73 let h_idx = (phi_deg / 10.0).floor() as usize;
75 let v_idx = (theta_deg / 10.0).floor() as usize;
76
77 let h_idx = h_idx.min(self.horizontal_angles.len() - 1);
78 let v_idx = v_idx.min(self.vertical_angles.len() - 1);
79
80 let h_next = (h_idx + 1) % self.horizontal_angles.len();
81 let v_next = (v_idx + 1).min(self.vertical_angles.len() - 1);
82
83 let h_frac = (phi_deg / 10.0) - h_idx as f64;
85 let v_frac = (theta_deg / 10.0) - v_idx as f64;
86
87 let m00 = self.magnitude[[v_idx, h_idx]];
88 let m01 = self.magnitude[[v_idx, h_next]];
89 let m10 = self.magnitude[[v_next, h_idx]];
90 let m11 = self.magnitude[[v_next, h_next]];
91
92 let m0 = m00 * (1.0 - h_frac) + m01 * h_frac;
93 let m1 = m10 * (1.0 - h_frac) + m11 * h_frac;
94
95 m0 * (1.0 - v_frac) + m1 * v_frac
96 }
97}
98
99#[derive(Debug, Clone, Default, Serialize, Deserialize)]
101pub enum CrossoverFilter {
102 #[default]
104 FullRange,
105 Lowpass {
107 cutoff_freq: f64,
109 order: u32,
111 },
112 Highpass {
114 cutoff_freq: f64,
116 order: u32,
118 },
119 Bandpass {
121 low_cutoff: f64,
123 high_cutoff: f64,
125 order: u32,
127 },
128}
129
130impl CrossoverFilter {
131 pub fn amplitude_at_frequency(&self, frequency: f64) -> f64 {
133 match self {
134 CrossoverFilter::FullRange => 1.0,
135 CrossoverFilter::Lowpass { cutoff_freq, order } => {
136 let ratio = frequency / cutoff_freq;
137 1.0 / (1.0 + ratio.powi(*order as i32 * 2)).sqrt()
138 }
139 CrossoverFilter::Highpass { cutoff_freq, order } => {
140 let ratio = cutoff_freq / frequency;
141 1.0 / (1.0 + ratio.powi(*order as i32 * 2)).sqrt()
142 }
143 CrossoverFilter::Bandpass {
144 low_cutoff,
145 high_cutoff,
146 order,
147 } => {
148 let high_ratio = low_cutoff / frequency;
149 let low_ratio = frequency / high_cutoff;
150 let hp_response = 1.0 / (1.0 + high_ratio.powi(*order as i32 * 2)).sqrt();
151 let lp_response = 1.0 / (1.0 + low_ratio.powi(*order as i32 * 2)).sqrt();
152 hp_response * lp_response
153 }
154 }
155 }
156}
157
158#[derive(Debug, Clone, Serialize, Deserialize)]
160pub struct Source {
161 pub position: Point3D,
163 pub directivity: DirectivityPattern,
165 pub amplitude: f64,
167 pub crossover: CrossoverFilter,
169 pub name: String,
171}
172
173impl Source {
174 pub fn new(position: Point3D, directivity: DirectivityPattern, amplitude: f64) -> Self {
176 Self {
177 position,
178 directivity,
179 amplitude,
180 crossover: CrossoverFilter::FullRange,
181 name: String::from("Source"),
182 }
183 }
184
185 pub fn omnidirectional(position: Point3D, amplitude: f64) -> Self {
187 Self::new(position, DirectivityPattern::omnidirectional(), amplitude)
188 }
189
190 pub fn with_crossover(mut self, crossover: CrossoverFilter) -> Self {
192 self.crossover = crossover;
193 self
194 }
195
196 pub fn with_name(mut self, name: String) -> Self {
198 self.name = name;
199 self
200 }
201
202 pub fn amplitude_towards(&self, point: &Point3D, frequency: f64) -> f64 {
204 let dx = point.x - self.position.x;
205 let dy = point.y - self.position.y;
206 let dz = point.z - self.position.z;
207
208 let r = (dx * dx + dy * dy + dz * dz).sqrt();
209 if r < 1e-10 {
210 return self.amplitude * self.crossover.amplitude_at_frequency(frequency);
211 }
212
213 let theta = (dz / r).acos();
214 let phi = dy.atan2(dx);
215
216 let directivity_factor = self.directivity.interpolate(theta, phi);
217 let crossover_factor = self.crossover.amplitude_at_frequency(frequency);
218 self.amplitude * directivity_factor * crossover_factor
219 }
220}
221
222#[cfg(test)]
223mod tests {
224 use super::*;
225 use std::f64::consts::PI;
226
227 #[test]
228 fn test_omnidirectional_pattern() {
229 let pattern = DirectivityPattern::omnidirectional();
230 assert!((pattern.interpolate(0.0, 0.0) - 1.0).abs() < 1e-6);
232 assert!((pattern.interpolate(PI / 2.0, PI) - 1.0).abs() < 1e-6);
233 assert!((pattern.interpolate(PI, 0.0) - 1.0).abs() < 1e-6);
234 }
235
236 #[test]
237 fn test_crossover_lowpass() {
238 let crossover = CrossoverFilter::Lowpass {
239 cutoff_freq: 100.0,
240 order: 2,
241 };
242 assert!((crossover.amplitude_at_frequency(10.0) - 1.0).abs() < 0.1);
244 let at_cutoff = crossover.amplitude_at_frequency(100.0);
246 assert!(at_cutoff > 0.6 && at_cutoff < 0.8);
247 assert!(crossover.amplitude_at_frequency(1000.0) < 0.1);
249 }
250
251 #[test]
252 fn test_source_amplitude() {
253 let source = Source::omnidirectional(Point3D::new(0.0, 0.0, 0.0), 1.0);
254 let amp = source.amplitude_towards(&Point3D::new(1.0, 0.0, 0.0), 1000.0);
255 assert!((amp - 1.0).abs() < 1e-6);
256 }
257}