1use crate::{PortfolioOutcome, ProductOutcome, disaggregate};
2use clarabel::{algebra::*, solver::*};
3use fts_core::{
4 models::{DemandCurve, DemandGroup, Map, ProductGroup},
5 ports::Solver,
6};
7use std::{hash::Hash, marker::PhantomData};
8
9pub struct ClarabelSolver<DemandId, PortfolioId, ProductId>(
15 DefaultSettings<f64>,
16 PhantomData<(DemandId, PortfolioId, ProductId)>,
17);
18
19impl<A, B, C> ClarabelSolver<A, B, C> {
20 pub fn new(settings: DefaultSettings<f64>) -> Self {
22 Self(settings, PhantomData::default())
23 }
24}
25
26impl<A, B, C> Default for ClarabelSolver<A, B, C> {
27 fn default() -> Self {
28 let mut settings = DefaultSettings::default();
29 settings.verbose = false;
30 Self(settings, PhantomData::default())
31 }
32}
33
34impl<
35 DemandId: Clone + Eq + Hash,
36 PortfolioId: Clone + Eq + Hash,
37 ProductId: Clone + Eq + Hash + Ord,
38> ClarabelSolver<DemandId, PortfolioId, ProductId>
39{
40 fn solve(
41 settings: DefaultSettings<f64>,
42 demand_curves: Map<DemandId, DemandCurve>,
43 portfolios: Map<PortfolioId, (DemandGroup<DemandId>, ProductGroup<ProductId>)>,
44 ) -> Result<
45 (
46 Map<PortfolioId, PortfolioOutcome>,
47 Map<ProductId, ProductOutcome>,
48 ),
49 SolverStatus,
50 > {
51 let (demand_curves, portfolios, mut portfolio_outcomes, mut product_outcomes) =
53 super::prepare(demand_curves, portfolios);
54
55 if portfolio_outcomes.len() == 0 || product_outcomes.len() == 0 {
57 return Ok((portfolio_outcomes, product_outcomes));
58 }
59
60 let nproducts = product_outcomes.len();
64 let ndemands = demand_curves.len();
65
66 let mut p = Vec::new();
68 let mut q = Vec::new();
70
71 let mut b = vec![0.0; nproducts + ndemands];
74 let mut s = vec![ZeroConeT(b.len()), NonnegativeConeT(0)];
75
76 let mut a_nzval = Vec::new();
79 let mut a_rowval = Vec::new();
80 let mut a_colptr = Vec::new();
81
82 for (demand_group, product_group) in portfolios.values() {
84 if product_group.len() == 0 || demand_group.len() == 0 {
87 continue;
88 }
89
90 p.push(0.0);
92 q.push(0.0);
93
94 a_colptr.push(a_nzval.len());
96
97 for (product_id, &weight) in product_group.iter() {
99 let idx = product_outcomes.get_index_of(product_id).unwrap();
101 a_nzval.push(weight);
102 a_rowval.push(idx);
103 }
104
105 for (demand_id, &weight) in demand_group.iter() {
107 let idx = demand_curves.get_index_of(demand_id).unwrap();
109 a_nzval.push(weight);
110 a_rowval.push(nproducts + idx);
111 }
112 }
113
114 for (offset, (_, demand_curve)) in demand_curves.into_iter().enumerate() {
116 let row = nproducts + offset;
117 let (min, max) = demand_curve.domain();
118 let points = demand_curve.points();
119
120 if let Some(segments) = disaggregate(points.into_iter(), min, max) {
121 for segment in segments {
122 let segment = segment.unwrap();
124 let (m, pzero) = segment.slope_intercept();
125
126 p.push(-m);
128 q.push(-pzero);
129
130 a_colptr.push(a_nzval.len());
132
133 a_nzval.push(-1.0);
135 a_rowval.push(row);
136
137 if segment.q0.is_finite() {
140 a_nzval.push(-1.0);
141 a_rowval.push(b.len());
142 b.push(-segment.q0);
143 }
144 if segment.q1.is_finite() {
145 a_nzval.push(1.0);
146 a_rowval.push(b.len());
147 b.push(segment.q1);
148 }
149 }
150 }
151 }
152
153 a_colptr.push(a_nzval.len());
155
156 let m = b.len();
157 let n = p.len();
158
159 let a_matrix = CscMatrix {
160 m,
161 n,
162 colptr: a_colptr,
163 rowval: a_rowval,
164 nzval: a_nzval,
165 };
166
167 assert!(a_matrix.check_format().is_ok()); s[1] = NonnegativeConeT(b.len() - nproducts - ndemands);
171
172 let p_matrix = {
174 CscMatrix {
175 m: n,
176 n,
177 colptr: (0..=n).collect(),
178 rowval: (0..n).collect(),
179 nzval: p,
180 }
181 };
182
183 let mut solver = DefaultSolver::new(&p_matrix, &q, &a_matrix, &b, &s, settings)
185 .expect("valid solver config");
186 solver.solve();
187 match solver.solution.status {
188 SolverStatus::Solved => {}
189 SolverStatus::AlmostSolved => {
190 tracing::warn!(status = ?solver.solution.status, "convergence issues");
191 }
192 status => {
193 return Err(status);
194 }
195 };
196
197 super::finalize(
199 solver.solution.x.iter(),
200 solver.solution.z.iter(),
201 &portfolios,
202 &mut portfolio_outcomes,
203 &mut product_outcomes,
204 );
205
206 Ok((portfolio_outcomes, product_outcomes))
221 }
222}
223
224impl<
225 DemandId: Clone + Eq + Hash + Ord + Send + Sync + 'static,
226 PortfolioId: Clone + Eq + Hash + Ord + Send + Sync + 'static,
227 ProductId: Clone + Eq + Hash + Ord + Send + Sync + 'static,
228> Solver<DemandId, PortfolioId, ProductId> for ClarabelSolver<DemandId, PortfolioId, ProductId>
229{
230 type Error = tokio::task::JoinError;
231 type PortfolioOutcome = PortfolioOutcome;
232 type ProductOutcome = ProductOutcome;
233
234 type State = ();
235
236 async fn solve(
237 &self,
238 demand_curves: Map<DemandId, DemandCurve>,
239 portfolios: Map<PortfolioId, (DemandGroup<DemandId>, ProductGroup<ProductId>)>,
240 _state: Self::State,
241 ) -> Result<
242 (
243 Map<PortfolioId, Self::PortfolioOutcome>,
244 Map<ProductId, Self::ProductOutcome>,
245 ),
246 Self::Error,
247 > {
248 let settings = self.0.clone();
249 let solution =
250 tokio::spawn(async move { Self::solve(settings, demand_curves, portfolios) }).await?;
251
252 Ok(solution.expect("failed to solve"))
255 }
256}