use std::pin::Pin;
use rig::tool::{ToolDyn, ToolError};
use serde_json::Value;
use super::result::LoopToolResult;
use super::schema_flatten::{FlattenDecision, analyze_schema, flatten_schema, nest_arguments};
use super::tool::{AbortSignal, LoopTool, LoopToolUpdate};
use super::types::ToolExecutionMode;
#[cfg(test)]
use std::sync::Arc;
pub struct RigToolAdapter {
inner: Box<dyn ToolDyn>,
name: String,
description: String,
parameters: Value,
flat_parameters: Option<Value>,
execution_mode: Option<ToolExecutionMode>,
}
impl std::fmt::Debug for RigToolAdapter {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("RigToolAdapter")
.field("name", &self.name)
.field("execution_mode", &self.execution_mode)
.field("has_flat_schema", &self.flat_parameters.is_some())
.finish()
}
}
impl RigToolAdapter {
pub async fn new(inner: Box<dyn ToolDyn>) -> Self {
let def = inner.definition(String::new()).await;
let flat_parameters = match analyze_schema(&def.parameters) {
FlattenDecision {
should_flatten: true,
..
} => Some(flatten_schema(&def.parameters)),
_ => None,
};
Self {
inner,
name: def.name,
description: def.description,
parameters: def.parameters,
flat_parameters,
execution_mode: None,
}
}
pub fn with_execution_mode(mut self, mode: ToolExecutionMode) -> Self {
self.execution_mode = Some(mode);
self
}
#[cfg(test)]
fn from_parts(
inner: Box<dyn ToolDyn>,
name: String,
description: String,
parameters: Value,
) -> Self {
Self {
inner,
name,
description,
parameters,
flat_parameters: None,
execution_mode: None,
}
}
}
impl LoopTool for RigToolAdapter {
fn name(&self) -> &str {
&self.name
}
fn description(&self) -> &str {
&self.description
}
fn label(&self) -> &str {
&self.name
}
fn parameters(&self) -> &Value {
&self.parameters
}
fn flat_parameters(&self) -> Option<&Value> {
self.flat_parameters.as_ref()
}
fn execution_mode(&self) -> Option<ToolExecutionMode> {
self.execution_mode
}
fn prepare_arguments(&self, args: Value) -> Value {
if self.flat_parameters.is_some() {
nest_arguments(&args)
} else {
args
}
}
fn execute<'a>(
&'a self,
_tool_call_id: &'a str,
args: Value,
_signal: AbortSignal,
_on_update: LoopToolUpdate,
) -> Pin<Box<dyn Future<Output = Result<LoopToolResult, String>> + Send + 'a>> {
Box::pin(async move {
let args_string = match serde_json::to_string(&args) {
Ok(s) => s,
Err(e) => return Err(format!("rig adapter: arg serialization failed: {e}")),
};
match self.inner.call(args_string).await {
Ok(output_text) => {
Ok(LoopToolResult {
content: vec![serde_json::json!({
"type": "text",
"text": output_text,
})],
details: Value::String(output_text),
terminate: None,
})
}
Err(err) => Err(format_tool_error(err)),
}
})
}
}
fn format_tool_error(err: ToolError) -> String {
let raw = err.to_string();
if raw.contains("missing field") || raw.contains("expected") || raw.contains("invalid type") {
format!(
"Tool input rejected: the arguments did not match the tool's schema.\n\
Try: re-check the tool's required fields and types, then retry.\n\
Details: {raw}"
)
} else {
raw
}
}
#[cfg(test)]
mod tests {
use super::*;
use rig::completion::ToolDefinition;
use rig::tool::Tool;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone)]
struct EchoTool;
#[derive(Debug, Deserialize, Serialize)]
struct EchoArgs {
value: String,
}
#[derive(Debug, thiserror::Error)]
enum EchoError {
#[error("echo failed: {0}")]
Generic(String),
}
impl Tool for EchoTool {
const NAME: &'static str = "echo";
type Error = EchoError;
type Args = EchoArgs;
type Output = String;
async fn definition(&self, _prompt: String) -> ToolDefinition {
ToolDefinition {
name: "echo".to_string(),
description: "Echo the input back".to_string(),
parameters: serde_json::json!({
"type": "object",
"properties": {
"value": {"type": "string"}
},
"required": ["value"]
}),
}
}
async fn call(&self, args: Self::Args) -> Result<Self::Output, Self::Error> {
Ok(format!("echoed: {}", args.value))
}
}
#[derive(Debug, Clone)]
struct FailingTool;
impl Tool for FailingTool {
const NAME: &'static str = "failing";
type Error = EchoError;
type Args = EchoArgs;
type Output = String;
async fn definition(&self, _prompt: String) -> ToolDefinition {
ToolDefinition {
name: "failing".to_string(),
description: "Always fails".to_string(),
parameters: serde_json::json!({"type": "object"}),
}
}
async fn call(&self, _args: Self::Args) -> Result<Self::Output, Self::Error> {
Err(EchoError::Generic("synthetic failure".to_string()))
}
}
fn dummy_update() -> LoopToolUpdate {
Arc::new(|_partial: &LoopToolResult| {})
}
#[tokio::test]
async fn adapter_caches_definition_at_construction() {
let adapter = RigToolAdapter::new(Box::new(EchoTool)).await;
assert_eq!(adapter.name(), "echo");
assert_eq!(adapter.description(), "Echo the input back");
assert_eq!(adapter.label(), "echo"); assert_eq!(adapter.parameters()["type"], "object");
assert!(adapter.execution_mode().is_none());
}
#[tokio::test]
async fn adapter_with_execution_mode_overrides_default() {
let adapter = RigToolAdapter::new(Box::new(EchoTool))
.await
.with_execution_mode(ToolExecutionMode::Sequential);
assert_eq!(
adapter.execution_mode(),
Some(ToolExecutionMode::Sequential)
);
}
#[tokio::test]
async fn execute_happy_path_wraps_output() {
let adapter = RigToolAdapter::new(Box::new(EchoTool)).await;
let result = adapter
.execute(
"call-1",
serde_json::json!({"value": "hello"}),
AbortSignal::new(),
dummy_update(),
)
.await
.expect("execute should succeed");
assert_eq!(result.content.len(), 1);
assert_eq!(result.content[0]["type"], "text");
assert_eq!(result.content[0]["text"], "echoed: hello");
assert_eq!(result.details, Value::String("echoed: hello".to_string()));
assert!(result.terminate.is_none());
}
#[tokio::test]
async fn execute_propagates_tool_error() {
let adapter = RigToolAdapter::new(Box::new(FailingTool)).await;
let result = adapter
.execute(
"call-1",
serde_json::json!({"value": "x"}),
AbortSignal::new(),
dummy_update(),
)
.await;
let err_string = result.expect_err("execute should fail");
assert!(
err_string.contains("synthetic failure"),
"expected error string to mention the underlying message: {err_string}"
);
}
#[tokio::test]
async fn execute_with_malformed_args_returns_error() {
let adapter = RigToolAdapter::new(Box::new(EchoTool)).await;
let result = adapter
.execute(
"call-1",
serde_json::json!({}), AbortSignal::new(),
dummy_update(),
)
.await;
assert!(result.is_err(), "missing required arg should produce error");
}
#[tokio::test]
async fn prepare_arguments_is_identity() {
let adapter = RigToolAdapter::new(Box::new(EchoTool)).await;
let input = serde_json::json!({"value": "x", "extra": "y"});
let output = adapter.prepare_arguments(input.clone());
assert_eq!(output, input);
}
#[tokio::test]
async fn from_parts_builds_adapter() {
let adapter = RigToolAdapter::from_parts(
Box::new(EchoTool),
"custom_name".to_string(),
"custom desc".to_string(),
serde_json::json!({"type": "object"}),
);
assert_eq!(adapter.name(), "custom_name");
assert_eq!(adapter.description(), "custom desc");
}
#[tokio::test]
async fn adapter_matches_rig_path_for_real_dirge_tool() {
use crate::agent::tools::ReadTool;
use rig::tool::ToolDyn;
let dir = std::env::temp_dir().join(format!(
"dirge_rig_tool_test_{}_{}",
std::process::id(),
crate::time_util::now_unix_nanos(),
));
std::fs::create_dir_all(&dir).unwrap();
let target = dir.join("sample.txt");
std::fs::write(&target, b"hello from integration test\n").unwrap();
let path_str = target.to_string_lossy().into_owned();
let tool_a = ReadTool::new(None, None);
let rig_args = serde_json::json!({"path": path_str}).to_string();
let rig_output = <ReadTool as ToolDyn>::call(&tool_a, rig_args)
.await
.expect("rig direct call should succeed");
let tool_b = ReadTool::new(None, None);
let adapter = RigToolAdapter::new(Box::new(tool_b)).await;
let adapter_result = adapter
.execute(
"call-1",
serde_json::json!({"path": path_str}),
AbortSignal::new(),
dummy_update(),
)
.await
.expect("adapter execute should succeed");
assert_eq!(adapter_result.content.len(), 1);
let adapter_text = adapter_result.content[0]["text"]
.as_str()
.expect("text field");
assert_eq!(
adapter_text, rig_output,
"adapter must produce the same text as the rig direct path"
);
let _ = std::fs::remove_dir_all(&dir);
}
}