1use std::collections::VecDeque;
18use std::pin::Pin;
19use std::task::{Context, Poll};
20
21use futures_core::Stream;
22
23use crate::error::Error;
24use crate::row::{Column, Row};
25
26pub struct QueryStream<'a> {
44 columns: Vec<Column>,
46 rows: VecDeque<Row>,
48 finished: bool,
50 _marker: std::marker::PhantomData<&'a ()>,
52}
53
54impl QueryStream<'_> {
55 pub(crate) fn new(columns: Vec<Column>, rows: Vec<Row>) -> Self {
57 Self {
58 columns,
59 rows: rows.into(),
60 finished: false,
61 _marker: std::marker::PhantomData,
62 }
63 }
64
65 #[allow(dead_code)]
67 pub(crate) fn empty() -> Self {
68 Self {
69 columns: Vec::new(),
70 rows: VecDeque::new(),
71 finished: true,
72 _marker: std::marker::PhantomData,
73 }
74 }
75
76 #[must_use]
78 pub fn columns(&self) -> &[Column] {
79 &self.columns
80 }
81
82 #[must_use]
84 pub fn is_finished(&self) -> bool {
85 self.finished
86 }
87
88 #[must_use]
90 pub fn rows_remaining(&self) -> usize {
91 self.rows.len()
92 }
93
94 pub async fn collect_all(mut self) -> Result<Vec<Row>, Error> {
99 let rows: Vec<Row> = self.rows.drain(..).collect();
101 self.finished = true;
102 Ok(rows)
103 }
104
105 pub fn try_next(&mut self) -> Option<Row> {
109 if self.finished {
110 return None;
111 }
112
113 match self.rows.pop_front() {
114 Some(row) => Some(row),
115 None => {
116 self.finished = true;
117 None
118 }
119 }
120 }
121}
122
123impl Stream for QueryStream<'_> {
124 type Item = Result<Row, Error>;
125
126 fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
127 let this = self.get_mut();
128
129 if this.finished {
130 return Poll::Ready(None);
131 }
132
133 match this.rows.pop_front() {
135 Some(row) => Poll::Ready(Some(Ok(row))),
136 None => {
137 this.finished = true;
138 Poll::Ready(None)
139 }
140 }
141 }
142}
143
144impl ExactSizeIterator for QueryStream<'_> {}
145
146impl Iterator for QueryStream<'_> {
147 type Item = Result<Row, Error>;
148
149 fn next(&mut self) -> Option<Self::Item> {
150 if self.finished {
151 return None;
152 }
153
154 match self.rows.pop_front() {
155 Some(row) => Some(Ok(row)),
156 None => {
157 self.finished = true;
158 None
159 }
160 }
161 }
162
163 fn size_hint(&self) -> (usize, Option<usize>) {
164 let remaining = self.rows.len();
165 (remaining, Some(remaining))
166 }
167}
168
169#[derive(Debug, Clone)]
173pub struct ExecuteResult {
174 pub rows_affected: u64,
176 pub output_params: Vec<OutputParam>,
178}
179
180#[derive(Debug, Clone)]
182pub struct OutputParam {
183 pub name: String,
185 pub value: mssql_types::SqlValue,
187}
188
189impl ExecuteResult {
190 pub fn new(rows_affected: u64) -> Self {
192 Self {
193 rows_affected,
194 output_params: Vec::new(),
195 }
196 }
197
198 pub fn with_outputs(rows_affected: u64, output_params: Vec<OutputParam>) -> Self {
200 Self {
201 rows_affected,
202 output_params,
203 }
204 }
205
206 #[must_use]
208 pub fn get_output(&self, name: &str) -> Option<&OutputParam> {
209 self.output_params
210 .iter()
211 .find(|p| p.name.eq_ignore_ascii_case(name))
212 }
213}
214
215#[derive(Debug)]
217pub struct ResultSet {
218 columns: Vec<Column>,
220 rows: VecDeque<Row>,
222}
223
224impl ResultSet {
225 pub fn new(columns: Vec<Column>, rows: Vec<Row>) -> Self {
227 Self {
228 columns,
229 rows: rows.into(),
230 }
231 }
232
233 #[must_use]
235 pub fn columns(&self) -> &[Column] {
236 &self.columns
237 }
238
239 #[must_use]
241 pub fn rows_remaining(&self) -> usize {
242 self.rows.len()
243 }
244
245 pub fn next_row(&mut self) -> Option<Row> {
247 self.rows.pop_front()
248 }
249
250 #[must_use]
252 pub fn is_empty(&self) -> bool {
253 self.rows.is_empty()
254 }
255
256 pub fn collect_all(&mut self) -> Vec<Row> {
258 self.rows.drain(..).collect()
259 }
260}
261
262pub struct MultiResultStream<'a> {
286 result_sets: Vec<ResultSet>,
288 current_result: usize,
290 _marker: std::marker::PhantomData<&'a ()>,
292}
293
294impl<'a> MultiResultStream<'a> {
295 pub(crate) fn new(result_sets: Vec<ResultSet>) -> Self {
297 Self {
298 result_sets,
299 current_result: 0,
300 _marker: std::marker::PhantomData,
301 }
302 }
303
304 #[allow(dead_code)]
306 pub(crate) fn empty() -> Self {
307 Self {
308 result_sets: Vec::new(),
309 current_result: 0,
310 _marker: std::marker::PhantomData,
311 }
312 }
313
314 #[must_use]
316 pub fn current_result_index(&self) -> usize {
317 self.current_result
318 }
319
320 #[must_use]
322 pub fn result_count(&self) -> usize {
323 self.result_sets.len()
324 }
325
326 #[must_use]
328 pub fn has_more_results(&self) -> bool {
329 self.current_result + 1 < self.result_sets.len()
330 }
331
332 #[must_use]
336 pub fn columns(&self) -> Option<&[Column]> {
337 self.result_sets
338 .get(self.current_result)
339 .map(|rs| rs.columns())
340 }
341
342 pub async fn next_result(&mut self) -> Result<bool, Error> {
346 if self.current_result + 1 < self.result_sets.len() {
347 self.current_result += 1;
348 Ok(true)
349 } else {
350 Ok(false)
351 }
352 }
353
354 pub async fn next_row(&mut self) -> Result<Option<Row>, Error> {
359 if let Some(result_set) = self.result_sets.get_mut(self.current_result) {
360 Ok(result_set.next_row())
361 } else {
362 Ok(None)
363 }
364 }
365
366 #[must_use]
368 pub fn current_result_set(&mut self) -> Option<&mut ResultSet> {
369 self.result_sets.get_mut(self.current_result)
370 }
371
372 pub fn collect_current(&mut self) -> Vec<Row> {
374 self.result_sets
375 .get_mut(self.current_result)
376 .map(|rs| rs.collect_all())
377 .unwrap_or_default()
378 }
379
380 pub fn into_query_streams(self) -> Vec<QueryStream<'a>> {
382 self.result_sets
383 .into_iter()
384 .map(|rs| QueryStream::new(rs.columns, rs.rows.into()))
385 .collect()
386 }
387}
388
389#[cfg(test)]
390#[allow(clippy::unwrap_used)]
391mod tests {
392 use super::*;
393
394 #[test]
395 fn test_execute_result() {
396 let result = ExecuteResult::new(42);
397 assert_eq!(result.rows_affected, 42);
398 assert!(result.output_params.is_empty());
399 }
400
401 #[test]
402 fn test_execute_result_with_outputs() {
403 let outputs = vec![OutputParam {
404 name: "ReturnValue".to_string(),
405 value: mssql_types::SqlValue::Int(100),
406 }];
407
408 let result = ExecuteResult::with_outputs(10, outputs);
409 assert_eq!(result.rows_affected, 10);
410 assert!(result.get_output("ReturnValue").is_some());
411 assert!(result.get_output("returnvalue").is_some()); assert!(result.get_output("NotFound").is_none());
413 }
414
415 #[test]
416 fn test_query_stream_columns() {
417 let columns = vec![Column {
418 name: "id".to_string(),
419 index: 0,
420 type_name: "INT".to_string(),
421 nullable: false,
422 max_length: Some(4),
423 precision: Some(0),
424 scale: Some(0),
425 collation: None,
426 }];
427
428 let stream = QueryStream::new(columns, Vec::new());
429 assert_eq!(stream.columns().len(), 1);
430 assert_eq!(stream.columns()[0].name, "id");
431 assert!(!stream.is_finished());
432 }
433
434 #[test]
435 fn test_query_stream_with_rows() {
436 use mssql_types::SqlValue;
437
438 let columns = vec![
439 Column {
440 name: "id".to_string(),
441 index: 0,
442 type_name: "INT".to_string(),
443 nullable: false,
444 max_length: Some(4),
445 precision: None,
446 scale: None,
447 collation: None,
448 },
449 Column {
450 name: "name".to_string(),
451 index: 1,
452 type_name: "NVARCHAR".to_string(),
453 nullable: true,
454 max_length: Some(100),
455 precision: None,
456 scale: None,
457 collation: None,
458 },
459 ];
460
461 let rows = vec![
462 Row::from_values(
463 columns.clone(),
464 vec![SqlValue::Int(1), SqlValue::String("Alice".to_string())],
465 ),
466 Row::from_values(
467 columns.clone(),
468 vec![SqlValue::Int(2), SqlValue::String("Bob".to_string())],
469 ),
470 ];
471
472 let mut stream = QueryStream::new(columns, rows);
473 assert_eq!(stream.columns().len(), 2);
474 assert_eq!(stream.rows_remaining(), 2);
475 assert!(!stream.is_finished());
476
477 let row1 = stream.try_next().unwrap();
479 assert_eq!(row1.get::<i32>(0).unwrap(), 1);
480 assert_eq!(row1.get_by_name::<String>("name").unwrap(), "Alice");
481
482 let row2 = stream.try_next().unwrap();
484 assert_eq!(row2.get::<i32>(0).unwrap(), 2);
485 assert_eq!(row2.get_by_name::<String>("name").unwrap(), "Bob");
486
487 assert!(stream.try_next().is_none());
489 assert!(stream.is_finished());
490 }
491
492 #[test]
493 fn test_query_stream_iterator() {
494 use mssql_types::SqlValue;
495
496 let columns = vec![Column {
497 name: "val".to_string(),
498 index: 0,
499 type_name: "INT".to_string(),
500 nullable: false,
501 max_length: None,
502 precision: None,
503 scale: None,
504 collation: None,
505 }];
506
507 let rows = vec![
508 Row::from_values(columns.clone(), vec![SqlValue::Int(10)]),
509 Row::from_values(columns.clone(), vec![SqlValue::Int(20)]),
510 Row::from_values(columns.clone(), vec![SqlValue::Int(30)]),
511 ];
512
513 let mut stream = QueryStream::new(columns, rows);
514
515 let values: Vec<i32> = stream
517 .by_ref()
518 .filter_map(|r| r.ok())
519 .map(|r| r.get::<i32>(0).unwrap())
520 .collect();
521
522 assert_eq!(values, vec![10, 20, 30]);
523 assert!(stream.is_finished());
524 }
525
526 #[test]
527 fn test_query_stream_empty() {
528 let stream = QueryStream::empty();
529 assert!(stream.columns().is_empty());
530 assert_eq!(stream.rows_remaining(), 0);
531 assert!(stream.is_finished());
532 }
533}