snowflake_api/
lib.rs

1#![doc(
2    issue_tracker_base_url = "https://github.com/mycelial/snowflake-rs/issues",
3    test(no_crate_inject)
4)]
5#![doc = include_str!("../README.md")]
6#![warn(clippy::all, clippy::pedantic)]
7#![allow(
8clippy::must_use_candidate,
9clippy::missing_errors_doc,
10clippy::module_name_repetitions,
11clippy::struct_field_names,
12clippy::future_not_send, // This one seems like something we should eventually fix
13clippy::missing_panics_doc
14)]
15
16use std::fmt::{Display, Formatter};
17use std::io;
18use std::sync::Arc;
19
20use arrow_ipc::reader::StreamReader;
21use base64::Engine;
22use bytes::{Buf, Bytes};
23use futures::future::try_join_all;
24use regex::Regex;
25use reqwest_middleware::ClientWithMiddleware;
26use thiserror::Error;
27
28// Part of public interface
29pub use arrow_array::RecordBatch;
30pub use arrow_schema::ArrowError;
31
32use responses::ExecResponse;
33use session::{AuthError, Session};
34
35use crate::connection::QueryType;
36use crate::connection::{Connection, ConnectionError};
37use crate::requests::ExecRequest;
38use crate::responses::{ExecResponseRowType, SnowflakeType};
39use crate::session::AuthError::MissingEnvArgument;
40
41pub mod connection;
42#[cfg(feature = "polars")]
43mod polars;
44mod put;
45mod requests;
46mod responses;
47mod session;
48
49#[derive(Error, Debug)]
50pub enum SnowflakeApiError {
51    #[error(transparent)]
52    RequestError(#[from] ConnectionError),
53
54    #[error(transparent)]
55    AuthError(#[from] AuthError),
56
57    #[error(transparent)]
58    ResponseDeserializationError(#[from] base64::DecodeError),
59
60    #[error(transparent)]
61    ArrowError(#[from] ArrowError),
62
63    #[error("S3 bucket path in PUT request is invalid: `{0}`")]
64    InvalidBucketPath(String),
65
66    #[error("Couldn't extract filename from the local path: `{0}`")]
67    InvalidLocalPath(String),
68
69    #[error(transparent)]
70    LocalIoError(#[from] io::Error),
71
72    #[error(transparent)]
73    ObjectStoreError(#[from] object_store::Error),
74
75    #[error(transparent)]
76    ObjectStorePathError(#[from] object_store::path::Error),
77
78    #[error(transparent)]
79    TokioTaskJoinError(#[from] tokio::task::JoinError),
80
81    #[error("Snowflake API error. Code: `{0}`. Message: `{1}`")]
82    ApiError(String, String),
83
84    #[error("Snowflake API empty response could mean that query wasn't executed correctly or API call was faulty")]
85    EmptyResponse,
86
87    #[error("No usable rowsets were included in the response")]
88    BrokenResponse,
89
90    #[error("Following feature is not implemented yet: {0}")]
91    Unimplemented(String),
92
93    #[error("Unexpected API response")]
94    UnexpectedResponse,
95
96    #[error(transparent)]
97    GlobPatternError(#[from] glob::PatternError),
98
99    #[error(transparent)]
100    GlobError(#[from] glob::GlobError),
101}
102
103/// Even if Arrow is specified as a return type non-select queries
104/// will return Json array of arrays: `[[42, "answer"], [43, "non-answer"]]`.
105pub struct JsonResult {
106    // todo: can it _only_ be a json array of arrays or something else too?
107    pub value: serde_json::Value,
108    /// Field ordering matches the array ordering
109    pub schema: Vec<FieldSchema>,
110}
111
112impl Display for JsonResult {
113    fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
114        write!(f, "{}", self.value)
115    }
116}
117
118/// Based on the [`ExecResponseRowType`]
119pub struct FieldSchema {
120    pub name: String,
121    // todo: is it a good idea to expose internal response struct to the user?
122    pub type_: SnowflakeType,
123    pub scale: Option<i64>,
124    pub precision: Option<i64>,
125    pub nullable: bool,
126}
127
128impl From<ExecResponseRowType> for FieldSchema {
129    fn from(value: ExecResponseRowType) -> Self {
130        FieldSchema {
131            name: value.name,
132            type_: value.type_,
133            scale: value.scale,
134            precision: value.precision,
135            nullable: value.nullable,
136        }
137    }
138}
139
140/// Container for query result.
141/// Arrow is returned by-default for all SELECT statements,
142/// unless there is session configuration issue or it's a different statement type.
143pub enum QueryResult {
144    Arrow(Vec<RecordBatch>),
145    Json(JsonResult),
146    Empty,
147}
148
149/// Raw query result
150/// Can be transformed into [`QueryResult`]
151pub enum RawQueryResult {
152    /// Arrow IPC chunks
153    /// see: <https://arrow.apache.org/docs/format/Columnar.html#serialization-and-interprocess-communication-ipc>
154    Bytes(Vec<Bytes>),
155    /// Json payload is deserialized,
156    /// as it's already a part of REST response
157    Json(JsonResult),
158    Empty,
159}
160
161impl RawQueryResult {
162    pub fn deserialize_arrow(self) -> Result<QueryResult, ArrowError> {
163        match self {
164            RawQueryResult::Bytes(bytes) => {
165                Self::flat_bytes_to_batches(bytes).map(QueryResult::Arrow)
166            }
167            RawQueryResult::Json(j) => Ok(QueryResult::Json(j)),
168            RawQueryResult::Empty => Ok(QueryResult::Empty),
169        }
170    }
171
172    fn flat_bytes_to_batches(bytes: Vec<Bytes>) -> Result<Vec<RecordBatch>, ArrowError> {
173        let mut res = vec![];
174        for b in bytes {
175            let mut batches = Self::bytes_to_batches(b)?;
176            res.append(&mut batches);
177        }
178        Ok(res)
179    }
180
181    fn bytes_to_batches(bytes: Bytes) -> Result<Vec<RecordBatch>, ArrowError> {
182        let record_batches = StreamReader::try_new(bytes.reader(), None)?;
183        record_batches.into_iter().collect()
184    }
185}
186
187pub struct AuthArgs {
188    pub account_identifier: String,
189    pub warehouse: Option<String>,
190    pub database: Option<String>,
191    pub schema: Option<String>,
192    pub username: String,
193    pub role: Option<String>,
194    pub auth_type: AuthType,
195}
196
197impl AuthArgs {
198    pub fn from_env() -> Result<AuthArgs, SnowflakeApiError> {
199        let auth_type = if let Ok(password) = std::env::var("SNOWFLAKE_PASSWORD") {
200            Ok(AuthType::Password(PasswordArgs { password }))
201        } else if let Ok(private_key_pem) = std::env::var("SNOWFLAKE_PRIVATE_KEY") {
202            Ok(AuthType::Certificate(CertificateArgs { private_key_pem }))
203        } else {
204            Err(MissingEnvArgument(
205                "SNOWFLAKE_PASSWORD or SNOWFLAKE_PRIVATE_KEY".to_owned(),
206            ))
207        };
208
209        Ok(AuthArgs {
210            account_identifier: std::env::var("SNOWFLAKE_ACCOUNT")
211                .map_err(|_| MissingEnvArgument("SNOWFLAKE_ACCOUNT".to_owned()))?,
212            warehouse: std::env::var("SNOWLFLAKE_WAREHOUSE").ok(),
213            database: std::env::var("SNOWFLAKE_DATABASE").ok(),
214            schema: std::env::var("SNOWFLAKE_SCHEMA").ok(),
215            username: std::env::var("SNOWFLAKE_USER")
216                .map_err(|_| MissingEnvArgument("SNOWFLAKE_USER".to_owned()))?,
217            role: std::env::var("SNOWFLAKE_ROLE").ok(),
218            auth_type: auth_type?,
219        })
220    }
221}
222
223pub enum AuthType {
224    Password(PasswordArgs),
225    Certificate(CertificateArgs),
226}
227
228pub struct PasswordArgs {
229    pub password: String,
230}
231
232pub struct CertificateArgs {
233    pub private_key_pem: String,
234}
235
236#[must_use]
237pub struct SnowflakeApiBuilder {
238    pub auth: AuthArgs,
239    client: Option<ClientWithMiddleware>,
240}
241
242impl SnowflakeApiBuilder {
243    pub fn new(auth: AuthArgs) -> Self {
244        Self { auth, client: None }
245    }
246
247    pub fn with_client(mut self, client: ClientWithMiddleware) -> Self {
248        self.client = Some(client);
249        self
250    }
251
252    pub fn build(self) -> Result<SnowflakeApi, SnowflakeApiError> {
253        let connection = match self.client {
254            Some(client) => Arc::new(Connection::new_with_middware(client)),
255            None => Arc::new(Connection::new()?),
256        };
257
258        let session = match self.auth.auth_type {
259            AuthType::Password(args) => Session::password_auth(
260                Arc::clone(&connection),
261                &self.auth.account_identifier,
262                self.auth.warehouse.as_deref(),
263                self.auth.database.as_deref(),
264                self.auth.schema.as_deref(),
265                &self.auth.username,
266                self.auth.role.as_deref(),
267                &args.password,
268            ),
269            AuthType::Certificate(args) => Session::cert_auth(
270                Arc::clone(&connection),
271                &self.auth.account_identifier,
272                self.auth.warehouse.as_deref(),
273                self.auth.database.as_deref(),
274                self.auth.schema.as_deref(),
275                &self.auth.username,
276                self.auth.role.as_deref(),
277                &args.private_key_pem,
278            ),
279        };
280
281        let account_identifier = self.auth.account_identifier.to_uppercase();
282
283        Ok(SnowflakeApi::new(
284            Arc::clone(&connection),
285            session,
286            account_identifier,
287        ))
288    }
289}
290
291/// Snowflake API, keeps connection pool and manages session for you
292pub struct SnowflakeApi {
293    connection: Arc<Connection>,
294    session: Session,
295    account_identifier: String,
296}
297
298impl SnowflakeApi {
299    /// Create a new `SnowflakeApi` object with an existing connection and session.
300    pub fn new(connection: Arc<Connection>, session: Session, account_identifier: String) -> Self {
301        Self {
302            connection,
303            session,
304            account_identifier,
305        }
306    }
307    /// Initialize object with password auth. Authentication happens on the first request.
308    pub fn with_password_auth(
309        account_identifier: &str,
310        warehouse: Option<&str>,
311        database: Option<&str>,
312        schema: Option<&str>,
313        username: &str,
314        role: Option<&str>,
315        password: &str,
316    ) -> Result<Self, SnowflakeApiError> {
317        let connection = Arc::new(Connection::new()?);
318
319        let session = Session::password_auth(
320            Arc::clone(&connection),
321            account_identifier,
322            warehouse,
323            database,
324            schema,
325            username,
326            role,
327            password,
328        );
329
330        let account_identifier = account_identifier.to_uppercase();
331        Ok(Self::new(
332            Arc::clone(&connection),
333            session,
334            account_identifier,
335        ))
336    }
337
338    /// Initialize object with private certificate auth. Authentication happens on the first request.
339    pub fn with_certificate_auth(
340        account_identifier: &str,
341        warehouse: Option<&str>,
342        database: Option<&str>,
343        schema: Option<&str>,
344        username: &str,
345        role: Option<&str>,
346        private_key_pem: &str,
347    ) -> Result<Self, SnowflakeApiError> {
348        let connection = Arc::new(Connection::new()?);
349
350        let session = Session::cert_auth(
351            Arc::clone(&connection),
352            account_identifier,
353            warehouse,
354            database,
355            schema,
356            username,
357            role,
358            private_key_pem,
359        );
360
361        let account_identifier = account_identifier.to_uppercase();
362        Ok(Self::new(
363            Arc::clone(&connection),
364            session,
365            account_identifier,
366        ))
367    }
368
369    pub fn from_env() -> Result<Self, SnowflakeApiError> {
370        SnowflakeApiBuilder::new(AuthArgs::from_env()?).build()
371    }
372
373    /// Closes the current session, this is necessary to clean up temporary objects (tables, functions, etc)
374    /// which are Snowflake session dependent.
375    /// If another request is made the new session will be initiated.
376    pub async fn close_session(&mut self) -> Result<(), SnowflakeApiError> {
377        self.session.close().await?;
378        Ok(())
379    }
380
381    /// Execute a single query against API.
382    /// If statement is PUT, then file will be uploaded to the Snowflake-managed storage
383    pub async fn exec(&self, sql: &str) -> Result<QueryResult, SnowflakeApiError> {
384        let raw = self.exec_raw(sql).await?;
385        let res = raw.deserialize_arrow()?;
386        Ok(res)
387    }
388
389    /// Executes a single query against API.
390    /// If statement is PUT, then file will be uploaded to the Snowflake-managed storage
391    /// Returns raw bytes in the Arrow response
392    pub async fn exec_raw(&self, sql: &str) -> Result<RawQueryResult, SnowflakeApiError> {
393        let put_re = Regex::new(r"(?i)^(?:/\*.*\*/\s*)*put\s+").unwrap();
394
395        // put commands go through a different flow and result is side-effect
396        if put_re.is_match(sql) {
397            log::info!("Detected PUT query");
398            self.exec_put(sql).await.map(|()| RawQueryResult::Empty)
399        } else {
400            self.exec_arrow_raw(sql).await
401        }
402    }
403
404    async fn exec_put(&self, sql: &str) -> Result<(), SnowflakeApiError> {
405        let resp = self
406            .run_sql::<ExecResponse>(sql, QueryType::JsonQuery)
407            .await?;
408        log::debug!("Got PUT response: {resp:?}");
409
410        match resp {
411            ExecResponse::Query(_) => Err(SnowflakeApiError::UnexpectedResponse),
412            ExecResponse::PutGet(pg) => put::put(pg).await,
413            ExecResponse::Error(e) => Err(SnowflakeApiError::ApiError(
414                e.data.error_code,
415                e.message.unwrap_or_default(),
416            )),
417        }
418    }
419
420    /// Useful for debugging to get the straight query response
421    #[cfg(debug_assertions)]
422    pub async fn exec_response(&mut self, sql: &str) -> Result<ExecResponse, SnowflakeApiError> {
423        self.run_sql::<ExecResponse>(sql, QueryType::ArrowQuery)
424            .await
425    }
426
427    /// Useful for debugging to get raw JSON response
428    #[cfg(debug_assertions)]
429    pub async fn exec_json(&mut self, sql: &str) -> Result<serde_json::Value, SnowflakeApiError> {
430        self.run_sql::<serde_json::Value>(sql, QueryType::JsonQuery)
431            .await
432    }
433
434    async fn exec_arrow_raw(&self, sql: &str) -> Result<RawQueryResult, SnowflakeApiError> {
435        let resp = self
436            .run_sql::<ExecResponse>(sql, QueryType::ArrowQuery)
437            .await?;
438        log::debug!("Got query response: {resp:?}");
439
440        let resp = match resp {
441            // processable response
442            ExecResponse::Query(qr) => Ok(qr),
443            ExecResponse::PutGet(_) => Err(SnowflakeApiError::UnexpectedResponse),
444            ExecResponse::Error(e) => Err(SnowflakeApiError::ApiError(
445                e.data.error_code,
446                e.message.unwrap_or_default(),
447            )),
448        }?;
449
450        // if response was empty, base64 data is empty string
451        // todo: still return empty arrow batch with proper schema? (schema always included)
452        if resp.data.returned == 0 {
453            log::debug!("Got response with 0 rows");
454            Ok(RawQueryResult::Empty)
455        } else if let Some(value) = resp.data.rowset {
456            log::debug!("Got JSON response");
457            // NOTE: json response could be chunked too. however, go clients should receive arrow by-default,
458            // unless user sets session variable to return json. This case was added for debugging and status
459            // information being passed through that fields.
460            Ok(RawQueryResult::Json(JsonResult {
461                value,
462                schema: resp.data.rowtype.into_iter().map(Into::into).collect(),
463            }))
464        } else if let Some(base64) = resp.data.rowset_base64 {
465            // fixme: is it possible to give streaming interface?
466            let mut chunks = try_join_all(resp.data.chunks.iter().map(|chunk| {
467                self.connection
468                    .get_chunk(&chunk.url, &resp.data.chunk_headers)
469            }))
470            .await?;
471
472            // fixme: should base64 chunk go first?
473            // fixme: if response is chunked is it both base64 + chunks or just chunks?
474            if !base64.is_empty() {
475                log::debug!("Got base64 encoded response");
476                let bytes = Bytes::from(base64::engine::general_purpose::STANDARD.decode(base64)?);
477                chunks.push(bytes);
478            }
479
480            Ok(RawQueryResult::Bytes(chunks))
481        } else {
482            Err(SnowflakeApiError::BrokenResponse)
483        }
484    }
485
486    async fn run_sql<R: serde::de::DeserializeOwned>(
487        &self,
488        sql_text: &str,
489        query_type: QueryType,
490    ) -> Result<R, SnowflakeApiError> {
491        log::debug!("Executing: {sql_text}");
492
493        let parts = self.session.get_token().await?;
494
495        let body = ExecRequest {
496            sql_text: sql_text.to_string(),
497            async_exec: false,
498            sequence_id: parts.sequence_id,
499            is_internal: false,
500        };
501
502        let resp = self
503            .connection
504            .request::<R>(
505                query_type,
506                &self.account_identifier,
507                &[],
508                Some(&parts.session_token_auth_header),
509                body,
510            )
511            .await?;
512
513        Ok(resp)
514    }
515}