Skip to main content

batuta/agent/tool/
inference.rs

1//! `InferenceTool` — sub-model invocation for agent delegation.
2//!
3//! Allows an agent to run a secondary LLM completion via the
4//! same driver, useful for chain-of-thought delegation or
5//! specialized reasoning sub-tasks.
6
7use async_trait::async_trait;
8use std::sync::Arc;
9
10use super::{Tool, ToolResult};
11use crate::agent::capability::Capability;
12use crate::agent::driver::{CompletionRequest, LlmDriver, Message, ToolDefinition};
13
14/// Tool that runs a sub-inference via the agent's LLM driver.
15pub struct InferenceTool {
16    driver: Arc<dyn LlmDriver>,
17    max_tokens: u32,
18}
19
20impl InferenceTool {
21    /// Create a new `InferenceTool` with the given driver.
22    pub fn new(driver: Arc<dyn LlmDriver>, max_tokens: u32) -> Self {
23        Self { driver, max_tokens }
24    }
25}
26
27#[async_trait]
28impl Tool for InferenceTool {
29    fn name(&self) -> &'static str {
30        "inference"
31    }
32
33    fn definition(&self) -> ToolDefinition {
34        ToolDefinition {
35            name: "inference".into(),
36            description: "Run a sub-inference completion for \
37                          delegation or chain-of-thought reasoning"
38                .into(),
39            input_schema: serde_json::json!({
40                "type": "object",
41                "properties": {
42                    "prompt": {
43                        "type": "string",
44                        "description": "The prompt to send for completion"
45                    },
46                    "system_prompt": {
47                        "type": "string",
48                        "description": "Optional system prompt override"
49                    }
50                },
51                "required": ["prompt"]
52            }),
53        }
54    }
55
56    #[cfg_attr(
57        feature = "agents-contracts",
58        provable_contracts_macros::contract("agent-loop-v1", equation = "inference_timeout")
59    )]
60    async fn execute(&self, input: serde_json::Value) -> ToolResult {
61        let Some(prompt) = input.get("prompt").and_then(|p| p.as_str()) else {
62            return ToolResult::error("missing required field: prompt");
63        };
64
65        let system = input.get("system_prompt").and_then(|s| s.as_str()).map(String::from);
66
67        let request = CompletionRequest {
68            model: String::new(),
69            messages: vec![Message::User(prompt.into())],
70            max_tokens: self.max_tokens,
71            temperature: 0.0,
72            tools: vec![],
73            system,
74        };
75
76        match self.driver.complete(request).await {
77            Ok(response) => {
78                if response.text.is_empty() {
79                    ToolResult::error("inference returned empty response")
80                } else {
81                    ToolResult::success(response.text)
82                }
83            }
84            Err(e) => ToolResult::error(format!("inference error: {e}")),
85        }
86    }
87
88    fn required_capability(&self) -> Capability {
89        Capability::Inference
90    }
91
92    fn timeout(&self) -> std::time::Duration {
93        std::time::Duration::from_secs(300)
94    }
95}
96
97#[cfg(test)]
98mod tests {
99    use super::*;
100    use crate::agent::driver::mock::MockDriver;
101
102    #[test]
103    fn test_inference_tool_definition() {
104        let driver = Arc::new(MockDriver::single_response("ok"));
105        let tool = InferenceTool::new(driver, 256);
106        let def = tool.definition();
107        assert_eq!(def.name, "inference");
108        assert!(def.description.contains("sub-inference"));
109        let props = def.input_schema.get("properties").expect("schema properties");
110        assert!(props.get("prompt").is_some());
111        assert!(props.get("system_prompt").is_some());
112    }
113
114    #[test]
115    fn test_inference_tool_capability() {
116        let driver = Arc::new(MockDriver::single_response("ok"));
117        let tool = InferenceTool::new(driver, 256);
118        assert_eq!(tool.required_capability(), Capability::Inference);
119    }
120
121    #[test]
122    fn test_inference_tool_timeout() {
123        let driver = Arc::new(MockDriver::single_response("ok"));
124        let tool = InferenceTool::new(driver, 256);
125        assert_eq!(tool.timeout(), std::time::Duration::from_secs(300),);
126    }
127
128    #[tokio::test]
129    async fn test_inference_missing_prompt() {
130        let driver = Arc::new(MockDriver::single_response("ok"));
131        let tool = InferenceTool::new(driver, 256);
132        let result = tool.execute(serde_json::json!({})).await;
133        assert!(result.is_error);
134        assert!(result.content.contains("missing"));
135    }
136
137    #[tokio::test]
138    async fn test_inference_executes() {
139        let driver = Arc::new(MockDriver::single_response("The answer is 42."));
140        let tool = InferenceTool::new(driver, 256);
141        let result = tool
142            .execute(serde_json::json!({
143                "prompt": "What is the meaning of life?"
144            }))
145            .await;
146        assert!(!result.is_error);
147        assert!(result.content.contains("42"));
148    }
149
150    #[tokio::test]
151    async fn test_inference_with_system_prompt() {
152        let driver = Arc::new(MockDriver::single_response("I am a math tutor."));
153        let tool = InferenceTool::new(driver, 256);
154        let result = tool
155            .execute(serde_json::json!({
156                "prompt": "Help me with algebra",
157                "system_prompt": "You are a math tutor."
158            }))
159            .await;
160        assert!(!result.is_error);
161        assert!(result.content.contains("math tutor"));
162    }
163}