Skip to main content

faucet_source_elasticsearch/
stream.rs

1//! Elasticsearch scroll-based search source.
2
3use crate::config::{ElasticsearchAuth, ElasticsearchSourceConfig};
4use async_trait::async_trait;
5use faucet_core::util::{DEFAULT_ERROR_BODY_MAX_LEN, check_http_response};
6use faucet_core::{AuthSpec, FaucetError, SharedAuthProvider, Stream, StreamPage};
7use reqwest::Client;
8use serde_json::{Value, json};
9use std::pin::Pin;
10
11/// `size` used by the [`Source::stream_pages`] non-scroll fallback (when
12/// `batch_size = 0`). Mirrors Elasticsearch's default `index.max_result_window`
13/// so the request stays within ES's out-of-the-box cap.
14pub(crate) const NO_BATCHING_SEARCH_SIZE: usize = 10_000;
15
16/// A source that reads documents from an Elasticsearch index using the scroll API.
17pub struct ElasticsearchSource {
18    config: ElasticsearchSourceConfig,
19    client: Client,
20    /// Optional shared auth provider. When set it takes precedence over inline
21    /// auth. Injected by the CLI (to resolve `auth: { ref }`) or directly by
22    /// library callers who want to share one token across multiple sources.
23    auth_provider: Option<SharedAuthProvider>,
24}
25
26impl ElasticsearchSource {
27    /// Create a new Elasticsearch source from the given configuration.
28    /// Construction does no I/O; it fails only on an invalid config (an
29    /// out-of-range `batch_size`).
30    pub fn new(config: ElasticsearchSourceConfig) -> Result<Self, FaucetError> {
31        faucet_core::validate_batch_size(config.batch_size)?;
32        Ok(Self {
33            config,
34            client: Client::new(),
35            auth_provider: None,
36        })
37    }
38
39    /// Attach a shared [`AuthProvider`](faucet_core::AuthProvider). When set,
40    /// the provider supplies the credential for every request (taking precedence
41    /// over inline auth), so several sources can share one token with
42    /// single-flight refresh. Used by the CLI to resolve `auth: { ref }`, and
43    /// by library callers who inject one provider into many sources.
44    pub fn with_auth_provider(mut self, provider: SharedAuthProvider) -> Self {
45        self.auth_provider = Some(provider);
46        self
47    }
48
49    /// Resolve the effective [`ElasticsearchAuth`] for the current request.
50    ///
51    /// Resolution order:
52    /// 1. If a shared provider is attached, call it and map the credential.
53    /// 2. Otherwise use the inline auth from config.
54    /// 3. If the config is a `Reference` with no provider, return an error.
55    async fn resolve_auth(&self) -> Result<ElasticsearchAuth, FaucetError> {
56        if let Some(p) = &self.auth_provider {
57            return faucet_common_elasticsearch::credential_to_auth(p.credential().await?);
58        }
59        match &self.config.auth {
60            AuthSpec::Inline(a) => Ok(a.clone()),
61            AuthSpec::Reference(r) => Err(FaucetError::Auth(format!(
62                "auth references provider '{}' but no provider was supplied",
63                r.name
64            ))),
65        }
66    }
67
68    /// Apply an [`ElasticsearchAuth`] to a request builder.
69    fn apply_auth_value(
70        req: reqwest::RequestBuilder,
71        auth: &ElasticsearchAuth,
72    ) -> reqwest::RequestBuilder {
73        match auth {
74            ElasticsearchAuth::None => req,
75            ElasticsearchAuth::Basic { username, password } => {
76                req.basic_auth(username, Some(password))
77            }
78            ElasticsearchAuth::Bearer { token } => req.bearer_auth(token),
79            ElasticsearchAuth::ApiKey { key } => {
80                req.header("Authorization", format!("ApiKey {key}"))
81            }
82        }
83    }
84
85    /// Extract `hits.hits[*]._source` from an Elasticsearch search response.
86    fn extract_hits(body: &Value) -> Vec<Value> {
87        body.get("hits")
88            .and_then(|h| h.get("hits"))
89            .and_then(|h| h.as_array())
90            .map(|hits| {
91                hits.iter()
92                    .filter_map(|hit| hit.get("_source").cloned())
93                    .collect()
94            })
95            .unwrap_or_default()
96    }
97
98    /// Extract the `_scroll_id` from an Elasticsearch response.
99    fn extract_scroll_id(body: &Value) -> Option<String> {
100        body.get("_scroll_id")
101            .and_then(|v| v.as_str())
102            .map(|s| s.to_string())
103    }
104
105    /// Clear a scroll context. Best-effort: errors are logged but not propagated.
106    async fn clear_scroll(&self, scroll_id: &str) {
107        let url = format!("{}/_search/scroll", self.config.base_url);
108        let req = self
109            .client
110            .delete(&url)
111            .json(&json!({"scroll_id": scroll_id}));
112        let auth = match self.resolve_auth().await {
113            Ok(a) => a,
114            Err(e) => {
115                tracing::warn!(error = %e, "failed to resolve auth for scroll cleanup");
116                return;
117            }
118        };
119        let req = Self::apply_auth_value(req, &auth);
120
121        if let Err(e) = req.send().await {
122            tracing::warn!(error = %e, "failed to clear Elasticsearch scroll context");
123        }
124    }
125
126    /// Resolve the index and query under the supplied parent context. Returns
127    /// `(index, query)`.
128    fn resolve_index_and_query(
129        &self,
130        context: &std::collections::HashMap<String, Value>,
131    ) -> Result<(String, Value), FaucetError> {
132        let index = if context.is_empty() {
133            self.config.index.clone()
134        } else {
135            faucet_core::util::substitute_context(&self.config.index, context)
136        };
137        let query = if context.is_empty() {
138            self.config.query.clone()
139        } else {
140            let s = serde_json::to_string(&self.config.query)
141                .map_err(|e| FaucetError::Config(format!("failed to serialize query: {e}")))?;
142            let s = faucet_core::util::substitute_context_json(&s, context);
143            serde_json::from_str(&s).map_err(|e| {
144                FaucetError::Config(format!("failed to parse substituted query: {e}"))
145            })?
146        };
147        Ok((index, query))
148    }
149}
150
151#[async_trait]
152impl faucet_core::Source for ElasticsearchSource {
153    async fn fetch_with_context(
154        &self,
155        context: &std::collections::HashMap<String, serde_json::Value>,
156    ) -> Result<Vec<Value>, FaucetError> {
157        let (index, query) = self.resolve_index_and_query(context)?;
158        // Resolve auth once; reuse the same credential across all scroll pages.
159        let auth = self.resolve_auth().await?;
160
161        let mut all_records = Vec::new();
162
163        // `batch_size = 0` is the "no batching" sentinel. Interpolating it
164        // directly as `size=0` would make Elasticsearch return zero hits, so
165        // map it to the same large page size the streaming path uses (#78/#33).
166        let page_size = if self.config.batch_size == 0 {
167            NO_BATCHING_SEARCH_SIZE
168        } else {
169            self.config.batch_size
170        };
171
172        // Initial search request with scroll.
173        let url = format!(
174            "{}/{}/_search?scroll={}&size={}",
175            self.config.base_url, index, self.config.scroll_timeout, page_size
176        );
177        let req = self.client.post(&url).json(&json!({"query": query}));
178        let req = Self::apply_auth_value(req, &auth);
179
180        let resp = req.send().await?;
181        let resp = check_http_response(resp, DEFAULT_ERROR_BODY_MAX_LEN).await?;
182        let body: Value = resp.json().await?;
183
184        let mut records = Self::extract_hits(&body);
185        let mut scroll_id = Self::extract_scroll_id(&body);
186        let mut pages_fetched: usize = 1;
187
188        tracing::debug!(
189            records = records.len(),
190            page = pages_fetched,
191            "Elasticsearch initial search"
192        );
193
194        all_records.append(&mut records);
195
196        // Scroll loop.
197        while let Some(ref sid) = scroll_id {
198            // Check max_pages limit.
199            if let Some(max) = self.config.max_pages
200                && pages_fetched >= max
201            {
202                tracing::debug!(max_pages = max, "max_pages reached, stopping scroll");
203                break;
204            }
205
206            let scroll_url = format!("{}/_search/scroll", self.config.base_url);
207            let req = self.client.post(&scroll_url).json(&json!({
208                "scroll": self.config.scroll_timeout,
209                "scroll_id": sid,
210            }));
211            let req = Self::apply_auth_value(req, &auth);
212
213            let resp = req.send().await?;
214            let resp = check_http_response(resp, DEFAULT_ERROR_BODY_MAX_LEN).await?;
215            let body: Value = resp.json().await?;
216
217            let mut page_records = Self::extract_hits(&body);
218            pages_fetched += 1;
219
220            tracing::debug!(
221                records = page_records.len(),
222                page = pages_fetched,
223                "Elasticsearch scroll page"
224            );
225
226            // Stop when no more hits are returned.
227            if page_records.is_empty() {
228                break;
229            }
230
231            // Update scroll_id for the next iteration.
232            scroll_id = Self::extract_scroll_id(&body);
233            all_records.append(&mut page_records);
234        }
235
236        // Clear the scroll context (best-effort).
237        if let Some(ref sid) = scroll_id {
238            self.clear_scroll(sid).await;
239        }
240
241        tracing::debug!(
242            total_records = all_records.len(),
243            pages = pages_fetched,
244            "Elasticsearch fetch complete"
245        );
246
247        Ok(all_records)
248    }
249
250    /// Stream documents from Elasticsearch as scroll pages, one
251    /// [`StreamPage`] per scroll response. Bounds client-side memory at
252    /// O(batch_size) regardless of the index's total document count.
253    ///
254    /// The trait-level `batch_size` argument is ignored in favour of
255    /// [`ElasticsearchSourceConfig::batch_size`] — the config is the
256    /// user-facing knob the README documents, and routing the
257    /// pipeline-supplied hint through it would silently override an explicit
258    /// config value.
259    ///
260    /// When `batch_size = 0` the source issues a single non-scroll
261    /// `_search?size=10_000` and emits exactly one page. The scroll API is
262    /// not used and no scroll context needs to be cleared.
263    ///
264    /// The Elasticsearch search source has no incremental-replication mode
265    /// today, so every emitted page carries `bookmark: None`.
266    ///
267    /// **Scroll-context cleanup is mandatory.** On every exit path — clean
268    /// drain, `max_pages` truncation, mid-stream HTTP error, or consumer
269    /// dropping the stream — the open `_scroll_id` is sent to
270    /// `DELETE _search/scroll` so the cluster does not leak server-side
271    /// state. Cleanup runs inside a guard whose `Drop` impl spawns the
272    /// delete request, so even cancellation at any `.await` point still
273    /// releases the context.
274    fn stream_pages<'a>(
275        &'a self,
276        context: &'a std::collections::HashMap<String, Value>,
277        _batch_size: usize,
278    ) -> Pin<Box<dyn Stream<Item = Result<StreamPage, FaucetError>> + Send + 'a>> {
279        let batch_size = self.config.batch_size;
280
281        Box::pin(async_stream::try_stream! {
282            let (index, query) = self.resolve_index_and_query(context)?;
283            // Resolve auth once; reuse across all scroll pages and cleanup.
284            let auth = self.resolve_auth().await?;
285
286            // batch_size == 0: single non-scroll _search with size = max_result_window default.
287            if batch_size == 0 {
288                let url = format!(
289                    "{}/{}/_search?size={}",
290                    self.config.base_url, index, NO_BATCHING_SEARCH_SIZE
291                );
292                let req = self.client.post(&url).json(&json!({"query": query}));
293                let req = Self::apply_auth_value(req, &auth);
294                let resp = req.send().await?;
295                let resp = check_http_response(resp, DEFAULT_ERROR_BODY_MAX_LEN).await?;
296                let body: Value = resp.json().await?;
297                let records = Self::extract_hits(&body);
298                tracing::info!(
299                    docs = records.len(),
300                    batch_size = 0,
301                    "Elasticsearch source stream complete (no-batching path)",
302                );
303                yield StreamPage { records, bookmark: None };
304                return;
305            }
306
307            // Scroll path. Wire up a guard so the scroll context is always
308            // cleared, even on early-return / error / drop.
309            // Pass the already-resolved auth so the guard's spawned cleanup
310            // tasks never need to call async auth resolution.
311            let mut guard = ScrollGuard::new(
312                self.config.base_url.clone(),
313                self.client.clone(),
314                auth.clone(),
315            );
316
317            let url = format!(
318                "{}/{}/_search?scroll={}&size={}",
319                self.config.base_url, index, self.config.scroll_timeout, batch_size
320            );
321            let req = self.client.post(&url).json(&json!({"query": query}));
322            let req = Self::apply_auth_value(req, &auth);
323            let resp = req.send().await?;
324            let resp = check_http_response(resp, DEFAULT_ERROR_BODY_MAX_LEN).await?;
325            let body: Value = resp.json().await?;
326
327            let records = Self::extract_hits(&body);
328            guard.update(Self::extract_scroll_id(&body));
329            let mut pages_emitted: usize = 0;
330            let mut total = records.len();
331
332            // The initial search always counts as page 1, even when it
333            // returns zero hits — emit it and move on.
334            pages_emitted += 1;
335            let is_final = records.is_empty()
336                || guard.scroll_id().is_none()
337                || matches!(self.config.max_pages, Some(max) if pages_emitted >= max);
338            yield StreamPage { records, bookmark: None };
339            if is_final {
340                guard.disarm_if_done();
341                tracing::info!(
342                    docs = total,
343                    pages = pages_emitted,
344                    batch_size,
345                    "Elasticsearch source stream complete",
346                );
347                return;
348            }
349
350            // Scroll loop.
351            while let Some(sid) = guard.scroll_id().map(|s| s.to_string()) {
352                let scroll_url = format!("{}/_search/scroll", self.config.base_url);
353                let req = self.client.post(&scroll_url).json(&json!({
354                    "scroll": self.config.scroll_timeout,
355                    "scroll_id": sid,
356                }));
357                let req = Self::apply_auth_value(req, &auth);
358                let resp = req.send().await?;
359                let resp = check_http_response(resp, DEFAULT_ERROR_BODY_MAX_LEN).await?;
360                let body: Value = resp.json().await?;
361
362                let records = Self::extract_hits(&body);
363                guard.update(Self::extract_scroll_id(&body));
364                pages_emitted += 1;
365                total += records.len();
366
367                let is_empty = records.is_empty();
368                let hit_cap = matches!(self.config.max_pages, Some(max) if pages_emitted >= max);
369
370                if is_empty {
371                    // Final empty page — ES uses an empty hits array as the
372                    // end-of-scroll sentinel. Drop it; nothing to emit.
373                    break;
374                }
375
376                yield StreamPage { records, bookmark: None };
377
378                if hit_cap {
379                    tracing::debug!(
380                        max_pages = self.config.max_pages.unwrap_or(0),
381                        "max_pages reached, stopping scroll"
382                    );
383                    break;
384                }
385            }
386
387            tracing::info!(
388                docs = total,
389                pages = pages_emitted,
390                batch_size,
391                "Elasticsearch source stream complete",
392            );
393
394            // Successful drain — let the guard clean up the scroll id (if any).
395            guard.disarm_if_done();
396        })
397    }
398
399    fn config_schema(&self) -> serde_json::Value {
400        serde_json::to_value(faucet_core::schema_for!(ElasticsearchSourceConfig))
401            .expect("schema serialization")
402    }
403}
404
405/// RAII guard that owns the active scroll id and clears it on drop.
406///
407/// Holds a pre-resolved [`ElasticsearchAuth`] (not `AuthSpec`) so the drop-path
408/// spawned cleanup tasks never need to perform async auth resolution.
409struct ScrollGuard {
410    base_url: String,
411    client: Client,
412    auth: ElasticsearchAuth,
413    scroll_id: Option<String>,
414}
415
416impl ScrollGuard {
417    fn new(base_url: String, client: Client, auth: ElasticsearchAuth) -> Self {
418        Self {
419            base_url,
420            client,
421            auth,
422            scroll_id: None,
423        }
424    }
425
426    fn scroll_id(&self) -> Option<&str> {
427        self.scroll_id.as_deref()
428    }
429
430    fn update(&mut self, new_id: Option<String>) {
431        if let Some(id) = new_id {
432            self.scroll_id = Some(id);
433        }
434    }
435
436    /// Called when the stream drained cleanly. Spawns cleanup as a detached
437    /// task and disarms the drop fallback.
438    fn disarm_if_done(&mut self) {
439        if let Some(sid) = self.scroll_id.take() {
440            let base_url = self.base_url.clone();
441            let auth = self.auth.clone();
442            let client = self.client.clone();
443            tokio::spawn(async move {
444                let url = format!("{base_url}/_search/scroll");
445                let req = client.delete(&url).json(&json!({"scroll_id": sid}));
446                let req = apply_auth_to(req, &auth);
447                if let Err(e) = req.send().await {
448                    tracing::warn!(error = %e, "failed to clear Elasticsearch scroll context");
449                }
450            });
451        }
452    }
453}
454
455impl Drop for ScrollGuard {
456    fn drop(&mut self) {
457        if let Some(sid) = self.scroll_id.take() {
458            // Error / cancellation path. Spawn so cleanup survives the
459            // stream future being dropped mid-await.
460            let base_url = self.base_url.clone();
461            let auth = self.auth.clone();
462            let client = self.client.clone();
463            tokio::spawn(async move {
464                let url = format!("{base_url}/_search/scroll");
465                let req = client.delete(&url).json(&json!({"scroll_id": sid}));
466                let req = apply_auth_to(req, &auth);
467                if let Err(e) = req.send().await {
468                    tracing::warn!(
469                        error = %e,
470                        "failed to clear Elasticsearch scroll context (drop path)",
471                    );
472                }
473            });
474        }
475    }
476}
477
478/// Apply an [`ElasticsearchAuth`] to a request builder. Standalone so
479/// spawned cleanup tasks can use it without holding a source reference.
480fn apply_auth_to(
481    req: reqwest::RequestBuilder,
482    auth: &ElasticsearchAuth,
483) -> reqwest::RequestBuilder {
484    match auth {
485        ElasticsearchAuth::None => req,
486        ElasticsearchAuth::Basic { username, password } => req.basic_auth(username, Some(password)),
487        ElasticsearchAuth::Bearer { token } => req.bearer_auth(token),
488        ElasticsearchAuth::ApiKey { key } => req.header("Authorization", format!("ApiKey {key}")),
489    }
490}
491
492#[cfg(test)]
493mod tests {
494    use super::*;
495
496    #[test]
497    fn new_rejects_out_of_range_batch_size() {
498        let mut config = ElasticsearchSourceConfig::new("http://localhost:9200", "idx");
499        config.batch_size = faucet_core::MAX_BATCH_SIZE + 1;
500        match ElasticsearchSource::new(config) {
501            Err(FaucetError::Config(m)) => assert!(m.contains("batch_size"), "got: {m}"),
502            _ => panic!("expected a batch_size Config error"),
503        }
504    }
505}