use crate::Db;
use crate::types::{BatchData, DemandId, OutcomeRow, PortfolioId, ProductId};
use fts_core::models::{DateTimeRangeQuery, DateTimeRangeResponse, Map, OutcomeRecord};
use fts_core::{
models::{DemandCurve, DemandCurveDto, DemandGroup, ProductGroup},
ports::{BatchRepository, Solver},
};
impl<T: Solver<DemandId, PortfolioId, ProductId>> BatchRepository<T> for Db
where
T: Send,
T::Error: Send,
T::State: Send,
T::PortfolioOutcome: Unpin + Send + serde::Serialize + serde::de::DeserializeOwned,
T::ProductOutcome: Unpin + Send + serde::Serialize + serde::de::DeserializeOwned,
{
async fn run_batch(
&self,
timestamp: Self::DateTime,
solver: T,
state: T::State,
) -> Result<Result<(), T::Error>, Self::Error> {
let data = sqlx::query_file_as!(BatchData, "queries/gather_batch.sql", timestamp)
.fetch_optional(&self.reader)
.await?;
let (demands, portfolios) = if let Some(BatchData {
demands,
portfolios,
}) = data
{
let demands = demands
.map(|x| x.0)
.unwrap_or_default()
.into_iter()
.map(|(key, value)| (key, unsafe { DemandCurve::new_unchecked(value) }))
.collect();
let portfolios = portfolios.map(|x| x.0).unwrap_or_default();
(demands, portfolios)
} else {
Default::default()
};
let outcome = solver.solve(demands, portfolios, state).await;
match outcome {
Ok((portfolio_outcomes, product_outcomes)) => {
let portfolio_outcomes = sqlx::types::Json(portfolio_outcomes);
let product_outcomes = sqlx::types::Json(product_outcomes);
sqlx::query!(
r#"
update
batch
set
as_of = $1,
portfolio_outcomes = jsonb($2),
product_outcomes = jsonb($3)
"#,
timestamp,
portfolio_outcomes,
product_outcomes
)
.execute(&self.writer)
.await?;
Ok(Ok(()))
}
Err(error) => Ok(Err(error)),
}
}
async fn get_portfolio_outcomes(
&self,
portfolio_id: Self::PortfolioId,
query: DateTimeRangeQuery<Self::DateTime>,
limit: usize,
) -> Result<
DateTimeRangeResponse<OutcomeRecord<Self::DateTime, T::PortfolioOutcome>, Self::DateTime>,
Self::Error,
> {
let limit_p1 = (limit + 1) as i64;
let mut rows = sqlx::query_as!(
OutcomeRow::<T::PortfolioOutcome>,
r#"
select
valid_from as "as_of!: crate::types::DateTime",
json(value) as "outcome!: sqlx::types::Json<T::PortfolioOutcome>"
from
portfolio_outcome
where
portfolio_id = $1
and
($2 is null or valid_from >= $2)
and
($3 is null or valid_until is null or valid_until < $3)
group by
valid_from
order by
valid_from desc
limit $4
"#,
portfolio_id,
query.after,
query.before,
limit_p1,
)
.fetch_all(&self.reader)
.await?;
let more = if rows.len() == limit + 1 {
let extra = rows.pop().unwrap();
Some(DateTimeRangeQuery {
before: Some(extra.as_of),
after: query.after,
})
} else {
None
};
Ok(DateTimeRangeResponse {
results: rows
.into_iter()
.map(|row| OutcomeRecord {
as_of: row.as_of,
outcome: row.outcome.0,
})
.collect(),
more,
})
}
async fn get_product_outcomes(
&self,
product_id: Self::ProductId,
query: DateTimeRangeQuery<Self::DateTime>,
limit: usize,
) -> Result<
DateTimeRangeResponse<OutcomeRecord<Self::DateTime, T::ProductOutcome>, Self::DateTime>,
Self::Error,
> {
let limit_p1 = (limit + 1) as i64;
let mut rows = sqlx::query_as!(
OutcomeRow::<T::ProductOutcome>,
r#"
select
valid_from as "as_of!: crate::types::DateTime",
json(value) as "outcome!: sqlx::types::Json<T::ProductOutcome>"
from
product_outcome
where
product_id = $1
and
($2 is null or valid_from >= $2)
and
($3 is null or valid_until is null or valid_until < $3)
group by
valid_from
order by
valid_from desc
limit $4
"#,
product_id,
query.after,
query.before,
limit_p1,
)
.fetch_all(&self.reader)
.await?;
let more = if rows.len() == limit + 1 {
let extra = rows.pop().unwrap();
Some(DateTimeRangeQuery {
before: Some(extra.as_of),
after: query.after,
})
} else {
None
};
Ok(DateTimeRangeResponse {
results: rows
.into_iter()
.map(|row| OutcomeRecord {
as_of: row.as_of,
outcome: row.outcome.0,
})
.collect(),
more,
})
}
}