1use anyhow::Result;
23use async_trait::async_trait;
24use serde::{Deserialize, Serialize};
25use std::collections::HashMap;
26use std::sync::atomic::{AtomicU64, Ordering};
27use std::sync::Arc;
28
29#[derive(Debug, Clone, Serialize, Deserialize)]
31pub struct BootstrapRequest {
32 pub query_id: String,
33 pub node_labels: Vec<String>,
34 pub relation_labels: Vec<String>,
35 pub request_id: String,
36}
37
38#[derive(Clone)]
50pub struct BootstrapContext {
51 pub server_id: String,
53 pub source_id: String,
55 pub sequence_counter: Arc<AtomicU64>,
57 properties: Arc<HashMap<String, serde_json::Value>>,
59}
60
61impl BootstrapContext {
62 pub fn new_minimal(server_id: String, source_id: String) -> Self {
67 Self {
68 server_id,
69 source_id,
70 sequence_counter: Arc::new(AtomicU64::new(0)),
71 properties: Arc::new(HashMap::new()),
72 }
73 }
74
75 pub fn with_properties(
81 server_id: String,
82 source_id: String,
83 properties: HashMap<String, serde_json::Value>,
84 ) -> Self {
85 Self {
86 server_id,
87 source_id,
88 sequence_counter: Arc::new(AtomicU64::new(0)),
89 properties: Arc::new(properties),
90 }
91 }
92
93 pub fn next_sequence(&self) -> u64 {
95 self.sequence_counter.fetch_add(1, Ordering::SeqCst)
96 }
97
98 pub fn get_property(&self, key: &str) -> Option<serde_json::Value> {
100 self.properties.get(key).cloned()
101 }
102
103 pub fn get_typed_property<T>(&self, key: &str) -> Result<Option<T>>
105 where
106 T: for<'de> Deserialize<'de>,
107 {
108 match self.get_property(key) {
109 Some(value) => Ok(Some(serde_json::from_value(value.clone())?)),
110 None => Ok(None),
111 }
112 }
113}
114
115use crate::channels::BootstrapEventSender;
116
117#[async_trait]
120pub trait BootstrapProvider: Send + Sync {
121 async fn bootstrap(
131 &self,
132 request: BootstrapRequest,
133 context: &BootstrapContext,
134 event_tx: BootstrapEventSender,
135 settings: Option<&crate::config::SourceSubscriptionSettings>,
136 ) -> Result<usize>;
137}
138
139#[async_trait]
142impl BootstrapProvider for Box<dyn BootstrapProvider> {
143 async fn bootstrap(
144 &self,
145 request: BootstrapRequest,
146 context: &BootstrapContext,
147 event_tx: BootstrapEventSender,
148 settings: Option<&crate::config::SourceSubscriptionSettings>,
149 ) -> Result<usize> {
150 (**self)
151 .bootstrap(request, context, event_tx, settings)
152 .await
153 }
154}
155
156#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
171pub struct PostgresBootstrapConfig {
172 }
175
176#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Default)]
189pub struct ApplicationBootstrapConfig {
190 }
193
194#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
210pub struct ScriptFileBootstrapConfig {
211 pub file_paths: Vec<String>,
213}
214
215#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
229pub struct PlatformBootstrapConfig {
230 #[serde(skip_serializing_if = "Option::is_none")]
233 pub query_api_url: Option<String>,
234
235 #[serde(default = "default_platform_timeout")]
237 pub timeout_seconds: u64,
238}
239
240fn default_platform_timeout() -> u64 {
241 300
242}
243
244impl Default for PlatformBootstrapConfig {
245 fn default() -> Self {
246 Self {
247 query_api_url: None,
248 timeout_seconds: default_platform_timeout(),
249 }
250 }
251}
252
253#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
255#[serde(tag = "type", rename_all = "lowercase")]
256pub enum BootstrapProviderConfig {
257 Postgres(PostgresBootstrapConfig),
259 Application(ApplicationBootstrapConfig),
261 ScriptFile(ScriptFileBootstrapConfig),
263 Platform(PlatformBootstrapConfig),
266 Noop,
268}
269
270pub struct BootstrapProviderFactory;
276
277impl BootstrapProviderFactory {
278 pub fn create_provider(config: &BootstrapProviderConfig) -> Result<Box<dyn BootstrapProvider>> {
288 match config {
289 BootstrapProviderConfig::Postgres(_) => {
290 Err(anyhow::anyhow!(
291 "PostgreSQL bootstrap provider is available in the drasi-bootstrap-postgres crate. \
292 Use PostgresBootstrapProvider::builder().with_host(...).build() to create it."
293 ))
294 }
295 BootstrapProviderConfig::Application(_) => {
296 Err(anyhow::anyhow!(
297 "Application bootstrap provider is available in the drasi-bootstrap-application crate. \
298 Use ApplicationBootstrapProvider::builder().build() to create it."
299 ))
300 }
301 BootstrapProviderConfig::ScriptFile(config) => {
302 Err(anyhow::anyhow!(
303 "ScriptFile bootstrap provider is available in the drasi-bootstrap-scriptfile crate. \
304 Use ScriptFileBootstrapProvider::builder().with_file(...).build() to create it. \
305 File paths: {:?}",
306 config.file_paths
307 ))
308 }
309 BootstrapProviderConfig::Platform(config) => {
310 Err(anyhow::anyhow!(
311 "Platform bootstrap provider is available in the drasi-bootstrap-platform crate. \
312 Use PlatformBootstrapProvider::builder().with_query_api_url(...).build() to create it. \
313 Config: {config:?}"
314 ))
315 }
316 BootstrapProviderConfig::Noop => {
317 Err(anyhow::anyhow!(
318 "No-op bootstrap provider is available in the drasi-bootstrap-noop crate. \
319 Use NoOpBootstrapProvider::builder().build() or NoOpBootstrapProvider::new() to create it."
320 ))
321 }
322 }
323 }
324}
325
326#[cfg(test)]
327mod tests {
328 use super::*;
329
330 #[test]
331 fn test_platform_bootstrap_config_defaults() {
332 let config = PlatformBootstrapConfig {
333 query_api_url: Some("http://test:8080".to_string()), ..Default::default()
335 };
336 assert_eq!(config.timeout_seconds, 300);
337 assert_eq!(config.query_api_url, Some("http://test:8080".to_string())); }
339
340 #[test]
341 fn test_postgres_bootstrap_config_defaults() {
342 let config = PostgresBootstrapConfig::default();
343 assert_eq!(config, PostgresBootstrapConfig {});
345 }
346
347 #[test]
348 fn test_application_bootstrap_config_defaults() {
349 let config = ApplicationBootstrapConfig::default();
350 assert_eq!(config, ApplicationBootstrapConfig {});
352 }
353
354 #[test]
355 fn test_platform_bootstrap_config_serialization() {
356 let config = BootstrapProviderConfig::Platform(PlatformBootstrapConfig {
357 query_api_url: Some("http://test:8080".to_string()), timeout_seconds: 600,
359 });
360
361 let json = serde_json::to_string(&config).unwrap();
362 assert!(json.contains("\"type\":\"platform\""));
363 assert!(json.contains("\"query_api_url\":\"http://test:8080\"")); assert!(json.contains("\"timeout_seconds\":600"));
365
366 let deserialized: BootstrapProviderConfig = serde_json::from_str(&json).unwrap();
367 match deserialized {
368 BootstrapProviderConfig::Platform(cfg) => {
369 assert_eq!(cfg.query_api_url, Some("http://test:8080".to_string())); assert_eq!(cfg.timeout_seconds, 600);
371 }
372 _ => panic!("Expected Platform variant"),
373 }
374 }
375
376 #[test]
377 fn test_scriptfile_bootstrap_config() {
378 let config = BootstrapProviderConfig::ScriptFile(ScriptFileBootstrapConfig {
379 file_paths: vec![
380 "/path/to/file1.jsonl".to_string(),
381 "/path/to/file2.jsonl".to_string(),
382 ],
383 });
384
385 let json = serde_json::to_string(&config).unwrap();
386 assert!(json.contains("\"type\":\"scriptfile\""));
387 assert!(json.contains("\"file_paths\""));
388
389 let deserialized: BootstrapProviderConfig = serde_json::from_str(&json).unwrap();
390 match deserialized {
391 BootstrapProviderConfig::ScriptFile(cfg) => {
392 assert_eq!(cfg.file_paths.len(), 2);
393 assert_eq!(cfg.file_paths[0], "/path/to/file1.jsonl");
394 assert_eq!(cfg.file_paths[1], "/path/to/file2.jsonl");
395 }
396 _ => panic!("Expected ScriptFile variant"),
397 }
398 }
399
400 #[test]
401 fn test_noop_bootstrap_config() {
402 let config = BootstrapProviderConfig::Noop;
403
404 let json = serde_json::to_string(&config).unwrap();
405 assert!(json.contains("\"type\":\"noop\""));
406
407 let deserialized: BootstrapProviderConfig = serde_json::from_str(&json).unwrap();
408 assert!(matches!(deserialized, BootstrapProviderConfig::Noop));
409 }
410
411 #[test]
412 fn test_postgres_bootstrap_config_serialization() {
413 let config = BootstrapProviderConfig::Postgres(PostgresBootstrapConfig::default());
414
415 let json = serde_json::to_string(&config).unwrap();
416 assert!(json.contains("\"type\":\"postgres\""));
417
418 let deserialized: BootstrapProviderConfig = serde_json::from_str(&json).unwrap();
419 assert!(matches!(deserialized, BootstrapProviderConfig::Postgres(_)));
420 }
421
422 #[test]
423 fn test_application_bootstrap_config_serialization() {
424 let config = BootstrapProviderConfig::Application(ApplicationBootstrapConfig::default());
425
426 let json = serde_json::to_string(&config).unwrap();
427 assert!(json.contains("\"type\":\"application\""));
428
429 let deserialized: BootstrapProviderConfig = serde_json::from_str(&json).unwrap();
430 assert!(matches!(
431 deserialized,
432 BootstrapProviderConfig::Application(_)
433 ));
434 }
435
436 #[test]
437 fn test_yaml_deserialization_platform() {
438 let yaml = r#"
439type: platform
440query_api_url: "http://remote:8080" # DevSkim: ignore DS137138
441timeout_seconds: 300
442"#;
443
444 let config: BootstrapProviderConfig = serde_yaml::from_str(yaml).unwrap();
445 match config {
446 BootstrapProviderConfig::Platform(cfg) => {
447 assert_eq!(cfg.query_api_url, Some("http://remote:8080".to_string())); assert_eq!(cfg.timeout_seconds, 300);
449 }
450 _ => panic!("Expected Platform variant"),
451 }
452 }
453
454 #[test]
455 fn test_yaml_deserialization_scriptfile() {
456 let yaml = r#"
457type: scriptfile
458file_paths:
459 - "/data/file1.jsonl"
460 - "/data/file2.jsonl"
461"#;
462
463 let config: BootstrapProviderConfig = serde_yaml::from_str(yaml).unwrap();
464 match config {
465 BootstrapProviderConfig::ScriptFile(cfg) => {
466 assert_eq!(cfg.file_paths.len(), 2);
467 assert_eq!(cfg.file_paths[0], "/data/file1.jsonl");
468 }
469 _ => panic!("Expected ScriptFile variant"),
470 }
471 }
472
473 #[test]
474 fn test_platform_config_with_defaults() {
475 let yaml = r#"
476type: platform
477query_api_url: "http://test:8080" # DevSkim: ignore DS137138
478"#;
479
480 let config: BootstrapProviderConfig = serde_yaml::from_str(yaml).unwrap();
481 match config {
482 BootstrapProviderConfig::Platform(cfg) => {
483 assert_eq!(cfg.timeout_seconds, 300); }
485 _ => panic!("Expected Platform variant"),
486 }
487 }
488
489 #[test]
490 fn test_bootstrap_config_equality() {
491 let config1 = BootstrapProviderConfig::Platform(PlatformBootstrapConfig {
492 query_api_url: Some("http://test:8080".to_string()), timeout_seconds: 300,
494 });
495
496 let config2 = BootstrapProviderConfig::Platform(PlatformBootstrapConfig {
497 query_api_url: Some("http://test:8080".to_string()), timeout_seconds: 300,
499 });
500
501 assert_eq!(config1, config2);
502 }
503
504 #[test]
505 fn test_backward_compatibility_yaml() {
506 let yaml = r#"
508type: postgres
509"#;
510
511 let config: BootstrapProviderConfig = serde_yaml::from_str(yaml).unwrap();
512 assert!(matches!(config, BootstrapProviderConfig::Postgres(_)));
513 }
514}