use crate::agents::{DeliberationPhase, PendingToolCall, ToolCallStatus, UserToolHandlerTrait};
use crate::nats_utils::ensure_kv_bucket;
use anyhow::Result;
use async_trait::async_trait;
use futures_util::StreamExt;
use std::time::{Duration, Instant, SystemTime, UNIX_EPOCH};
use tracing::{info, warn};
use uuid::Uuid;
fn escape_xml(s: &str) -> String {
s.replace('&', "&")
.replace('<', "<")
.replace('>', ">")
}
fn escape_xml_attr(s: &str) -> String {
s.replace('&', "&")
.replace('<', "<")
.replace('>', ">")
.replace('"', """)
.replace('\'', "'")
}
fn compute_finalization_reserve(
phase_budget: Duration,
reserve_secs: f64,
reserve_ratio: f64,
) -> Duration {
let safe_secs = if reserve_secs.is_finite() && reserve_secs > 0.0 {
reserve_secs
} else {
0.0
};
let safe_ratio = if reserve_ratio.is_finite() && reserve_ratio > 0.0 {
reserve_ratio.min(1.0)
} else {
0.0
};
let ratio_based = Duration::from_secs_f64(phase_budget.as_secs_f64() * safe_ratio);
let fixed = Duration::from_secs_f64(safe_secs);
ratio_based.min(fixed)
}
#[derive(Clone, Debug)]
pub struct UserToolHandler {
nats_client: async_nats::Client,
js_context: async_nats::jetstream::Context,
session_id: String,
agent_id: String,
subject_prefix: String,
phase_start: Instant,
phase_budget: Duration,
max_pending_per_agent: usize,
finalization_reserve_secs: f64,
finalization_reserve_ratio: f64,
}
enum WaitResult {
Responded(String),
Timeout,
Error(String),
}
impl UserToolHandler {
pub fn new(
nats_client: async_nats::Client,
js_context: async_nats::jetstream::Context,
session_id: String,
agent_id: String,
phase_budget_remaining_secs: f64,
) -> Self {
let safe_budget = if phase_budget_remaining_secs.is_finite() {
phase_budget_remaining_secs.max(0.0)
} else {
0.0
};
Self {
nats_client,
js_context,
session_id,
agent_id,
subject_prefix: "nsed".to_string(),
phase_start: Instant::now(),
phase_budget: Duration::from_secs_f64(safe_budget),
max_pending_per_agent: 3,
finalization_reserve_secs: 30.0,
finalization_reserve_ratio: 0.15,
}
}
pub fn with_subject_prefix(mut self, prefix: String) -> Self {
self.subject_prefix = prefix;
self
}
pub fn with_max_pending_per_agent(mut self, max: usize) -> Self {
self.max_pending_per_agent = max;
self
}
pub fn with_finalization_reserve(mut self, secs: f64, ratio: f64) -> Self {
self.finalization_reserve_secs = secs;
self.finalization_reserve_ratio = ratio;
self
}
fn remaining_budget(&self) -> Duration {
self.phase_budget.saturating_sub(self.phase_start.elapsed())
}
fn finalization_reserve(&self) -> Duration {
compute_finalization_reserve(
self.phase_budget,
self.finalization_reserve_secs,
self.finalization_reserve_ratio,
)
}
fn now_epoch_millis() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_millis() as u64
}
fn bucket_name(&self) -> String {
format!(
"{}_toolcalls_{}",
crate::nats_utils::sanitize_subject_component(&self.subject_prefix),
crate::nats_utils::sanitize_subject_component(&self.session_id)
)
}
pub async fn handle_call(
&self,
tool_name: &str,
arguments_json: &str,
round: u32,
phase: DeliberationPhase,
) -> String {
let arguments: serde_json::Value = match serde_json::from_str(arguments_json) {
Ok(v) => v,
Err(e) => {
return format!("Error: Invalid JSON arguments: {}", e);
}
};
let bucket_name = self.bucket_name();
let toolcall_store = match self.get_or_create_bucket(&bucket_name).await {
Ok(store) => store,
Err(e) => {
warn!("Failed to access toolcall bucket: {}", e);
return format!("Error: Failed to register tool call: {}", e);
}
};
match self.count_pending_for_agent(&toolcall_store).await {
Ok(count) if count >= self.max_pending_per_agent => {
return format!(
"Error: Maximum pending tool calls ({}) reached for this agent. \
Wait for existing calls to be answered before making new ones.",
self.max_pending_per_agent
);
}
Err(e) => {
warn!("Failed to check pending count: {}", e);
}
_ => {}
}
let call_id = Uuid::new_v4().to_string();
let pending_call = PendingToolCall {
call_id: call_id.clone(),
job_id: self.session_id.clone(),
agent_id: self.agent_id.clone(),
tool_name: tool_name.to_string(),
arguments: arguments.clone(),
round,
phase,
status: ToolCallStatus::Pending,
created_at: Self::now_epoch_millis(),
responded_at: None,
result: None,
};
let key = format!("call_{}", call_id);
let data = match serde_json::to_vec(&pending_call) {
Ok(d) => d,
Err(e) => return format!("Error: Failed to serialize tool call: {}", e),
};
if let Err(e) = toolcall_store.put(&key, data.into()).await {
return format!("Error: Failed to store pending tool call: {}", e);
}
self.publish_sse_event(
"tool_call_pending",
&serde_json::json!({
"call_id": &call_id,
"agent_id": &self.agent_id,
"tool_name": tool_name,
"arguments": &arguments,
"round": round,
"phase": &phase,
}),
)
.await;
info!(
agent = %self.agent_id,
tool = %tool_name,
call_id = %call_id,
"User tool call published. Waiting for response..."
);
let remaining = self.remaining_budget();
let reserve = self.finalization_reserve();
if remaining <= reserve {
self.expire_call(&toolcall_store, &key, &call_id, tool_name)
.await;
return "[No response — phase budget exhausted. Proceed immediately.]".to_string();
}
let deadline = remaining.saturating_sub(reserve);
let result = self
.wait_for_response(&toolcall_store, &key, deadline)
.await;
match result {
WaitResult::Responded(response_text) => {
self.publish_sse_event(
"tool_call_responded",
&serde_json::json!({
"call_id": &call_id,
"agent_id": &self.agent_id,
"tool_name": tool_name,
}),
)
.await;
info!(
agent = %self.agent_id,
call_id = %call_id,
"User tool call responded."
);
let escaped = escape_xml(&response_text);
let safe_tool = escape_xml_attr(tool_name);
let safe_call = escape_xml_attr(&call_id);
format!(
"<user_tool_result tool=\"{}\" call_id=\"{}\">{}</user_tool_result>",
safe_tool, safe_call, escaped
)
}
WaitResult::Timeout => {
self.expire_call(&toolcall_store, &key, &call_id, tool_name)
.await;
let remaining_after = self.remaining_budget();
format!(
"[No response yet — you have {:.0}s remaining to finalize your proposal \
with your best judgment. The user may respond later and the result will \
be available next round.]",
remaining_after.as_secs_f64()
)
}
WaitResult::Error(e) => {
warn!(call_id = %call_id, error = %e, "Error waiting for tool call response");
format!("Error waiting for user response: {}", e)
}
}
}
async fn get_or_create_bucket(
&self,
bucket_name: &str,
) -> Result<async_nats::jetstream::kv::Store> {
ensure_kv_bucket(
&self.js_context,
async_nats::jetstream::kv::Config {
bucket: bucket_name.to_string(),
history: 5,
max_age: Duration::from_secs(86400 * 3),
storage: async_nats::jetstream::stream::StorageType::File,
..Default::default()
},
)
.await
}
async fn count_pending_for_agent(
&self,
store: &async_nats::jetstream::kv::Store,
) -> Result<usize> {
let scan_start = Instant::now();
let mut count = 0;
let mut total_keys = 0u32;
let mut keys = store.keys().await?;
while let Some(key_result) = keys.next().await {
let Ok(key) = key_result else { continue };
if !key.starts_with("call_") {
continue;
}
total_keys += 1;
let Ok(Some(entry)) = store.get(&key).await else {
continue;
};
let Ok(call) = serde_json::from_slice::<PendingToolCall>(&entry) else {
continue;
};
if call.agent_id == self.agent_id && call.status == ToolCallStatus::Pending {
count += 1;
}
}
let scan_ms = scan_start.elapsed().as_millis();
if total_keys > 50 || scan_ms > 100 {
warn!(
total_keys = total_keys,
pending = count,
agent = %self.agent_id,
scan_ms = scan_ms,
"Tool call bucket scan is growing — consider secondary counter if this persists"
);
}
Ok(count)
}
async fn wait_for_response(
&self,
store: &async_nats::jetstream::kv::Store,
key: &str,
timeout_duration: Duration,
) -> WaitResult {
let mut watcher = match store.watch_with_history(key).await {
Ok(w) => w,
Err(e) => return WaitResult::Error(format!("Failed to create KV watcher: {}", e)),
};
tokio::select! {
result = async {
while let Some(entry) = watcher.next().await {
let Ok(entry) = entry else { continue };
let Ok(call) = serde_json::from_slice::<PendingToolCall>(&entry.value) else {
continue;
};
if call.status == ToolCallStatus::Responded {
return WaitResult::Responded(call.result.unwrap_or_default());
}
}
WaitResult::Error("KV watcher stream ended unexpectedly".to_string())
} => result,
_ = tokio::time::sleep(timeout_duration) => {
WaitResult::Timeout
}
}
}
async fn expire_call(
&self,
store: &async_nats::jetstream::kv::Store,
key: &str,
call_id: &str,
tool_name: &str,
) {
if let Ok(Some(entry)) = store.entry(key).await
&& let Ok(mut call) = serde_json::from_slice::<PendingToolCall>(&entry.value)
{
if call.status != ToolCallStatus::Pending {
return;
}
call.status = ToolCallStatus::Expired;
match serde_json::to_vec(&call) {
Ok(data) => {
match store.update(key, data.into(), entry.revision).await {
Ok(_) => {
self.publish_sse_event(
"tool_call_expired",
&serde_json::json!({
"call_id": call_id,
"agent_id": &self.agent_id,
"tool_name": tool_name,
"timeout_secs": self.phase_budget.as_secs_f64(),
}),
)
.await;
}
Err(e) => {
warn!(
call_id = %call_id,
error = %e,
"CAS update failed for expire_call (concurrent modification?)"
);
}
}
}
Err(e) => {
warn!(
call_id = %call_id,
error = %e,
"Failed to serialize expired tool call"
);
}
}
}
}
async fn publish_sse_event<T: serde::Serialize>(&self, suffix: &str, payload: &T) {
let data = match serde_json::to_vec(payload) {
Ok(d) => d,
Err(e) => {
warn!("Failed to serialize SSE event: {}", e);
return;
}
};
let safe_session = crate::nats_utils::sanitize_subject_component(&self.session_id);
let safe_prefix = crate::nats_utils::sanitize_subject_component(&self.subject_prefix);
let subject = format!("{}.{}.result.event.{}", safe_prefix, safe_session, suffix);
if let Err(e) = self.nats_client.publish(subject.clone(), data.into()).await {
warn!("Failed to publish SSE event to {}: {}", subject, e);
}
}
}
#[async_trait]
impl UserToolHandlerTrait for UserToolHandler {
async fn handle_call(
&self,
tool_name: &str,
arguments_json: &str,
round: u32,
phase: DeliberationPhase,
) -> String {
self.handle_call(tool_name, arguments_json, round, phase)
.await
}
}
#[derive(Debug)]
pub struct NatsUserToolHandlerFactory;
impl crate::workers::UserToolHandlerFactory for NatsUserToolHandlerFactory {
fn create(
&self,
nats: async_nats::Client,
js: async_nats::jetstream::Context,
session_id: String,
agent_id: String,
budget_remaining_secs: f64,
subject_prefix: String,
) -> std::sync::Arc<dyn UserToolHandlerTrait> {
std::sync::Arc::new(
UserToolHandler::new(nats, js, session_id, agent_id, budget_remaining_secs)
.with_subject_prefix(subject_prefix),
)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::agents::{PendingToolCall, ToolCallStatus, UserToolDefinition};
#[test]
fn test_user_tool_definition_serde_roundtrip() {
let def = UserToolDefinition {
name: "dm_user".to_string(),
description: "Send a DM to the user".to_string(),
parameters: Some(serde_json::json!({
"type": "object",
"properties": {
"message": { "type": "string" }
},
"required": ["message"]
})),
strict: Some(true),
};
let json = serde_json::to_string(&def).unwrap();
let parsed: UserToolDefinition = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.name, "dm_user");
assert_eq!(parsed.strict, Some(true));
assert!(parsed.parameters.is_some());
}
#[test]
fn test_user_tool_definition_minimal() {
let json = r#"{"name": "ping", "description": "Ping the user"}"#;
let parsed: UserToolDefinition = serde_json::from_str(json).unwrap();
assert_eq!(parsed.name, "ping");
assert!(parsed.parameters.is_none());
assert!(parsed.strict.is_none());
}
#[test]
fn test_pending_tool_call_serde_roundtrip() {
let call = PendingToolCall {
call_id: "abc-123".to_string(),
job_id: "job-1".to_string(),
agent_id: "agent-1".to_string(),
tool_name: "user_dm_user".to_string(),
arguments: serde_json::json!({"message": "hello"}),
round: 1,
phase: DeliberationPhase::Proposing,
status: ToolCallStatus::Pending,
created_at: 1234567890,
responded_at: None,
result: None,
};
let json = serde_json::to_string(&call).unwrap();
let parsed: PendingToolCall = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.call_id, "abc-123");
assert_eq!(parsed.status, ToolCallStatus::Pending);
assert!(parsed.responded_at.is_none());
assert!(parsed.result.is_none());
}
#[test]
fn test_pending_tool_call_responded() {
let call = PendingToolCall {
call_id: "abc-123".to_string(),
job_id: "job-1".to_string(),
agent_id: "agent-1".to_string(),
tool_name: "user_dm_user".to_string(),
arguments: serde_json::json!({}),
round: 2,
phase: DeliberationPhase::Evaluating,
status: ToolCallStatus::Responded,
created_at: 1234567890,
responded_at: Some(1234567900),
result: Some("The answer is 42".to_string()),
};
let json = serde_json::to_string(&call).unwrap();
let parsed: PendingToolCall = serde_json::from_str(&json).unwrap();
assert_eq!(parsed.status, ToolCallStatus::Responded);
assert_eq!(parsed.result, Some("The answer is 42".to_string()));
}
#[test]
fn test_tool_call_status_all_variants() {
for (status, expected) in [
(ToolCallStatus::Pending, "\"Pending\""),
(ToolCallStatus::Responded, "\"Responded\""),
(ToolCallStatus::Expired, "\"Expired\""),
] {
let json = serde_json::to_string(&status).unwrap();
assert_eq!(json, expected);
let parsed: ToolCallStatus = serde_json::from_str(&json).unwrap();
assert_eq!(parsed, status);
}
}
#[test]
fn test_finalization_reserve_computation() {
assert_eq!(
compute_finalization_reserve(Duration::from_secs(200), 30.0, 0.15),
Duration::from_secs(30)
);
}
#[test]
fn test_finalization_reserve_small_budget() {
assert_eq!(
compute_finalization_reserve(Duration::from_secs(60), 30.0, 0.15),
Duration::from_secs(9)
);
}
#[test]
fn test_escape_xml_basic_entities() {
assert_eq!(escape_xml("hello"), "hello");
assert_eq!(escape_xml("<script>"), "<script>");
assert_eq!(escape_xml("a & b"), "a & b");
assert_eq!(escape_xml(""), "");
}
#[test]
fn test_escape_xml_preserves_wrapper_integrity() {
let malicious = "</user_tool_result><injected>evil</injected>";
let escaped = escape_xml(malicious);
let wrapped = format!(
"<user_tool_result tool=\"test\" call_id=\"c1\">{}</user_tool_result>",
escaped
);
assert!(wrapped.starts_with("<user_tool_result tool=\"test\" call_id=\"c1\">"));
assert!(wrapped.ends_with("</user_tool_result>"));
assert!(!wrapped.contains("<injected>"));
assert!(wrapped.contains("<injected>"));
}
#[test]
fn test_escape_xml_combined() {
let input = "x < 5 & y > 3";
let expected = "x < 5 & y > 3";
assert_eq!(escape_xml(input), expected);
}
#[test]
fn test_escape_xml_ampersand_first() {
let input = "<";
let expected = "&lt;";
assert_eq!(escape_xml(input), expected);
}
#[test]
fn test_escape_xml_attr_quotes() {
assert_eq!(escape_xml_attr(r#"he said "hi""#), "he said "hi"");
assert_eq!(escape_xml_attr("it's"), "it's");
}
#[test]
fn test_escape_xml_attr_prevents_attribute_injection() {
let malicious_tool = r#"evil" onclick="alert(1)"#;
let safe = escape_xml_attr(malicious_tool);
assert!(!safe.contains('"'));
assert!(safe.contains("""));
}
#[test]
fn test_finalization_reserve_nan_inputs() {
let result = compute_finalization_reserve(Duration::from_secs(100), f64::NAN, 0.15);
assert_eq!(result, Duration::ZERO);
let result = compute_finalization_reserve(Duration::from_secs(100), 30.0, f64::NAN);
assert_eq!(result, Duration::ZERO);
}
#[test]
fn test_finalization_reserve_negative_inputs() {
let result = compute_finalization_reserve(Duration::from_secs(100), -10.0, 0.15);
assert_eq!(result, Duration::ZERO);
let result = compute_finalization_reserve(Duration::from_secs(100), 30.0, -0.5);
assert_eq!(result, Duration::ZERO);
}
#[test]
fn test_finalization_reserve_infinite_inputs() {
let result = compute_finalization_reserve(Duration::from_secs(100), f64::INFINITY, 0.15);
assert_eq!(result, Duration::ZERO);
let result =
compute_finalization_reserve(Duration::from_secs(100), 30.0, f64::NEG_INFINITY);
assert_eq!(result, Duration::ZERO);
}
#[test]
fn test_finalization_reserve_ratio_capped() {
let result = compute_finalization_reserve(Duration::from_secs(100), 200.0, 2.0);
assert_eq!(result, Duration::from_secs(100));
}
#[test]
fn test_phase_budget_sanitization_nan() {
let val: f64 = f64::NAN;
let safe = if val.is_finite() { val.max(0.0) } else { 0.0 };
assert_eq!(safe, 0.0);
let _ = Duration::from_secs_f64(safe);
}
#[test]
fn test_phase_budget_sanitization_infinity() {
let val: f64 = f64::INFINITY;
let safe = if val.is_finite() { val.max(0.0) } else { 0.0 };
assert_eq!(safe, 0.0);
let _ = Duration::from_secs_f64(safe);
}
#[test]
fn test_phase_budget_sanitization_neg_infinity() {
let val: f64 = f64::NEG_INFINITY;
let safe = if val.is_finite() { val.max(0.0) } else { 0.0 };
assert_eq!(safe, 0.0);
let _ = Duration::from_secs_f64(safe);
}
#[test]
fn test_phase_budget_sanitization_negative() {
let val: f64 = -100.0;
let safe = if val.is_finite() { val.max(0.0) } else { 0.0 };
assert_eq!(safe, 0.0);
let _ = Duration::from_secs_f64(safe);
}
#[test]
fn test_phase_budget_sanitization_valid() {
let val: f64 = 42.5;
let safe = if val.is_finite() { val.max(0.0) } else { 0.0 };
assert_eq!(safe, 42.5);
assert_eq!(Duration::from_secs_f64(safe), Duration::from_millis(42500));
}
#[test]
fn test_publish_sse_event_sanitizes_session_id() {
let raw = "my.session>with*wildcards";
let safe = crate::nats_utils::sanitize_subject_component(raw);
assert!(!safe.contains('.'), "Dots should be removed");
assert!(!safe.contains('>'), "Greater-than should be removed");
assert!(!safe.contains('*'), "Wildcards should be removed");
assert!(!safe.is_empty(), "Sanitized result should not be empty");
}
#[test]
fn test_escape_xml_empty_string() {
assert_eq!(escape_xml(""), "");
}
#[test]
fn test_escape_xml_no_special_chars() {
let input = "Hello, world! 123 test";
assert_eq!(escape_xml(input), input);
}
#[test]
fn test_escape_xml_all_special_chars() {
let input = "&<>";
assert_eq!(escape_xml(input), "&<>");
}
#[test]
fn test_escape_xml_preserves_quotes() {
let input = "He said \"hello\" and it's fine";
assert_eq!(escape_xml(input), "He said \"hello\" and it's fine");
}
#[test]
fn test_escape_xml_already_escaped() {
let input = "& < >";
let result = escape_xml(input);
assert_eq!(result, "&amp; &lt; &gt;");
}
#[test]
fn test_escape_xml_multiline() {
let input = "line1 <b>bold</b>\nline2 & more\nline3 > end";
let expected = "line1 <b>bold</b>\nline2 & more\nline3 > end";
assert_eq!(escape_xml(input), expected);
}
#[test]
fn test_escape_xml_attr_empty() {
assert_eq!(escape_xml_attr(""), "");
}
#[test]
fn test_escape_xml_attr_all_five_special_chars() {
let input = "&<>\"'";
assert_eq!(escape_xml_attr(input), "&<>"'");
}
#[test]
fn test_escape_xml_attr_no_special_chars() {
let input = "simple text 123";
assert_eq!(escape_xml_attr(input), input);
}
#[test]
fn test_escape_xml_attr_mixed_quotes_and_entities() {
let input = "tool_name=\"bad\" & 'evil' <injected>";
let expected = "tool_name="bad" & 'evil' <injected>";
assert_eq!(escape_xml_attr(input), expected);
}
#[test]
fn test_finalization_reserve_zero_budget() {
let result = compute_finalization_reserve(Duration::ZERO, 30.0, 0.15);
assert_eq!(result, Duration::ZERO);
}
#[test]
fn test_finalization_reserve_both_params_zero() {
let result = compute_finalization_reserve(Duration::from_secs(100), 0.0, 0.0);
assert_eq!(result, Duration::ZERO);
}
#[test]
fn test_finalization_reserve_ratio_exactly_one() {
let result = compute_finalization_reserve(Duration::from_secs(100), 200.0, 1.0);
assert_eq!(result, Duration::from_secs(100));
}
#[test]
fn test_finalization_reserve_fixed_smaller_than_ratio() {
let result = compute_finalization_reserve(Duration::from_secs(100), 10.0, 0.5);
assert_eq!(result, Duration::from_secs(10));
}
#[test]
fn test_finalization_reserve_very_large_budget() {
let result = compute_finalization_reserve(Duration::from_secs(10000), 60.0, 0.01);
assert_eq!(result, Duration::from_secs(60));
}
#[test]
fn test_finalization_reserve_very_small_budget() {
let result = compute_finalization_reserve(Duration::from_millis(100), 30.0, 0.15);
assert_eq!(result, Duration::from_millis(15));
}
#[test]
fn test_finalization_reserve_both_nan() {
let result = compute_finalization_reserve(Duration::from_secs(100), f64::NAN, f64::NAN);
assert_eq!(result, Duration::ZERO);
}
#[test]
fn test_finalization_reserve_both_infinite() {
let result =
compute_finalization_reserve(Duration::from_secs(100), f64::INFINITY, f64::INFINITY);
assert_eq!(result, Duration::ZERO);
}
#[test]
fn test_finalization_reserve_neg_infinity_secs() {
let result =
compute_finalization_reserve(Duration::from_secs(100), f64::NEG_INFINITY, 0.15);
assert_eq!(result, Duration::ZERO);
}
#[test]
fn test_finalization_reserve_fractional_duration() {
let result = compute_finalization_reserve(Duration::from_millis(500), 1.0, 0.1);
assert_eq!(result, Duration::from_millis(50));
}
}