use std::collections::HashMap;
use crate::error::{FFTError, FFTResult};
use super::multipole::{LocalExpansion, MultipoleExpansion};
use super::tree::{QuadTree, TreeNode};
pub struct FMM2D {
pub order: usize,
pub max_depth: usize,
pub mac: f64,
pub max_per_leaf: usize,
}
impl FMM2D {
pub fn new(order: usize) -> Self {
FMM2D {
order,
max_depth: 10,
mac: 0.5,
max_per_leaf: 16,
}
}
pub fn with_mac(mut self, mac: f64) -> Self {
self.mac = mac;
self
}
pub fn with_max_depth(mut self, depth: usize) -> Self {
self.max_depth = depth;
self
}
pub fn with_max_per_leaf(mut self, n: usize) -> Self {
self.max_per_leaf = n;
self
}
pub fn compute_potentials(
&self,
sources: &[[f64; 2]],
charges: &[f64],
targets: &[[f64; 2]],
) -> FFTResult<Vec<f64>> {
if sources.len() != charges.len() {
return Err(FFTError::ValueError(
"sources and charges must have the same length".into(),
));
}
if sources.is_empty() || targets.is_empty() {
return Ok(vec![0.0; targets.len()]);
}
let tree = QuadTree::build(sources, self.max_depth, self.max_per_leaf)
.map_err(|e| FFTError::ComputationError(format!("tree build: {e}")))?;
let multipoles = self.upward_pass(&tree, sources, charges)?;
let locals = self.downward_pass(&tree, &multipoles)?;
let result = self.evaluate_at_targets(
&tree, sources, charges, targets, &locals,
)?;
Ok(result)
}
pub fn direct_sum(
sources: &[[f64; 2]],
charges: &[f64],
targets: &[[f64; 2]],
) -> Vec<f64> {
targets
.iter()
.map(|t| {
sources
.iter()
.zip(charges.iter())
.map(|(s, q)| {
let r = ((t[0] - s[0]).powi(2) + (t[1] - s[1]).powi(2)).sqrt();
if r > 1e-15 {
q * r.ln()
} else {
0.0
}
})
.sum::<f64>()
})
.collect()
}
fn upward_pass(
&self,
tree: &QuadTree,
sources: &[[f64; 2]],
charges: &[f64],
) -> FFTResult<HashMap<usize, MultipoleExpansion>> {
let mut multipoles: HashMap<usize, MultipoleExpansion> = HashMap::new();
self.upward_recursive(&tree.root, sources, charges, &mut multipoles)?;
Ok(multipoles)
}
fn upward_recursive(
&self,
node: &TreeNode,
sources: &[[f64; 2]],
charges: &[f64],
multipoles: &mut HashMap<usize, MultipoleExpansion>,
) -> FFTResult<()> {
if node.is_leaf {
let mut m = MultipoleExpansion::new(node.center, self.order);
for &idx in &node.point_indices {
if idx < sources.len() {
m.add_source(sources[idx], charges[idx]);
}
}
multipoles.insert(node.node_id, m);
} else if let Some(children) = &node.children {
for child in children.iter() {
self.upward_recursive(child, sources, charges, multipoles)?;
}
let mut parent_m = MultipoleExpansion::new(node.center, self.order);
for child in children.iter() {
if let Some(child_m) = multipoles.get(&child.node_id) {
let translated = child_m.translate(node.center);
for k in 0..=self.order {
parent_m.coeffs[k][0] += translated.coeffs[k][0];
parent_m.coeffs[k][1] += translated.coeffs[k][1];
}
}
}
multipoles.insert(node.node_id, parent_m);
}
Ok(())
}
fn downward_pass(
&self,
tree: &QuadTree,
multipoles: &HashMap<usize, MultipoleExpansion>,
) -> FFTResult<HashMap<usize, LocalExpansion>> {
let mut locals: HashMap<usize, LocalExpansion> = HashMap::new();
self.downward_recursive(&tree.root, &tree.root, multipoles, &mut locals, true)?;
Ok(locals)
}
fn downward_recursive(
&self,
root: &TreeNode,
node: &TreeNode,
multipoles: &HashMap<usize, MultipoleExpansion>,
locals: &mut HashMap<usize, LocalExpansion>,
is_root_call: bool,
) -> FFTResult<()> {
if !locals.contains_key(&node.node_id) {
locals.insert(node.node_id, LocalExpansion::new(node.center, self.order));
}
if is_root_call {
let all_nodes = root.bfs_nodes_with_ids();
let n = all_nodes.len();
for i in 0..n {
let (_, ni) = &all_nodes[i];
for j in 0..n {
if i == j {
continue;
}
let (_, nj) = &all_nodes[j];
if ni.is_well_separated(nj, self.mac) {
if let Some(m_j) = multipoles.get(&nj.node_id) {
let q_total = m_j.coeffs[0][0].abs() + m_j.coeffs[0][1].abs();
if q_total > 1e-30 {
match m_j.to_local(ni.center, self.order) {
Ok(new_local) => {
let entry = locals
.entry(ni.node_id)
.or_insert_with(|| LocalExpansion::new(ni.center, self.order));
entry.add(&new_local);
}
Err(_) => {
}
}
}
}
}
}
}
self.l2l_pass(root, locals)?;
return Ok(());
}
Ok(())
}
fn l2l_pass(
&self,
node: &TreeNode,
locals: &mut HashMap<usize, LocalExpansion>,
) -> FFTResult<()> {
if let Some(children) = &node.children {
if let Some(parent_local) = locals.get(&node.node_id).cloned() {
for child in children.iter() {
let child_local = parent_local.translate(child.center);
let entry = locals
.entry(child.node_id)
.or_insert_with(|| LocalExpansion::new(child.center, self.order));
entry.add(&child_local);
}
}
for child in children.iter() {
self.l2l_pass(child, locals)?;
}
}
Ok(())
}
fn evaluate_at_targets(
&self,
tree: &QuadTree,
sources: &[[f64; 2]],
charges: &[f64],
targets: &[[f64; 2]],
locals: &HashMap<usize, LocalExpansion>,
) -> FFTResult<Vec<f64>> {
let mut potentials = vec![0.0_f64; targets.len()];
let leaves = tree.leaves();
for (t_idx, target) in targets.iter().enumerate() {
let containing_leaf = find_leaf(&tree.root, *target);
let mut phi = 0.0;
if let Some(leaf) = containing_leaf {
if let Some(local) = locals.get(&leaf.node_id) {
phi += local.evaluate(*target);
}
for other_leaf in &leaves {
if leaf.is_adjacent(other_leaf) || other_leaf.node_id == leaf.node_id {
for &s_idx in &other_leaf.point_indices {
if s_idx < sources.len() {
let s = sources[s_idx];
let r = ((target[0] - s[0]).powi(2)
+ (target[1] - s[1]).powi(2))
.sqrt();
if r > 1e-15 {
phi += charges[s_idx] * r.ln();
}
}
}
}
}
} else {
for (s, q) in sources.iter().zip(charges.iter()) {
let r = ((target[0] - s[0]).powi(2) + (target[1] - s[1]).powi(2)).sqrt();
if r > 1e-15 {
phi += q * r.ln();
}
}
}
potentials[t_idx] = phi;
}
Ok(potentials)
}
}
fn find_leaf<'a>(node: &'a TreeNode, p: [f64; 2]) -> Option<&'a TreeNode> {
if !node.contains(p) {
return None;
}
if node.is_leaf {
return Some(node);
}
if let Some(children) = &node.children {
for child in children.iter() {
if let Some(found) = find_leaf(child, p) {
return Some(found);
}
}
}
Some(node)
}
trait BfsWithIds {
fn bfs_nodes_with_ids(&self) -> Vec<(usize, &TreeNode)>;
}
impl BfsWithIds for TreeNode {
fn bfs_nodes_with_ids(&self) -> Vec<(usize, &TreeNode)> {
let mut result = Vec::new();
let mut queue = std::collections::VecDeque::new();
queue.push_back((0usize, self));
while let Some((idx, node)) = queue.pop_front() {
result.push((idx, node));
if let Some(children) = &node.children {
for child in children.iter() {
queue.push_back((child.node_id, child));
}
}
}
result
}
}
#[cfg(test)]
mod tests {
use super::*;
fn random_points(n: usize, seed: u64) -> (Vec<[f64; 2]>, Vec<f64>) {
let mut state = seed;
let lcg = |s: &mut u64| -> f64 {
*s = s.wrapping_mul(6364136223846793005).wrapping_add(1442695040888963407);
(*s >> 33) as f64 / (1u64 << 31) as f64
};
let positions: Vec<[f64; 2]> = (0..n)
.map(|_| [lcg(&mut state) * 2.0 - 1.0, lcg(&mut state) * 2.0 - 1.0])
.collect();
let charges: Vec<f64> = (0..n)
.map(|_| lcg(&mut state) * 2.0 - 1.0)
.collect();
(positions, charges)
}
#[test]
fn test_direct_sum_single() {
let sources = vec![[1.0_f64, 0.0]];
let charges = vec![1.0_f64];
let targets = vec![[0.0_f64, 0.0]];
let phi = FMM2D::direct_sum(&sources, &charges, &targets);
assert!((phi[0] - 0.0).abs() < 1e-12, "phi={}", phi[0]);
let targets2 = vec![[2.0_f64, 0.0]];
let phi2 = FMM2D::direct_sum(&sources, &charges, &targets2);
assert!((phi2[0] - 0.0).abs() < 1e-12, "phi2={}", phi2[0]);
let targets3 = vec![[1.0_f64 + std::f64::consts::E, 0.0]];
let phi3 = FMM2D::direct_sum(&sources, &charges, &targets3);
assert!((phi3[0] - 1.0).abs() < 1e-10, "phi3={:.8}", phi3[0]);
}
#[test]
fn test_fmm_vs_direct_small() {
let (sources, charges) = random_points(20, 42);
let (targets, _) = random_points(10, 99);
let fmm = FMM2D::new(8).with_mac(0.5).with_max_per_leaf(4);
let fmm_phi = fmm
.compute_potentials(&sources, &charges, &targets)
.expect("FMM failed");
let direct_phi = FMM2D::direct_sum(&sources, &charges, &targets);
for (i, (&fmm_v, &dir_v)) in fmm_phi.iter().zip(direct_phi.iter()).enumerate() {
let rel_err = if dir_v.abs() > 1e-10 {
(fmm_v - dir_v).abs() / dir_v.abs()
} else {
(fmm_v - dir_v).abs()
};
assert!(
rel_err < 0.5,
"target {i}: FMM={fmm_v:.6} direct={dir_v:.6} rel_err={rel_err:.4}"
);
}
}
#[test]
fn test_fmm_direct_sum_parity() {
let sources = vec![
[1.0, 0.0],
[-1.0, 0.0],
[0.0, 1.0],
[0.0, -1.0],
];
let charges = vec![1.0, -1.0, 0.5, -0.5];
let targets = vec![[3.0, 3.0], [4.0, 0.0]];
let direct_phi = FMM2D::direct_sum(&sources, &charges, &targets);
let fmm = FMM2D::new(10).with_mac(0.4).with_max_per_leaf(16);
let fmm_phi = fmm
.compute_potentials(&sources, &charges, &targets)
.expect("FMM failed");
for (i, (&d, &f)) in direct_phi.iter().zip(fmm_phi.iter()).enumerate() {
let err = (d - f).abs();
assert!(err < 0.1, "target {i}: direct={d:.8} fmm={f:.8} err={err:.2e}");
}
}
}