use std::sync::Arc;
use async_trait::async_trait;
use uuid::Uuid;
use cognee_database::{PipelineRunRepository, PipelineRunStatus as DbStatus};
use crate::pipeline::{
PipelineRunInfo, PipelineRunStatus as CoreStatus, PipelineStatus, PipelineWatcher, TaskStatus,
};
use super::{run_info_for_errored, run_info_for_initiated, run_info_for_running};
pub struct DbPipelineWatcher {
repo: Arc<dyn PipelineRunRepository>,
}
impl DbPipelineWatcher {
pub fn new(repo: Arc<dyn PipelineRunRepository>) -> Self {
Self { repo }
}
fn core_to_db_status(status: &CoreStatus) -> DbStatus {
match status {
CoreStatus::Initiated => DbStatus::Initiated,
CoreStatus::Started => DbStatus::Started,
CoreStatus::Completed => DbStatus::Completed,
CoreStatus::Errored => DbStatus::Errored,
}
}
}
#[async_trait]
impl PipelineWatcher for DbPipelineWatcher {
async fn on_pipeline(&self, _pipeline_id: Uuid, _status: PipelineStatus) {}
async fn on_task(
&self,
_pipeline_id: Uuid,
_task_index: usize,
_task_name: Option<&str>,
_total_tasks: usize,
_status: TaskStatus,
) {
}
async fn on_pipeline_run_initiated(&self, run: &PipelineRunInfo) {
let run_info = Some(run_info_for_initiated());
if let Err(e) = self
.repo
.log_pipeline_run(
run.run_id,
run.pipeline_id,
&run.pipeline_name,
run.dataset_id,
DbStatus::Initiated,
run_info,
)
.await
{
tracing::warn!(
run_id = %run.run_id,
"DbPipelineWatcher: DB write for Initiated failed (non-fatal): {e}"
);
}
}
async fn on_pipeline_run_started(&self, run: &PipelineRunInfo) {
let run_info = Some(run_info_for_running(&run.data_ids));
if let Err(e) = self
.repo
.log_pipeline_run(
run.run_id,
run.pipeline_id,
&run.pipeline_name,
run.dataset_id,
Self::core_to_db_status(&run.status),
run_info,
)
.await
{
tracing::warn!(
run_id = %run.run_id,
"DbPipelineWatcher: DB write for Started failed (non-fatal): {e}"
);
}
}
async fn on_pipeline_run_completed(&self, run: &PipelineRunInfo, _output_count: usize) {
let run_info = Some(run_info_for_running(&run.data_ids));
if let Err(e) = self
.repo
.log_pipeline_run(
run.run_id,
run.pipeline_id,
&run.pipeline_name,
run.dataset_id,
DbStatus::Completed,
run_info,
)
.await
{
tracing::warn!(
run_id = %run.run_id,
"DbPipelineWatcher: DB write for Completed failed (non-fatal): {e}"
);
}
}
async fn on_pipeline_run_errored(&self, run: &PipelineRunInfo, error: &str) {
let run_info = Some(run_info_for_errored(&run.data_ids, error));
if let Err(e) = self
.repo
.log_pipeline_run(
run.run_id,
run.pipeline_id,
&run.pipeline_name,
run.dataset_id,
DbStatus::Errored,
run_info,
)
.await
{
tracing::warn!(
run_id = %run.run_id,
"DbPipelineWatcher: DB write for Errored failed (non-fatal): {e}"
);
}
}
async fn on_payload_field(&self, run_id: Uuid, key: &str, value: serde_json::Value) {
if let Err(e) = self.repo.set_payload_field(run_id, key, value).await {
tracing::warn!(
run_id = %run_id,
key = %key,
"DbPipelineWatcher: DB write for payload field failed (non-fatal): {e}"
);
}
}
}