Skip to main content

teaql_sql/
executor.rs

1#![allow(async_fn_in_trait)]
2
3use std::time::SystemTime;
4use teaql_core::Record;
5use teaql_data_service::{
6    DataServiceCapabilities, DataServiceExecutor, DataServiceOperation, ExecutionMetadata,
7    MutationExecutor, MutationRequest, MutationResult, QueryExecutor, QueryRequest, QueryResult,
8};
9
10use crate::{CompiledQuery, SqlCompileError, SqlDialect};
11
12pub trait SqlTransport: Send + Sync {
13    type Error: std::error::Error + Send + Sync + 'static;
14
15    fn fetch_all_sql(&self, query: &CompiledQuery) -> impl std::future::Future<Output = Result<Vec<Record>, Self::Error>> + Send;
16    fn execute_sql(&self, query: &CompiledQuery) -> impl std::future::Future<Output = Result<u64, Self::Error>> + Send;
17}
18
19pub trait SqlTransactionTransport: SqlTransport {
20    type Tx<'a>: SqlTransport<Error = Self::Error> + SqlTransaction<Error = Self::Error> + Send + Sync + 'a
21    where
22        Self: 'a;
23
24    fn begin_sql(&self) -> impl std::future::Future<Output = Result<Self::Tx<'_>, Self::Error>> + Send;
25}
26
27pub trait SqlTransaction {
28    type Error: std::error::Error + Send + Sync + 'static;
29    fn commit_sql(self) -> impl std::future::Future<Output = Result<(), Self::Error>> + Send;
30    fn rollback_sql(self) -> impl std::future::Future<Output = Result<(), Self::Error>> + Send;
31}
32
33#[derive(Debug)]
34pub enum SqlExecutorError<E: std::error::Error + Send + Sync + 'static> {
35    Compile(SqlCompileError),
36    Transport(E),
37}
38
39impl<E: std::error::Error + Send + Sync + 'static> std::fmt::Display for SqlExecutorError<E> {
40    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
41        match self {
42            SqlExecutorError::Compile(e) => write!(f, "SQL compile error: {}", e),
43            SqlExecutorError::Transport(e) => write!(f, "Transport error: {}", e),
44        }
45    }
46}
47
48impl<E: std::error::Error + Send + Sync + 'static> std::error::Error for SqlExecutorError<E> {}
49
50#[derive(Clone)]
51pub struct SqlDataServiceExecutor<D, T, S> {
52    pub dialect: D,
53    pub transport: T,
54    pub schema_provider: S,
55}
56
57impl<D, T, S> SqlDataServiceExecutor<D, T, S> {
58    pub fn new(dialect: D, transport: T, schema_provider: S) -> Self {
59        Self { dialect, transport, schema_provider }
60    }
61}
62
63impl<D: SqlDialect + Send + Sync, T: SqlTransport + Send + Sync, S: teaql_data_service::SchemaProvider + Send + Sync> DataServiceExecutor
64    for SqlDataServiceExecutor<D, T, S>
65{
66    type Error = SqlExecutorError<T::Error>;
67
68    fn capabilities(&self) -> DataServiceCapabilities {
69        DataServiceCapabilities {
70            query: true,
71            mutation: true,
72            transaction: false, // Override if T implements SqlTransactionTransport
73            schema: false,
74            id_generation: false,
75            batch_mutation: true,
76            returning: false,
77        }
78    }
79}
80
81impl<D: SqlDialect + Send + Sync, T: SqlTransport + Send + Sync, S: teaql_data_service::SchemaProvider + Send + Sync> QueryExecutor
82    for SqlDataServiceExecutor<D, T, S>
83{
84    fn query(&self, request: QueryRequest) -> impl std::future::Future<Output = Result<QueryResult, Self::Error>> + Send { async move {
85        let entity_desc = self.schema_provider.get_entity(&request.query.entity)
86            .ok_or_else(|| SqlExecutorError::Compile(SqlCompileError::UnknownEntity(request.query.entity.clone())))?;
87
88        let compiled = self.dialect.compile_select(&entity_desc, &request.query).map_err(SqlExecutorError::Compile)?;
89        let start = SystemTime::now();
90        let rows = self.transport.fetch_all_sql(&compiled).await.map_err(SqlExecutorError::Transport)?;
91        let end = SystemTime::now();
92
93        let metadata = ExecutionMetadata {
94            backend: "sql".to_string(),
95            operation: DataServiceOperation::Query,
96            started_at: start,
97            ended_at: end,
98            affected_rows: None,
99            result_count: Some(rows.len()),
100            trace_chain: request.trace_chain,
101            comment: request.comment,
102            backend_request_id: None,
103            debug_query: Some(compiled.debug_sql(self.dialect.kind())),
104        };
105
106        Ok(QueryResult { rows, metadata })
107    } }
108}
109
110impl<D: SqlDialect + Send + Sync, T: SqlTransport + Send + Sync, S: teaql_data_service::SchemaProvider + Send + Sync> MutationExecutor
111    for SqlDataServiceExecutor<D, T, S>
112{
113    fn mutate(&self, request: MutationRequest) -> impl std::future::Future<Output = Result<MutationResult, Self::Error>> + Send { async move {
114        let entity_name = match &request {
115            MutationRequest::Insert(cmd) => &cmd.entity,
116            MutationRequest::Update(cmd) => &cmd.entity,
117            MutationRequest::Delete(cmd) => &cmd.entity,
118            MutationRequest::Recover(cmd) => &cmd.entity,
119            MutationRequest::Batch(mutations) => {
120                let mut total_affected = 0;
121                let start = SystemTime::now();
122                for req in mutations {
123                    let res = Box::pin(self.mutate(req.clone())).await?;
124                    total_affected += res.affected_rows;
125                }
126                let end = SystemTime::now();
127                return Ok(MutationResult {
128                    affected_rows: total_affected,
129                    generated_values: Record::default(),
130                    metadata: ExecutionMetadata {
131                        backend: "sql".to_string(),
132                        operation: DataServiceOperation::Batch,
133                        started_at: start,
134                        ended_at: end,
135                        affected_rows: Some(total_affected),
136                        result_count: None,
137                        trace_chain: Vec::new(),
138                        comment: None,
139                        backend_request_id: None,
140                        debug_query: None,
141                    },
142                });
143            }
144        };
145        
146        let entity_desc = self.schema_provider.get_entity(entity_name)
147            .ok_or_else(|| SqlExecutorError::Compile(SqlCompileError::UnknownEntity(entity_name.clone())))?;
148
149        let compiled = match &request {
150            MutationRequest::Insert(cmd) => self.dialect.compile_insert(&entity_desc, cmd).map_err(SqlExecutorError::Compile)?,
151            MutationRequest::Update(cmd) => self.dialect.compile_update(&entity_desc, cmd).map_err(SqlExecutorError::Compile)?,
152            MutationRequest::Delete(cmd) => self.dialect.compile_delete(&entity_desc, cmd).map_err(SqlExecutorError::Compile)?,
153            MutationRequest::Recover(cmd) => self.dialect.compile_recover(&entity_desc, cmd).map_err(SqlExecutorError::Compile)?,
154            MutationRequest::Batch(_) => unreachable!(),
155        };
156
157        let start = SystemTime::now();
158        let affected_rows = self.transport.execute_sql(&compiled).await.map_err(SqlExecutorError::Transport)?;
159        let end = SystemTime::now();
160
161        let operation = match &request {
162            MutationRequest::Insert(_) => DataServiceOperation::Insert,
163            MutationRequest::Update(_) => DataServiceOperation::Update,
164            MutationRequest::Delete(_) => DataServiceOperation::Delete,
165            MutationRequest::Recover(_) => DataServiceOperation::Recover,
166            MutationRequest::Batch(_) => DataServiceOperation::Batch,
167        };
168
169        let metadata = ExecutionMetadata {
170            backend: "sql".to_string(),
171            operation,
172            started_at: start,
173            ended_at: end,
174            affected_rows: Some(affected_rows),
175            result_count: None,
176            trace_chain: request.trace_chain().to_vec(),
177            comment: request.comment().map(|s| s.to_owned()),
178            backend_request_id: None,
179            debug_query: Some(compiled.debug_sql(self.dialect.kind())),
180        };
181
182        Ok(MutationResult {
183            affected_rows,
184            generated_values: Record::default(),
185            metadata,
186        })
187    } }
188}
189
190#[derive(Clone)]
191pub struct SqlDataServiceTransaction<'a, D, Tx: SqlTransport + SqlTransaction, S> {
192    pub dialect: &'a D,
193    pub transport: Tx,
194    pub schema_provider: &'a S,
195}
196
197impl<'a, D: SqlDialect + Send + Sync, Tx: SqlTransport + SqlTransaction<Error = <Tx as SqlTransport>::Error> + Send + Sync, S: teaql_data_service::SchemaProvider + Send + Sync> DataServiceExecutor
198    for SqlDataServiceTransaction<'a, D, Tx, S>
199{
200    type Error = SqlExecutorError<<Tx as SqlTransport>::Error>;
201
202    fn capabilities(&self) -> DataServiceCapabilities {
203        DataServiceCapabilities {
204            query: true,
205            mutation: true,
206            transaction: false,
207            schema: false,
208            id_generation: false,
209            batch_mutation: true,
210            returning: false,
211        }
212    }
213}
214
215impl<'a, D: SqlDialect + Send + Sync, Tx: SqlTransport + SqlTransaction<Error = <Tx as SqlTransport>::Error> + Send + Sync, S: teaql_data_service::SchemaProvider + Send + Sync> QueryExecutor
216    for SqlDataServiceTransaction<'a, D, Tx, S>
217{
218    fn query(&self, request: QueryRequest) -> impl std::future::Future<Output = Result<QueryResult, Self::Error>> + Send { async move {
219        let entity_desc = self.schema_provider.get_entity(&request.query.entity)
220            .ok_or_else(|| SqlExecutorError::Compile(SqlCompileError::UnknownEntity(request.query.entity.clone())))?;
221
222        let compiled = self.dialect.compile_select(&entity_desc, &request.query).map_err(SqlExecutorError::Compile)?;
223        let start = SystemTime::now();
224        let rows = self.transport.fetch_all_sql(&compiled).await.map_err(SqlExecutorError::Transport)?;
225        let end = SystemTime::now();
226
227        let metadata = ExecutionMetadata {
228            backend: "sql".to_string(),
229            operation: DataServiceOperation::Query,
230            started_at: start,
231            ended_at: end,
232            affected_rows: None,
233            result_count: Some(rows.len()),
234            trace_chain: request.trace_chain,
235            comment: request.comment,
236            backend_request_id: None,
237            debug_query: Some(compiled.debug_sql(self.dialect.kind())),
238        };
239
240        Ok(QueryResult { rows, metadata })
241    } }
242}
243
244impl<'a, D: SqlDialect + Send + Sync, Tx: SqlTransport + SqlTransaction<Error = <Tx as SqlTransport>::Error> + Send + Sync, S: teaql_data_service::SchemaProvider + Send + Sync> MutationExecutor
245    for SqlDataServiceTransaction<'a, D, Tx, S>
246{
247    fn mutate(&self, request: MutationRequest) -> impl std::future::Future<Output = Result<MutationResult, Self::Error>> + Send { async move {
248        let entity_name = match &request {
249            MutationRequest::Insert(cmd) => &cmd.entity,
250            MutationRequest::Update(cmd) => &cmd.entity,
251            MutationRequest::Delete(cmd) => &cmd.entity,
252            MutationRequest::Recover(cmd) => &cmd.entity,
253            MutationRequest::Batch(mutations) => {
254                let mut total_affected = 0;
255                let start = SystemTime::now();
256                for req in mutations {
257                    let res = Box::pin(self.mutate(req.clone())).await?;
258                    total_affected += res.affected_rows;
259                }
260                let end = SystemTime::now();
261                return Ok(MutationResult {
262                    affected_rows: total_affected,
263                    generated_values: Record::default(),
264                    metadata: ExecutionMetadata {
265                        backend: "sql".to_string(),
266                        operation: DataServiceOperation::Batch,
267                        started_at: start,
268                        ended_at: end,
269                        affected_rows: Some(total_affected),
270                        result_count: None,
271                        trace_chain: Vec::new(),
272                        comment: None,
273                        backend_request_id: None,
274                        debug_query: None,
275                    },
276                });
277            }
278        };
279        
280        let entity_desc = self.schema_provider.get_entity(entity_name)
281            .ok_or_else(|| SqlExecutorError::Compile(SqlCompileError::UnknownEntity(entity_name.clone())))?;
282
283        let compiled = match &request {
284            MutationRequest::Insert(cmd) => self.dialect.compile_insert(&entity_desc, cmd).map_err(SqlExecutorError::Compile)?,
285            MutationRequest::Update(cmd) => self.dialect.compile_update(&entity_desc, cmd).map_err(SqlExecutorError::Compile)?,
286            MutationRequest::Delete(cmd) => self.dialect.compile_delete(&entity_desc, cmd).map_err(SqlExecutorError::Compile)?,
287            MutationRequest::Recover(cmd) => self.dialect.compile_recover(&entity_desc, cmd).map_err(SqlExecutorError::Compile)?,
288            MutationRequest::Batch(_) => unreachable!("batch handled above"),
289        };
290
291        let start = SystemTime::now();
292        let affected_rows = self.transport.execute_sql(&compiled).await.map_err(SqlExecutorError::Transport)?;
293        let end = SystemTime::now();
294
295        let operation = match &request {
296            MutationRequest::Insert(_) => DataServiceOperation::Insert,
297            MutationRequest::Update(_) => DataServiceOperation::Update,
298            MutationRequest::Delete(_) => DataServiceOperation::Delete,
299            MutationRequest::Recover(_) => DataServiceOperation::Recover,
300            MutationRequest::Batch(_) => DataServiceOperation::Batch,
301        };
302
303        let metadata = ExecutionMetadata {
304            backend: "sql".to_string(),
305            operation,
306            started_at: start,
307            ended_at: end,
308            affected_rows: Some(affected_rows),
309            result_count: None,
310            trace_chain: request.trace_chain().to_vec(),
311            comment: request.comment().map(|s| s.to_owned()),
312            backend_request_id: None,
313            debug_query: Some(compiled.debug_sql(self.dialect.kind())),
314        };
315
316        Ok(MutationResult {
317            affected_rows,
318            generated_values: Record::default(),
319            metadata,
320        })
321    } }
322}
323
324impl<'a, D: SqlDialect + Send + Sync, Tx: SqlTransport + SqlTransaction<Error = <Tx as SqlTransport>::Error> + Send + Sync, S: teaql_data_service::SchemaProvider + Send + Sync> teaql_data_service::Transaction
325    for SqlDataServiceTransaction<'a, D, Tx, S>
326{
327    type Error = SqlExecutorError<<Tx as SqlTransport>::Error>;
328
329    fn commit(self) -> impl std::future::Future<Output = Result<(), Self::Error>> + Send { async move {
330        self.transport.commit_sql().await.map_err(SqlExecutorError::Transport)
331    } }
332
333    fn rollback(self) -> impl std::future::Future<Output = Result<(), Self::Error>> + Send { async move {
334        self.transport.rollback_sql().await.map_err(SqlExecutorError::Transport)
335    } }
336}
337
338impl<D: SqlDialect + Send + Sync, T: SqlTransactionTransport + Send + Sync, S: teaql_data_service::SchemaProvider + Send + Sync> teaql_data_service::TransactionExecutor
339    for SqlDataServiceExecutor<D, T, S>
340{
341    type Tx<'a> = SqlDataServiceTransaction<'a, D, T::Tx<'a>, S> where Self: 'a;
342
343    fn begin(&self) -> impl std::future::Future<Output = Result<Self::Tx<'_>, Self::Error>> + Send { async move {
344        let tx = self.transport.begin_sql().await.map_err(SqlExecutorError::Transport)?;
345        Ok(SqlDataServiceTransaction {
346            dialect: &self.dialect,
347            transport: tx,
348            schema_provider: &self.schema_provider,
349        })
350    } }
351}