pooly/services/queries/
mod.rs1use std::sync::Arc;
2
3use deadpool_postgres::Transaction;
4use postgres_types::ToSql;
5
6use crate::models::errors::QueryError;
7use crate::models::query::parameters::convert_params;
8use crate::models::payloads::{ErrorResponse, query_response, QueryRequest, QueryResponse, RowResponseGroup, tx_bulk_query_response, TxBulkQueryRequest, TxBulkQueryRequestBody, TxBulkQueryResponse, TxBulkQuerySuccessResponse, TxQuerySuccessResponse};
9use crate::models::payloads::QuerySuccessResponse;
10use crate::models::responses::ResponseWithCode;
11use crate::models::query::rows::convert_rows;
12use crate::services::auth::access::AccessControlService;
13use crate::services::connections::ConnectionService;
14
15pub struct QueryService {
16
17 access_control_service: Arc<AccessControlService>,
18 connection_service: ConnectionService
19
20}
21
22impl QueryService {
23
24 pub fn new(access_control_service: Arc<AccessControlService>,
25 connection_service: ConnectionService) -> Self {
26 QueryService {
27 access_control_service,
28 connection_service
29 }
30 }
31
32 pub async fn bulk_tx(&self,
33 client_id: &str,
34 request: &TxBulkQueryRequest,
35 correlation_id: &str) -> ResponseWithCode<TxBulkQueryResponse> {
36 match self.do_bulk_tx(client_id, request).await {
37 Ok(ok) => ResponseWithCode::ok(ok.into()),
38 Err(err) => QueryService::build_response(err, correlation_id)
39 }
40 }
41
42 pub async fn query(&self,
43 client_id: &str,
44 request: &QueryRequest,
45 correlation_id: &str) -> ResponseWithCode<QueryResponse> {
46 match self.do_query(client_id, request).await {
47 Ok(ok) => ResponseWithCode::ok(ok.into()),
48 Err(err) => QueryService::build_response(err, correlation_id)
49 }
50 }
51
52 async fn do_bulk_tx(&self,
53 client_id: &str,
54 request: &TxBulkQueryRequest) -> Result<Vec<TxQuerySuccessResponse>, QueryError> {
55 let connection_id: &str = &request.connection_id;
56
57 if !self.access_control_service.is_allowed(client_id, connection_id)? {
58 return Err(QueryError::ForbiddenConnectionId);
59 }
60
61 if request.queries.is_empty() {
62 return Ok(Vec::new());
63 }
64
65 match self.connection_service.get(connection_id).await {
66 Some(connection_result) => {
67 let mut connection = connection_result?;
68
69 let tx: Transaction = connection.transaction().await?;
70
71 let mut results = Vec::new();
72
73 for (i, query_request_body) in request.queries.iter().enumerate() {
74 let query_response =
75 QueryService::do_execute_bulk(&tx, &query_request_body,i)
76 .await?;
77
78 results.push(query_response);
79 }
80
81 tx.commit().await?;
82
83 Ok(results)
84 },
85 None => Err(QueryError::UnknownDatabaseConnection(connection_id.to_owned())),
86 }
87 }
88
89 async fn do_execute_bulk(tx: &Transaction<'_>,
90 bulk_body: &TxBulkQueryRequestBody,
91 ord_num: usize) -> Result<TxQuerySuccessResponse, QueryError> {
92 let stmt =
93 tx.prepare_cached(&bulk_body.query).await?;
94
95 let mut results = Vec::new();
96
97 for params_row in &bulk_body.params {
98 let param_values: Vec<&(dyn ToSql + Sync)> = convert_params(
99 stmt.params(),
100 ¶ms_row.values
101 )?;
102
103 let query_results =
104 tx.query(&stmt, param_values.as_slice()).await?;
105
106 results.push(convert_rows(query_results)?);
107 }
108
109 let column_names =
110 results.first()
111 .map_or(Vec::new(), |cwr| cwr.1.clone());
112
113 let row_groups =
114 results.into_iter()
115 .map(|cwr| RowResponseGroup { rows: cwr.0 })
116 .collect();
117
118 Ok(TxQuerySuccessResponse {
119 ord_num: ord_num as i32,
120 column_names,
121 row_groups
122 })
123 }
124
125 async fn do_query(&self,
126 client_id: &str,
127 request: &QueryRequest) -> Result<QuerySuccessResponse, QueryError> {
128 let connection_id: &str = &request.connection_id;
129
130 if !self.access_control_service.is_allowed(client_id, connection_id)? {
131 return Err(QueryError::ForbiddenConnectionId);
132 }
133
134 match self.connection_service.get(connection_id).await {
135 Some(connection_result) => {
136 let connection = connection_result?;
137
138 let stmt = connection.prepare_cached(&request.query).await?;
139
140 let params: Vec<&(dyn ToSql + Sync)> =
141 convert_params(stmt.params(), &request.params)?;
142
143 let results =
144 connection.query(&stmt, params.as_slice()).await?;
145
146 let cwr = convert_rows(results)?;
147
148 Ok(
149 QuerySuccessResponse {
150 rows: cwr.0,
151 column_names: cwr.1
152 }
153 )
154 }
155 None => Err(QueryError::UnknownDatabaseConnection(connection_id.to_owned()))
156 }
157 }
158
159 fn build_response<T: From<ErrorResponse>>(err: QueryError,
160 correlation_id: &str) -> ResponseWithCode<T> {
161 let code = err.get_code();
162 ResponseWithCode(err, code)
163 .map(|err|
164 err.to_error_response(correlation_id.to_string()))
165 .map(|err_response| err_response.into())
166 }
167}
168
169impl From<QuerySuccessResponse> for QueryResponse {
170 fn from(success: QuerySuccessResponse) -> Self {
171 QueryResponse {
172 payload: Some(query_response::Payload::Success(success))
173 }
174 }
175}
176
177impl From<ErrorResponse> for QueryResponse {
178 fn from(err: ErrorResponse) -> Self {
179 QueryResponse {
180 payload: Some(query_response::Payload::Error(err))
181 }
182 }
183}
184
185impl From<Vec<TxQuerySuccessResponse>> for TxBulkQueryResponse {
186 fn from(responses: Vec<TxQuerySuccessResponse>) -> Self {
187 TxBulkQueryResponse {
188 payload: Some(tx_bulk_query_response::Payload::Success(
189 TxBulkQuerySuccessResponse {
190 responses
191 }))
192 }
193 }
194}
195
196impl From<ErrorResponse> for TxBulkQueryResponse {
197 fn from(err: ErrorResponse) -> Self {
198 TxBulkQueryResponse {
199 payload: Some(tx_bulk_query_response::Payload::Error(err))
200 }
201 }
202}
203
204