1#![doc(
2 issue_tracker_base_url = "https://github.com/mycelial/snowflake-rs/issues",
3 test(no_crate_inject)
4)]
5#![doc = include_str!("../README.md")]
6#![warn(clippy::all, clippy::pedantic)]
7#![allow(
8clippy::must_use_candidate,
9clippy::missing_errors_doc,
10clippy::module_name_repetitions,
11clippy::struct_field_names,
12clippy::future_not_send, clippy::missing_panics_doc
14)]
15
16use std::fmt::{Display, Formatter};
17use std::io;
18use std::sync::Arc;
19
20use arrow_ipc::reader::StreamReader;
21use base64::Engine;
22use bytes::{Buf, Bytes};
23use futures::future::try_join_all;
24use regex::Regex;
25use reqwest_middleware::ClientWithMiddleware;
26use thiserror::Error;
27
28pub use arrow_array::RecordBatch;
30pub use arrow_schema::ArrowError;
31
32use responses::ExecResponse;
33use session::{AuthError, Session};
34
35use crate::connection::QueryType;
36use crate::connection::{Connection, ConnectionError};
37use crate::requests::ExecRequest;
38use crate::responses::{ExecResponseRowType, SnowflakeType};
39use crate::session::AuthError::MissingEnvArgument;
40
41pub mod connection;
42#[cfg(feature = "polars")]
43mod polars;
44mod put;
45mod requests;
46mod responses;
47mod session;
48
49#[derive(Error, Debug)]
50pub enum SnowflakeApiError {
51 #[error(transparent)]
52 RequestError(#[from] ConnectionError),
53
54 #[error(transparent)]
55 AuthError(#[from] AuthError),
56
57 #[error(transparent)]
58 ResponseDeserializationError(#[from] base64::DecodeError),
59
60 #[error(transparent)]
61 ArrowError(#[from] ArrowError),
62
63 #[error("S3 bucket path in PUT request is invalid: `{0}`")]
64 InvalidBucketPath(String),
65
66 #[error("Couldn't extract filename from the local path: `{0}`")]
67 InvalidLocalPath(String),
68
69 #[error(transparent)]
70 LocalIoError(#[from] io::Error),
71
72 #[error(transparent)]
73 ObjectStoreError(#[from] object_store::Error),
74
75 #[error(transparent)]
76 ObjectStorePathError(#[from] object_store::path::Error),
77
78 #[error(transparent)]
79 TokioTaskJoinError(#[from] tokio::task::JoinError),
80
81 #[error("Snowflake API error. Code: `{0}`. Message: `{1}`")]
82 ApiError(String, String),
83
84 #[error("Snowflake API empty response could mean that query wasn't executed correctly or API call was faulty")]
85 EmptyResponse,
86
87 #[error("No usable rowsets were included in the response")]
88 BrokenResponse,
89
90 #[error("Following feature is not implemented yet: {0}")]
91 Unimplemented(String),
92
93 #[error("Unexpected API response")]
94 UnexpectedResponse,
95
96 #[error(transparent)]
97 GlobPatternError(#[from] glob::PatternError),
98
99 #[error(transparent)]
100 GlobError(#[from] glob::GlobError),
101}
102
103pub struct JsonResult {
106 pub value: serde_json::Value,
108 pub schema: Vec<FieldSchema>,
110}
111
112impl Display for JsonResult {
113 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
114 write!(f, "{}", self.value)
115 }
116}
117
118pub struct FieldSchema {
120 pub name: String,
121 pub type_: SnowflakeType,
123 pub scale: Option<i64>,
124 pub precision: Option<i64>,
125 pub nullable: bool,
126}
127
128impl From<ExecResponseRowType> for FieldSchema {
129 fn from(value: ExecResponseRowType) -> Self {
130 FieldSchema {
131 name: value.name,
132 type_: value.type_,
133 scale: value.scale,
134 precision: value.precision,
135 nullable: value.nullable,
136 }
137 }
138}
139
140pub enum QueryResult {
144 Arrow(Vec<RecordBatch>),
145 Json(JsonResult),
146 Empty,
147}
148
149pub enum RawQueryResult {
152 Bytes(Vec<Bytes>),
155 Json(JsonResult),
158 Empty,
159}
160
161impl RawQueryResult {
162 pub fn deserialize_arrow(self) -> Result<QueryResult, ArrowError> {
163 match self {
164 RawQueryResult::Bytes(bytes) => {
165 Self::flat_bytes_to_batches(bytes).map(QueryResult::Arrow)
166 }
167 RawQueryResult::Json(j) => Ok(QueryResult::Json(j)),
168 RawQueryResult::Empty => Ok(QueryResult::Empty),
169 }
170 }
171
172 fn flat_bytes_to_batches(bytes: Vec<Bytes>) -> Result<Vec<RecordBatch>, ArrowError> {
173 let mut res = vec![];
174 for b in bytes {
175 let mut batches = Self::bytes_to_batches(b)?;
176 res.append(&mut batches);
177 }
178 Ok(res)
179 }
180
181 fn bytes_to_batches(bytes: Bytes) -> Result<Vec<RecordBatch>, ArrowError> {
182 let record_batches = StreamReader::try_new(bytes.reader(), None)?;
183 record_batches.into_iter().collect()
184 }
185}
186
187pub struct AuthArgs {
188 pub account_identifier: String,
189 pub warehouse: Option<String>,
190 pub database: Option<String>,
191 pub schema: Option<String>,
192 pub username: String,
193 pub role: Option<String>,
194 pub auth_type: AuthType,
195}
196
197impl AuthArgs {
198 pub fn from_env() -> Result<AuthArgs, SnowflakeApiError> {
199 let auth_type = if let Ok(password) = std::env::var("SNOWFLAKE_PASSWORD") {
200 Ok(AuthType::Password(PasswordArgs { password }))
201 } else if let Ok(private_key_pem) = std::env::var("SNOWFLAKE_PRIVATE_KEY") {
202 Ok(AuthType::Certificate(CertificateArgs { private_key_pem }))
203 } else {
204 Err(MissingEnvArgument(
205 "SNOWFLAKE_PASSWORD or SNOWFLAKE_PRIVATE_KEY".to_owned(),
206 ))
207 };
208
209 Ok(AuthArgs {
210 account_identifier: std::env::var("SNOWFLAKE_ACCOUNT")
211 .map_err(|_| MissingEnvArgument("SNOWFLAKE_ACCOUNT".to_owned()))?,
212 warehouse: std::env::var("SNOWLFLAKE_WAREHOUSE").ok(),
213 database: std::env::var("SNOWFLAKE_DATABASE").ok(),
214 schema: std::env::var("SNOWFLAKE_SCHEMA").ok(),
215 username: std::env::var("SNOWFLAKE_USER")
216 .map_err(|_| MissingEnvArgument("SNOWFLAKE_USER".to_owned()))?,
217 role: std::env::var("SNOWFLAKE_ROLE").ok(),
218 auth_type: auth_type?,
219 })
220 }
221}
222
223pub enum AuthType {
224 Password(PasswordArgs),
225 Certificate(CertificateArgs),
226}
227
228pub struct PasswordArgs {
229 pub password: String,
230}
231
232pub struct CertificateArgs {
233 pub private_key_pem: String,
234}
235
236#[must_use]
237pub struct SnowflakeApiBuilder {
238 pub auth: AuthArgs,
239 client: Option<ClientWithMiddleware>,
240}
241
242impl SnowflakeApiBuilder {
243 pub fn new(auth: AuthArgs) -> Self {
244 Self { auth, client: None }
245 }
246
247 pub fn with_client(mut self, client: ClientWithMiddleware) -> Self {
248 self.client = Some(client);
249 self
250 }
251
252 pub fn build(self) -> Result<SnowflakeApi, SnowflakeApiError> {
253 let connection = match self.client {
254 Some(client) => Arc::new(Connection::new_with_middware(client)),
255 None => Arc::new(Connection::new()?),
256 };
257
258 let session = match self.auth.auth_type {
259 AuthType::Password(args) => Session::password_auth(
260 Arc::clone(&connection),
261 &self.auth.account_identifier,
262 self.auth.warehouse.as_deref(),
263 self.auth.database.as_deref(),
264 self.auth.schema.as_deref(),
265 &self.auth.username,
266 self.auth.role.as_deref(),
267 &args.password,
268 ),
269 AuthType::Certificate(args) => Session::cert_auth(
270 Arc::clone(&connection),
271 &self.auth.account_identifier,
272 self.auth.warehouse.as_deref(),
273 self.auth.database.as_deref(),
274 self.auth.schema.as_deref(),
275 &self.auth.username,
276 self.auth.role.as_deref(),
277 &args.private_key_pem,
278 ),
279 };
280
281 let account_identifier = self.auth.account_identifier.to_uppercase();
282
283 Ok(SnowflakeApi::new(
284 Arc::clone(&connection),
285 session,
286 account_identifier,
287 ))
288 }
289}
290
291pub struct SnowflakeApi {
293 connection: Arc<Connection>,
294 session: Session,
295 account_identifier: String,
296}
297
298impl SnowflakeApi {
299 pub fn new(connection: Arc<Connection>, session: Session, account_identifier: String) -> Self {
301 Self {
302 connection,
303 session,
304 account_identifier,
305 }
306 }
307 pub fn with_password_auth(
309 account_identifier: &str,
310 warehouse: Option<&str>,
311 database: Option<&str>,
312 schema: Option<&str>,
313 username: &str,
314 role: Option<&str>,
315 password: &str,
316 ) -> Result<Self, SnowflakeApiError> {
317 let connection = Arc::new(Connection::new()?);
318
319 let session = Session::password_auth(
320 Arc::clone(&connection),
321 account_identifier,
322 warehouse,
323 database,
324 schema,
325 username,
326 role,
327 password,
328 );
329
330 let account_identifier = account_identifier.to_uppercase();
331 Ok(Self::new(
332 Arc::clone(&connection),
333 session,
334 account_identifier,
335 ))
336 }
337
338 pub fn with_certificate_auth(
340 account_identifier: &str,
341 warehouse: Option<&str>,
342 database: Option<&str>,
343 schema: Option<&str>,
344 username: &str,
345 role: Option<&str>,
346 private_key_pem: &str,
347 ) -> Result<Self, SnowflakeApiError> {
348 let connection = Arc::new(Connection::new()?);
349
350 let session = Session::cert_auth(
351 Arc::clone(&connection),
352 account_identifier,
353 warehouse,
354 database,
355 schema,
356 username,
357 role,
358 private_key_pem,
359 );
360
361 let account_identifier = account_identifier.to_uppercase();
362 Ok(Self::new(
363 Arc::clone(&connection),
364 session,
365 account_identifier,
366 ))
367 }
368
369 pub fn from_env() -> Result<Self, SnowflakeApiError> {
370 SnowflakeApiBuilder::new(AuthArgs::from_env()?).build()
371 }
372
373 pub async fn close_session(&mut self) -> Result<(), SnowflakeApiError> {
377 self.session.close().await?;
378 Ok(())
379 }
380
381 pub async fn exec(&self, sql: &str) -> Result<QueryResult, SnowflakeApiError> {
384 let raw = self.exec_raw(sql).await?;
385 let res = raw.deserialize_arrow()?;
386 Ok(res)
387 }
388
389 pub async fn exec_raw(&self, sql: &str) -> Result<RawQueryResult, SnowflakeApiError> {
393 let put_re = Regex::new(r"(?i)^(?:/\*.*\*/\s*)*put\s+").unwrap();
394
395 if put_re.is_match(sql) {
397 log::info!("Detected PUT query");
398 self.exec_put(sql).await.map(|()| RawQueryResult::Empty)
399 } else {
400 self.exec_arrow_raw(sql).await
401 }
402 }
403
404 async fn exec_put(&self, sql: &str) -> Result<(), SnowflakeApiError> {
405 let resp = self
406 .run_sql::<ExecResponse>(sql, QueryType::JsonQuery)
407 .await?;
408 log::debug!("Got PUT response: {resp:?}");
409
410 match resp {
411 ExecResponse::Query(_) => Err(SnowflakeApiError::UnexpectedResponse),
412 ExecResponse::PutGet(pg) => put::put(pg).await,
413 ExecResponse::Error(e) => Err(SnowflakeApiError::ApiError(
414 e.data.error_code,
415 e.message.unwrap_or_default(),
416 )),
417 }
418 }
419
420 #[cfg(debug_assertions)]
422 pub async fn exec_response(&mut self, sql: &str) -> Result<ExecResponse, SnowflakeApiError> {
423 self.run_sql::<ExecResponse>(sql, QueryType::ArrowQuery)
424 .await
425 }
426
427 #[cfg(debug_assertions)]
429 pub async fn exec_json(&mut self, sql: &str) -> Result<serde_json::Value, SnowflakeApiError> {
430 self.run_sql::<serde_json::Value>(sql, QueryType::JsonQuery)
431 .await
432 }
433
434 async fn exec_arrow_raw(&self, sql: &str) -> Result<RawQueryResult, SnowflakeApiError> {
435 let resp = self
436 .run_sql::<ExecResponse>(sql, QueryType::ArrowQuery)
437 .await?;
438 log::debug!("Got query response: {resp:?}");
439
440 let resp = match resp {
441 ExecResponse::Query(qr) => Ok(qr),
443 ExecResponse::PutGet(_) => Err(SnowflakeApiError::UnexpectedResponse),
444 ExecResponse::Error(e) => Err(SnowflakeApiError::ApiError(
445 e.data.error_code,
446 e.message.unwrap_or_default(),
447 )),
448 }?;
449
450 if resp.data.returned == 0 {
453 log::debug!("Got response with 0 rows");
454 Ok(RawQueryResult::Empty)
455 } else if let Some(value) = resp.data.rowset {
456 log::debug!("Got JSON response");
457 Ok(RawQueryResult::Json(JsonResult {
461 value,
462 schema: resp.data.rowtype.into_iter().map(Into::into).collect(),
463 }))
464 } else if let Some(base64) = resp.data.rowset_base64 {
465 let mut chunks = try_join_all(resp.data.chunks.iter().map(|chunk| {
467 self.connection
468 .get_chunk(&chunk.url, &resp.data.chunk_headers)
469 }))
470 .await?;
471
472 if !base64.is_empty() {
475 log::debug!("Got base64 encoded response");
476 let bytes = Bytes::from(base64::engine::general_purpose::STANDARD.decode(base64)?);
477 chunks.push(bytes);
478 }
479
480 Ok(RawQueryResult::Bytes(chunks))
481 } else {
482 Err(SnowflakeApiError::BrokenResponse)
483 }
484 }
485
486 async fn run_sql<R: serde::de::DeserializeOwned>(
487 &self,
488 sql_text: &str,
489 query_type: QueryType,
490 ) -> Result<R, SnowflakeApiError> {
491 log::debug!("Executing: {sql_text}");
492
493 let parts = self.session.get_token().await?;
494
495 let body = ExecRequest {
496 sql_text: sql_text.to_string(),
497 async_exec: false,
498 sequence_id: parts.sequence_id,
499 is_internal: false,
500 };
501
502 let resp = self
503 .connection
504 .request::<R>(
505 query_type,
506 &self.account_identifier,
507 &[],
508 Some(&parts.session_token_auth_header),
509 body,
510 )
511 .await?;
512
513 Ok(resp)
514 }
515}