Skip to main content

drasi_bootstrap_http/
provider.rs

1// Copyright 2025 The Drasi Authors.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! HTTP Bootstrap Provider implementation.
16//!
17//! Fetches initial state from HTTP REST APIs and emits graph elements
18//! through the bootstrap event channel.
19
20use anyhow::{Context, Result};
21use async_trait::async_trait;
22use chrono::Utc;
23use log::{debug, error, info, warn};
24use reqwest::Client;
25use std::collections::HashSet;
26use std::time::Duration;
27
28use drasi_core::models::SourceChange;
29use drasi_lib::bootstrap::{
30    BootstrapContext, BootstrapProvider, BootstrapRequest, BootstrapResult,
31};
32use drasi_lib::channels::BootstrapEvent;
33
34use crate::auth::{self, ResolvedAuth};
35use crate::config::{EndpointConfig, HttpBootstrapConfig, HttpMethod};
36use crate::content_parser::{self, ContentType};
37use crate::pagination::{self, NextPage, Paginator};
38use crate::response;
39use crate::template_engine::TemplateEngine;
40
41/// Default maximum number of pages to fetch per endpoint to prevent infinite loops.
42const DEFAULT_MAX_PAGES: u64 = 10_000;
43
44/// Maximum backoff delay (60 seconds) to prevent unbounded sleep.
45const MAX_RETRY_DELAY: Duration = Duration::from_secs(60);
46
47/// Safely truncate a string to at most `max_chars` characters without panicking
48/// on multi-byte UTF-8 boundaries.
49fn safe_truncate(s: &str, max_chars: usize) -> &str {
50    if s.len() <= max_chars {
51        return s;
52    }
53    let mut end = max_chars;
54    while end > 0 && !s.is_char_boundary(end) {
55        end -= 1;
56    }
57    &s[..end]
58}
59
60/// A resolved endpoint bundling config with its resolved auth.
61struct ResolvedEndpoint {
62    config: EndpointConfig,
63    auth: Option<ResolvedAuth>,
64}
65
66/// HTTP Bootstrap Provider that fetches data from REST APIs.
67pub struct HttpBootstrapProvider {
68    config: HttpBootstrapConfig,
69    client: Client,
70    endpoints: Vec<ResolvedEndpoint>,
71    engine: TemplateEngine,
72}
73
74impl HttpBootstrapProvider {
75    /// Create a new HTTP bootstrap provider from configuration.
76    pub fn new(config: HttpBootstrapConfig) -> Result<Self> {
77        let timeout = Duration::from_secs(config.timeout_seconds);
78        let client = Client::builder()
79            .timeout(timeout)
80            // Force HTTP/1.1 to avoid HTTP/2 multiplexing issues when multiple
81            // bootstrap calls share this client concurrently. H2 frame reassembly
82            // under heavy multiplexing can corrupt large response bodies.
83            .http1_only()
84            .build()
85            .context("Failed to build HTTP client")?;
86
87        // Resolve authentication for each endpoint and bundle them together
88        let mut endpoints = Vec::new();
89        for endpoint in &config.endpoints {
90            let auth = match &endpoint.auth {
91                Some(auth_config) => Some(
92                    auth::resolve_auth(auth_config, &client)
93                        .context("Failed to resolve authentication")?,
94                ),
95                None => None,
96            };
97            endpoints.push(ResolvedEndpoint {
98                config: endpoint.clone(),
99                auth,
100            });
101        }
102
103        let engine = TemplateEngine::new();
104
105        Ok(Self {
106            config,
107            client,
108            endpoints,
109            engine,
110        })
111    }
112
113    /// Fetch all pages from a single endpoint and emit bootstrap events.
114    async fn fetch_endpoint(
115        &self,
116        endpoint: &EndpointConfig,
117        auth: &Option<ResolvedAuth>,
118        context: &BootstrapContext,
119        request: &BootstrapRequest,
120        event_tx: &drasi_lib::channels::BootstrapEventSender,
121    ) -> Result<u64> {
122        let mut total_sent: u64 = 0;
123
124        // Determine content type override
125        let content_type_override = endpoint
126            .response
127            .content_type
128            .as_ref()
129            .map(ContentType::from_override);
130
131        // Set up pagination with SSRF protection (origin host validation)
132        let origin_host = pagination::extract_origin_host(&endpoint.url).unwrap_or_default();
133        let mut paginator: Option<Box<dyn Paginator>> = endpoint
134            .pagination
135            .as_ref()
136            .map(|p| pagination::create_paginator(p, origin_host));
137
138        // Get initial pagination params
139        let initial_params: Vec<(String, String)> = paginator
140            .as_ref()
141            .map(|p| p.initial_params())
142            .unwrap_or_default();
143
144        let mut current_url = endpoint.url.clone();
145        let mut current_params = initial_params;
146        let mut page_num = 0u64;
147        let mut seen_requests: HashSet<(String, Vec<(String, String)>)> = HashSet::new();
148
149        loop {
150            page_num += 1;
151
152            // Prevent infinite loops
153            let max_pages = self.config.max_pages.unwrap_or(DEFAULT_MAX_PAGES);
154            if page_num > max_pages {
155                error!(
156                    "Reached maximum page limit ({max_pages}) for endpoint '{}', stopping pagination. \
157                     Configure 'maxPages' to increase this limit.",
158                    endpoint.url
159                );
160                break;
161            }
162
163            // Detect cycles: same URL + same params seen before
164            let current_key = (current_url.clone(), current_params.clone());
165            if !seen_requests.insert(current_key) {
166                warn!(
167                    "Pagination cycle detected for endpoint '{}', stopping",
168                    endpoint.url
169                );
170                break;
171            }
172
173            debug!("Fetching page {page_num} from endpoint: {}", endpoint.url);
174
175            // Make the HTTP request with retries
176            let (response_text, response_headers) = self
177                .fetch_with_retry(&current_url, endpoint, auth, &current_params)
178                .await
179                .with_context(|| {
180                    format!(
181                        "Failed to fetch from endpoint '{}' (page {page_num})",
182                        endpoint.url
183                    )
184                })?;
185
186            // Determine content type from override or response header
187            let ct = content_type_override.clone().unwrap_or_else(|| {
188                let header_value = response_headers
189                    .get(reqwest::header::CONTENT_TYPE)
190                    .and_then(|v| v.to_str().ok());
191                ContentType::from_header(header_value)
192            });
193
194            // Parse response body using the correct content type
195            let parsed_body = match content_parser::parse_body(&response_text, &ct) {
196                Ok(body) => body,
197                Err(e) => {
198                    error!(
199                        "Failed to parse response from '{}' as {ct:?}. Body length: {}, first 200 chars: {:?}",
200                        endpoint.url,
201                        response_text.len(),
202                        safe_truncate(&response_text, 200)
203                    );
204                    return Err(e.context(format!(
205                        "Failed to parse response from '{}' as {ct:?}",
206                        endpoint.url
207                    )));
208                }
209            };
210
211            // Extract items
212            let items = response::extract_items(&parsed_body, &endpoint.response.items_path)?;
213            let items_count = items.len();
214
215            debug!(
216                "Extracted {items_count} items from page {page_num} of {}",
217                endpoint.url
218            );
219
220            // If no items, we're done
221            if items_count == 0 {
222                break;
223            }
224
225            // Map items to elements and emit events
226            let element_results = response::map_items_to_elements(
227                &items,
228                &endpoint.response.mappings,
229                &context.source_id,
230                &self.engine,
231            );
232
233            for result in element_results {
234                match result {
235                    Ok(element) => {
236                        // Filter by requested labels
237                        if !should_include_element(&element, request) {
238                            continue;
239                        }
240
241                        let source_change = SourceChange::Insert { element };
242                        let sequence = context.next_sequence();
243
244                        let bootstrap_event = BootstrapEvent {
245                            source_id: context.source_id.clone(),
246                            change: source_change,
247                            timestamp: Utc::now(),
248                            sequence,
249                        };
250
251                        event_tx
252                            .send(bootstrap_event)
253                            .await
254                            .context("Failed to send bootstrap event")?;
255
256                        total_sent += 1;
257                    }
258                    Err(e) => {
259                        warn!("Failed to map item to element: {e}");
260                    }
261                }
262            }
263
264            // Check pagination for next page
265            match paginator.as_mut() {
266                Some(ref mut pag) => {
267                    match pag.next_page(&parsed_body, &response_headers, items_count)? {
268                        Some(NextPage::QueryParams(params)) => {
269                            current_params = params;
270                        }
271                        Some(NextPage::NewUrl(url)) => {
272                            current_url = url;
273                            current_params = Vec::new();
274                        }
275                        None => break,
276                    }
277                }
278                None => break, // No pagination, single page only
279            }
280        }
281
282        info!(
283            "Completed fetching from endpoint '{}': {} pages, {} elements",
284            endpoint.url, page_num, total_sent
285        );
286
287        Ok(total_sent)
288    }
289
290    /// Fetch a URL with retry logic.
291    async fn fetch_with_retry(
292        &self,
293        url: &str,
294        endpoint: &EndpointConfig,
295        auth: &Option<ResolvedAuth>,
296        query_params: &[(String, String)],
297    ) -> Result<(String, reqwest::header::HeaderMap)> {
298        let max_retries = self.config.max_retries;
299        let retry_delay = Duration::from_millis(self.config.retry_delay_ms);
300
301        let mut last_error = None;
302
303        for attempt in 0..=max_retries {
304            if attempt > 0 {
305                let factor = 1u64.checked_shl(attempt - 1).unwrap_or(u64::MAX);
306                let delay = retry_delay
307                    .saturating_mul(factor.min(u32::MAX as u64) as u32)
308                    .min(MAX_RETRY_DELAY);
309                debug!("Retry attempt {attempt} after {delay:?} delay");
310                tokio::time::sleep(delay).await;
311            }
312
313            match self.make_request(url, endpoint, auth, query_params).await {
314                Ok(result) => return Ok(result),
315                Err(e) => {
316                    warn!(
317                        "Request to endpoint failed (attempt {}/{}): {}",
318                        attempt + 1,
319                        max_retries + 1,
320                        e
321                    );
322                    last_error = Some(e);
323                }
324            }
325        }
326
327        Err(last_error.unwrap_or_else(|| anyhow::anyhow!("Request failed with no error details")))
328    }
329
330    /// Make a single HTTP request.
331    /// Returns raw response text and headers for proper content-type handling.
332    async fn make_request(
333        &self,
334        url: &str,
335        endpoint: &EndpointConfig,
336        auth: &Option<ResolvedAuth>,
337        query_params: &[(String, String)],
338    ) -> Result<(String, reqwest::header::HeaderMap)> {
339        // Build request
340        let mut builder = match endpoint.method {
341            HttpMethod::Get => self.client.get(url),
342            HttpMethod::Post => self.client.post(url),
343            HttpMethod::Put => self.client.put(url),
344        };
345
346        // Add headers
347        for (key, value) in &endpoint.headers {
348            builder = builder.header(key.as_str(), value.as_str());
349        }
350
351        // Add query parameters
352        if !query_params.is_empty() {
353            builder = builder.query(query_params);
354        }
355
356        // Add request body
357        if let Some(ref body) = endpoint.body {
358            builder = builder.json(body);
359        }
360
361        // Apply auth
362        if let Some(ref resolved_auth) = auth {
363            builder = auth::apply_auth(builder, resolved_auth).await?;
364        }
365
366        // Send request
367        let response = builder.send().await.context("HTTP request failed")?;
368
369        if !response.status().is_success() {
370            let status = response.status();
371            let body = response
372                .text()
373                .await
374                .unwrap_or_else(|_| "Unable to read error response".to_string());
375            let truncated = if body.len() > 256 {
376                format!("{}... (truncated)", safe_truncate(&body, 256))
377            } else {
378                body
379            };
380            return Err(anyhow::anyhow!(
381                "HTTP request returned error status {status}: {truncated}"
382            ));
383        }
384
385        let headers = response.headers().clone();
386        let body_text = response
387            .text()
388            .await
389            .context("Failed to read response body")?;
390
391        Ok((body_text, headers))
392    }
393}
394
395/// Check if an element should be included based on the bootstrap request's label filters.
396fn should_include_element(
397    element: &drasi_core::models::Element,
398    request: &BootstrapRequest,
399) -> bool {
400    match element {
401        drasi_core::models::Element::Node { metadata, .. } => {
402            if request.node_labels.is_empty() {
403                return true;
404            }
405            metadata
406                .labels
407                .iter()
408                .any(|l| request.node_labels.iter().any(|nl| nl.as_str() == &**l))
409        }
410        drasi_core::models::Element::Relation { metadata, .. } => {
411            if request.relation_labels.is_empty() {
412                return true;
413            }
414            metadata
415                .labels
416                .iter()
417                .any(|l| request.relation_labels.iter().any(|rl| rl.as_str() == &**l))
418        }
419    }
420}
421
422#[async_trait]
423impl BootstrapProvider for HttpBootstrapProvider {
424    async fn bootstrap(
425        &self,
426        request: BootstrapRequest,
427        context: &BootstrapContext,
428        event_tx: drasi_lib::channels::BootstrapEventSender,
429        _settings: Option<&drasi_lib::config::SourceSubscriptionSettings>,
430    ) -> Result<BootstrapResult> {
431        info!(
432            "Starting HTTP bootstrap for query {} from source {}",
433            request.query_id, context.source_id
434        );
435
436        let mut total_events: u64 = 0;
437
438        for resolved in &self.endpoints {
439            match self
440                .fetch_endpoint(
441                    &resolved.config,
442                    &resolved.auth,
443                    context,
444                    &request,
445                    &event_tx,
446                )
447                .await
448            {
449                Ok(count) => {
450                    total_events += count;
451                }
452                Err(e) => {
453                    error!(
454                        "Failed to bootstrap from endpoint '{}': {}",
455                        resolved.config.url, e
456                    );
457                    return Err(e);
458                }
459            }
460        }
461
462        info!(
463            "Completed HTTP bootstrap for query {}: {} total elements",
464            request.query_id, total_events
465        );
466
467        Ok(BootstrapResult {
468            event_count: total_events as usize,
469            last_sequence: None,
470            sequences_aligned: false,
471            source_position: None,
472        })
473    }
474}