Skip to main content

iter_solver/
lib.rs

1#![doc = include_str!("../README.md")]
2#![cfg_attr(not(feature = "std"), no_std)]
3#![cfg_attr(docsrs, feature(doc_cfg))]
4
5extern crate self as iter_solver;
6
7pub use iter_state_derive::IterState;
8
9use core::marker::PhantomData;
10
11#[cfg(feature = "std")]
12use std::time::Duration;
13
14#[cfg(feature = "std")]
15use std::time::Instant;
16
17#[cfg(feature = "std")]
18use crate::error::TimeOut;
19
20use crate::error::MaxIterationReached;
21
22pub mod error;
23
24#[cfg(test)]
25mod build_test;
26
27/// Intermediate state of an iterative algorithm.
28///
29/// For many iterative algorithms, the intermediate state often requires a more complex structure to describe compared to the initial iteration value and the solution.
30///
31/// In practical iterative algorithms, it is often necessary to perform some simple computations or obtain algorithm metadata after obtaining the solution.
32///
33/// Therefore, IterState allows you to customize the intermediate state and separate the abstractions of value and solution.
34///
35/// If you expect the simplest behavior, this crate has already implemented `IterState` for basic data types `i*`, `u*`, and `f*`, where their associated types `Value` and `Solution` are themselves.
36/// 
37/// Additionally, you can use the macro `#[derive(IterState)]` to derive `IterState`. In this case, the associated types [`IterState::Value`] and [`IterState::Solution`] will be set to `Self`. Moreover, similar to the implementation for basic data types, the [`IterState::init_from_value`] method will simply return the parameter directly, and the [`IterState::into_sol`] method will directly return `self`.
38pub trait IterState: Sized {
39    /// Type representing the value during iteration (e.g., intermediate computation results).
40    type Value;
41
42    /// Type representing the final solution.  
43    type Solution;  
44
45    /// Initializes the state from an initial value.  
46    fn init_from_value(initial_point: Self::Value) -> Self;  
47
48    /// Converts the current state into the solution.  
49    fn into_sol(self) -> Self::Solution;  
50}
51
52macro_rules! iterstate_impl {
53    ($($ty:ty),*) => {
54        $(
55            impl IterState for $ty {
56                type Value = $ty;
57                type Solution = $ty;
58                fn init_from_value(initial_point: Self::Value) -> Self {
59                    initial_point
60                }
61                fn into_sol(self) -> Self::Solution {
62                    self
63                }
64            }
65        )*
66    };
67}
68
69iterstate_impl!(
70    u8,u16,u32,u64,u128,usize,
71    i8,i16,i32,i64,i128,isize,
72    f32, f64
73);
74
75
76/// Solver type.
77/// 
78/// Note: This type does not provide specific algorithms but allows you to customize iteration methods and stopping conditions.
79/// 
80/// The `Problem` generic parameter has no constraints, allowing you to define any kind of interface in the Problem generic for use by iteration methods and stopping conditions.
81#[derive(Debug)]
82pub struct Solver<State, Problem, IterFn, TermFn>
83where
84    State: IterState,
85    IterFn: Fn(State, &Problem) -> State,
86    TermFn: Fn(&State, &Problem) -> bool,
87{
88    /// Intermediate state of the solver (uninitialized at the start)
89    state: PhantomData<State>,
90    /// Placeholder for the problem type (no runtime storage)
91    problem: PhantomData<Problem>,
92    /// Function defining the iteration logic
93    iter_fn: IterFn,
94    /// Function defining the termination condition
95    term_cond: TermFn,
96}
97impl<State, Problem, IterFn, TermFn> Solver<State, Problem, IterFn, TermFn>
98where 
99    State: IterState,
100    IterFn: Fn(State, &Problem) -> State,
101    TermFn: Fn(&State, &Problem) -> bool
102{
103    /// Creates a Solver instance with the specified methods.
104    ///
105    /// Parameter `iter_fn` defines the state iteration rule for the Solver.
106    ///
107    /// Parameter `term_cond` specifies that the Solver stops iterating if and only if
108    /// the condition is met: `term_cond` returns `true` in the current state.
109    pub fn new(
110        iter_fn: IterFn,
111        term_cond: TermFn
112    ) -> Self {
113        Self { 
114            state: PhantomData::<State>, 
115            problem: PhantomData::<Problem>, 
116            iter_fn: iter_fn, 
117            term_cond: term_cond 
118        }
119    }
120
121    /// Solves the problem by executing iterative logic with the given initial value and specific problem.
122    ///
123    /// # Note
124    /// If the algorithm defined by the solver contains logical errors, 
125    /// the solve function may enter an infinite loop. 
126    /// To avoid this, try [`Solver::solve_with_max_iterations`] or [`Solver::solve_with_timeout`]. If you need more flexible error handling, try [`Solver::solve_with_error`].
127    pub fn solve(&self, initial_point: State::Value, problem: &Problem) -> State::Solution {
128        // init state
129        let initial_state = State::init_from_value(initial_point);
130        let mut state = initial_state;
131
132        if (self.term_cond)(&state, problem) {
133            return state.into_sol();
134        }
135        
136        // do iter
137        loop {
138            //let state = unsafe { self.state.assume_init_mut() };
139
140            state = (self.iter_fn)(state, problem);
141
142            // check termination cond
143            if (self.term_cond)(&state, problem) {
144                break;
145            }
146        }
147        
148        let final_state = state;
149
150        let sol = final_state.into_sol();
151
152        //unsafe { self.state.assume_init_drop(); }
153
154        sol
155    }
156
157
158    /// A solution method with a maximum iteration limit.
159    /// If the termination condition is met before reaching the maximum number of iterations, it returns an [`Ok`] value with the type [`IterState::Solution`].
160    /// Otherwise, it returns an [`Err`] with [`error::MaxIterationReached`].
161    /// 
162    /// # Example
163    /// ```
164    /// use iter_solver::Solver;
165    /// 
166    /// // define a never stop solver
167    /// let loop_solver = Solver::new(|_state: f64, _: &()| {_state}, |_: &f64, _: &()| {false});
168    /// let try_solve = loop_solver.solve_with_max_iterations(0.0, &(), 10);
169    /// 
170    /// assert!(try_solve.is_err());
171    /// ```
172    /// 
173    /// If you need more flexible error handling, try [`Solver::solve_with_error`].
174    pub fn solve_with_max_iterations(
175        &self,
176        initial_point: State::Value,
177        problem: &Problem,
178        max_iteration: u64
179    ) -> Result<State::Solution, MaxIterationReached<State>> {
180        let mut reach_max = true;
181
182        // init state
183        let initial_state = State::init_from_value(initial_point);
184        let mut state = initial_state;     
185
186        if (self.term_cond)(&state, problem) {
187            return Ok(state.into_sol());
188        }   
189        
190        for _iteration in 0..max_iteration {
191            //state = unsafe { self.state.assume_init_mut() };
192
193            state = (self.iter_fn)(state, problem);
194
195            // check termination cond
196            if (self.term_cond)(&state, problem) {
197                reach_max = false;
198                break;
199            }
200        }
201
202        let ret_res = if !reach_max {
203            let final_state = state;
204            let sol = final_state.into_sol();
205            //unsafe { self.state.assume_init_drop(); }
206            Ok(sol)
207        } else {
208            let final_state = state;
209            Err(MaxIterationReached{
210                max_iteration: max_iteration, 
211                final_state: final_state
212            })
213        };
214
215        ret_res
216    }
217
218    #[cfg(feature = "std")]
219    /// A solution method with a time limit.
220    /// If the termination condition is met before reaching the timeout duration elapses, it returns an [`Ok`] value with the type [`IterState::Solution`].
221    /// Otherwise, it returns an [`Err`] with [`error::TimeOut`].
222    /// 
223    /// # Example
224    /// ```
225    /// use iter_solver::Solver;
226    /// use std::time::Duration;
227    /// 
228    /// // define a never stop solver
229    /// let loop_solver = Solver::new(|_state: f64, _: &()| {_state}, |_: &f64, _: &()| {false});
230    /// let try_solve = loop_solver.solve_with_timeout(0.0, &(), Duration::from_secs(1));
231    /// 
232    /// assert!(try_solve.is_err());
233    /// ```
234    /// 
235    /// If you need more flexible error handling, try [`Solver::solve_with_error`].
236    pub fn solve_with_timeout(
237        &self,
238        initial_point: State::Value,
239        problem: &Problem,
240        timeout: Duration        
241    ) -> Result<State::Solution, TimeOut<State>> {
242        let start_time = Instant::now();
243        let mut is_timeout = true;
244        
245        // init state
246        let initial_state = State::init_from_value(initial_point);
247        let mut state = initial_state;
248
249        if (self.term_cond)(&state, problem) {
250            return Ok(state.into_sol());
251        }
252        
253        // do iter
254        loop {
255            //state = unsafe { self.state.assume_init_mut() };
256
257            state = (self.iter_fn)(state, problem);
258
259            if start_time.elapsed() > timeout {
260                break;
261            }
262
263            // check termination cond
264            if (self.term_cond)(&state, problem) {
265                is_timeout = false;
266                break;
267            }
268        }
269
270        if !is_timeout {
271            let final_state = state;
272
273            let sol = final_state.into_sol();
274
275            //unsafe { self.state.assume_init_drop(); }
276
277            Ok(sol)                
278        } else {
279            let final_state = state; //unsafe { self.state.assume_init_read() };
280            Err(TimeOut { timeout: timeout, final_state: final_state })            
281        }
282
283    }
284
285    /// Performs iterative solving with custom error handling, allowing early termination.
286    ///
287    /// This method executes an iterative solving process, but before each iteration,
288    /// it invokes the provided `check_fn` with the current state and the problem reference.
289    /// If `check_fn` returns `Ok(())`, the iteration continues until the stopping criteria
290    /// are met, returning [`Ok`] with the final solution. If `check_fn` returns `Err(e)`, the
291    /// iteration stops immediately and returns `Err(e)`.
292    ///
293    /// The key feature is **flexible error type customization** – `E` can be any type
294    /// that suits your error-handling needs (e.g., a simple `&'static str`, a custom enum,
295    /// or a structured error type). This allows you to:
296    /// - Embed domain-specific failure semantics directly into the solving flow.
297    /// - Propagate rich error information without boxing or trait objects.
298    /// - Maintain full control over error kinds and context.
299    /// 
300    /// # Example
301    /// ```
302    /// use iter_solver::Solver;
303    ///         
304    /// let check_fn = |float: &f64, _: &()| {
305    ///     if float.is_infinite() {
306    ///         return Err("Inf Error");
307    ///     } else if float.is_nan() {
308    ///        return Err("NaN Error");
309    ///     }
310    ///     Ok(())
311    /// };
312    ///
313    /// let solver = Solver::new(
314    ///    |f, _| f * 2.0, // 2^n -> Inf
315    ///    |_,_| {false} // never stop
316    /// );
317    ///
318    /// let result = solver.solve_with_error(1.0, &(), check_fn);
319    ///
320    /// assert!(result.is_err());
321    /// println!("{}", result.unwrap_err()) // print "Inf Error"
322    /// ```
323    pub fn solve_with_error<E, F>(
324        &self,
325        initial_point: State::Value,
326        problem: &Problem,
327        check_fn: F
328    ) -> Result<State::Solution, E>
329    where 
330        F: Fn(&State, &Problem) -> Result<(), E>
331    {
332        // init state
333        let initial_state = State::init_from_value(initial_point);
334        let mut state = initial_state;
335
336        check_fn(&state, problem)?;
337        if (self.term_cond)(&state, problem) {
338            return Ok(state.into_sol());
339        }
340        
341        // do iter
342        loop {
343            //let state = unsafe { self.state.assume_init_mut() };
344
345            state = (self.iter_fn)(state, problem);
346
347            check_fn(&state, problem)?;
348            // check termination cond
349            if (self.term_cond)(&state, problem) {
350                break;
351            }
352        }
353        
354        let final_state = state;
355
356        let sol = final_state.into_sol();
357
358        //unsafe { self.state.assume_init_drop(); }
359
360        Ok(sol)
361    }
362
363
364
365    /// Consumes `self` and returns a new [`Solver`] with the given new termination condition.
366    pub fn with_term_cond<NewTermCond>(self, new_cond: NewTermCond) -> Solver<State, Problem, IterFn, NewTermCond> 
367    where 
368        NewTermCond: Fn(&State, &Problem) -> bool
369    {
370        Solver { 
371            state: self.state, 
372            problem: self.problem, 
373            iter_fn: self.iter_fn,
374            term_cond: new_cond 
375        }
376    }
377}
378
379impl<State, Problem, IterFn, TermFn> Clone for Solver<State, Problem, IterFn, TermFn>
380where 
381   State: IterState,
382   IterFn: Fn(State, &Problem) -> State + Clone,
383   TermFn: Fn(&State, &Problem) -> bool + Clone
384{
385    fn clone(&self) -> Self {
386        Self { state: PhantomData::<State>, problem: PhantomData::<Problem>, iter_fn: self.iter_fn.clone(), term_cond: self.term_cond.clone() }
387    }
388}
389
390
391
392
393
394
395#[cfg(test)]
396mod test {
397    use crate::Solver;
398
399    mod newton {
400        use crate::{IterState, Solver};
401
402        fn f_and_df(x: f64) -> (f64, f64) {
403            let fx = x.exp() - 1.5;
404            let dfx = x.exp();
405            (fx, dfx)            
406        }
407        #[test]
408    fn show() {
409        let iter_fn = |state: f64, problem: &fn(f64) -> (f64, f64)| {
410            let x_n = state;
411            let (fx, dfx) = problem(x_n);
412            x_n - (fx / dfx)
413        };
414
415        let term_cond = |state: &f64, problem: &fn(f64) -> (f64, f64)| {
416            let (fx, _) = problem(*state);
417            fx.abs() < 1e-6
418        };
419
420        let solver = Solver::new(iter_fn, term_cond);
421
422        let solution = solver.solve(1.5, &(f_and_df as fn(f64) -> (f64, f64)));
423
424        println!("solver's solution: {}", solution);  
425        println!("use std function ln: {}", 1.5_f64.ln());
426
427        // solver's solution: 0.4054651081202111
428        // use std function ln: 0.4054651081081644
429    }
430
431        #[derive(Clone)]
432        enum Equation {
433            Exp {
434                // $ae^x - k = 0$
435                a: f64,
436                k: f64
437            },
438
439            Square {
440                // $ax^2 + bx + c = 0$ 
441                a: f64,
442                b: f64,
443                c: f64
444            }
445        }
446
447        impl Equation {
448            fn calc(&self, val: f64) -> f64 {
449                match self {
450                    Self::Exp { a, k } => {
451                        a * val.exp() - k
452                    }
453
454                    Self::Square { a, b, c } => {
455                        let x2 = a * val * val;
456                        let x1 = b * val;
457                        let x0 = c;
458                        x2 + x1 + x0
459                    }
460                }
461            }
462
463            fn diff(&self, val: f64) -> f64 {
464                match self {
465                    Self::Exp { a, k: _ } => {
466                        a * val.exp() 
467                    }
468
469                    Self::Square { a, b, c: _ } => {
470                        ((2. * a) * val) + b
471                    }
472                }                
473            }
474        }
475
476        #[derive(Debug, Clone)]
477        struct NewtonState(f64);
478
479        impl IterState for NewtonState {
480            type Value = f64;
481        
482            type Solution = f64;
483        
484            fn init_from_value(initial_point: Self::Value) -> Self {
485                Self(initial_point)
486            }
487        
488            fn into_sol(self) -> Self::Solution {
489                self.0
490            }
491        }
492
493        #[test]
494        fn test() {
495            let iter_fn = |state: NewtonState, problem: &Equation| {
496                let x = state.0;
497                let dx = problem.diff(x);
498                let fx = problem.calc(x);
499
500                let next_x = x - (fx / dx);
501
502                NewtonState(next_x)
503            };
504
505            let term_cond = |state: &NewtonState, problem: &Equation| {
506                let epsilon = 1e-6;
507                problem.calc(state.0) < epsilon
508            };
509
510            let  solver = Solver::new(iter_fn, term_cond);
511
512            let  solver1 = solver.clone().with_term_cond(|state, equation| {
513                equation.calc(state.0) < 1e-9
514            });
515
516            let prob1 = (Equation::Exp { a: 2., k: 3. }, 2.);
517
518            let cloned_and_change_cond_sol = solver1.solve(prob1.1, &prob1.0.clone());
519
520            let prob2 = (Equation::Square { a: 2., b: -5., c: 3. }, 6.);
521
522            let prob1_sol = solver.solve(prob1.1, &prob1.0);
523            let prob2_sol = solver.solve(prob2.1, &prob2.0);
524
525            println!("the numerical solution of $2e^x - 3 = 0$ is: {}", prob1_sol);
526            println!("with direct calc: {}", (1.5_f64).ln());
527            println!("the numerical solution of $2x^2 - 5x + 3 = 0$ is: {}", prob2_sol);
528            println!("with direct calc: {} or {}", ((5. + 1.)/4.) , (3./4.));
529
530            println!("cloned sol: {}", cloned_and_change_cond_sol);
531
532            assert!(prob1.0.calc(prob1_sol) < 1e-6);
533            assert!(prob2.0.calc(prob2_sol) < 1e-6)
534        }
535
536        
537            
538    }
539
540    
541    #[test]
542    fn test_with_error() {
543        let check_fn = |float: &f64, _: &()| {
544            if float.is_infinite() {
545                return Err("Inf Error");
546            } else if float.is_nan() {
547                return Err("NaN Error");
548            }
549            Ok(())
550        };
551
552        let solver = Solver::new(
553            |f, _| f * 2.0, // 2^n -> Inf
554            |_,_| {false} // never stop
555        );
556
557        let result = solver.solve_with_error(1.0, &(), check_fn);
558
559        assert!(result.is_err());
560        println!("{}", result.unwrap_err()) // Inf Error
561    }
562    
563
564    mod guard {
565        use std::time::Duration;
566
567        use crate::Solver;
568
569        #[test]
570        fn test() {
571            
572            // define a never stop solver
573            let loop_solver = Solver::new(|_state: f64, _: &()| {_state}, |_: &f64, _: &()| {false});
574            let try_solve = loop_solver.solve_with_timeout(0.0, &(), Duration::from_secs(1));
575            
576            assert!(try_solve.is_err());
577
578            let try_solve = loop_solver.solve_with_max_iterations(0.0, &(), 10);
579            assert!(try_solve.is_err());
580
581        }
582    }
583
584
585    mod derive_test {
586        use crate::IterState;
587
588        #[derive(PartialEq, Eq , IterState, Debug)]
589        struct State(Vec<u8>, Box<String>);
590
591        #[test]
592        fn test_derive() {
593            let vec1 = vec![0,12, 39];
594            let boxed_str = Box::new("some str".to_string());
595            let value = State(vec1.clone(), boxed_str.clone());
596            let state = State::init_from_value(value);
597            assert_eq!(vec1, state.0);
598            let final_s = state.into_sol();
599            assert_eq!(final_s.1, boxed_str);
600        }
601    }
602}