Skip to main content

prax_cassandra/
engine.rs

1//! Query execution engine.
2//!
3//! This module defines the public query API (query/execute/batch/LWT/paging).
4//! Routes every statement through the cdrs-tokio session held by the
5//! underlying [`CassandraPool`].
6
7use crate::error::{CassandraError, CassandraResult};
8use crate::pool::CassandraPool;
9use crate::row::{FromRow, Row};
10
11/// Aggregate result of a CQL query.
12#[derive(Debug, Default)]
13pub struct QueryResult {
14    /// Rows returned by the query. Empty for non-SELECT statements.
15    pub rows: Vec<Row>,
16    /// Whether a lightweight transaction applied.
17    pub applied: Option<bool>,
18}
19
20impl CassandraPool {
21    /// Execute a query returning rows.
22    pub async fn query(&self, cql: &str) -> CassandraResult<QueryResult> {
23        let envelope = self
24            .connection()
25            .session()
26            .query(cql)
27            .await
28            .map_err(|e| CassandraError::Query(format!("query failed: {e}")))?;
29
30        // Parse the response. SELECT responses carry a ResponseBody::Result
31        // with rows; INSERT/UPDATE/DELETE responses typically carry an
32        // empty result. LWT responses carry a single row with the
33        // `[applied]` boolean column first.
34        let body = envelope
35            .response_body()
36            .map_err(|e| CassandraError::Query(format!("response body parse: {e}")))?;
37
38        let (rows, applied) = if let Some(raw_rows) = body.into_rows() {
39            // LWT responses carry the applied-boolean as the first column
40            // of a single row. Detect that shape by checking whether the
41            // result set is exactly one row and the first column is a
42            // boolean named "[applied]".
43            let applied = raw_rows.first().and_then(|row| {
44                use cdrs_tokio::types::ByName;
45                row.by_name::<bool>("[applied]").ok().flatten()
46            });
47            let decoded: Vec<crate::row::Row> = raw_rows
48                .into_iter()
49                .map(|r| crate::row::Row::from_cdrs_row(&r))
50                .collect::<CassandraResult<_>>()?;
51            (decoded, applied)
52        } else {
53            (Vec::new(), None)
54        };
55
56        Ok(QueryResult { rows, applied })
57    }
58
59    /// Execute a statement not expecting rows (INSERT, UPDATE, DELETE, DDL).
60    pub async fn execute(&self, cql: &str) -> CassandraResult<()> {
61        self.connection()
62            .session()
63            .query(cql)
64            .await
65            .map_err(|e| CassandraError::Query(format!("execute failed: {e}")))?;
66        Ok(())
67    }
68
69    /// Query a single row, deserialized into T.
70    pub async fn query_one<T: FromRow>(&self, cql: &str) -> CassandraResult<T> {
71        let result = self.query(cql).await?;
72        let row = result
73            .rows
74            .into_iter()
75            .next()
76            .ok_or_else(|| CassandraError::Query("query_one: no rows returned".into()))?;
77        T::from_row(&row)
78    }
79
80    /// Query many rows.
81    pub async fn query_many<T: FromRow>(&self, cql: &str) -> CassandraResult<Vec<T>> {
82        let result = self.query(cql).await?;
83        result.rows.iter().map(|row| T::from_row(row)).collect()
84    }
85
86    /// Execute a lightweight transaction. Returns whether the CAS succeeded.
87    pub async fn execute_lwt(&self, cql: &str) -> CassandraResult<bool> {
88        let result = self.query(cql).await?;
89        Ok(result.applied.unwrap_or(false))
90    }
91
92    /// Build a batch of statements.
93    pub fn batch(&self) -> BatchBuilder<'_> {
94        BatchBuilder {
95            pool: self,
96            statements: Vec::new(),
97        }
98    }
99}
100
101/// Builder for a CQL batch.
102pub struct BatchBuilder<'a> {
103    pool: &'a CassandraPool,
104    statements: Vec<String>,
105}
106
107impl<'a> BatchBuilder<'a> {
108    /// Add a statement to the batch.
109    pub fn add_statement(mut self, cql: impl Into<String>) -> Self {
110        self.statements.push(cql.into());
111        self
112    }
113
114    /// Execute the batch as a LOGGED batch (default).
115    pub async fn execute(self) -> CassandraResult<()> {
116        self.execute_logged().await
117    }
118
119    /// Execute the batch as a LOGGED batch.
120    pub async fn execute_logged(self) -> CassandraResult<()> {
121        self.execute_with_type(cdrs_tokio::frame::message_batch::BatchType::Logged)
122            .await
123    }
124
125    /// Execute the batch as an UNLOGGED batch.
126    pub async fn execute_unlogged(self) -> CassandraResult<()> {
127        self.execute_with_type(cdrs_tokio::frame::message_batch::BatchType::Unlogged)
128            .await
129    }
130
131    /// Execute the batch as a COUNTER batch.
132    pub async fn execute_counter(self) -> CassandraResult<()> {
133        self.execute_with_type(cdrs_tokio::frame::message_batch::BatchType::Counter)
134            .await
135    }
136
137    async fn execute_with_type(
138        self,
139        batch_type: cdrs_tokio::frame::message_batch::BatchType,
140    ) -> CassandraResult<()> {
141        if self.statements.is_empty() {
142            return Err(CassandraError::Query("cannot execute empty batch".into()));
143        }
144        let mut builder = cdrs_tokio::query::BatchQueryBuilder::new().with_batch_type(batch_type);
145        for stmt in self.statements {
146            builder = builder.add_query(stmt, cdrs_tokio::query::QueryValues::SimpleValues(vec![]));
147        }
148        let batch = builder
149            .build()
150            .map_err(|e| CassandraError::Query(format!("batch build: {e}")))?;
151        self.pool
152            .connection()
153            .session()
154            .batch(batch)
155            .await
156            .map_err(|e| CassandraError::Query(format!("batch execute: {e}")))?;
157        Ok(())
158    }
159
160    /// Number of statements in the batch (for test/debug).
161    pub fn len(&self) -> usize {
162        self.statements.len()
163    }
164
165    /// True if the batch has no statements.
166    pub fn is_empty(&self) -> bool {
167        self.statements.is_empty()
168    }
169}
170
171/// Top-level query engine for the Cassandra driver.
172///
173/// Thin wrapper around [`CassandraPool`] that lets `#[derive(Model)]`-
174/// generated `Client<E>` target Cassandra through the same codegen
175/// pipeline the SQL drivers use. Routes SELECT/DELETE through the real
176/// cdrs-tokio session; `execute_update` runs the UPDATE then re-
177/// SELECTs rows matching the WHERE clause; `execute_insert` currently
178/// returns [`QueryError::unsupported`] — the pool's query/execute API
179/// doesn't accept bound params yet, so a safe PK-keyed follow-up
180/// SELECT isn't possible. Prefer [`prax_scylladb::ScyllaEngine`] for
181/// typed Client inserts against any CQL-compatible cluster.
182#[derive(Clone)]
183pub struct CassandraEngine {
184    pool: CassandraPool,
185}
186
187impl CassandraEngine {
188    /// Create a new engine wrapping the given pool.
189    pub fn new(pool: CassandraPool) -> Self {
190        Self { pool }
191    }
192
193    /// Borrow the underlying pool. Exposed for callers that need to
194    /// reach the raw query/execute/batch helpers directly.
195    pub fn pool(&self) -> &CassandraPool {
196        &self.pool
197    }
198}
199
200impl prax_query::traits::QueryEngine for CassandraEngine {
201    fn dialect(&self) -> &dyn prax_query::dialect::SqlDialect {
202        &prax_query::dialect::Cql
203    }
204
205    fn query_many<T: prax_query::traits::Model + prax_query::row::FromRow + Send + 'static>(
206        &self,
207        sql: &str,
208        _params: Vec<prax_query::filter::FilterValue>,
209    ) -> prax_query::traits::BoxFuture<'_, prax_query::QueryResult<Vec<T>>> {
210        let sql = sql.to_string();
211        let pool = self.pool.clone();
212        Box::pin(async move {
213            let result = pool
214                .query(&sql)
215                .await
216                .map_err(|e| prax_query::QueryError::database(e.to_string()).with_source(e))?;
217            result
218                .rows
219                .iter()
220                .map(|r| r.as_cdrs())
221                .map(decode_row::<T>)
222                .collect()
223        })
224    }
225
226    fn query_one<T: prax_query::traits::Model + prax_query::row::FromRow + Send + 'static>(
227        &self,
228        sql: &str,
229        _params: Vec<prax_query::filter::FilterValue>,
230    ) -> prax_query::traits::BoxFuture<'_, prax_query::QueryResult<T>> {
231        let sql = sql.to_string();
232        let pool = self.pool.clone();
233        Box::pin(async move {
234            let result = pool
235                .query(&sql)
236                .await
237                .map_err(|e| prax_query::QueryError::database(e.to_string()).with_source(e))?;
238            let cdrs_row = result
239                .rows
240                .iter()
241                .map(|r| r.as_cdrs())
242                .next()
243                .ok_or_else(|| prax_query::QueryError::not_found(T::MODEL_NAME))?;
244            decode_row::<T>(cdrs_row)
245        })
246    }
247
248    fn query_optional<T: prax_query::traits::Model + prax_query::row::FromRow + Send + 'static>(
249        &self,
250        sql: &str,
251        _params: Vec<prax_query::filter::FilterValue>,
252    ) -> prax_query::traits::BoxFuture<'_, prax_query::QueryResult<Option<T>>> {
253        let sql = sql.to_string();
254        let pool = self.pool.clone();
255        Box::pin(async move {
256            let result = pool
257                .query(&sql)
258                .await
259                .map_err(|e| prax_query::QueryError::database(e.to_string()).with_source(e))?;
260            result
261                .rows
262                .iter()
263                .map(|r| r.as_cdrs())
264                .next()
265                .map(decode_row::<T>)
266                .transpose()
267        })
268    }
269
270    fn execute_insert<T: prax_query::traits::Model + prax_query::row::FromRow + Send + 'static>(
271        &self,
272        sql: &str,
273        _params: Vec<prax_query::filter::FilterValue>,
274    ) -> prax_query::traits::BoxFuture<'_, prax_query::QueryResult<T>> {
275        // CassandraPool::query/execute doesn't accept bound params yet —
276        // the prepared-statement integration is a follow-up task. Without
277        // real parameter binding, a PK-keyed follow-up SELECT can't be
278        // built safely (a LIMIT 1 with no WHERE would race concurrent
279        // writers and return the wrong row). Refuse rather than fabricate
280        // a result. The Scylla driver is feature-complete on this path
281        // and is the recommended CQL backend for typed Client inserts.
282        let _ = (sql, T::MODEL_NAME);
283        Box::pin(async move {
284            Err(prax_query::QueryError::unsupported(
285                "CassandraEngine::execute_insert requires prepared-statement \
286                 binding to safely re-fetch by PK; use ScyllaEngine or call \
287                 pool.execute + pool.query manually",
288            ))
289        })
290    }
291
292    fn execute_update<T: prax_query::traits::Model + prax_query::row::FromRow + Send + 'static>(
293        &self,
294        sql: &str,
295        _params: Vec<prax_query::filter::FilterValue>,
296    ) -> prax_query::traits::BoxFuture<'_, prax_query::QueryResult<Vec<T>>> {
297        let sql = sql.to_string();
298        let pool = self.pool.clone();
299        Box::pin(async move {
300            pool.execute(&sql)
301                .await
302                .map_err(|e| prax_query::QueryError::database(e.to_string()).with_source(e))?;
303            // Recover the WHERE clause from the generated UPDATE so the
304            // follow-up SELECT touches the same rows. Refuse to SELECT
305            // everything on a WHERE-less UPDATE — that would be a
306            // worse failure mode than erroring.
307            let where_clause = extract_where_clause(&sql).ok_or_else(|| {
308                prax_query::QueryError::internal(
309                    "CassandraEngine::execute_update: UPDATE lacked a WHERE \
310                     clause; refusing to SELECT entire table",
311                )
312            })?;
313            let select_sql = format!(
314                "SELECT {} FROM {} WHERE {}",
315                T::COLUMNS.join(", "),
316                T::TABLE_NAME,
317                where_clause,
318            );
319            let result = pool
320                .query(&select_sql)
321                .await
322                .map_err(|e| prax_query::QueryError::database(e.to_string()).with_source(e))?;
323            result
324                .rows
325                .iter()
326                .map(|r| r.as_cdrs())
327                .map(decode_row::<T>)
328                .collect()
329        })
330    }
331
332    fn execute_delete(
333        &self,
334        sql: &str,
335        _params: Vec<prax_query::filter::FilterValue>,
336    ) -> prax_query::traits::BoxFuture<'_, prax_query::QueryResult<u64>> {
337        let sql = sql.to_string();
338        let pool = self.pool.clone();
339        Box::pin(async move {
340            let _: () = pool
341                .execute(&sql)
342                .await
343                .map_err(|e| prax_query::QueryError::database(e.to_string()).with_source(e))?;
344            Ok(0)
345        })
346    }
347
348    fn execute_raw(
349        &self,
350        sql: &str,
351        params: Vec<prax_query::filter::FilterValue>,
352    ) -> prax_query::traits::BoxFuture<'_, prax_query::QueryResult<u64>> {
353        self.execute_delete(sql, params)
354    }
355
356    fn count(
357        &self,
358        sql: &str,
359        _params: Vec<prax_query::filter::FilterValue>,
360    ) -> prax_query::traits::BoxFuture<'_, prax_query::QueryResult<u64>> {
361        let sql = sql.to_string();
362        let pool = self.pool.clone();
363        Box::pin(async move {
364            let _: QueryResult = pool
365                .query(&sql)
366                .await
367                .map_err(|e| prax_query::QueryError::database(e.to_string()).with_source(e))?;
368            Ok(0)
369        })
370    }
371}
372
373// WHERE-clause extraction lives in prax_query::sql::parse — import
374// here under the old name to minimise churn.
375use prax_query::sql::parse::extract_where_body as extract_where_clause;
376
377/// Decode one cdrs-tokio row into the caller's `T: Model + FromRow`.
378/// Shared by every QueryEngine method that hands back typed rows, so
379/// the column-list allocation and error-wrapping stay in one place.
380fn decode_row<T: prax_query::traits::Model + prax_query::row::FromRow>(
381    cdrs_row: &cdrs_tokio::types::rows::Row,
382) -> prax_query::QueryResult<T> {
383    let cols: Vec<String> = T::COLUMNS.iter().map(|s| s.to_string()).collect();
384    let rr = crate::row_ref::CassandraRowRef::from_cdrs_with_cols(cdrs_row, &cols);
385    T::from_row(&rr).map_err(|e| {
386        let msg = e.to_string();
387        prax_query::QueryError::deserialization(msg).with_source(e)
388    })
389}
390
391#[cfg(test)]
392mod tests {
393    use super::*;
394    use crate::config::CassandraConfig;
395
396    #[tokio::test]
397    async fn test_query_without_connection_returns_error() {
398        let config = CassandraConfig::builder()
399            .known_nodes(["127.0.0.1:9042".to_string()])
400            .build();
401        // Pool.connect returns an error in the stub phase, so we can't
402        // build a pool here. Instead, construct the error directly via
403        // the assertion below. This test primarily exercises the API
404        // surface compiles.
405        let _ = config;
406    }
407
408    #[test]
409    fn test_batch_builder_add_increments_len() {
410        // Construct a fake pool surface through a compile-check-only path.
411        // We can't instantiate a real pool without a live cluster, so this
412        // test lives as a TODO placeholder; live integration covers the
413        // real behavior.
414        let stmts: Vec<String> = vec!["INSERT INTO t VALUES (1)".into()];
415        assert_eq!(stmts.len(), 1);
416    }
417
418    #[test]
419    fn test_query_result_default_is_empty() {
420        let r = QueryResult::default();
421        assert!(r.rows.is_empty());
422        assert!(r.applied.is_none());
423    }
424}