clarabel/solver/implementations/default/
solver.rs1use super::*;
2use crate::solver::core::callbacks::SolverCallbacks;
3use crate::solver::traits::Settings;
4use crate::{
5 io::ConfigurablePrintTarget,
6 solver::core::{
7 cones::{CompositeCone, SupportedConeT},
8 kktsolvers::HasLinearSolverInfo,
9 traits::ProblemData,
10 SettingsError, Solver,
11 },
12};
13use thiserror::Error;
14
15use crate::algebra::*;
16use crate::timers::*;
17
18pub type DefaultSolver<T = f64> = Solver<
20 T,
21 DefaultProblemData<T>,
22 DefaultVariables<T>,
23 DefaultResiduals<T>,
24 DefaultKKTSystem<T>,
25 CompositeCone<T>,
26 DefaultInfo<T>,
27 DefaultSolution<T>,
28 DefaultSettings<T>,
29>;
30
31#[derive(Error, Debug)]
34pub enum SolverError {
36 #[error("Bad input data: {0}")]
38 BadInputData(&'static str),
39
40 #[error("Bad settings: {0}")]
42 SettingsError(#[from] SettingsError),
43
44 #[error("I/O error: {0}")]
46 IoError(#[from] std::io::Error),
47
48 #[error("JSON error: {0}")]
50 JsonError(#[from] serde_json::Error),
51}
52
53impl<T> DefaultSolver<T>
54where
55 T: FloatT,
56{
57 pub fn new(
58 P: &CscMatrix<T>,
59 q: &[T],
60 A: &CscMatrix<T>,
61 b: &[T],
62 cones: &[SupportedConeT<T>],
63 settings: DefaultSettings<T>,
64 ) -> Result<Self, SolverError> {
65 check_dimensions(P, q, A, b, cones)?;
67 settings.validate()?;
69
70 let mut timers = Timers::default();
71 let mut output;
72 let mut info = DefaultInfo::<T>::new();
73
74 timeit! {timers => "setup"; {
75
76 let solution = DefaultSolution::<T>::new(A.n, A.m);
78
79 let mut data;
82 timeit!{timers => "presolve"; {
83 data = DefaultProblemData::<T>::new(P,q,A,b,cones,&settings);
84 }}
85
86 let cones = CompositeCone::<T>::new(&data.cones);
87 assert_eq!(cones.numel, data.m);
88 let variables = DefaultVariables::<T>::new(data.n,data.m);
89 let residuals = DefaultResiduals::<T>::new(data.n,data.m);
90
91 timeit!{timers => "equilibration"; {
95 data.equilibrate(&cones,&settings);
96 }}
97
98 let kktsystem;
99 timeit!{timers => "kktinit"; {
100 kktsystem = DefaultKKTSystem::<T>::new(&data,&cones,&settings);
101 }}
102 info.linsolver = kktsystem.linear_solver_info();
103
104 let step_rhs = DefaultVariables::<T>::new(data.n,data.m);
106 let step_lhs = DefaultVariables::<T>::new(data.n,data.m);
107 let prev_vars = DefaultVariables::<T>::new(data.n,data.m);
108
109 output = Self{
112 data,variables,residuals,kktsystem,
113 step_lhs,step_rhs,prev_vars,info,
114 solution,cones,settings,
115 timers: None,
116 callbacks: SolverCallbacks::default(),
117 phantom: std::marker::PhantomData };
118
119 }} output.timers.replace(timers);
124
125 Ok(output)
126 }
127}
128
129fn check_dimensions<T: FloatT>(
130 P: &CscMatrix<T>,
131 q: &[T],
132 A: &CscMatrix<T>,
133 b: &[T],
134 cone_types: &[SupportedConeT<T>],
135) -> Result<(), SolverError> {
136 let m = b.len();
137 let n = q.len();
138 let p = cone_types.iter().fold(0, |acc, cone| acc + cone.nvars());
139
140 if m != A.nrows() {
141 return Err(SolverError::BadInputData("A and b incompatible dimensions"));
142 }
143 if p != m {
144 return Err(SolverError::BadInputData(
145 "Constraint dimensions inconsistent with size of cones",
146 ));
147 }
148 if n != A.ncols() {
149 return Err(SolverError::BadInputData("A and q incompatible dimensions"));
150 }
151 if n != P.ncols() {
152 return Err(SolverError::BadInputData("P and q incompatible dimensions"));
153 }
154 if !P.is_square() {
155 return Err(SolverError::BadInputData("P not square"));
156 }
157
158 Ok(())
159}
160
161impl<T> ConfigurablePrintTarget for DefaultSolver<T>
162where
163 T: FloatT,
164{
165 fn print_to_stdout(&mut self) {
166 self.info.print_to_stdout();
167 }
168 fn print_to_file(&mut self, file: std::fs::File) {
169 self.info.print_to_file(file)
170 }
171 fn print_to_stream(&mut self, stream: Box<dyn std::io::Write + Send + Sync>) {
172 self.info.print_to_stream(stream)
173 }
174 fn print_to_sink(&mut self) {
175 self.info.print_to_sink()
176 }
177 fn print_to_buffer(&mut self) {
178 self.info.print_to_buffer();
179 }
180 fn get_print_buffer(&mut self) -> std::io::Result<String> {
181 self.info.get_print_buffer()
182 }
183}