batuta/agent/tool/
inference.rs1use async_trait::async_trait;
8use std::sync::Arc;
9
10use super::{Tool, ToolResult};
11use crate::agent::capability::Capability;
12use crate::agent::driver::{CompletionRequest, LlmDriver, Message, ToolDefinition};
13
14pub struct InferenceTool {
16 driver: Arc<dyn LlmDriver>,
17 max_tokens: u32,
18}
19
20impl InferenceTool {
21 pub fn new(driver: Arc<dyn LlmDriver>, max_tokens: u32) -> Self {
23 Self { driver, max_tokens }
24 }
25}
26
27#[async_trait]
28impl Tool for InferenceTool {
29 fn name(&self) -> &'static str {
30 "inference"
31 }
32
33 fn definition(&self) -> ToolDefinition {
34 ToolDefinition {
35 name: "inference".into(),
36 description: "Run a sub-inference completion for \
37 delegation or chain-of-thought reasoning"
38 .into(),
39 input_schema: serde_json::json!({
40 "type": "object",
41 "properties": {
42 "prompt": {
43 "type": "string",
44 "description": "The prompt to send for completion"
45 },
46 "system_prompt": {
47 "type": "string",
48 "description": "Optional system prompt override"
49 }
50 },
51 "required": ["prompt"]
52 }),
53 }
54 }
55
56 #[cfg_attr(
57 feature = "agents-contracts",
58 provable_contracts_macros::contract("agent-loop-v1", equation = "inference_timeout")
59 )]
60 async fn execute(&self, input: serde_json::Value) -> ToolResult {
61 let Some(prompt) = input.get("prompt").and_then(|p| p.as_str()) else {
62 return ToolResult::error("missing required field: prompt");
63 };
64
65 let system = input.get("system_prompt").and_then(|s| s.as_str()).map(String::from);
66
67 let request = CompletionRequest {
68 model: String::new(),
69 messages: vec![Message::User(prompt.into())],
70 max_tokens: self.max_tokens,
71 temperature: 0.0,
72 tools: vec![],
73 system,
74 };
75
76 match self.driver.complete(request).await {
77 Ok(response) => {
78 if response.text.is_empty() {
79 ToolResult::error("inference returned empty response")
80 } else {
81 ToolResult::success(response.text)
82 }
83 }
84 Err(e) => ToolResult::error(format!("inference error: {e}")),
85 }
86 }
87
88 fn required_capability(&self) -> Capability {
89 Capability::Inference
90 }
91
92 fn timeout(&self) -> std::time::Duration {
93 std::time::Duration::from_secs(300)
94 }
95}
96
97#[cfg(test)]
98mod tests {
99 use super::*;
100 use crate::agent::driver::mock::MockDriver;
101
102 #[test]
103 fn test_inference_tool_definition() {
104 let driver = Arc::new(MockDriver::single_response("ok"));
105 let tool = InferenceTool::new(driver, 256);
106 let def = tool.definition();
107 assert_eq!(def.name, "inference");
108 assert!(def.description.contains("sub-inference"));
109 let props = def.input_schema.get("properties").expect("schema properties");
110 assert!(props.get("prompt").is_some());
111 assert!(props.get("system_prompt").is_some());
112 }
113
114 #[test]
115 fn test_inference_tool_capability() {
116 let driver = Arc::new(MockDriver::single_response("ok"));
117 let tool = InferenceTool::new(driver, 256);
118 assert_eq!(tool.required_capability(), Capability::Inference);
119 }
120
121 #[test]
122 fn test_inference_tool_timeout() {
123 let driver = Arc::new(MockDriver::single_response("ok"));
124 let tool = InferenceTool::new(driver, 256);
125 assert_eq!(tool.timeout(), std::time::Duration::from_secs(300),);
126 }
127
128 #[tokio::test]
129 async fn test_inference_missing_prompt() {
130 let driver = Arc::new(MockDriver::single_response("ok"));
131 let tool = InferenceTool::new(driver, 256);
132 let result = tool.execute(serde_json::json!({})).await;
133 assert!(result.is_error);
134 assert!(result.content.contains("missing"));
135 }
136
137 #[tokio::test]
138 async fn test_inference_executes() {
139 let driver = Arc::new(MockDriver::single_response("The answer is 42."));
140 let tool = InferenceTool::new(driver, 256);
141 let result = tool
142 .execute(serde_json::json!({
143 "prompt": "What is the meaning of life?"
144 }))
145 .await;
146 assert!(!result.is_error);
147 assert!(result.content.contains("42"));
148 }
149
150 #[tokio::test]
151 async fn test_inference_with_system_prompt() {
152 let driver = Arc::new(MockDriver::single_response("I am a math tutor."));
153 let tool = InferenceTool::new(driver, 256);
154 let result = tool
155 .execute(serde_json::json!({
156 "prompt": "Help me with algebra",
157 "system_prompt": "You are a math tutor."
158 }))
159 .await;
160 assert!(!result.is_error);
161 assert!(result.content.contains("math tutor"));
162 }
163}