cratestack_sqlx/query/batch/
get.rs1use std::collections::HashMap;
6use std::hash::Hash;
7
8use cratestack_core::{BatchResponse, CoolContext, CoolError};
9
10use crate::query::support::push_action_policy_query;
11use crate::{ModelDescriptor, ModelPrimaryKey, SqlxRuntime, sqlx};
12
13use super::validate::{reject_duplicate_pks, validate_batch_size};
14
15#[derive(Debug, Clone)]
16pub struct BatchGet<'a, M: 'static, PK: 'static> {
17 pub(crate) runtime: &'a SqlxRuntime,
18 pub(crate) descriptor: &'static ModelDescriptor<M, PK>,
19 pub(crate) ids: Vec<PK>,
20}
21
22impl<'a, M: 'static, PK: 'static> BatchGet<'a, M, PK> {
23 pub async fn run(self, ctx: &CoolContext) -> Result<BatchResponse<M>, CoolError>
24 where
25 for<'r> M:
26 Send + Unpin + sqlx::FromRow<'r, sqlx::postgres::PgRow> + ModelPrimaryKey<PK>,
27 PK: Clone
28 + Eq
29 + Hash
30 + Send
31 + sqlx::Type<sqlx::Postgres>
32 + for<'q> sqlx::Encode<'q, sqlx::Postgres>,
33 {
34 validate_batch_size(self.ids.len())?;
35 reject_duplicate_pks(&self.ids)?;
36 if self.ids.is_empty() {
37 return Ok(BatchResponse::from_results(vec![]));
38 }
39
40 let mut query = sqlx::QueryBuilder::<sqlx::Postgres>::new("SELECT ");
42 query.push(self.descriptor.select_projection());
43 query.push(" FROM ").push(self.descriptor.table_name);
44 query.push(" WHERE ");
45 if let Some(col) = self.descriptor.soft_delete_column {
46 query.push(col).push(" IS NULL AND ");
47 }
48 query.push(self.descriptor.primary_key).push(" IN (");
49 for (index, id) in self.ids.iter().enumerate() {
50 if index > 0 {
51 query.push(", ");
52 }
53 query.push_bind(id.clone());
54 }
55 query.push(") AND ");
56 push_action_policy_query(
57 &mut query,
58 self.descriptor.read_allow_policies,
59 self.descriptor.read_deny_policies,
60 ctx,
61 );
62
63 let rows: Vec<M> = query
64 .build_query_as::<M>()
65 .fetch_all(self.runtime.pool())
66 .await
67 .map_err(|error| CoolError::Database(error.to_string()))?;
68
69 let mut by_pk: HashMap<PK, M> =
72 rows.into_iter().map(|m| (m.primary_key(), m)).collect();
73 let per_item: Vec<Result<M, CoolError>> = self
74 .ids
75 .into_iter()
76 .map(|id| {
77 by_pk
78 .remove(&id)
79 .ok_or_else(|| CoolError::NotFound("no row matched".to_owned()))
80 })
81 .collect();
82
83 Ok(BatchResponse::from_results(per_item))
84 }
85}