snowflake_deserializer/
lib.rs

1pub use chrono;
2use data_manipulation::DataManipulationResult;
3use jwt::{KeyPairError, TokenFromFileError};
4use reqwest::header::{ACCEPT, AUTHORIZATION, CONTENT_TYPE, HeaderMap, USER_AGENT};
5use serde::{Deserialize, Serialize};
6pub use serde_json;
7use std::{collections::HashMap, path::Path, str::FromStr};
8
9use crate::bindings::{BindingType, BindingValue};
10
11pub mod bindings;
12pub mod data_manipulation;
13#[cfg(feature = "lazy")]
14pub mod lazy;
15#[cfg(feature = "multiple")]
16pub mod multiple;
17
18mod jwt;
19
20#[derive(Debug)]
21pub struct SnowflakeConnector {
22    host: String,
23    client: reqwest::Client,
24}
25
26impl SnowflakeConnector {
27    pub fn try_new(
28        public_key: &str,
29        private_key: &str,
30        host: &str,
31        account_identifier: &str,
32        user: &str,
33    ) -> Result<Self, NewSnowflakeConnectorError> {
34        let token = jwt::create_token(
35            public_key,
36            private_key,
37            &account_identifier.to_ascii_uppercase(),
38            &user.to_ascii_uppercase(),
39        )?;
40        let headers = Self::get_headers(&token);
41        let client = reqwest::Client::builder()
42            .default_headers(headers)
43            .build()?;
44        Ok(SnowflakeConnector {
45            host: format!("https://{host}.snowflakecomputing.com/api/v2/"),
46            client,
47        })
48    }
49    pub fn try_new_from_file<P: AsRef<Path>>(
50        public_key_path: P,
51        private_key_path: P,
52        host: &str,
53        account_identifier: &str,
54        user: &str,
55    ) -> Result<Self, NewSnowflakeConnectorFromFileError> {
56        let token = jwt::create_token_from_file(
57            public_key_path,
58            private_key_path,
59            &account_identifier.to_ascii_uppercase(),
60            &user.to_ascii_uppercase(),
61        )?;
62        let headers = Self::get_headers(&token);
63        let client = reqwest::Client::builder()
64            .default_headers(headers)
65            .build()?;
66        Ok(SnowflakeConnector {
67            host: format!("https://{host}.snowflakecomputing.com/api/v2/"),
68            client,
69        })
70    }
71
72    pub fn execute<D: ToString>(&self, database: D) -> SnowflakeExecutor<D> {
73        SnowflakeExecutor {
74            host: &self.host,
75            database,
76            client: &self.client,
77        }
78    }
79    fn get_headers(token: &str) -> HeaderMap {
80        let mut headers = HeaderMap::with_capacity(5);
81        headers.append(CONTENT_TYPE, "application/json".parse().unwrap());
82        headers.append(AUTHORIZATION, format!("Bearer {}", token).parse().unwrap());
83        headers.append(
84            "X-Snowflake-Authorization-Token-Type",
85            "KEYPAIR_JWT".parse().unwrap(),
86        );
87        headers.append(ACCEPT, "application/json".parse().unwrap());
88        headers.append(
89            USER_AGENT,
90            concat!(env!("CARGO_PKG_NAME"), '/', env!("CARGO_PKG_VERSION"))
91                .parse()
92                .unwrap(),
93        );
94        headers
95    }
96}
97
98/// Error creating a new [SnowflakeConnector]
99#[derive(thiserror::Error, Debug)]
100pub enum NewSnowflakeConnectorError {
101    #[error(transparent)]
102    KeyPair(#[from] KeyPairError),
103    #[error(transparent)]
104    ClientBuildError(#[from] reqwest::Error),
105}
106
107/// Error creating a new [SnowflakeConnector] from key paths
108#[derive(thiserror::Error, Debug)]
109pub enum NewSnowflakeConnectorFromFileError {
110    #[error(transparent)]
111    Token(#[from] TokenFromFileError),
112    #[error(transparent)]
113    ClientBuildError(#[from] reqwest::Error),
114}
115
116#[derive(Debug)]
117pub struct SnowflakeExecutor<'a, D: ToString> {
118    host: &'a str,
119    database: D,
120    client: &'a reqwest::Client,
121}
122
123impl<'a, D: ToString> SnowflakeExecutor<'a, D> {
124    pub fn sql(self, statement: &'a str) -> SnowflakeSQL<'a> {
125        SnowflakeSQL::new(
126            self.client,
127            self.host,
128            SnowflakeExecutorSQLJSON::new(statement, self.database.to_string()),
129            uuid::Uuid::new_v4(),
130        )
131    }
132}
133
134#[derive(Debug)]
135pub struct SnowflakeSQL<'a> {
136    client: &'a reqwest::Client,
137    host: &'a str,
138    statement: SnowflakeExecutorSQLJSON<'a>,
139    uuid: uuid::Uuid,
140}
141
142impl<'a> SnowflakeSQL<'a> {
143    pub(crate) fn new(
144        client: &'a reqwest::Client,
145        host: &'a str,
146        statement: SnowflakeExecutorSQLJSON<'a>,
147        uuid: uuid::Uuid,
148    ) -> Self {
149        SnowflakeSQL {
150            client,
151            host,
152            statement,
153            uuid,
154        }
155    }
156    pub async fn text(self) -> Result<String, SnowflakeSQLTextError> {
157        self.client
158            .post(self.get_url())
159            .json(&self.statement)
160            .send()
161            .await
162            .map_err(SnowflakeSQLTextError::Request)?
163            .text()
164            .await
165            .map_err(SnowflakeSQLTextError::ToText)
166    }
167    /// Use with `SELECT` queries.
168    pub async fn select<T: SnowflakeDeserialize>(
169        self,
170    ) -> Result<StatementResult<'a, T>, SnowflakeSQLSelectError<T::Error>> {
171        let r = self
172            .client
173            .post(self.get_url())
174            .json(&self.statement)
175            .send()
176            .await
177            .map_err(SnowflakeSQLSelectError::Request)?;
178        let status_code = r.status();
179        match status_code {
180            reqwest::StatusCode::OK => Ok(StatementResult::Result(
181                r.json::<SnowflakeSQLResponse>()
182                    .await
183                    .map_err(SnowflakeSQLSelectError::Decode)?
184                    .deserialize()
185                    .map_err(SnowflakeSQLSelectError::Deserialize)?,
186            )),
187            reqwest::StatusCode::REQUEST_TIMEOUT | reqwest::StatusCode::ACCEPTED => {
188                Ok(StatementResult::Status(SnowflakeQueryStatus {
189                    client: self.client,
190                    host: self.host,
191                    query_status: r
192                        .json::<QueryStatus>()
193                        .await
194                        .map_err(SnowflakeSQLSelectError::Decode)?,
195                }))
196            }
197            reqwest::StatusCode::UNPROCESSABLE_ENTITY => Err(SnowflakeSQLSelectError::Query(
198                r.json().await.map_err(SnowflakeSQLSelectError::Decode)?,
199            )),
200            status_code => Err(SnowflakeSQLSelectError::Unknown(status_code)),
201        }
202    }
203    /// Use with `DELETE`, `INSERT`, `UPDATE` queries.
204    pub async fn manipulate(self) -> Result<DataManipulationResult, SnowflakeSQLManipulateError> {
205        self.client
206            .post(self.get_url())
207            .json(&self.statement)
208            .send()
209            .await
210            .map_err(SnowflakeSQLManipulateError::Request)?
211            .json()
212            .await
213            .map_err(SnowflakeSQLManipulateError::Decode)
214    }
215    pub fn with_timeout(mut self, timeout: u32) -> Self {
216        self.statement.timeout = Some(timeout);
217        self
218    }
219    pub fn with_role<R: ToString>(mut self, role: R) -> Self {
220        self.statement.role = Some(role.to_string());
221        self
222    }
223    pub fn with_warehouse<W: ToString>(mut self, warehouse: W) -> Self {
224        self.statement.warehouse = Some(warehouse.to_string());
225        self
226    }
227    pub fn add_binding<T: Into<BindingValue>>(mut self, value: T) -> Self {
228        let value: BindingValue = value.into();
229        let value_str = value.to_string();
230        let value_type: BindingType = value.into();
231        let binding = Binding {
232            value_type: value_type.to_string(),
233            value: value_str,
234        };
235        if let Some(bindings) = &mut self.statement.bindings {
236            bindings.insert((bindings.len() + 1).to_string(), binding);
237        } else {
238            self.statement.bindings = Some(HashMap::from([("1".into(), binding)]));
239        }
240        self
241    }
242    fn get_url(&self) -> String {
243        get_url(self.host, &self.uuid)
244    }
245}
246
247pub(crate) fn get_url(host: &str, uuid: &uuid::Uuid) -> String {
248    // TODO: make another return type that allows retrying by calling same statement again with retry flag!
249    format!("{host}statements?nullable=false&requestId={uuid}")
250}
251
252/// Error retrieving results of SQL statement as text
253#[derive(thiserror::Error, Debug)]
254#[error(transparent)]
255pub enum SnowflakeSQLTextError {
256    Request(reqwest::Error),
257    ToText(reqwest::Error),
258}
259
260/// Error retrieving results of SQL selection
261#[derive(thiserror::Error, Debug)]
262pub enum SnowflakeSQLSelectError<DeserializeError> {
263    #[error(transparent)]
264    Request(reqwest::Error),
265    #[error(transparent)]
266    Decode(reqwest::Error),
267    #[error(transparent)]
268    Deserialize(DeserializeError),
269    #[error(transparent)]
270    Query(QueryFailureStatus),
271    #[error("unknown error with status code: {0}")]
272    Unknown(reqwest::StatusCode),
273}
274
275/// Error retrieving results of SQL manipulation
276#[derive(thiserror::Error, Debug)]
277pub enum SnowflakeSQLManipulateError {
278    #[error(transparent)]
279    Request(reqwest::Error),
280    #[error(transparent)]
281    Decode(reqwest::Error),
282}
283
284#[derive(Serialize, Debug)]
285pub struct SnowflakeExecutorSQLJSON<'a> {
286    statement: &'a str,
287    timeout: Option<u32>,
288    database: String,
289    warehouse: Option<String>,
290    role: Option<String>,
291    bindings: Option<HashMap<String, Binding>>,
292}
293impl<'a> SnowflakeExecutorSQLJSON<'a> {
294    pub(crate) fn new(statement: &'a str, database: String) -> Self {
295        SnowflakeExecutorSQLJSON {
296            statement,
297            timeout: None,
298            database,
299            warehouse: None,
300            role: None,
301            bindings: None,
302        }
303    }
304}
305
306#[derive(Serialize, Debug)]
307pub struct Binding {
308    #[serde(rename = "type")]
309    value_type: String,
310    value: String,
311}
312
313pub trait SnowflakeDeserialize {
314    type Error;
315    fn snowflake_deserialize(
316        response: SnowflakeSQLResponse,
317    ) -> Result<SnowflakeSQLResult<Self>, Self::Error>
318    where
319        Self: Sized;
320}
321
322#[derive(Deserialize, Debug)]
323#[serde(rename_all = "camelCase")]
324pub struct SnowflakeSQLResponse {
325    pub result_set_meta_data: MetaData,
326    pub data: Vec<Vec<String>>,
327    pub code: String,
328    pub statement_status_url: String,
329    pub request_id: String,
330    pub sql_state: String,
331    pub message: String,
332    //pub created_on: u64,
333}
334
335impl SnowflakeSQLResponse {
336    pub fn deserialize<T: SnowflakeDeserialize>(self) -> Result<SnowflakeSQLResult<T>, T::Error> {
337        T::snowflake_deserialize(self)
338    }
339}
340
341/// [ResultSetMetaData](https://docs.snowflake.com/en/developer-guide/sql-api/reference#label-sql-api-reference-resultset-resultsetmetadata)
342#[derive(Deserialize, Debug)]
343#[serde(rename_all = "camelCase")]
344pub struct MetaData {
345    pub num_rows: usize,
346    pub format: String,
347    pub row_type: Vec<RowType>,
348}
349
350/// [RowType](https://docs.snowflake.com/en/developer-guide/sql-api/reference#label-sql-api-reference-resultset-resultsetmetadata-rowtype)
351#[derive(Deserialize, Debug)]
352#[serde(rename_all = "camelCase")]
353pub struct RowType {
354    pub name: String,
355    pub database: String,
356    pub schema: String,
357    pub table: String,
358    pub precision: Option<u32>,
359    pub byte_length: Option<usize>,
360    #[serde(rename = "type")]
361    pub data_type: String,
362    pub scale: Option<i32>,
363    pub nullable: bool,
364    //pub collation: ???,
365    //pub length: ???,
366}
367
368/// Whether the query is running or finished
369#[derive(Debug)]
370pub enum StatementResult<'a, T> {
371    /// Query still in progress...
372    Status(SnowflakeQueryStatus<'a>),
373    /// Query finished!
374    Result(SnowflakeSQLResult<T>),
375}
376#[derive(Debug)]
377pub struct SnowflakeSQLResult<T> {
378    pub data: Vec<T>,
379}
380
381#[derive(Debug)]
382pub struct SnowflakeQueryStatus<'a> {
383    client: &'a reqwest::Client,
384    host: &'a str,
385    query_status: QueryStatus,
386}
387
388impl<'a> SnowflakeQueryStatus<'a> {
389    pub fn take_query_status(self) -> QueryStatus {
390        self.query_status
391    }
392    pub async fn cancel(&self) -> Result<(), QueryCancelError> {
393        let url = format!(
394            "{}statements/{}/cancel",
395            self.host, self.query_status.statement_handle
396        );
397        let response = self.client.post(url).send().await;
398        match response {
399            Ok(r) => match r.status() {
400                reqwest::StatusCode::OK => Ok(()),
401                status => Err(QueryCancelError::Unknown(status)),
402            },
403            Err(e) => Err(QueryCancelError::Request(e)),
404        }
405    }
406}
407
408/// Error canceling a query
409#[derive(thiserror::Error, Debug)]
410pub enum QueryCancelError {
411    #[error(transparent)]
412    Request(reqwest::Error),
413    #[error("unknown error with status code: {0}")]
414    Unknown(reqwest::StatusCode),
415}
416
417/// A unique tag that identifies a SQL statement request
418#[derive(serde::Deserialize, Clone, Debug)]
419#[serde(transparent)]
420pub struct StatementHandle(String);
421impl StatementHandle {
422    pub fn handle(&self) -> &str {
423        &self.0
424    }
425}
426impl std::fmt::Display for StatementHandle {
427    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
428        self.0.fmt(f)
429    }
430}
431
432/// [QueryStatus](https://docs.snowflake.com/en/developer-guide/sql-api/reference#label-sql-api-reference-querystatus)
433#[derive(serde::Deserialize, thiserror::Error, Debug)]
434#[serde(rename_all = "camelCase")]
435#[error("Error for statement {statement_handle}: {message}")]
436pub struct QueryStatus {
437    code: String,
438    sql_state: String,
439    message: String,
440    statement_handle: StatementHandle,
441    created_on: i64,
442    statement_status_url: String,
443}
444
445impl QueryStatus {
446    pub fn code(&self) -> &str {
447        &self.code
448    }
449    pub fn sql_state(&self) -> &str {
450        &self.sql_state
451    }
452    pub fn message(&self) -> &str {
453        &self.message
454    }
455    pub fn statement_handle(&self) -> &StatementHandle {
456        &self.statement_handle
457    }
458    pub fn created_on(&self) -> i64 {
459        self.created_on
460    }
461    pub fn statement_status_url(&self) -> &str {
462        &self.statement_status_url
463    }
464}
465
466/// [QueryFailureStatus](https://docs.snowflake.com/en/developer-guide/sql-api/reference#label-sql-api-reference-queryfailurestatus)
467#[derive(serde::Deserialize, thiserror::Error, Debug)]
468#[serde(rename_all = "camelCase")]
469#[error("Error for statement {statement_handle}: {message}")]
470pub struct QueryFailureStatus {
471    code: String,
472    sql_state: String,
473    message: String,
474    statement_handle: StatementHandle,
475    created_on: Option<i64>,
476    statement_status_url: Option<String>,
477}
478
479impl QueryFailureStatus {
480    pub fn code(&self) -> &str {
481        &self.code
482    }
483    pub fn sql_state(&self) -> &str {
484        &self.sql_state
485    }
486    pub fn message(&self) -> &str {
487        &self.message
488    }
489    pub fn statement_handle(&self) -> &StatementHandle {
490        &self.statement_handle
491    }
492    pub fn created_on(&self) -> Option<i64> {
493        self.created_on
494    }
495    pub fn statement_status_url(&self) -> Option<&str> {
496        self.statement_status_url.as_deref()
497    }
498}
499
500/// For custom data parsing,
501/// ex. you want to convert the retrieved data (strings) to enums
502///
503/// Data in cells are not their type, they are simply strings that need to be converted.
504pub trait DeserializeFromStr {
505    type Error;
506    fn deserialize_from_str(value: &str) -> Result<Self, Self::Error>
507    where
508        Self: Sized;
509}
510
511impl DeserializeFromStr for chrono::NaiveDate {
512    type Error = chrono::ParseError;
513
514    fn deserialize_from_str(s: &str) -> Result<Self, Self::Error> {
515        chrono::NaiveDate::parse_from_str(s, "%Y-%m-%d")
516    }
517}
518
519impl DeserializeFromStr for chrono::NaiveDateTime {
520    type Error = chrono::ParseError;
521
522    fn deserialize_from_str(s: &str) -> Result<Self, Self::Error> {
523        chrono::NaiveDateTime::parse_from_str(s, "%Y-%m-%d %H:%M:%S%.f")
524            .or_else(|_| chrono::NaiveDateTime::parse_from_str(s, "%Y-%m-%d %H:%M:%S"))
525    }
526}
527
528impl DeserializeFromStr for chrono::DateTime<chrono::Utc> {
529    type Error = chrono::ParseError;
530
531    fn deserialize_from_str(value: &str) -> Result<Self, Self::Error> {
532        // Parse any ISO 8601 / RFC3339 style string and convert to UTC
533        chrono::DateTime::parse_from_rfc3339(value).map(|dt| dt.with_timezone(&chrono::Utc))
534    }
535}
536
537impl DeserializeFromStr for chrono::DateTime<chrono::FixedOffset> {
538    type Error = chrono::ParseError;
539
540    fn deserialize_from_str(value: &str) -> Result<Self, Self::Error> {
541        chrono::DateTime::parse_from_rfc3339(value)
542    }
543}
544
545impl<T: DeserializeFromStr> DeserializeFromStr for Option<T> {
546    type Error = <T as DeserializeFromStr>::Error;
547    fn deserialize_from_str(value: &str) -> Result<Self, Self::Error>
548    where
549        Self: Sized,
550    {
551        if value == "NULL" {
552            Ok(None)
553        } else {
554            <T as DeserializeFromStr>::deserialize_from_str(value).map(|f| Some(f))
555        }
556    }
557}
558macro_rules! impl_deserialize_from_str {
559    ($ty: ty) => {
560        impl DeserializeFromStr for $ty {
561            type Error = <$ty as FromStr>::Err;
562            fn deserialize_from_str(value: &str) -> Result<Self, Self::Error> {
563                <$ty>::from_str(value)
564            }
565        }
566    };
567}
568
569impl_deserialize_from_str!(bool);
570impl_deserialize_from_str!(usize);
571impl_deserialize_from_str!(isize);
572impl_deserialize_from_str!(u8);
573impl_deserialize_from_str!(u16);
574impl_deserialize_from_str!(u32);
575impl_deserialize_from_str!(u64);
576impl_deserialize_from_str!(u128);
577impl_deserialize_from_str!(i16);
578impl_deserialize_from_str!(i32);
579impl_deserialize_from_str!(i64);
580impl_deserialize_from_str!(i128);
581impl_deserialize_from_str!(f32);
582impl_deserialize_from_str!(f64);
583impl_deserialize_from_str!(String);
584
585#[cfg(test)]
586mod tests {
587    use super::*;
588
589    #[test]
590    fn sql() -> Result<(), anyhow::Error> {
591        let connector = SnowflakeConnector::try_new_from_file(
592            "./environment_variables/local/rsa_key.pub",
593            "./environment_variables/local/rsa_key.p8",
594            "HOST",
595            "ACCOUNT",
596            "USER",
597        )?;
598        let sql = connector
599            .execute("DB")
600            .sql("SELECT * FROM TEST_TABLE WHERE id = ? AND name = ?")
601            .add_binding(69);
602        if let Some(bindings) = &sql.statement.bindings {
603            assert_eq!(bindings.len(), 1);
604        } else {
605            assert!(sql.statement.bindings.is_some());
606        }
607        let sql = sql.add_binding("JoMama");
608        if let Some(bindings) = &sql.statement.bindings {
609            assert_eq!(bindings.len(), 2);
610        } else {
611            assert!(sql.statement.bindings.is_some());
612        }
613        Ok(())
614    }
615}