use std::{convert::Infallible, error::Error};
use twine_core::{
DerivativeOf, EquationProblem, Model, OdeProblem, OptimizationProblem, StepIntegrable,
};
use twine_observers::{PlotObserver, ShowConfig};
use twine_solvers::{equation::bisection, optimization::golden_section, transient::euler};
fn main() -> Result<(), Box<dyn Error>> {
let mode = std::env::args().nth(1).unwrap_or_else(|| "bisect".into());
match mode.as_str() {
"bisect" => bisect(),
"maximize" => maximize(),
"ode" => {
let dt = std::env::args()
.nth(2)
.as_deref()
.map(str::parse::<f64>)
.transpose()
.unwrap_or_else(|_| {
eprintln!("Invalid step size — expected a number, e.g. 0.1");
std::process::exit(1);
})
.unwrap_or(0.05);
ode(dt)
}
other => {
eprintln!("Unknown mode: {other}");
eprintln!("Usage: plot [bisect|maximize|ode [dt]]");
std::process::exit(1);
}
}
}
struct Passthrough;
impl Model for Passthrough {
type Input = f64;
type Output = f64;
type Error = Infallible;
fn call(&self, input: &f64) -> Result<f64, Infallible> {
Ok(*input)
}
}
struct CosMinusX;
impl EquationProblem<1> for CosMinusX {
type Input = f64;
type Output = f64;
type Error = Infallible;
fn input(&self, x: &[f64; 1]) -> Result<f64, Infallible> {
Ok(x[0])
}
fn residuals(&self, input: &f64, _output: &f64) -> Result<[f64; 1], Infallible> {
Ok([input.cos() - input])
}
}
fn bisect() -> Result<(), Box<dyn Error>> {
let mut obs = PlotObserver::<2>::new(["x", "Residual"]);
let mut iter = 0_u32;
bisection::solve(
&Passthrough,
&CosMinusX,
[0.0, 2.0],
&bisection::Config::default(),
|event: &bisection::Event<'_, Passthrough, CosMinusX>| {
let n = f64::from(iter);
iter += 1;
let residual = event.result().as_ref().ok().map(|e| e.residuals[0].abs());
obs.record(n, [Some(event.x()), residual]);
None
},
)?;
obs.show(
ShowConfig::new()
.title("Bisection: cos(x) = x → Dottie number ≈ 0.7391")
.legend(),
)?;
Ok(())
}
struct Sine;
impl Model for Sine {
type Input = f64;
type Output = f64;
type Error = Infallible;
fn call(&self, input: &f64) -> Result<f64, Infallible> {
Ok(input.sin())
}
}
struct DirectObjective;
impl OptimizationProblem<1> for DirectObjective {
type Input = f64;
type Output = f64;
type Error = Infallible;
fn input(&self, x: &[f64; 1]) -> Result<f64, Infallible> {
Ok(x[0])
}
fn objective(&self, _input: &f64, output: &f64) -> Result<f64, Infallible> {
Ok(*output)
}
}
fn maximize() -> Result<(), Box<dyn Error>> {
let mut obs = PlotObserver::<2>::new(["sin(x)", "Evaluated points"]);
obs.label_size(16.0);
for i in 0_u32..=2000 {
let u = f64::from(i) / 1000.0 - 1.0;
let x = std::f64::consts::FRAC_PI_2 * (1.0 + u.powi(3));
obs.record(x, [Some(x.sin()), None]);
}
let mut iter = 1_u32;
let config = golden_section::Config::new(100, 1e-5, 1e-5).unwrap();
golden_section::maximize(
&Sine,
&DirectObjective,
[0.0, std::f64::consts::PI],
&config,
|event: &golden_section::Event<'_, Sine, DirectObjective>| {
if let golden_section::Event::Evaluated { point, .. } = event {
obs.record(point.x, [None, Some(point.objective)]);
obs.label(point.x, point.objective, iter.to_string());
iter += 1;
}
None
},
)?;
obs.show(
ShowConfig::new()
.title("Maximize: sin(x) on [0, π] → maximum at (π/2, 1) ≈ (1.571, 1)")
.legend(),
)?;
Ok(())
}
#[derive(Clone, Debug)]
struct OscState {
position: f64,
velocity: f64,
}
#[derive(Clone, Debug)]
struct OscDerivative {
d_position: f64,
d_velocity: f64,
}
impl StepIntegrable<f64> for OscState {
type Derivative = OscDerivative;
fn step(&self, deriv: OscDerivative, dt: f64) -> Self {
OscState {
position: self.position + deriv.d_position * dt,
velocity: self.velocity + deriv.d_velocity * dt,
}
}
}
#[derive(Clone, Debug)]
struct OscInput {
state: OscState,
t: f64,
}
#[derive(Clone, Debug)]
struct OscOutput {
d_position: f64,
d_velocity: f64,
}
struct OscModel {
zeta: f64,
omega0: f64,
}
impl Model for OscModel {
type Input = OscInput;
type Output = OscOutput;
type Error = Infallible;
fn call(&self, input: &OscInput) -> Result<OscOutput, Infallible> {
Ok(OscOutput {
d_position: input.state.velocity,
d_velocity: -2.0 * self.zeta * self.omega0 * input.state.velocity
- self.omega0.powi(2) * input.state.position,
})
}
}
struct OscProblem;
impl OdeProblem for OscProblem {
type Input = OscInput;
type Output = OscOutput;
type Delta = f64;
type State = OscState;
type Error = Infallible;
fn state(&self, input: &OscInput) -> Result<OscState, Infallible> {
Ok(input.state.clone())
}
fn derivative(
&self,
_input: &OscInput,
output: &OscOutput,
) -> Result<DerivativeOf<OscState, f64>, Infallible> {
Ok(OscDerivative {
d_position: output.d_position,
d_velocity: output.d_velocity,
})
}
fn build_input(
&self,
base: &OscInput,
state: &OscState,
dt: &f64,
) -> Result<OscInput, Infallible> {
Ok(OscInput {
state: state.clone(),
t: base.t + dt,
})
}
}
fn ode(dt: f64) -> Result<(), Box<dyn Error>> {
let zeta = 0.1_f64;
let omega0 = 1.0_f64;
let omega_d = (omega0.powi(2) - zeta.powi(2)).sqrt();
let model = OscModel { zeta, omega0 };
let initial = OscInput {
state: OscState {
position: 1.0,
velocity: 0.0,
},
t: 0.0,
};
let analytical = move |t: f64| {
(-zeta * t).exp() * ((omega_d * t).cos() + (zeta / omega_d) * (omega_d * t).sin())
};
let mut obs = PlotObserver::<2>::new(["Euler (numerical)", "Analytical"]);
#[allow(clippy::cast_possible_truncation, clippy::cast_sign_loss)]
let steps = (30.0 / dt).round() as usize;
euler::solve(
&model,
&OscProblem,
initial,
dt,
steps,
|event: &euler::Event<OscInput, OscOutput>| {
let t = event.snapshot.input.t;
obs.record(
t,
[
Some(event.snapshot.input.state.position),
Some(analytical(t)),
],
);
None
},
)?;
obs.show(
ShowConfig::new()
.title(format!(
"ODE: Damped oscillator (ζ=0.1, dt={dt}) — Euler vs. analytical"
))
.legend(),
)?;
Ok(())
}