use async_trait::async_trait;
use deadpool_postgres::Pool;
use tracing::debug;
use smooth_operator::agent_config::{AgentBehaviorConfig, AgentConfigResolver};
#[derive(Clone)]
pub struct PgAgentConfigResolver {
pool: Pool,
}
impl PgAgentConfigResolver {
#[must_use]
pub fn new(pool: Pool) -> Self {
Self { pool }
}
async fn fetch(&self, agent_id: &str) -> Option<AgentBehaviorConfig> {
let id = match uuid::Uuid::parse_str(agent_id) {
Ok(id) => id,
Err(_) => {
debug!(agent_id, "agent_id is not a uuid; no per-agent config");
return None;
}
};
let client = match self.pool.get().await {
Ok(c) => c,
Err(e) => {
debug!(error = %e, "agent config: pool.get failed; falling back to org default");
return None;
}
};
let row = match client
.query_opt(
"SELECT instructions, personality, greeting, conversation_workflow, tool_config, visibility \
FROM agents WHERE id = $1",
&[&id],
)
.await
{
Ok(row) => row?,
Err(e) => {
debug!(error = %e, agent_id, "agent config query failed; falling back to org default");
return None;
}
};
let instructions: Option<serde_json::Value> = row.try_get("instructions").ok().flatten();
let personality: Option<serde_json::Value> = row.try_get("personality").ok().flatten();
let greeting: Option<String> = row.try_get("greeting").ok().flatten();
let workflow: Option<serde_json::Value> =
row.try_get("conversation_workflow").ok().flatten();
let tool_config: Option<serde_json::Value> = row.try_get("tool_config").ok().flatten();
let visibility: Option<String> = row.try_get("visibility").ok().flatten();
let config = AgentBehaviorConfig::from_row_values(
instructions,
personality,
greeting,
workflow,
tool_config,
visibility,
);
if config.is_empty() {
None
} else {
Some(config)
}
}
}
#[async_trait]
impl AgentConfigResolver for PgAgentConfigResolver {
async fn resolve(&self, agent_id: &str) -> Option<AgentBehaviorConfig> {
self.fetch(agent_id).await
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn non_uuid_agent_id_is_none_without_touching_db() {
let mut cfg = deadpool_postgres::Config::new();
cfg.host = Some("127.0.0.1".to_string());
cfg.port = Some(1); cfg.dbname = Some("nope".to_string());
cfg.user = Some("nobody".to_string());
let pool = cfg
.create_pool(
Some(deadpool_postgres::Runtime::Tokio1),
tokio_postgres::NoTls,
)
.expect("build pool");
let provider = PgAgentConfigResolver::new(pool);
assert!(provider.resolve("not-a-uuid").await.is_none());
}
}