use std::collections::HashMap;
use async_trait::async_trait;
use chrono::{DateTime, Utc};
use deadpool_postgres::Pool;
use rust_decimal::Decimal;
use uuid::Uuid;
use crate::agent::BrokenTool;
use crate::agent::routine::{Routine, RoutineRun, RunStatus};
use crate::config::DatabaseConfig;
use crate::context::{ActionRecord, JobContext, JobState};
use crate::db::{
ConversationStore, Database, JobStore, RoutineStore, SandboxStore, SettingsStore,
ToolFailureStore, WorkspaceStore,
};
use crate::error::{DatabaseError, WorkspaceError};
use crate::history::{
AgentJobRecord, AgentJobSummary, ConversationMessage, ConversationSummary, JobEventRecord,
LlmCallRecord, SandboxJobRecord, SandboxJobSummary, SettingRow, Store,
};
use crate::workspace::{
MemoryChunk, MemoryDocument, Repository, SearchConfig, SearchResult, WorkspaceEntry,
};
pub struct PgBackend {
store: Store,
repo: Repository,
}
impl PgBackend {
pub async fn new(config: &DatabaseConfig) -> Result<Self, DatabaseError> {
let store = Store::new(config).await?;
let repo = Repository::new(store.pool());
Ok(Self { store, repo })
}
pub fn pool(&self) -> Pool {
self.store.pool()
}
}
#[async_trait]
impl Database for PgBackend {
async fn run_migrations(&self) -> Result<(), DatabaseError> {
self.store.run_migrations().await
}
}
#[async_trait]
impl ConversationStore for PgBackend {
async fn create_conversation(
&self,
channel: &str,
user_id: &str,
thread_id: Option<&str>,
) -> Result<Uuid, DatabaseError> {
self.store
.create_conversation(channel, user_id, thread_id)
.await
}
async fn touch_conversation(&self, id: Uuid) -> Result<(), DatabaseError> {
self.store.touch_conversation(id).await
}
async fn add_conversation_message(
&self,
conversation_id: Uuid,
role: &str,
content: &str,
) -> Result<Uuid, DatabaseError> {
self.store
.add_conversation_message(conversation_id, role, content)
.await
}
async fn ensure_conversation(
&self,
id: Uuid,
channel: &str,
user_id: &str,
thread_id: Option<&str>,
) -> Result<bool, DatabaseError> {
self.store
.ensure_conversation(id, channel, user_id, thread_id)
.await
}
async fn list_conversations_with_preview(
&self,
user_id: &str,
channel: &str,
limit: i64,
) -> Result<Vec<ConversationSummary>, DatabaseError> {
self.store
.list_conversations_with_preview(user_id, channel, limit)
.await
}
async fn list_conversations_all_channels(
&self,
user_id: &str,
limit: i64,
) -> Result<Vec<ConversationSummary>, DatabaseError> {
self.store
.list_conversations_all_channels(user_id, limit)
.await
}
async fn get_or_create_routine_conversation(
&self,
routine_id: Uuid,
routine_name: &str,
user_id: &str,
) -> Result<Uuid, DatabaseError> {
self.store
.get_or_create_routine_conversation(routine_id, routine_name, user_id)
.await
}
async fn get_or_create_heartbeat_conversation(
&self,
user_id: &str,
) -> Result<Uuid, DatabaseError> {
self.store
.get_or_create_heartbeat_conversation(user_id)
.await
}
async fn get_or_create_assistant_conversation(
&self,
user_id: &str,
channel: &str,
) -> Result<Uuid, DatabaseError> {
self.store
.get_or_create_assistant_conversation(user_id, channel)
.await
}
async fn create_conversation_with_metadata(
&self,
channel: &str,
user_id: &str,
metadata: &serde_json::Value,
) -> Result<Uuid, DatabaseError> {
self.store
.create_conversation_with_metadata(channel, user_id, metadata)
.await
}
async fn list_conversation_messages_paginated(
&self,
conversation_id: Uuid,
before: Option<DateTime<Utc>>,
limit: i64,
) -> Result<(Vec<ConversationMessage>, bool), DatabaseError> {
self.store
.list_conversation_messages_paginated(conversation_id, before, limit)
.await
}
async fn update_conversation_metadata_field(
&self,
id: Uuid,
key: &str,
value: &serde_json::Value,
) -> Result<(), DatabaseError> {
self.store
.update_conversation_metadata_field(id, key, value)
.await
}
async fn get_conversation_metadata(
&self,
id: Uuid,
) -> Result<Option<serde_json::Value>, DatabaseError> {
self.store.get_conversation_metadata(id).await
}
async fn list_conversation_messages(
&self,
conversation_id: Uuid,
) -> Result<Vec<ConversationMessage>, DatabaseError> {
self.store.list_conversation_messages(conversation_id).await
}
async fn conversation_belongs_to_user(
&self,
conversation_id: Uuid,
user_id: &str,
) -> Result<bool, DatabaseError> {
self.store
.conversation_belongs_to_user(conversation_id, user_id)
.await
}
}
#[async_trait]
impl JobStore for PgBackend {
async fn save_job(&self, ctx: &JobContext) -> Result<(), DatabaseError> {
self.store.save_job(ctx).await
}
async fn get_job(&self, id: Uuid) -> Result<Option<JobContext>, DatabaseError> {
self.store.get_job(id).await
}
async fn update_job_status(
&self,
id: Uuid,
status: JobState,
failure_reason: Option<&str>,
) -> Result<(), DatabaseError> {
self.store
.update_job_status(id, status, failure_reason)
.await
}
async fn mark_job_stuck(&self, id: Uuid) -> Result<(), DatabaseError> {
self.store.mark_job_stuck(id).await
}
async fn get_stuck_jobs(&self) -> Result<Vec<Uuid>, DatabaseError> {
self.store.get_stuck_jobs().await
}
async fn list_agent_jobs(&self) -> Result<Vec<AgentJobRecord>, DatabaseError> {
self.store.list_agent_jobs().await
}
async fn list_agent_jobs_for_user(
&self,
user_id: &str,
) -> Result<Vec<AgentJobRecord>, DatabaseError> {
self.store.list_agent_jobs_for_user(user_id).await
}
async fn agent_job_summary(&self) -> Result<AgentJobSummary, DatabaseError> {
self.store.agent_job_summary().await
}
async fn agent_job_summary_for_user(
&self,
user_id: &str,
) -> Result<AgentJobSummary, DatabaseError> {
self.store.agent_job_summary_for_user(user_id).await
}
async fn get_agent_job_failure_reason(
&self,
id: Uuid,
) -> Result<Option<String>, DatabaseError> {
self.store.get_agent_job_failure_reason(id).await
}
async fn save_action(&self, job_id: Uuid, action: &ActionRecord) -> Result<(), DatabaseError> {
self.store.save_action(job_id, action).await
}
async fn get_job_actions(&self, job_id: Uuid) -> Result<Vec<ActionRecord>, DatabaseError> {
self.store.get_job_actions(job_id).await
}
async fn record_llm_call(&self, record: &LlmCallRecord<'_>) -> Result<Uuid, DatabaseError> {
self.store.record_llm_call(record).await
}
async fn save_estimation_snapshot(
&self,
job_id: Uuid,
category: &str,
tool_names: &[String],
estimated_cost: Decimal,
estimated_time_secs: i32,
estimated_value: Decimal,
) -> Result<Uuid, DatabaseError> {
self.store
.save_estimation_snapshot(
job_id,
category,
tool_names,
estimated_cost,
estimated_time_secs,
estimated_value,
)
.await
}
async fn update_estimation_actuals(
&self,
id: Uuid,
actual_cost: Decimal,
actual_time_secs: i32,
actual_value: Option<Decimal>,
) -> Result<(), DatabaseError> {
self.store
.update_estimation_actuals(id, actual_cost, actual_time_secs, actual_value)
.await
}
}
#[async_trait]
impl SandboxStore for PgBackend {
async fn save_sandbox_job(&self, job: &SandboxJobRecord) -> Result<(), DatabaseError> {
self.store.save_sandbox_job(job).await
}
async fn get_sandbox_job(&self, id: Uuid) -> Result<Option<SandboxJobRecord>, DatabaseError> {
self.store.get_sandbox_job(id).await
}
async fn list_sandbox_jobs(&self) -> Result<Vec<SandboxJobRecord>, DatabaseError> {
self.store.list_sandbox_jobs().await
}
async fn update_sandbox_job_status(
&self,
id: Uuid,
status: &str,
success: Option<bool>,
message: Option<&str>,
started_at: Option<DateTime<Utc>>,
completed_at: Option<DateTime<Utc>>,
) -> Result<(), DatabaseError> {
self.store
.update_sandbox_job_status(id, status, success, message, started_at, completed_at)
.await
}
async fn cleanup_stale_sandbox_jobs(&self) -> Result<u64, DatabaseError> {
self.store.cleanup_stale_sandbox_jobs().await
}
async fn sandbox_job_summary(&self) -> Result<SandboxJobSummary, DatabaseError> {
self.store.sandbox_job_summary().await
}
async fn list_sandbox_jobs_for_user(
&self,
user_id: &str,
) -> Result<Vec<SandboxJobRecord>, DatabaseError> {
self.store.list_sandbox_jobs_for_user(user_id).await
}
async fn sandbox_job_summary_for_user(
&self,
user_id: &str,
) -> Result<SandboxJobSummary, DatabaseError> {
self.store.sandbox_job_summary_for_user(user_id).await
}
async fn sandbox_job_belongs_to_user(
&self,
job_id: Uuid,
user_id: &str,
) -> Result<bool, DatabaseError> {
self.store
.sandbox_job_belongs_to_user(job_id, user_id)
.await
}
async fn update_sandbox_job_mode(&self, id: Uuid, mode: &str) -> Result<(), DatabaseError> {
self.store.update_sandbox_job_mode(id, mode).await
}
async fn get_sandbox_job_mode(&self, id: Uuid) -> Result<Option<String>, DatabaseError> {
self.store.get_sandbox_job_mode(id).await
}
async fn save_job_event(
&self,
job_id: Uuid,
event_type: &str,
data: &serde_json::Value,
) -> Result<(), DatabaseError> {
self.store.save_job_event(job_id, event_type, data).await
}
async fn list_job_events(
&self,
job_id: Uuid,
limit: Option<i64>,
) -> Result<Vec<JobEventRecord>, DatabaseError> {
self.store.list_job_events(job_id, limit).await
}
}
#[async_trait]
impl RoutineStore for PgBackend {
async fn create_routine(&self, routine: &Routine) -> Result<(), DatabaseError> {
self.store.create_routine(routine).await
}
async fn get_routine(&self, id: Uuid) -> Result<Option<Routine>, DatabaseError> {
self.store.get_routine(id).await
}
async fn get_routine_by_name(
&self,
user_id: &str,
name: &str,
) -> Result<Option<Routine>, DatabaseError> {
self.store.get_routine_by_name(user_id, name).await
}
async fn list_routines(&self, user_id: &str) -> Result<Vec<Routine>, DatabaseError> {
self.store.list_routines(user_id).await
}
async fn list_all_routines(&self) -> Result<Vec<Routine>, DatabaseError> {
self.store.list_all_routines().await
}
async fn list_event_routines(&self) -> Result<Vec<Routine>, DatabaseError> {
self.store.list_event_routines().await
}
async fn list_due_cron_routines(&self) -> Result<Vec<Routine>, DatabaseError> {
self.store.list_due_cron_routines().await
}
async fn update_routine(&self, routine: &Routine) -> Result<(), DatabaseError> {
self.store.update_routine(routine).await
}
async fn update_routine_runtime(
&self,
id: Uuid,
last_run_at: DateTime<Utc>,
next_fire_at: Option<DateTime<Utc>>,
run_count: u64,
consecutive_failures: u32,
state: &serde_json::Value,
) -> Result<(), DatabaseError> {
self.store
.update_routine_runtime(
id,
last_run_at,
next_fire_at,
run_count,
consecutive_failures,
state,
)
.await
}
async fn delete_routine(&self, id: Uuid) -> Result<bool, DatabaseError> {
self.store.delete_routine(id).await
}
async fn create_routine_run(&self, run: &RoutineRun) -> Result<(), DatabaseError> {
self.store.create_routine_run(run).await
}
async fn complete_routine_run(
&self,
id: Uuid,
status: RunStatus,
result_summary: Option<&str>,
tokens_used: Option<i32>,
) -> Result<(), DatabaseError> {
self.store
.complete_routine_run(id, status, result_summary, tokens_used)
.await
}
async fn list_routine_runs(
&self,
routine_id: Uuid,
limit: i64,
) -> Result<Vec<RoutineRun>, DatabaseError> {
self.store.list_routine_runs(routine_id, limit).await
}
async fn count_running_routine_runs(&self, routine_id: Uuid) -> Result<i64, DatabaseError> {
self.store.count_running_routine_runs(routine_id).await
}
async fn count_running_routine_runs_batch(
&self,
routine_ids: &[Uuid],
) -> Result<std::collections::HashMap<Uuid, i64>, DatabaseError> {
self.store
.count_running_routine_runs_batch(routine_ids)
.await
}
async fn batch_get_last_run_status(
&self,
routine_ids: &[Uuid],
) -> Result<std::collections::HashMap<Uuid, crate::agent::routine::RunStatus>, DatabaseError>
{
self.store.batch_get_last_run_status(routine_ids).await
}
async fn link_routine_run_to_job(
&self,
run_id: Uuid,
job_id: Uuid,
) -> Result<(), DatabaseError> {
self.store.link_routine_run_to_job(run_id, job_id).await
}
async fn get_webhook_routine_by_path(
&self,
path: &str,
) -> Result<Option<Routine>, DatabaseError> {
self.store.get_webhook_routine_by_path(path).await
}
async fn list_dispatched_routine_runs(&self) -> Result<Vec<RoutineRun>, DatabaseError> {
self.store.list_dispatched_routine_runs().await
}
}
#[async_trait]
impl ToolFailureStore for PgBackend {
async fn record_tool_failure(
&self,
tool_name: &str,
error_message: &str,
) -> Result<(), DatabaseError> {
self.store
.record_tool_failure(tool_name, error_message)
.await
}
async fn get_broken_tools(&self, threshold: i32) -> Result<Vec<BrokenTool>, DatabaseError> {
self.store.get_broken_tools(threshold).await
}
async fn mark_tool_repaired(&self, tool_name: &str) -> Result<(), DatabaseError> {
self.store.mark_tool_repaired(tool_name).await
}
async fn increment_repair_attempts(&self, tool_name: &str) -> Result<(), DatabaseError> {
self.store.increment_repair_attempts(tool_name).await
}
}
#[async_trait]
impl SettingsStore for PgBackend {
async fn get_setting(
&self,
user_id: &str,
key: &str,
) -> Result<Option<serde_json::Value>, DatabaseError> {
self.store.get_setting(user_id, key).await
}
async fn get_setting_full(
&self,
user_id: &str,
key: &str,
) -> Result<Option<SettingRow>, DatabaseError> {
self.store.get_setting_full(user_id, key).await
}
async fn set_setting(
&self,
user_id: &str,
key: &str,
value: &serde_json::Value,
) -> Result<(), DatabaseError> {
self.store.set_setting(user_id, key, value).await
}
async fn delete_setting(&self, user_id: &str, key: &str) -> Result<bool, DatabaseError> {
self.store.delete_setting(user_id, key).await
}
async fn list_settings(&self, user_id: &str) -> Result<Vec<SettingRow>, DatabaseError> {
self.store.list_settings(user_id).await
}
async fn get_all_settings(
&self,
user_id: &str,
) -> Result<HashMap<String, serde_json::Value>, DatabaseError> {
self.store.get_all_settings(user_id).await
}
async fn set_all_settings(
&self,
user_id: &str,
settings: &HashMap<String, serde_json::Value>,
) -> Result<(), DatabaseError> {
self.store.set_all_settings(user_id, settings).await
}
async fn has_settings(&self, user_id: &str) -> Result<bool, DatabaseError> {
self.store.has_settings(user_id).await
}
}
#[async_trait]
impl WorkspaceStore for PgBackend {
async fn get_document_by_path(
&self,
user_id: &str,
agent_id: Option<Uuid>,
path: &str,
) -> Result<MemoryDocument, WorkspaceError> {
self.repo
.get_document_by_path(user_id, agent_id, path)
.await
}
async fn get_document_by_id(&self, id: Uuid) -> Result<MemoryDocument, WorkspaceError> {
self.repo.get_document_by_id(id).await
}
async fn get_or_create_document_by_path(
&self,
user_id: &str,
agent_id: Option<Uuid>,
path: &str,
) -> Result<MemoryDocument, WorkspaceError> {
self.repo
.get_or_create_document_by_path(user_id, agent_id, path)
.await
}
async fn update_document(&self, id: Uuid, content: &str) -> Result<(), WorkspaceError> {
self.repo.update_document(id, content).await
}
async fn delete_document_by_path(
&self,
user_id: &str,
agent_id: Option<Uuid>,
path: &str,
) -> Result<(), WorkspaceError> {
self.repo
.delete_document_by_path(user_id, agent_id, path)
.await
}
async fn list_directory(
&self,
user_id: &str,
agent_id: Option<Uuid>,
directory: &str,
) -> Result<Vec<WorkspaceEntry>, WorkspaceError> {
self.repo.list_directory(user_id, agent_id, directory).await
}
async fn list_all_paths(
&self,
user_id: &str,
agent_id: Option<Uuid>,
) -> Result<Vec<String>, WorkspaceError> {
self.repo.list_all_paths(user_id, agent_id).await
}
async fn list_documents(
&self,
user_id: &str,
agent_id: Option<Uuid>,
) -> Result<Vec<MemoryDocument>, WorkspaceError> {
self.repo.list_documents(user_id, agent_id).await
}
async fn delete_chunks(&self, document_id: Uuid) -> Result<(), WorkspaceError> {
self.repo.delete_chunks(document_id).await
}
async fn insert_chunk(
&self,
document_id: Uuid,
chunk_index: i32,
content: &str,
embedding: Option<&[f32]>,
) -> Result<Uuid, WorkspaceError> {
self.repo
.insert_chunk(document_id, chunk_index, content, embedding)
.await
}
async fn update_chunk_embedding(
&self,
chunk_id: Uuid,
embedding: &[f32],
) -> Result<(), WorkspaceError> {
self.repo.update_chunk_embedding(chunk_id, embedding).await
}
async fn get_chunks_without_embeddings(
&self,
user_id: &str,
agent_id: Option<Uuid>,
limit: usize,
) -> Result<Vec<MemoryChunk>, WorkspaceError> {
self.repo
.get_chunks_without_embeddings(user_id, agent_id, limit)
.await
}
async fn hybrid_search(
&self,
user_id: &str,
agent_id: Option<Uuid>,
query: &str,
embedding: Option<&[f32]>,
config: &SearchConfig,
) -> Result<Vec<SearchResult>, WorkspaceError> {
self.repo
.hybrid_search(user_id, agent_id, query, embedding, config)
.await
}
async fn hybrid_search_multi(
&self,
user_ids: &[String],
agent_id: Option<Uuid>,
query: &str,
embedding: Option<&[f32]>,
config: &SearchConfig,
) -> Result<Vec<SearchResult>, WorkspaceError> {
self.repo
.hybrid_search_multi(user_ids, agent_id, query, embedding, config)
.await
}
async fn list_all_paths_multi(
&self,
user_ids: &[String],
agent_id: Option<Uuid>,
) -> Result<Vec<String>, WorkspaceError> {
self.repo.list_all_paths_multi(user_ids, agent_id).await
}
async fn get_document_by_path_multi(
&self,
user_ids: &[String],
agent_id: Option<Uuid>,
path: &str,
) -> Result<MemoryDocument, WorkspaceError> {
self.repo
.get_document_by_path_multi(user_ids, agent_id, path)
.await
}
async fn list_directory_multi(
&self,
user_ids: &[String],
agent_id: Option<Uuid>,
directory: &str,
) -> Result<Vec<WorkspaceEntry>, WorkspaceError> {
self.repo
.list_directory_multi(user_ids, agent_id, directory)
.await
}
}