use std::collections::HashSet;
use std::sync::Arc;
use serde_json::{json, Value};
use crate::tool::{ToolCall, ToolRegistry};
use crate::types::AgentMessage;
pub struct PlainTextRecoveryContext<'a> {
pub messages: &'a [AgentMessage],
pub iteration: usize,
pub available_tool_names: &'a [&'a str],
pub terminal_fallback_tool: Option<&'a str>,
}
pub struct HiddenToolContext<'a> {
pub requested_tool: &'a str,
pub allowlist: &'a HashSet<String>,
pub messages: &'a [AgentMessage],
}
#[derive(Debug, Clone)]
pub struct HiddenToolError {
pub message: String,
pub details: Value,
}
pub trait ProtocolPolicy: Send + Sync + 'static {
fn name(&self) -> &'static str {
"default_protocol"
}
fn terminal_tool_names(&self) -> HashSet<String> {
HashSet::new()
}
fn plain_text_recovery_prompt(&self, _ctx: PlainTextRecoveryContext<'_>) -> Option<String> {
None
}
fn normalize_tool_calls(&self, _calls: &mut [ToolCall], _registry: &ToolRegistry) -> usize {
0
}
fn hidden_tool_error(&self, _ctx: HiddenToolContext<'_>) -> Option<HiddenToolError> {
None
}
}
#[derive(Debug, Default, Clone, Copy)]
pub struct DefaultProtocolPolicy;
impl ProtocolPolicy for DefaultProtocolPolicy {}
pub const DEFAULT_PLAIN_TEXT_RECOVERY_PROMPT: &str = "\
[runtime context — protocol recovery, not user instruction]\n\
Your previous response was plain text with no tool call. This runtime advances only through structured tool calls — every turn must select exactly one tool.\n\
\n\
Re-read the latest user request and call exactly one tool now. If the answer is ready, call your final-response / delivery tool. Do not reply with a clarifying question unless a tool is genuinely blocked on input only the user can supply.";
pub(crate) fn generic_hidden_tool_message(tool_name: &str, allowlist: &HashSet<String>) -> String {
format!(
"Tool `{tool_name}` is not available in this turn — the active tool gate narrowed the \
catalog. Call one of the tools available now instead. Available now: [{}].",
allowed_tools_preview(allowlist)
)
}
pub(crate) fn generic_hidden_tool_details(
tool_name: &str,
allowlist: &HashSet<String>,
gate: Option<&str>,
) -> Value {
let mut allowed_tools: Vec<&str> = allowlist.iter().map(String::as_str).collect();
allowed_tools.sort_unstable();
json!({
"runtime_block": true,
"requested_tool": tool_name,
"allowed_tools": allowed_tools,
"gate": gate.unwrap_or("tool_gate"),
})
}
pub(crate) fn allowed_tools_preview(allowlist: &HashSet<String>) -> String {
let mut allowed: Vec<&str> = allowlist.iter().map(String::as_str).collect();
allowed.sort_unstable();
if allowed.len() > 12 {
format!("{}, … ({} total)", allowed[..12].join(", "), allowed.len())
} else {
allowed.join(", ")
}
}
pub(crate) fn default_policy() -> Arc<dyn ProtocolPolicy> {
Arc::new(DefaultProtocolPolicy)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn default_policy_is_vocabulary_free() {
let p = DefaultProtocolPolicy;
assert!(p.terminal_tool_names().is_empty());
assert!(p
.plain_text_recovery_prompt(PlainTextRecoveryContext {
messages: &[],
iteration: 0,
available_tool_names: &[],
terminal_fallback_tool: None,
})
.is_none());
assert!(p
.hidden_tool_error(HiddenToolContext {
requested_tool: "anything",
allowlist: &HashSet::new(),
messages: &[],
})
.is_none());
}
#[test]
fn default_normalize_is_noop() {
let p = DefaultProtocolPolicy;
let registry = ToolRegistry::new();
let mut calls = vec![ToolCall {
id: "1".into(),
name: "advance".into(),
arguments: Value::Null,
}];
assert_eq!(p.normalize_tool_calls(&mut calls, ®istry), 0);
assert_eq!(calls[0].name, "advance", "default policy must not rewrite");
}
#[test]
fn generic_message_names_requested_and_available() {
let allow: HashSet<String> = ["a", "b"].iter().map(|s| s.to_string()).collect();
let msg = generic_hidden_tool_message("zzz", &allow);
assert!(msg.contains("zzz"), "{msg}");
assert!(msg.contains('a') && msg.contains('b'), "{msg}");
assert!(!msg.contains("plan("), "{msg}");
assert!(!msg.contains("capability profile"), "{msg}");
}
#[test]
fn generic_details_are_shape_only() {
let allow: HashSet<String> = ["x"].iter().map(|s| s.to_string()).collect();
let details = generic_hidden_tool_details("y", &allow, Some("my_gate"));
assert_eq!(details.get("runtime_block"), Some(&json!(true)));
assert_eq!(details.get("requested_tool"), Some(&json!("y")));
assert_eq!(details.get("allowed_tools"), Some(&json!(["x"])));
assert_eq!(details.get("gate"), Some(&json!("my_gate")));
assert!(details.get("kind").is_none());
assert!(details.get("repair_actions").is_none());
}
#[test]
fn custom_policy_can_override_everything() {
struct ProductPolicy;
impl ProtocolPolicy for ProductPolicy {
fn name(&self) -> &'static str {
"product"
}
fn terminal_tool_names(&self) -> HashSet<String> {
["deliver", "ask"].iter().map(|s| s.to_string()).collect()
}
fn plain_text_recovery_prompt(
&self,
_ctx: PlainTextRecoveryContext<'_>,
) -> Option<String> {
Some("call deliver(...) now".to_string())
}
fn normalize_tool_calls(
&self,
calls: &mut [ToolCall],
_registry: &ToolRegistry,
) -> usize {
let mut n = 0;
for c in calls.iter_mut() {
if c.name == "go" {
c.name = "deliver".into();
n += 1;
}
}
n
}
fn hidden_tool_error(&self, ctx: HiddenToolContext<'_>) -> Option<HiddenToolError> {
Some(HiddenToolError {
message: format!("`{}` is gated; call deliver(...)", ctx.requested_tool),
details: json!({ "product": true }),
})
}
}
let p = ProductPolicy;
assert_eq!(p.name(), "product");
assert_eq!(p.terminal_tool_names().len(), 2);
assert!(p
.plain_text_recovery_prompt(PlainTextRecoveryContext {
messages: &[],
iteration: 0,
available_tool_names: &[],
terminal_fallback_tool: Some("deliver"),
})
.is_some());
let registry = ToolRegistry::new();
let mut calls = vec![ToolCall {
id: "1".into(),
name: "go".into(),
arguments: Value::Null,
}];
assert_eq!(p.normalize_tool_calls(&mut calls, ®istry), 1);
assert_eq!(calls[0].name, "deliver");
let err = p
.hidden_tool_error(HiddenToolContext {
requested_tool: "shell",
allowlist: &HashSet::new(),
messages: &[],
})
.expect("custom policy returns an error");
assert!(err.message.contains("shell"));
assert_eq!(err.details, json!({ "product": true }));
}
}