fts_sqlite/impl/
demand.rs

1use crate::{
2    Db,
3    types::{BidderId, DateTime, DemandHistoryRow, DemandId, DemandRow, PortfolioId},
4};
5use fts_core::{
6    models::{
7        DateTimeRangeQuery, DateTimeRangeResponse, DemandCurve, DemandCurveDto, DemandRecord, Map,
8        ValueRecord,
9    },
10    ports::DemandRepository,
11};
12
13impl<DemandData: Send + Unpin + serde::Serialize + serde::de::DeserializeOwned>
14    DemandRepository<DemandData> for Db
15{
16    async fn get_demand_bidder_id(
17        &self,
18        demand_id: Self::DemandId,
19    ) -> Result<Option<Self::BidderId>, Self::Error> {
20        sqlx::query_scalar!(
21            r#"
22            select
23                bidder_id as "id!: BidderId"
24            from
25                demand
26            where
27                id = $1
28            "#,
29            demand_id
30        )
31        .fetch_optional(&self.reader)
32        .await
33    }
34
35    async fn query_demand(
36        &self,
37        bidder_ids: &[Self::BidderId],
38        as_of: Self::DateTime,
39    ) -> Result<Vec<Self::DemandId>, Self::Error> {
40        if bidder_ids.len() == 0 {
41            Ok(Vec::new())
42        } else {
43            let bidder_ids = sqlx::types::Json(bidder_ids);
44            sqlx::query_scalar!(
45                r#"
46                select
47                    demand.id as "id!: DemandId"
48                from
49                    demand
50                join
51                    curve_data
52                on
53                    demand.id = curve_data.demand_id
54                join
55                    json_each($1) as bidder_ids
56                on
57                    demand.bidder_id = bidder_ids.atom
58                where
59                    curve_data.value is not null
60                and
61                    valid_from <= $2
62                and
63                    ($2 < valid_until or valid_until is null) 
64                "#,
65                bidder_ids,
66                as_of,
67            )
68            .fetch_all(&self.reader)
69            .await
70        }
71    }
72
73    async fn create_demand(
74        &self,
75        demand_id: Self::DemandId,
76        bidder_id: Self::BidderId,
77        app_data: DemandData,
78        curve_data: Option<DemandCurve>,
79        as_of: Self::DateTime,
80    ) -> Result<(), Self::Error> {
81        let app_data = sqlx::types::Json(app_data);
82        // Important: If curve_data is None, we insert NULL into the database
83        // Else, this propagates into a [0] value in the JSONB column
84        let curve_data = curve_data.map(|x| sqlx::types::Json(x));
85        sqlx::query!(
86            r#"
87            insert into
88                demand (id, as_of, bidder_id, app_data, curve_data)
89            values
90                (?, ?, ?, jsonb(?), jsonb(?))
91            "#,
92            demand_id,
93            as_of,
94            bidder_id,
95            app_data,
96            curve_data,
97        )
98        .execute(&self.writer)
99        .await?;
100        Ok(())
101    }
102
103    async fn update_demand(
104        &self,
105        demand_id: Self::DemandId,
106        curve_data: Option<DemandCurve>,
107        as_of: Self::DateTime,
108    ) -> Result<bool, Self::Error> {
109        let curve_data = curve_data.map(|x| sqlx::types::Json(x));
110        let query = sqlx::query!(
111            r#"
112            update
113                demand
114            set
115                as_of = $2,
116                curve_data = jsonb($3)
117            where
118                id = $1
119            "#,
120            demand_id,
121            as_of,
122            curve_data,
123        )
124        .execute(&self.writer)
125        .await?;
126
127        Ok(query.rows_affected() > 0)
128    }
129
130    async fn get_demand(
131        &self,
132        demand_id: Self::DemandId,
133        as_of: Self::DateTime,
134    ) -> Result<
135        Option<
136            DemandRecord<
137                Self::DateTime,
138                Self::BidderId,
139                Self::DemandId,
140                Self::PortfolioId,
141                DemandData,
142            >,
143        >,
144        Self::Error,
145    > {
146        let query =
147            sqlx::query_file_as!(DemandRow, "queries/get_demand_by_id.sql", demand_id, as_of)
148                .fetch_optional(&self.reader)
149                .await?;
150
151        Ok(query.map(|row| DemandRecord {
152            id: demand_id,
153            as_of,
154            bidder_id: row.bidder_id,
155            app_data: row.app_data.0,
156            curve_data: row
157                .curve_data
158                // SAFETY: `curve_data` was necessarily serialized from a valid curve, so we can safely skip the validation
159                .map(|data| unsafe { DemandCurve::new_unchecked(data.0) }),
160            portfolio_group: row.portfolio_group.map(|data| data.0).unwrap_or_default(),
161        }))
162    }
163
164    async fn get_demand_history(
165        &self,
166        demand_id: Self::DemandId,
167        query: DateTimeRangeQuery<Self::DateTime>,
168        limit: usize,
169    ) -> Result<
170        DateTimeRangeResponse<ValueRecord<Self::DateTime, DemandCurve>, Self::DateTime>,
171        Self::Error,
172    > {
173        let limit_p1 = (limit + 1) as i64;
174        let mut rows = sqlx::query_as!(
175            DemandHistoryRow,
176            r#"
177                select
178                    valid_from as "valid_from!: DateTime",
179                    valid_until as "valid_until?: DateTime",
180                    json(value) as "curve_data!: sqlx::types::Json<DemandCurveDto>"
181                from
182                    curve_data
183                where
184                    demand_id = $1
185                and
186                    ($2 is null or valid_from >= $2)
187                and
188                    ($3 is null or valid_until is null or valid_until < $3)
189                and
190                    value is not null
191                order by
192                    valid_from desc
193                limit $4
194            "#,
195            demand_id,
196            query.after,
197            query.before,
198            limit_p1, // +1 to check if there are more results
199        )
200        .fetch_all(&self.reader)
201        .await?;
202
203        // We paginate by adding 1 to the limit, popping the result of, and
204        // using it to adjust the query object
205        let more = if rows.len() == limit + 1 {
206            let extra = rows.pop().unwrap();
207            Some(DateTimeRangeQuery {
208                before: Some(extra.valid_from),
209                after: query.after,
210            })
211        } else {
212            None
213        };
214
215        Ok(DateTimeRangeResponse {
216            results: rows.into_iter().map(Into::into).collect(),
217            more,
218        })
219    }
220}