use std::task::{Context, Poll};
use futures::future::BoxFuture;
use serde_json::Value;
use tower::{Layer, Service};
use crate::error::{Error, Result};
use crate::service::ToolInvocation;
use crate::tools::{ToolErrorKind, ToolErrorKindSet};
#[derive(Clone, Copy, Debug, Default, Eq, PartialEq)]
pub struct ToolErrorPolicy {
pub terminate_on: ToolErrorKindSet,
}
impl ToolErrorPolicy {
#[must_use]
pub const fn new() -> Self {
Self {
terminate_on: ToolErrorKindSet::empty(),
}
}
#[must_use]
pub const fn operator_safe() -> Self {
Self {
terminate_on: ToolErrorKindSet::empty()
.with(ToolErrorKind::Auth)
.with(ToolErrorKind::Quota)
.with(ToolErrorKind::Permanent),
}
}
#[must_use]
pub const fn add_terminal_kind(mut self, kind: ToolErrorKind) -> Self {
self.terminate_on = self.terminate_on.with(kind);
self
}
#[must_use]
pub const fn classifies_terminal(self, kind: ToolErrorKind) -> bool {
self.terminate_on.contains(kind)
}
}
#[derive(Clone, Copy, Debug)]
pub struct ToolErrorPolicyLayer {
policy: ToolErrorPolicy,
}
impl ToolErrorPolicyLayer {
pub const NAME: &'static str = "tool_error_policy";
#[must_use]
pub const fn new(policy: ToolErrorPolicy) -> Self {
Self { policy }
}
#[must_use]
pub const fn policy(&self) -> &ToolErrorPolicy {
&self.policy
}
}
impl crate::NamedLayer for ToolErrorPolicyLayer {
fn layer_name(&self) -> &'static str {
Self::NAME
}
}
impl<S> Layer<S> for ToolErrorPolicyLayer
where
S: Service<ToolInvocation, Response = Value, Error = Error> + Clone + Send + 'static,
S::Future: Send + 'static,
{
type Service = ToolErrorPolicyService<S>;
fn layer(&self, inner: S) -> Self::Service {
ToolErrorPolicyService {
inner,
policy: self.policy,
}
}
}
#[derive(Clone, Debug)]
pub struct ToolErrorPolicyService<Inner> {
inner: Inner,
policy: ToolErrorPolicy,
}
impl<Inner> Service<ToolInvocation> for ToolErrorPolicyService<Inner>
where
Inner: Service<ToolInvocation, Response = Value, Error = Error> + Clone + Send + 'static,
Inner::Future: Send + 'static,
{
type Response = Value;
type Error = Error;
type Future = BoxFuture<'static, Result<Value>>;
fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<()>> {
self.inner.poll_ready(cx)
}
fn call(&mut self, invocation: ToolInvocation) -> Self::Future {
let mut inner = self.inner.clone();
let policy = self.policy;
let tool_name = invocation.metadata.name.clone();
let audit = invocation.ctx.audit_sink();
Box::pin(async move {
let result = inner.call(invocation).await;
let Err(err) = result else {
return result;
};
if matches!(
err,
Error::Cancelled
| Error::DeadlineExceeded
| Error::Interrupted { .. }
| Error::ModelRetry { .. }
) {
return Err(err);
}
let kind = ToolErrorKind::classify(&err);
if !policy.classifies_terminal(kind) {
return Err(err);
}
let recorded_tool: &str = match &err {
Error::ToolErrorTerminal {
tool_name: inner, ..
} => inner.as_str(),
_ => tool_name.as_str(),
};
if let Some(handle) = &audit {
handle
.as_sink()
.record_tool_error_terminal(kind, recorded_tool);
}
if matches!(err, Error::ToolErrorTerminal { .. }) {
return Err(err);
}
Err(Error::tool_error_terminal(kind, tool_name, err))
})
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used)]
mod tests {
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use serde_json::{Value, json};
use tower::{Layer, Service, ServiceExt};
use crate::LlmRenderable;
use crate::context::ExecutionContext;
use crate::error::Error;
use crate::service::ToolInvocation;
use crate::tools::{ToolErrorKind, ToolMetadata};
use super::*;
#[derive(Clone)]
struct StubTool {
calls: Arc<AtomicUsize>,
outcome: Arc<dyn Fn() -> std::result::Result<Value, Error> + Send + Sync>,
}
impl StubTool {
fn new(
outcome: impl Fn() -> std::result::Result<Value, Error> + Send + Sync + 'static,
) -> Self {
Self {
calls: Arc::new(AtomicUsize::new(0)),
outcome: Arc::new(outcome),
}
}
}
impl Service<ToolInvocation> for StubTool {
type Response = Value;
type Error = Error;
type Future = futures::future::BoxFuture<'static, std::result::Result<Value, Error>>;
fn poll_ready(&mut self, _cx: &mut Context<'_>) -> Poll<std::result::Result<(), Error>> {
Poll::Ready(Ok(()))
}
fn call(&mut self, _req: ToolInvocation) -> Self::Future {
self.calls.fetch_add(1, Ordering::SeqCst);
let outcome = Arc::clone(&self.outcome);
Box::pin(async move { outcome() })
}
}
fn invocation() -> ToolInvocation {
let metadata = Arc::new(ToolMetadata::function(
"stub",
"stub tool for tests",
json!({"type": "object"}),
));
ToolInvocation::new("tu-1".into(), metadata, json!({}), ExecutionContext::new())
}
async fn dispatch(
policy: ToolErrorPolicy,
outcome: Error,
) -> (std::result::Result<Value, Error>, usize) {
let stub = StubTool::new(move || Err(clone_error(&outcome)));
let layer = ToolErrorPolicyLayer::new(policy);
let mut svc = layer.layer(stub.clone());
let result = svc.ready().await.unwrap().call(invocation()).await;
(result, stub.calls.load(Ordering::SeqCst))
}
fn clone_error(err: &Error) -> Error {
match err {
Error::Provider { kind, message, .. } => match kind {
crate::error::ProviderErrorKind::Http(s) => {
Error::provider_http(*s, message.clone())
}
crate::error::ProviderErrorKind::Network => {
Error::provider_network(message.clone())
}
crate::error::ProviderErrorKind::Tls => Error::provider_tls(message.clone()),
crate::error::ProviderErrorKind::Dns => Error::provider_dns(message.clone()),
},
Error::Cancelled => Error::Cancelled,
Error::DeadlineExceeded => Error::DeadlineExceeded,
Error::InvalidRequest(s) => Error::invalid_request(s.clone()),
Error::Config(s) => Error::config(s.clone()),
Error::ModelRetry { hint, attempt } => Error::model_retry(hint.clone(), *attempt),
other => Error::invalid_request(format!("unrepeatable in test: {other}")),
}
}
#[tokio::test]
async fn auth_failure_escalates_under_operator_safe_default() {
let (result, calls) = dispatch(
ToolErrorPolicy::operator_safe(),
Error::provider_http(401, "unauthorized"),
)
.await;
let err = result.unwrap_err();
assert!(
matches!(&err, Error::ToolErrorTerminal { kind, .. } if *kind == ToolErrorKind::Auth),
"expected ToolErrorTerminal{{Auth}}, got {err:?}"
);
assert_eq!(calls, 1, "inner service must be invoked exactly once");
}
#[tokio::test]
async fn validation_failure_passes_through_under_operator_safe() {
let (result, calls) = dispatch(
ToolErrorPolicy::operator_safe(),
Error::invalid_request("bad input"),
)
.await;
let err = result.unwrap_err();
assert!(
matches!(err, Error::InvalidRequest(_)),
"Validation must NOT wrap under operator_safe"
);
assert_eq!(calls, 1);
}
#[tokio::test]
async fn empty_policy_is_full_passthrough() {
let (result, _) = dispatch(
ToolErrorPolicy::new(),
Error::provider_http(401, "unauthorized"),
)
.await;
assert!(matches!(result.unwrap_err(), Error::Provider { .. }));
}
#[tokio::test]
async fn control_signals_short_circuit_classification() {
let policy = ToolErrorPolicy::new().add_terminal_kind(ToolErrorKind::Internal);
let (result, _) = dispatch(policy, Error::Cancelled).await;
assert!(
matches!(result.unwrap_err(), Error::Cancelled),
"Cancelled must pass through unchanged"
);
let (result, _) = dispatch(policy, Error::DeadlineExceeded).await;
assert!(matches!(result.unwrap_err(), Error::DeadlineExceeded));
let model_retry = Error::model_retry("re-check the shape".to_owned().for_llm(), 0);
let (result, _) = dispatch(policy, model_retry).await;
assert!(matches!(result.unwrap_err(), Error::ModelRetry { .. }));
}
#[tokio::test]
async fn already_terminal_passes_through_unchanged() {
let inner_terminal = Error::tool_error_terminal(
ToolErrorKind::Auth,
"sub_tool",
Error::provider_http(401, "no auth"),
);
let stub = StubTool::new(move || {
Err(Error::tool_error_terminal(
ToolErrorKind::Auth,
"sub_tool",
Error::provider_http(401, "no auth"),
))
});
let layer = ToolErrorPolicyLayer::new(ToolErrorPolicy::operator_safe());
let mut svc = layer.layer(stub);
let err = svc
.ready()
.await
.unwrap()
.call(invocation())
.await
.unwrap_err();
match err {
Error::ToolErrorTerminal { source, .. } => match *source {
Error::Provider { .. } => {}
other => panic!("expected leaf Provider source, got {other:?}"),
},
other => panic!("expected ToolErrorTerminal, got {other:?}"),
}
let _ = inner_terminal; }
#[tokio::test]
async fn success_is_uninstrumented() {
let stub = StubTool::new(|| Ok(json!({"ok": true})));
let layer = ToolErrorPolicyLayer::new(ToolErrorPolicy::operator_safe());
let mut svc = layer.layer(stub);
let out = svc.ready().await.unwrap().call(invocation()).await.unwrap();
assert_eq!(out, json!({"ok": true}));
}
}