1#![doc = include_str!("../README.md")]
2
3use std::{marker::PhantomData, mem, time::{Duration, Instant}};
4
5use crate::error::{ReachMaxIteration, TimeOut};
6
7pub mod error;
8mod utils;
9
10pub trait IterState: Sized {
20 type Value;
22
23 type Solution;
25
26 fn init_from_value(initial_point: Self::Value) -> Self;
28
29 fn to_sol(&self) -> Self::Solution;
31}
32
33macro_rules! iterstate_impl {
34 ($($ty:ty),*) => {
35 $(
36 impl IterState for $ty {
37 type Value = $ty;
38 type Solution = $ty;
39 fn init_from_value(initial_point: Self::Value) -> Self {
40 initial_point
41 }
42 fn to_sol(&self) -> Self::Solution {
43 *self
44 }
45 }
46 )*
47 };
48}
49
50iterstate_impl!(
51 u8,u16,u32,u64,u128,usize,
52 i8,i16,i32,i64,i128,isize,
53 f32, f64
54);
55
56
57#[derive(Debug)]
63pub struct Solver<State, Problem, IterFn, TermFn>
64where
65State: IterState,
66IterFn: Fn(&State, &Problem) -> State,
67TermFn: Fn(&State, &Problem) -> bool,
68{
69 state: PhantomData<State>,
71 problem: PhantomData<Problem>,
73 iter_fn: IterFn,
75 term_cond: TermFn,
77}
78impl<State, Problem, IterFn, TermFn> Solver<State, Problem, IterFn, TermFn>
79where
80 State: IterState,
81 IterFn: Fn(&State, &Problem) -> State,
82 TermFn: Fn(&State, &Problem) -> bool
83{
84 pub fn new(
91 iter_fn: IterFn,
92 term_cond: TermFn
93 ) -> Self {
94 Self {
95 state: PhantomData::<State>,
96 problem: PhantomData::<Problem>,
97 iter_fn: iter_fn,
98 term_cond: term_cond
99 }
100 }
101
102 pub fn solve(&self, initial_point: State::Value, problem: &Problem) -> State::Solution {
109 let initial_state = State::init_from_value(initial_point);
111 let mut state = initial_state;
112
113 loop {
115 state = (self.iter_fn)(&state, problem);
118
119 if (self.term_cond)(&state, problem) {
121 break;
122 }
123 }
124
125 let final_state = state;
126
127 let sol = final_state.to_sol();
128
129 sol
132 }
133
134
135 pub fn solve_with_max_iteration(
150 &self,
151 initial_point: State::Value,
152 problem: &Problem,
153 max_iteration: u64
154 ) -> Result<State::Solution, error::ReachMaxIteration<State>> {
155 let mut reach_max = true;
156
157 let initial_state = State::init_from_value(initial_point);
159 let mut state = initial_state;
160
161 for _iteration in 0..max_iteration {
162 state = (self.iter_fn)(&state, problem);
165
166 if (self.term_cond)(&state, problem) {
168 reach_max = false;
169 break;
170 }
171 }
172
173 let ret_res = if !reach_max {
174 let final_state = state;
175 let sol = final_state.to_sol();
176 Ok(sol)
178 } else {
179 let final_state = state;
180 Err(ReachMaxIteration{
181 max_iteration: max_iteration,
182 final_state: final_state
183 })
184 };
185
186 ret_res
187 }
188
189 pub fn solve_with_timeout(
205 &self,
206 initial_point: State::Value,
207 problem: &Problem,
208 timeout: Duration
209 ) -> Result<State::Solution, TimeOut<State>> {
210 let start_time = Instant::now();
211 let mut is_timeout = true;
212
213 let initial_state = State::init_from_value(initial_point);
215 let mut state = initial_state;
216
217 loop {
219 state = (self.iter_fn)(&state, problem);
222
223 if start_time.elapsed() > timeout {
224 break;
225 }
226
227 if (self.term_cond)(&state, problem) {
229 is_timeout = false;
230 break;
231 }
232 }
233
234 if !is_timeout {
235 let final_state = state;
236
237 let sol = final_state.to_sol();
238
239 Ok(sol)
242 } else {
243 let final_state = state; Err(TimeOut { timeout: timeout, final_state: final_state })
245 }
246
247 }
248
249
250
251 pub fn into_iter<'prob>(self,
268 initial_point: State::Value,
269 problem: &'prob Problem
270 ) -> SolverIterater<'prob, State, Problem, IterFn, TermFn> {
271 let initial_state = State::init_from_value(initial_point);
273
274 SolverIterater {
275 problem: problem,
276 state: Some(initial_state),
277 iter_fn: self.iter_fn,
278 term_cond: self.term_cond,
279 }
281 }
282
283 pub fn change_term_cond<NewTermCond>(self, new_cond: NewTermCond) -> Solver<State, Problem, IterFn, NewTermCond>
285 where
286 NewTermCond: Fn(&State, &Problem) -> bool
287 {
288 Solver {
289 state: self.state,
290 problem: self.problem,
291 iter_fn: self.iter_fn,
292 term_cond: new_cond
293 }
294 }
295}
296
297impl<State, Problem, IterFn, TermFn> Clone for Solver<State, Problem, IterFn, TermFn>
298where
299 State: IterState,
300 IterFn: Fn(&State, &Problem) -> State + Clone,
301 TermFn: Fn(&State, &Problem) -> bool + Clone
302{
303 fn clone(&self) -> Self {
304 Self { state: PhantomData::<State>, problem: PhantomData::<Problem>, iter_fn: self.iter_fn.clone(), term_cond: self.term_cond.clone() }
305 }
306}
307
308#[derive(Debug)]
320pub struct SolverIterater<'prob, State, Problem, IterFn, TermFn>
321where
322 State: IterState,
323 IterFn: Fn(&State, &Problem) -> State,
324 TermFn: Fn(&State, &Problem) -> bool
325{
326 state: Option<State>,
327 iter_fn: IterFn,
328 term_cond: TermFn,
329 problem: &'prob Problem,
330 }
332
333impl<'prob, State, Problem, IterFn, TermFn> Iterator for SolverIterater<'prob, State, Problem, IterFn, TermFn>
334where
335 State: IterState,
336 IterFn: Fn(&State, &Problem) -> State,
337 TermFn: Fn(&State, &Problem) -> bool
338{
339 type Item = State;
340
341 fn next(&mut self) -> Option<Self::Item> {
342 let state = if (&self.state).is_none() {
343 return None;
344 } else {
345 &self.state.as_ref().unwrap()
346 };
347
348 if (self.term_cond)(&state, &self.problem) {
351 let old_state = mem::replace(&mut self.state, None);
352
353 return old_state;
354 } else {
355 let next_state = (self.iter_fn)(state, &self.problem);
356 let old_state = mem::replace(&mut self.state, Some(next_state));
357
358 return old_state;
359 }
360 }
361}
362
363
364
365
366
367
368#[cfg(test)]
369mod test {
370 mod newton {
371 use crate::{IterState, Solver};
372
373 fn f_and_df(x: f64) -> (f64, f64) {
374 let fx = x.exp() - 1.5;
375 let dfx = x.exp();
376 (fx, dfx)
377 }
378 #[test]
379 fn show() {
380 let iter_fn = |state: &f64, problem: &fn(f64) -> (f64, f64)| {
381 let x_n = *state;
382 let (fx, dfx) = problem(x_n);
383 x_n - (fx / dfx)
384 };
385
386 let term_cond = |state: &f64, problem: &fn(f64) -> (f64, f64)| {
387 let (fx, _) = problem(*state);
388 fx.abs() < 1e-6
389 };
390
391 let solver = Solver::new(iter_fn, term_cond);
392
393 let solution = solver.solve(1.5, &(f_and_df as fn(f64) -> (f64, f64)));
394
395 println!("solver's solution: {}", solution);
396 println!("use std function ln: {}", 1.5_f64.ln());
397 }
398
399 #[derive(Clone)]
400 enum Equation {
401 Exp {
402 a: f64,
404 k: f64
405 },
406
407 Square {
408 a: f64,
410 b: f64,
411 c: f64
412 }
413 }
414
415 impl Equation {
416 fn calc(&self, val: f64) -> f64 {
417 match self {
418 Self::Exp { a, k } => {
419 a * val.exp() - k
420 }
421
422 Self::Square { a, b, c } => {
423 let x2 = a * val * val;
424 let x1 = b * val;
425 let x0 = c;
426 x2 + x1 + x0
427 }
428 }
429 }
430
431 fn diff(&self, val: f64) -> f64 {
432 match self {
433 Self::Exp { a, k: _ } => {
434 a * val.exp()
435 }
436
437 Self::Square { a, b, c: _ } => {
438 ((2. * a) * val) + b
439 }
440 }
441 }
442 }
443
444 #[derive(Debug, Clone)]
445 struct NewtonState(f64);
446
447 impl IterState for NewtonState {
448 type Value = f64;
449
450 type Solution = f64;
451
452 fn init_from_value(initial_point: Self::Value) -> Self {
453 Self(initial_point)
454 }
455
456 fn to_sol(&self) -> Self::Solution {
457 self.0
458 }
459 }
460
461 #[test]
462 fn test() {
463 let iter_fn = |state: &NewtonState, problem: &Equation| {
464 let x = state.0;
465 let dx = problem.diff(x);
466 let fx = problem.calc(x);
467
468 let next_x = x - (fx / dx);
469
470 NewtonState(next_x)
471 };
472
473 let term_cond = |state: &NewtonState, problem: &Equation| {
474 let epsilon = 1e-6;
475 problem.calc(state.0) < epsilon
476 };
477
478 let solver = Solver::new(iter_fn, term_cond);
479
480 let solver1 = solver.clone().change_term_cond(|state, equation| {
481 equation.calc(state.0) < 1e-9
482 });
483
484 let prob1 = (Equation::Exp { a: 2., k: 3. }, 2.);
485
486 let cloned_and_change_cond_sol = solver1.solve(prob1.1, &prob1.0.clone());
487
488 let prob2 = (Equation::Square { a: 2., b: -5., c: 3. }, 6.);
489
490 let prob1_sol = solver.solve(prob1.1, &prob1.0);
491 let prob2_sol = solver.solve(prob2.1, &prob2.0);
492
493 println!("the numerical solution of $2e^x - 3 = 0$ is: {}", prob1_sol);
494 println!("with direct calc: {}", (1.5_f64).ln());
495 println!("the numerical solution of $2x^2 - 5x + 3 = 0$ is: {}", prob2_sol);
496 println!("with direct calc: {} or {}", ((5. + 1.)/4.) , (3./4.));
497
498 println!("cloned sol: {}", cloned_and_change_cond_sol);
499
500 assert!(prob1.0.calc(prob1_sol) < 1e-6);
501 assert!(prob2.0.calc(prob2_sol) < 1e-6)
502 }
503
504 #[test]
505 fn solver_iter() {
506 let iter_fn = |state: &NewtonState, problem: &Equation| {
507 let x = state.0;
508 let dx = problem.diff(x);
509 let fx = problem.calc(x);
510
511 let next_x = x - (fx / dx);
512
513 NewtonState(next_x)
514 };
515
516 let term_cond = |state: &NewtonState, problem: &Equation| {
517 let epsilon = 1e-6;
518 problem.calc(state.0).abs() < epsilon
519 };
520
521 let solver = Solver::new(iter_fn, term_cond);
522
523 let prob1 = (Equation::Exp { a: 2., k: 3. }, 2.);
524
525 let prob2 = (Equation::Square { a: 2., b: -5., c: 3. }, 6.);
526
527 {
528 let iter = solver.clone().into_iter(prob1.1, &prob1.0);
529 let mut counter = 0usize;
530
531 for state in iter {
532 counter += 1;
533 if counter == 5 {
534 println!("after 5 times iter, the f(x) = {}", (prob1.0).calc(state.0));
535 break;
536 }
537 }
538 }
539
540 {
543 let iter = solver.into_iter(prob2.1, &prob2.0);
544 let mut counter = 0usize;
545
546 for state in iter {
547 counter += 1;
548 if counter == 5 {
549 println!("after 5 times iter, the f(x) = {}", (prob2.0).calc(state.0));
550 break;
551 }
552 }
553 }
554 }
555 }
556
557 mod test_leak {
558 use std::time::Duration;
559
560 use crate::{utils, Solver};
561
562
563 #[test]
564 fn solve() {
565 let iter_fn = |state: &utils::debug::VisibleDrop, _: &()| {
566 utils::debug::VisibleDrop::new(state.get().wrapping_add(1))
567 };
568
569 let term_cond = |state: &utils::debug::VisibleDrop, _: &()| {
570 state.get() == 2
571 };
572
573 let solver = Solver::new(iter_fn, term_cond);
574
575 println!("solve");
576
577 solver.solve(0, &());
578
579
580 println!("iter");
581 for _ in solver.clone().into_iter(0, &()) {
582 println!("do iter")
583 }
584
585 println!("solve with timeout");
586 let loop_solver = solver.change_term_cond(|_,_| false);
587
588 let timeout = Duration::from_nanos(1000);
589
590 let _ = loop_solver.solve_with_timeout(0, &(), timeout);
591 }
592 }
593
594 mod guard {
595 use std::time::Duration;
596
597 use crate::Solver;
598
599 #[test]
600 fn test() {
601
602
603 let loop_solver = Solver::new(|_state: &f64, _: &()| {*_state}, |_: &f64, _: &()| {false});
605 let try_solve = loop_solver.solve_with_timeout(0.0, &(), Duration::from_secs(1));
606
607 assert!(try_solve.is_err());
608
609 let try_solve = loop_solver.solve_with_max_iteration(0.0, &(), 10);
610 assert!(try_solve.is_err());
611
612 }
613 }
614}