snowflakedb-rs 1.1.6

A lightweight, comprehensive and familiar database driver for the SnowflakeDB written natively in Rust
Documentation
use std::{
    collections::VecDeque,
    io::{Cursor, Read},
    sync::{Arc, Weak},
};

use async_stream::try_stream;

use crate::{
    SnowflakeError,
    auth::session::Session,
    connection::Connection,
    driver::{
        Protocol,
        base::{
            BinaryQueryBuilder,
            bindings::{BindMetadata, Bindings},
            response::RawQueryResponse,
        },
        primitives::{column::Column, row::Row},
        query::{DescribeResult, Query, QueryResult},
    },
    error,
    http::client::SnowflakeHttpClient,
    this_errors,
};

use futures_util::{TryStreamExt, lock::Mutex};

#[derive(Clone)]
pub struct JsonProtocol {}

impl Protocol for JsonProtocol {
    type Query<C>
        = JsonQuery<C>
    where
        C: SnowflakeHttpClient;
}

impl Default for JsonProtocol {
    fn default() -> Self {
        Self {}
    }
}

pub struct JsonQuery<C: SnowflakeHttpClient> {
    session: Weak<Mutex<Session<C>>>,
    bindings: Bindings,
    query: String,
}

impl<C: SnowflakeHttpClient> Query<C> for JsonQuery<C> {
    type Result = JsonQueryResult<C>;
    type Describe = JsonDescribeResult;

    fn bind_row(&mut self, params: Vec<impl crate::driver::primitives::cell::ToCellValue>) {
        self.bindings.bind_row(params);
    }

    fn bind_row_named(
        &mut self,
        params: Vec<(
            impl ToString,
            impl crate::driver::primitives::cell::ToCellValue,
        )>,
    ) {
        self.bindings.bind_row_named(params);
    }

    fn new(query: impl ToString, session: Weak<Mutex<Session<C>>>) -> Self {
        Self {
            session,
            bindings: Bindings::new(),
            query: query.to_string(),
        }
    }

    async fn describe(self) -> Result<JsonDescribeResult, SnowflakeError> {
        let query = this_errors!(
            "failed to build underlying binary query",
            BinaryQueryBuilder::default()
                .accept_header("application/json")
                .sql_text(self.query)
                .is_describe_only(true)
                .bindings(self.bindings)
                .build()
        );

        let session = self
            .session
            .upgrade()
            .ok_or(error!("The surrounding connection for this query is dead."))?;

        let mut session = session.lock().await;

        let raw = query.run(&mut session).await?;

        let cols = raw
            .rowtype
            .clone()
            .into_iter()
            .map(|x| Arc::new(x))
            .collect::<Vec<Arc<Column>>>();

        Ok(JsonDescribeResult { columns: cols, raw })
    }

    async fn execute(self) -> Result<Self::Result, SnowflakeError> {
        let query = this_errors!(
            "failed to build underlying binary query",
            BinaryQueryBuilder::default()
                .accept_header("application/json")
                .sql_text(self.query)
                .is_describe_only(false)
                .bindings(self.bindings)
                .build()
        );

        let session = self
            .session
            .upgrade()
            .ok_or(error!("The surrounding connection for this query is dead."))?;

        let mut session = session.lock().await;

        let raw = query.run(&mut session).await?;

        let cols = raw
            .rowtype
            .clone()
            .into_iter()
            .map(|x| Arc::new(x))
            .collect::<Vec<Arc<Column>>>();

        Ok(JsonQueryResult {
            conn: session.get_conn(),
            raw,
            cols,
        })
    }
}

pub struct JsonQueryResult<C: SnowflakeHttpClient + Clone> {
    conn: Connection<C>,
    raw: RawQueryResponse,

    cols: Vec<Arc<Column>>,
}

impl<C: SnowflakeHttpClient + Clone> QueryResult for JsonQueryResult<C> {
    fn expected_result_length(&self) -> i64 {
        self.raw.total
    }

    fn columns(&self) -> Vec<Arc<Column>> {
        self.cols.clone()
    }

    fn rows(
        self,
    ) -> futures_util::stream::BoxStream<
        'static,
        Result<crate::driver::primitives::row::Row, SnowflakeError>,
    > {
        let mut retrieved_first_chunk = false;
        let (_total, mut _retrieved, mut cursor) = (self.raw.total, self.raw.returned, 0i64);
        let cols = self.columns();
        let is_dml = self.raw.is_dml();
        let mut raw_stream = self.raw.stream_chunks(self.conn);

        let stream = try_stream! {
            if is_dml {
                Err(error!("there are no rows to retrieve"))?;
                return
            } else {
                while let Some((_row_count, chunk)) = raw_stream.try_next().await? {
                    let mut chunk_as_json_slice: VecDeque<Vec<Option<String>>> = if retrieved_first_chunk {
                        let prefix = Cursor::new(b"[");
                        let chunk_data = Cursor::new(chunk);
                        let suffix = Cursor::new(b"]");

                        serde_json::from_reader(prefix.chain(chunk_data.chain(suffix))).map_err(|e| error!("failed to parse chunk data as json", e))?
                    } else {
                        retrieved_first_chunk = true;
                        serde_json::from_reader(chunk.as_slice()).map_err(|e| error!("failed to parse chunk data as json", e))?
                    };

                    while let Some(row) = chunk_as_json_slice.pop_front() {
                        yield Row::new_from_strings(cols.clone(), row, cursor);
                        cursor += 1;
                    }
                }
            }
        };

        Box::pin(stream)
    }

    fn is_dml(&self) -> bool {
        self.raw.is_dml()
    }

    fn is_dql(&self) -> bool {
        self.raw.is_dql()
    }

    fn rows_affected(&self) -> i64 {
        self.raw
            .stats
            .as_ref()
            .map(|x| {
                x.num_rows_updated + x.num_dml_duplicates + x.num_rows_deleted + x.num_rows_inserted
            })
            .unwrap_or(0)
    }

    fn rows_updated(&self) -> i64 {
        self.raw
            .stats
            .as_ref()
            .map(|x| x.num_rows_updated)
            .unwrap_or(0)
    }

    fn rows_deleted(&self) -> i64 {
        self.raw
            .stats
            .as_ref()
            .map(|x| x.num_rows_deleted)
            .unwrap_or(0)
    }

    fn rows_inserted(&self) -> i64 {
        self.raw
            .stats
            .as_ref()
            .map(|x| x.num_rows_inserted)
            .unwrap_or(0)
    }
}

#[derive(Debug)]
pub struct JsonDescribeResult {
    columns: Vec<Arc<Column>>,
    raw: RawQueryResponse,
}

impl DescribeResult for JsonDescribeResult {
    fn columns(&self) -> Vec<Arc<Column>> {
        self.columns.clone()
    }

    fn bind_metadata(&self) -> Option<Vec<BindMetadata>> {
        self.raw.meta_data_of_binds.clone()
    }

    fn bind_count(&self) -> i32 {
        self.raw.number_of_binds
    }

    fn is_dml(&self) -> bool {
        self.raw.is_dml()
    }

    fn is_dql(&self) -> bool {
        self.raw.is_dql()
    }
}