1pub use chrono;
2use data_manipulation::DataManipulationResult;
3use jwt::{KeyPairError, TokenFromFileError};
4use reqwest::header::{ACCEPT, AUTHORIZATION, CONTENT_TYPE, HeaderMap, USER_AGENT};
5use serde::{Deserialize, Serialize};
6pub use serde_json;
7use std::{collections::HashMap, path::Path, str::FromStr};
8
9use crate::bindings::{BindingType, BindingValue};
10
11pub mod bindings;
12pub mod data_manipulation;
13#[cfg(feature = "lazy")]
14pub mod lazy;
15#[cfg(feature = "multiple")]
16pub mod multiple;
17
18mod jwt;
19
20#[derive(Debug)]
21pub struct SnowflakeConnector {
22 host: String,
23 client: reqwest::Client,
24}
25
26impl SnowflakeConnector {
27 pub fn try_new(
28 public_key: &str,
29 private_key: &str,
30 host: &str,
31 account_identifier: &str,
32 user: &str,
33 ) -> Result<Self, NewSnowflakeConnectorError> {
34 let token = jwt::create_token(
35 public_key,
36 private_key,
37 &account_identifier.to_ascii_uppercase(),
38 &user.to_ascii_uppercase(),
39 )?;
40 let headers = Self::get_headers(&token);
41 let client = reqwest::Client::builder()
42 .default_headers(headers)
43 .build()?;
44 Ok(SnowflakeConnector {
45 host: format!("https://{host}.snowflakecomputing.com/api/v2/"),
46 client,
47 })
48 }
49 pub fn try_new_from_file<P: AsRef<Path>>(
50 public_key_path: P,
51 private_key_path: P,
52 host: &str,
53 account_identifier: &str,
54 user: &str,
55 ) -> Result<Self, NewSnowflakeConnectorFromFileError> {
56 let token = jwt::create_token_from_file(
57 public_key_path,
58 private_key_path,
59 &account_identifier.to_ascii_uppercase(),
60 &user.to_ascii_uppercase(),
61 )?;
62 let headers = Self::get_headers(&token);
63 let client = reqwest::Client::builder()
64 .default_headers(headers)
65 .build()?;
66 Ok(SnowflakeConnector {
67 host: format!("https://{host}.snowflakecomputing.com/api/v2/"),
68 client,
69 })
70 }
71
72 pub fn execute<D: ToString>(&self, database: D) -> SnowflakeExecutor<D> {
73 SnowflakeExecutor {
74 host: &self.host,
75 database,
76 client: &self.client,
77 }
78 }
79 fn get_headers(token: &str) -> HeaderMap {
80 let mut headers = HeaderMap::with_capacity(5);
81 headers.append(CONTENT_TYPE, "application/json".parse().unwrap());
82 headers.append(AUTHORIZATION, format!("Bearer {}", token).parse().unwrap());
83 headers.append(
84 "X-Snowflake-Authorization-Token-Type",
85 "KEYPAIR_JWT".parse().unwrap(),
86 );
87 headers.append(ACCEPT, "application/json".parse().unwrap());
88 headers.append(
89 USER_AGENT,
90 concat!(env!("CARGO_PKG_NAME"), '/', env!("CARGO_PKG_VERSION"))
91 .parse()
92 .unwrap(),
93 );
94 headers
95 }
96}
97
98#[derive(thiserror::Error, Debug)]
100pub enum NewSnowflakeConnectorError {
101 #[error(transparent)]
102 KeyPair(#[from] KeyPairError),
103 #[error(transparent)]
104 ClientBuildError(#[from] reqwest::Error),
105}
106
107#[derive(thiserror::Error, Debug)]
109pub enum NewSnowflakeConnectorFromFileError {
110 #[error(transparent)]
111 Token(#[from] TokenFromFileError),
112 #[error(transparent)]
113 ClientBuildError(#[from] reqwest::Error),
114}
115
116#[derive(Debug)]
117pub struct SnowflakeExecutor<'a, D: ToString> {
118 host: &'a str,
119 database: D,
120 client: &'a reqwest::Client,
121}
122
123impl<'a, D: ToString> SnowflakeExecutor<'a, D> {
124 pub fn sql(self, statement: &'a str) -> SnowflakeSQL<'a> {
125 SnowflakeSQL::new(
126 self.client,
127 self.host,
128 SnowflakeExecutorSQLJSON::new(statement, self.database.to_string()),
129 uuid::Uuid::new_v4(),
130 )
131 }
132}
133
134#[derive(Debug)]
135pub struct SnowflakeSQL<'a> {
136 client: &'a reqwest::Client,
137 host: &'a str,
138 statement: SnowflakeExecutorSQLJSON<'a>,
139 uuid: uuid::Uuid,
140}
141
142impl<'a> SnowflakeSQL<'a> {
143 pub(crate) fn new(
144 client: &'a reqwest::Client,
145 host: &'a str,
146 statement: SnowflakeExecutorSQLJSON<'a>,
147 uuid: uuid::Uuid,
148 ) -> Self {
149 SnowflakeSQL {
150 client,
151 host,
152 statement,
153 uuid,
154 }
155 }
156 pub async fn text(self) -> Result<String, SnowflakeSQLTextError> {
157 self.client
158 .post(self.get_url())
159 .json(&self.statement)
160 .send()
161 .await
162 .map_err(SnowflakeSQLTextError::Request)?
163 .text()
164 .await
165 .map_err(SnowflakeSQLTextError::ToText)
166 }
167 pub async fn select<T: SnowflakeDeserialize>(
169 self,
170 ) -> Result<StatementResult<'a, T>, SnowflakeSQLSelectError<T::Error>> {
171 let r = self
172 .client
173 .post(self.get_url())
174 .json(&self.statement)
175 .send()
176 .await
177 .map_err(SnowflakeSQLSelectError::Request)?;
178 let status_code = r.status();
179 match status_code {
180 reqwest::StatusCode::OK => Ok(StatementResult::Result(
181 r.json::<SnowflakeSQLResponse>()
182 .await
183 .map_err(SnowflakeSQLSelectError::Decode)?
184 .deserialize()
185 .map_err(SnowflakeSQLSelectError::Deserialize)?,
186 )),
187 reqwest::StatusCode::REQUEST_TIMEOUT | reqwest::StatusCode::ACCEPTED => {
188 Ok(StatementResult::Status(SnowflakeQueryStatus {
189 client: self.client,
190 host: self.host,
191 query_status: r
192 .json::<QueryStatus>()
193 .await
194 .map_err(SnowflakeSQLSelectError::Decode)?,
195 }))
196 }
197 reqwest::StatusCode::UNPROCESSABLE_ENTITY => Err(SnowflakeSQLSelectError::Query(
198 r.json().await.map_err(SnowflakeSQLSelectError::Decode)?,
199 )),
200 status_code => Err(SnowflakeSQLSelectError::Unknown(status_code)),
201 }
202 }
203 pub async fn manipulate(self) -> Result<DataManipulationResult, SnowflakeSQLManipulateError> {
205 self.client
206 .post(self.get_url())
207 .json(&self.statement)
208 .send()
209 .await
210 .map_err(SnowflakeSQLManipulateError::Request)?
211 .json()
212 .await
213 .map_err(SnowflakeSQLManipulateError::Decode)
214 }
215 pub fn with_timeout(mut self, timeout: u32) -> Self {
216 self.statement.timeout = Some(timeout);
217 self
218 }
219 pub fn with_role<R: ToString>(mut self, role: R) -> Self {
220 self.statement.role = Some(role.to_string());
221 self
222 }
223 pub fn with_warehouse<W: ToString>(mut self, warehouse: W) -> Self {
224 self.statement.warehouse = Some(warehouse.to_string());
225 self
226 }
227 pub fn add_binding<T: Into<BindingValue>>(mut self, value: T) -> Self {
228 let value: BindingValue = value.into();
229 let value_str = value.to_string();
230 let value_type: BindingType = value.into();
231 let binding = Binding {
232 value_type: value_type.to_string(),
233 value: value_str,
234 };
235 if let Some(bindings) = &mut self.statement.bindings {
236 bindings.insert((bindings.len() + 1).to_string(), binding);
237 } else {
238 self.statement.bindings = Some(HashMap::from([("1".into(), binding)]));
239 }
240 self
241 }
242 fn get_url(&self) -> String {
243 get_url(self.host, &self.uuid)
244 }
245}
246
247pub(crate) fn get_url(host: &str, uuid: &uuid::Uuid) -> String {
248 format!("{host}statements?nullable=false&requestId={uuid}")
250}
251
252#[derive(thiserror::Error, Debug)]
254#[error(transparent)]
255pub enum SnowflakeSQLTextError {
256 Request(reqwest::Error),
257 ToText(reqwest::Error),
258}
259
260#[derive(thiserror::Error, Debug)]
262pub enum SnowflakeSQLSelectError<DeserializeError> {
263 #[error(transparent)]
264 Request(reqwest::Error),
265 #[error(transparent)]
266 Decode(reqwest::Error),
267 #[error(transparent)]
268 Deserialize(DeserializeError),
269 #[error(transparent)]
270 Query(QueryFailureStatus),
271 #[error("unknown error with status code: {0}")]
272 Unknown(reqwest::StatusCode),
273}
274
275#[derive(thiserror::Error, Debug)]
277pub enum SnowflakeSQLManipulateError {
278 #[error(transparent)]
279 Request(reqwest::Error),
280 #[error(transparent)]
281 Decode(reqwest::Error),
282}
283
284#[derive(Serialize, Debug)]
285pub struct SnowflakeExecutorSQLJSON<'a> {
286 statement: &'a str,
287 timeout: Option<u32>,
288 database: String,
289 warehouse: Option<String>,
290 role: Option<String>,
291 bindings: Option<HashMap<String, Binding>>,
292}
293impl<'a> SnowflakeExecutorSQLJSON<'a> {
294 pub(crate) fn new(statement: &'a str, database: String) -> Self {
295 SnowflakeExecutorSQLJSON {
296 statement,
297 timeout: None,
298 database,
299 warehouse: None,
300 role: None,
301 bindings: None,
302 }
303 }
304}
305
306#[derive(Serialize, Debug)]
307pub struct Binding {
308 #[serde(rename = "type")]
309 value_type: String,
310 value: String,
311}
312
313pub trait SnowflakeDeserialize {
314 type Error;
315 fn snowflake_deserialize(
316 response: SnowflakeSQLResponse,
317 ) -> Result<SnowflakeSQLResult<Self>, Self::Error>
318 where
319 Self: Sized;
320}
321
322#[derive(Deserialize, Debug)]
323#[serde(rename_all = "camelCase")]
324pub struct SnowflakeSQLResponse {
325 pub result_set_meta_data: MetaData,
326 pub data: Vec<Vec<String>>,
327 pub code: String,
328 pub statement_status_url: String,
329 pub request_id: String,
330 pub sql_state: String,
331 pub message: String,
332 }
334
335impl SnowflakeSQLResponse {
336 pub fn deserialize<T: SnowflakeDeserialize>(self) -> Result<SnowflakeSQLResult<T>, T::Error> {
337 T::snowflake_deserialize(self)
338 }
339}
340
341#[derive(Deserialize, Debug)]
343#[serde(rename_all = "camelCase")]
344pub struct MetaData {
345 pub num_rows: usize,
346 pub format: String,
347 pub row_type: Vec<RowType>,
348}
349
350#[derive(Deserialize, Debug)]
352#[serde(rename_all = "camelCase")]
353pub struct RowType {
354 pub name: String,
355 pub database: String,
356 pub schema: String,
357 pub table: String,
358 pub precision: Option<u32>,
359 pub byte_length: Option<usize>,
360 #[serde(rename = "type")]
361 pub data_type: String,
362 pub scale: Option<i32>,
363 pub nullable: bool,
364 }
367
368#[derive(Debug)]
370pub enum StatementResult<'a, T> {
371 Status(SnowflakeQueryStatus<'a>),
373 Result(SnowflakeSQLResult<T>),
375}
376#[derive(Debug)]
377pub struct SnowflakeSQLResult<T> {
378 pub data: Vec<T>,
379}
380
381#[derive(Debug)]
382pub struct SnowflakeQueryStatus<'a> {
383 client: &'a reqwest::Client,
384 host: &'a str,
385 query_status: QueryStatus,
386}
387
388impl<'a> SnowflakeQueryStatus<'a> {
389 pub fn take_query_status(self) -> QueryStatus {
390 self.query_status
391 }
392 pub async fn cancel(&self) -> Result<(), QueryCancelError> {
393 let url = format!(
394 "{}statements/{}/cancel",
395 self.host, self.query_status.statement_handle
396 );
397 let response = self.client.post(url).send().await;
398 match response {
399 Ok(r) => match r.status() {
400 reqwest::StatusCode::OK => Ok(()),
401 status => Err(QueryCancelError::Unknown(status)),
402 },
403 Err(e) => Err(QueryCancelError::Request(e)),
404 }
405 }
406}
407
408#[derive(thiserror::Error, Debug)]
410pub enum QueryCancelError {
411 #[error(transparent)]
412 Request(reqwest::Error),
413 #[error("unknown error with status code: {0}")]
414 Unknown(reqwest::StatusCode),
415}
416
417#[derive(serde::Deserialize, Clone, Debug)]
419#[serde(transparent)]
420pub struct StatementHandle(String);
421impl StatementHandle {
422 pub fn handle(&self) -> &str {
423 &self.0
424 }
425}
426impl std::fmt::Display for StatementHandle {
427 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
428 self.0.fmt(f)
429 }
430}
431
432#[derive(serde::Deserialize, thiserror::Error, Debug)]
434#[serde(rename_all = "camelCase")]
435#[error("Error for statement {statement_handle}: {message}")]
436pub struct QueryStatus {
437 code: String,
438 sql_state: String,
439 message: String,
440 statement_handle: StatementHandle,
441 created_on: i64,
442 statement_status_url: String,
443}
444
445impl QueryStatus {
446 pub fn code(&self) -> &str {
447 &self.code
448 }
449 pub fn sql_state(&self) -> &str {
450 &self.sql_state
451 }
452 pub fn message(&self) -> &str {
453 &self.message
454 }
455 pub fn statement_handle(&self) -> &StatementHandle {
456 &self.statement_handle
457 }
458 pub fn created_on(&self) -> i64 {
459 self.created_on
460 }
461 pub fn statement_status_url(&self) -> &str {
462 &self.statement_status_url
463 }
464}
465
466#[derive(serde::Deserialize, thiserror::Error, Debug)]
468#[serde(rename_all = "camelCase")]
469#[error("Error for statement {statement_handle}: {message}")]
470pub struct QueryFailureStatus {
471 code: String,
472 sql_state: String,
473 message: String,
474 statement_handle: StatementHandle,
475 created_on: Option<i64>,
476 statement_status_url: Option<String>,
477}
478
479impl QueryFailureStatus {
480 pub fn code(&self) -> &str {
481 &self.code
482 }
483 pub fn sql_state(&self) -> &str {
484 &self.sql_state
485 }
486 pub fn message(&self) -> &str {
487 &self.message
488 }
489 pub fn statement_handle(&self) -> &StatementHandle {
490 &self.statement_handle
491 }
492 pub fn created_on(&self) -> Option<i64> {
493 self.created_on
494 }
495 pub fn statement_status_url(&self) -> Option<&str> {
496 self.statement_status_url.as_deref()
497 }
498}
499
500pub trait DeserializeFromStr {
505 type Error;
506 fn deserialize_from_str(value: &str) -> Result<Self, Self::Error>
507 where
508 Self: Sized;
509}
510
511impl DeserializeFromStr for chrono::NaiveDate {
512 type Error = chrono::ParseError;
513
514 fn deserialize_from_str(s: &str) -> Result<Self, Self::Error> {
515 chrono::NaiveDate::parse_from_str(s, "%Y-%m-%d")
516 }
517}
518
519impl DeserializeFromStr for chrono::NaiveDateTime {
520 type Error = chrono::ParseError;
521
522 fn deserialize_from_str(s: &str) -> Result<Self, Self::Error> {
523 chrono::NaiveDateTime::parse_from_str(s, "%Y-%m-%d %H:%M:%S%.f")
524 .or_else(|_| chrono::NaiveDateTime::parse_from_str(s, "%Y-%m-%d %H:%M:%S"))
525 }
526}
527
528impl DeserializeFromStr for chrono::DateTime<chrono::Utc> {
529 type Error = chrono::ParseError;
530
531 fn deserialize_from_str(value: &str) -> Result<Self, Self::Error> {
532 chrono::DateTime::parse_from_rfc3339(value).map(|dt| dt.with_timezone(&chrono::Utc))
534 }
535}
536
537impl DeserializeFromStr for chrono::DateTime<chrono::FixedOffset> {
538 type Error = chrono::ParseError;
539
540 fn deserialize_from_str(value: &str) -> Result<Self, Self::Error> {
541 chrono::DateTime::parse_from_rfc3339(value)
542 }
543}
544
545impl<T: DeserializeFromStr> DeserializeFromStr for Option<T> {
546 type Error = <T as DeserializeFromStr>::Error;
547 fn deserialize_from_str(value: &str) -> Result<Self, Self::Error>
548 where
549 Self: Sized,
550 {
551 if value == "NULL" {
552 Ok(None)
553 } else {
554 <T as DeserializeFromStr>::deserialize_from_str(value).map(|f| Some(f))
555 }
556 }
557}
558macro_rules! impl_deserialize_from_str {
559 ($ty: ty) => {
560 impl DeserializeFromStr for $ty {
561 type Error = <$ty as FromStr>::Err;
562 fn deserialize_from_str(value: &str) -> Result<Self, Self::Error> {
563 <$ty>::from_str(value)
564 }
565 }
566 };
567}
568
569impl_deserialize_from_str!(bool);
570impl_deserialize_from_str!(usize);
571impl_deserialize_from_str!(isize);
572impl_deserialize_from_str!(u8);
573impl_deserialize_from_str!(u16);
574impl_deserialize_from_str!(u32);
575impl_deserialize_from_str!(u64);
576impl_deserialize_from_str!(u128);
577impl_deserialize_from_str!(i16);
578impl_deserialize_from_str!(i32);
579impl_deserialize_from_str!(i64);
580impl_deserialize_from_str!(i128);
581impl_deserialize_from_str!(f32);
582impl_deserialize_from_str!(f64);
583impl_deserialize_from_str!(String);
584
585#[cfg(test)]
586mod tests {
587 use super::*;
588
589 #[test]
590 fn sql() -> Result<(), anyhow::Error> {
591 let connector = SnowflakeConnector::try_new_from_file(
592 "./environment_variables/local/rsa_key.pub",
593 "./environment_variables/local/rsa_key.p8",
594 "HOST",
595 "ACCOUNT",
596 "USER",
597 )?;
598 let sql = connector
599 .execute("DB")
600 .sql("SELECT * FROM TEST_TABLE WHERE id = ? AND name = ?")
601 .add_binding(69);
602 if let Some(bindings) = &sql.statement.bindings {
603 assert_eq!(bindings.len(), 1);
604 } else {
605 assert!(sql.statement.bindings.is_some());
606 }
607 let sql = sql.add_binding("JoMama");
608 if let Some(bindings) = &sql.statement.bindings {
609 assert_eq!(bindings.len(), 2);
610 } else {
611 assert!(sql.statement.bindings.is_some());
612 }
613 Ok(())
614 }
615}