1use crate::{
2 Db,
3 types::{BidderId, DateTime, DemandHistoryRow, DemandId, DemandRow, PortfolioId},
4};
5use fts_core::{
6 models::{
7 DateTimeRangeQuery, DateTimeRangeResponse, DemandCurve, DemandCurveDto, DemandRecord, Map,
8 ValueRecord,
9 },
10 ports::DemandRepository,
11};
12
13impl<DemandData: Send + Unpin + serde::Serialize + serde::de::DeserializeOwned>
14 DemandRepository<DemandData> for Db
15{
16 async fn get_demand_bidder_id(
17 &self,
18 demand_id: Self::DemandId,
19 ) -> Result<Option<Self::BidderId>, Self::Error> {
20 sqlx::query_scalar!(
21 r#"
22 select
23 bidder_id as "id!: BidderId"
24 from
25 demand
26 where
27 id = $1
28 "#,
29 demand_id
30 )
31 .fetch_optional(&self.reader)
32 .await
33 }
34
35 async fn query_demand(
36 &self,
37 bidder_ids: &[Self::BidderId],
38 as_of: Self::DateTime,
39 ) -> Result<Vec<Self::DemandId>, Self::Error> {
40 if bidder_ids.len() == 0 {
41 Ok(Vec::new())
42 } else {
43 let bidder_ids = sqlx::types::Json(bidder_ids);
44 sqlx::query_scalar!(
45 r#"
46 select
47 demand.id as "id!: DemandId"
48 from
49 demand
50 join
51 curve_data
52 on
53 demand.id = curve_data.demand_id
54 join
55 json_each($1) as bidder_ids
56 on
57 demand.bidder_id = bidder_ids.atom
58 where
59 curve_data.value is not null
60 and
61 valid_from <= $2
62 and
63 ($2 < valid_until or valid_until is null)
64 "#,
65 bidder_ids,
66 as_of,
67 )
68 .fetch_all(&self.reader)
69 .await
70 }
71 }
72
73 async fn create_demand(
74 &self,
75 demand_id: Self::DemandId,
76 bidder_id: Self::BidderId,
77 app_data: DemandData,
78 curve_data: Option<DemandCurve>,
79 as_of: Self::DateTime,
80 ) -> Result<(), Self::Error> {
81 let app_data = sqlx::types::Json(app_data);
82 let curve_data = curve_data.map(|x| sqlx::types::Json(x));
85 sqlx::query!(
86 r#"
87 insert into
88 demand (id, as_of, bidder_id, app_data, curve_data)
89 values
90 (?, ?, ?, jsonb(?), jsonb(?))
91 "#,
92 demand_id,
93 as_of,
94 bidder_id,
95 app_data,
96 curve_data,
97 )
98 .execute(&self.writer)
99 .await?;
100 Ok(())
101 }
102
103 async fn update_demand(
104 &self,
105 demand_id: Self::DemandId,
106 curve_data: Option<DemandCurve>,
107 as_of: Self::DateTime,
108 ) -> Result<bool, Self::Error> {
109 let curve_data = curve_data.map(|x| sqlx::types::Json(x));
110 let query = sqlx::query!(
111 r#"
112 update
113 demand
114 set
115 as_of = $2,
116 curve_data = jsonb($3)
117 where
118 id = $1
119 "#,
120 demand_id,
121 as_of,
122 curve_data,
123 )
124 .execute(&self.writer)
125 .await?;
126
127 Ok(query.rows_affected() > 0)
128 }
129
130 async fn get_demand(
131 &self,
132 demand_id: Self::DemandId,
133 as_of: Self::DateTime,
134 ) -> Result<
135 Option<
136 DemandRecord<
137 Self::DateTime,
138 Self::BidderId,
139 Self::DemandId,
140 Self::PortfolioId,
141 DemandData,
142 >,
143 >,
144 Self::Error,
145 > {
146 let query =
147 sqlx::query_file_as!(DemandRow, "queries/get_demand_by_id.sql", demand_id, as_of)
148 .fetch_optional(&self.reader)
149 .await?;
150
151 Ok(query.map(|row| DemandRecord {
152 id: demand_id,
153 as_of,
154 bidder_id: row.bidder_id,
155 app_data: row.app_data.0,
156 curve_data: row
157 .curve_data
158 .map(|data| unsafe { DemandCurve::new_unchecked(data.0) }),
160 portfolio_group: row.portfolio_group.map(|data| data.0).unwrap_or_default(),
161 }))
162 }
163
164 async fn get_demand_history(
165 &self,
166 demand_id: Self::DemandId,
167 query: DateTimeRangeQuery<Self::DateTime>,
168 limit: usize,
169 ) -> Result<
170 DateTimeRangeResponse<ValueRecord<Self::DateTime, DemandCurve>, Self::DateTime>,
171 Self::Error,
172 > {
173 let limit_p1 = (limit + 1) as i64;
174 let mut rows = sqlx::query_as!(
175 DemandHistoryRow,
176 r#"
177 select
178 valid_from as "valid_from!: DateTime",
179 valid_until as "valid_until?: DateTime",
180 json(value) as "curve_data!: sqlx::types::Json<DemandCurveDto>"
181 from
182 curve_data
183 where
184 demand_id = $1
185 and
186 ($2 is null or valid_from >= $2)
187 and
188 ($3 is null or valid_until is null or valid_until < $3)
189 and
190 value is not null
191 order by
192 valid_from desc
193 limit $4
194 "#,
195 demand_id,
196 query.after,
197 query.before,
198 limit_p1, )
200 .fetch_all(&self.reader)
201 .await?;
202
203 let more = if rows.len() == limit + 1 {
206 let extra = rows.pop().unwrap();
207 Some(DateTimeRangeQuery {
208 before: Some(extra.valid_from),
209 after: query.after,
210 })
211 } else {
212 None
213 };
214
215 Ok(DateTimeRangeResponse {
216 results: rows.into_iter().map(Into::into).collect(),
217 more,
218 })
219 }
220}