fts_solver/impls/
clarabel.rs

1use crate::{PortfolioOutcome, ProductOutcome, disaggregate};
2use clarabel::{algebra::*, solver::*};
3use fts_core::{
4    models::{DemandCurve, DemandGroup, Map, ProductGroup},
5    ports::Solver,
6};
7use std::{hash::Hash, marker::PhantomData};
8
9/// A solver implementation that uses the Clarabel interior point method
10/// for quadratic programming to solve the market clearing problem.
11///
12/// This solver is generally more accurate but can be slower than ADMM-based
13/// solvers for large problems. It's a good choice when high precision is needed.
14pub struct ClarabelSolver<DemandId, PortfolioId, ProductId>(
15    DefaultSettings<f64>,
16    PhantomData<(DemandId, PortfolioId, ProductId)>,
17);
18
19impl<A, B, C> ClarabelSolver<A, B, C> {
20    /// create a new solver with the given settings
21    pub fn new(settings: DefaultSettings<f64>) -> Self {
22        Self(settings, PhantomData::default())
23    }
24}
25
26impl<A, B, C> Default for ClarabelSolver<A, B, C> {
27    fn default() -> Self {
28        let mut settings = DefaultSettings::default();
29        settings.verbose = false;
30        Self(settings, PhantomData::default())
31    }
32}
33
34impl<
35    DemandId: Clone + Eq + Hash,
36    PortfolioId: Clone + Eq + Hash,
37    ProductId: Clone + Eq + Hash + Ord,
38> ClarabelSolver<DemandId, PortfolioId, ProductId>
39{
40    fn solve(
41        settings: DefaultSettings<f64>,
42        demand_curves: Map<DemandId, DemandCurve>,
43        portfolios: Map<PortfolioId, (DemandGroup<DemandId>, ProductGroup<ProductId>)>,
44    ) -> Result<
45        (
46            Map<PortfolioId, PortfolioOutcome>,
47            Map<ProductId, ProductOutcome>,
48        ),
49        SolverStatus,
50    > {
51        // This prepare method canonicalizes the input in a manner appropriate for naive CSC construction
52        let (demand_curves, portfolios, mut portfolio_outcomes, mut product_outcomes) =
53            super::prepare(demand_curves, portfolios);
54
55        // If there are no portfolios or products, there is nothing to do.
56        if portfolio_outcomes.len() == 0 || product_outcomes.len() == 0 {
57            return Ok((portfolio_outcomes, product_outcomes));
58        }
59
60        // The trade and bid constraints are all (something) = 0, we need to
61        // know how many of these there are in order to handle the box
62        // constraints for each decision variable
63        let nproducts = product_outcomes.len();
64        let ndemands = demand_curves.len();
65
66        // Our quadratic term is diagonal, so we build the matrix by defining its diagonal
67        let mut p = Vec::new();
68        // and these are the linear terms
69        let mut q = Vec::new();
70
71        // Clarabel handles constraints via a cone specification, e.g. Ax + s = b, where s is a cone.
72        // The first `nzero` of b and s are just =0, so we do that work upfront.
73        let mut b = vec![0.0; nproducts + ndemands];
74        let mut s = vec![ZeroConeT(b.len()), NonnegativeConeT(0)];
75
76        // Clarabel's matrix input is in the form of CSC, so we handle the memory representation
77        // carefully.
78        let mut a_nzval = Vec::new();
79        let mut a_rowval = Vec::new();
80        let mut a_colptr = Vec::new();
81
82        // We begin by setting up the portfolio variables.
83        for (demand_group, product_group) in portfolios.values() {
84            // We can skip any portfolio variable that does not have associated products or demands
85            // (This is because our outcomes are preloaded with zero solutions)
86            if product_group.len() == 0 || demand_group.len() == 0 {
87                continue;
88            }
89
90            // portfolio variables contribute nothing to the objective
91            p.push(0.0);
92            q.push(0.0);
93
94            // start a new column in the constraint matrix
95            a_colptr.push(a_nzval.len());
96
97            // We copy the product weights into the matrix
98            for (product_id, &weight) in product_group.iter() {
99                // SAFETY: this unwrap() is guaranteed by the logic in prepare()
100                let idx = product_outcomes.get_index_of(product_id).unwrap();
101                a_nzval.push(weight);
102                a_rowval.push(idx);
103            }
104
105            // We copy the demand weights into the matrix as well
106            for (demand_id, &weight) in demand_group.iter() {
107                // SAFETY: this unwrap() is guaranteed by the logic in prepare()
108                let idx = demand_curves.get_index_of(demand_id).unwrap();
109                a_nzval.push(weight);
110                a_rowval.push(nproducts + idx);
111            }
112        }
113
114        // Now we setup the segment variables
115        for (offset, (_, demand_curve)) in demand_curves.into_iter().enumerate() {
116            let row = nproducts + offset;
117            let (min, max) = demand_curve.domain();
118            let points = demand_curve.points();
119
120            if let Some(segments) = disaggregate(points.into_iter(), min, max) {
121                for segment in segments {
122                    // TODO: propagate the error upwards
123                    let segment = segment.unwrap();
124                    let (m, pzero) = segment.slope_intercept();
125
126                    // Setup the contributions to the objective
127                    p.push(-m);
128                    q.push(-pzero);
129
130                    // Insert a new column
131                    a_colptr.push(a_nzval.len());
132
133                    // Ensure it counts towards the group
134                    a_nzval.push(-1.0);
135                    a_rowval.push(row);
136
137                    // Setup the box constraints
138                    // x0 <= y <= x1 ==> -y + s == -x0 and y + s == x1
139                    if segment.q0.is_finite() {
140                        a_nzval.push(-1.0);
141                        a_rowval.push(b.len());
142                        b.push(-segment.q0);
143                    }
144                    if segment.q1.is_finite() {
145                        a_nzval.push(1.0);
146                        a_rowval.push(b.len());
147                        b.push(segment.q1);
148                    }
149                }
150            }
151        }
152
153        // We need to polish off the CSC matrix
154        a_colptr.push(a_nzval.len());
155
156        let m = b.len();
157        let n = p.len();
158
159        let a_matrix = CscMatrix {
160            m,
161            n,
162            colptr: a_colptr,
163            rowval: a_rowval,
164            nzval: a_nzval,
165        };
166
167        assert!(a_matrix.check_format().is_ok()); // TODO: maybe remove this
168
169        // We also need to cleanup the cone specification
170        s[1] = NonnegativeConeT(b.len() - nproducts - ndemands);
171
172        // Finally, we need to convert our p spec into a csc matrix
173        let p_matrix = {
174            CscMatrix {
175                m: n,
176                n,
177                colptr: (0..=n).collect(),
178                rowval: (0..n).collect(),
179                nzval: p,
180            }
181        };
182
183        // Now we can solve!
184        let mut solver = DefaultSolver::new(&p_matrix, &q, &a_matrix, &b, &s, settings)
185            .expect("valid solver config");
186        solver.solve();
187        match solver.solution.status {
188            SolverStatus::Solved => {}
189            SolverStatus::AlmostSolved => {
190                tracing::warn!(status = ?solver.solution.status, "convergence issues");
191            }
192            status => {
193                return Err(status);
194            }
195        };
196
197        // Now we copy the solution back
198        super::finalize(
199            solver.solution.x.iter(),
200            solver.solution.z.iter(),
201            &portfolios,
202            &mut portfolio_outcomes,
203            &mut product_outcomes,
204        );
205
206        // TODO:
207        // We have assigned the products prices straight from the solver
208        // (and computed the portfolio prices from those).
209        // Under pathological circumstances, the price may not be unique
210        // (either when there is no trade, or the supply exactly matches the demand).
211        // We should think about injecting an auxiliary solve for choosing a canonical
212        // price and/or for detecting when there is such a degeneracy.
213
214        // TODO:
215        // When there are "flat" demand curves, it is possible for nonuniqueness
216        // in the traded outcomes. The convex regularization is to minimize the L2 norm
217        // of the trades as a tie-break. We should think about the best way to regularize
218        // the solve accordingly.
219
220        Ok((portfolio_outcomes, product_outcomes))
221    }
222}
223
224impl<
225    DemandId: Clone + Eq + Hash + Ord + Send + Sync + 'static,
226    PortfolioId: Clone + Eq + Hash + Ord + Send + Sync + 'static,
227    ProductId: Clone + Eq + Hash + Ord + Send + Sync + 'static,
228> Solver<DemandId, PortfolioId, ProductId> for ClarabelSolver<DemandId, PortfolioId, ProductId>
229{
230    type Error = tokio::task::JoinError;
231    type PortfolioOutcome = PortfolioOutcome;
232    type ProductOutcome = ProductOutcome;
233
234    type State = ();
235
236    async fn solve(
237        &self,
238        demand_curves: Map<DemandId, DemandCurve>,
239        portfolios: Map<PortfolioId, (DemandGroup<DemandId>, ProductGroup<ProductId>)>,
240        _state: Self::State,
241    ) -> Result<
242        (
243            Map<PortfolioId, Self::PortfolioOutcome>,
244            Map<ProductId, Self::ProductOutcome>,
245        ),
246        Self::Error,
247    > {
248        let settings = self.0.clone();
249        let solution =
250            tokio::spawn(async move { Self::solve(settings, demand_curves, portfolios) }).await?;
251
252        // TODO: The JoinError happens when we panic inside.
253        // We can change this later, for now we just assume the solver worked.
254        Ok(solution.expect("failed to solve"))
255    }
256}