1use super::traits::{Tool, ToolResult};
8use crate::providers::{self, Provider};
9use crate::security::SecurityPolicy;
10use crate::security::policy::ToolOperation;
11use async_trait::async_trait;
12use serde_json::json;
13use std::sync::Arc;
14
15pub struct LlmTaskTool {
19 security: Arc<SecurityPolicy>,
20 default_provider: String,
22 default_model: String,
24 default_temperature: f64,
26 api_key: Option<String>,
28 provider_runtime_options: providers::ProviderRuntimeOptions,
30}
31
32impl LlmTaskTool {
33 pub fn new(
34 security: Arc<SecurityPolicy>,
35 default_provider: String,
36 default_model: String,
37 default_temperature: f64,
38 api_key: Option<String>,
39 provider_runtime_options: providers::ProviderRuntimeOptions,
40 ) -> Self {
41 Self {
42 security,
43 default_provider,
44 default_model,
45 default_temperature,
46 api_key,
47 provider_runtime_options,
48 }
49 }
50}
51
52#[async_trait]
53impl Tool for LlmTaskTool {
54 fn name(&self) -> &str {
55 "llm_task"
56 }
57
58 fn description(&self) -> &str {
59 "Run a prompt through an LLM with no tool access and return the response. \
60 Optionally validates the output against a JSON Schema. Ideal for structured \
61 data extraction, classification, summarization, and transformation tasks."
62 }
63
64 fn parameters_schema(&self) -> serde_json::Value {
65 json!({
66 "type": "object",
67 "properties": {
68 "prompt": {
69 "type": "string",
70 "description": "The prompt to send to the LLM."
71 },
72 "schema": {
73 "type": "object",
74 "description": "Optional JSON Schema to validate the LLM response against. \
75 When provided, the LLM is instructed to return valid JSON \
76 matching this schema."
77 },
78 "model": {
79 "type": "string",
80 "description": "Optional model override (e.g. 'anthropic/claude-sonnet-4-6'). \
81 Defaults to the configured default model."
82 },
83 "temperature": {
84 "type": "number",
85 "description": "Optional temperature override (0.0-2.0). \
86 Defaults to the configured default temperature."
87 }
88 },
89 "required": ["prompt"]
90 })
91 }
92
93 async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
94 if let Err(error) = self
96 .security
97 .enforce_tool_operation(ToolOperation::Act, "llm_task")
98 {
99 return Ok(ToolResult {
100 success: false,
101 output: String::new(),
102 error: Some(error),
103 });
104 }
105
106 let prompt = match args.get("prompt").and_then(|v| v.as_str()) {
108 Some(p) if !p.trim().is_empty() => p,
109 _ => {
110 return Ok(ToolResult {
111 success: false,
112 output: String::new(),
113 error: Some("Missing or empty required parameter: prompt".to_string()),
114 });
115 }
116 };
117
118 let schema = args.get("schema").and_then(|v| v.as_object());
120 let model = args
121 .get("model")
122 .and_then(|v| v.as_str())
123 .unwrap_or(&self.default_model);
124 let temperature = args
125 .get("temperature")
126 .and_then(|v| v.as_f64())
127 .unwrap_or(self.default_temperature);
128
129 let effective_prompt = if let Some(schema_obj) = schema {
131 let schema_json =
132 serde_json::to_string_pretty(&serde_json::Value::Object(schema_obj.clone()))
133 .unwrap_or_else(|_| "{}".to_string());
134 format!(
135 "{prompt}\n\n\
136 IMPORTANT: You MUST respond with valid JSON that conforms to this schema:\n\
137 ```json\n{schema_json}\n```\n\
138 Respond ONLY with the JSON object, no explanation or markdown."
139 )
140 } else {
141 prompt.to_string()
142 };
143
144 let api_key_ref = self.api_key.as_deref();
146 let provider: Box<dyn Provider> = match providers::create_provider_with_options(
147 &self.default_provider,
148 api_key_ref,
149 &self.provider_runtime_options,
150 ) {
151 Ok(p) => p,
152 Err(e) => {
153 return Ok(ToolResult {
154 success: false,
155 output: String::new(),
156 error: Some(format!("Failed to create provider: {e}")),
157 });
158 }
159 };
160
161 let response = match provider
163 .simple_chat(&effective_prompt, model, temperature)
164 .await
165 {
166 Ok(text) => text,
167 Err(e) => {
168 return Ok(ToolResult {
169 success: false,
170 output: String::new(),
171 error: Some(format!("LLM call failed: {e}")),
172 });
173 }
174 };
175
176 if let Some(schema_obj) = schema {
178 let schema_value = serde_json::Value::Object(schema_obj.clone());
179 match validate_json_response(&response, &schema_value) {
180 Ok(validated_json) => Ok(ToolResult {
181 success: true,
182 output: validated_json,
183 error: None,
184 }),
185 Err(validation_error) => Ok(ToolResult {
186 success: false,
187 output: response,
188 error: Some(format!("Schema validation failed: {validation_error}")),
189 }),
190 }
191 } else {
192 Ok(ToolResult {
193 success: true,
194 output: response,
195 error: None,
196 })
197 }
198 }
199}
200
201fn validate_json_response(response: &str, schema: &serde_json::Value) -> Result<String, String> {
207 let trimmed = response.trim();
209 let json_str = if trimmed.starts_with("```") {
210 trimmed
211 .trim_start_matches("```json")
212 .trim_start_matches("```")
213 .trim_end_matches("```")
214 .trim()
215 } else {
216 trimmed
217 };
218
219 let parsed: serde_json::Value =
221 serde_json::from_str(json_str).map_err(|e| format!("Invalid JSON: {e}"))?;
222
223 if let Some(required) = schema.get("required").and_then(|v| v.as_array()) {
225 for req in required {
226 if let Some(field_name) = req.as_str() {
227 if parsed.get(field_name).is_none() {
228 return Err(format!("Missing required field: {field_name}"));
229 }
230 }
231 }
232 }
233
234 if let Some(properties) = schema.get("properties").and_then(|v| v.as_object()) {
236 for (prop_name, prop_schema) in properties {
237 if let Some(value) = parsed.get(prop_name) {
238 if let Some(expected_type) = prop_schema.get("type").and_then(|t| t.as_str()) {
239 if !type_matches(value, expected_type) {
240 return Err(format!(
241 "Field '{prop_name}' has wrong type: expected {expected_type}, \
242 got {}",
243 json_type_name(value)
244 ));
245 }
246 }
247 }
248 }
249 }
250
251 serde_json::to_string(&parsed).map_err(|e| format!("JSON serialization error: {e}"))
253}
254
255fn type_matches(value: &serde_json::Value, expected: &str) -> bool {
257 match expected {
258 "string" => value.is_string(),
259 "number" => value.is_number(),
260 "integer" => value.is_i64() || value.is_u64(),
261 "boolean" => value.is_boolean(),
262 "array" => value.is_array(),
263 "object" => value.is_object(),
264 "null" => value.is_null(),
265 _ => true, }
267}
268
269fn json_type_name(value: &serde_json::Value) -> &'static str {
271 match value {
272 serde_json::Value::Null => "null",
273 serde_json::Value::Bool(_) => "boolean",
274 serde_json::Value::Number(_) => "number",
275 serde_json::Value::String(_) => "string",
276 serde_json::Value::Array(_) => "array",
277 serde_json::Value::Object(_) => "object",
278 }
279}
280
281#[cfg(test)]
282mod tests {
283 use super::*;
284
285 #[test]
288 fn validate_valid_json_against_schema() {
289 let schema = json!({
290 "type": "object",
291 "properties": {
292 "name": { "type": "string" },
293 "age": { "type": "integer" }
294 },
295 "required": ["name", "age"]
296 });
297
298 let response = r#"{"name": "Alice", "age": 30}"#;
299 let result = validate_json_response(response, &schema);
300 assert!(result.is_ok());
301
302 let parsed: serde_json::Value = serde_json::from_str(&result.unwrap()).unwrap();
303 assert_eq!(parsed["name"], "Alice");
304 assert_eq!(parsed["age"], 30);
305 }
306
307 #[test]
308 fn validate_missing_required_field() {
309 let schema = json!({
310 "type": "object",
311 "properties": {
312 "title": { "type": "string" },
313 "score": { "type": "number" }
314 },
315 "required": ["title", "score"]
316 });
317
318 let response = r#"{"title": "Test"}"#;
319 let result = validate_json_response(response, &schema);
320 assert!(result.is_err());
321 assert!(
322 result
323 .unwrap_err()
324 .contains("Missing required field: score")
325 );
326 }
327
328 #[test]
329 fn validate_wrong_type() {
330 let schema = json!({
331 "type": "object",
332 "properties": {
333 "count": { "type": "integer" }
334 },
335 "required": ["count"]
336 });
337
338 let response = r#"{"count": "not_a_number"}"#;
339 let result = validate_json_response(response, &schema);
340 assert!(result.is_err());
341 assert!(result.unwrap_err().contains("wrong type"));
342 }
343
344 #[test]
345 fn validate_strips_markdown_code_fences() {
346 let schema = json!({
347 "type": "object",
348 "properties": {
349 "result": { "type": "string" }
350 },
351 "required": ["result"]
352 });
353
354 let response = "```json\n{\"result\": \"ok\"}\n```";
355 let result = validate_json_response(response, &schema);
356 assert!(result.is_ok());
357 }
358
359 #[test]
360 fn validate_invalid_json() {
361 let schema = json!({ "type": "object" });
362 let response = "this is not json at all";
363 let result = validate_json_response(response, &schema);
364 assert!(result.is_err());
365 assert!(result.unwrap_err().contains("Invalid JSON"));
366 }
367
368 #[test]
369 fn validate_optional_fields_accepted() {
370 let schema = json!({
371 "type": "object",
372 "properties": {
373 "name": { "type": "string" },
374 "bio": { "type": "string" }
375 },
376 "required": ["name"]
377 });
378
379 let response = r#"{"name": "Bob"}"#;
381 let result = validate_json_response(response, &schema);
382 assert!(result.is_ok());
383 }
384
385 #[test]
386 fn validate_all_type_checks() {
387 assert!(type_matches(&json!("hello"), "string"));
388 assert!(!type_matches(&json!(42), "string"));
389
390 assert!(type_matches(&json!(2.72), "number"));
391 assert!(type_matches(&json!(42), "number"));
392 assert!(!type_matches(&json!("42"), "number"));
393
394 assert!(type_matches(&json!(42), "integer"));
395 assert!(!type_matches(&json!(2.72), "integer"));
396
397 assert!(type_matches(&json!(true), "boolean"));
398 assert!(!type_matches(&json!(1), "boolean"));
399
400 assert!(type_matches(&json!([1, 2]), "array"));
401 assert!(!type_matches(&json!({}), "array"));
402
403 assert!(type_matches(&json!({}), "object"));
404 assert!(!type_matches(&json!([]), "object"));
405
406 assert!(type_matches(&json!(null), "null"));
407
408 assert!(type_matches(&json!("anything"), "custom_type"));
410 }
411
412 #[test]
415 fn tool_metadata() {
416 let tool = LlmTaskTool::new(
417 Arc::new(SecurityPolicy::default()),
418 "openrouter".to_string(),
419 "test-model".to_string(),
420 0.7,
421 None,
422 providers::ProviderRuntimeOptions::default(),
423 );
424
425 assert_eq!(tool.name(), "llm_task");
426 assert!(tool.description().contains("LLM"));
427
428 let schema = tool.parameters_schema();
429 assert_eq!(schema["type"], "object");
430 assert!(schema["properties"]["prompt"].is_object());
431 assert!(schema["properties"]["schema"].is_object());
432 assert!(schema["properties"]["model"].is_object());
433 assert!(schema["properties"]["temperature"].is_object());
434
435 let required = schema["required"].as_array().unwrap();
436 assert_eq!(required.len(), 1);
437 assert_eq!(required[0], "prompt");
438 }
439
440 #[tokio::test]
441 async fn execute_missing_prompt_returns_error() {
442 let tool = LlmTaskTool::new(
443 Arc::new(SecurityPolicy::default()),
444 "openrouter".to_string(),
445 "test-model".to_string(),
446 0.7,
447 None,
448 providers::ProviderRuntimeOptions::default(),
449 );
450
451 let result = tool.execute(json!({})).await.unwrap();
452 assert!(!result.success);
453 assert!(result.error.as_deref().unwrap().contains("prompt"));
454 }
455
456 #[tokio::test]
457 async fn execute_empty_prompt_returns_error() {
458 let tool = LlmTaskTool::new(
459 Arc::new(SecurityPolicy::default()),
460 "openrouter".to_string(),
461 "test-model".to_string(),
462 0.7,
463 None,
464 providers::ProviderRuntimeOptions::default(),
465 );
466
467 let result = tool.execute(json!({"prompt": " "})).await.unwrap();
468 assert!(!result.success);
469 assert!(result.error.as_deref().unwrap().contains("prompt"));
470 }
471
472 #[tokio::test]
473 async fn execute_with_invalid_provider_returns_error() {
474 let tool = LlmTaskTool::new(
475 Arc::new(SecurityPolicy::default()),
476 "nonexistent_provider_xyz".to_string(),
477 "test-model".to_string(),
478 0.7,
479 None,
480 providers::ProviderRuntimeOptions::default(),
481 );
482
483 let result = tool
484 .execute(json!({"prompt": "Hello world"}))
485 .await
486 .unwrap();
487 assert!(!result.success);
488 assert!(result.error.as_deref().unwrap().contains("provider"));
489 }
490}