use std::collections::HashMap;
use std::sync::Arc;
use chrono::{DateTime, Utc};
use rust_decimal::Decimal;
use tokio::sync::{Semaphore, SemaphorePermit};
use uuid::Uuid;
use crate::agent::BrokenTool;
use crate::agent::cost_guard::{CostGuard, CostLimitExceeded};
use crate::agent::routine::{Routine, RoutineRun, RunStatus};
use crate::context::{ActionRecord, JobContext, JobState};
use crate::db::Database;
use crate::error::DatabaseError;
use crate::history::{
AgentJobRecord, AgentJobSummary, ConversationMessage, ConversationSummary, LlmCallRecord,
SandboxJobRecord, SandboxJobSummary, SettingRow,
};
use crate::workspace::Workspace;
#[derive(Clone)]
pub struct TenantScope {
user_id: String,
inner: Arc<dyn Database>,
}
impl TenantScope {
pub fn new(user_id: impl Into<String>, db: Arc<dyn Database>) -> Self {
Self {
user_id: user_id.into(),
inner: db,
}
}
pub fn user_id(&self) -> &str {
&self.user_id
}
pub async fn list_agent_jobs(&self) -> Result<Vec<AgentJobRecord>, DatabaseError> {
self.inner.list_agent_jobs_for_user(&self.user_id).await
}
pub async fn agent_job_summary(&self) -> Result<AgentJobSummary, DatabaseError> {
self.inner.agent_job_summary_for_user(&self.user_id).await
}
pub async fn get_job(&self, id: Uuid) -> Result<Option<JobContext>, DatabaseError> {
match self.inner.get_job(id).await? {
Some(ctx) if ctx.user_id == self.user_id => Ok(Some(ctx)),
_ => Ok(None),
}
}
pub async fn get_agent_job_failure_reason(
&self,
id: Uuid,
) -> Result<Option<String>, DatabaseError> {
if self.get_job(id).await?.is_none() {
return Ok(None);
}
self.inner.get_agent_job_failure_reason(id).await
}
pub async fn update_job_status(
&self,
id: Uuid,
status: JobState,
failure_reason: Option<&str>,
) -> Result<(), DatabaseError> {
if self.get_job(id).await?.is_none() {
return Err(DatabaseError::NotFound {
entity: "job".to_string(),
id: id.to_string(),
});
}
self.inner
.update_job_status(id, status, failure_reason)
.await
}
pub async fn list_sandbox_jobs(&self) -> Result<Vec<SandboxJobRecord>, DatabaseError> {
self.inner.list_sandbox_jobs_for_user(&self.user_id).await
}
pub async fn sandbox_job_summary(&self) -> Result<SandboxJobSummary, DatabaseError> {
self.inner.sandbox_job_summary_for_user(&self.user_id).await
}
pub async fn get_sandbox_job(
&self,
id: Uuid,
) -> Result<Option<SandboxJobRecord>, DatabaseError> {
match self.inner.get_sandbox_job(id).await? {
Some(job) if job.user_id == self.user_id => Ok(Some(job)),
_ => Ok(None),
}
}
pub async fn sandbox_job_belongs_to_user(&self, job_id: Uuid) -> Result<bool, DatabaseError> {
self.inner
.sandbox_job_belongs_to_user(job_id, &self.user_id)
.await
}
pub async fn list_routines(&self) -> Result<Vec<Routine>, DatabaseError> {
self.inner.list_routines(&self.user_id).await
}
pub async fn get_routine_by_name(&self, name: &str) -> Result<Option<Routine>, DatabaseError> {
self.inner.get_routine_by_name(&self.user_id, name).await
}
pub async fn get_routine(&self, id: Uuid) -> Result<Option<Routine>, DatabaseError> {
match self.inner.get_routine(id).await? {
Some(r) if r.user_id == self.user_id => Ok(Some(r)),
_ => Ok(None),
}
}
pub async fn create_routine(&self, routine: &Routine) -> Result<(), DatabaseError> {
debug_assert_eq!(
routine.user_id, self.user_id,
"routine.user_id must match TenantScope user"
);
self.inner.create_routine(routine).await
}
pub async fn update_routine(&self, routine: &Routine) -> Result<(), DatabaseError> {
if self.get_routine(routine.id).await?.is_none() {
return Err(DatabaseError::NotFound {
entity: "routine".to_string(),
id: routine.id.to_string(),
});
}
self.inner.update_routine(routine).await
}
pub async fn delete_routine(&self, id: Uuid) -> Result<bool, DatabaseError> {
if self.get_routine(id).await?.is_none() {
return Err(DatabaseError::NotFound {
entity: "routine".to_string(),
id: id.to_string(),
});
}
self.inner.delete_routine(id).await
}
pub async fn list_routine_runs(
&self,
routine_id: Uuid,
limit: i64,
) -> Result<Vec<RoutineRun>, DatabaseError> {
if self.get_routine(routine_id).await?.is_none() {
return Err(DatabaseError::NotFound {
entity: "routine".to_string(),
id: routine_id.to_string(),
});
}
self.inner.list_routine_runs(routine_id, limit).await
}
pub async fn get_webhook_routine_by_path(
&self,
path: &str,
) -> Result<Option<Routine>, DatabaseError> {
self.inner
.get_webhook_routine_by_path(path, Some(&self.user_id))
.await
}
pub async fn record_llm_call(&self, record: &LlmCallRecord<'_>) -> Result<Uuid, DatabaseError> {
self.inner.record_llm_call(record).await
}
pub async fn get_setting(&self, key: &str) -> Result<Option<serde_json::Value>, DatabaseError> {
self.inner.get_setting(&self.user_id, key).await
}
pub async fn get_setting_full(&self, key: &str) -> Result<Option<SettingRow>, DatabaseError> {
self.inner.get_setting_full(&self.user_id, key).await
}
pub async fn set_setting(
&self,
key: &str,
value: &serde_json::Value,
) -> Result<(), DatabaseError> {
self.inner.set_setting(&self.user_id, key, value).await
}
pub async fn delete_setting(&self, key: &str) -> Result<bool, DatabaseError> {
self.inner.delete_setting(&self.user_id, key).await
}
pub async fn list_settings(&self) -> Result<Vec<SettingRow>, DatabaseError> {
self.inner.list_settings(&self.user_id).await
}
pub async fn get_all_settings(
&self,
) -> Result<HashMap<String, serde_json::Value>, DatabaseError> {
self.inner.get_all_settings(&self.user_id).await
}
pub async fn set_all_settings(
&self,
settings: &HashMap<String, serde_json::Value>,
) -> Result<(), DatabaseError> {
self.inner.set_all_settings(&self.user_id, settings).await
}
pub async fn has_settings(&self) -> Result<bool, DatabaseError> {
self.inner.has_settings(&self.user_id).await
}
pub async fn create_conversation(
&self,
channel: &str,
thread_id: Option<&str>,
) -> Result<Uuid, DatabaseError> {
self.inner
.create_conversation(channel, &self.user_id, thread_id)
.await
}
pub async fn ensure_conversation(
&self,
id: Uuid,
channel: &str,
thread_id: Option<&str>,
) -> Result<bool, DatabaseError> {
self.inner
.ensure_conversation(id, channel, &self.user_id, thread_id)
.await
}
pub async fn list_conversations_with_preview(
&self,
channel: &str,
limit: i64,
) -> Result<Vec<ConversationSummary>, DatabaseError> {
self.inner
.list_conversations_with_preview(&self.user_id, channel, limit)
.await
}
pub async fn list_conversations_all_channels(
&self,
limit: i64,
) -> Result<Vec<ConversationSummary>, DatabaseError> {
self.inner
.list_conversations_all_channels(&self.user_id, limit)
.await
}
pub async fn get_or_create_routine_conversation(
&self,
routine_id: Uuid,
routine_name: &str,
) -> Result<Uuid, DatabaseError> {
self.inner
.get_or_create_routine_conversation(routine_id, routine_name, &self.user_id)
.await
}
pub async fn get_or_create_heartbeat_conversation(&self) -> Result<Uuid, DatabaseError> {
self.inner
.get_or_create_heartbeat_conversation(&self.user_id)
.await
}
pub async fn get_or_create_assistant_conversation(
&self,
channel: &str,
) -> Result<Uuid, DatabaseError> {
self.inner
.get_or_create_assistant_conversation(&self.user_id, channel)
.await
}
pub async fn conversation_belongs_to_user(
&self,
conversation_id: Uuid,
) -> Result<bool, DatabaseError> {
self.inner
.conversation_belongs_to_user(conversation_id, &self.user_id)
.await
}
pub async fn add_conversation_message(
&self,
conversation_id: Uuid,
role: &str,
content: &str,
) -> Result<Uuid, DatabaseError> {
if !self.conversation_belongs_to_user(conversation_id).await? {
return Err(DatabaseError::NotFound {
entity: "conversation".to_string(),
id: conversation_id.to_string(),
});
}
self.inner
.add_conversation_message(conversation_id, role, content)
.await
}
pub async fn touch_conversation(&self, id: Uuid) -> Result<(), DatabaseError> {
if !self.conversation_belongs_to_user(id).await? {
return Err(DatabaseError::NotFound {
entity: "conversation".to_string(),
id: id.to_string(),
});
}
self.inner.touch_conversation(id).await
}
pub async fn list_conversation_messages(
&self,
conversation_id: Uuid,
) -> Result<Vec<ConversationMessage>, DatabaseError> {
if !self.conversation_belongs_to_user(conversation_id).await? {
return Err(DatabaseError::NotFound {
entity: "conversation".to_string(),
id: conversation_id.to_string(),
});
}
self.inner.list_conversation_messages(conversation_id).await
}
pub async fn list_conversation_messages_paginated(
&self,
conversation_id: Uuid,
before: Option<DateTime<Utc>>,
limit: i64,
) -> Result<(Vec<ConversationMessage>, bool), DatabaseError> {
if !self.conversation_belongs_to_user(conversation_id).await? {
return Err(DatabaseError::NotFound {
entity: "conversation".to_string(),
id: conversation_id.to_string(),
});
}
self.inner
.list_conversation_messages_paginated(conversation_id, before, limit)
.await
}
pub async fn create_conversation_with_metadata(
&self,
channel: &str,
metadata: &serde_json::Value,
) -> Result<Uuid, DatabaseError> {
self.inner
.create_conversation_with_metadata(channel, &self.user_id, metadata)
.await
}
pub async fn update_conversation_metadata_field(
&self,
id: Uuid,
key: &str,
value: &serde_json::Value,
) -> Result<(), DatabaseError> {
if !self.conversation_belongs_to_user(id).await? {
return Err(DatabaseError::NotFound {
entity: "conversation".to_string(),
id: id.to_string(),
});
}
self.inner
.update_conversation_metadata_field(id, key, value)
.await
}
pub async fn get_conversation_metadata(
&self,
id: Uuid,
) -> Result<Option<serde_json::Value>, DatabaseError> {
if !self.conversation_belongs_to_user(id).await? {
return Err(DatabaseError::NotFound {
entity: "conversation".to_string(),
id: id.to_string(),
});
}
self.inner.get_conversation_metadata(id).await
}
}
#[derive(Clone)]
pub struct AdminScope {
inner: Arc<dyn Database>,
}
impl AdminScope {
pub fn new(db: Arc<dyn Database>) -> Self {
Self { inner: db }
}
pub fn db(&self) -> &Arc<dyn Database> {
&self.inner
}
pub async fn list_all_routines(&self) -> Result<Vec<Routine>, DatabaseError> {
self.inner.list_all_routines().await
}
pub async fn list_event_routines(&self) -> Result<Vec<Routine>, DatabaseError> {
self.inner.list_event_routines().await
}
pub async fn list_due_cron_routines(&self) -> Result<Vec<Routine>, DatabaseError> {
self.inner.list_due_cron_routines().await
}
pub async fn list_dispatched_routine_runs(&self) -> Result<Vec<RoutineRun>, DatabaseError> {
self.inner.list_dispatched_routine_runs().await
}
pub async fn count_running_routine_runs_batch(
&self,
routine_ids: &[Uuid],
) -> Result<HashMap<Uuid, i64>, DatabaseError> {
self.inner
.count_running_routine_runs_batch(routine_ids)
.await
}
pub async fn batch_get_last_run_status(
&self,
routine_ids: &[Uuid],
) -> Result<HashMap<Uuid, RunStatus>, DatabaseError> {
self.inner.batch_get_last_run_status(routine_ids).await
}
pub async fn count_running_routine_runs(&self, routine_id: Uuid) -> Result<i64, DatabaseError> {
self.inner.count_running_routine_runs(routine_id).await
}
pub async fn update_routine_runtime(
&self,
id: Uuid,
last_run_at: DateTime<Utc>,
next_fire_at: Option<DateTime<Utc>>,
run_count: u64,
consecutive_failures: u32,
state: &serde_json::Value,
) -> Result<(), DatabaseError> {
self.inner
.update_routine_runtime(
id,
last_run_at,
next_fire_at,
run_count,
consecutive_failures,
state,
)
.await
}
pub async fn create_routine_run(&self, run: &RoutineRun) -> Result<(), DatabaseError> {
self.inner.create_routine_run(run).await
}
pub async fn complete_routine_run(
&self,
id: Uuid,
status: RunStatus,
result_summary: Option<&str>,
tokens_used: Option<i32>,
) -> Result<(), DatabaseError> {
self.inner
.complete_routine_run(id, status, result_summary, tokens_used)
.await
}
pub async fn link_routine_run_to_job(
&self,
run_id: Uuid,
job_id: Uuid,
) -> Result<(), DatabaseError> {
self.inner.link_routine_run_to_job(run_id, job_id).await
}
pub async fn get_routine(&self, id: Uuid) -> Result<Option<Routine>, DatabaseError> {
self.inner.get_routine(id).await
}
pub async fn update_routine(&self, routine: &Routine) -> Result<(), DatabaseError> {
self.inner.update_routine(routine).await
}
pub async fn get_stuck_jobs(&self) -> Result<Vec<Uuid>, DatabaseError> {
self.inner.get_stuck_jobs().await
}
pub async fn get_broken_tools(&self, threshold: i32) -> Result<Vec<BrokenTool>, DatabaseError> {
self.inner.get_broken_tools(threshold).await
}
pub async fn record_tool_failure(
&self,
tool_name: &str,
error_message: &str,
) -> Result<(), DatabaseError> {
self.inner
.record_tool_failure(tool_name, error_message)
.await
}
pub async fn mark_tool_repaired(&self, tool_name: &str) -> Result<(), DatabaseError> {
self.inner.mark_tool_repaired(tool_name).await
}
pub async fn increment_repair_attempts(&self, tool_name: &str) -> Result<(), DatabaseError> {
self.inner.increment_repair_attempts(tool_name).await
}
pub async fn cleanup_stale_sandbox_jobs(&self) -> Result<u64, DatabaseError> {
self.inner.cleanup_stale_sandbox_jobs().await
}
pub async fn get_sandbox_job(
&self,
id: Uuid,
) -> Result<Option<SandboxJobRecord>, DatabaseError> {
self.inner.get_sandbox_job(id).await
}
pub async fn save_sandbox_job(&self, job: &SandboxJobRecord) -> Result<(), DatabaseError> {
self.inner.save_sandbox_job(job).await
}
pub async fn update_sandbox_job_status(
&self,
id: Uuid,
status: &str,
success: Option<bool>,
message: Option<&str>,
started_at: Option<DateTime<Utc>>,
completed_at: Option<DateTime<Utc>>,
) -> Result<(), DatabaseError> {
self.inner
.update_sandbox_job_status(id, status, success, message, started_at, completed_at)
.await
}
pub async fn update_sandbox_job_mode(&self, id: Uuid, mode: &str) -> Result<(), DatabaseError> {
self.inner.update_sandbox_job_mode(id, mode).await
}
pub async fn get_sandbox_job_mode(&self, id: Uuid) -> Result<Option<String>, DatabaseError> {
self.inner.get_sandbox_job_mode(id).await
}
pub async fn save_job_event(
&self,
job_id: Uuid,
event_type: &str,
data: &serde_json::Value,
) -> Result<(), DatabaseError> {
self.inner.save_job_event(job_id, event_type, data).await
}
pub async fn list_job_events(
&self,
job_id: Uuid,
limit: Option<i64>,
) -> Result<Vec<crate::history::JobEventRecord>, DatabaseError> {
self.inner.list_job_events(job_id, limit).await
}
pub async fn get_job(&self, id: Uuid) -> Result<Option<JobContext>, DatabaseError> {
self.inner.get_job(id).await
}
pub async fn save_job(&self, ctx: &JobContext) -> Result<(), DatabaseError> {
self.inner.save_job(ctx).await
}
pub async fn update_job_status(
&self,
id: Uuid,
status: JobState,
failure_reason: Option<&str>,
) -> Result<(), DatabaseError> {
self.inner
.update_job_status(id, status, failure_reason)
.await
}
pub async fn mark_job_stuck(&self, id: Uuid) -> Result<(), DatabaseError> {
self.inner.mark_job_stuck(id).await
}
pub async fn list_agent_jobs(&self) -> Result<Vec<AgentJobRecord>, DatabaseError> {
self.inner.list_agent_jobs().await
}
pub async fn get_agent_job_failure_reason(
&self,
id: Uuid,
) -> Result<Option<String>, DatabaseError> {
self.inner.get_agent_job_failure_reason(id).await
}
pub async fn record_llm_call(&self, record: &LlmCallRecord<'_>) -> Result<Uuid, DatabaseError> {
self.inner.record_llm_call(record).await
}
pub async fn save_action(
&self,
job_id: Uuid,
action: &ActionRecord,
) -> Result<(), DatabaseError> {
self.inner.save_action(job_id, action).await
}
pub async fn get_job_actions(&self, job_id: Uuid) -> Result<Vec<ActionRecord>, DatabaseError> {
self.inner.get_job_actions(job_id).await
}
pub async fn save_estimation_snapshot(
&self,
job_id: Uuid,
category: &str,
tool_names: &[String],
estimated_cost: Decimal,
estimated_time_secs: i32,
estimated_value: Decimal,
) -> Result<Uuid, DatabaseError> {
self.inner
.save_estimation_snapshot(
job_id,
category,
tool_names,
estimated_cost,
estimated_time_secs,
estimated_value,
)
.await
}
pub async fn update_estimation_actuals(
&self,
id: Uuid,
actual_cost: Decimal,
actual_time_secs: i32,
actual_value: Option<Decimal>,
) -> Result<(), DatabaseError> {
self.inner
.update_estimation_actuals(id, actual_cost, actual_time_secs, actual_value)
.await
}
pub async fn add_conversation_message(
&self,
conversation_id: Uuid,
role: &str,
content: &str,
) -> Result<Uuid, DatabaseError> {
self.inner
.add_conversation_message(conversation_id, role, content)
.await
}
pub async fn get_or_create_routine_conversation(
&self,
routine_id: Uuid,
routine_name: &str,
user_id: &str,
) -> Result<Uuid, DatabaseError> {
self.inner
.get_or_create_routine_conversation(routine_id, routine_name, user_id)
.await
}
pub async fn get_or_create_heartbeat_conversation(
&self,
user_id: &str,
) -> Result<Uuid, DatabaseError> {
self.inner
.get_or_create_heartbeat_conversation(user_id)
.await
}
}
pub struct TenantRateState {
pub llm_semaphore: Arc<Semaphore>,
pub job_semaphore: Arc<Semaphore>,
}
impl TenantRateState {
pub fn new(max_llm_concurrent: usize, max_job_concurrent: usize) -> Self {
Self {
llm_semaphore: Arc::new(Semaphore::new(max_llm_concurrent)),
job_semaphore: Arc::new(Semaphore::new(max_job_concurrent)),
}
}
}
pub struct TenantRateRegistry {
state: tokio::sync::RwLock<HashMap<String, Arc<TenantRateState>>>,
max_llm_concurrent: usize,
max_job_concurrent: usize,
}
impl TenantRateRegistry {
pub fn new(max_llm_concurrent: usize, max_job_concurrent: usize) -> Self {
Self {
state: tokio::sync::RwLock::new(HashMap::new()),
max_llm_concurrent,
max_job_concurrent,
}
}
pub async fn get_or_create(&self, user_id: &str) -> Arc<TenantRateState> {
{
let map = self.state.read().await;
if let Some(s) = map.get(user_id) {
return Arc::clone(s);
}
}
let mut map = self.state.write().await;
if let Some(s) = map.get(user_id) {
return Arc::clone(s);
}
let s = Arc::new(TenantRateState::new(
self.max_llm_concurrent,
self.max_job_concurrent,
));
map.insert(user_id.to_string(), Arc::clone(&s));
s
}
}
#[derive(Clone)]
pub struct TenantCtx {
user_id: String,
store: Option<TenantScope>,
workspace: Option<Arc<Workspace>>,
cost_guard: Arc<CostGuard>,
rate: Arc<TenantRateState>,
}
impl TenantCtx {
pub fn new(
user_id: impl Into<String>,
store: Option<TenantScope>,
workspace: Option<Arc<Workspace>>,
cost_guard: Arc<CostGuard>,
rate: Arc<TenantRateState>,
) -> Self {
Self {
user_id: user_id.into(),
store,
workspace,
cost_guard,
rate,
}
}
pub fn user_id(&self) -> &str {
&self.user_id
}
pub fn store(&self) -> Option<&TenantScope> {
self.store.as_ref()
}
pub fn workspace(&self) -> Option<&Arc<Workspace>> {
self.workspace.as_ref()
}
pub fn cost_guard(&self) -> &CostGuard {
&self.cost_guard
}
pub async fn check_cost_allowed(&self) -> Result<(), CostLimitExceeded> {
self.cost_guard.check_allowed_for_user(&self.user_id).await
}
#[allow(clippy::too_many_arguments)]
pub async fn record_llm_call(
&self,
model: &str,
input_tokens: u32,
output_tokens: u32,
cache_read_input_tokens: u32,
cache_creation_input_tokens: u32,
cache_read_discount: Decimal,
cache_write_multiplier: Decimal,
cost_per_token: Option<(Decimal, Decimal)>,
) -> Decimal {
self.cost_guard
.record_llm_call_for_user(
&self.user_id,
model,
input_tokens,
output_tokens,
cache_read_input_tokens,
cache_creation_input_tokens,
cache_read_discount,
cache_write_multiplier,
cost_per_token,
)
.await
}
pub async fn acquire_llm_permit(&self) -> Result<SemaphorePermit<'_>, crate::error::Error> {
self.rate.llm_semaphore.acquire().await.map_err(|_| {
crate::error::Error::Config(crate::error::ConfigError::InvalidValue {
key: "llm_semaphore".to_string(),
message: "semaphore closed".to_string(),
})
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_rate_registry_returns_same_state_for_same_user() {
let registry = TenantRateRegistry::new(4, 3);
let a1 = registry.get_or_create("alice").await;
let a2 = registry.get_or_create("alice").await;
assert!(Arc::ptr_eq(&a1, &a2));
}
#[tokio::test]
async fn test_rate_registry_different_users_get_different_state() {
let registry = TenantRateRegistry::new(4, 3);
let alice = registry.get_or_create("alice").await;
let bob = registry.get_or_create("bob").await;
assert!(!Arc::ptr_eq(&alice, &bob));
}
}