1use crate::config::SnowflakeSinkConfig;
4use async_trait::async_trait;
5use faucet_core::util::quote_ident;
6use faucet_core::{AuthSpec, FaucetError, SharedAuthProvider};
7use faucet_common_snowflake::{
8 SnowflakeAuth, authorization_header, credential_to_auth, snowflake_token_type,
9};
10use reqwest::Client;
11use serde::Deserialize;
12use serde_json::{Value, json};
13
14pub struct SnowflakeSink {
17 config: SnowflakeSinkConfig,
18 client: Client,
19 endpoint: Option<String>,
23 auth_provider: Option<SharedAuthProvider>,
27}
28
29#[derive(Deserialize)]
30struct SnowflakeResponse {
31 message: Option<String>,
32 #[serde(default)]
33 code: Option<String>,
34 #[serde(rename = "statementHandle", default)]
37 statement_handle: Option<String>,
38}
39
40fn check_statement_code(sf_resp: &SnowflakeResponse) -> Result<(), FaucetError> {
44 if let Some(code) = &sf_resp.code
45 && code != "090001"
46 {
47 return Err(FaucetError::Sink(format!(
48 "Snowflake error {}: {}",
49 code,
50 sf_resp.message.clone().unwrap_or_default()
51 )));
52 }
53 Ok(())
54}
55
56impl SnowflakeSink {
57 pub fn new(config: SnowflakeSinkConfig) -> Result<Self, FaucetError> {
62 faucet_core::validate_batch_size(config.batch_size)?;
63 Ok(Self {
64 config,
65 client: Client::new(),
66 endpoint: None,
67 auth_provider: None,
68 })
69 }
70
71 pub fn with_auth_provider(mut self, provider: SharedAuthProvider) -> Self {
81 self.auth_provider = Some(provider);
82 self
83 }
84
85 pub fn with_endpoint(mut self, endpoint: impl Into<String>) -> Self {
90 self.endpoint = Some(endpoint.into());
91 self
92 }
93
94 fn api_url(&self) -> String {
96 if let Some(endpoint) = &self.endpoint {
97 return endpoint.clone();
98 }
99 format!(
100 "https://{}.snowflakecomputing.com/api/v2/statements",
101 self.config.account
102 )
103 }
104
105 async fn resolve_auth(&self) -> Result<SnowflakeAuth, FaucetError> {
113 if let Some(p) = &self.auth_provider {
114 return credential_to_auth(p.credential().await?);
115 }
116 match &self.config.auth {
117 AuthSpec::Inline(a) => Ok(a.clone()),
118 AuthSpec::Reference(r) => Err(FaucetError::Auth(format!(
119 "auth references provider '{}' but no provider was supplied",
120 r.name
121 ))),
122 }
123 }
124
125 async fn auth_header(&self) -> Result<(String, &'static str), FaucetError> {
127 let effective = self.resolve_auth().await?;
128 let header = authorization_header(&effective, &self.config.account)?;
129 let token_type = snowflake_token_type(&effective);
130 Ok((header, token_type))
131 }
132
133 async fn execute_sql(&self, sql: &str, bindings: Option<Value>) -> Result<(), FaucetError> {
136 let url = self.api_url();
137 let (auth, token_type) = self.auth_header().await?;
138
139 let mut body = json!({
140 "statement": sql,
141 "timeout": 60,
142 "database": self.config.database,
143 "schema": self.config.schema,
144 "warehouse": self.config.warehouse,
145 });
146 if let Some(bindings) = bindings {
147 body["bindings"] = bindings;
148 }
149
150 let resp = self
151 .client
152 .post(&url)
153 .header("Authorization", &auth)
154 .header("Content-Type", "application/json")
155 .header("Accept", "application/json")
156 .header("X-Snowflake-Authorization-Token-Type", token_type)
157 .json(&body)
158 .send()
159 .await
160 .map_err(|e| FaucetError::Sink(format!("Snowflake request failed: {e}")))?;
161
162 let status = resp.status();
163 if !status.is_success() {
164 let body_text = resp.text().await.unwrap_or_default();
165 return Err(FaucetError::Sink(format!(
166 "Snowflake SQL API returned HTTP {status}: {body_text}"
167 )));
168 }
169
170 let is_async = status.as_u16() == 202;
175
176 let sf_resp: SnowflakeResponse = resp
177 .json()
178 .await
179 .map_err(|e| FaucetError::Sink(format!("failed to parse Snowflake response: {e}")))?;
180
181 if is_async {
182 let handle = sf_resp.statement_handle.ok_or_else(|| {
183 FaucetError::Sink(
184 "Snowflake returned HTTP 202 without a statementHandle to poll".into(),
185 )
186 })?;
187 return self.poll_until_complete(&handle).await;
188 }
189
190 check_statement_code(&sf_resp)
191 }
192
193 async fn poll_until_complete(&self, handle: &str) -> Result<(), FaucetError> {
196 let url = format!("{}/{}", self.api_url(), handle);
197 let poll_timeout = self.config.poll_timeout;
198 let started = std::time::Instant::now();
199 loop {
200 let (auth, token_type) = self.auth_header().await?;
206 let resp = self
207 .client
208 .get(&url)
209 .header("Authorization", &auth)
210 .header("Accept", "application/json")
211 .header("X-Snowflake-Authorization-Token-Type", token_type)
212 .send()
213 .await
214 .map_err(|e| FaucetError::Sink(format!("Snowflake poll request failed: {e}")))?;
215
216 let status = resp.status();
217 if status.as_u16() == 202 {
218 if !poll_timeout.is_zero() && started.elapsed() >= poll_timeout {
220 return Err(FaucetError::Sink(format!(
221 "Snowflake statement '{handle}' did not finish within poll_timeout ({}s); still HTTP 202",
222 poll_timeout.as_secs()
223 )));
224 }
225 tokio::time::sleep(std::time::Duration::from_millis(500)).await;
226 continue;
227 }
228 if !status.is_success() {
229 let body_text = resp.text().await.unwrap_or_default();
230 return Err(FaucetError::Sink(format!(
231 "Snowflake poll returned HTTP {status}: {body_text}"
232 )));
233 }
234 let sf_resp: SnowflakeResponse = resp.json().await.map_err(|e| {
235 FaucetError::Sink(format!("failed to parse Snowflake poll response: {e}"))
236 })?;
237 return check_statement_code(&sf_resp);
238 }
239 }
240
241 fn build_insert(&self, records: &[Value]) -> Result<(String, String), FaucetError> {
269 let mut columns: Option<Vec<String>> = None;
272 for record in records {
273 let obj = record.as_object().ok_or_else(|| {
274 FaucetError::Sink("Snowflake sink requires JSON object records".into())
275 })?;
276 if columns.is_none() && !obj.is_empty() {
277 columns = Some(obj.keys().cloned().collect());
278 }
279 }
280 let columns = columns.ok_or_else(|| {
281 FaucetError::Sink("Snowflake sink: records have no fields to insert".into())
282 })?;
283
284 let col_list = columns
287 .iter()
288 .map(|c| quote_ident(c))
289 .collect::<Vec<_>>()
290 .join(", ");
291 let projection = columns
292 .iter()
293 .map(|c| format!("value:{}::string", quote_ident(c)))
294 .collect::<Vec<_>>()
295 .join(", ");
296
297 let payload = Value::Array(records.to_vec()).to_string();
298 let sql = format!(
299 "INSERT INTO {}.{}.{} ({}) SELECT {} FROM TABLE(FLATTEN(input => PARSE_JSON(?)))",
300 quote_ident(&self.config.database),
301 quote_ident(&self.config.schema),
302 quote_ident(&self.config.table),
303 col_list,
304 projection,
305 );
306 Ok((sql, payload))
307 }
308}
309
310#[async_trait]
311impl faucet_core::Sink for SnowflakeSink {
312 fn config_schema(&self) -> serde_json::Value {
313 serde_json::to_value(faucet_core::schema_for!(SnowflakeSinkConfig))
314 .expect("schema serialization")
315 }
316
317 async fn check(
327 &self,
328 ctx: &faucet_core::check::CheckContext,
329 ) -> Result<faucet_core::check::CheckReport, FaucetError> {
330 use faucet_core::check::{CheckReport, Probe};
331
332 let started = std::time::Instant::now();
333
334 let result = tokio::time::timeout(ctx.timeout, self.execute_sql("SELECT 1", None)).await;
335
336 let probe = match result {
337 Ok(Ok(())) => Probe::pass("auth", started.elapsed()),
338 Ok(Err(e)) => Probe::fail_hint(
339 "auth",
340 started.elapsed(),
341 format!("Snowflake SELECT 1 failed: {e}"),
342 "Verify the account identifier, warehouse, and credentials \
343 (OAuth token or key-pair JWT) and that the role can use the \
344 configured warehouse.",
345 ),
346 Err(_elapsed) => Probe::fail_hint(
347 "auth",
348 started.elapsed(),
349 format!("Snowflake SELECT 1 timed out after {:?}", ctx.timeout),
350 "Check network reachability to the Snowflake SQL REST API \
351 endpoint and that the warehouse can resume within the timeout.",
352 ),
353 };
354
355 Ok(CheckReport::single(probe))
356 }
357
358 async fn write_batch(&self, records: &[Value]) -> Result<usize, FaucetError> {
359 if records.is_empty() {
360 return Ok(0);
361 }
362
363 let effective_chunk = if self.config.batch_size == 0 {
369 records.len()
370 } else {
371 self.config.batch_size
372 };
373
374 let mut total = 0;
375 for chunk in records.chunks(effective_chunk) {
376 let (sql, payload) = self.build_insert(chunk)?;
377 let bindings = json!({ "1": { "type": "TEXT", "value": payload } });
378 self.execute_sql(&sql, Some(bindings)).await?;
379 total += chunk.len();
380 }
381
382 tracing::info!(
383 table = %format!(
384 "{}.{}.{}",
385 self.config.database, self.config.schema, self.config.table
386 ),
387 rows = total,
388 "Snowflake write complete"
389 );
390 Ok(total)
391 }
392}
393
394#[cfg(test)]
395mod tests {
396 use super::*;
397 use crate::config::SnowflakeAuth;
398
399 #[test]
400 fn new_rejects_oversized_batch_size() {
401 let config = SnowflakeSinkConfig::new(
403 "acct",
404 "wh",
405 "db",
406 "schema",
407 "tbl",
408 SnowflakeAuth::OAuth { token: "t".into() },
409 )
410 .with_batch_size(faucet_core::MAX_BATCH_SIZE + 1);
411 assert!(SnowflakeSink::new(config).is_err());
412 }
413
414 #[test]
415 fn api_url_format() {
416 let config = SnowflakeSinkConfig::new(
417 "xy12345.us-east-1",
418 "wh",
419 "db",
420 "schema",
421 "tbl",
422 SnowflakeAuth::OAuth {
423 token: "tok".into(),
424 },
425 );
426 let sink = SnowflakeSink::new(config).unwrap();
427 assert_eq!(
428 sink.api_url(),
429 "https://xy12345.us-east-1.snowflakecomputing.com/api/v2/statements"
430 );
431 }
432
433 #[tokio::test]
434 async fn oauth_auth_header() {
435 let config = SnowflakeSinkConfig::new(
436 "acct",
437 "wh",
438 "db",
439 "schema",
440 "tbl",
441 SnowflakeAuth::OAuth {
442 token: "my-token".into(),
443 },
444 );
445 let sink = SnowflakeSink::new(config).unwrap();
446 let (header, token_type) = sink.auth_header().await.unwrap();
447 assert_eq!(header, "Snowflake Token=\"my-token\"");
448 assert_eq!(token_type, "OAUTH");
449 }
450
451 #[test]
452 fn api_url_honours_endpoint_override() {
453 let config = SnowflakeSinkConfig::new(
454 "acct",
455 "wh",
456 "db",
457 "schema",
458 "tbl",
459 SnowflakeAuth::OAuth { token: "t".into() },
460 );
461 let sink = SnowflakeSink::new(config)
462 .unwrap()
463 .with_endpoint("http://127.0.0.1:1234/api/v2/statements");
464 assert_eq!(sink.api_url(), "http://127.0.0.1:1234/api/v2/statements");
465 }
466
467 #[test]
468 fn build_insert_uses_quoted_identifiers() {
469 let config = SnowflakeSinkConfig::new(
470 "acct",
471 "wh",
472 "MY_DB",
473 "PUBLIC",
474 "events",
475 SnowflakeAuth::OAuth { token: "t".into() },
476 );
477 let sink = SnowflakeSink::new(config).unwrap();
478 let records = vec![serde_json::json!({"id": 1})];
479 let (sql, _payload) = sink.build_insert(&records).unwrap();
480 assert!(sql.contains("\"MY_DB\".\"PUBLIC\".\"events\""));
481 }
482
483 #[test]
484 fn build_insert_binds_payload_instead_of_interpolating() {
485 let config = SnowflakeSinkConfig::new(
490 "acct",
491 "wh",
492 "db",
493 "schema",
494 "tbl",
495 SnowflakeAuth::OAuth { token: "t".into() },
496 );
497 let sink = SnowflakeSink::new(config).unwrap();
498 let records = vec![
499 serde_json::json!({"name": "O'Brien"}),
500 serde_json::json!({"note": "'); DROP TABLE events;--"}),
501 ];
502 let (sql, payload) = sink.build_insert(&records).unwrap();
503
504 assert!(sql.contains("PARSE_JSON(?)"), "sql: {sql}");
506 assert!(
507 !sql.contains('\''),
508 "sql must not embed a quoted literal: {sql}"
509 );
510 assert!(!sql.contains("O'Brien"));
511 assert!(!sql.contains("DROP TABLE"));
512
513 let parsed: Value = serde_json::from_str(&payload).unwrap();
515 assert_eq!(parsed[0]["name"], "O'Brien");
516 assert_eq!(parsed[1]["note"], "'); DROP TABLE events;--");
517 }
518
519 #[test]
520 fn build_insert_maps_record_fields_to_columns_not_flatten_metadata() {
521 let config = SnowflakeSinkConfig::new(
527 "acct",
528 "wh",
529 "db",
530 "schema",
531 "events",
532 SnowflakeAuth::OAuth { token: "t".into() },
533 );
534 let sink = SnowflakeSink::new(config).unwrap();
535 let records = vec![serde_json::json!({"user_id": 1, "event": "click"})];
536 let (sql, _payload) = sink.build_insert(&records).unwrap();
537
538 assert!(sql.contains("\"user_id\""), "sql: {sql}");
540 assert!(sql.contains("\"event\""), "sql: {sql}");
541 assert!(sql.contains("value:\"user_id\"::string"), "sql: {sql}");
542 assert!(sql.contains("value:\"event\"::string"), "sql: {sql}");
543 assert!(
545 !sql.contains("SELECT *"),
546 "must not SELECT * over FLATTEN: {sql}"
547 );
548 assert!(
549 sql.contains("FLATTEN(input => PARSE_JSON(?))"),
550 "sql: {sql}"
551 );
552 }
553
554 #[test]
555 fn build_insert_escapes_record_keys_in_columns_and_paths() {
556 let config = SnowflakeSinkConfig::new(
560 "acct",
561 "wh",
562 "db",
563 "schema",
564 "events",
565 SnowflakeAuth::OAuth { token: "t".into() },
566 );
567 let sink = SnowflakeSink::new(config).unwrap();
568 let records = vec![serde_json::json!({"a\"b": 1})];
569 let (sql, _payload) = sink.build_insert(&records).unwrap();
570 assert!(sql.contains("\"a\"\"b\""), "sql: {sql}");
572 assert!(sql.contains("value:\"a\"\"b\"::string"), "sql: {sql}");
573 }
574
575 #[test]
576 fn build_insert_rejects_all_empty_records() {
577 let config = SnowflakeSinkConfig::new(
578 "acct",
579 "wh",
580 "db",
581 "schema",
582 "events",
583 SnowflakeAuth::OAuth { token: "t".into() },
584 );
585 let sink = SnowflakeSink::new(config).unwrap();
586 let records = vec![serde_json::json!({})];
587 assert!(sink.build_insert(&records).is_err());
588 }
589}