use std::collections::HashMap;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::{RwLock, mpsc, oneshot};
use tokio::task::JoinHandle;
use uuid::Uuid;
use crate::agent::task::{Task, TaskContext, TaskOutput};
use crate::config::AgentConfig;
use crate::context::{ContextManager, JobContext, JobState};
use crate::db::Database;
use crate::error::{Error, JobError};
use crate::extensions::ExtensionManager;
use crate::hooks::HookRegistry;
use crate::llm::LlmProvider;
use crate::safety::SafetyLayer;
use crate::tools::{
ApprovalContext, ToolRegistry, autonomous_allowed_tool_names, autonomous_unavailable_error,
prepare_tool_params,
};
use crate::worker::job::{Worker, WorkerDeps};
#[derive(Debug)]
pub enum WorkerMessage {
Start,
Stop,
Ping,
UserMessage(String),
}
#[derive(Debug)]
pub struct ScheduledJob {
pub handle: JoinHandle<()>,
pub tx: mpsc::Sender<WorkerMessage>,
}
struct ScheduledSubtask {
handle: JoinHandle<Result<TaskOutput, Error>>,
}
pub struct SchedulerDeps {
pub tools: Arc<ToolRegistry>,
pub extension_manager: Option<Arc<ExtensionManager>>,
pub store: Option<Arc<dyn Database>>,
pub hooks: Arc<HookRegistry>,
}
pub struct Scheduler {
config: AgentConfig,
context_manager: Arc<ContextManager>,
llm: Arc<dyn LlmProvider>,
safety: Arc<SafetyLayer>,
tools: Arc<ToolRegistry>,
extension_manager: Option<Arc<ExtensionManager>>,
store: Option<Arc<dyn Database>>,
hooks: Arc<HookRegistry>,
sse_tx: Option<Arc<crate::channels::web::sse::SseManager>>,
http_interceptor: Option<Arc<dyn crate::llm::recording::HttpInterceptor>>,
jobs: Arc<RwLock<HashMap<Uuid, ScheduledJob>>>,
subtasks: Arc<RwLock<HashMap<Uuid, ScheduledSubtask>>>,
}
impl Scheduler {
pub fn new(
config: AgentConfig,
context_manager: Arc<ContextManager>,
llm: Arc<dyn LlmProvider>,
safety: Arc<SafetyLayer>,
deps: SchedulerDeps,
) -> Self {
Self {
config,
context_manager,
llm,
safety,
tools: deps.tools,
extension_manager: deps.extension_manager,
store: deps.store,
hooks: deps.hooks,
sse_tx: None,
http_interceptor: None,
jobs: Arc::new(RwLock::new(HashMap::new())),
subtasks: Arc::new(RwLock::new(HashMap::new())),
}
}
pub fn set_sse_sender(&mut self, sse: Arc<crate::channels::web::sse::SseManager>) {
self.sse_tx = Some(sse);
}
pub fn set_http_interceptor(
&mut self,
interceptor: Arc<dyn crate::llm::recording::HttpInterceptor>,
) {
self.http_interceptor = Some(interceptor);
}
pub async fn dispatch_job(
&self,
user_id: &str,
title: &str,
description: &str,
metadata: Option<serde_json::Value>,
) -> Result<Uuid, JobError> {
let approval_context = self.autonomous_approval_context(user_id).await;
self.dispatch_job_inner(
user_id,
title,
description,
metadata,
Some(approval_context),
)
.await
}
pub async fn dispatch_job_with_context(
&self,
user_id: &str,
title: &str,
description: &str,
metadata: Option<serde_json::Value>,
approval_context: ApprovalContext,
) -> Result<Uuid, JobError> {
self.dispatch_job_inner(
user_id,
title,
description,
metadata,
Some(approval_context),
)
.await
}
async fn dispatch_job_inner(
&self,
user_id: &str,
title: &str,
description: &str,
metadata: Option<serde_json::Value>,
approval_context: Option<ApprovalContext>,
) -> Result<Uuid, JobError> {
let job_id = self
.context_manager
.create_job_for_user(user_id, title, description)
.await?;
let user_max_tokens = metadata
.as_ref()
.and_then(|m| m.get("max_tokens"))
.and_then(|v| v.as_u64());
let max_tokens = user_max_tokens
.map(|user_val| {
if self.config.max_tokens_per_job == 0 {
user_val
} else {
std::cmp::min(user_val, self.config.max_tokens_per_job)
}
})
.unwrap_or(self.config.max_tokens_per_job);
let ctx = if let Some(meta) = metadata {
self.context_manager
.update_context_and_get(job_id, |ctx| {
ctx.metadata = meta;
if max_tokens > 0 {
ctx.max_tokens = max_tokens;
}
})
.await?
} else if max_tokens > 0 {
self.context_manager
.update_context_and_get(job_id, |ctx| {
ctx.max_tokens = max_tokens;
})
.await?
} else {
self.context_manager.get_context(job_id).await?
};
if let Some(ref store) = self.store {
store.save_job(&ctx).await.map_err(|e| JobError::Failed {
id: job_id,
reason: format!("failed to persist job: {e}"),
})?;
}
self.schedule_with_context(job_id, approval_context).await?;
Ok(job_id)
}
async fn autonomous_approval_context(&self, user_id: &str) -> ApprovalContext {
ApprovalContext::autonomous_with_tools(
autonomous_allowed_tool_names(&self.tools, self.extension_manager.as_ref(), user_id)
.await,
)
}
pub async fn schedule(&self, job_id: Uuid) -> Result<(), JobError> {
self.schedule_with_context(job_id, None).await
}
async fn schedule_with_context(
&self,
job_id: Uuid,
approval_context: Option<ApprovalContext>,
) -> Result<(), JobError> {
{
let mut jobs = self.jobs.write().await;
if jobs.contains_key(&job_id) {
return Ok(());
}
if jobs.len() >= self.config.max_parallel_jobs {
return Err(JobError::MaxJobsExceeded {
max: self.config.max_parallel_jobs,
});
}
self.context_manager
.update_context(job_id, |ctx| {
ctx.transition_to(
JobState::InProgress,
Some("Scheduled for execution".to_string()),
)
})
.await?
.map_err(|s| JobError::ContextError {
id: job_id,
reason: s,
})?;
let (tx, rx) = mpsc::channel(16);
let deps = WorkerDeps {
context_manager: self.context_manager.clone(),
llm: self.llm.clone(),
safety: self.safety.clone(),
tools: self.tools.clone(),
store: self.store.clone(),
hooks: self.hooks.clone(),
timeout: self.config.job_timeout,
use_planning: self.config.use_planning,
sse_tx: self.sse_tx.clone(),
approval_context,
http_interceptor: self.http_interceptor.clone(),
};
let worker = Worker::new(job_id, deps);
let handle = tokio::spawn(async move {
if let Err(e) = worker.run(rx).await {
tracing::error!("Worker for job {} failed: {}", job_id, e);
}
});
if tx.send(WorkerMessage::Start).await.is_err() {
tracing::error!(job_id = %job_id, "Worker died before receiving Start message");
}
jobs.insert(job_id, ScheduledJob { handle, tx });
}
let jobs = Arc::clone(&self.jobs);
tokio::spawn(async move {
loop {
let finished = {
let jobs_read = jobs.read().await;
match jobs_read.get(&job_id) {
Some(scheduled) => scheduled.handle.is_finished(),
None => true,
}
};
if finished {
jobs.write().await.remove(&job_id);
break;
}
tokio::time::sleep(Duration::from_secs(1)).await;
}
});
tracing::info!("Scheduled job {} for execution", job_id);
Ok(())
}
pub async fn spawn_subtask(
&self,
parent_id: Uuid,
task: Task,
) -> Result<oneshot::Receiver<Result<TaskOutput, Error>>, JobError> {
let task_id = Uuid::new_v4();
let (result_tx, result_rx) = oneshot::channel();
let handle = match task {
Task::Job { .. } => {
return Err(JobError::ContextError {
id: parent_id,
reason: "Use schedule() for Job tasks, not spawn_subtask()".to_string(),
});
}
Task::ToolExec {
parent_id: tool_parent_id,
tool_name,
params,
} => {
let tools = self.tools.clone();
let context_manager = self.context_manager.clone();
let safety = self.safety.clone();
tokio::spawn(async move {
let result = Self::execute_tool_task(
tools,
context_manager,
safety,
None,
tool_parent_id,
&tool_name,
params,
)
.await;
let _ = result_tx.send(result);
})
}
Task::Background { id: _, handler } => {
let ctx = TaskContext::new(task_id).with_parent(parent_id);
tokio::spawn(async move {
let result = handler.run(ctx).await;
let _ = result_tx.send(result);
})
}
};
self.subtasks.write().await.insert(
task_id,
ScheduledSubtask {
handle: tokio::spawn(async move {
match handle.await {
Ok(()) => Err(Error::Job(JobError::ContextError {
id: task_id,
reason: "Subtask completed but result not captured".to_string(),
})),
Err(e) => Err(Error::Job(JobError::ContextError {
id: task_id,
reason: format!("Subtask panicked: {}", e),
})),
}
}),
},
);
let subtasks = Arc::clone(&self.subtasks);
tokio::spawn(async move {
loop {
let finished = {
let subtasks_read = subtasks.read().await;
match subtasks_read.get(&task_id) {
Some(scheduled) => scheduled.handle.is_finished(),
None => true,
}
};
if finished {
subtasks.write().await.remove(&task_id);
break;
}
tokio::time::sleep(Duration::from_secs(1)).await;
}
});
tracing::debug!(
parent_id = %parent_id,
task_id = %task_id,
"Spawned subtask"
);
Ok(result_rx)
}
pub async fn spawn_batch(
&self,
parent_id: Uuid,
tasks: Vec<Task>,
) -> Vec<Result<TaskOutput, Error>> {
if tasks.is_empty() {
return Vec::new();
}
let mut receivers = Vec::with_capacity(tasks.len());
for task in tasks {
match self.spawn_subtask(parent_id, task).await {
Ok(rx) => receivers.push(Some(rx)),
Err(e) => {
receivers.push(None);
tracing::warn!(
parent_id = %parent_id,
error = %e,
"Failed to spawn subtask in batch"
);
}
}
}
let mut results = Vec::with_capacity(receivers.len());
for rx in receivers {
let result = match rx {
Some(receiver) => match receiver.await {
Ok(task_result) => task_result,
Err(_) => Err(Error::Job(JobError::ContextError {
id: parent_id,
reason: "Subtask channel closed unexpectedly".to_string(),
})),
},
None => Err(Error::Job(JobError::ContextError {
id: parent_id,
reason: "Subtask failed to spawn".to_string(),
})),
};
results.push(result);
}
results
}
async fn execute_tool_task(
tools: Arc<ToolRegistry>,
context_manager: Arc<ContextManager>,
safety: Arc<SafetyLayer>,
approval_context: Option<ApprovalContext>,
job_id: Uuid,
tool_name: &str,
params: serde_json::Value,
) -> Result<TaskOutput, Error> {
let start = std::time::Instant::now();
let tool = tools.get(tool_name).await.ok_or_else(|| {
Error::Tool(crate::error::ToolError::NotFound {
name: tool_name.to_string(),
})
})?;
let job_ctx: JobContext = context_manager.get_context(job_id).await?;
if job_ctx.state == JobState::Cancelled {
return Err(crate::error::ToolError::ExecutionFailed {
name: tool_name.to_string(),
reason: "Job is cancelled".to_string(),
}
.into());
}
let normalized_params = prepare_tool_params(tool.as_ref(), ¶ms);
let requirement = tool.requires_approval(&normalized_params);
let blocked =
ApprovalContext::is_blocked_or_default(&approval_context, tool_name, requirement);
if blocked {
return Err(autonomous_unavailable_error(tool_name, &job_ctx.user_id).into());
}
let output_str = crate::tools::execute::execute_tool_with_safety(
&tools, &safety, tool_name, params, &job_ctx,
)
.await?;
let result_value: serde_json::Value = serde_json::from_str(&output_str).map_err(|e| {
Error::Tool(crate::error::ToolError::ExecutionFailed {
name: tool_name.to_string(),
reason: format!("Failed to parse tool output as JSON: {}", e),
})
})?;
Ok(TaskOutput::new(result_value, start.elapsed()))
}
pub async fn stop(&self, job_id: Uuid) -> Result<(), JobError> {
let mut jobs = self.jobs.write().await;
if let Some(scheduled) = jobs.remove(&job_id) {
let _ = scheduled.tx.send(WorkerMessage::Stop).await;
tokio::time::sleep(tokio::time::Duration::from_millis(100)).await;
if !scheduled.handle.is_finished() {
scheduled.handle.abort();
}
self.context_manager
.update_context(job_id, |ctx| {
if let Err(e) = ctx.transition_to(
JobState::Cancelled,
Some("Stopped by scheduler".to_string()),
) {
tracing::warn!(
job_id = %job_id,
error = %e,
"Failed to transition job to Cancelled state"
);
}
})
.await?;
if let Some(ref store) = self.store {
let store = store.clone();
tokio::spawn(async move {
if let Err(e) = store
.update_job_status(
job_id,
JobState::Cancelled,
Some("Stopped by scheduler"),
)
.await
{
tracing::warn!("Failed to persist cancellation for job {}: {}", job_id, e);
}
});
}
tracing::info!("Stopped job {}", job_id);
}
Ok(())
}
pub async fn send_message(&self, job_id: Uuid, content: String) -> Result<(), JobError> {
let tx = {
let jobs = self.jobs.read().await;
let scheduled = jobs.get(&job_id).ok_or(JobError::NotFound { id: job_id })?;
scheduled.tx.clone()
};
tx.send(WorkerMessage::UserMessage(content))
.await
.map_err(|_| JobError::Failed {
id: job_id,
reason: "Worker channel closed".to_string(),
})?;
Ok(())
}
pub async fn is_running(&self, job_id: Uuid) -> bool {
self.jobs.read().await.contains_key(&job_id)
}
pub async fn running_count(&self) -> usize {
self.jobs.read().await.len()
}
pub async fn subtask_count(&self) -> usize {
self.subtasks.read().await.len()
}
pub async fn running_jobs(&self) -> Vec<Uuid> {
self.jobs.read().await.keys().cloned().collect()
}
pub async fn cleanup_finished(&self) {
{
let mut jobs = self.jobs.write().await;
let mut finished = Vec::new();
for (id, scheduled) in jobs.iter() {
if scheduled.handle.is_finished() {
finished.push(*id);
}
}
for id in finished {
jobs.remove(&id);
tracing::debug!("Cleaned up finished job {}", id);
}
}
{
let mut subtasks = self.subtasks.write().await;
let mut finished = Vec::new();
for (id, scheduled) in subtasks.iter() {
if scheduled.handle.is_finished() {
finished.push(*id);
}
}
for id in finished {
subtasks.remove(&id);
tracing::trace!("Cleaned up finished subtask {}", id);
}
}
}
pub async fn stop_all(&self) {
let job_ids: Vec<Uuid> = self.jobs.read().await.keys().cloned().collect();
for job_id in job_ids {
let _ = self.stop(job_id).await;
}
let mut subtasks = self.subtasks.write().await;
for (_, scheduled) in subtasks.drain() {
scheduled.handle.abort();
}
}
pub fn tools(&self) -> &Arc<ToolRegistry> {
&self.tools
}
pub fn context_manager(&self) -> &Arc<ContextManager> {
&self.context_manager
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::SafetyConfig;
use crate::llm::{
CompletionRequest, CompletionResponse, LlmError, LlmProvider, ToolCompletionRequest,
ToolCompletionResponse,
};
use crate::safety::SafetyLayer;
use crate::tools::{ApprovalRequirement, Tool, ToolError, ToolOutput};
use rust_decimal_macros::dec;
struct StubLlm;
#[async_trait::async_trait]
impl LlmProvider for StubLlm {
fn model_name(&self) -> &str {
"stub"
}
fn cost_per_token(&self) -> (rust_decimal::Decimal, rust_decimal::Decimal) {
(dec!(0), dec!(0))
}
async fn complete(&self, _req: CompletionRequest) -> Result<CompletionResponse, LlmError> {
Err(LlmError::RequestFailed {
provider: "stub".into(),
reason: "not implemented".into(),
})
}
async fn complete_with_tools(
&self,
_req: ToolCompletionRequest,
) -> Result<ToolCompletionResponse, LlmError> {
Err(LlmError::RequestFailed {
provider: "stub".into(),
reason: "not implemented".into(),
})
}
}
fn make_test_scheduler(max_tokens_per_job: u64) -> Scheduler {
let config = AgentConfig {
name: "test".to_string(),
max_parallel_jobs: 5,
job_timeout: std::time::Duration::from_secs(30),
stuck_threshold: std::time::Duration::from_secs(300),
repair_check_interval: std::time::Duration::from_secs(3600),
max_repair_attempts: 0,
use_planning: false,
session_idle_timeout: std::time::Duration::from_secs(3600),
allow_local_tools: true,
max_cost_per_day_cents: None,
max_actions_per_hour: None,
max_tool_iterations: 10,
auto_approve_tools: true,
default_timezone: "UTC".to_string(),
max_tokens_per_job,
};
let cm = Arc::new(ContextManager::new(5));
let llm: Arc<dyn LlmProvider> = Arc::new(StubLlm);
let safety = Arc::new(SafetyLayer::new(&SafetyConfig {
max_output_length: 100_000,
injection_check_enabled: false,
}));
let tools = Arc::new(ToolRegistry::new());
let hooks = Arc::new(HookRegistry::default());
Scheduler::new(
config,
cm,
llm,
safety,
SchedulerDeps {
tools,
extension_manager: None,
store: None,
hooks,
},
)
}
#[tokio::test]
async fn test_dispatch_job_caps_user_max_tokens() {
let sched = make_test_scheduler(1000);
let meta = serde_json::json!({ "max_tokens": 5000 });
let job_id = sched
.dispatch_job("user1", "test", "desc", Some(meta))
.await
.unwrap();
let ctx = sched.context_manager.get_context(job_id).await.unwrap();
assert_eq!(ctx.max_tokens, 1000, "should cap at configured limit");
}
#[tokio::test]
async fn test_dispatch_job_unlimited_config_preserves_user_tokens() {
let sched = make_test_scheduler(0); let meta = serde_json::json!({ "max_tokens": 5000 });
let job_id = sched
.dispatch_job("user1", "test", "desc", Some(meta))
.await
.unwrap();
let ctx = sched.context_manager.get_context(job_id).await.unwrap();
assert_eq!(
ctx.max_tokens, 5000,
"unlimited config should preserve user value"
);
}
#[tokio::test]
async fn test_dispatch_job_no_user_tokens_uses_config() {
let sched = make_test_scheduler(2000);
let job_id = sched
.dispatch_job("user1", "test", "desc", None)
.await
.unwrap();
let ctx = sched.context_manager.get_context(job_id).await.unwrap();
assert_eq!(
ctx.max_tokens, 2000,
"should use config default when no user value"
);
}
#[tokio::test]
async fn test_dispatch_job_atomic_metadata_and_tokens() {
let sched = make_test_scheduler(10_000);
let meta = serde_json::json!({
"max_tokens": 3000,
"custom_key": "custom_value"
});
let job_id = sched
.dispatch_job("user1", "test", "desc", Some(meta))
.await
.unwrap();
let ctx = sched.context_manager.get_context(job_id).await.unwrap();
assert_eq!(ctx.max_tokens, 3000, "should use user value within limit");
assert_eq!(
ctx.metadata.get("custom_key").and_then(|v| v.as_str()),
Some("custom_value"),
"metadata should be set atomically with token budget"
);
}
#[tokio::test]
async fn test_dispatch_job_no_metadata_no_user_tokens_edge_case() {
let sched = make_test_scheduler(0); let job_id = sched
.dispatch_job("user1", "test", "desc", None) .await
.unwrap();
let ctx = sched.context_manager.get_context(job_id).await.unwrap(); assert!(ctx.metadata.is_null() || ctx.metadata == serde_json::json!({})); assert_eq!(ctx.max_tokens, 0, "unlimited config"); }
#[test]
fn test_scheduler_creation() {
}
#[tokio::test]
async fn test_spawn_batch_empty() {
}
struct SoftApprovalTool;
#[async_trait::async_trait]
impl Tool for SoftApprovalTool {
fn name(&self) -> &str {
"soft_gate"
}
fn description(&self) -> &str {
"needs soft approval"
}
fn parameters_schema(&self) -> serde_json::Value {
serde_json::json!({"type": "object", "properties": {}})
}
async fn execute(
&self,
_params: serde_json::Value,
_ctx: &JobContext,
) -> Result<ToolOutput, ToolError> {
Ok(ToolOutput::text(
"soft_ok",
std::time::Instant::now().elapsed(),
))
}
fn requires_approval(&self, _params: &serde_json::Value) -> ApprovalRequirement {
ApprovalRequirement::UnlessAutoApproved
}
fn requires_sanitization(&self) -> bool {
false
}
}
struct HardApprovalTool;
#[async_trait::async_trait]
impl Tool for HardApprovalTool {
fn name(&self) -> &str {
"hard_gate"
}
fn description(&self) -> &str {
"needs hard approval"
}
fn parameters_schema(&self) -> serde_json::Value {
serde_json::json!({"type": "object", "properties": {}})
}
async fn execute(
&self,
_params: serde_json::Value,
_ctx: &JobContext,
) -> Result<ToolOutput, ToolError> {
Ok(ToolOutput::text(
"hard_ok",
std::time::Instant::now().elapsed(),
))
}
fn requires_approval(&self, _params: &serde_json::Value) -> ApprovalRequirement {
ApprovalRequirement::Always
}
fn requires_sanitization(&self) -> bool {
false
}
}
async fn setup_tools_and_job() -> (
Arc<ToolRegistry>,
Arc<ContextManager>,
Arc<SafetyLayer>,
Uuid,
) {
let registry = ToolRegistry::new();
registry.register(Arc::new(SoftApprovalTool)).await;
registry.register(Arc::new(HardApprovalTool)).await;
let cm = Arc::new(ContextManager::new(5));
let job_id = cm.create_job("test", "approval test").await.unwrap();
cm.update_context(job_id, |ctx| ctx.transition_to(JobState::InProgress, None))
.await
.unwrap()
.unwrap();
let safety = Arc::new(SafetyLayer::new(&SafetyConfig {
max_output_length: 100_000,
injection_check_enabled: false,
}));
(Arc::new(registry), cm, safety, job_id)
}
#[tokio::test]
async fn test_execute_tool_task_blocks_without_context() {
let (tools, cm, safety, job_id) = setup_tools_and_job().await;
let result = Scheduler::execute_tool_task(
tools.clone(),
cm.clone(),
safety.clone(),
None,
job_id,
"soft_gate",
serde_json::json!({}),
)
.await;
assert!(
result.is_err(),
"soft_gate should be blocked without context"
);
let result = Scheduler::execute_tool_task(
tools,
cm,
safety,
None,
job_id,
"hard_gate",
serde_json::json!({}),
)
.await;
assert!(
result.is_err(),
"hard_gate should be blocked without context"
);
}
#[tokio::test]
async fn test_execute_tool_task_autonomous_unblocks_soft() {
let (tools, cm, safety, job_id) = setup_tools_and_job().await;
let result = Scheduler::execute_tool_task(
tools.clone(),
cm.clone(),
safety.clone(),
Some(ApprovalContext::autonomous_with_tools([
"soft_gate".to_string()
])),
job_id,
"soft_gate",
serde_json::json!({}),
)
.await;
assert!(
result.is_ok(),
"soft_gate should pass with autonomous context"
);
let result = Scheduler::execute_tool_task(
tools,
cm,
safety,
Some(ApprovalContext::autonomous()),
job_id,
"hard_gate",
serde_json::json!({}),
)
.await;
assert!(
result.is_err(),
"hard_gate should still be blocked without explicit permission"
);
}
#[tokio::test]
async fn test_execute_tool_task_autonomous_with_permissions() {
let (tools, cm, safety, job_id) = setup_tools_and_job().await;
let ctx = ApprovalContext::autonomous_with_tools([
"soft_gate".to_string(),
"hard_gate".to_string(),
]);
let result = Scheduler::execute_tool_task(
tools.clone(),
cm.clone(),
safety.clone(),
Some(ctx.clone()),
job_id,
"soft_gate",
serde_json::json!({}),
)
.await;
assert!(result.is_ok(), "soft_gate should pass");
let result = Scheduler::execute_tool_task(
tools,
cm,
safety,
Some(ctx),
job_id,
"hard_gate",
serde_json::json!({}),
)
.await;
assert!(
result.is_ok(),
"hard_gate should pass with explicit permission"
);
}
struct NormalizedApprovalTool;
#[async_trait::async_trait]
impl Tool for NormalizedApprovalTool {
fn name(&self) -> &str {
"normalized_gate"
}
fn description(&self) -> &str {
"approval depends on normalized params"
}
fn parameters_schema(&self) -> serde_json::Value {
serde_json::json!({
"type": "object",
"properties": {
"safe": { "type": "boolean" }
}
})
}
async fn execute(
&self,
_params: serde_json::Value,
_ctx: &JobContext,
) -> Result<ToolOutput, ToolError> {
Ok(ToolOutput::text(
"normalized_ok",
std::time::Instant::now().elapsed(),
))
}
fn requires_approval(&self, params: &serde_json::Value) -> ApprovalRequirement {
if params.get("safe").and_then(|v| v.as_bool()) == Some(true) {
ApprovalRequirement::Never
} else {
ApprovalRequirement::Always
}
}
fn requires_sanitization(&self) -> bool {
false
}
}
#[tokio::test]
async fn test_execute_tool_task_normalizes_params_before_approval() {
let registry = ToolRegistry::new();
registry.register(Arc::new(NormalizedApprovalTool)).await;
let cm = Arc::new(ContextManager::new(5));
let job_id = cm.create_job("test", "normalized approval").await.unwrap(); cm.update_context(job_id, |ctx| ctx.transition_to(JobState::InProgress, None))
.await
.unwrap() .unwrap();
let safety = Arc::new(SafetyLayer::new(&SafetyConfig {
max_output_length: 100_000,
injection_check_enabled: false,
}));
let result = Scheduler::execute_tool_task(
Arc::new(registry),
cm,
safety,
None,
job_id,
"normalized_gate",
serde_json::json!({"safe": "true"}),
)
.await;
#[rustfmt::skip]
assert!( result.is_ok(),
"stringified boolean should normalize before approval: {result:?}"
);
}
}