use std::any::{Any, TypeId};
use std::collections::{HashMap, HashSet};
use std::future::Future;
use std::marker::PhantomData;
use std::ops::Deref;
use std::pin::Pin;
use crate::dialect::{render_select, Dialect};
use crate::executor::Executor;
use crate::model::Model;
use crate::query::ast::{SelectItem, SelectStatement};
use crate::query::expr::Expr;
use crate::query::QuerySet;
use crate::relation::Relation;
use crate::row::Row;
use crate::value::Value;
type BoxFuture<'a, T> = Pin<Box<dyn Future<Output = T> + Send + 'a>>;
trait QueryRunner: Sync {
fn dialect(&self) -> &dyn Dialect;
fn fetch_all<'a>(
&'a self,
sql: String,
params: Vec<Value>,
) -> BoxFuture<'a, crate::Result<Vec<Row>>>;
}
impl<E: Executor + Sync> QueryRunner for E {
fn dialect(&self) -> &dyn Dialect {
Executor::dialect(self)
}
fn fetch_all<'a>(
&'a self,
sql: String,
params: Vec<Value>,
) -> BoxFuture<'a, crate::Result<Vec<Row>>> {
Box::pin(Executor::fetch_all(self, sql, params))
}
}
#[derive(Clone, PartialEq, Eq, Hash)]
struct RelationKey {
type_id: TypeId,
from_column: &'static str,
to_column: &'static str,
}
impl RelationKey {
fn of<P, C: 'static>(relation: &Relation<P, C>) -> Self {
Self {
type_id: TypeId::of::<C>(),
from_column: relation.from_column(),
to_column: relation.to_column(),
}
}
}
pub struct Preloaded<M> {
parent: M,
relations: HashMap<RelationKey, Box<dyn Any + Send + Sync>>,
}
impl<M> Preloaded<M> {
pub fn get<C: Send + Sync + 'static>(&self) -> &[C] {
let type_id = TypeId::of::<C>();
self.relations
.iter()
.find(|(key, _)| key.type_id == type_id)
.and_then(|(_, boxed)| boxed.downcast_ref::<Vec<C>>())
.map(Vec::as_slice)
.unwrap_or(&[])
}
pub fn get_via<C: Send + Sync + 'static>(&self, relation: &Relation<M, C>) -> &[C] {
self.relations
.get(&RelationKey::of(relation))
.and_then(|boxed| boxed.downcast_ref::<Vec<C>>())
.map(Vec::as_slice)
.unwrap_or(&[])
}
pub fn parent(&self) -> &M {
&self.parent
}
pub fn into_parent(self) -> M {
self.parent
}
}
impl<M> Deref for Preloaded<M> {
type Target = M;
fn deref(&self) -> &M {
&self.parent
}
}
struct PlanOutput {
key: RelationKey,
per_parent: Vec<Box<dyn Any + Send + Sync>>,
}
trait PreloadStep<M>: Send + Sync {
fn load<'a>(
&'a self,
parents: &'a [M],
runner: &'a dyn QueryRunner,
) -> BoxFuture<'a, crate::Result<PlanOutput>>;
}
struct RelationPreload<M, C> {
relation: Relation<M, C>,
}
impl<M: Model, C: Model> PreloadStep<M> for RelationPreload<M, C> {
fn load<'a>(
&'a self,
parents: &'a [M],
runner: &'a dyn QueryRunner,
) -> BoxFuture<'a, crate::Result<PlanOutput>> {
Box::pin(async move {
let from_column = self.relation.from_column();
let to_column = self.relation.to_column();
let mut keys: Vec<Value> = Vec::new();
let mut seen: HashSet<String> = HashSet::new();
for parent in parents {
if let Some(value) = column_value(parent, from_column) {
if seen.insert(value_key(&value)) {
keys.push(value);
}
}
}
let relation_key = RelationKey::of(&self.relation);
if keys.is_empty() {
return Ok(PlanOutput {
key: relation_key,
per_parent: parents.iter().map(|_| empty_children::<C>()).collect(),
});
}
const FILTER_PARAM_MARGIN: usize = 16;
let chunk_size = runner
.dialect()
.max_bind_params()
.saturating_sub(FILTER_PARAM_MARGIN)
.max(1);
let mut groups: HashMap<String, Vec<Row>> = HashMap::new();
for key_chunk in keys.chunks(chunk_size) {
let projection = C::COLUMNS
.iter()
.map(|column| SelectItem::Column {
table: C::TABLE,
column: column.name,
})
.collect();
let mut statement = SelectStatement::new(C::TABLE, projection);
statement.filters.push(Expr::in_list(
Expr::column(C::TABLE, to_column),
key_chunk.to_vec(),
));
statement
.filters
.extend(self.relation.preload_filters().iter().cloned());
statement
.order_by
.extend(self.relation.preload_order_by().iter().cloned());
statement.limit = self.relation.preload_limit();
let (sql, params) = render_select(runner.dialect(), &statement);
let rows = runner.fetch_all(sql, params).await?;
for row in rows {
let key = value_key(&row.get::<Value>(to_column)?);
groups.entry(key).or_default().push(row);
}
}
let mut per_parent: Vec<Box<dyn Any + Send + Sync>> = Vec::with_capacity(parents.len());
for parent in parents {
let children: Vec<C> = match column_value(parent, from_column) {
Some(value) => match groups.get(&value_key(&value)) {
Some(rows) => rows
.iter()
.map(C::from_row)
.collect::<crate::Result<Vec<C>>>()?,
None => Vec::new(),
},
None => Vec::new(),
};
per_parent.push(Box::new(children));
}
Ok(PlanOutput {
key: relation_key,
per_parent,
})
})
}
}
pub struct Preloader<M: Model> {
base: QuerySet<M>,
plans: Vec<Box<dyn PreloadStep<M>>>,
_marker: PhantomData<fn() -> M>,
}
impl<M: Model> Preloader<M> {
pub(crate) fn new(base: QuerySet<M>) -> Self {
Self {
base,
plans: Vec::new(),
_marker: PhantomData,
}
}
pub fn preload<C: Model>(mut self, relation: Relation<M, C>) -> Self {
self.plans.push(Box::new(RelationPreload { relation }));
self
}
pub fn filter(mut self, predicate: Expr) -> Self {
self.base = self.base.filter(predicate);
self
}
pub fn filter_any(mut self, predicates: impl IntoIterator<Item = Expr>) -> Self {
self.base = self.base.filter_any(predicates);
self
}
pub fn filter_all(mut self, predicates: impl IntoIterator<Item = Expr>) -> Self {
self.base = self.base.filter_all(predicates);
self
}
pub fn filter_not(mut self, predicate: Expr) -> Self {
self.base = self.base.filter_not(predicate);
self
}
pub fn order_by(mut self, term: crate::query::ast::OrderTerm) -> Self {
self.base = self.base.order_by(term);
self
}
pub fn limit(mut self, limit: u64) -> Self {
self.base = self.base.limit(limit);
self
}
pub fn offset(mut self, offset: u64) -> Self {
self.base = self.base.offset(offset);
self
}
pub fn distinct(mut self) -> Self {
self.base = self.base.distinct();
self
}
pub async fn all<E: Executor + Sync>(self, executor: E) -> crate::Result<Vec<Preloaded<M>>> {
let parents = self.base.all(&executor).await?;
let mut relation_maps: Vec<HashMap<RelationKey, Box<dyn Any + Send + Sync>>> =
(0..parents.len()).map(|_| HashMap::new()).collect();
for plan in &self.plans {
let output = plan.load(&parents, &executor).await?;
for (index, children) in output.per_parent.into_iter().enumerate() {
relation_maps[index].insert(output.key.clone(), children);
}
}
Ok(parents
.into_iter()
.zip(relation_maps)
.map(|(parent, relations)| Preloaded { parent, relations })
.collect())
}
pub async fn first<E: Executor + Sync>(
self,
executor: E,
) -> crate::Result<Option<Preloaded<M>>> {
Ok(self.limit(1).all(executor).await?.into_iter().next())
}
}
fn column_value<M: Model>(model: &M, column: &str) -> Option<Value> {
if column == M::PRIMARY_KEY {
return Some(model.primary_key_value());
}
model
.column_values()
.into_iter()
.find(|(name, _)| *name == column)
.map(|(_, value)| value)
}
fn value_key(value: &Value) -> String {
match value {
Value::Null => "null".to_string(),
Value::Bool(b) => format!("b:{b}"),
Value::Int(i) => format!("i:{i}"),
Value::Real(r) => format!("r:{r}"),
Value::Text(s) => format!("t:{s}"),
Value::Blob(bytes) => format!("x:{bytes:?}"),
Value::Timestamp(ts) => format!("ts:{}", ts.unix_timestamp_nanos()),
Value::Uuid(u) => format!("u:{u}"),
Value::Json(j) => format!("j:{j}"),
Value::Array(items) => format!("a:{items:?}"),
}
}
fn empty_children<C: Send + Sync + 'static>() -> Box<dyn Any + Send + Sync> {
Box::new(Vec::<C>::new())
}
#[cfg(test)]
mod tests {
use super::*;
use time::{OffsetDateTime, UtcOffset};
#[test]
fn timestamp_keys_are_offset_independent() {
let instant = OffsetDateTime::from_unix_timestamp(1_700_000_000).unwrap();
let shifted = instant.to_offset(UtcOffset::from_hms(5, 30, 0).unwrap());
assert_ne!(instant.offset(), shifted.offset(), "offsets must differ");
assert_eq!(
value_key(&Value::Timestamp(instant)),
value_key(&Value::Timestamp(shifted)),
"equal instants must produce the same grouping key"
);
}
#[test]
fn distinct_instants_get_distinct_keys() {
let a = OffsetDateTime::from_unix_timestamp(1_700_000_000).unwrap();
let b = OffsetDateTime::from_unix_timestamp(1_700_000_001).unwrap();
assert_ne!(value_key(&Value::Timestamp(a)), value_key(&Value::Timestamp(b)));
}
}