use std::sync::Arc;
use async_trait::async_trait;
use chrono::Utc;
use uuid::Uuid;
use cognee_database::{PipelineRunRepository, PipelineRunStatus as DbStatus};
use crate::pipeline::{
PipelineRunInfo, PipelineRunStatus as CoreStatus, PipelineStatus, PipelineWatcher, TaskStatus,
};
use super::types::{RunEvent, RunEventKind, RunPhase};
pub struct PerRunSink {
pub(crate) event_tx: tokio::sync::broadcast::Sender<RunEvent>,
pub(crate) phase_tx: tokio::sync::watch::Sender<RunPhase>,
}
impl PerRunSink {
pub fn from_parts(
event_tx: tokio::sync::broadcast::Sender<RunEvent>,
phase_tx: tokio::sync::watch::Sender<RunPhase>,
) -> Self {
Self { event_tx, phase_tx }
}
}
impl PerRunSink {
pub fn publish(&self, event: RunEvent) {
let _ = self.event_tx.send(event);
}
pub fn set_phase(&self, phase: RunPhase) {
let _ = self.phase_tx.send(phase);
}
}
pub struct ScopedRunWatcher {
run_id: Uuid,
sink: PerRunSink,
db: Arc<dyn PipelineRunRepository>,
}
impl ScopedRunWatcher {
pub fn new(run_id: Uuid, sink: PerRunSink, db: Arc<dyn PipelineRunRepository>) -> Self {
Self { run_id, sink, db }
}
}
fn core_to_db_status(status: &CoreStatus) -> DbStatus {
match status {
CoreStatus::Initiated => DbStatus::Initiated,
CoreStatus::Started => DbStatus::Started,
CoreStatus::Completed => DbStatus::Completed,
CoreStatus::Errored => DbStatus::Errored,
}
}
#[async_trait]
impl PipelineWatcher for ScopedRunWatcher {
async fn on_pipeline(&self, _pipeline_id: Uuid, _status: PipelineStatus) {}
async fn on_task(
&self,
_pipeline_id: Uuid,
_task_index: usize,
_task_name: Option<&str>,
_total_tasks: usize,
_status: TaskStatus,
) {
}
async fn on_pipeline_run_initiated(&self, run: &PipelineRunInfo) {
let run_info = Some(super::run_info_for_initiated());
let db_result = self
.db
.log_pipeline_run(
run.run_id,
run.pipeline_id,
&run.pipeline_name,
run.dataset_id,
DbStatus::Initiated,
run_info,
)
.await;
if let Err(e) = db_result {
tracing::warn!(
run_id = %self.run_id,
"ScopedRunWatcher: DB write for Initiated failed (non-fatal): {e}"
);
}
}
async fn on_pipeline_run_started(&self, run: &PipelineRunInfo) {
let run_info = Some(super::run_info_for_running(&run.data_ids));
let db_result = self
.db
.log_pipeline_run(
run.run_id,
run.pipeline_id,
&run.pipeline_name,
run.dataset_id,
core_to_db_status(&run.status),
run_info,
)
.await;
if let Err(e) = db_result {
tracing::warn!(
run_id = %self.run_id,
"ScopedRunWatcher: DB write for Started failed (non-fatal): {e}"
);
}
self.sink.set_phase(RunPhase::Running);
self.sink.publish(RunEvent {
run_id: self.run_id,
kind: RunEventKind::Started,
payload: serde_json::Value::Null,
at: Utc::now(),
});
}
async fn on_pipeline_run_completed(&self, run: &PipelineRunInfo, _output_count: usize) {
let run_info = Some(super::run_info_for_running(&run.data_ids));
let db_result = self
.db
.log_pipeline_run(
run.run_id,
run.pipeline_id,
&run.pipeline_name,
run.dataset_id,
DbStatus::Completed,
run_info,
)
.await;
if let Err(e) = db_result {
tracing::warn!(
run_id = %self.run_id,
"ScopedRunWatcher: DB write for Completed failed (non-fatal): {e}"
);
}
self.sink.set_phase(RunPhase::Completed);
self.sink.publish(RunEvent {
run_id: self.run_id,
kind: RunEventKind::Completed,
payload: serde_json::Value::Null,
at: Utc::now(),
});
}
async fn on_payload_field(&self, run_id: Uuid, key: &str, value: serde_json::Value) {
if let Err(e) = self.db.set_payload_field(run_id, key, value).await {
tracing::warn!(
run_id = %run_id,
key = %key,
"ScopedRunWatcher: DB write for payload field failed (non-fatal): {e}"
);
}
}
async fn on_pipeline_run_errored(&self, run: &PipelineRunInfo, error: &str) {
let run_info = Some(super::run_info_for_errored(&run.data_ids, error));
let db_result = self
.db
.log_pipeline_run(
run.run_id,
run.pipeline_id,
&run.pipeline_name,
run.dataset_id,
DbStatus::Errored,
run_info,
)
.await;
if let Err(e) = db_result {
tracing::warn!(
run_id = %self.run_id,
"ScopedRunWatcher: DB write for Errored failed (non-fatal): {e}"
);
}
self.sink.set_phase(RunPhase::Errored {
message: error.to_string(),
});
self.sink.publish(RunEvent {
run_id: self.run_id,
kind: RunEventKind::Errored {
message: error.to_string(),
},
payload: serde_json::Value::Null,
at: Utc::now(),
});
}
}