use nalgebra::DVector;
use sindr_devices::bjt::{BjtKind, BjtParams};
use sindr_devices::diode::{self, temperature_scale_is, DiodeParams, GMIN};
use crate::circuit::{Circuit, CircuitElement};
use crate::error::SimError;
use crate::mna::MnaSystem;
use crate::node_map::NodeMap;
use crate::results::{self, SimulationResult};
use crate::stamp;
pub const MAX_NR_ITERATIONS: usize = 100;
pub const V_ABSTOL: f64 = 1e-6;
#[allow(dead_code)]
pub const I_ABSTOL: f64 = 1e-12;
pub const RELTOL: f64 = 1e-3;
pub fn solve_nonlinear(
circuit: &Circuit,
node_map: &NodeMap,
num_nodes: usize,
num_vsources: usize,
) -> Result<SimulationResult, SimError> {
solve_nonlinear_with_seeds(circuit, node_map, num_nodes, num_vsources, None)
}
fn solve_nonlinear_with_seeds(
circuit: &Circuit,
node_map: &NodeMap,
num_nodes: usize,
num_vsources: usize,
initial_voltages: Option<&std::collections::HashMap<String, f64>>,
) -> Result<SimulationResult, SimError> {
match nr_inner(
circuit,
node_map,
num_nodes,
num_vsources,
None,
initial_voltages,
0.0,
) {
Ok((result, _)) => Ok(result),
Err(SimError::ConvergenceFailed { .. }) => {
gmin_stepping(circuit, node_map, num_nodes, num_vsources, initial_voltages)
}
Err(e) => Err(e),
}
}
fn gmin_stepping(
circuit: &Circuit,
node_map: &NodeMap,
num_nodes: usize,
num_vsources: usize,
initial_voltages: Option<&std::collections::HashMap<String, f64>>,
) -> Result<SimulationResult, SimError> {
const LADDER: &[f64] = &[1e-2, 1e-3, 1e-4, 1e-5, 1e-6, 1e-8, 1e-10, 0.0];
let mut warm_start: Option<DVector<f64>> = None;
let mut last_err = SimError::ConvergenceFailed {
iterations: 0,
max_step_volts: f64::INFINITY,
};
for &gmin_extra in LADDER {
let seeds = if warm_start.is_some() {
None
} else {
initial_voltages
};
match nr_inner(
circuit,
node_map,
num_nodes,
num_vsources,
warm_start.as_ref(),
seeds,
gmin_extra,
) {
Ok((result, v)) => {
if gmin_extra == 0.0 {
return Ok(result);
}
warm_start = Some(v);
}
Err(e) => {
last_err = e;
return Err(last_err);
}
}
}
Err(last_err)
}
pub fn solve_nonlinear_with_initial_voltages(
circuit: &Circuit,
node_map: &NodeMap,
num_nodes: usize,
num_vsources: usize,
initial_voltages: &std::collections::HashMap<String, f64>,
) -> Result<SimulationResult, SimError> {
solve_nonlinear_with_seeds(
circuit,
node_map,
num_nodes,
num_vsources,
Some(initial_voltages),
)
}
fn nr_inner(
circuit: &Circuit,
node_map: &NodeMap,
num_nodes: usize,
num_vsources: usize,
v_init: Option<&DVector<f64>>,
initial_voltages: Option<&std::collections::HashMap<String, f64>>,
gmin_extra: f64,
) -> Result<(SimulationResult, DVector<f64>), SimError> {
let size = num_nodes + num_vsources;
let mut v_prev = match v_init {
Some(v) => v.clone(),
None => {
let mut v = DVector::zeros(size);
init_bjt_voltages(circuit, node_map, &mut v);
v
}
};
if let Some(seed) = initial_voltages {
for (name, value) in seed {
if let Some(idx) = node_map.index(name) {
v_prev[idx] = *value;
}
}
}
let diode_info = collect_diode_info(circuit, node_map);
let debug_nr = std::env::var("DEBUG_NR").is_ok();
let mut last_step = f64::INFINITY;
for _iteration in 0..MAX_NR_ITERATIONS {
let mut system = MnaSystem::new(num_nodes, num_vsources);
stamp::stamp_circuit(circuit, &mut system, node_map, Some(&v_prev))?;
add_gmin_shunts(&mut system, num_nodes);
if gmin_extra > 0.0 {
for i in 0..num_nodes {
system.a[(i, i)] += gmin_extra;
}
}
let v_new = system.solve()?;
let v_limited = apply_voltage_limiting(&v_new, &v_prev, &diode_info);
if debug_nr && _iteration < 30 {
eprintln!(
"NR iter {}: v_prev={:.4?} -> v_new={:.4?} -> v_limited={:.4?}",
_iteration,
v_prev.iter().take(num_nodes).collect::<Vec<_>>(),
v_new.iter().take(num_nodes).collect::<Vec<_>>(),
v_limited.iter().take(num_nodes).collect::<Vec<_>>()
);
}
if converged(&v_prev, &v_limited, num_nodes) {
let result = results::extract_results(circuit, node_map, &v_limited, num_nodes);
return Ok((result, v_limited));
}
last_step = max_node_step(&v_prev, &v_limited, num_nodes);
v_prev = v_limited;
}
Err(SimError::ConvergenceFailed {
iterations: MAX_NR_ITERATIONS,
max_step_volts: last_step,
})
}
pub(crate) struct DiodeLimitInfo {
anode_idx: Option<usize>,
cathode_idx: Option<usize>,
v_crit: f64,
n: f64,
correction_node: Option<CorrectionTarget>,
is_bjt: bool,
}
enum CorrectionTarget {
Anode,
Cathode,
}
pub(crate) fn collect_diode_info(circuit: &Circuit, node_map: &NodeMap) -> Vec<DiodeLimitInfo> {
let mut info = Vec::new();
for component in &circuit.components {
match component {
CircuitElement::Diode {
nodes, temperature, ..
} => {
let mut params = DiodeParams::silicon();
if (*temperature - 300.15).abs() > 1e-6 {
params.is = temperature_scale_is(params.is, *temperature, 300.15, 1.11, 2.0);
}
info.push(DiodeLimitInfo {
anode_idx: node_map.index(&nodes[0]),
cathode_idx: node_map.index(&nodes[1]),
v_crit: diode::critical_voltage(params.is, params.n),
n: params.n,
correction_node: None,
is_bjt: false,
});
}
CircuitElement::Led {
nodes,
color,
temperature,
..
} => {
let mut params = DiodeParams::for_led_color(color);
if (*temperature - 300.15).abs() > 1e-6 {
params.is = temperature_scale_is(params.is, *temperature, 300.15, 1.11, 2.0);
}
info.push(DiodeLimitInfo {
anode_idx: node_map.index(&nodes[0]),
cathode_idx: node_map.index(&nodes[1]),
v_crit: diode::critical_voltage(params.is, params.n),
n: params.n,
correction_node: None,
is_bjt: false,
});
}
CircuitElement::ZenerDiode { nodes, .. } => {
let params = DiodeParams::silicon();
info.push(DiodeLimitInfo {
anode_idx: node_map.index(&nodes[0]),
cathode_idx: node_map.index(&nodes[1]),
v_crit: diode::critical_voltage(params.is, params.n),
n: params.n,
correction_node: None,
is_bjt: false,
});
}
CircuitElement::SchottkyDiode { nodes, .. } => {
let schottky_params = sindr_devices::schottky::SchottkyParams::default();
info.push(DiodeLimitInfo {
anode_idx: node_map.index(&nodes[0]),
cathode_idx: node_map.index(&nodes[1]),
v_crit: diode::critical_voltage(schottky_params.is, schottky_params.n),
n: schottky_params.n,
correction_node: None,
is_bjt: false,
});
}
CircuitElement::Photodiode { nodes, .. } => {
let photo_params = sindr_devices::photodiode::PhotodiodeParams::default();
info.push(DiodeLimitInfo {
anode_idx: node_map.index(&nodes[0]),
cathode_idx: node_map.index(&nodes[1]),
v_crit: diode::critical_voltage(photo_params.is, photo_params.n),
n: photo_params.n,
correction_node: None,
is_bjt: false,
});
}
CircuitElement::Bjt {
nodes,
kind,
bf,
temperature,
..
} => {
let mut params = BjtParams::new(*bf);
if (*temperature - 300.15).abs() > 1e-6 {
params.is = temperature_scale_is(params.is, *temperature, 300.15, 1.11, 3.0);
}
let (be_anode, be_cathode) = match kind {
BjtKind::Npn => (&nodes[0], &nodes[2]),
BjtKind::Pnp => (&nodes[2], &nodes[0]),
};
info.push(DiodeLimitInfo {
anode_idx: node_map.index(be_anode),
cathode_idx: node_map.index(be_cathode),
v_crit: diode::critical_voltage(params.is, params.nf),
n: params.nf,
correction_node: None,
is_bjt: true,
});
let (bc_anode, bc_cathode) = match kind {
BjtKind::Npn => (&nodes[0], &nodes[1]),
BjtKind::Pnp => (&nodes[1], &nodes[0]),
};
let bc_correction = match kind {
BjtKind::Npn => CorrectionTarget::Cathode,
BjtKind::Pnp => CorrectionTarget::Anode,
};
info.push(DiodeLimitInfo {
anode_idx: node_map.index(bc_anode),
cathode_idx: node_map.index(bc_cathode),
v_crit: diode::critical_voltage(params.is, params.nr),
n: params.nr,
correction_node: Some(bc_correction),
is_bjt: true,
});
}
_ => {}
}
}
info
}
pub(crate) fn apply_voltage_limiting(
v_new: &DVector<f64>,
v_prev: &DVector<f64>,
diode_info: &[DiodeLimitInfo],
) -> DVector<f64> {
let mut v_limited = v_new.clone();
for info in diode_info {
let v_anode_new = info.anode_idx.map_or(0.0, |i| v_limited[i]);
let v_cathode_new = info.cathode_idx.map_or(0.0, |i| v_limited[i]);
let v_d_new = v_anode_new - v_cathode_new;
let v_anode_old = info.anode_idx.map_or(0.0, |i| v_prev[i]);
let v_cathode_old = info.cathode_idx.map_or(0.0, |i| v_prev[i]);
let v_d_old = v_anode_old - v_cathode_old;
let mut v_d_limited = diode::limit_voltage(v_d_new, v_d_old, info.v_crit, info.n);
if info.is_bjt {
let max_step = 10.0 * info.n * sindr_devices::diode::V_T; let step = v_d_limited - v_d_old;
if step.abs() > max_step {
v_d_limited = v_d_old + step.signum() * max_step;
}
}
let delta = v_d_limited - v_d_new;
if delta.abs() <= 1e-15 {
continue;
}
match &info.correction_node {
Some(CorrectionTarget::Anode) => {
if let Some(ai) = info.anode_idx {
v_limited[ai] += delta;
}
}
Some(CorrectionTarget::Cathode) => {
if let Some(ci) = info.cathode_idx {
v_limited[ci] -= delta;
}
}
None => {
let both_present = info.anode_idx.is_some() && info.cathode_idx.is_some();
let share = if both_present { 0.5 } else { 1.0 };
if let Some(ai) = info.anode_idx {
v_limited[ai] += delta * share;
}
if let Some(ci) = info.cathode_idx {
v_limited[ci] -= delta * share;
}
}
}
}
v_limited
}
pub(crate) fn add_gmin_shunts(system: &mut MnaSystem, num_nodes: usize) {
for i in 0..num_nodes {
system.a[(i, i)] += GMIN;
}
}
pub(crate) fn converged(v_prev: &DVector<f64>, v_new: &DVector<f64>, num_nodes: usize) -> bool {
for i in 0..num_nodes {
let diff = (v_new[i] - v_prev[i]).abs();
let tol = V_ABSTOL + RELTOL * v_new[i].abs().max(v_prev[i].abs());
if diff > tol {
return false;
}
}
true
}
pub(crate) fn max_node_step(v_prev: &DVector<f64>, v_new: &DVector<f64>, num_nodes: usize) -> f64 {
let mut max = 0.0f64;
for i in 0..num_nodes {
let diff = (v_new[i] - v_prev[i]).abs();
if diff > max {
max = diff;
}
}
max
}
fn init_bjt_voltages(circuit: &Circuit, node_map: &NodeMap, v_prev: &mut DVector<f64>) {
const VBE_GUESS: f64 = 0.65;
let has_nonlinear = circuit.components.iter().any(|c| {
matches!(
c,
CircuitElement::Bjt { .. }
| CircuitElement::Mosfet { .. }
| CircuitElement::ZenerDiode { .. }
| CircuitElement::SchottkyDiode { .. }
| CircuitElement::Photodiode { .. }
)
});
if !has_nonlinear {
return;
}
let num_nodes = node_map.num_nodes();
let num_vsources = circuit.count_voltage_sources();
let mut system = MnaSystem::new(num_nodes, num_vsources);
let mut vsource_index: usize = 0;
for component in &circuit.components {
match component {
CircuitElement::Resistor {
nodes, resistance, ..
} if *resistance > 0.0 => {
stamp::stamp_resistor(&mut system, node_map, nodes, *resistance);
}
CircuitElement::VoltageSource { nodes, voltage, .. } => {
let branch = num_nodes + vsource_index;
stamp::stamp_voltage_source(&mut system, node_map, nodes, *voltage, branch);
vsource_index += 1;
}
CircuitElement::CurrentSource { nodes, current, .. } => {
stamp::stamp_current_source(&mut system, node_map, nodes, *current);
}
_ => {} }
}
add_gmin_shunts(&mut system, num_nodes);
if let Ok(linear_solution) = system.solve() {
for i in 0..v_prev.len().min(linear_solution.len()) {
v_prev[i] = linear_solution[i];
}
}
for component in &circuit.components {
if let CircuitElement::Bjt { nodes, kind, .. } = component {
let base_idx = node_map.index(&nodes[0]);
let emitter_idx = node_map.index(&nodes[2]);
let ve = emitter_idx.map_or(0.0, |i| v_prev[i]);
match kind {
BjtKind::Npn => {
if let Some(bi) = base_idx {
let current_vbe = v_prev[bi] - ve;
if !(0.3..=1.0).contains(¤t_vbe) {
v_prev[bi] = ve + VBE_GUESS;
}
}
}
BjtKind::Pnp => {
if let Some(bi) = base_idx {
let current_vbe = v_prev[bi] - ve;
if !(-1.0..=-0.3).contains(¤t_vbe) {
v_prev[bi] = ve - VBE_GUESS;
}
}
}
}
}
}
for component in &circuit.components {
if let CircuitElement::ZenerDiode { nodes, vz, .. } = component {
let anode_idx = node_map.index(&nodes[0]);
let cathode_idx = node_map.index(&nodes[1]);
let v_anode = anode_idx.map_or(0.0, |i| v_prev[i]);
let v_cathode = cathode_idx.map_or(0.0, |i| v_prev[i]);
let v_d = v_anode - v_cathode;
const V_INIT_FORWARD: f64 = 0.4;
let v_crit_zener = vz + 0.1;
let v_d_clamped = v_d.clamp(-v_crit_zener, V_INIT_FORWARD);
if (v_d_clamped - v_d).abs() > 1e-9 {
if let Some(ai) = anode_idx {
v_prev[ai] = v_cathode + v_d_clamped;
} else if let Some(ci) = cathode_idx {
v_prev[ci] = -v_d_clamped;
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::circuit::CircuitElement;
fn diode_resistor_circuit(vsource: f64) -> Circuit {
Circuit {
ground_node: "0".into(),
components: vec![
CircuitElement::VoltageSource {
id: "V1".into(),
nodes: ["n1".into(), "0".into()],
voltage: vsource,
waveform: None,
},
CircuitElement::Resistor {
id: "R1".into(),
nodes: ["n1".into(), "n2".into()],
resistance: 100.0,
},
CircuitElement::Diode {
id: "D1".into(),
nodes: ["n2".into(), "0".into()],
temperature: 300.15,
},
],
}
}
#[test]
fn max_node_step_picks_largest_diff() {
let prev = DVector::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
let new = DVector::from_vec(vec![1.05, 1.9, 3.5, 4.0]);
let r = max_node_step(&prev, &new, 4);
assert!((r - 0.5).abs() < 1e-12, "got {r}");
}
#[test]
fn max_node_step_only_considers_node_range() {
let prev = DVector::from_vec(vec![1.0, 2.0, 3.0, 4.0]);
let new = DVector::from_vec(vec![1.01, 2.01, 99.0, 99.0]);
let r = max_node_step(&prev, &new, 2);
assert!((r - 0.01).abs() < 1e-12, "got {r}");
}
#[test]
fn gmin_stepping_solves_diode_resistor_directly() {
let circuit = diode_resistor_circuit(5.0);
let node_map = NodeMap::from_circuit(&circuit);
let num_nodes = node_map.num_nodes();
let num_vsources = circuit.count_voltage_sources();
let result = gmin_stepping(&circuit, &node_map, num_nodes, num_vsources, None)
.expect("gmin stepping should converge on a diode-R circuit");
let v_n2 = result.node_voltages["n2"];
assert!(
(0.5..=0.9).contains(&v_n2),
"expected forward-bias diode drop, got V(n2) = {v_n2}"
);
let plain = crate::solve_circuit(&circuit).expect("plain solve should converge");
let v_n2_plain = plain.node_voltages["n2"];
assert!(
(v_n2 - v_n2_plain).abs() < 1e-3,
"gmin path ({v_n2}) and plain NR ({v_n2_plain}) should agree to ~mV"
);
}
#[test]
fn gmin_stepping_rescues_pathological_initial_seed() {
let circuit = diode_resistor_circuit(5.0);
let node_map = NodeMap::from_circuit(&circuit);
let num_nodes = node_map.num_nodes();
let num_vsources = circuit.count_voltage_sources();
let mut bad_seed = std::collections::HashMap::new();
bad_seed.insert("n1".to_string(), 1e6);
bad_seed.insert("n2".to_string(), -1e6);
let result = gmin_stepping(
&circuit,
&node_map,
num_nodes,
num_vsources,
Some(&bad_seed),
)
.expect("gmin stepping should rescue absurd initial seed");
let v_n2 = result.node_voltages["n2"];
assert!(
(0.5..=0.9).contains(&v_n2),
"rescued solution should still be physical, got V(n2) = {v_n2}"
);
}
#[test]
fn solve_nonlinear_still_correct_for_easy_circuit() {
let circuit = diode_resistor_circuit(5.0);
let result = crate::solve_circuit(&circuit).expect("should converge");
let v_n2 = result.node_voltages["n2"];
assert!((0.5..=0.9).contains(&v_n2), "V(n2) = {v_n2}");
}
#[test]
fn gmin_stepping_handles_low_bias() {
let circuit = diode_resistor_circuit(0.3);
let node_map = NodeMap::from_circuit(&circuit);
let num_nodes = node_map.num_nodes();
let num_vsources = circuit.count_voltage_sources();
let result = gmin_stepping(&circuit, &node_map, num_nodes, num_vsources, None)
.expect("low-bias diode should still converge via homotopy");
let v_n2 = result.node_voltages["n2"];
assert!(
(0.0..=0.35).contains(&v_n2),
"low-bias diode should sit below knee, got V(n2) = {v_n2}"
);
}
}