1use anyhow::Result;
33use log::{debug, error, info, warn};
34use std::sync::Arc;
35use tokio::sync::RwLock;
36use tracing::Instrument;
37
38use crate::bootstrap::{BootstrapContext, BootstrapProvider, BootstrapRequest};
39use crate::channels::*;
40use crate::component_graph::ComponentStatusHandle;
41use crate::context::SourceRuntimeContext;
42use crate::identity::IdentityProvider;
43use crate::profiling;
44use crate::state_store::StateStoreProvider;
45use drasi_core::models::SourceChange;
46
47pub struct SourceBaseParams {
65 pub id: String,
67 pub dispatch_mode: Option<DispatchMode>,
69 pub dispatch_buffer_capacity: Option<usize>,
71 pub bootstrap_provider: Option<Box<dyn BootstrapProvider + 'static>>,
73 pub auto_start: bool,
75}
76
77impl std::fmt::Debug for SourceBaseParams {
78 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
79 f.debug_struct("SourceBaseParams")
80 .field("id", &self.id)
81 .field("dispatch_mode", &self.dispatch_mode)
82 .field("dispatch_buffer_capacity", &self.dispatch_buffer_capacity)
83 .field(
84 "bootstrap_provider",
85 &self.bootstrap_provider.as_ref().map(|_| "<provider>"),
86 )
87 .field("auto_start", &self.auto_start)
88 .finish()
89 }
90}
91
92impl SourceBaseParams {
93 pub fn new(id: impl Into<String>) -> Self {
95 Self {
96 id: id.into(),
97 dispatch_mode: None,
98 dispatch_buffer_capacity: None,
99 bootstrap_provider: None,
100 auto_start: true,
101 }
102 }
103
104 pub fn with_dispatch_mode(mut self, mode: DispatchMode) -> Self {
106 self.dispatch_mode = Some(mode);
107 self
108 }
109
110 pub fn with_dispatch_buffer_capacity(mut self, capacity: usize) -> Self {
112 self.dispatch_buffer_capacity = Some(capacity);
113 self
114 }
115
116 pub fn with_bootstrap_provider(mut self, provider: impl BootstrapProvider + 'static) -> Self {
121 self.bootstrap_provider = Some(Box::new(provider));
122 self
123 }
124
125 pub fn with_auto_start(mut self, auto_start: bool) -> Self {
130 self.auto_start = auto_start;
131 self
132 }
133}
134
135pub struct SourceBase {
137 pub id: String,
139 dispatch_mode: DispatchMode,
141 dispatch_buffer_capacity: usize,
143 pub auto_start: bool,
145 status_handle: ComponentStatusHandle,
147 pub dispatchers: Arc<RwLock<Vec<Box<dyn ChangeDispatcher<SourceEventWrapper> + Send + Sync>>>>,
153 context: Arc<RwLock<Option<SourceRuntimeContext>>>,
155 state_store: Arc<RwLock<Option<Arc<dyn StateStoreProvider>>>>,
157 pub task_handle: Arc<RwLock<Option<tokio::task::JoinHandle<()>>>>,
159 pub shutdown_tx: Arc<RwLock<Option<tokio::sync::oneshot::Sender<()>>>>,
161 bootstrap_provider: Arc<RwLock<Option<Arc<dyn BootstrapProvider>>>>,
163 identity_provider: Arc<RwLock<Option<Arc<dyn IdentityProvider>>>>,
167}
168
169impl SourceBase {
170 pub fn new(params: SourceBaseParams) -> Result<Self> {
178 let dispatch_mode = params.dispatch_mode.unwrap_or_default();
180 let dispatch_buffer_capacity = params.dispatch_buffer_capacity.unwrap_or(1000);
181
182 let mut dispatchers: Vec<Box<dyn ChangeDispatcher<SourceEventWrapper> + Send + Sync>> =
184 Vec::new();
185
186 if dispatch_mode == DispatchMode::Broadcast {
187 let dispatcher =
189 BroadcastChangeDispatcher::<SourceEventWrapper>::new(dispatch_buffer_capacity);
190 dispatchers.push(Box::new(dispatcher));
191 }
192 let bootstrap_provider = params
196 .bootstrap_provider
197 .map(|p| Arc::from(p) as Arc<dyn BootstrapProvider>);
198
199 Ok(Self {
200 id: params.id.clone(),
201 dispatch_mode,
202 dispatch_buffer_capacity,
203 auto_start: params.auto_start,
204 status_handle: ComponentStatusHandle::new(¶ms.id),
205 dispatchers: Arc::new(RwLock::new(dispatchers)),
206 context: Arc::new(RwLock::new(None)), state_store: Arc::new(RwLock::new(None)), task_handle: Arc::new(RwLock::new(None)),
209 shutdown_tx: Arc::new(RwLock::new(None)),
210 bootstrap_provider: Arc::new(RwLock::new(bootstrap_provider)),
211 identity_provider: Arc::new(RwLock::new(None)),
212 })
213 }
214
215 pub fn get_auto_start(&self) -> bool {
217 self.auto_start
218 }
219
220 pub async fn initialize(&self, context: SourceRuntimeContext) {
230 *self.context.write().await = Some(context.clone());
232
233 self.status_handle.wire(context.update_tx.clone()).await;
235
236 if let Some(state_store) = context.state_store.as_ref() {
237 *self.state_store.write().await = Some(state_store.clone());
238 }
239
240 if let Some(ip) = context.identity_provider.as_ref() {
242 let mut guard = self.identity_provider.write().await;
243 if guard.is_none() {
244 *guard = Some(ip.clone());
245 }
246 }
247 }
248
249 pub async fn context(&self) -> Option<SourceRuntimeContext> {
253 self.context.read().await.clone()
254 }
255
256 pub async fn state_store(&self) -> Option<Arc<dyn StateStoreProvider>> {
260 self.state_store.read().await.clone()
261 }
262
263 pub async fn identity_provider(&self) -> Option<Arc<dyn IdentityProvider>> {
269 self.identity_provider.read().await.clone()
270 }
271
272 pub async fn set_identity_provider(&self, provider: Arc<dyn IdentityProvider>) {
278 *self.identity_provider.write().await = Some(provider);
279 }
280
281 pub fn status_handle(&self) -> ComponentStatusHandle {
286 self.status_handle.clone()
287 }
288
289 pub fn clone_shared(&self) -> Self {
294 Self {
295 id: self.id.clone(),
296 dispatch_mode: self.dispatch_mode,
297 dispatch_buffer_capacity: self.dispatch_buffer_capacity,
298 auto_start: self.auto_start,
299 status_handle: self.status_handle.clone(),
300 dispatchers: self.dispatchers.clone(),
301 context: self.context.clone(),
302 state_store: self.state_store.clone(),
303 task_handle: self.task_handle.clone(),
304 shutdown_tx: self.shutdown_tx.clone(),
305 bootstrap_provider: self.bootstrap_provider.clone(),
306 identity_provider: self.identity_provider.clone(),
307 }
308 }
309
310 pub async fn set_bootstrap_provider(&self, provider: impl BootstrapProvider + 'static) {
321 *self.bootstrap_provider.write().await = Some(Arc::new(provider));
322 }
323
324 pub fn get_id(&self) -> &str {
326 &self.id
327 }
328
329 pub async fn create_streaming_receiver(
337 &self,
338 ) -> Result<Box<dyn ChangeReceiver<SourceEventWrapper>>> {
339 let receiver: Box<dyn ChangeReceiver<SourceEventWrapper>> = match self.dispatch_mode {
340 DispatchMode::Broadcast => {
341 let dispatchers = self.dispatchers.read().await;
343 if let Some(dispatcher) = dispatchers.first() {
344 dispatcher.create_receiver().await?
345 } else {
346 return Err(anyhow::anyhow!("No broadcast dispatcher available"));
347 }
348 }
349 DispatchMode::Channel => {
350 let dispatcher = ChannelChangeDispatcher::<SourceEventWrapper>::new(
352 self.dispatch_buffer_capacity,
353 );
354 let receiver = dispatcher.create_receiver().await?;
355
356 let mut dispatchers = self.dispatchers.write().await;
358 dispatchers.push(Box::new(dispatcher));
359
360 receiver
361 }
362 };
363
364 Ok(receiver)
365 }
366
367 pub async fn subscribe_with_bootstrap(
375 &self,
376 settings: &crate::config::SourceSubscriptionSettings,
377 source_type: &str,
378 ) -> Result<SubscriptionResponse> {
379 info!(
380 "Query '{}' subscribing to {} source '{}' (bootstrap: {})",
381 settings.query_id, source_type, self.id, settings.enable_bootstrap
382 );
383
384 let receiver = self.create_streaming_receiver().await?;
386
387 let query_id_for_response = settings.query_id.clone();
388
389 let bootstrap_receiver = if settings.enable_bootstrap {
391 self.handle_bootstrap_subscription(settings, source_type)
392 .await?
393 } else {
394 None
395 };
396
397 Ok(SubscriptionResponse {
398 query_id: query_id_for_response,
399 source_id: self.id.clone(),
400 receiver,
401 bootstrap_receiver,
402 })
403 }
404
405 async fn handle_bootstrap_subscription(
407 &self,
408 settings: &crate::config::SourceSubscriptionSettings,
409 source_type: &str,
410 ) -> Result<Option<BootstrapEventReceiver>> {
411 let provider_guard = self.bootstrap_provider.read().await;
412 if let Some(provider) = provider_guard.clone() {
413 drop(provider_guard); info!(
416 "Creating bootstrap for query '{}' on {} source '{}'",
417 settings.query_id, source_type, self.id
418 );
419
420 let context = BootstrapContext::new_minimal(
422 self.id.clone(), self.id.clone(), );
425
426 let (bootstrap_tx, bootstrap_rx) = tokio::sync::mpsc::channel(1000);
428
429 let node_labels: Vec<String> = settings.nodes.iter().cloned().collect();
431 let relation_labels: Vec<String> = settings.relations.iter().cloned().collect();
432
433 let request = BootstrapRequest {
435 query_id: settings.query_id.clone(),
436 node_labels,
437 relation_labels,
438 request_id: format!("{}-{}", settings.query_id, uuid::Uuid::new_v4()),
439 };
440
441 let settings_clone = settings.clone();
443 let source_id = self.id.clone();
444
445 let instance_id = self
447 .context()
448 .await
449 .map(|c| c.instance_id.clone())
450 .unwrap_or_default();
451
452 let span = tracing::info_span!(
454 "source_bootstrap",
455 instance_id = %instance_id,
456 component_id = %source_id,
457 component_type = "source"
458 );
459 tokio::spawn(
460 async move {
461 match provider
462 .bootstrap(request, &context, bootstrap_tx, Some(&settings_clone))
463 .await
464 {
465 Ok(count) => {
466 info!(
467 "Bootstrap completed successfully for query '{}', sent {count} events",
468 settings_clone.query_id
469 );
470 }
471 Err(e) => {
472 error!(
473 "Bootstrap failed for query '{}': {e}",
474 settings_clone.query_id
475 );
476 }
477 }
478 }
479 .instrument(span),
480 );
481
482 Ok(Some(bootstrap_rx))
483 } else {
484 info!(
485 "Bootstrap requested for query '{}' but no bootstrap provider configured for {} source '{}'",
486 settings.query_id, source_type, self.id
487 );
488 Ok(None)
489 }
490 }
491
492 pub async fn dispatch_source_change(&self, change: SourceChange) -> Result<()> {
500 let mut profiling = profiling::ProfilingMetadata::new();
502 profiling.source_send_ns = Some(profiling::timestamp_ns());
503
504 let wrapper = SourceEventWrapper::with_profiling(
506 self.id.clone(),
507 SourceEvent::Change(change),
508 chrono::Utc::now(),
509 profiling,
510 );
511
512 self.dispatch_event(wrapper).await
514 }
515
516 pub async fn dispatch_event(&self, wrapper: SourceEventWrapper) -> Result<()> {
522 debug!("[{}] Dispatching event: {:?}", self.id, &wrapper);
523
524 let arc_wrapper = Arc::new(wrapper);
526
527 let dispatchers = self.dispatchers.read().await;
529 for dispatcher in dispatchers.iter() {
530 if let Err(e) = dispatcher.dispatch_change(arc_wrapper.clone()).await {
531 debug!("[{}] Failed to dispatch event: {}", self.id, e);
532 }
533 }
534
535 Ok(())
536 }
537
538 pub async fn broadcast_control(&self, control: SourceControl) -> Result<()> {
540 let wrapper = SourceEventWrapper::new(
541 self.id.clone(),
542 SourceEvent::Control(control),
543 chrono::Utc::now(),
544 );
545 self.dispatch_event(wrapper).await
546 }
547
548 pub fn try_test_subscribe(
561 &self,
562 ) -> anyhow::Result<Box<dyn ChangeReceiver<SourceEventWrapper>>> {
563 tokio::task::block_in_place(|| {
564 tokio::runtime::Handle::current().block_on(self.create_streaming_receiver())
565 })
566 }
567
568 pub fn test_subscribe(&self) -> Box<dyn ChangeReceiver<SourceEventWrapper>> {
576 self.try_test_subscribe()
577 .expect("Failed to create test subscription receiver")
578 }
579
580 pub async fn dispatch_from_task(
592 dispatchers: Arc<RwLock<Vec<Box<dyn ChangeDispatcher<SourceEventWrapper> + Send + Sync>>>>,
593 wrapper: SourceEventWrapper,
594 source_id: &str,
595 ) -> Result<()> {
596 debug!(
597 "[{}] Dispatching event from task: {:?}",
598 source_id, &wrapper
599 );
600
601 let arc_wrapper = Arc::new(wrapper);
603
604 let dispatchers_guard = dispatchers.read().await;
606 for dispatcher in dispatchers_guard.iter() {
607 if let Err(e) = dispatcher.dispatch_change(arc_wrapper.clone()).await {
608 debug!("[{source_id}] Failed to dispatch event from task: {e}");
609 }
610 }
611
612 Ok(())
613 }
614
615 pub async fn stop_common(&self) -> Result<()> {
617 info!("Stopping source '{}'", self.id);
618
619 if let Some(tx) = self.shutdown_tx.write().await.take() {
621 let _ = tx.send(());
622 }
623
624 if let Some(mut handle) = self.task_handle.write().await.take() {
626 match tokio::time::timeout(std::time::Duration::from_secs(5), &mut handle).await {
627 Ok(Ok(())) => {
628 info!("Source '{}' task completed successfully", self.id);
629 }
630 Ok(Err(e)) => {
631 error!("Source '{}' task panicked: {}", self.id, e);
632 }
633 Err(_) => {
634 warn!(
635 "Source '{}' task did not complete within timeout, aborting",
636 self.id
637 );
638 handle.abort();
639 }
640 }
641 }
642
643 self.set_status(
644 ComponentStatus::Stopped,
645 Some(format!("Source '{}' stopped", self.id)),
646 )
647 .await;
648 info!("Source '{}' stopped", self.id);
649 Ok(())
650 }
651
652 pub async fn deprovision_common(&self) -> Result<()> {
658 info!("Deprovisioning source '{}'", self.id);
659 if let Some(store) = self.state_store().await {
660 let count = store.clear_store(&self.id).await.map_err(|e| {
661 anyhow::anyhow!(
662 "Failed to clear state store for source '{}': {}",
663 self.id,
664 e
665 )
666 })?;
667 info!(
668 "Cleared {} keys from state store for source '{}'",
669 count, self.id
670 );
671 }
672 Ok(())
673 }
674
675 pub async fn get_status(&self) -> ComponentStatus {
677 self.status_handle.get_status().await
678 }
679
680 pub async fn set_status(&self, status: ComponentStatus, message: Option<String>) {
684 self.status_handle.set_status(status, message).await;
685 }
686
687 pub async fn set_task_handle(&self, handle: tokio::task::JoinHandle<()>) {
689 *self.task_handle.write().await = Some(handle);
690 }
691
692 pub async fn set_shutdown_tx(&self, tx: tokio::sync::oneshot::Sender<()>) {
694 *self.shutdown_tx.write().await = Some(tx);
695 }
696}
697
698#[cfg(test)]
699mod tests {
700 use super::*;
701
702 #[test]
707 fn test_params_new_defaults() {
708 let params = SourceBaseParams::new("test-source");
709 assert_eq!(params.id, "test-source");
710 assert!(params.dispatch_mode.is_none());
711 assert!(params.dispatch_buffer_capacity.is_none());
712 assert!(params.bootstrap_provider.is_none());
713 assert!(params.auto_start);
714 }
715
716 #[test]
717 fn test_params_with_dispatch_mode() {
718 let params = SourceBaseParams::new("s1").with_dispatch_mode(DispatchMode::Broadcast);
719 assert_eq!(params.dispatch_mode, Some(DispatchMode::Broadcast));
720 }
721
722 #[test]
723 fn test_params_with_dispatch_buffer_capacity() {
724 let params = SourceBaseParams::new("s1").with_dispatch_buffer_capacity(50000);
725 assert_eq!(params.dispatch_buffer_capacity, Some(50000));
726 }
727
728 #[test]
729 fn test_params_with_auto_start_false() {
730 let params = SourceBaseParams::new("s1").with_auto_start(false);
731 assert!(!params.auto_start);
732 }
733
734 #[test]
735 fn test_params_builder_chaining() {
736 let params = SourceBaseParams::new("chained")
737 .with_dispatch_mode(DispatchMode::Broadcast)
738 .with_dispatch_buffer_capacity(2000)
739 .with_auto_start(false);
740
741 assert_eq!(params.id, "chained");
742 assert_eq!(params.dispatch_mode, Some(DispatchMode::Broadcast));
743 assert_eq!(params.dispatch_buffer_capacity, Some(2000));
744 assert!(!params.auto_start);
745 }
746
747 #[tokio::test]
752 async fn test_new_defaults() {
753 let params = SourceBaseParams::new("my-source");
754 let base = SourceBase::new(params).unwrap();
755
756 assert_eq!(base.id, "my-source");
757 assert!(base.auto_start);
758 assert_eq!(base.get_status().await, ComponentStatus::Stopped);
759 }
760
761 #[tokio::test]
762 async fn test_get_id() {
763 let base = SourceBase::new(SourceBaseParams::new("id-check")).unwrap();
764 assert_eq!(base.get_id(), "id-check");
765 }
766
767 #[tokio::test]
768 async fn test_get_auto_start() {
769 let base_default = SourceBase::new(SourceBaseParams::new("a")).unwrap();
770 assert!(base_default.get_auto_start());
771
772 let base_false =
773 SourceBase::new(SourceBaseParams::new("b").with_auto_start(false)).unwrap();
774 assert!(!base_false.get_auto_start());
775 }
776
777 #[tokio::test]
778 async fn test_get_status_initial() {
779 let base = SourceBase::new(SourceBaseParams::new("s")).unwrap();
780 assert_eq!(base.get_status().await, ComponentStatus::Stopped);
781 }
782
783 #[tokio::test]
784 async fn test_set_status() {
785 let base = SourceBase::new(SourceBaseParams::new("s")).unwrap();
786
787 base.set_status(ComponentStatus::Running, None).await;
788 assert_eq!(base.get_status().await, ComponentStatus::Running);
789
790 base.set_status(ComponentStatus::Error, Some("oops".into()))
791 .await;
792 assert_eq!(base.get_status().await, ComponentStatus::Error);
793 }
794
795 #[tokio::test]
796 async fn test_status_handle_returns_handle() {
797 let base = SourceBase::new(SourceBaseParams::new("s")).unwrap();
798 let handle = base.status_handle();
799
800 assert_eq!(handle.get_status().await, ComponentStatus::Stopped);
802
803 handle.set_status(ComponentStatus::Starting, None).await;
805 assert_eq!(base.get_status().await, ComponentStatus::Starting);
806 }
807}