use std::sync::Arc;
use chrono::Utc;
use tokio::task;
use crate::{
application::{
audit::AuditUseCase,
dto::{
AuditQueryRequest, AuditQueryResponse, IngestClaimRequest, IngestClaimResponse,
QueryHistoryRequest, QueryHistoryResponse, QueryMemoryRequest, QueryMemoryResponse,
ReconcileRequest, ReconcileResponse,
},
ingest_claim::IngestClaimUseCase,
query_history::QueryHistoryUseCase,
query_memory::QueryMemoryUseCase,
reconcile::ReconcileUseCase,
submit_adjudication::SubmitAdjudicationUseCase,
sweep_adjudications::SweepAdjudicationsUseCase,
},
concurrency::agent_lock::AgentWriteLockMap,
config::EngineConfig,
error::MemError,
ports::{OraclePort, PendingAdjudicationPort, PersistencePort, VectorPort},
};
#[allow(missing_docs)]
pub trait ErasedPendingStore: Send + Sync + 'static {
fn insert_pending_erased(
&self,
row: &crate::ports::pending_adjudication::PendingAdjudicationRow,
) -> Result<(), Box<dyn std::error::Error + Send + Sync + 'static>>;
fn get_pending_erased(
&self,
handle_id: uuid::Uuid,
) -> Result<Option<crate::ports::pending_adjudication::PendingAdjudicationRow>, Box<dyn std::error::Error + Send + Sync + 'static>>;
fn list_pending_erased(
&self,
agent_id: Option<&mempill_types::AgentId>,
) -> Result<Vec<crate::ports::pending_adjudication::PendingAdjudicationRow>, Box<dyn std::error::Error + Send + Sync + 'static>>;
fn list_expired_erased(
&self,
now: chrono::DateTime<chrono::Utc>,
) -> Result<Vec<crate::ports::pending_adjudication::PendingAdjudicationRow>, Box<dyn std::error::Error + Send + Sync + 'static>>;
fn mark_resolved_erased(
&self,
handle_id: uuid::Uuid,
) -> Result<(), Box<dyn std::error::Error + Send + Sync + 'static>>;
fn mark_expired_erased(
&self,
handle_id: uuid::Uuid,
) -> Result<(), Box<dyn std::error::Error + Send + Sync + 'static>>;
fn list_queued_orphan_claims_erased(
&self,
) -> Result<Vec<crate::ports::pending_adjudication::OrphanedQueuedClaim>, Box<dyn std::error::Error + Send + Sync + 'static>>;
}
pub struct ErasedPendingStoreAdapter<S: PendingAdjudicationPort> {
inner: S,
}
impl<S: PendingAdjudicationPort> ErasedPendingStoreAdapter<S> {
pub fn new(inner: S) -> Self {
Self { inner }
}
}
impl<S: PendingAdjudicationPort> ErasedPendingStore for ErasedPendingStoreAdapter<S> {
fn insert_pending_erased(
&self,
row: &crate::ports::pending_adjudication::PendingAdjudicationRow,
) -> Result<(), Box<dyn std::error::Error + Send + Sync + 'static>> {
self.inner.insert_pending(row).map_err(|e| Box::new(e) as _)
}
fn get_pending_erased(
&self,
handle_id: uuid::Uuid,
) -> Result<Option<crate::ports::pending_adjudication::PendingAdjudicationRow>, Box<dyn std::error::Error + Send + Sync + 'static>> {
self.inner.get_pending(handle_id).map_err(|e| Box::new(e) as _)
}
fn list_pending_erased(
&self,
agent_id: Option<&mempill_types::AgentId>,
) -> Result<Vec<crate::ports::pending_adjudication::PendingAdjudicationRow>, Box<dyn std::error::Error + Send + Sync + 'static>> {
self.inner.list_pending(agent_id).map_err(|e| Box::new(e) as _)
}
fn list_expired_erased(
&self,
now: chrono::DateTime<chrono::Utc>,
) -> Result<Vec<crate::ports::pending_adjudication::PendingAdjudicationRow>, Box<dyn std::error::Error + Send + Sync + 'static>> {
self.inner.list_expired(now).map_err(|e| Box::new(e) as _)
}
fn mark_resolved_erased(
&self,
handle_id: uuid::Uuid,
) -> Result<(), Box<dyn std::error::Error + Send + Sync + 'static>> {
self.inner.mark_resolved(handle_id).map_err(|e| Box::new(e) as _)
}
fn mark_expired_erased(
&self,
handle_id: uuid::Uuid,
) -> Result<(), Box<dyn std::error::Error + Send + Sync + 'static>> {
self.inner.mark_expired(handle_id).map_err(|e| Box::new(e) as _)
}
fn list_queued_orphan_claims_erased(
&self,
) -> Result<Vec<crate::ports::pending_adjudication::OrphanedQueuedClaim>, Box<dyn std::error::Error + Send + Sync + 'static>> {
self.inner.list_queued_orphan_claims().map_err(|e| Box::new(e) as _)
}
}
pub struct EngineHandle<P, O, V>
where
P: PersistencePort + Send + Sync + 'static,
O: OraclePort + Send + Sync + 'static,
V: VectorPort + Send + Sync + 'static,
{
persistence: Arc<P>,
oracle: Option<Arc<O>>,
vector: Option<Arc<V>>,
pending_store: Option<Arc<dyn ErasedPendingStore>>,
config: EngineConfig,
write_locks: AgentWriteLockMap,
store_write_lock: Arc<tokio::sync::Mutex<()>>,
}
impl<P, O, V> EngineHandle<P, O, V>
where
P: PersistencePort + Send + Sync + 'static,
O: OraclePort + Send + Sync + 'static,
V: VectorPort + Send + Sync + 'static,
{
pub fn new(
persistence: Arc<P>,
oracle: Option<Arc<O>>,
vector: Option<Arc<V>>,
config: EngineConfig,
) -> Self {
Self {
persistence,
oracle,
vector,
pending_store: None,
config,
write_locks: AgentWriteLockMap::new(),
store_write_lock: Arc::new(tokio::sync::Mutex::new(())),
}
}
pub fn new_with_pending_store<S>(
persistence: Arc<P>,
oracle: Option<Arc<O>>,
vector: Option<Arc<V>>,
pending_store: Arc<dyn ErasedPendingStore>,
config: EngineConfig,
) -> Self {
Self {
persistence,
oracle,
vector,
pending_store: Some(pending_store),
config,
write_locks: AgentWriteLockMap::new(),
store_write_lock: Arc::new(tokio::sync::Mutex::new(())),
}
}
pub async fn ingest_claim(
&self,
req: IngestClaimRequest,
) -> Result<IngestClaimResponse, MemError> {
let now = Utc::now(); let _store_lock = if self.persistence.requires_global_write_serialization() {
Some(self.store_write_lock.lock().await)
} else {
None
};
let _guard = self.write_locks.acquire(&req.agent_id).await;
let uc = IngestClaimUseCase::new(
Arc::clone(&self.persistence),
self.oracle.clone(),
self.pending_store.clone(),
self.config.clone(),
);
task::spawn_blocking(move || uc.execute_with_time(req, now))
.await
.map_err(|e| MemError::SpawnBlocking { reason: e.to_string() })?
}
pub async fn query_memory(
&self,
req: QueryMemoryRequest,
) -> Result<QueryMemoryResponse, MemError> {
let now = Utc::now();
let uc = QueryMemoryUseCase::new(
Arc::clone(&self.persistence),
self.vector.clone(),
self.config.clone(),
);
task::spawn_blocking(move || uc.execute_with_time(req, now))
.await
.map_err(|e| MemError::SpawnBlocking { reason: e.to_string() })?
}
pub async fn query_history(
&self,
req: QueryHistoryRequest,
) -> Result<QueryHistoryResponse, MemError> {
let now = Utc::now();
let uc = QueryHistoryUseCase::new(
Arc::clone(&self.persistence),
self.vector.clone(),
self.config.clone(),
);
task::spawn_blocking(move || uc.execute_with_time(req, now))
.await
.map_err(|e| MemError::SpawnBlocking { reason: e.to_string() })?
}
pub async fn reconcile(
&self,
req: ReconcileRequest,
) -> Result<ReconcileResponse, MemError> {
let _store_lock = if self.persistence.requires_global_write_serialization() {
Some(self.store_write_lock.lock().await)
} else {
None
};
let _guard = self.write_locks.acquire(&req.agent_id).await;
let uc = ReconcileUseCase::new(
Arc::clone(&self.persistence),
self.oracle.clone(),
self.config.clone(),
);
task::spawn_blocking(move || uc.execute(req))
.await
.map_err(|e| MemError::SpawnBlocking { reason: e.to_string() })?
}
pub async fn query_audit(
&self,
req: AuditQueryRequest,
) -> Result<AuditQueryResponse, MemError> {
let uc = AuditUseCase::new(Arc::clone(&self.persistence));
task::spawn_blocking(move || uc.execute(req))
.await
.map_err(|e| MemError::SpawnBlocking { reason: e.to_string() })?
}
pub async fn submit_adjudication(
&self,
handle_id: uuid::Uuid,
response: mempill_types::AdjudicationResponse,
) -> Result<mempill_types::AdjudicationOutcome, MemError> {
let now = Utc::now();
let pending_store = self.pending_store.as_ref()
.ok_or(MemError::AdjudicationHandleNotFound { handle_id })?;
let pending_store_arc = Arc::clone(pending_store);
let resolve_result = task::spawn_blocking(move || {
let row = pending_store_arc
.get_pending_erased(handle_id)
.map_err(|e| MemError::PendingStore { source: e })?
.ok_or(MemError::AdjudicationHandleNotFound { handle_id })?;
Ok::<_, MemError>(row)
})
.await
.map_err(|e| MemError::SpawnBlocking { reason: e.to_string() })??;
let row = resolve_result;
let agent_id = row.agent_id.clone();
let _store_lock = if self.persistence.requires_global_write_serialization() {
Some(self.store_write_lock.lock().await)
} else {
None
};
let _guard = self.write_locks.acquire(&agent_id).await;
let pending_store_arc2 = Arc::clone(pending_store);
let uc = SubmitAdjudicationUseCase::new(
Arc::clone(&self.persistence),
pending_store_arc2,
);
task::spawn_blocking(move || uc.execute(handle_id, response, now))
.await
.map_err(|e| MemError::SpawnBlocking { reason: e.to_string() })?
}
pub async fn list_pending_adjudications(
&self,
agent_id: Option<mempill_types::AgentId>,
) -> Result<Vec<crate::ports::pending_adjudication::PendingAdjudicationRow>, MemError> {
let pending_store = match &self.pending_store {
Some(ps) => Arc::clone(ps),
None => return Ok(vec![]),
};
task::spawn_blocking(move || {
pending_store
.list_pending_erased(agent_id.as_ref())
.map_err(|e| MemError::PendingStore { source: e })
})
.await
.map_err(|e| MemError::SpawnBlocking { reason: e.to_string() })?
}
pub async fn sweep_expired_adjudications(&self) -> Result<usize, MemError> {
let now = Utc::now();
let pending_store = match &self.pending_store {
Some(ps) => Arc::clone(ps),
None => return Ok(0),
};
let ps_for_list = Arc::clone(&pending_store);
let expired_rows = task::spawn_blocking(move || {
ps_for_list
.list_expired_erased(now)
.map_err(|e| MemError::PendingStore { source: e })
})
.await
.map_err(|e| MemError::SpawnBlocking { reason: e.to_string() })??;
let mut swept = 0usize;
for row in expired_rows {
let agent_id = row.agent_id.clone();
let _store_lock = if self.persistence.requires_global_write_serialization() {
Some(self.store_write_lock.lock().await)
} else {
None
};
let _guard = self.write_locks.acquire(&agent_id).await;
let persistence = Arc::clone(&self.persistence);
let ps = Arc::clone(&pending_store);
let row_clone = row.clone();
let result = task::spawn_blocking(move || {
let uc = SweepAdjudicationsUseCase::new(persistence, ps);
uc.revert_expired_row(&row_clone, now)
})
.await
.map_err(|e| MemError::SpawnBlocking { reason: e.to_string() })??;
if result {
swept += 1;
}
}
let ps_for_orphans = Arc::clone(&pending_store);
let orphans = task::spawn_blocking(move || {
ps_for_orphans
.list_queued_orphan_claims_erased()
.map_err(|e| MemError::PendingStore { source: e })
})
.await
.map_err(|e| MemError::SpawnBlocking { reason: e.to_string() })??;
for orphan in orphans {
let agent_id = orphan.agent_id.clone();
let _store_lock = if self.persistence.requires_global_write_serialization() {
Some(self.store_write_lock.lock().await)
} else {
None
};
let _guard = self.write_locks.acquire(&agent_id).await;
let persistence = Arc::clone(&self.persistence);
let ps = Arc::clone(&pending_store);
let orphan_clone = orphan.clone();
let result = task::spawn_blocking(move || {
let uc = SweepAdjudicationsUseCase::new(persistence, ps);
uc.revert_orphan(&orphan_clone, now)
})
.await
.map_err(|e| MemError::SpawnBlocking { reason: e.to_string() })??;
if result {
swept += 1;
}
}
Ok(swept)
}
}
impl<P, O, V> Clone for EngineHandle<P, O, V>
where
P: PersistencePort + Send + Sync + 'static,
O: OraclePort + Send + Sync + 'static,
V: VectorPort + Send + Sync + 'static,
{
fn clone(&self) -> Self {
Self {
persistence: Arc::clone(&self.persistence),
oracle: self.oracle.clone(),
vector: self.vector.clone(),
pending_store: self.pending_store.clone(),
config: self.config.clone(),
write_locks: self.write_locks.clone(),
store_write_lock: Arc::clone(&self.store_write_lock),
}
}
}