pooly/services/queries/
mod.rs

1use std::sync::Arc;
2
3use deadpool_postgres::Transaction;
4use postgres_types::ToSql;
5
6use crate::models::errors::QueryError;
7use crate::models::query::parameters::convert_params;
8use crate::models::payloads::{ErrorResponse, query_response, QueryRequest, QueryResponse, RowResponseGroup, tx_bulk_query_response, TxBulkQueryRequest, TxBulkQueryRequestBody, TxBulkQueryResponse, TxBulkQuerySuccessResponse, TxQuerySuccessResponse};
9use crate::models::payloads::QuerySuccessResponse;
10use crate::models::responses::ResponseWithCode;
11use crate::models::query::rows::convert_rows;
12use crate::services::auth::access::AccessControlService;
13use crate::services::connections::ConnectionService;
14
15pub struct QueryService {
16
17    access_control_service: Arc<AccessControlService>,
18    connection_service: ConnectionService
19
20}
21
22impl QueryService {
23
24    pub fn new(access_control_service: Arc<AccessControlService>,
25               connection_service: ConnectionService) -> Self {
26        QueryService {
27            access_control_service,
28            connection_service
29        }
30    }
31
32    pub async fn bulk_tx(&self,
33                         client_id: &str,
34                         request: &TxBulkQueryRequest,
35                         correlation_id: &str) -> ResponseWithCode<TxBulkQueryResponse> {
36        match self.do_bulk_tx(client_id, request).await {
37            Ok(ok) => ResponseWithCode::ok(ok.into()),
38            Err(err) => QueryService::build_response(err, correlation_id)
39        }
40    }
41
42    pub async fn query(&self,
43                       client_id: &str,
44                       request: &QueryRequest,
45                       correlation_id: &str) -> ResponseWithCode<QueryResponse> {
46        match self.do_query(client_id, request).await {
47            Ok(ok) => ResponseWithCode::ok(ok.into()),
48            Err(err) => QueryService::build_response(err, correlation_id)
49        }
50    }
51
52    async fn do_bulk_tx(&self,
53                        client_id: &str,
54                        request: &TxBulkQueryRequest) -> Result<Vec<TxQuerySuccessResponse>, QueryError> {
55        let connection_id: &str = &request.connection_id;
56
57        if !self.access_control_service.is_allowed(client_id, connection_id)? {
58            return Err(QueryError::ForbiddenConnectionId);
59        }
60
61        if request.queries.is_empty() {
62            return Ok(Vec::new());
63        }
64
65        match self.connection_service.get(connection_id).await {
66            Some(connection_result) => {
67                let mut connection = connection_result?;
68
69                let tx: Transaction = connection.transaction().await?;
70
71                let mut results = Vec::new();
72
73                for (i, query_request_body) in request.queries.iter().enumerate() {
74                    let query_response =
75                        QueryService::do_execute_bulk(&tx, &query_request_body,i)
76                            .await?;
77
78                    results.push(query_response);
79                }
80
81                tx.commit().await?;
82
83                Ok(results)
84            },
85            None => Err(QueryError::UnknownDatabaseConnection(connection_id.to_owned())),
86        }
87    }
88
89    async fn do_execute_bulk(tx: &Transaction<'_>,
90                             bulk_body: &TxBulkQueryRequestBody,
91                             ord_num: usize) -> Result<TxQuerySuccessResponse, QueryError> {
92        let stmt =
93            tx.prepare_cached(&bulk_body.query).await?;
94
95        let mut results = Vec::new();
96
97        for params_row in &bulk_body.params {
98            let param_values: Vec<&(dyn ToSql + Sync)> = convert_params(
99                stmt.params(),
100                &params_row.values
101            )?;
102
103            let query_results =
104                tx.query(&stmt, param_values.as_slice()).await?;
105
106            results.push(convert_rows(query_results)?);
107        }
108
109        let column_names =
110            results.first()
111                .map_or(Vec::new(), |cwr| cwr.1.clone());
112
113        let row_groups =
114            results.into_iter()
115                .map(|cwr| RowResponseGroup { rows: cwr.0 })
116                .collect();
117
118        Ok(TxQuerySuccessResponse {
119            ord_num: ord_num as i32,
120            column_names,
121            row_groups
122        })
123    }
124
125    async fn do_query(&self,
126                      client_id: &str,
127                      request: &QueryRequest) -> Result<QuerySuccessResponse, QueryError> {
128        let connection_id: &str = &request.connection_id;
129
130        if !self.access_control_service.is_allowed(client_id, connection_id)? {
131            return Err(QueryError::ForbiddenConnectionId);
132        }
133
134        match self.connection_service.get(connection_id).await {
135            Some(connection_result) => {
136                let connection = connection_result?;
137
138                let stmt = connection.prepare_cached(&request.query).await?;
139
140                let params: Vec<&(dyn ToSql + Sync)> =
141                    convert_params(stmt.params(), &request.params)?;
142
143                let results =
144                    connection.query(&stmt, params.as_slice()).await?;
145
146                let cwr = convert_rows(results)?;
147
148                Ok(
149                    QuerySuccessResponse {
150                        rows: cwr.0,
151                        column_names: cwr.1
152                    }
153                )
154            }
155            None => Err(QueryError::UnknownDatabaseConnection(connection_id.to_owned()))
156        }
157    }
158
159    fn build_response<T: From<ErrorResponse>>(err: QueryError,
160                                              correlation_id: &str) -> ResponseWithCode<T> {
161        let code = err.get_code();
162        ResponseWithCode(err, code)
163            .map(|err|
164                err.to_error_response(correlation_id.to_string()))
165            .map(|err_response| err_response.into())
166    }
167}
168
169impl From<QuerySuccessResponse> for QueryResponse {
170    fn from(success: QuerySuccessResponse) -> Self {
171        QueryResponse {
172            payload: Some(query_response::Payload::Success(success))
173        }
174    }
175}
176
177impl From<ErrorResponse> for QueryResponse {
178    fn from(err: ErrorResponse) -> Self {
179        QueryResponse {
180            payload: Some(query_response::Payload::Error(err))
181        }
182    }
183}
184
185impl From<Vec<TxQuerySuccessResponse>> for TxBulkQueryResponse {
186    fn from(responses: Vec<TxQuerySuccessResponse>) -> Self {
187        TxBulkQueryResponse {
188            payload: Some(tx_bulk_query_response::Payload::Success(
189                TxBulkQuerySuccessResponse {
190                    responses
191                }))
192        }
193    }
194}
195
196impl From<ErrorResponse> for TxBulkQueryResponse {
197    fn from(err: ErrorResponse) -> Self {
198        TxBulkQueryResponse {
199            payload: Some(tx_bulk_query_response::Payload::Error(err))
200        }
201    }
202}
203
204