Skip to main content

mssql_client/
stream.rs

1//! Streaming query result support.
2//!
3//! This module provides streaming result sets for memory-efficient
4//! processing of large query results.
5//!
6//! ## Buffered vs True Streaming
7//!
8//! The current implementation uses a buffered approach where all rows from
9//! the TDS response are parsed upfront. This works well because:
10//!
11//! 1. TDS responses arrive as complete messages (reassembled by mssql-codec)
12//! 2. Memory is shared via `Arc<Bytes>` pattern per ADR-004
13//! 3. No complex lifetime/borrow issues with the connection
14//!
15//! For truly large result sets, consider using OFFSET/FETCH pagination.
16
17use std::collections::VecDeque;
18use std::pin::Pin;
19use std::task::{Context, Poll};
20
21use futures_core::Stream;
22
23use crate::error::Error;
24use crate::row::{Column, Row};
25
26/// A streaming result set from a query.
27///
28/// This stream yields rows one at a time, allowing processing of
29/// large result sets without loading everything into memory.
30///
31/// # Example
32///
33/// ```rust,ignore
34/// use futures::StreamExt;
35///
36/// let mut stream = client.query("SELECT * FROM large_table", &[]).await?;
37///
38/// while let Some(row) = stream.next().await {
39///     let row = row?;
40///     process_row(&row);
41/// }
42/// ```
43#[must_use = "streams must be consumed; dropping a stream discards remaining rows"]
44pub struct QueryStream<'a> {
45    /// Column metadata for the result set.
46    columns: Vec<Column>,
47    /// Buffered rows from the response.
48    rows: VecDeque<Row>,
49    /// Whether the stream has completed.
50    finished: bool,
51    /// Lifetime tied to the connection.
52    _marker: std::marker::PhantomData<&'a ()>,
53}
54
55impl QueryStream<'_> {
56    /// Create a new query stream with columns and buffered rows.
57    pub(crate) fn new(columns: Vec<Column>, rows: Vec<Row>) -> Self {
58        Self {
59            columns,
60            rows: rows.into(),
61            finished: false,
62            _marker: std::marker::PhantomData,
63        }
64    }
65
66    /// Create an empty query stream (no results).
67    #[allow(dead_code)]
68    pub(crate) fn empty() -> Self {
69        Self {
70            columns: Vec::new(),
71            rows: VecDeque::new(),
72            finished: true,
73            _marker: std::marker::PhantomData,
74        }
75    }
76
77    /// Get the column metadata for this result set.
78    #[must_use]
79    pub fn columns(&self) -> &[Column] {
80        &self.columns
81    }
82
83    /// Check if the stream has finished.
84    #[must_use]
85    pub fn is_finished(&self) -> bool {
86        self.finished
87    }
88
89    /// Get the number of rows remaining in the buffer.
90    #[must_use]
91    pub fn rows_remaining(&self) -> usize {
92        self.rows.len()
93    }
94
95    /// Collect all remaining rows into a vector.
96    ///
97    /// This consumes the stream and loads all rows into memory.
98    /// For large result sets, consider iterating with the stream instead.
99    pub async fn collect_all(mut self) -> Result<Vec<Row>, Error> {
100        // Drain all remaining rows from the buffer
101        let rows: Vec<Row> = self.rows.drain(..).collect();
102        self.finished = true;
103        Ok(rows)
104    }
105
106    /// Try to get the next row synchronously (without async).
107    ///
108    /// Returns `None` when no more rows are available.
109    pub fn try_next(&mut self) -> Option<Row> {
110        if self.finished {
111            return None;
112        }
113
114        match self.rows.pop_front() {
115            Some(row) => Some(row),
116            None => {
117                self.finished = true;
118                None
119            }
120        }
121    }
122}
123
124impl Stream for QueryStream<'_> {
125    type Item = Result<Row, Error>;
126
127    fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
128        let this = self.get_mut();
129
130        if this.finished {
131            return Poll::Ready(None);
132        }
133
134        // Pop the next row from the buffer
135        match this.rows.pop_front() {
136            Some(row) => Poll::Ready(Some(Ok(row))),
137            None => {
138                this.finished = true;
139                Poll::Ready(None)
140            }
141        }
142    }
143}
144
145impl ExactSizeIterator for QueryStream<'_> {}
146
147impl Iterator for QueryStream<'_> {
148    type Item = Result<Row, Error>;
149
150    fn next(&mut self) -> Option<Self::Item> {
151        if self.finished {
152            return None;
153        }
154
155        match self.rows.pop_front() {
156            Some(row) => Some(Ok(row)),
157            None => {
158                self.finished = true;
159                None
160            }
161        }
162    }
163
164    fn size_hint(&self) -> (usize, Option<usize>) {
165        let remaining = self.rows.len();
166        (remaining, Some(remaining))
167    }
168}
169
170/// Result of a non-query execution.
171///
172/// Contains the number of affected rows and any output parameters.
173#[derive(Debug, Clone)]
174#[non_exhaustive]
175#[must_use]
176pub struct ExecuteResult {
177    /// Number of rows affected by the statement.
178    pub rows_affected: u64,
179    /// Output parameters from stored procedures.
180    pub output_params: Vec<OutputParam>,
181}
182
183/// An output parameter from a stored procedure call.
184#[derive(Debug, Clone)]
185#[non_exhaustive]
186pub struct OutputParam {
187    /// Parameter name.
188    pub name: String,
189    /// Parameter value.
190    pub value: mssql_types::SqlValue,
191}
192
193impl ExecuteResult {
194    /// Create a new execute result.
195    pub fn new(rows_affected: u64) -> Self {
196        Self {
197            rows_affected,
198            output_params: Vec::new(),
199        }
200    }
201
202    /// Create a result with output parameters.
203    pub fn with_outputs(rows_affected: u64, output_params: Vec<OutputParam>) -> Self {
204        Self {
205            rows_affected,
206            output_params,
207        }
208    }
209
210    /// Get an output parameter by name.
211    #[must_use]
212    pub fn get_output(&self, name: &str) -> Option<&OutputParam> {
213        self.output_params
214            .iter()
215            .find(|p| p.name.eq_ignore_ascii_case(name))
216    }
217}
218
219/// Result of a stored procedure execution.
220///
221/// Contains the return value, affected row count, output parameters,
222/// and any result sets produced by the procedure.
223///
224/// # Example
225///
226/// ```rust,ignore
227/// let result = client.call_procedure("dbo.GetUser", &[&1i32]).await?;
228///
229/// // Check the return value (RETURN statement in the proc)
230/// assert_eq!(result.return_value, 0);
231///
232/// // Process result sets
233/// for mut rs in result.result_sets {
234///     while let Some(row) = rs.next_row() {
235///         println!("{:?}", row);
236///     }
237/// }
238/// ```
239#[derive(Debug, Clone)]
240#[non_exhaustive]
241#[must_use]
242pub struct ProcedureResult {
243    /// Return value from the stored procedure's RETURN statement.
244    ///
245    /// Defaults to 0 if the procedure does not explicitly return a value,
246    /// which matches SQL Server's default behavior.
247    pub return_value: i32,
248    /// Total number of rows affected by statements within the procedure.
249    pub rows_affected: u64,
250    /// Output parameters returned by the procedure.
251    pub output_params: Vec<OutputParam>,
252    /// Result sets produced by SELECT statements within the procedure.
253    pub result_sets: Vec<ResultSet>,
254}
255
256impl ProcedureResult {
257    /// Create a new empty procedure result.
258    pub(crate) fn new() -> Self {
259        Self {
260            return_value: 0,
261            rows_affected: 0,
262            output_params: Vec::new(),
263            result_sets: Vec::new(),
264        }
265    }
266
267    /// Get the return value from the stored procedure.
268    ///
269    /// This is the value from the procedure's `RETURN` statement.
270    /// Defaults to 0 if not explicitly set by the procedure.
271    #[must_use]
272    pub fn get_return_value(&self) -> i32 {
273        self.return_value
274    }
275
276    /// Get an output parameter by name (case-insensitive).
277    ///
278    /// Strips the `@` prefix from both the search name and stored names
279    /// before comparing, so `get_output("result")` and `get_output("@result")`
280    /// are equivalent.
281    ///
282    /// # Example
283    ///
284    /// ```rust,ignore
285    /// let result = client.procedure("dbo.CalculateSum")?
286    ///     .input("@a", &10i32)
287    ///     .input("@b", &20i32)
288    ///     .output_int("@result")
289    ///     .execute().await?;
290    ///
291    /// let output = result.get_output("@result").expect("output param exists");
292    /// assert_eq!(output.value, SqlValue::Int(30));
293    /// ```
294    #[must_use]
295    pub fn get_output(&self, name: &str) -> Option<&OutputParam> {
296        let search = name.strip_prefix('@').unwrap_or(name);
297        self.output_params.iter().find(|p| {
298            let stored = p.name.strip_prefix('@').unwrap_or(&p.name);
299            stored.eq_ignore_ascii_case(search)
300        })
301    }
302
303    /// Get the first result set, if any.
304    ///
305    /// Convenience method for procedures that return a single result set.
306    #[must_use]
307    pub fn first_result_set(&self) -> Option<&ResultSet> {
308        self.result_sets.first()
309    }
310
311    /// Check if the procedure produced any result sets.
312    #[must_use]
313    pub fn has_result_sets(&self) -> bool {
314        !self.result_sets.is_empty()
315    }
316}
317
318/// A single result set within a multi-result batch.
319#[derive(Debug, Clone)]
320#[must_use]
321pub struct ResultSet {
322    /// Column metadata for this result set.
323    columns: Vec<Column>,
324    /// Rows in this result set.
325    rows: VecDeque<Row>,
326}
327
328impl ResultSet {
329    /// Create a new result set.
330    pub fn new(columns: Vec<Column>, rows: Vec<Row>) -> Self {
331        Self {
332            columns,
333            rows: rows.into(),
334        }
335    }
336
337    /// Get the column metadata.
338    #[must_use]
339    pub fn columns(&self) -> &[Column] {
340        &self.columns
341    }
342
343    /// Get the number of rows remaining.
344    #[must_use]
345    pub fn rows_remaining(&self) -> usize {
346        self.rows.len()
347    }
348
349    /// Get the next row from this result set.
350    pub fn next_row(&mut self) -> Option<Row> {
351        self.rows.pop_front()
352    }
353
354    /// Check if this result set is empty.
355    #[must_use]
356    pub fn is_empty(&self) -> bool {
357        self.rows.is_empty()
358    }
359
360    /// Collect all remaining rows into a vector.
361    pub fn collect_all(&mut self) -> Vec<Row> {
362        self.rows.drain(..).collect()
363    }
364}
365
366/// Multiple result sets from a batch or stored procedure.
367///
368/// Some queries return multiple result sets (e.g., stored procedures
369/// with multiple SELECT statements, or batches with multiple queries).
370///
371/// # Example
372///
373/// ```rust,ignore
374/// // Execute a batch with multiple SELECTs
375/// let mut results = client.query_multiple("SELECT 1 AS a; SELECT 2 AS b, 3 AS c;", &[]).await?;
376///
377/// // Process first result set
378/// while let Some(row) = results.next_row().await? {
379///     println!("Result 1: {:?}", row);
380/// }
381///
382/// // Move to second result set
383/// if results.next_result().await? {
384///     while let Some(row) = results.next_row().await? {
385///         println!("Result 2: {:?}", row);
386///     }
387/// }
388/// ```
389#[must_use = "streams must be consumed; dropping a stream discards remaining results"]
390pub struct MultiResultStream<'a> {
391    /// All result sets from the batch.
392    result_sets: Vec<ResultSet>,
393    /// Current result set index (0-based).
394    current_result: usize,
395    /// Lifetime tied to the connection.
396    _marker: std::marker::PhantomData<&'a ()>,
397}
398
399impl<'a> MultiResultStream<'a> {
400    /// Create a new multi-result stream from parsed result sets.
401    pub(crate) fn new(result_sets: Vec<ResultSet>) -> Self {
402        Self {
403            result_sets,
404            current_result: 0,
405            _marker: std::marker::PhantomData,
406        }
407    }
408
409    /// Create an empty multi-result stream.
410    #[allow(dead_code)]
411    pub(crate) fn empty() -> Self {
412        Self {
413            result_sets: Vec::new(),
414            current_result: 0,
415            _marker: std::marker::PhantomData,
416        }
417    }
418
419    /// Get the current result set index (0-based).
420    #[must_use]
421    pub fn current_result_index(&self) -> usize {
422        self.current_result
423    }
424
425    /// Get the total number of result sets.
426    #[must_use]
427    pub fn result_count(&self) -> usize {
428        self.result_sets.len()
429    }
430
431    /// Check if there are more result sets after the current one.
432    #[must_use]
433    pub fn has_more_results(&self) -> bool {
434        self.current_result + 1 < self.result_sets.len()
435    }
436
437    /// Get the column metadata for the current result set.
438    ///
439    /// Returns `None` if there are no result sets or we've moved past all of them.
440    #[must_use]
441    pub fn columns(&self) -> Option<&[Column]> {
442        self.result_sets
443            .get(self.current_result)
444            .map(|rs| rs.columns())
445    }
446
447    /// Move to the next result set.
448    ///
449    /// Returns `true` if there is another result set, `false` if no more.
450    pub async fn next_result(&mut self) -> Result<bool, Error> {
451        if self.current_result + 1 < self.result_sets.len() {
452            self.current_result += 1;
453            Ok(true)
454        } else {
455            Ok(false)
456        }
457    }
458
459    /// Get the next row from the current result set.
460    ///
461    /// Returns `None` when no more rows in the current result set.
462    /// Call `next_result()` to move to the next result set.
463    pub async fn next_row(&mut self) -> Result<Option<Row>, Error> {
464        if let Some(result_set) = self.result_sets.get_mut(self.current_result) {
465            Ok(result_set.next_row())
466        } else {
467            Ok(None)
468        }
469    }
470
471    /// Get a mutable reference to the current result set.
472    #[must_use]
473    pub fn current_result_set(&mut self) -> Option<&mut ResultSet> {
474        self.result_sets.get_mut(self.current_result)
475    }
476
477    /// Collect all rows from the current result set.
478    ///
479    /// Returns an empty `Vec` if the current result index is out of range
480    /// (e.g., all result sets have been consumed). The `unwrap_or_default`
481    /// below is on `Option`, not `Result` — no errors are being swallowed.
482    pub fn collect_current(&mut self) -> Vec<Row> {
483        self.result_sets
484            .get_mut(self.current_result)
485            .map(|rs| rs.collect_all())
486            .unwrap_or_default()
487    }
488
489    /// Consume the stream and return all result sets as `QueryStream`s.
490    pub fn into_query_streams(self) -> Vec<QueryStream<'a>> {
491        self.result_sets
492            .into_iter()
493            .map(|rs| QueryStream::new(rs.columns, rs.rows.into()))
494            .collect()
495    }
496}
497
498#[cfg(test)]
499#[allow(clippy::unwrap_used)]
500mod tests {
501    use super::*;
502
503    #[test]
504    fn test_execute_result() {
505        let result = ExecuteResult::new(42);
506        assert_eq!(result.rows_affected, 42);
507        assert!(result.output_params.is_empty());
508    }
509
510    #[test]
511    fn test_procedure_result_defaults() {
512        let result = ProcedureResult::new();
513        assert_eq!(result.return_value, 0);
514        assert_eq!(result.rows_affected, 0);
515        assert!(result.output_params.is_empty());
516        assert!(result.result_sets.is_empty());
517        assert!(!result.has_result_sets());
518        assert!(result.first_result_set().is_none());
519    }
520
521    #[test]
522    fn test_procedure_result_get_output() {
523        let mut result = ProcedureResult::new();
524        result.output_params.push(OutputParam {
525            name: "@Total".to_string(),
526            value: mssql_types::SqlValue::Int(42),
527        });
528        result.output_params.push(OutputParam {
529            name: "@Message".to_string(),
530            value: mssql_types::SqlValue::String("ok".to_string()),
531        });
532
533        // Exact match (case-insensitive)
534        assert!(result.get_output("@Total").is_some());
535        assert!(result.get_output("@total").is_some());
536        assert!(result.get_output("@TOTAL").is_some());
537
538        // @ prefix stripping
539        assert!(result.get_output("Total").is_some());
540        assert!(result.get_output("total").is_some());
541
542        // Non-existent
543        assert!(result.get_output("@NotHere").is_none());
544        assert!(result.get_output("NotHere").is_none());
545    }
546
547    #[test]
548    fn test_procedure_result_with_result_sets() {
549        use mssql_types::SqlValue;
550
551        let columns = vec![Column {
552            name: "id".to_string(),
553            index: 0,
554            type_name: "INT".to_string(),
555            nullable: false,
556            max_length: Some(4),
557            precision: None,
558            scale: None,
559            collation: None,
560        }];
561        let rows = vec![Row::from_values(columns.clone(), vec![SqlValue::Int(1)])];
562        let rs = ResultSet::new(columns, rows);
563
564        let mut result = ProcedureResult::new();
565        result.result_sets.push(rs);
566        result.return_value = 7;
567        result.rows_affected = 5;
568
569        assert!(result.has_result_sets());
570        assert_eq!(result.get_return_value(), 7);
571        assert_eq!(result.first_result_set().unwrap().columns().len(), 1);
572    }
573
574    #[test]
575    fn test_execute_result_with_outputs() {
576        let outputs = vec![OutputParam {
577            name: "ReturnValue".to_string(),
578            value: mssql_types::SqlValue::Int(100),
579        }];
580
581        let result = ExecuteResult::with_outputs(10, outputs);
582        assert_eq!(result.rows_affected, 10);
583        assert!(result.get_output("ReturnValue").is_some());
584        assert!(result.get_output("returnvalue").is_some()); // case-insensitive
585        assert!(result.get_output("NotFound").is_none());
586    }
587
588    #[test]
589    fn test_query_stream_columns() {
590        let columns = vec![Column {
591            name: "id".to_string(),
592            index: 0,
593            type_name: "INT".to_string(),
594            nullable: false,
595            max_length: Some(4),
596            precision: Some(0),
597            scale: Some(0),
598            collation: None,
599        }];
600
601        let stream = QueryStream::new(columns, Vec::new());
602        assert_eq!(stream.columns().len(), 1);
603        assert_eq!(stream.columns()[0].name, "id");
604        assert!(!stream.is_finished());
605    }
606
607    #[test]
608    fn test_query_stream_with_rows() {
609        use mssql_types::SqlValue;
610
611        let columns = vec![
612            Column {
613                name: "id".to_string(),
614                index: 0,
615                type_name: "INT".to_string(),
616                nullable: false,
617                max_length: Some(4),
618                precision: None,
619                scale: None,
620                collation: None,
621            },
622            Column {
623                name: "name".to_string(),
624                index: 1,
625                type_name: "NVARCHAR".to_string(),
626                nullable: true,
627                max_length: Some(100),
628                precision: None,
629                scale: None,
630                collation: None,
631            },
632        ];
633
634        let rows = vec![
635            Row::from_values(
636                columns.clone(),
637                vec![SqlValue::Int(1), SqlValue::String("Alice".to_string())],
638            ),
639            Row::from_values(
640                columns.clone(),
641                vec![SqlValue::Int(2), SqlValue::String("Bob".to_string())],
642            ),
643        ];
644
645        let mut stream = QueryStream::new(columns, rows);
646        assert_eq!(stream.columns().len(), 2);
647        assert_eq!(stream.rows_remaining(), 2);
648        assert!(!stream.is_finished());
649
650        // First row
651        let row1 = stream.try_next().unwrap();
652        assert_eq!(row1.get::<i32>(0).unwrap(), 1);
653        assert_eq!(row1.get_by_name::<String>("name").unwrap(), "Alice");
654
655        // Second row
656        let row2 = stream.try_next().unwrap();
657        assert_eq!(row2.get::<i32>(0).unwrap(), 2);
658        assert_eq!(row2.get_by_name::<String>("name").unwrap(), "Bob");
659
660        // No more rows
661        assert!(stream.try_next().is_none());
662        assert!(stream.is_finished());
663    }
664
665    #[test]
666    fn test_query_stream_iterator() {
667        use mssql_types::SqlValue;
668
669        let columns = vec![Column {
670            name: "val".to_string(),
671            index: 0,
672            type_name: "INT".to_string(),
673            nullable: false,
674            max_length: None,
675            precision: None,
676            scale: None,
677            collation: None,
678        }];
679
680        let rows = vec![
681            Row::from_values(columns.clone(), vec![SqlValue::Int(10)]),
682            Row::from_values(columns.clone(), vec![SqlValue::Int(20)]),
683            Row::from_values(columns.clone(), vec![SqlValue::Int(30)]),
684        ];
685
686        let mut stream = QueryStream::new(columns, rows);
687
688        // Use iterator — unwrap each Result so test failures are visible
689        // (QueryStream's Iterator impl always yields Ok, but we should
690        // not silently swallow errors if that ever changes)
691        let values: Vec<i32> = stream
692            .by_ref()
693            .map(|r| r.unwrap().get::<i32>(0).unwrap())
694            .collect();
695
696        assert_eq!(values, vec![10, 20, 30]);
697        assert!(stream.is_finished());
698    }
699
700    #[test]
701    fn test_query_stream_empty() {
702        let stream = QueryStream::empty();
703        assert!(stream.columns().is_empty());
704        assert_eq!(stream.rows_remaining(), 0);
705        assert!(stream.is_finished());
706    }
707}