Skip to main content

construct/tools/
llm_task.rs

1//! Lightweight LLM task tool for structured JSON-only sub-calls.
2//!
3//! Runs a single prompt through an LLM provider with no tool access and
4//! optionally validates the response against a caller-supplied JSON Schema.
5//! Ideal for structured data extraction in workflows.
6
7use super::traits::{Tool, ToolResult};
8use crate::providers::{self, Provider};
9use crate::security::SecurityPolicy;
10use crate::security::policy::ToolOperation;
11use async_trait::async_trait;
12use serde_json::json;
13use std::sync::Arc;
14
15/// Tool that runs a single prompt through an LLM and optionally validates
16/// the response against a JSON Schema. No tools are provided to the LLM —
17/// this is a pure text-in, text-out (or JSON-out) call.
18pub struct LlmTaskTool {
19    security: Arc<SecurityPolicy>,
20    /// Default provider name from root config (e.g. "openrouter").
21    default_provider: String,
22    /// Default model from root config.
23    default_model: String,
24    /// Default temperature from root config.
25    default_temperature: f64,
26    /// API key for provider authentication.
27    api_key: Option<String>,
28    /// Provider runtime options inherited from root config.
29    provider_runtime_options: providers::ProviderRuntimeOptions,
30}
31
32impl LlmTaskTool {
33    pub fn new(
34        security: Arc<SecurityPolicy>,
35        default_provider: String,
36        default_model: String,
37        default_temperature: f64,
38        api_key: Option<String>,
39        provider_runtime_options: providers::ProviderRuntimeOptions,
40    ) -> Self {
41        Self {
42            security,
43            default_provider,
44            default_model,
45            default_temperature,
46            api_key,
47            provider_runtime_options,
48        }
49    }
50}
51
52#[async_trait]
53impl Tool for LlmTaskTool {
54    fn name(&self) -> &str {
55        "llm_task"
56    }
57
58    fn description(&self) -> &str {
59        "Run a prompt through an LLM with no tool access and return the response. \
60         Optionally validates the output against a JSON Schema. Ideal for structured \
61         data extraction, classification, summarization, and transformation tasks."
62    }
63
64    fn parameters_schema(&self) -> serde_json::Value {
65        json!({
66            "type": "object",
67            "properties": {
68                "prompt": {
69                    "type": "string",
70                    "description": "The prompt to send to the LLM."
71                },
72                "schema": {
73                    "type": "object",
74                    "description": "Optional JSON Schema to validate the LLM response against. \
75                                    When provided, the LLM is instructed to return valid JSON \
76                                    matching this schema."
77                },
78                "model": {
79                    "type": "string",
80                    "description": "Optional model override (e.g. 'anthropic/claude-sonnet-4-6'). \
81                                    Defaults to the configured default model."
82                },
83                "temperature": {
84                    "type": "number",
85                    "description": "Optional temperature override (0.0-2.0). \
86                                    Defaults to the configured default temperature."
87                }
88            },
89            "required": ["prompt"]
90        })
91    }
92
93    async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
94        // Security gate
95        if let Err(error) = self
96            .security
97            .enforce_tool_operation(ToolOperation::Act, "llm_task")
98        {
99            return Ok(ToolResult {
100                success: false,
101                output: String::new(),
102                error: Some(error),
103            });
104        }
105
106        // Extract required prompt
107        let prompt = match args.get("prompt").and_then(|v| v.as_str()) {
108            Some(p) if !p.trim().is_empty() => p,
109            _ => {
110                return Ok(ToolResult {
111                    success: false,
112                    output: String::new(),
113                    error: Some("Missing or empty required parameter: prompt".to_string()),
114                });
115            }
116        };
117
118        // Extract optional overrides
119        let schema = args.get("schema").and_then(|v| v.as_object());
120        let model = args
121            .get("model")
122            .and_then(|v| v.as_str())
123            .unwrap_or(&self.default_model);
124        let temperature = args
125            .get("temperature")
126            .and_then(|v| v.as_f64())
127            .unwrap_or(self.default_temperature);
128
129        // Build the effective prompt, adding JSON schema instructions when needed
130        let effective_prompt = if let Some(schema_obj) = schema {
131            let schema_json =
132                serde_json::to_string_pretty(&serde_json::Value::Object(schema_obj.clone()))
133                    .unwrap_or_else(|_| "{}".to_string());
134            format!(
135                "{prompt}\n\n\
136                 IMPORTANT: You MUST respond with valid JSON that conforms to this schema:\n\
137                 ```json\n{schema_json}\n```\n\
138                 Respond ONLY with the JSON object, no explanation or markdown."
139            )
140        } else {
141            prompt.to_string()
142        };
143
144        // Create provider
145        let api_key_ref = self.api_key.as_deref();
146        let provider: Box<dyn Provider> = match providers::create_provider_with_options(
147            &self.default_provider,
148            api_key_ref,
149            &self.provider_runtime_options,
150        ) {
151            Ok(p) => p,
152            Err(e) => {
153                return Ok(ToolResult {
154                    success: false,
155                    output: String::new(),
156                    error: Some(format!("Failed to create provider: {e}")),
157                });
158            }
159        };
160
161        // Make the LLM call (no tools, no agent loop)
162        let response = match provider
163            .simple_chat(&effective_prompt, model, temperature)
164            .await
165        {
166            Ok(text) => text,
167            Err(e) => {
168                return Ok(ToolResult {
169                    success: false,
170                    output: String::new(),
171                    error: Some(format!("LLM call failed: {e}")),
172                });
173            }
174        };
175
176        // If schema was provided, validate the response
177        if let Some(schema_obj) = schema {
178            let schema_value = serde_json::Value::Object(schema_obj.clone());
179            match validate_json_response(&response, &schema_value) {
180                Ok(validated_json) => Ok(ToolResult {
181                    success: true,
182                    output: validated_json,
183                    error: None,
184                }),
185                Err(validation_error) => Ok(ToolResult {
186                    success: false,
187                    output: response,
188                    error: Some(format!("Schema validation failed: {validation_error}")),
189                }),
190            }
191        } else {
192            Ok(ToolResult {
193                success: true,
194                output: response,
195                error: None,
196            })
197        }
198    }
199}
200
201/// Validate a JSON response string against a JSON Schema value.
202///
203/// Performs lightweight validation: parses the response as JSON, checks that
204/// required fields exist, and verifies basic type constraints (string, number,
205/// integer, boolean, array, object) for each declared property.
206fn validate_json_response(response: &str, schema: &serde_json::Value) -> Result<String, String> {
207    // Strip markdown code fences if the LLM wrapped the response
208    let trimmed = response.trim();
209    let json_str = if trimmed.starts_with("```") {
210        trimmed
211            .trim_start_matches("```json")
212            .trim_start_matches("```")
213            .trim_end_matches("```")
214            .trim()
215    } else {
216        trimmed
217    };
218
219    // Parse as JSON
220    let parsed: serde_json::Value =
221        serde_json::from_str(json_str).map_err(|e| format!("Invalid JSON: {e}"))?;
222
223    // Check required fields
224    if let Some(required) = schema.get("required").and_then(|v| v.as_array()) {
225        for req in required {
226            if let Some(field_name) = req.as_str() {
227                if parsed.get(field_name).is_none() {
228                    return Err(format!("Missing required field: {field_name}"));
229                }
230            }
231        }
232    }
233
234    // Check property types
235    if let Some(properties) = schema.get("properties").and_then(|v| v.as_object()) {
236        for (prop_name, prop_schema) in properties {
237            if let Some(value) = parsed.get(prop_name) {
238                if let Some(expected_type) = prop_schema.get("type").and_then(|t| t.as_str()) {
239                    if !type_matches(value, expected_type) {
240                        return Err(format!(
241                            "Field '{prop_name}' has wrong type: expected {expected_type}, \
242                             got {}",
243                            json_type_name(value)
244                        ));
245                    }
246                }
247            }
248        }
249    }
250
251    // Return the cleaned, re-serialized JSON
252    serde_json::to_string(&parsed).map_err(|e| format!("JSON serialization error: {e}"))
253}
254
255/// Check whether a JSON value matches an expected JSON Schema type string.
256fn type_matches(value: &serde_json::Value, expected: &str) -> bool {
257    match expected {
258        "string" => value.is_string(),
259        "number" => value.is_number(),
260        "integer" => value.is_i64() || value.is_u64(),
261        "boolean" => value.is_boolean(),
262        "array" => value.is_array(),
263        "object" => value.is_object(),
264        "null" => value.is_null(),
265        _ => true, // Unknown type — accept
266    }
267}
268
269/// Return a human-readable type name for a JSON value.
270fn json_type_name(value: &serde_json::Value) -> &'static str {
271    match value {
272        serde_json::Value::Null => "null",
273        serde_json::Value::Bool(_) => "boolean",
274        serde_json::Value::Number(_) => "number",
275        serde_json::Value::String(_) => "string",
276        serde_json::Value::Array(_) => "array",
277        serde_json::Value::Object(_) => "object",
278    }
279}
280
281#[cfg(test)]
282mod tests {
283    use super::*;
284
285    // ── Schema validation tests ──────────────────────────────────────
286
287    #[test]
288    fn validate_valid_json_against_schema() {
289        let schema = json!({
290            "type": "object",
291            "properties": {
292                "name": { "type": "string" },
293                "age": { "type": "integer" }
294            },
295            "required": ["name", "age"]
296        });
297
298        let response = r#"{"name": "Alice", "age": 30}"#;
299        let result = validate_json_response(response, &schema);
300        assert!(result.is_ok());
301
302        let parsed: serde_json::Value = serde_json::from_str(&result.unwrap()).unwrap();
303        assert_eq!(parsed["name"], "Alice");
304        assert_eq!(parsed["age"], 30);
305    }
306
307    #[test]
308    fn validate_missing_required_field() {
309        let schema = json!({
310            "type": "object",
311            "properties": {
312                "title": { "type": "string" },
313                "score": { "type": "number" }
314            },
315            "required": ["title", "score"]
316        });
317
318        let response = r#"{"title": "Test"}"#;
319        let result = validate_json_response(response, &schema);
320        assert!(result.is_err());
321        assert!(
322            result
323                .unwrap_err()
324                .contains("Missing required field: score")
325        );
326    }
327
328    #[test]
329    fn validate_wrong_type() {
330        let schema = json!({
331            "type": "object",
332            "properties": {
333                "count": { "type": "integer" }
334            },
335            "required": ["count"]
336        });
337
338        let response = r#"{"count": "not_a_number"}"#;
339        let result = validate_json_response(response, &schema);
340        assert!(result.is_err());
341        assert!(result.unwrap_err().contains("wrong type"));
342    }
343
344    #[test]
345    fn validate_strips_markdown_code_fences() {
346        let schema = json!({
347            "type": "object",
348            "properties": {
349                "result": { "type": "string" }
350            },
351            "required": ["result"]
352        });
353
354        let response = "```json\n{\"result\": \"ok\"}\n```";
355        let result = validate_json_response(response, &schema);
356        assert!(result.is_ok());
357    }
358
359    #[test]
360    fn validate_invalid_json() {
361        let schema = json!({ "type": "object" });
362        let response = "this is not json at all";
363        let result = validate_json_response(response, &schema);
364        assert!(result.is_err());
365        assert!(result.unwrap_err().contains("Invalid JSON"));
366    }
367
368    #[test]
369    fn validate_optional_fields_accepted() {
370        let schema = json!({
371            "type": "object",
372            "properties": {
373                "name": { "type": "string" },
374                "bio": { "type": "string" }
375            },
376            "required": ["name"]
377        });
378
379        // bio is optional, so this should pass
380        let response = r#"{"name": "Bob"}"#;
381        let result = validate_json_response(response, &schema);
382        assert!(result.is_ok());
383    }
384
385    #[test]
386    fn validate_all_type_checks() {
387        assert!(type_matches(&json!("hello"), "string"));
388        assert!(!type_matches(&json!(42), "string"));
389
390        assert!(type_matches(&json!(2.72), "number"));
391        assert!(type_matches(&json!(42), "number"));
392        assert!(!type_matches(&json!("42"), "number"));
393
394        assert!(type_matches(&json!(42), "integer"));
395        assert!(!type_matches(&json!(2.72), "integer"));
396
397        assert!(type_matches(&json!(true), "boolean"));
398        assert!(!type_matches(&json!(1), "boolean"));
399
400        assert!(type_matches(&json!([1, 2]), "array"));
401        assert!(!type_matches(&json!({}), "array"));
402
403        assert!(type_matches(&json!({}), "object"));
404        assert!(!type_matches(&json!([]), "object"));
405
406        assert!(type_matches(&json!(null), "null"));
407
408        // Unknown types are accepted
409        assert!(type_matches(&json!("anything"), "custom_type"));
410    }
411
412    // ── Tool trait tests ─────────────────────────────────────────────
413
414    #[test]
415    fn tool_metadata() {
416        let tool = LlmTaskTool::new(
417            Arc::new(SecurityPolicy::default()),
418            "openrouter".to_string(),
419            "test-model".to_string(),
420            0.7,
421            None,
422            providers::ProviderRuntimeOptions::default(),
423        );
424
425        assert_eq!(tool.name(), "llm_task");
426        assert!(tool.description().contains("LLM"));
427
428        let schema = tool.parameters_schema();
429        assert_eq!(schema["type"], "object");
430        assert!(schema["properties"]["prompt"].is_object());
431        assert!(schema["properties"]["schema"].is_object());
432        assert!(schema["properties"]["model"].is_object());
433        assert!(schema["properties"]["temperature"].is_object());
434
435        let required = schema["required"].as_array().unwrap();
436        assert_eq!(required.len(), 1);
437        assert_eq!(required[0], "prompt");
438    }
439
440    #[tokio::test]
441    async fn execute_missing_prompt_returns_error() {
442        let tool = LlmTaskTool::new(
443            Arc::new(SecurityPolicy::default()),
444            "openrouter".to_string(),
445            "test-model".to_string(),
446            0.7,
447            None,
448            providers::ProviderRuntimeOptions::default(),
449        );
450
451        let result = tool.execute(json!({})).await.unwrap();
452        assert!(!result.success);
453        assert!(result.error.as_deref().unwrap().contains("prompt"));
454    }
455
456    #[tokio::test]
457    async fn execute_empty_prompt_returns_error() {
458        let tool = LlmTaskTool::new(
459            Arc::new(SecurityPolicy::default()),
460            "openrouter".to_string(),
461            "test-model".to_string(),
462            0.7,
463            None,
464            providers::ProviderRuntimeOptions::default(),
465        );
466
467        let result = tool.execute(json!({"prompt": "  "})).await.unwrap();
468        assert!(!result.success);
469        assert!(result.error.as_deref().unwrap().contains("prompt"));
470    }
471
472    #[tokio::test]
473    async fn execute_with_invalid_provider_returns_error() {
474        let tool = LlmTaskTool::new(
475            Arc::new(SecurityPolicy::default()),
476            "nonexistent_provider_xyz".to_string(),
477            "test-model".to_string(),
478            0.7,
479            None,
480            providers::ProviderRuntimeOptions::default(),
481        );
482
483        let result = tool
484            .execute(json!({"prompt": "Hello world"}))
485            .await
486            .unwrap();
487        assert!(!result.success);
488        assert!(result.error.as_deref().unwrap().contains("provider"));
489    }
490}