Skip to main content

faucet_sink_snowflake/
sink.rs

1//! Snowflake SQL REST API sink.
2
3use crate::config::SnowflakeSinkConfig;
4use async_trait::async_trait;
5use faucet_core::util::quote_ident;
6use faucet_core::{AuthSpec, FaucetError, SharedAuthProvider};
7use faucet_common_snowflake::{
8    SnowflakeAuth, authorization_header, credential_to_auth, snowflake_token_type,
9};
10use reqwest::Client;
11use serde::Deserialize;
12use serde_json::{Value, json};
13
14/// A sink that writes JSON records to a Snowflake table using the
15/// SQL REST API.
16pub struct SnowflakeSink {
17    config: SnowflakeSinkConfig,
18    client: Client,
19    /// Optional explicit endpoint override. When `None`, the URL is derived
20    /// from `config.account`. Used by tests to point the sink at a mock
21    /// server, and useful for proxies / private-link deployments.
22    endpoint: Option<String>,
23    /// Optional shared auth provider. When set, takes precedence over inline
24    /// auth; the provider yields a `Bearer` or `Token` credential mapped onto
25    /// [`SnowflakeAuth::OAuth`]. Set via [`Self::with_auth_provider`].
26    auth_provider: Option<SharedAuthProvider>,
27}
28
29#[derive(Deserialize)]
30struct SnowflakeResponse {
31    message: Option<String>,
32    #[serde(default)]
33    code: Option<String>,
34    /// Present on an HTTP 202 (asynchronous execution) response — the
35    /// opaque handle used to poll the statement to completion.
36    #[serde(rename = "statementHandle", default)]
37    statement_handle: Option<String>,
38}
39
40/// Map a parsed statement response onto a success/error result. Code
41/// `090001` is "Statement executed successfully"; any other non-null code
42/// is a Snowflake-side error.
43fn check_statement_code(sf_resp: &SnowflakeResponse) -> Result<(), FaucetError> {
44    if let Some(code) = &sf_resp.code
45        && code != "090001"
46    {
47        return Err(FaucetError::Sink(format!(
48            "Snowflake error {}: {}",
49            code,
50            sf_resp.message.clone().unwrap_or_default()
51        )));
52    }
53    Ok(())
54}
55
56impl SnowflakeSink {
57    /// Create a new Snowflake sink.
58    ///
59    /// Returns [`FaucetError::Config`] if `batch_size` exceeds
60    /// `MAX_BATCH_SIZE` (#78/#44).
61    pub fn new(config: SnowflakeSinkConfig) -> Result<Self, FaucetError> {
62        faucet_core::validate_batch_size(config.batch_size)?;
63        Ok(Self {
64            config,
65            client: Client::new(),
66            endpoint: None,
67            auth_provider: None,
68        })
69    }
70
71    /// Attach a shared [`AuthProvider`](faucet_core::AuthProvider). When set,
72    /// the provider supplies the credential for every request (taking
73    /// precedence over inline auth), so several sinks can share one OAuth
74    /// token with single-flight refresh. Used by the CLI to resolve
75    /// `auth: { ref }`, and by library callers who inject a provider directly.
76    ///
77    /// The provider must yield a `Bearer` or `Token` credential, which maps
78    /// onto [`SnowflakeAuth::OAuth`]. Key-pair JWT cannot be supplied via a
79    /// provider (JWT is minted locally from the RSA key).
80    pub fn with_auth_provider(mut self, provider: SharedAuthProvider) -> Self {
81        self.auth_provider = Some(provider);
82        self
83    }
84
85    /// Override the API endpoint URL (full URL including
86    /// `/api/v2/statements`). When set, this URL is used verbatim instead
87    /// of the account-derived `https://{account}.snowflakecomputing.com/...`
88    /// URL. Intended for tests (wiremock) and proxy / private-link setups.
89    pub fn with_endpoint(mut self, endpoint: impl Into<String>) -> Self {
90        self.endpoint = Some(endpoint.into());
91        self
92    }
93
94    /// Build the SQL REST API endpoint URL.
95    fn api_url(&self) -> String {
96        if let Some(endpoint) = &self.endpoint {
97            return endpoint.clone();
98        }
99        format!(
100            "https://{}.snowflakecomputing.com/api/v2/statements",
101            self.config.account
102        )
103    }
104
105    /// Resolve the effective [`SnowflakeAuth`] for this request.
106    ///
107    /// Resolution order:
108    /// 1. If a shared provider is attached, call it and map the credential.
109    /// 2. Otherwise, use the inline auth from the config.
110    /// 3. If the config holds an unresolved `Reference` with no provider,
111    ///    return [`FaucetError::Auth`].
112    async fn resolve_auth(&self) -> Result<SnowflakeAuth, FaucetError> {
113        if let Some(p) = &self.auth_provider {
114            return credential_to_auth(p.credential().await?);
115        }
116        match &self.config.auth {
117            AuthSpec::Inline(a) => Ok(a.clone()),
118            AuthSpec::Reference(r) => Err(FaucetError::Auth(format!(
119                "auth references provider '{}' but no provider was supplied",
120                r.name
121            ))),
122        }
123    }
124
125    /// Get the authorization header value.
126    async fn auth_header(&self) -> Result<(String, &'static str), FaucetError> {
127        let effective = self.resolve_auth().await?;
128        let header = authorization_header(&effective, &self.config.account)?;
129        let token_type = snowflake_token_type(&effective);
130        Ok((header, token_type))
131    }
132
133    /// Execute a SQL statement via the REST API, optionally with positional
134    /// bindings (`{"1": {"type": "TEXT", "value": ...}}`).
135    async fn execute_sql(&self, sql: &str, bindings: Option<Value>) -> Result<(), FaucetError> {
136        let url = self.api_url();
137        let (auth, token_type) = self.auth_header().await?;
138
139        let mut body = json!({
140            "statement": sql,
141            "timeout": 60,
142            "database": self.config.database,
143            "schema": self.config.schema,
144            "warehouse": self.config.warehouse,
145        });
146        if let Some(bindings) = bindings {
147            body["bindings"] = bindings;
148        }
149
150        let resp = self
151            .client
152            .post(&url)
153            .header("Authorization", &auth)
154            .header("Content-Type", "application/json")
155            .header("Accept", "application/json")
156            .header("X-Snowflake-Authorization-Token-Type", token_type)
157            .json(&body)
158            .send()
159            .await
160            .map_err(|e| FaucetError::Sink(format!("Snowflake request failed: {e}")))?;
161
162        let status = resp.status();
163        if !status.is_success() {
164            let body_text = resp.text().await.unwrap_or_default();
165            return Err(FaucetError::Sink(format!(
166                "Snowflake SQL API returned HTTP {status}: {body_text}"
167            )));
168        }
169
170        // HTTP 202 means Snowflake *accepted* the statement but has not yet
171        // executed it. Treating that as success would report rows as written
172        // before they are actually committed. Poll the returned handle until
173        // the statement completes (#78/#17).
174        let is_async = status.as_u16() == 202;
175
176        let sf_resp: SnowflakeResponse = resp
177            .json()
178            .await
179            .map_err(|e| FaucetError::Sink(format!("failed to parse Snowflake response: {e}")))?;
180
181        if is_async {
182            let handle = sf_resp.statement_handle.ok_or_else(|| {
183                FaucetError::Sink(
184                    "Snowflake returned HTTP 202 without a statementHandle to poll".into(),
185                )
186            })?;
187            return self.poll_until_complete(&handle).await;
188        }
189
190        check_statement_code(&sf_resp)
191    }
192
193    /// Poll `GET /api/v2/statements/{handle}` until the statement finishes
194    /// executing (HTTP 200 + code `090001`), bounded by `poll_timeout`.
195    async fn poll_until_complete(&self, handle: &str) -> Result<(), FaucetError> {
196        let url = format!("{}/{}", self.api_url(), handle);
197        let poll_timeout = self.config.poll_timeout;
198        let started = std::time::Instant::now();
199        loop {
200            // Re-resolve auth every iteration: a long-running async statement can
201            // outlive a short-lived OAuth token, so we re-ask the (single-flight,
202            // cached) provider for a current token rather than reusing the one
203            // minted at submit time — otherwise the poll 401s mid-run after a
204            // rotation (#146).
205            let (auth, token_type) = self.auth_header().await?;
206            let resp = self
207                .client
208                .get(&url)
209                .header("Authorization", &auth)
210                .header("Accept", "application/json")
211                .header("X-Snowflake-Authorization-Token-Type", token_type)
212                .send()
213                .await
214                .map_err(|e| FaucetError::Sink(format!("Snowflake poll request failed: {e}")))?;
215
216            let status = resp.status();
217            if status.as_u16() == 202 {
218                // `poll_timeout == 0` disables the cap (poll forever).
219                if !poll_timeout.is_zero() && started.elapsed() >= poll_timeout {
220                    return Err(FaucetError::Sink(format!(
221                        "Snowflake statement '{handle}' did not finish within poll_timeout ({}s); still HTTP 202",
222                        poll_timeout.as_secs()
223                    )));
224                }
225                tokio::time::sleep(std::time::Duration::from_millis(500)).await;
226                continue;
227            }
228            if !status.is_success() {
229                let body_text = resp.text().await.unwrap_or_default();
230                return Err(FaucetError::Sink(format!(
231                    "Snowflake poll returned HTTP {status}: {body_text}"
232                )));
233            }
234            let sf_resp: SnowflakeResponse = resp.json().await.map_err(|e| {
235                FaucetError::Sink(format!("failed to parse Snowflake poll response: {e}"))
236            })?;
237            return check_statement_code(&sf_resp);
238        }
239    }
240
241    /// Build an INSERT statement plus the JSON payload to bind to its single
242    /// `PARSE_JSON(?)` parameter.
243    ///
244    /// The record array travels as one bound `TEXT` parameter to
245    /// `PARSE_JSON(?)`, never interpolated into a SQL string literal:
246    /// interpolation was a SQL-injection vector and corrupted any value
247    /// containing an apostrophe (#78/#5). `FLATTEN` then yields one row per
248    /// array element, and each record field is projected into its matching
249    /// column.
250    ///
251    /// The projection is **per-column** — `value:"col"::string` for each key —
252    /// not `SELECT *`. `SELECT *` over `FLATTEN` returns FLATTEN's fixed
253    /// `SEQ, KEY, PATH, INDEX, VALUE, THIS` metadata columns, so the previous
254    /// statement inserted that metadata instead of the record's fields and was
255    /// non-functional for any normal table (audit #146 C2). The `::string` cast
256    /// strips the VARIANT's JSON quotes and lets Snowflake coerce the scalar
257    /// into the destination column's type on `INSERT` (text → number / boolean
258    /// / timestamp, etc.). The column set is taken from the first non-empty
259    /// record; a key missing from a later record projects to SQL `NULL`.
260    ///
261    /// Both the column identifiers and the JSON path keys are escaped via
262    /// [`quote_ident`] (double-quote doubling), so record keys cannot inject
263    /// SQL. Returns `(sql, json_payload)`.
264    ///
265    /// Note: a record key whose target column is semi-structured (`VARIANT` /
266    /// `OBJECT` / `ARRAY`) is stringified by the `::string` cast rather than
267    /// stored as structured JSON; this sink maps records to scalar columns.
268    fn build_insert(&self, records: &[Value]) -> Result<(String, String), FaucetError> {
269        // The column set is the keys of the first non-empty record (all rows in
270        // one INSERT must share a column list). Every record must be an object.
271        let mut columns: Option<Vec<String>> = None;
272        for record in records {
273            let obj = record.as_object().ok_or_else(|| {
274                FaucetError::Sink("Snowflake sink requires JSON object records".into())
275            })?;
276            if columns.is_none() && !obj.is_empty() {
277                columns = Some(obj.keys().cloned().collect());
278            }
279        }
280        let columns = columns.ok_or_else(|| {
281            FaucetError::Sink("Snowflake sink: records have no fields to insert".into())
282        })?;
283
284        // `quote_ident` produces a `"`-escaped quoted identifier, which is also
285        // the correct (injection-safe) form for a FLATTEN path key: `value:"k"`.
286        let col_list = columns
287            .iter()
288            .map(|c| quote_ident(c))
289            .collect::<Vec<_>>()
290            .join(", ");
291        let projection = columns
292            .iter()
293            .map(|c| format!("value:{}::string", quote_ident(c)))
294            .collect::<Vec<_>>()
295            .join(", ");
296
297        let payload = Value::Array(records.to_vec()).to_string();
298        let sql = format!(
299            "INSERT INTO {}.{}.{} ({}) SELECT {} FROM TABLE(FLATTEN(input => PARSE_JSON(?)))",
300            quote_ident(&self.config.database),
301            quote_ident(&self.config.schema),
302            quote_ident(&self.config.table),
303            col_list,
304            projection,
305        );
306        Ok((sql, payload))
307    }
308}
309
310#[async_trait]
311impl faucet_core::Sink for SnowflakeSink {
312    fn config_schema(&self) -> serde_json::Value {
313        serde_json::to_value(faucet_core::schema_for!(SnowflakeSinkConfig))
314            .expect("schema serialization")
315    }
316
317    /// Preflight check (`faucet doctor`).
318    ///
319    /// Runs a single read-only `SELECT 1` through the existing SQL REST API
320    /// request path (`execute_sql`), reusing the sink's
321    /// configured account/warehouse/auth. This resolves the effective
322    /// credential (inline or shared provider), builds the authorization
323    /// header, and confirms Snowflake accepts the session — without writing
324    /// any rows. Auth-resolution, network, and SQL-API errors surface as a
325    /// `Fail` probe with a hint. Tokens are never placed in the reason/hint.
326    async fn check(
327        &self,
328        ctx: &faucet_core::check::CheckContext,
329    ) -> Result<faucet_core::check::CheckReport, FaucetError> {
330        use faucet_core::check::{CheckReport, Probe};
331
332        let started = std::time::Instant::now();
333
334        let result = tokio::time::timeout(ctx.timeout, self.execute_sql("SELECT 1", None)).await;
335
336        let probe = match result {
337            Ok(Ok(())) => Probe::pass("auth", started.elapsed()),
338            Ok(Err(e)) => Probe::fail_hint(
339                "auth",
340                started.elapsed(),
341                format!("Snowflake SELECT 1 failed: {e}"),
342                "Verify the account identifier, warehouse, and credentials \
343                 (OAuth token or key-pair JWT) and that the role can use the \
344                 configured warehouse.",
345            ),
346            Err(_elapsed) => Probe::fail_hint(
347                "auth",
348                started.elapsed(),
349                format!("Snowflake SELECT 1 timed out after {:?}", ctx.timeout),
350                "Check network reachability to the Snowflake SQL REST API \
351                 endpoint and that the warehouse can resume within the timeout.",
352            ),
353        };
354
355        Ok(CheckReport::single(probe))
356    }
357
358    async fn write_batch(&self, records: &[Value]) -> Result<usize, FaucetError> {
359        if records.is_empty() {
360            return Ok(0);
361        }
362
363        // `batch_size = 0` is the "no batching" sentinel: forward whatever
364        // upstream handed us as a single INSERT, preserving `StreamPage`
365        // framing. Otherwise re-chunk into `batch_size` slices so each
366        // outbound REST request stays near Snowflake's documented sweet
367        // spot (~1000 rows).
368        let effective_chunk = if self.config.batch_size == 0 {
369            records.len()
370        } else {
371            self.config.batch_size
372        };
373
374        let mut total = 0;
375        for chunk in records.chunks(effective_chunk) {
376            let (sql, payload) = self.build_insert(chunk)?;
377            let bindings = json!({ "1": { "type": "TEXT", "value": payload } });
378            self.execute_sql(&sql, Some(bindings)).await?;
379            total += chunk.len();
380        }
381
382        tracing::info!(
383            table = %format!(
384                "{}.{}.{}",
385                self.config.database, self.config.schema, self.config.table
386            ),
387            rows = total,
388            "Snowflake write complete"
389        );
390        Ok(total)
391    }
392}
393
394#[cfg(test)]
395mod tests {
396    use super::*;
397    use crate::config::SnowflakeAuth;
398
399    #[test]
400    fn new_rejects_oversized_batch_size() {
401        // Regression for #78/#44.
402        let config = SnowflakeSinkConfig::new(
403            "acct",
404            "wh",
405            "db",
406            "schema",
407            "tbl",
408            SnowflakeAuth::OAuth { token: "t".into() },
409        )
410        .with_batch_size(faucet_core::MAX_BATCH_SIZE + 1);
411        assert!(SnowflakeSink::new(config).is_err());
412    }
413
414    #[test]
415    fn api_url_format() {
416        let config = SnowflakeSinkConfig::new(
417            "xy12345.us-east-1",
418            "wh",
419            "db",
420            "schema",
421            "tbl",
422            SnowflakeAuth::OAuth {
423                token: "tok".into(),
424            },
425        );
426        let sink = SnowflakeSink::new(config).unwrap();
427        assert_eq!(
428            sink.api_url(),
429            "https://xy12345.us-east-1.snowflakecomputing.com/api/v2/statements"
430        );
431    }
432
433    #[tokio::test]
434    async fn oauth_auth_header() {
435        let config = SnowflakeSinkConfig::new(
436            "acct",
437            "wh",
438            "db",
439            "schema",
440            "tbl",
441            SnowflakeAuth::OAuth {
442                token: "my-token".into(),
443            },
444        );
445        let sink = SnowflakeSink::new(config).unwrap();
446        let (header, token_type) = sink.auth_header().await.unwrap();
447        assert_eq!(header, "Snowflake Token=\"my-token\"");
448        assert_eq!(token_type, "OAUTH");
449    }
450
451    #[test]
452    fn api_url_honours_endpoint_override() {
453        let config = SnowflakeSinkConfig::new(
454            "acct",
455            "wh",
456            "db",
457            "schema",
458            "tbl",
459            SnowflakeAuth::OAuth { token: "t".into() },
460        );
461        let sink = SnowflakeSink::new(config)
462            .unwrap()
463            .with_endpoint("http://127.0.0.1:1234/api/v2/statements");
464        assert_eq!(sink.api_url(), "http://127.0.0.1:1234/api/v2/statements");
465    }
466
467    #[test]
468    fn build_insert_uses_quoted_identifiers() {
469        let config = SnowflakeSinkConfig::new(
470            "acct",
471            "wh",
472            "MY_DB",
473            "PUBLIC",
474            "events",
475            SnowflakeAuth::OAuth { token: "t".into() },
476        );
477        let sink = SnowflakeSink::new(config).unwrap();
478        let records = vec![serde_json::json!({"id": 1})];
479        let (sql, _payload) = sink.build_insert(&records).unwrap();
480        assert!(sql.contains("\"MY_DB\".\"PUBLIC\".\"events\""));
481    }
482
483    #[test]
484    fn build_insert_binds_payload_instead_of_interpolating() {
485        // Regression for #78/#5. The record JSON must travel as a bound TEXT
486        // parameter to PARSE_JSON(?), never interpolated into a SQL string
487        // literal — interpolation is a SQL-injection vector and breaks on any
488        // value containing an apostrophe.
489        let config = SnowflakeSinkConfig::new(
490            "acct",
491            "wh",
492            "db",
493            "schema",
494            "tbl",
495            SnowflakeAuth::OAuth { token: "t".into() },
496        );
497        let sink = SnowflakeSink::new(config).unwrap();
498        let records = vec![
499            serde_json::json!({"name": "O'Brien"}),
500            serde_json::json!({"note": "'); DROP TABLE events;--"}),
501        ];
502        let (sql, payload) = sink.build_insert(&records).unwrap();
503
504        // SQL is a parameterised placeholder — no record data, no literal.
505        assert!(sql.contains("PARSE_JSON(?)"), "sql: {sql}");
506        assert!(
507            !sql.contains('\''),
508            "sql must not embed a quoted literal: {sql}"
509        );
510        assert!(!sql.contains("O'Brien"));
511        assert!(!sql.contains("DROP TABLE"));
512
513        // The payload is the JSON array, carrying the apostrophe data intact.
514        let parsed: Value = serde_json::from_str(&payload).unwrap();
515        assert_eq!(parsed[0]["name"], "O'Brien");
516        assert_eq!(parsed[1]["note"], "'); DROP TABLE events;--");
517    }
518
519    #[test]
520    fn build_insert_maps_record_fields_to_columns_not_flatten_metadata() {
521        // C2 regression (audit #146): the INSERT must project each record field
522        // into its named column, NOT `SELECT *` over FLATTEN — `SELECT *` over
523        // FLATTEN returns the fixed SEQ/KEY/PATH/INDEX/VALUE/THIS metadata
524        // columns, so the old statement inserted metadata instead of the
525        // record's own fields.
526        let config = SnowflakeSinkConfig::new(
527            "acct",
528            "wh",
529            "db",
530            "schema",
531            "events",
532            SnowflakeAuth::OAuth { token: "t".into() },
533        );
534        let sink = SnowflakeSink::new(config).unwrap();
535        let records = vec![serde_json::json!({"user_id": 1, "event": "click"})];
536        let (sql, _payload) = sink.build_insert(&records).unwrap();
537
538        // Named column list + per-column projection from the FLATTEN `value`.
539        assert!(sql.contains("\"user_id\""), "sql: {sql}");
540        assert!(sql.contains("\"event\""), "sql: {sql}");
541        assert!(sql.contains("value:\"user_id\"::string"), "sql: {sql}");
542        assert!(sql.contains("value:\"event\"::string"), "sql: {sql}");
543        // Crucially, NOT a metadata-projecting `SELECT *`.
544        assert!(
545            !sql.contains("SELECT *"),
546            "must not SELECT * over FLATTEN: {sql}"
547        );
548        assert!(
549            sql.contains("FLATTEN(input => PARSE_JSON(?))"),
550            "sql: {sql}"
551        );
552    }
553
554    #[test]
555    fn build_insert_escapes_record_keys_in_columns_and_paths() {
556        // Record keys are user-controlled; a key containing a double quote must
557        // be `"`-doubled in both the column list and the FLATTEN path so it
558        // cannot break out of the identifier / path.
559        let config = SnowflakeSinkConfig::new(
560            "acct",
561            "wh",
562            "db",
563            "schema",
564            "events",
565            SnowflakeAuth::OAuth { token: "t".into() },
566        );
567        let sink = SnowflakeSink::new(config).unwrap();
568        let records = vec![serde_json::json!({"a\"b": 1})];
569        let (sql, _payload) = sink.build_insert(&records).unwrap();
570        // Column identifier and path key are both escaped as "a""b".
571        assert!(sql.contains("\"a\"\"b\""), "sql: {sql}");
572        assert!(sql.contains("value:\"a\"\"b\"::string"), "sql: {sql}");
573    }
574
575    #[test]
576    fn build_insert_rejects_all_empty_records() {
577        let config = SnowflakeSinkConfig::new(
578            "acct",
579            "wh",
580            "db",
581            "schema",
582            "events",
583            SnowflakeAuth::OAuth { token: "t".into() },
584        );
585        let sink = SnowflakeSink::new(config).unwrap();
586        let records = vec![serde_json::json!({})];
587        assert!(sink.build_insert(&records).is_err());
588    }
589}