use std::time::Duration;
use async_trait::async_trait;
use rust_decimal::Decimal;
use serde::{Deserialize, Serialize};
use thiserror::Error;
use crate::context::JobContext;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
pub enum ToolDomain {
Orchestrator,
Container,
}
#[derive(Debug, Error)]
pub enum ToolError {
#[error("Invalid parameters: {0}")]
InvalidParameters(String),
#[error("Execution failed: {0}")]
ExecutionFailed(String),
#[error("Timeout after {0:?}")]
Timeout(Duration),
#[error("Not authorized: {0}")]
NotAuthorized(String),
#[error("Rate limited, retry after {0:?}")]
RateLimited(Option<Duration>),
#[error("External service error: {0}")]
ExternalService(String),
#[error("Sandbox error: {0}")]
Sandbox(String),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolOutput {
pub result: serde_json::Value,
pub cost: Option<Decimal>,
pub duration: Duration,
#[serde(skip_serializing_if = "Option::is_none")]
pub raw: Option<String>,
}
impl ToolOutput {
pub fn success(result: serde_json::Value, duration: Duration) -> Self {
Self {
result,
cost: None,
duration,
raw: None,
}
}
pub fn text(text: impl Into<String>, duration: Duration) -> Self {
Self {
result: serde_json::Value::String(text.into()),
cost: None,
duration,
raw: None,
}
}
pub fn with_cost(mut self, cost: Decimal) -> Self {
self.cost = Some(cost);
self
}
pub fn with_raw(mut self, raw: impl Into<String>) -> Self {
self.raw = Some(raw.into());
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ToolSchema {
pub name: String,
pub description: String,
pub parameters: serde_json::Value,
}
impl ToolSchema {
pub fn new(name: impl Into<String>, description: impl Into<String>) -> Self {
Self {
name: name.into(),
description: description.into(),
parameters: serde_json::json!({
"type": "object",
"properties": {},
"required": []
}),
}
}
pub fn with_parameters(mut self, parameters: serde_json::Value) -> Self {
self.parameters = parameters;
self
}
}
#[async_trait]
pub trait Tool: Send + Sync {
fn name(&self) -> &str;
fn description(&self) -> &str;
fn parameters_schema(&self) -> serde_json::Value;
async fn execute(
&self,
params: serde_json::Value,
ctx: &JobContext,
) -> Result<ToolOutput, ToolError>;
fn estimated_cost(&self, _params: &serde_json::Value) -> Option<Decimal> {
None
}
fn estimated_duration(&self, _params: &serde_json::Value) -> Option<Duration> {
None
}
fn requires_sanitization(&self) -> bool {
true
}
fn requires_approval(&self) -> bool {
false
}
fn requires_approval_for(&self, _params: &serde_json::Value) -> bool {
false
}
fn execution_timeout(&self) -> Duration {
Duration::from_secs(60)
}
fn domain(&self) -> ToolDomain {
ToolDomain::Orchestrator
}
fn schema(&self) -> ToolSchema {
ToolSchema {
name: self.name().to_string(),
description: self.description().to_string(),
parameters: self.parameters_schema(),
}
}
}
pub fn require_str<'a>(params: &'a serde_json::Value, name: &str) -> Result<&'a str, ToolError> {
params
.get(name)
.and_then(|v| v.as_str())
.ok_or_else(|| ToolError::InvalidParameters(format!("missing '{}' parameter", name)))
}
pub fn require_param<'a>(
params: &'a serde_json::Value,
name: &str,
) -> Result<&'a serde_json::Value, ToolError> {
params
.get(name)
.ok_or_else(|| ToolError::InvalidParameters(format!("missing '{}' parameter", name)))
}
#[cfg(test)]
mod tests {
use super::*;
#[derive(Debug)]
pub struct EchoTool;
#[async_trait]
impl Tool for EchoTool {
fn name(&self) -> &str {
"echo"
}
fn description(&self) -> &str {
"Echoes back the input message. Useful for testing."
}
fn parameters_schema(&self) -> serde_json::Value {
serde_json::json!({
"type": "object",
"properties": {
"message": {
"type": "string",
"description": "The message to echo back"
}
},
"required": ["message"]
})
}
async fn execute(
&self,
params: serde_json::Value,
_ctx: &JobContext,
) -> Result<ToolOutput, ToolError> {
let message = require_str(¶ms, "message")?;
Ok(ToolOutput::text(message, Duration::from_millis(1)))
}
fn requires_sanitization(&self) -> bool {
false }
}
#[tokio::test]
async fn test_echo_tool() {
let tool = EchoTool;
let ctx = JobContext::default();
let result = tool
.execute(serde_json::json!({"message": "hello"}), &ctx)
.await
.unwrap();
assert_eq!(result.result, serde_json::json!("hello"));
}
#[test]
fn test_tool_schema() {
let tool = EchoTool;
let schema = tool.schema();
assert_eq!(schema.name, "echo");
assert!(!schema.description.is_empty());
}
#[test]
fn test_execution_timeout_default() {
let tool = EchoTool;
assert_eq!(tool.execution_timeout(), Duration::from_secs(60));
}
#[test]
fn test_require_str_present() {
let params = serde_json::json!({"name": "alice"});
assert_eq!(require_str(¶ms, "name").unwrap(), "alice");
}
#[test]
fn test_require_str_missing() {
let params = serde_json::json!({});
let err = require_str(¶ms, "name").unwrap_err();
assert!(err.to_string().contains("missing 'name'"));
}
#[test]
fn test_require_str_wrong_type() {
let params = serde_json::json!({"name": 42});
let err = require_str(¶ms, "name").unwrap_err();
assert!(err.to_string().contains("missing 'name'"));
}
#[test]
fn test_require_param_present() {
let params = serde_json::json!({"data": [1, 2, 3]});
assert_eq!(
require_param(¶ms, "data").unwrap(),
&serde_json::json!([1, 2, 3])
);
}
#[test]
fn test_require_param_missing() {
let params = serde_json::json!({});
let err = require_param(¶ms, "data").unwrap_err();
assert!(err.to_string().contains("missing 'data'"));
}
#[test]
fn test_requires_approval_for_default() {
let tool = EchoTool;
assert!(!tool.requires_approval_for(&serde_json::json!({"message": "hi"})));
}
}