use anyhow::{Context, Result};
use async_trait::async_trait;
use chrono::Utc;
use futures::StreamExt;
use log::{debug, info, warn};
use reqwest::Client;
use serde::{Deserialize, Serialize};
use serde_json::Map;
use std::sync::Arc;
use std::time::Duration;
use drasi_core::models::{Element, ElementMetadata, ElementReference, SourceChange};
use drasi_lib::bootstrap::{BootstrapContext, BootstrapProvider, BootstrapRequest};
use drasi_lib::sources::manager::convert_json_to_element_properties;
#[derive(Debug, Clone, Serialize)]
#[serde(rename_all = "camelCase")]
struct SubscriptionRequest {
query_id: String,
query_node_id: String,
node_labels: Vec<String>,
rel_labels: Vec<String>,
}
#[derive(Debug, Clone, Deserialize)]
#[serde(rename_all = "camelCase")]
struct BootstrapElement {
id: String,
labels: Vec<String>,
properties: Map<String, serde_json::Value>,
#[serde(skip_serializing_if = "Option::is_none")]
start_id: Option<String>,
#[serde(skip_serializing_if = "Option::is_none")]
end_id: Option<String>,
}
use drasi_lib::bootstrap::PlatformBootstrapConfig;
pub struct PlatformBootstrapProvider {
query_api_url: String,
client: Client,
}
impl PlatformBootstrapProvider {
pub fn new(config: PlatformBootstrapConfig) -> Result<Self> {
let query_api_url = config.query_api_url.ok_or_else(|| {
anyhow::anyhow!("query_api_url is required for PlatformBootstrapProvider")
})?;
Self::create_internal(query_api_url, config.timeout_seconds)
}
pub fn with_url(query_api_url: impl Into<String>, timeout_seconds: u64) -> Result<Self> {
Self::create_internal(query_api_url.into(), timeout_seconds)
}
pub fn builder() -> PlatformBootstrapProviderBuilder {
PlatformBootstrapProviderBuilder::new()
}
fn create_internal(query_api_url: String, timeout_seconds: u64) -> Result<Self> {
reqwest::Url::parse(&query_api_url)
.context(format!("Invalid query_api_url: {query_api_url}"))?;
let timeout = Duration::from_secs(timeout_seconds);
let client = Client::builder()
.timeout(timeout)
.build()
.context("Failed to build HTTP client")?;
Ok(Self {
query_api_url,
client,
})
}
async fn make_subscription_request(
&self,
request: &BootstrapRequest,
context: &BootstrapContext,
) -> Result<reqwest::Response> {
let subscription_req = SubscriptionRequest {
query_id: request.query_id.clone(),
query_node_id: context.server_id.clone(),
node_labels: request.node_labels.clone(),
rel_labels: request.relation_labels.clone(),
};
let url = format!("{}/subscription", self.query_api_url);
debug!(
"Making bootstrap subscription request to {} for query {}",
url, request.query_id
);
let response = self
.client
.post(&url)
.json(&subscription_req)
.send()
.await
.context(format!("Failed to connect to Query API at {url}"))?;
if !response.status().is_success() {
let status = response.status();
let error_text = response
.text()
.await
.unwrap_or_else(|_| "Unable to read error response".to_string());
return Err(anyhow::anyhow!(
"Query API returned error status {status}: {error_text}"
));
}
debug!(
"Successfully connected to Query API, preparing to stream bootstrap data for query {}",
request.query_id
);
Ok(response)
}
async fn process_bootstrap_stream(
&self,
response: reqwest::Response,
) -> Result<Vec<BootstrapElement>> {
let mut elements = Vec::new();
let mut line_buffer = String::new();
let mut byte_stream = response.bytes_stream();
let mut element_count = 0;
while let Some(chunk_result) = byte_stream.next().await {
let chunk = chunk_result.context("Error reading stream chunk")?;
let chunk_str =
std::str::from_utf8(&chunk).context("Invalid UTF-8 in stream response")?;
line_buffer.push_str(chunk_str);
while let Some(newline_pos) = line_buffer.find('\n') {
let line = line_buffer[..newline_pos].trim().to_string();
line_buffer = line_buffer[newline_pos + 1..].to_string();
if line.is_empty() {
continue;
}
match serde_json::from_str::<BootstrapElement>(&line) {
Ok(element) => {
elements.push(element);
element_count += 1;
if element_count % 1000 == 0 {
debug!("Received {element_count} bootstrap elements from stream");
}
}
Err(e) => {
warn!("Failed to parse bootstrap element from JSON: {e} - Line: {line}");
}
}
}
}
let remaining = line_buffer.trim();
if !remaining.is_empty() {
match serde_json::from_str::<BootstrapElement>(remaining) {
Ok(element) => {
elements.push(element);
element_count += 1;
}
Err(e) => {
warn!(
"Failed to parse final bootstrap element from JSON: {e} - Line: {remaining}"
);
}
}
}
info!("Received total of {element_count} bootstrap elements from Query API stream");
Ok(elements)
}
}
pub struct PlatformBootstrapProviderBuilder {
query_api_url: Option<String>,
timeout_seconds: u64,
}
impl PlatformBootstrapProviderBuilder {
pub fn new() -> Self {
Self {
query_api_url: None,
timeout_seconds: 300, }
}
pub fn with_query_api_url(mut self, url: impl Into<String>) -> Self {
self.query_api_url = Some(url.into());
self
}
pub fn with_timeout_seconds(mut self, seconds: u64) -> Self {
self.timeout_seconds = seconds;
self
}
pub fn build(self) -> Result<PlatformBootstrapProvider> {
let query_api_url = self
.query_api_url
.ok_or_else(|| anyhow::anyhow!("query_api_url is required"))?;
PlatformBootstrapProvider::create_internal(query_api_url, self.timeout_seconds)
}
}
impl Default for PlatformBootstrapProviderBuilder {
fn default() -> Self {
Self::new()
}
}
#[async_trait]
impl BootstrapProvider for PlatformBootstrapProvider {
async fn bootstrap(
&self,
request: BootstrapRequest,
context: &BootstrapContext,
event_tx: drasi_lib::channels::BootstrapEventSender,
_settings: Option<&drasi_lib::config::SourceSubscriptionSettings>,
) -> Result<usize> {
info!(
"Starting platform bootstrap for query {} from source {}",
request.query_id, context.source_id
);
let response = self
.make_subscription_request(&request, context)
.await
.context("Failed to make subscription request to Query API")?;
let bootstrap_elements = self
.process_bootstrap_stream(response)
.await
.context("Failed to process bootstrap stream from Query API")?;
debug!(
"Processing {} bootstrap elements for query {}",
bootstrap_elements.len(),
request.query_id
);
let mut sent_count = 0;
let mut filtered_nodes = 0;
let mut filtered_relations = 0;
for bootstrap_elem in bootstrap_elements {
let is_relation = bootstrap_elem.start_id.is_some() && bootstrap_elem.end_id.is_some();
let should_process = if is_relation {
matches_labels(&bootstrap_elem.labels, &request.relation_labels)
} else {
matches_labels(&bootstrap_elem.labels, &request.node_labels)
};
if !should_process {
if is_relation {
filtered_relations += 1;
} else {
filtered_nodes += 1;
}
continue;
}
let element = transform_element(&context.source_id, bootstrap_elem)
.context("Failed to transform bootstrap element")?;
let source_change = SourceChange::Insert { element };
let sequence = context.next_sequence();
let bootstrap_event = drasi_lib::channels::BootstrapEvent {
source_id: context.source_id.clone(),
change: source_change,
timestamp: Utc::now(),
sequence,
};
event_tx
.send(bootstrap_event)
.await
.context("Failed to send bootstrap element via channel")?;
sent_count += 1;
}
debug!(
"Filtered {filtered_nodes} nodes and {filtered_relations} relations based on requested labels"
);
info!(
"Completed platform bootstrap for query {}: sent {} elements",
request.query_id, sent_count
);
Ok(sent_count)
}
}
fn matches_labels(element_labels: &[String], requested_labels: &[String]) -> bool {
requested_labels.is_empty()
|| element_labels
.iter()
.any(|label| requested_labels.contains(label))
}
fn transform_element(source_id: &str, bootstrap_elem: BootstrapElement) -> Result<Element> {
let properties = convert_json_to_element_properties(&bootstrap_elem.properties)
.context("Failed to convert element properties")?;
let labels: Arc<[Arc<str>]> = bootstrap_elem
.labels
.iter()
.map(|l| Arc::from(l.as_str()))
.collect::<Vec<_>>()
.into();
if let (Some(start_id), Some(end_id)) = (&bootstrap_elem.start_id, &bootstrap_elem.end_id) {
let in_node = ElementReference::new(source_id, start_id);
let out_node = ElementReference::new(source_id, end_id);
Ok(Element::Relation {
metadata: ElementMetadata {
reference: ElementReference::new(source_id, &bootstrap_elem.id),
labels,
effective_from: 0,
},
properties,
in_node,
out_node,
})
} else {
Ok(Element::Node {
metadata: ElementMetadata {
reference: ElementReference::new(source_id, &bootstrap_elem.id),
labels,
effective_from: 0,
},
properties,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_matches_labels_empty_requested() {
let element_labels = vec!["Person".to_string(), "Employee".to_string()];
let requested_labels = vec![];
assert!(matches_labels(&element_labels, &requested_labels));
}
#[test]
fn test_matches_labels_matching() {
let element_labels = vec!["Person".to_string(), "Employee".to_string()];
let requested_labels = vec!["Person".to_string()];
assert!(matches_labels(&element_labels, &requested_labels));
}
#[test]
fn test_matches_labels_non_matching() {
let element_labels = vec!["Person".to_string(), "Employee".to_string()];
let requested_labels = vec!["Company".to_string()];
assert!(!matches_labels(&element_labels, &requested_labels));
}
#[test]
fn test_matches_labels_partial_overlap() {
let element_labels = vec!["Person".to_string(), "Employee".to_string()];
let requested_labels = vec!["Employee".to_string(), "Company".to_string()];
assert!(matches_labels(&element_labels, &requested_labels));
}
#[test]
fn test_matches_labels_empty_element() {
let element_labels = vec![];
let requested_labels = vec!["Person".to_string()];
assert!(!matches_labels(&element_labels, &requested_labels));
}
#[test]
fn test_matches_labels_both_empty() {
let element_labels = vec![];
let requested_labels = vec![];
assert!(matches_labels(&element_labels, &requested_labels));
}
#[test]
fn test_transform_element_node() {
let mut properties = Map::new();
properties.insert("name".to_string(), serde_json::json!("Alice"));
properties.insert("age".to_string(), serde_json::json!(30));
let bootstrap_elem = BootstrapElement {
id: "1".to_string(),
labels: vec!["Person".to_string()],
properties,
start_id: None,
end_id: None,
};
let element = transform_element("test-source", bootstrap_elem).unwrap();
match element {
Element::Node { metadata, .. } => {
assert_eq!(metadata.reference.element_id.as_ref(), "1");
assert_eq!(metadata.labels.len(), 1);
assert_eq!(metadata.labels[0].as_ref(), "Person");
}
_ => panic!("Expected Node element"),
}
}
#[test]
fn test_transform_element_relation() {
let mut properties = Map::new();
properties.insert("since".to_string(), serde_json::json!("2020"));
let bootstrap_elem = BootstrapElement {
id: "r1".to_string(),
labels: vec!["WORKS_FOR".to_string()],
properties,
start_id: Some("1".to_string()),
end_id: Some("2".to_string()),
};
let element = transform_element("test-source", bootstrap_elem).unwrap();
match element {
Element::Relation {
metadata,
in_node,
out_node,
..
} => {
assert_eq!(metadata.reference.element_id.as_ref(), "r1");
assert_eq!(metadata.labels.len(), 1);
assert_eq!(metadata.labels[0].as_ref(), "WORKS_FOR");
assert_eq!(in_node.element_id.as_ref(), "1");
assert_eq!(out_node.element_id.as_ref(), "2");
}
_ => panic!("Expected Relation element"),
}
}
#[test]
fn test_transform_element_various_property_types() {
let mut properties = Map::new();
properties.insert("string_prop".to_string(), serde_json::json!("text"));
properties.insert("number_prop".to_string(), serde_json::json!(42));
properties.insert("float_prop".to_string(), serde_json::json!(1.23456));
properties.insert("bool_prop".to_string(), serde_json::json!(true));
properties.insert("null_prop".to_string(), serde_json::json!(null));
let bootstrap_elem = BootstrapElement {
id: "1".to_string(),
labels: vec!["Test".to_string()],
properties,
start_id: None,
end_id: None,
};
let element = transform_element("test-source", bootstrap_elem).unwrap();
match element {
Element::Node { metadata, .. } => {
assert_eq!(metadata.reference.element_id.as_ref(), "1");
}
_ => panic!("Expected Node element"),
}
}
#[test]
fn test_transform_element_empty_properties() {
let bootstrap_elem = BootstrapElement {
id: "1".to_string(),
labels: vec!["Empty".to_string()],
properties: Map::new(),
start_id: None,
end_id: None,
};
let element = transform_element("test-source", bootstrap_elem).unwrap();
match element {
Element::Node { metadata, .. } => {
assert_eq!(metadata.reference.element_id.as_ref(), "1");
}
_ => panic!("Expected Node element"),
}
}
}