1use crate::Db;
2use crate::types::{BatchData, DemandId, OutcomeRow, PortfolioId, ProductId};
3use fts_core::models::{DateTimeRangeQuery, DateTimeRangeResponse, Map, OutcomeRecord};
4use fts_core::{
5 models::{DemandCurve, DemandCurveDto, DemandGroup, ProductGroup},
6 ports::{BatchRepository, Solver},
7};
8
9impl<T: Solver<DemandId, PortfolioId, ProductId>> BatchRepository<T> for Db
10where
11 T: Send,
12 T::Error: Send,
13 T::State: Send,
14 T::PortfolioOutcome: Unpin + Send + serde::Serialize + serde::de::DeserializeOwned,
15 T::ProductOutcome: Unpin + Send + serde::Serialize + serde::de::DeserializeOwned,
16{
17 async fn run_batch(
18 &self,
19 timestamp: Self::DateTime,
20 solver: T,
21 state: T::State,
22 ) -> Result<Result<(), T::Error>, Self::Error> {
23 let data = sqlx::query_file_as!(BatchData, "queries/gather_batch.sql", timestamp)
28 .fetch_optional(&self.reader)
29 .await?;
30
31 let (demands, portfolios) = if let Some(BatchData {
32 demands,
33 portfolios,
34 }) = data
35 {
36 let demands = demands
37 .map(|x| x.0)
38 .unwrap_or_default()
39 .into_iter()
40 .map(|(key, value)| (key, unsafe { DemandCurve::new_unchecked(value) }))
41 .collect();
42
43 let portfolios = portfolios.map(|x| x.0).unwrap_or_default();
44
45 (demands, portfolios)
46 } else {
47 Default::default()
48 };
49
50 let outcome = solver.solve(demands, portfolios, state).await;
51
52 match outcome {
53 Ok((portfolio_outcomes, product_outcomes)) => {
54 let portfolio_outcomes = sqlx::types::Json(portfolio_outcomes);
55 let product_outcomes = sqlx::types::Json(product_outcomes);
56 sqlx::query!(
57 r#"
58 update
59 batch
60 set
61 as_of = $1,
62 portfolio_outcomes = jsonb($2),
63 product_outcomes = jsonb($3)
64 "#,
65 timestamp,
66 portfolio_outcomes,
67 product_outcomes
68 )
69 .execute(&self.writer)
70 .await?;
71 Ok(Ok(()))
72 }
73 Err(error) => Ok(Err(error)),
74 }
75 }
76
77 async fn get_portfolio_outcomes(
83 &self,
84 portfolio_id: Self::PortfolioId,
85 query: DateTimeRangeQuery<Self::DateTime>,
86 limit: usize,
87 ) -> Result<
88 DateTimeRangeResponse<OutcomeRecord<Self::DateTime, T::PortfolioOutcome>, Self::DateTime>,
89 Self::Error,
90 > {
91 let limit_p1 = (limit + 1) as i64;
92 let mut rows = sqlx::query_as!(
93 OutcomeRow::<T::PortfolioOutcome>,
94 r#"
95 select
96 valid_from as "as_of!: crate::types::DateTime",
97 json(value) as "outcome!: sqlx::types::Json<T::PortfolioOutcome>"
98 from
99 portfolio_outcome
100 where
101 portfolio_id = $1
102 and
103 ($2 is null or valid_from >= $2)
104 and
105 ($3 is null or valid_until is null or valid_until < $3)
106 group by
107 valid_from
108 order by
109 valid_from desc
110 limit $4
111 "#,
112 portfolio_id,
113 query.after,
114 query.before,
115 limit_p1,
116 )
117 .fetch_all(&self.reader)
118 .await?;
119
120 let more = if rows.len() == limit + 1 {
121 let extra = rows.pop().unwrap();
122 Some(DateTimeRangeQuery {
123 before: Some(extra.as_of),
124 after: query.after,
125 })
126 } else {
127 None
128 };
129
130 Ok(DateTimeRangeResponse {
131 results: rows
132 .into_iter()
133 .map(|row| OutcomeRecord {
134 as_of: row.as_of,
135 outcome: row.outcome.0,
136 })
137 .collect(),
138 more,
139 })
140 }
141
142 async fn get_product_outcomes(
148 &self,
149 product_id: Self::ProductId,
150 query: DateTimeRangeQuery<Self::DateTime>,
151 limit: usize,
152 ) -> Result<
153 DateTimeRangeResponse<OutcomeRecord<Self::DateTime, T::ProductOutcome>, Self::DateTime>,
154 Self::Error,
155 > {
156 let limit_p1 = (limit + 1) as i64;
157 let mut rows = sqlx::query_as!(
158 OutcomeRow::<T::ProductOutcome>,
159 r#"
160 select
161 valid_from as "as_of!: crate::types::DateTime",
162 json(value) as "outcome!: sqlx::types::Json<T::ProductOutcome>"
163 from
164 product_outcome
165 where
166 product_id = $1
167 and
168 ($2 is null or valid_from >= $2)
169 and
170 ($3 is null or valid_until is null or valid_until < $3)
171 group by
172 valid_from
173 order by
174 valid_from desc
175 limit $4
176 "#,
177 product_id,
178 query.after,
179 query.before,
180 limit_p1,
181 )
182 .fetch_all(&self.reader)
183 .await?;
184
185 let more = if rows.len() == limit + 1 {
186 let extra = rows.pop().unwrap();
187 Some(DateTimeRangeQuery {
188 before: Some(extra.as_of),
189 after: query.after,
190 })
191 } else {
192 None
193 };
194
195 Ok(DateTimeRangeResponse {
196 results: rows
197 .into_iter()
198 .map(|row| OutcomeRecord {
199 as_of: row.as_of,
200 outcome: row.outcome.0,
201 })
202 .collect(),
203 more,
204 })
205 }
206}