use std::collections::VecDeque;
use std::pin::Pin;
use std::task::{Context, Poll};
use futures_core::Stream;
use crate::error::Error;
use crate::row::{Column, Row};
#[must_use = "streams must be consumed; dropping a stream discards remaining rows"]
pub struct QueryStream<'a> {
columns: Vec<Column>,
rows: VecDeque<Row>,
finished: bool,
_marker: std::marker::PhantomData<&'a ()>,
}
impl QueryStream<'_> {
pub(crate) fn new(columns: Vec<Column>, rows: Vec<Row>) -> Self {
Self {
columns,
rows: rows.into(),
finished: false,
_marker: std::marker::PhantomData,
}
}
#[allow(dead_code)]
pub(crate) fn empty() -> Self {
Self {
columns: Vec::new(),
rows: VecDeque::new(),
finished: true,
_marker: std::marker::PhantomData,
}
}
#[must_use]
pub fn columns(&self) -> &[Column] {
&self.columns
}
#[must_use]
pub fn is_finished(&self) -> bool {
self.finished
}
#[must_use]
pub fn rows_remaining(&self) -> usize {
self.rows.len()
}
pub async fn collect_all(mut self) -> Result<Vec<Row>, Error> {
let rows: Vec<Row> = self.rows.drain(..).collect();
self.finished = true;
Ok(rows)
}
pub fn try_next(&mut self) -> Option<Row> {
if self.finished {
return None;
}
match self.rows.pop_front() {
Some(row) => Some(row),
None => {
self.finished = true;
None
}
}
}
}
impl Stream for QueryStream<'_> {
type Item = Result<Row, Error>;
fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
let this = self.get_mut();
if this.finished {
return Poll::Ready(None);
}
match this.rows.pop_front() {
Some(row) => Poll::Ready(Some(Ok(row))),
None => {
this.finished = true;
Poll::Ready(None)
}
}
}
}
impl ExactSizeIterator for QueryStream<'_> {}
impl Iterator for QueryStream<'_> {
type Item = Result<Row, Error>;
fn next(&mut self) -> Option<Self::Item> {
if self.finished {
return None;
}
match self.rows.pop_front() {
Some(row) => Some(Ok(row)),
None => {
self.finished = true;
None
}
}
}
fn size_hint(&self) -> (usize, Option<usize>) {
let remaining = self.rows.len();
(remaining, Some(remaining))
}
}
#[derive(Debug, Clone)]
#[non_exhaustive]
#[must_use]
pub struct ExecuteResult {
pub rows_affected: u64,
pub output_params: Vec<OutputParam>,
}
#[derive(Debug, Clone)]
#[non_exhaustive]
pub struct OutputParam {
pub name: String,
pub value: mssql_types::SqlValue,
}
impl ExecuteResult {
pub fn new(rows_affected: u64) -> Self {
Self {
rows_affected,
output_params: Vec::new(),
}
}
pub fn with_outputs(rows_affected: u64, output_params: Vec<OutputParam>) -> Self {
Self {
rows_affected,
output_params,
}
}
#[must_use]
pub fn get_output(&self, name: &str) -> Option<&OutputParam> {
self.output_params
.iter()
.find(|p| p.name.eq_ignore_ascii_case(name))
}
}
#[derive(Debug)]
#[must_use]
pub struct ResultSet {
columns: Vec<Column>,
rows: VecDeque<Row>,
}
impl ResultSet {
pub fn new(columns: Vec<Column>, rows: Vec<Row>) -> Self {
Self {
columns,
rows: rows.into(),
}
}
#[must_use]
pub fn columns(&self) -> &[Column] {
&self.columns
}
#[must_use]
pub fn rows_remaining(&self) -> usize {
self.rows.len()
}
pub fn next_row(&mut self) -> Option<Row> {
self.rows.pop_front()
}
#[must_use]
pub fn is_empty(&self) -> bool {
self.rows.is_empty()
}
pub fn collect_all(&mut self) -> Vec<Row> {
self.rows.drain(..).collect()
}
}
#[must_use = "streams must be consumed; dropping a stream discards remaining results"]
pub struct MultiResultStream<'a> {
result_sets: Vec<ResultSet>,
current_result: usize,
_marker: std::marker::PhantomData<&'a ()>,
}
impl<'a> MultiResultStream<'a> {
pub(crate) fn new(result_sets: Vec<ResultSet>) -> Self {
Self {
result_sets,
current_result: 0,
_marker: std::marker::PhantomData,
}
}
#[allow(dead_code)]
pub(crate) fn empty() -> Self {
Self {
result_sets: Vec::new(),
current_result: 0,
_marker: std::marker::PhantomData,
}
}
#[must_use]
pub fn current_result_index(&self) -> usize {
self.current_result
}
#[must_use]
pub fn result_count(&self) -> usize {
self.result_sets.len()
}
#[must_use]
pub fn has_more_results(&self) -> bool {
self.current_result + 1 < self.result_sets.len()
}
#[must_use]
pub fn columns(&self) -> Option<&[Column]> {
self.result_sets
.get(self.current_result)
.map(|rs| rs.columns())
}
pub async fn next_result(&mut self) -> Result<bool, Error> {
if self.current_result + 1 < self.result_sets.len() {
self.current_result += 1;
Ok(true)
} else {
Ok(false)
}
}
pub async fn next_row(&mut self) -> Result<Option<Row>, Error> {
if let Some(result_set) = self.result_sets.get_mut(self.current_result) {
Ok(result_set.next_row())
} else {
Ok(None)
}
}
#[must_use]
pub fn current_result_set(&mut self) -> Option<&mut ResultSet> {
self.result_sets.get_mut(self.current_result)
}
pub fn collect_current(&mut self) -> Vec<Row> {
self.result_sets
.get_mut(self.current_result)
.map(|rs| rs.collect_all())
.unwrap_or_default()
}
pub fn into_query_streams(self) -> Vec<QueryStream<'a>> {
self.result_sets
.into_iter()
.map(|rs| QueryStream::new(rs.columns, rs.rows.into()))
.collect()
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use super::*;
#[test]
fn test_execute_result() {
let result = ExecuteResult::new(42);
assert_eq!(result.rows_affected, 42);
assert!(result.output_params.is_empty());
}
#[test]
fn test_execute_result_with_outputs() {
let outputs = vec![OutputParam {
name: "ReturnValue".to_string(),
value: mssql_types::SqlValue::Int(100),
}];
let result = ExecuteResult::with_outputs(10, outputs);
assert_eq!(result.rows_affected, 10);
assert!(result.get_output("ReturnValue").is_some());
assert!(result.get_output("returnvalue").is_some()); assert!(result.get_output("NotFound").is_none());
}
#[test]
fn test_query_stream_columns() {
let columns = vec![Column {
name: "id".to_string(),
index: 0,
type_name: "INT".to_string(),
nullable: false,
max_length: Some(4),
precision: Some(0),
scale: Some(0),
collation: None,
}];
let stream = QueryStream::new(columns, Vec::new());
assert_eq!(stream.columns().len(), 1);
assert_eq!(stream.columns()[0].name, "id");
assert!(!stream.is_finished());
}
#[test]
fn test_query_stream_with_rows() {
use mssql_types::SqlValue;
let columns = vec![
Column {
name: "id".to_string(),
index: 0,
type_name: "INT".to_string(),
nullable: false,
max_length: Some(4),
precision: None,
scale: None,
collation: None,
},
Column {
name: "name".to_string(),
index: 1,
type_name: "NVARCHAR".to_string(),
nullable: true,
max_length: Some(100),
precision: None,
scale: None,
collation: None,
},
];
let rows = vec![
Row::from_values(
columns.clone(),
vec![SqlValue::Int(1), SqlValue::String("Alice".to_string())],
),
Row::from_values(
columns.clone(),
vec![SqlValue::Int(2), SqlValue::String("Bob".to_string())],
),
];
let mut stream = QueryStream::new(columns, rows);
assert_eq!(stream.columns().len(), 2);
assert_eq!(stream.rows_remaining(), 2);
assert!(!stream.is_finished());
let row1 = stream.try_next().unwrap();
assert_eq!(row1.get::<i32>(0).unwrap(), 1);
assert_eq!(row1.get_by_name::<String>("name").unwrap(), "Alice");
let row2 = stream.try_next().unwrap();
assert_eq!(row2.get::<i32>(0).unwrap(), 2);
assert_eq!(row2.get_by_name::<String>("name").unwrap(), "Bob");
assert!(stream.try_next().is_none());
assert!(stream.is_finished());
}
#[test]
fn test_query_stream_iterator() {
use mssql_types::SqlValue;
let columns = vec![Column {
name: "val".to_string(),
index: 0,
type_name: "INT".to_string(),
nullable: false,
max_length: None,
precision: None,
scale: None,
collation: None,
}];
let rows = vec![
Row::from_values(columns.clone(), vec![SqlValue::Int(10)]),
Row::from_values(columns.clone(), vec![SqlValue::Int(20)]),
Row::from_values(columns.clone(), vec![SqlValue::Int(30)]),
];
let mut stream = QueryStream::new(columns, rows);
let values: Vec<i32> = stream
.by_ref()
.filter_map(|r| r.ok())
.map(|r| r.get::<i32>(0).unwrap())
.collect();
assert_eq!(values, vec![10, 20, 30]);
assert!(stream.is_finished());
}
#[test]
fn test_query_stream_empty() {
let stream = QueryStream::empty();
assert!(stream.columns().is_empty());
assert_eq!(stream.rows_remaining(), 0);
assert!(stream.is_finished());
}
}