1use std::collections::VecDeque;
18use std::pin::Pin;
19use std::task::{Context, Poll};
20
21use futures_core::Stream;
22
23use crate::error::Error;
24use crate::row::{Column, Row};
25
26#[must_use = "streams must be consumed; dropping a stream discards remaining rows"]
44pub struct QueryStream<'a> {
45 columns: Vec<Column>,
47 rows: VecDeque<Row>,
49 finished: bool,
51 _marker: std::marker::PhantomData<&'a ()>,
53}
54
55impl QueryStream<'_> {
56 pub(crate) fn new(columns: Vec<Column>, rows: Vec<Row>) -> Self {
58 Self {
59 columns,
60 rows: rows.into(),
61 finished: false,
62 _marker: std::marker::PhantomData,
63 }
64 }
65
66 #[allow(dead_code)]
68 pub(crate) fn empty() -> Self {
69 Self {
70 columns: Vec::new(),
71 rows: VecDeque::new(),
72 finished: true,
73 _marker: std::marker::PhantomData,
74 }
75 }
76
77 #[must_use]
79 pub fn columns(&self) -> &[Column] {
80 &self.columns
81 }
82
83 #[must_use]
85 pub fn is_finished(&self) -> bool {
86 self.finished
87 }
88
89 #[must_use]
91 pub fn rows_remaining(&self) -> usize {
92 self.rows.len()
93 }
94
95 pub async fn collect_all(mut self) -> Result<Vec<Row>, Error> {
100 let rows: Vec<Row> = self.rows.drain(..).collect();
102 self.finished = true;
103 Ok(rows)
104 }
105
106 pub fn try_next(&mut self) -> Option<Row> {
110 if self.finished {
111 return None;
112 }
113
114 match self.rows.pop_front() {
115 Some(row) => Some(row),
116 None => {
117 self.finished = true;
118 None
119 }
120 }
121 }
122}
123
124impl Stream for QueryStream<'_> {
125 type Item = Result<Row, Error>;
126
127 fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
128 let this = self.get_mut();
129
130 if this.finished {
131 return Poll::Ready(None);
132 }
133
134 match this.rows.pop_front() {
136 Some(row) => Poll::Ready(Some(Ok(row))),
137 None => {
138 this.finished = true;
139 Poll::Ready(None)
140 }
141 }
142 }
143}
144
145impl ExactSizeIterator for QueryStream<'_> {}
146
147impl Iterator for QueryStream<'_> {
148 type Item = Result<Row, Error>;
149
150 fn next(&mut self) -> Option<Self::Item> {
151 if self.finished {
152 return None;
153 }
154
155 match self.rows.pop_front() {
156 Some(row) => Some(Ok(row)),
157 None => {
158 self.finished = true;
159 None
160 }
161 }
162 }
163
164 fn size_hint(&self) -> (usize, Option<usize>) {
165 let remaining = self.rows.len();
166 (remaining, Some(remaining))
167 }
168}
169
170#[derive(Debug, Clone)]
174#[non_exhaustive]
175#[must_use]
176pub struct ExecuteResult {
177 pub rows_affected: u64,
179 pub output_params: Vec<OutputParam>,
181}
182
183#[derive(Debug, Clone)]
185#[non_exhaustive]
186pub struct OutputParam {
187 pub name: String,
189 pub value: mssql_types::SqlValue,
191}
192
193impl ExecuteResult {
194 pub fn new(rows_affected: u64) -> Self {
196 Self {
197 rows_affected,
198 output_params: Vec::new(),
199 }
200 }
201
202 pub fn with_outputs(rows_affected: u64, output_params: Vec<OutputParam>) -> Self {
204 Self {
205 rows_affected,
206 output_params,
207 }
208 }
209
210 #[must_use]
212 pub fn get_output(&self, name: &str) -> Option<&OutputParam> {
213 self.output_params
214 .iter()
215 .find(|p| p.name.eq_ignore_ascii_case(name))
216 }
217}
218
219#[derive(Debug, Clone)]
240#[non_exhaustive]
241#[must_use]
242pub struct ProcedureResult {
243 pub return_value: i32,
248 pub rows_affected: u64,
250 pub output_params: Vec<OutputParam>,
252 pub result_sets: Vec<ResultSet>,
254}
255
256impl ProcedureResult {
257 pub(crate) fn new() -> Self {
259 Self {
260 return_value: 0,
261 rows_affected: 0,
262 output_params: Vec::new(),
263 result_sets: Vec::new(),
264 }
265 }
266
267 #[must_use]
272 pub fn get_return_value(&self) -> i32 {
273 self.return_value
274 }
275
276 #[must_use]
295 pub fn get_output(&self, name: &str) -> Option<&OutputParam> {
296 let search = name.strip_prefix('@').unwrap_or(name);
297 self.output_params.iter().find(|p| {
298 let stored = p.name.strip_prefix('@').unwrap_or(&p.name);
299 stored.eq_ignore_ascii_case(search)
300 })
301 }
302
303 #[must_use]
307 pub fn first_result_set(&self) -> Option<&ResultSet> {
308 self.result_sets.first()
309 }
310
311 #[must_use]
313 pub fn has_result_sets(&self) -> bool {
314 !self.result_sets.is_empty()
315 }
316}
317
318#[derive(Debug, Clone)]
320#[must_use]
321pub struct ResultSet {
322 columns: Vec<Column>,
324 rows: VecDeque<Row>,
326}
327
328impl ResultSet {
329 pub fn new(columns: Vec<Column>, rows: Vec<Row>) -> Self {
331 Self {
332 columns,
333 rows: rows.into(),
334 }
335 }
336
337 #[must_use]
339 pub fn columns(&self) -> &[Column] {
340 &self.columns
341 }
342
343 #[must_use]
345 pub fn rows_remaining(&self) -> usize {
346 self.rows.len()
347 }
348
349 pub fn next_row(&mut self) -> Option<Row> {
351 self.rows.pop_front()
352 }
353
354 #[must_use]
356 pub fn is_empty(&self) -> bool {
357 self.rows.is_empty()
358 }
359
360 pub fn collect_all(&mut self) -> Vec<Row> {
362 self.rows.drain(..).collect()
363 }
364}
365
366#[must_use = "streams must be consumed; dropping a stream discards remaining results"]
390pub struct MultiResultStream<'a> {
391 result_sets: Vec<ResultSet>,
393 current_result: usize,
395 _marker: std::marker::PhantomData<&'a ()>,
397}
398
399impl<'a> MultiResultStream<'a> {
400 pub(crate) fn new(result_sets: Vec<ResultSet>) -> Self {
402 Self {
403 result_sets,
404 current_result: 0,
405 _marker: std::marker::PhantomData,
406 }
407 }
408
409 #[allow(dead_code)]
411 pub(crate) fn empty() -> Self {
412 Self {
413 result_sets: Vec::new(),
414 current_result: 0,
415 _marker: std::marker::PhantomData,
416 }
417 }
418
419 #[must_use]
421 pub fn current_result_index(&self) -> usize {
422 self.current_result
423 }
424
425 #[must_use]
427 pub fn result_count(&self) -> usize {
428 self.result_sets.len()
429 }
430
431 #[must_use]
433 pub fn has_more_results(&self) -> bool {
434 self.current_result + 1 < self.result_sets.len()
435 }
436
437 #[must_use]
441 pub fn columns(&self) -> Option<&[Column]> {
442 self.result_sets
443 .get(self.current_result)
444 .map(|rs| rs.columns())
445 }
446
447 pub async fn next_result(&mut self) -> Result<bool, Error> {
451 if self.current_result + 1 < self.result_sets.len() {
452 self.current_result += 1;
453 Ok(true)
454 } else {
455 Ok(false)
456 }
457 }
458
459 pub async fn next_row(&mut self) -> Result<Option<Row>, Error> {
464 if let Some(result_set) = self.result_sets.get_mut(self.current_result) {
465 Ok(result_set.next_row())
466 } else {
467 Ok(None)
468 }
469 }
470
471 #[must_use]
473 pub fn current_result_set(&mut self) -> Option<&mut ResultSet> {
474 self.result_sets.get_mut(self.current_result)
475 }
476
477 pub fn collect_current(&mut self) -> Vec<Row> {
483 self.result_sets
484 .get_mut(self.current_result)
485 .map(|rs| rs.collect_all())
486 .unwrap_or_default()
487 }
488
489 pub fn into_query_streams(self) -> Vec<QueryStream<'a>> {
491 self.result_sets
492 .into_iter()
493 .map(|rs| QueryStream::new(rs.columns, rs.rows.into()))
494 .collect()
495 }
496}
497
498#[cfg(test)]
499#[allow(clippy::unwrap_used)]
500mod tests {
501 use super::*;
502
503 #[test]
504 fn test_execute_result() {
505 let result = ExecuteResult::new(42);
506 assert_eq!(result.rows_affected, 42);
507 assert!(result.output_params.is_empty());
508 }
509
510 #[test]
511 fn test_procedure_result_defaults() {
512 let result = ProcedureResult::new();
513 assert_eq!(result.return_value, 0);
514 assert_eq!(result.rows_affected, 0);
515 assert!(result.output_params.is_empty());
516 assert!(result.result_sets.is_empty());
517 assert!(!result.has_result_sets());
518 assert!(result.first_result_set().is_none());
519 }
520
521 #[test]
522 fn test_procedure_result_get_output() {
523 let mut result = ProcedureResult::new();
524 result.output_params.push(OutputParam {
525 name: "@Total".to_string(),
526 value: mssql_types::SqlValue::Int(42),
527 });
528 result.output_params.push(OutputParam {
529 name: "@Message".to_string(),
530 value: mssql_types::SqlValue::String("ok".to_string()),
531 });
532
533 assert!(result.get_output("@Total").is_some());
535 assert!(result.get_output("@total").is_some());
536 assert!(result.get_output("@TOTAL").is_some());
537
538 assert!(result.get_output("Total").is_some());
540 assert!(result.get_output("total").is_some());
541
542 assert!(result.get_output("@NotHere").is_none());
544 assert!(result.get_output("NotHere").is_none());
545 }
546
547 #[test]
548 fn test_procedure_result_with_result_sets() {
549 use mssql_types::SqlValue;
550
551 let columns = vec![Column {
552 name: "id".to_string(),
553 index: 0,
554 type_name: "INT".to_string(),
555 nullable: false,
556 max_length: Some(4),
557 precision: None,
558 scale: None,
559 collation: None,
560 }];
561 let rows = vec![Row::from_values(columns.clone(), vec![SqlValue::Int(1)])];
562 let rs = ResultSet::new(columns, rows);
563
564 let mut result = ProcedureResult::new();
565 result.result_sets.push(rs);
566 result.return_value = 7;
567 result.rows_affected = 5;
568
569 assert!(result.has_result_sets());
570 assert_eq!(result.get_return_value(), 7);
571 assert_eq!(result.first_result_set().unwrap().columns().len(), 1);
572 }
573
574 #[test]
575 fn test_execute_result_with_outputs() {
576 let outputs = vec![OutputParam {
577 name: "ReturnValue".to_string(),
578 value: mssql_types::SqlValue::Int(100),
579 }];
580
581 let result = ExecuteResult::with_outputs(10, outputs);
582 assert_eq!(result.rows_affected, 10);
583 assert!(result.get_output("ReturnValue").is_some());
584 assert!(result.get_output("returnvalue").is_some()); assert!(result.get_output("NotFound").is_none());
586 }
587
588 #[test]
589 fn test_query_stream_columns() {
590 let columns = vec![Column {
591 name: "id".to_string(),
592 index: 0,
593 type_name: "INT".to_string(),
594 nullable: false,
595 max_length: Some(4),
596 precision: Some(0),
597 scale: Some(0),
598 collation: None,
599 }];
600
601 let stream = QueryStream::new(columns, Vec::new());
602 assert_eq!(stream.columns().len(), 1);
603 assert_eq!(stream.columns()[0].name, "id");
604 assert!(!stream.is_finished());
605 }
606
607 #[test]
608 fn test_query_stream_with_rows() {
609 use mssql_types::SqlValue;
610
611 let columns = vec![
612 Column {
613 name: "id".to_string(),
614 index: 0,
615 type_name: "INT".to_string(),
616 nullable: false,
617 max_length: Some(4),
618 precision: None,
619 scale: None,
620 collation: None,
621 },
622 Column {
623 name: "name".to_string(),
624 index: 1,
625 type_name: "NVARCHAR".to_string(),
626 nullable: true,
627 max_length: Some(100),
628 precision: None,
629 scale: None,
630 collation: None,
631 },
632 ];
633
634 let rows = vec![
635 Row::from_values(
636 columns.clone(),
637 vec![SqlValue::Int(1), SqlValue::String("Alice".to_string())],
638 ),
639 Row::from_values(
640 columns.clone(),
641 vec![SqlValue::Int(2), SqlValue::String("Bob".to_string())],
642 ),
643 ];
644
645 let mut stream = QueryStream::new(columns, rows);
646 assert_eq!(stream.columns().len(), 2);
647 assert_eq!(stream.rows_remaining(), 2);
648 assert!(!stream.is_finished());
649
650 let row1 = stream.try_next().unwrap();
652 assert_eq!(row1.get::<i32>(0).unwrap(), 1);
653 assert_eq!(row1.get_by_name::<String>("name").unwrap(), "Alice");
654
655 let row2 = stream.try_next().unwrap();
657 assert_eq!(row2.get::<i32>(0).unwrap(), 2);
658 assert_eq!(row2.get_by_name::<String>("name").unwrap(), "Bob");
659
660 assert!(stream.try_next().is_none());
662 assert!(stream.is_finished());
663 }
664
665 #[test]
666 fn test_query_stream_iterator() {
667 use mssql_types::SqlValue;
668
669 let columns = vec![Column {
670 name: "val".to_string(),
671 index: 0,
672 type_name: "INT".to_string(),
673 nullable: false,
674 max_length: None,
675 precision: None,
676 scale: None,
677 collation: None,
678 }];
679
680 let rows = vec![
681 Row::from_values(columns.clone(), vec![SqlValue::Int(10)]),
682 Row::from_values(columns.clone(), vec![SqlValue::Int(20)]),
683 Row::from_values(columns.clone(), vec![SqlValue::Int(30)]),
684 ];
685
686 let mut stream = QueryStream::new(columns, rows);
687
688 let values: Vec<i32> = stream
692 .by_ref()
693 .map(|r| r.unwrap().get::<i32>(0).unwrap())
694 .collect();
695
696 assert_eq!(values, vec![10, 20, 30]);
697 assert!(stream.is_finished());
698 }
699
700 #[test]
701 fn test_query_stream_empty() {
702 let stream = QueryStream::empty();
703 assert!(stream.columns().is_empty());
704 assert_eq!(stream.rows_remaining(), 0);
705 assert!(stream.is_finished());
706 }
707}