fts_sqlite/impl/
batch.rs

1use crate::Db;
2use crate::types::{BatchData, DemandId, OutcomeRow, PortfolioId, ProductId};
3use fts_core::models::{DateTimeRangeQuery, DateTimeRangeResponse, Map, OutcomeRecord};
4use fts_core::{
5    models::{DemandCurve, DemandCurveDto, DemandGroup, ProductGroup},
6    ports::{BatchRepository, Solver},
7};
8
9impl<T: Solver<DemandId, PortfolioId, ProductId>> BatchRepository<T> for Db
10where
11    T: Send,
12    T::Error: Send,
13    T::State: Send,
14    T::PortfolioOutcome: Unpin + Send + serde::Serialize + serde::de::DeserializeOwned,
15    T::ProductOutcome: Unpin + Send + serde::Serialize + serde::de::DeserializeOwned,
16{
17    async fn run_batch(
18        &self,
19        timestamp: Self::DateTime,
20        solver: T,
21        state: T::State,
22    ) -> Result<Result<(), T::Error>, Self::Error> {
23        // TODO: we may wish to filter the portfolios we include for administrative reasons./
24        // what is the best way to do this? Perhaps we say this is (one of) the responsibilities
25        // of the state, e.g. contains a HashSet of the "suspended" portfolio ids, and our solver is
26        // responsible.... I actually like this a lot.
27        let data = sqlx::query_file_as!(BatchData, "queries/gather_batch.sql", timestamp)
28            .fetch_optional(&self.reader)
29            .await?;
30
31        let (demands, portfolios) = if let Some(BatchData {
32            demands,
33            portfolios,
34        }) = data
35        {
36            let demands = demands
37                .map(|x| x.0)
38                .unwrap_or_default()
39                .into_iter()
40                .map(|(key, value)| (key, unsafe { DemandCurve::new_unchecked(value) }))
41                .collect();
42
43            let portfolios = portfolios.map(|x| x.0).unwrap_or_default();
44
45            (demands, portfolios)
46        } else {
47            Default::default()
48        };
49
50        let outcome = solver.solve(demands, portfolios, state).await;
51
52        match outcome {
53            Ok((portfolio_outcomes, product_outcomes)) => {
54                let portfolio_outcomes = sqlx::types::Json(portfolio_outcomes);
55                let product_outcomes = sqlx::types::Json(product_outcomes);
56                sqlx::query!(
57                    r#"
58                    update
59                        batch
60                    set
61                        as_of = $1,
62                        portfolio_outcomes = jsonb($2),
63                        product_outcomes = jsonb($3)
64                    "#,
65                    timestamp,
66                    portfolio_outcomes,
67                    product_outcomes
68                )
69                .execute(&self.writer)
70                .await?;
71                Ok(Ok(()))
72            }
73            Err(error) => Ok(Err(error)),
74        }
75    }
76
77    /// Get the portfolio's outcomes
78    ///
79    /// This returns a list of outcomes, each corresponding to a specific point in time.
80    /// The records are ordered by `valid_from` in descending order
81    /// and are grouped by `valid_from`.
82    async fn get_portfolio_outcomes(
83        &self,
84        portfolio_id: Self::PortfolioId,
85        query: DateTimeRangeQuery<Self::DateTime>,
86        limit: usize,
87    ) -> Result<
88        DateTimeRangeResponse<OutcomeRecord<Self::DateTime, T::PortfolioOutcome>, Self::DateTime>,
89        Self::Error,
90    > {
91        let limit_p1 = (limit + 1) as i64;
92        let mut rows = sqlx::query_as!(
93            OutcomeRow::<T::PortfolioOutcome>,
94            r#"
95                select
96                    valid_from as "as_of!: crate::types::DateTime",
97                    json(value) as "outcome!: sqlx::types::Json<T::PortfolioOutcome>"
98                from
99                    portfolio_outcome
100                where
101                    portfolio_id = $1
102                and
103                    ($2 is null or valid_from >= $2)
104                and
105                    ($3 is null or valid_until is null or valid_until < $3)
106                group by
107                    valid_from
108                order by
109                    valid_from desc
110                limit $4
111            "#,
112            portfolio_id,
113            query.after,
114            query.before,
115            limit_p1,
116        )
117        .fetch_all(&self.reader)
118        .await?;
119
120        let more = if rows.len() == limit + 1 {
121            let extra = rows.pop().unwrap();
122            Some(DateTimeRangeQuery {
123                before: Some(extra.as_of),
124                after: query.after,
125            })
126        } else {
127            None
128        };
129
130        Ok(DateTimeRangeResponse {
131            results: rows
132                .into_iter()
133                .map(|row| OutcomeRecord {
134                    as_of: row.as_of,
135                    outcome: row.outcome.0,
136                })
137                .collect(),
138            more,
139        })
140    }
141
142    /// Get the product's outcomes
143    ///
144    /// This returns a list of outcomes, each corresponding to a specific point in time.
145    /// The records are ordered by `valid_from` in descending order
146    /// and are grouped by `valid_from`.
147    async fn get_product_outcomes(
148        &self,
149        product_id: Self::ProductId,
150        query: DateTimeRangeQuery<Self::DateTime>,
151        limit: usize,
152    ) -> Result<
153        DateTimeRangeResponse<OutcomeRecord<Self::DateTime, T::ProductOutcome>, Self::DateTime>,
154        Self::Error,
155    > {
156        let limit_p1 = (limit + 1) as i64;
157        let mut rows = sqlx::query_as!(
158            OutcomeRow::<T::ProductOutcome>,
159            r#"
160                select
161                    valid_from as "as_of!: crate::types::DateTime",
162                    json(value) as "outcome!: sqlx::types::Json<T::ProductOutcome>"
163                from
164                    product_outcome
165                where
166                    product_id = $1
167                and
168                    ($2 is null or valid_from >= $2)
169                and
170                    ($3 is null or valid_until is null or valid_until < $3)
171                group by
172                    valid_from
173                order by
174                    valid_from desc
175                limit $4
176            "#,
177            product_id,
178            query.after,
179            query.before,
180            limit_p1,
181        )
182        .fetch_all(&self.reader)
183        .await?;
184
185        let more = if rows.len() == limit + 1 {
186            let extra = rows.pop().unwrap();
187            Some(DateTimeRangeQuery {
188                before: Some(extra.as_of),
189                after: query.after,
190            })
191        } else {
192            None
193        };
194
195        Ok(DateTimeRangeResponse {
196            results: rows
197                .into_iter()
198                .map(|row| OutcomeRecord {
199                    as_of: row.as_of,
200                    outcome: row.outcome.0,
201                })
202                .collect(),
203            more,
204        })
205    }
206}