1use crate::config::{GraphqlAuth, GraphqlStreamConfig};
4use async_trait::async_trait;
5use base64::Engine as _;
6use faucet_core::util::{self, DEFAULT_ERROR_BODY_MAX_LEN};
7use faucet_core::{AuthSpec, Credential, FaucetError, SharedAuthProvider, Stream, StreamPage};
8use jsonpath_rust::JsonPath;
9use reqwest::Client;
10use serde_json::{Value, json};
11use std::collections::HashMap;
12use std::pin::Pin;
13use std::time::Duration;
14
15const RETRY_MAX_ATTEMPTS: u32 = 3;
17const RETRY_BASE_BACKOFF: Duration = Duration::from_millis(500);
19
20pub struct GraphqlStream {
22 config: GraphqlStreamConfig,
23 client: Client,
24 auth_provider: Option<SharedAuthProvider>,
28}
29
30fn credential_to_auth(cred: Credential) -> GraphqlAuth {
33 match cred {
34 Credential::Bearer(token) => GraphqlAuth::Bearer { token },
35 Credential::Token(token) => GraphqlAuth::Custom {
36 headers: HashMap::from([("Authorization".into(), token)]),
37 },
38 Credential::Header { name, value } => GraphqlAuth::Custom {
39 headers: HashMap::from([(name, value)]),
40 },
41 Credential::Basic { username, password } => GraphqlAuth::Custom {
42 headers: HashMap::from([(
43 "Authorization".into(),
44 format!(
45 "Basic {}",
46 base64::engine::general_purpose::STANDARD
47 .encode(format!("{username}:{password}"))
48 ),
49 )]),
50 },
51 }
52}
53
54impl GraphqlStream {
55 pub fn new(config: GraphqlStreamConfig) -> Self {
57 Self {
58 config,
59 client: Client::new(),
60 auth_provider: None,
61 }
62 }
63
64 pub fn with_auth_provider(mut self, provider: SharedAuthProvider) -> Self {
70 self.auth_provider = Some(provider);
71 self
72 }
73
74 pub async fn fetch_all(&self) -> Result<Vec<Value>, FaucetError> {
76 self.fetch_all_with_context(&std::collections::HashMap::new())
77 .await
78 }
79
80 async fn fetch_all_with_context(
82 &self,
83 context: &std::collections::HashMap<String, Value>,
84 ) -> Result<Vec<Value>, FaucetError> {
85 let mut all_records = Vec::new();
86 let mut cursor: Option<String> = None;
87 let mut pages_fetched = 0usize;
88
89 loop {
90 if let Some(max) = self.config.max_pages
91 && pages_fetched >= max
92 {
93 tracing::warn!("max pages ({max}) reached");
94 break;
95 }
96
97 let body = self.execute_query(&cursor, context).await?;
98 let records = self.extract_records(&body)?;
99 all_records.extend(records);
100 pages_fetched += 1;
101
102 match &self.config.pagination {
104 Some(pag) => {
105 let has_next = extract_bool(&body, &pag.has_next_page_path).unwrap_or(false);
106 if !has_next {
107 break;
108 }
109 let next_cursor = extract_string(&body, &pag.cursor_path);
110 if next_cursor.is_none() {
111 break;
112 }
113 if next_cursor == cursor {
119 tracing::warn!("cursor loop detected, stopping pagination");
120 break;
121 }
122 cursor = next_cursor;
123 }
124 None => break,
125 }
126 }
127
128 tracing::info!(
129 records = all_records.len(),
130 pages = pages_fetched,
131 "GraphQL fetch complete"
132 );
133 Ok(all_records)
134 }
135
136 async fn execute_query(
138 &self,
139 cursor: &Option<String>,
140 context: &std::collections::HashMap<String, Value>,
141 ) -> Result<Value, FaucetError> {
142 let mut variables = self.config.variables.clone();
143
144 if !context.is_empty()
146 && let Value::Object(ref mut map) = variables
147 {
148 for (key, value) in context {
149 map.insert(key.clone(), value.clone());
150 }
151 }
152
153 if let (Some(pag), Some(cursor_val)) = (&self.config.pagination, cursor)
155 && let Value::Object(ref mut map) = variables
156 {
157 map.insert(pag.cursor_variable.clone(), json!(cursor_val));
158 }
159 if let Some(pag) = &self.config.pagination
163 && self.config.batch_size != 0
164 && let Value::Object(map) = &mut variables
165 {
166 map.insert(
167 pag.page_size_variable.clone(),
168 json!(self.config.batch_size),
169 );
170 }
171
172 let payload = json!({
173 "query": self.config.query,
174 "variables": variables,
175 });
176
177 let mut req = self
178 .client
179 .post(&self.config.endpoint)
180 .headers(self.config.headers.clone())
181 .json(&payload);
182
183 let effective_auth: GraphqlAuth = if let Some(provider) = &self.auth_provider {
187 credential_to_auth(provider.credential().await?)
188 } else {
189 match &self.config.auth {
190 AuthSpec::Inline(a) => a.clone(),
191 AuthSpec::Reference(r) => {
192 return Err(FaucetError::Auth(format!(
193 "auth references provider '{}' but no provider was supplied; \
194 set one via the CLI `auth:` catalog or `with_auth_provider`",
195 r.name
196 )));
197 }
198 }
199 };
200
201 match effective_auth {
203 GraphqlAuth::None => {}
204 GraphqlAuth::Bearer { token } => {
205 req = req.bearer_auth(token);
206 }
207 GraphqlAuth::Custom { headers } => {
208 let mut hm = reqwest::header::HeaderMap::new();
209 for (name, value) in &headers {
210 let n =
211 reqwest::header::HeaderName::from_bytes(name.as_bytes()).map_err(|e| {
212 FaucetError::Auth(format!("invalid custom header name {name:?}: {e}"))
213 })?;
214 let v = reqwest::header::HeaderValue::from_str(value).map_err(|e| {
215 FaucetError::Auth(format!("invalid custom header value for {name:?}: {e}"))
216 })?;
217 hm.insert(n, v);
218 }
219 req = req.headers(hm);
220 }
221 }
222
223 let body: Value =
228 faucet_core::execute_with_retry(RETRY_MAX_ATTEMPTS, RETRY_BASE_BACKOFF, || {
229 let attempt = req.try_clone();
230 async move {
231 let req = attempt.ok_or_else(|| {
232 FaucetError::Source("graphql: request is not cloneable for retry".into())
233 })?;
234 let resp = req.send().await.map_err(FaucetError::Http)?;
235 let resp = util::check_http_response(resp, DEFAULT_ERROR_BODY_MAX_LEN).await?;
236 resp.json().await.map_err(FaucetError::Http)
237 }
238 })
239 .await?;
240
241 if let Some(errors) = body.get("errors")
243 && let Some(arr) = errors.as_array()
244 && !arr.is_empty()
245 {
246 let msg = arr
247 .iter()
248 .filter_map(|e| e.get("message").and_then(|m| m.as_str()))
249 .collect::<Vec<_>>()
250 .join("; ");
251 let lower = msg.to_lowercase();
257 if self.config.batch_size == 0
258 && let Some(pag) = &self.config.pagination
259 {
260 let var_name = pag.page_size_variable.to_lowercase();
261 if lower.contains(&var_name)
262 && (lower.contains("non-null")
263 || lower.contains("non null")
264 || lower.contains("must not be null")
265 || lower.contains("cannot be null")
266 || lower.contains("required"))
267 {
268 return Err(FaucetError::Config(format!(
269 "batch_size = 0 requires the upstream to accept a null {}: argument \
270 (GraphQL errors: {msg})",
271 pag.page_size_variable
272 )));
273 }
274 }
275 return Err(FaucetError::HttpStatus {
276 status: 200,
277 url: self.config.endpoint.clone(),
278 body: format!("GraphQL errors: {msg}"),
279 });
280 }
281
282 Ok(body)
283 }
284
285 fn extract_records(&self, body: &Value) -> Result<Vec<Value>, FaucetError> {
287 match &self.config.records_path {
288 Some(path) => util::extract_records(body, Some(path)),
289 None => {
290 match body.get("data") {
296 Some(Value::Null) | None => Ok(Vec::new()),
297 Some(data) => Ok(vec![data.clone()]),
298 }
299 }
300 }
301 }
302
303 fn stream_pages_inner(
316 &self,
317 context: &std::collections::HashMap<String, Value>,
318 ) -> Pin<Box<dyn Stream<Item = Result<StreamPage, FaucetError>> + Send + '_>> {
319 let owned_context: std::collections::HashMap<String, Value> = context.clone();
321
322 Box::pin(async_stream::try_stream! {
323 let mut cursor: Option<String> = None;
324 let mut pages_fetched = 0usize;
325 let running_max: Option<Value> = None;
330 let mut bookmark_emitted = false;
331
332 loop {
333 if let Some(max) = self.config.max_pages
334 && pages_fetched >= max
335 {
336 tracing::warn!("max pages ({max}) reached");
337 break;
338 }
339
340 let body = self.execute_query(&cursor, &owned_context).await?;
341 let records = self.extract_records(&body)?;
342 pages_fetched += 1;
343
344 let has_next = match &self.config.pagination {
347 Some(pag) => {
348 let next = extract_bool(&body, &pag.has_next_page_path).unwrap_or(false);
349 if next {
350 let next_cursor = extract_string(&body, &pag.cursor_path);
351 match next_cursor {
352 None => false,
353 Some(next_cursor) => {
354 if Some(&next_cursor) == cursor.as_ref() {
361 tracing::warn!("cursor loop detected, stopping pagination");
362 false
363 } else {
364 cursor = Some(next_cursor);
365 true
366 }
367 }
368 }
369 } else {
370 false
371 }
372 }
373 None => false,
374 };
375
376 if has_next {
377 yield StreamPage { records, bookmark: None };
379 } else {
380 bookmark_emitted = running_max.is_some();
383 yield StreamPage {
384 records,
385 bookmark: running_max.clone(),
386 };
387 break;
388 }
389 }
390
391 if !bookmark_emitted && running_max.is_some() {
398 yield StreamPage {
399 records: Vec::new(),
400 bookmark: running_max,
401 };
402 }
403
404 tracing::info!(
405 pages = pages_fetched,
406 batch_size = self.config.batch_size,
407 "GraphQL source stream complete",
408 );
409 })
410 }
411}
412
413#[async_trait]
414impl faucet_core::Source for GraphqlStream {
415 async fn fetch_with_context(
416 &self,
417 context: &std::collections::HashMap<String, serde_json::Value>,
418 ) -> Result<Vec<Value>, FaucetError> {
419 self.fetch_all_with_context(context).await
420 }
421
422 fn stream_pages<'a>(
429 &'a self,
430 context: &'a std::collections::HashMap<String, Value>,
431 _batch_size: usize,
432 ) -> Pin<Box<dyn Stream<Item = Result<StreamPage, FaucetError>> + Send + 'a>> {
433 self.stream_pages_inner(context)
434 }
435
436 fn config_schema(&self) -> serde_json::Value {
437 serde_json::to_value(faucet_core::schema_for!(GraphqlStreamConfig))
438 .expect("schema serialization")
439 }
440}
441
442fn extract_string(body: &Value, path: &str) -> Option<String> {
443 let results = body.query(path).ok()?;
444 match results.first()? {
445 Value::String(s) => Some(s.clone()),
446 _ => None,
447 }
448}
449
450fn extract_bool(body: &Value, path: &str) -> Option<bool> {
451 let results = body.query(path).ok()?;
452 results.first()?.as_bool()
453}
454
455#[cfg(test)]
456mod tests {
457 use super::*;
458
459 #[test]
460 fn extract_string_from_json() {
461 let body = json!({"data": {"users": {"pageInfo": {"endCursor": "abc123"}}}});
462 assert_eq!(
463 extract_string(&body, "$.data.users.pageInfo.endCursor"),
464 Some("abc123".into())
465 );
466 }
467
468 #[test]
469 fn extract_bool_from_json() {
470 let body = json!({"data": {"users": {"pageInfo": {"hasNextPage": true}}}});
471 assert_eq!(
472 extract_bool(&body, "$.data.users.pageInfo.hasNextPage"),
473 Some(true)
474 );
475 }
476
477 #[test]
478 fn extract_records_with_path() {
479 let config =
480 GraphqlStreamConfig::new("https://api.example.com/graphql", "query { users { id } }")
481 .records_path("$.data.users[*]");
482 let stream = GraphqlStream::new(config);
483 let body = json!({"data": {"users": [{"id": 1}, {"id": 2}]}});
484 let records = stream.extract_records(&body).unwrap();
485 assert_eq!(records.len(), 2);
486 assert_eq!(records[0]["id"], 1);
487 }
488
489 #[test]
490 fn extract_records_without_path_returns_data() {
491 let config =
492 GraphqlStreamConfig::new("https://api.example.com/graphql", "query { user { id } }");
493 let stream = GraphqlStream::new(config);
494 let body = json!({"data": {"user": {"id": 1}}});
495 let records = stream.extract_records(&body).unwrap();
496 assert_eq!(records.len(), 1);
497 assert_eq!(records[0]["user"]["id"], 1);
498 }
499
500 #[test]
501 fn extract_records_without_path_null_data_yields_empty() {
502 let config =
506 GraphqlStreamConfig::new("https://api.example.com/graphql", "query { user { id } }");
507 let stream = GraphqlStream::new(config);
508 let body = json!({ "data": null });
509 let records = stream.extract_records(&body).unwrap();
510 assert!(
511 records.is_empty(),
512 "expected empty Vec for null `data`, got {records:?}"
513 );
514 }
515
516 #[test]
517 fn extract_records_without_path_absent_data_yields_empty() {
518 let config =
521 GraphqlStreamConfig::new("https://api.example.com/graphql", "query { user { id } }");
522 let stream = GraphqlStream::new(config);
523 let body = json!({ "extensions": { "foo": 1 } });
524 let records = stream.extract_records(&body).unwrap();
525 assert!(
526 records.is_empty(),
527 "expected empty Vec when `data` is absent, got {records:?}"
528 );
529 }
530}