use super::reactions::*;
use crate::constants::*;
use std::collections::HashMap;
#[derive(Debug, Clone)]
pub struct AbundanceState {
pub time: f64, pub temperature: f64, pub abundances: HashMap<Nuclide, f64>, }
impl AbundanceState {
pub fn initial(temp_k: f64, baryon_density: f64) -> Self {
let mut abundances = HashMap::new();
let n_to_p = neutron_proton_equilibrium(temp_k);
let n_p = baryon_density / (1.0 + n_to_p);
let n_n = n_to_p * n_p;
abundances.insert(Nuclide::Proton, n_p);
abundances.insert(Nuclide::Neutron, n_n);
abundances.insert(Nuclide::Deuterium, 0.0);
abundances.insert(Nuclide::Tritium, 0.0);
abundances.insert(Nuclide::Helium3, 0.0);
abundances.insert(Nuclide::Helium4, 0.0);
abundances.insert(Nuclide::Lithium7, 0.0);
abundances.insert(Nuclide::Beryllium7, 0.0);
AbundanceState {
time: 0.0,
temperature: temp_k,
abundances,
}
}
pub fn total_baryon_number(&self) -> f64 {
self.abundances.iter()
.map(|(nuclide, &n)| n * nuclide.mass_number() as f64)
.sum()
}
pub fn mass_fraction(&self, nuclide: Nuclide) -> f64 {
let n = self.abundances.get(&nuclide).copied().unwrap_or(0.0);
let a = nuclide.mass_number() as f64;
let total_mass: f64 = self.abundances.iter()
.map(|(nuc, &num)| num * nuc.mass_number() as f64)
.sum();
if total_mass > 0.0 {
n * a / total_mass
} else {
0.0
}
}
pub fn helium_mass_fraction(&self) -> f64 {
self.mass_fraction(Nuclide::Helium4)
}
}
fn neutron_proton_equilibrium(temp_k: f64) -> f64 {
let delta_m = (M_N - M_P) * C * C; let delta_m_mev = delta_m * J_TO_EV * 1e-6; let t_mev = K_B * temp_k * J_TO_EV * 1e-6;
(-delta_m_mev / t_mev).exp()
}
pub struct NetworkSolver {
reactions: Vec<Reaction>,
}
impl NetworkSolver {
pub fn new(reactions: Vec<Reaction>) -> Self {
NetworkSolver { reactions }
}
pub fn standard_bbn() -> Self {
NetworkSolver::new(standard_bbn_reactions())
}
pub fn derivatives(&self, state: &AbundanceState) -> HashMap<Nuclide, f64> {
let mut dndt = HashMap::new();
for nuclide in [
Nuclide::Neutron, Nuclide::Proton, Nuclide::Deuterium,
Nuclide::Tritium, Nuclide::Helium3, Nuclide::Helium4,
Nuclide::Lithium7, Nuclide::Beryllium7,
] {
dndt.insert(nuclide, 0.0);
}
for reaction in &self.reactions {
let rate = reaction.rate_coefficient(state.temperature);
let mut n_product = rate;
for reactant in &reaction.reactants {
let n = state.abundances.get(reactant).copied().unwrap_or(0.0);
n_product *= n;
}
for reactant in &reaction.reactants {
*dndt.get_mut(reactant).unwrap() -= n_product;
}
for product in &reaction.products {
*dndt.get_mut(product).unwrap() += n_product;
}
let reverse_rate = reaction.reverse_rate(state.temperature);
let mut n_reverse = reverse_rate;
for product in &reaction.products {
let n = state.abundances.get(product).copied().unwrap_or(0.0);
n_reverse *= n;
}
for product in &reaction.products {
*dndt.get_mut(product).unwrap() -= n_reverse;
}
for reactant in &reaction.reactants {
*dndt.get_mut(reactant).unwrap() += n_reverse;
}
}
dndt
}
pub fn evolve(
&self,
initial_state: AbundanceState,
t_start: f64,
t_end: f64,
n_steps: usize,
) -> Vec<AbundanceState> {
let mut states = vec![initial_state.clone()];
let initial_temp = initial_state.temperature;
let mut current = initial_state;
let mut t = t_start;
let dt_max = 10.0; let dt_min = 1e-6; let mut dt: f64;
let mut step_count = 0;
const MAX_STEPS: usize = 1000000;
while t < t_end && step_count < MAX_STEPS {
let derivs = self.derivatives(¤t);
let mut dt_safe: f64 = dt_max;
for (nuclide, &dndt) in &derivs {
let n = current.abundances.get(nuclide).copied().unwrap_or(0.0);
if n > 1e10 && dndt.abs() > 1e-10 {
let dt_suggested = 0.1 * n / dndt.abs();
dt_safe = dt_safe.min(dt_suggested);
}
}
dt = dt_safe.max(dt_min).min(dt_max);
if t + dt > t_end {
dt = t_end - t;
}
let mut mid_abundances = current.abundances.clone();
for (nuclide, &dndt) in &derivs {
let old_n = current.abundances.get(nuclide).copied().unwrap_or(0.0);
let mid_n = (old_n + dndt * dt / 2.0).max(0.0);
mid_abundances.insert(*nuclide, mid_n);
}
let mid_time = t + dt / 2.0;
let mid_temp = initial_temp * (t_start / mid_time).sqrt();
let mid_state = AbundanceState {
time: mid_time,
temperature: mid_temp,
abundances: mid_abundances,
};
let mid_derivs = self.derivatives(&mid_state);
let mut new_abundances = current.abundances.clone();
for (nuclide, &dndt) in &mid_derivs {
let old_n = current.abundances.get(nuclide).copied().unwrap_or(0.0);
let new_n = (old_n + dndt * dt).max(0.0);
new_abundances.insert(*nuclide, new_n);
}
t += dt;
let new_temp = initial_temp * (t_start / t).sqrt();
current = AbundanceState {
time: t,
temperature: new_temp,
abundances: new_abundances,
};
if step_count % (MAX_STEPS / n_steps).max(1) == 0 || t >= t_end {
states.push(current.clone());
}
step_count += 1;
}
if states.last().map(|s| s.time).unwrap_or(0.0) < t_end {
states.push(current);
}
states
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_network_evolution() {
let solver = NetworkSolver::standard_bbn();
let initial = AbundanceState::initial(1e9, 1e18);
let evolution = solver.evolve(initial, 0.1, 100.0, 100);
assert!(evolution.len() >= 2, "Evolution should have at least initial and final states");
assert!(evolution.len() <= 100000, "Evolution should not exceed safety limit");
}
#[test]
fn test_initial_state() {
let state = AbundanceState::initial(1e10, 1e18);
let n_p = state.abundances.get(&Nuclide::Proton).unwrap();
let n_n = state.abundances.get(&Nuclide::Neutron).unwrap();
assert!(n_p > &0.0);
assert!(n_n > &0.0);
}
}