faucet_source_elasticsearch/
stream.rs1use crate::config::{ElasticsearchAuth, ElasticsearchSourceConfig};
4use async_trait::async_trait;
5use faucet_core::util::{DEFAULT_ERROR_BODY_MAX_LEN, check_http_response};
6use faucet_core::{AuthSpec, FaucetError, SharedAuthProvider, Stream, StreamPage};
7use reqwest::Client;
8use serde_json::{Value, json};
9use std::pin::Pin;
10
11pub(crate) const NO_BATCHING_SEARCH_SIZE: usize = 10_000;
15
16pub struct ElasticsearchSource {
18 config: ElasticsearchSourceConfig,
19 client: Client,
20 auth_provider: Option<SharedAuthProvider>,
24}
25
26impl ElasticsearchSource {
27 pub fn new(config: ElasticsearchSourceConfig) -> Result<Self, FaucetError> {
31 faucet_core::validate_batch_size(config.batch_size)?;
32 Ok(Self {
33 config,
34 client: Client::new(),
35 auth_provider: None,
36 })
37 }
38
39 pub fn with_auth_provider(mut self, provider: SharedAuthProvider) -> Self {
45 self.auth_provider = Some(provider);
46 self
47 }
48
49 async fn resolve_auth(&self) -> Result<ElasticsearchAuth, FaucetError> {
56 if let Some(p) = &self.auth_provider {
57 return faucet_common_elasticsearch::credential_to_auth(p.credential().await?);
58 }
59 match &self.config.auth {
60 AuthSpec::Inline(a) => Ok(a.clone()),
61 AuthSpec::Reference(r) => Err(FaucetError::Auth(format!(
62 "auth references provider '{}' but no provider was supplied",
63 r.name
64 ))),
65 }
66 }
67
68 fn apply_auth_value(
70 req: reqwest::RequestBuilder,
71 auth: &ElasticsearchAuth,
72 ) -> reqwest::RequestBuilder {
73 match auth {
74 ElasticsearchAuth::None => req,
75 ElasticsearchAuth::Basic { username, password } => {
76 req.basic_auth(username, Some(password))
77 }
78 ElasticsearchAuth::Bearer { token } => req.bearer_auth(token),
79 ElasticsearchAuth::ApiKey { key } => {
80 req.header("Authorization", format!("ApiKey {key}"))
81 }
82 }
83 }
84
85 fn extract_hits(body: &Value) -> Vec<Value> {
87 body.get("hits")
88 .and_then(|h| h.get("hits"))
89 .and_then(|h| h.as_array())
90 .map(|hits| {
91 hits.iter()
92 .filter_map(|hit| hit.get("_source").cloned())
93 .collect()
94 })
95 .unwrap_or_default()
96 }
97
98 fn extract_scroll_id(body: &Value) -> Option<String> {
100 body.get("_scroll_id")
101 .and_then(|v| v.as_str())
102 .map(|s| s.to_string())
103 }
104
105 async fn clear_scroll(&self, scroll_id: &str) {
107 let url = format!("{}/_search/scroll", self.config.base_url);
108 let req = self
109 .client
110 .delete(&url)
111 .json(&json!({"scroll_id": scroll_id}));
112 let auth = match self.resolve_auth().await {
113 Ok(a) => a,
114 Err(e) => {
115 tracing::warn!(error = %e, "failed to resolve auth for scroll cleanup");
116 return;
117 }
118 };
119 let req = Self::apply_auth_value(req, &auth);
120
121 if let Err(e) = req.send().await {
122 tracing::warn!(error = %e, "failed to clear Elasticsearch scroll context");
123 }
124 }
125
126 fn resolve_index_and_query(
129 &self,
130 context: &std::collections::HashMap<String, Value>,
131 ) -> Result<(String, Value), FaucetError> {
132 let index = if context.is_empty() {
133 self.config.index.clone()
134 } else {
135 faucet_core::util::substitute_context(&self.config.index, context)
136 };
137 let query = if context.is_empty() {
138 self.config.query.clone()
139 } else {
140 let s = serde_json::to_string(&self.config.query)
141 .map_err(|e| FaucetError::Config(format!("failed to serialize query: {e}")))?;
142 let s = faucet_core::util::substitute_context_json(&s, context);
143 serde_json::from_str(&s).map_err(|e| {
144 FaucetError::Config(format!("failed to parse substituted query: {e}"))
145 })?
146 };
147 Ok((index, query))
148 }
149}
150
151#[async_trait]
152impl faucet_core::Source for ElasticsearchSource {
153 async fn fetch_with_context(
154 &self,
155 context: &std::collections::HashMap<String, serde_json::Value>,
156 ) -> Result<Vec<Value>, FaucetError> {
157 let (index, query) = self.resolve_index_and_query(context)?;
158 let auth = self.resolve_auth().await?;
160
161 let mut all_records = Vec::new();
162
163 let page_size = if self.config.batch_size == 0 {
167 NO_BATCHING_SEARCH_SIZE
168 } else {
169 self.config.batch_size
170 };
171
172 let url = format!(
174 "{}/{}/_search?scroll={}&size={}",
175 self.config.base_url, index, self.config.scroll_timeout, page_size
176 );
177 let req = self.client.post(&url).json(&json!({"query": query}));
178 let req = Self::apply_auth_value(req, &auth);
179
180 let resp = req.send().await?;
181 let resp = check_http_response(resp, DEFAULT_ERROR_BODY_MAX_LEN).await?;
182 let body: Value = resp.json().await?;
183
184 let mut records = Self::extract_hits(&body);
185 let mut scroll_id = Self::extract_scroll_id(&body);
186 let mut pages_fetched: usize = 1;
187
188 tracing::debug!(
189 records = records.len(),
190 page = pages_fetched,
191 "Elasticsearch initial search"
192 );
193
194 all_records.append(&mut records);
195
196 while let Some(ref sid) = scroll_id {
198 if let Some(max) = self.config.max_pages
200 && pages_fetched >= max
201 {
202 tracing::debug!(max_pages = max, "max_pages reached, stopping scroll");
203 break;
204 }
205
206 let scroll_url = format!("{}/_search/scroll", self.config.base_url);
207 let req = self.client.post(&scroll_url).json(&json!({
208 "scroll": self.config.scroll_timeout,
209 "scroll_id": sid,
210 }));
211 let req = Self::apply_auth_value(req, &auth);
212
213 let resp = req.send().await?;
214 let resp = check_http_response(resp, DEFAULT_ERROR_BODY_MAX_LEN).await?;
215 let body: Value = resp.json().await?;
216
217 let mut page_records = Self::extract_hits(&body);
218 pages_fetched += 1;
219
220 tracing::debug!(
221 records = page_records.len(),
222 page = pages_fetched,
223 "Elasticsearch scroll page"
224 );
225
226 if page_records.is_empty() {
228 break;
229 }
230
231 scroll_id = Self::extract_scroll_id(&body);
233 all_records.append(&mut page_records);
234 }
235
236 if let Some(ref sid) = scroll_id {
238 self.clear_scroll(sid).await;
239 }
240
241 tracing::debug!(
242 total_records = all_records.len(),
243 pages = pages_fetched,
244 "Elasticsearch fetch complete"
245 );
246
247 Ok(all_records)
248 }
249
250 fn stream_pages<'a>(
275 &'a self,
276 context: &'a std::collections::HashMap<String, Value>,
277 _batch_size: usize,
278 ) -> Pin<Box<dyn Stream<Item = Result<StreamPage, FaucetError>> + Send + 'a>> {
279 let batch_size = self.config.batch_size;
280
281 Box::pin(async_stream::try_stream! {
282 let (index, query) = self.resolve_index_and_query(context)?;
283 let auth = self.resolve_auth().await?;
285
286 if batch_size == 0 {
288 let url = format!(
289 "{}/{}/_search?size={}",
290 self.config.base_url, index, NO_BATCHING_SEARCH_SIZE
291 );
292 let req = self.client.post(&url).json(&json!({"query": query}));
293 let req = Self::apply_auth_value(req, &auth);
294 let resp = req.send().await?;
295 let resp = check_http_response(resp, DEFAULT_ERROR_BODY_MAX_LEN).await?;
296 let body: Value = resp.json().await?;
297 let records = Self::extract_hits(&body);
298 tracing::info!(
299 docs = records.len(),
300 batch_size = 0,
301 "Elasticsearch source stream complete (no-batching path)",
302 );
303 yield StreamPage { records, bookmark: None };
304 return;
305 }
306
307 let mut guard = ScrollGuard::new(
312 self.config.base_url.clone(),
313 self.client.clone(),
314 auth.clone(),
315 );
316
317 let url = format!(
318 "{}/{}/_search?scroll={}&size={}",
319 self.config.base_url, index, self.config.scroll_timeout, batch_size
320 );
321 let req = self.client.post(&url).json(&json!({"query": query}));
322 let req = Self::apply_auth_value(req, &auth);
323 let resp = req.send().await?;
324 let resp = check_http_response(resp, DEFAULT_ERROR_BODY_MAX_LEN).await?;
325 let body: Value = resp.json().await?;
326
327 let records = Self::extract_hits(&body);
328 guard.update(Self::extract_scroll_id(&body));
329 let mut pages_emitted: usize = 0;
330 let mut total = records.len();
331
332 pages_emitted += 1;
335 let is_final = records.is_empty()
336 || guard.scroll_id().is_none()
337 || matches!(self.config.max_pages, Some(max) if pages_emitted >= max);
338 yield StreamPage { records, bookmark: None };
339 if is_final {
340 guard.disarm_if_done();
341 tracing::info!(
342 docs = total,
343 pages = pages_emitted,
344 batch_size,
345 "Elasticsearch source stream complete",
346 );
347 return;
348 }
349
350 while let Some(sid) = guard.scroll_id().map(|s| s.to_string()) {
352 let scroll_url = format!("{}/_search/scroll", self.config.base_url);
353 let req = self.client.post(&scroll_url).json(&json!({
354 "scroll": self.config.scroll_timeout,
355 "scroll_id": sid,
356 }));
357 let req = Self::apply_auth_value(req, &auth);
358 let resp = req.send().await?;
359 let resp = check_http_response(resp, DEFAULT_ERROR_BODY_MAX_LEN).await?;
360 let body: Value = resp.json().await?;
361
362 let records = Self::extract_hits(&body);
363 guard.update(Self::extract_scroll_id(&body));
364 pages_emitted += 1;
365 total += records.len();
366
367 let is_empty = records.is_empty();
368 let hit_cap = matches!(self.config.max_pages, Some(max) if pages_emitted >= max);
369
370 if is_empty {
371 break;
374 }
375
376 yield StreamPage { records, bookmark: None };
377
378 if hit_cap {
379 tracing::debug!(
380 max_pages = self.config.max_pages.unwrap_or(0),
381 "max_pages reached, stopping scroll"
382 );
383 break;
384 }
385 }
386
387 tracing::info!(
388 docs = total,
389 pages = pages_emitted,
390 batch_size,
391 "Elasticsearch source stream complete",
392 );
393
394 guard.disarm_if_done();
396 })
397 }
398
399 fn config_schema(&self) -> serde_json::Value {
400 serde_json::to_value(faucet_core::schema_for!(ElasticsearchSourceConfig))
401 .expect("schema serialization")
402 }
403}
404
405struct ScrollGuard {
410 base_url: String,
411 client: Client,
412 auth: ElasticsearchAuth,
413 scroll_id: Option<String>,
414}
415
416impl ScrollGuard {
417 fn new(base_url: String, client: Client, auth: ElasticsearchAuth) -> Self {
418 Self {
419 base_url,
420 client,
421 auth,
422 scroll_id: None,
423 }
424 }
425
426 fn scroll_id(&self) -> Option<&str> {
427 self.scroll_id.as_deref()
428 }
429
430 fn update(&mut self, new_id: Option<String>) {
431 if let Some(id) = new_id {
432 self.scroll_id = Some(id);
433 }
434 }
435
436 fn disarm_if_done(&mut self) {
439 if let Some(sid) = self.scroll_id.take() {
440 let base_url = self.base_url.clone();
441 let auth = self.auth.clone();
442 let client = self.client.clone();
443 tokio::spawn(async move {
444 let url = format!("{base_url}/_search/scroll");
445 let req = client.delete(&url).json(&json!({"scroll_id": sid}));
446 let req = apply_auth_to(req, &auth);
447 if let Err(e) = req.send().await {
448 tracing::warn!(error = %e, "failed to clear Elasticsearch scroll context");
449 }
450 });
451 }
452 }
453}
454
455impl Drop for ScrollGuard {
456 fn drop(&mut self) {
457 if let Some(sid) = self.scroll_id.take() {
458 let base_url = self.base_url.clone();
461 let auth = self.auth.clone();
462 let client = self.client.clone();
463 tokio::spawn(async move {
464 let url = format!("{base_url}/_search/scroll");
465 let req = client.delete(&url).json(&json!({"scroll_id": sid}));
466 let req = apply_auth_to(req, &auth);
467 if let Err(e) = req.send().await {
468 tracing::warn!(
469 error = %e,
470 "failed to clear Elasticsearch scroll context (drop path)",
471 );
472 }
473 });
474 }
475 }
476}
477
478fn apply_auth_to(
481 req: reqwest::RequestBuilder,
482 auth: &ElasticsearchAuth,
483) -> reqwest::RequestBuilder {
484 match auth {
485 ElasticsearchAuth::None => req,
486 ElasticsearchAuth::Basic { username, password } => req.basic_auth(username, Some(password)),
487 ElasticsearchAuth::Bearer { token } => req.bearer_auth(token),
488 ElasticsearchAuth::ApiKey { key } => req.header("Authorization", format!("ApiKey {key}")),
489 }
490}
491
492#[cfg(test)]
493mod tests {
494 use super::*;
495
496 #[test]
497 fn new_rejects_out_of_range_batch_size() {
498 let mut config = ElasticsearchSourceConfig::new("http://localhost:9200", "idx");
499 config.batch_size = faucet_core::MAX_BATCH_SIZE + 1;
500 match ElasticsearchSource::new(config) {
501 Err(FaucetError::Config(m)) => assert!(m.contains("batch_size"), "got: {m}"),
502 _ => panic!("expected a batch_size Config error"),
503 }
504 }
505}