Skip to main content

lash_llm_tools/
lib.rs

1use std::sync::Arc;
2
3use async_trait::async_trait;
4use lash_core::plugin::{PluginError, PluginFactory, PluginSessionContext};
5use lash_core::{
6    DirectJsonSchema, DirectMessage, DirectOutputSpec, DirectPart, DirectRequest, DirectRole,
7    PluginSpec, PluginSpecFactory, ToolCall, ToolContext, ToolDefinition, ToolProvider, ToolResult,
8    ToolScheduling,
9};
10use lash_tool_support::{StaticToolExecute, StaticToolProvider};
11use serde_json::{Value, json};
12
13#[derive(Clone, Debug, Default)]
14pub struct LlmToolsPluginFactory {
15    model: Option<String>,
16    model_variant: Option<String>,
17}
18
19impl LlmToolsPluginFactory {
20    pub fn with_model(mut self, model: impl Into<String>, model_variant: Option<String>) -> Self {
21        self.model = Some(model.into());
22        self.model_variant = model_variant;
23        self
24    }
25
26    pub fn with_model_variant(mut self, model_variant: impl Into<String>) -> Self {
27        self.model_variant = Some(model_variant.into());
28        self
29    }
30}
31
32impl PluginFactory for LlmToolsPluginFactory {
33    fn id(&self) -> &'static str {
34        "llm_tools"
35    }
36
37    fn build(
38        &self,
39        ctx: &PluginSessionContext,
40    ) -> Result<Arc<dyn lash_core::SessionPlugin>, PluginError> {
41        let provider: Arc<dyn ToolProvider> = Arc::new(llm_query_provider(
42            self.model.clone(),
43            self.model_variant.clone(),
44        ));
45
46        PluginSpecFactory::new(
47            "llm_tools",
48            Arc::new(move |_ctx| Ok(PluginSpec::new().with_tool_provider(Arc::clone(&provider)))),
49        )
50        .build(ctx)
51    }
52}
53
54pub struct LlmToolsProvider {
55    model: Option<String>,
56    model_variant: Option<String>,
57}
58
59/// Build the `llm_query` tool provider for the given optional model override.
60pub fn llm_query_provider(
61    model: Option<String>,
62    model_variant: Option<String>,
63) -> StaticToolProvider<LlmToolsProvider> {
64    StaticToolProvider::new(
65        vec![llm_query_tool_definition()],
66        LlmToolsProvider {
67            model,
68            model_variant,
69        },
70    )
71}
72
73impl LlmToolsProvider {
74    async fn llm_query(&self, args: &Value, context: &ToolContext<'_>) -> Result<Value, String> {
75        let task = required_string(args, "task")?;
76        let inputs = args.get("inputs").cloned().unwrap_or(Value::Null);
77        let output_schema = parse_output_schema(args.get("output"))?;
78        let session_model = context
79            .sessions()
80            .model()
81            .await
82            .map_err(|err| format!("failed to read current session model: {err}"))?;
83        let model = self.model.clone().unwrap_or(session_model.model);
84        let model_variant = self.model_variant.clone().or(session_model.model_variant);
85        let response_schema = llm_query_response_schema(output_schema.as_ref());
86        let prompt = llm_query_prompt(&task, &inputs, output_schema.as_ref());
87
88        let output = DirectOutputSpec::JsonSchema(DirectJsonSchema {
89            name: "llm_query_result".to_string(),
90            schema: response_schema.clone(),
91            strict: true,
92        });
93
94        let completion = context
95            .direct_completions()
96            .complete(
97                DirectRequest {
98                    model,
99                    model_variant,
100                    messages: vec![
101                        DirectMessage {
102                            role: DirectRole::System,
103                            parts: vec![DirectPart::Text(
104                                "Answer the focused sub-question using only the supplied task and inputs. Return only JSON matching the requested result wrapper. Use kind=\"error\" with a concise error only when the task cannot be answered from the supplied inputs."
105                                    .to_string(),
106                            )],
107                        },
108                        DirectMessage {
109                            role: DirectRole::User,
110                            parts: vec![DirectPart::Text(prompt)],
111                        },
112                    ],
113                    attachments: Vec::new(),
114                    output,
115                    stream_events: None,
116                    generation: lash_core::GenerationOptions::default(),
117                    session_id: Some(format!("{}-llm-query", context.session_id())),
118                    caused_by: None,
119                    replay: None,
120                },
121                "llm_query",
122            )
123            .await
124            .map_err(|err| format!("llm_query failed: {err}"))?;
125
126        parse_llm_query_result(&completion.text, &response_schema)
127    }
128}
129
130#[async_trait]
131impl StaticToolExecute for LlmToolsProvider {
132    async fn execute(&self, call: ToolCall<'_>) -> ToolResult {
133        let result = match call.name {
134            "llm_query" => self.llm_query(call.args, call.context).await,
135            _ => Err(format!("Unknown tool: {}", call.name)),
136        };
137        finalise_tool_result(result)
138    }
139}
140
141pub fn llm_query_tool_definition() -> ToolDefinition {
142    tool_definition(
143        "llm_query",
144        "Run a one-shot LLM prompt over supplied data and return its result. The `task` plus everything in `inputs` is rendered into that single prompt; the call cannot use tools, inspect files, or gather more context beyond what you pass it. Use this for extracting information, classification, summarization, judging, or transformation over data already in your variables. `inputs` can be any structured value. `output` is optional and defaults to a string; when present, it requests structured output using record descriptors or `Type { ... }` literals.",
145        llm_query_input_schema(),
146        vec![
147            r#"summary = await llm.query({ task: "Summarize the supplied notes in three bullets", inputs: { notes: notes } })?"#.into(),
148            r#"claims = await llm.query({ task: "Extract the key claim from each supplied chunk", inputs: { chunks: chunks }, output: { claims: "list[str]" } })?"#.into(),
149        ],
150        ToolScheduling::Parallel,
151    )
152    .with_lashlang_binding(lash_core::LashlangToolBinding::new(["llm"], "query"))
153    .with_output_from_input_schema("output", Some(json!({ "type": "string" })))
154}
155
156pub fn parse_output_schema(value: Option<&Value>) -> Result<Option<Value>, String> {
157    let Some(value) = value else {
158        return Ok(None);
159    };
160    if value.is_null() {
161        return Ok(None);
162    }
163    let output = value.as_object().ok_or_else(|| {
164        "invalid `output`: expected a record describing the typed shape".to_string()
165    })?;
166    if output.is_empty() {
167        return Err("at least one output field is required".to_string());
168    }
169
170    if output.len() == 1
171        && let Some(schema) = output.get(lashlang::LASH_TYPE_KEY)
172    {
173        validate_schema(schema)?;
174        return Ok(Some(schema.clone()));
175    }
176
177    let mut properties = serde_json::Map::new();
178    let mut required = Vec::new();
179    for (name, descriptor) in output {
180        let type_str = descriptor
181            .as_str()
182            .ok_or_else(|| format!("field `{name}`: type descriptor must be a string"))?;
183        properties.insert(name.clone(), type_descriptor_to_json_schema(type_str)?);
184        required.push(Value::String(name.clone()));
185    }
186    Ok(Some(json!({
187        "type": "object",
188        "properties": properties,
189        "required": required,
190        "additionalProperties": false,
191    })))
192}
193
194fn llm_query_input_schema() -> Value {
195    json!({
196        "type": "object",
197        "properties": {
198            "task": { "type": "string" },
199            "inputs": {},
200            "output": { "type": "object", "additionalProperties": true }
201        },
202        "required": ["task"],
203        "additionalProperties": false
204    })
205}
206
207fn llm_query_prompt(task: &str, inputs: &Value, output_schema: Option<&Value>) -> String {
208    let mut sections = Vec::new();
209    sections.push(format!("Task:\n{task}"));
210    sections.push(format!(
211        "Inputs:\n```json\n{}\n```",
212        serde_json::to_string_pretty(inputs).unwrap_or_else(|_| inputs.to_string())
213    ));
214    if let Some(schema) = output_schema {
215        sections.push(format!(
216            "Return `kind=\"value\"` with `value` matching this JSON Schema, or `kind=\"error\"` with a concise error if the task cannot be answered from the supplied inputs:\n```json\n{}\n```",
217            serde_json::to_string_pretty(schema).unwrap_or_else(|_| schema.to_string())
218        ));
219    } else {
220        sections.push("Return `kind=\"value\"` with a concise string `value`, or `kind=\"error\"` with a concise error if the task cannot be answered from the supplied inputs.".to_string());
221    }
222    sections.join("\n\n")
223}
224
225fn llm_query_response_schema(output_schema: Option<&Value>) -> Value {
226    let value_schema = output_schema
227        .cloned()
228        .unwrap_or_else(|| json!({"type": "string"}));
229    json!({
230        "type": "object",
231        "additionalProperties": false,
232        "required": ["kind", "value", "error"],
233        "properties": {
234            "kind": { "type": "string", "enum": ["value", "error"] },
235            "value": {
236                "anyOf": [
237                    value_schema,
238                    { "type": "null" }
239                ]
240            },
241            "error": {
242                "anyOf": [
243                    { "type": "string" },
244                    { "type": "null" }
245                ]
246            }
247        }
248    })
249}
250
251fn parse_llm_query_result(text: &str, schema: &Value) -> Result<Value, String> {
252    let trimmed = text.trim();
253    if trimmed.is_empty() {
254        return Err("llm_query returned empty output".to_string());
255    }
256    let value = serde_json::from_str::<Value>(trimmed).or_else(|err| {
257        let Some(start) = trimmed.find(['{', '[', '"']) else {
258            return Err(format!("llm_query returned non-JSON output: {err}"));
259        };
260        let end = trimmed
261            .rfind(['}', ']', '"'])
262            .ok_or_else(|| format!("llm_query returned malformed JSON output: {err}"))?;
263        if end < start {
264            return Err(format!("llm_query returned malformed JSON output: {err}"));
265        }
266        serde_json::from_str::<Value>(&trimmed[start..=end])
267            .map_err(|parse_err| format!("llm_query returned malformed JSON output: {parse_err}"))
268    })?;
269    let compiled = jsonschema::JSONSchema::compile(schema)
270        .map_err(|err| format!("llm_query output schema is invalid: {err}"))?;
271    if let Err(errors) = compiled.validate(&value) {
272        let message = errors
273            .map(|err| err.to_string())
274            .collect::<Vec<_>>()
275            .join("; ");
276        return Err(format!("llm_query output did not match schema: {message}"));
277    }
278    match value.get("kind").and_then(Value::as_str) {
279        Some("value") => value
280            .get("value")
281            .cloned()
282            .filter(|value| !value.is_null())
283            .ok_or_else(|| "llm_query returned value result without value".to_string()),
284        Some("error") => Err(value
285            .get("error")
286            .and_then(Value::as_str)
287            .map(str::trim)
288            .filter(|message| !message.is_empty())
289            .unwrap_or("llm_query returned an error")
290            .to_string()),
291        Some(other) => Err(format!("llm_query returned unknown result kind `{other}`")),
292        None => Err("llm_query returned result without kind field".to_string()),
293    }
294}
295
296fn tool_definition(
297    name: &str,
298    description: impl Into<String>,
299    input_schema: Value,
300    examples: Vec<String>,
301    execution_mode: ToolScheduling,
302) -> ToolDefinition {
303    ToolDefinition::raw(
304        format!("tool:{name}"),
305        name,
306        description,
307        input_schema,
308        json!({ "type": "object", "additionalProperties": true }),
309    )
310    .with_examples(examples)
311    .with_scheduling(execution_mode)
312}
313
314fn required_string(args: &Value, key: &str) -> Result<String, String> {
315    args.get(key)
316        .and_then(Value::as_str)
317        .map(str::trim)
318        .filter(|value| !value.is_empty())
319        .map(ToOwned::to_owned)
320        .ok_or_else(|| format!("missing required parameter: {key}"))
321}
322
323fn validate_schema(schema: &Value) -> Result<(), String> {
324    let object = schema
325        .as_object()
326        .ok_or_else(|| "Type schema must be a JSON object".to_string())?;
327    let kind = object
328        .get("type")
329        .and_then(Value::as_str)
330        .ok_or_else(|| "Type schema missing `type` field".to_string())?;
331    match kind {
332        "object" | "array" | "string" | "integer" | "number" | "boolean" => Ok(()),
333        other => Err(format!("unsupported Type schema kind `{other}`")),
334    }
335}
336
337fn type_descriptor_to_json_schema(descriptor: &str) -> Result<Value, String> {
338    let scalar = |ty: &str| -> Result<Value, String> {
339        match ty {
340            "str" | "string" => Ok(json!({"type": "string"})),
341            "int" | "integer" => Ok(json!({"type": "integer"})),
342            "float" | "number" => Ok(json!({"type": "number"})),
343            "bool" | "boolean" => Ok(json!({"type": "boolean"})),
344            "record" | "dict" | "object" => {
345                Ok(json!({"type": "object", "additionalProperties": true}))
346            }
347            other => Err(format!("unknown scalar type `{other}`")),
348        }
349    };
350    let trimmed = descriptor.trim();
351    if let Some(inner) = trimmed
352        .strip_prefix("list[")
353        .and_then(|rest| rest.strip_suffix(']'))
354    {
355        return Ok(json!({
356            "type": "array",
357            "items": scalar(inner.trim())?,
358        }));
359    }
360    scalar(trimmed)
361}
362
363fn finalise_tool_result(result: Result<Value, String>) -> ToolResult {
364    match result {
365        Ok(value) => ToolResult::ok(value),
366        Err(err) => ToolResult::err(json!(err)),
367    }
368}
369
370#[cfg(test)]
371mod tests {
372    use super::*;
373    use std::sync::Mutex;
374
375    use async_trait::async_trait;
376    use lash_core::plugin::runtime_host::{
377        SessionGraphService, SessionLifecycleService, SessionStateService,
378    };
379    use lash_core::plugin::{PluginError, SessionHandle};
380    use lash_core::runtime::RuntimeSessionState;
381    use lash_core::{SessionCreateRequest, SessionSnapshot, ToolCall};
382
383    fn model_spec(model: &str, variant: Option<&str>) -> lash_core::ModelSpec {
384        lash_core::ModelSpec::from_token_limits(model, variant.map(str::to_string), 200_000, None)
385            .expect("valid test model spec")
386    }
387
388    #[derive(Default)]
389    struct DirectCompletionManager {
390        snapshot: RuntimeSessionState,
391        requests: Mutex<Vec<(lash_core::DirectRequest, String)>>,
392        response_text: String,
393    }
394
395    #[async_trait]
396    impl SessionStateService for DirectCompletionManager {
397        async fn snapshot_current(&self) -> Result<SessionSnapshot, PluginError> {
398            Ok(self.snapshot.to_snapshot())
399        }
400
401        async fn snapshot_session(
402            &self,
403            _session_id: &str,
404        ) -> Result<SessionSnapshot, PluginError> {
405            Ok(self.snapshot.to_snapshot())
406        }
407        async fn tool_catalog(
408            &self,
409            _session_id: &str,
410        ) -> Result<Vec<serde_json::Value>, PluginError> {
411            Ok(Vec::new())
412        }
413    }
414
415    #[async_trait]
416    impl SessionLifecycleService for DirectCompletionManager {
417        async fn create_session(
418            &self,
419            _request: SessionCreateRequest,
420        ) -> Result<SessionHandle, PluginError> {
421            Err(PluginError::Session("not used".to_string()))
422        }
423
424        async fn close_session(&self, _session_id: &str) -> Result<(), PluginError> {
425            Ok(())
426        }
427    }
428
429    #[async_trait]
430    impl SessionGraphService for DirectCompletionManager {}
431
432    fn direct_completion_context(
433        manager: Arc<DirectCompletionManager>,
434    ) -> lash_core::ToolContext<'static> {
435        let completions = lash_core::DirectCompletionClient::from_fn({
436            let manager = Arc::clone(&manager);
437            move |request, usage_source| {
438                manager
439                    .requests
440                    .lock()
441                    .expect("requests")
442                    .push((request, usage_source));
443                Ok(lash_core::DirectCompletion {
444                    text: manager.response_text.clone(),
445                    usage: lash_core::TokenUsage::default(),
446                })
447            }
448        });
449        lash_core::testing::mock_tool_context_with_host_and_direct_completions(manager, completions)
450    }
451
452    #[test]
453    fn llm_definitions_include_llm_query_only() {
454        let provider = llm_query_provider(None, None);
455        let manifests = provider.tool_manifests();
456        let names = manifests
457            .iter()
458            .map(|tool| tool.name.clone())
459            .collect::<Vec<_>>();
460        assert_eq!(names, vec!["llm_query"]);
461        assert_eq!(
462            manifests[0].effective_availability(),
463            lash_core::ToolAvailability::Showcased
464        );
465    }
466
467    #[test]
468    fn output_schema_supports_scalars_and_lists() {
469        let schema = parse_output_schema(Some(&json!({
470            "answer": "str",
471            "count": "int",
472            "items": "list[str]"
473        })))
474        .expect("schema")
475        .expect("present");
476        assert_eq!(schema["properties"]["answer"]["type"], json!("string"));
477        assert_eq!(schema["properties"]["count"]["type"], json!("integer"));
478        assert_eq!(schema["properties"]["items"]["type"], json!("array"));
479    }
480
481    #[test]
482    fn output_schema_passes_through_lash_type_wrapper() {
483        let inner_schema = json!({
484            "type": "object",
485            "properties": {
486                "name": { "type": "string" },
487                "tags": { "type": "array", "items": { "type": "string" } },
488                "status": { "type": "string", "enum": ["ok", "err"] }
489            },
490            "required": ["name", "tags", "status"],
491            "additionalProperties": false
492        });
493        let wrapped = json!({ lashlang::LASH_TYPE_KEY: inner_schema.clone() });
494        let schema = parse_output_schema(Some(&wrapped))
495            .expect("schema")
496            .expect("present");
497        assert_eq!(schema, inner_schema);
498    }
499
500    #[test]
501    fn output_schema_rejects_lash_type_without_type_field() {
502        let wrapped = json!({ lashlang::LASH_TYPE_KEY: {"properties": {}} });
503        let err = parse_output_schema(Some(&wrapped)).expect_err("missing type");
504        assert!(err.contains("type"), "error: {err}");
505    }
506
507    #[test]
508    fn output_schema_accepts_array_top_level_type() {
509        let wrapped = json!({
510            lashlang::LASH_TYPE_KEY: {
511                "type": "array",
512                "items": {"type": "string"}
513            }
514        });
515        let schema = parse_output_schema(Some(&wrapped))
516            .expect("schema")
517            .expect("present");
518        assert_eq!(schema["type"], json!("array"));
519    }
520
521    #[tokio::test]
522    async fn llm_query_uses_current_policy_and_direct_completion() {
523        let manager = Arc::new(DirectCompletionManager {
524            snapshot: RuntimeSessionState {
525                policy: lash_core::SessionPolicy {
526                    model: model_spec("root-model", Some("fast")),
527                    ..lash_core::SessionPolicy::default()
528                },
529                ..RuntimeSessionState::default()
530            },
531            requests: Mutex::new(Vec::new()),
532            response_text:
533                r#"{"kind":"value","value":{"root_cause":"missing config","confidence":0.8},"error":null}"#
534                    .to_string(),
535        });
536        let provider = llm_query_provider(None, None);
537        let context = direct_completion_context(manager.clone());
538
539        let args = json!({
540            "task": "extract root cause",
541            "inputs": { "log": "failed" },
542            "output": { "root_cause": "str", "confidence": "float" }
543        });
544        let result = provider
545            .execute(ToolCall {
546                name: "llm_query",
547                args: &args,
548                context: &context,
549                progress: None,
550            })
551            .await;
552
553        assert!(result.is_success(), "{:?}", result.value_for_projection());
554        assert_eq!(
555            result.value_for_projection()["root_cause"],
556            json!("missing config")
557        );
558        assert_eq!(result.value_for_projection()["confidence"], json!(0.8));
559
560        let requests = manager.requests.lock().expect("requests");
561        assert_eq!(requests.len(), 1);
562        let (request, usage_source) = &requests[0];
563        assert_eq!(usage_source, "llm_query");
564        assert_eq!(request.model, "root-model");
565        assert_eq!(request.model_variant.as_deref(), Some("fast"));
566        assert!(matches!(
567            request.output,
568            lash_core::DirectOutputSpec::JsonSchema(_)
569        ));
570        let prompt = request
571            .messages
572            .iter()
573            .flat_map(|message| message.parts.iter())
574            .filter_map(|part| match part {
575                lash_core::DirectPart::Text(text) => Some(text.as_str()),
576                lash_core::DirectPart::Image(_) => None,
577            })
578            .collect::<Vec<_>>()
579            .join("\n");
580        assert!(prompt.contains("extract root cause"));
581        assert!(prompt.contains("\"log\": \"failed\""));
582    }
583
584    #[tokio::test]
585    async fn llm_query_uses_configured_model_override() {
586        let manager = Arc::new(DirectCompletionManager {
587            snapshot: RuntimeSessionState {
588                policy: lash_core::SessionPolicy {
589                    model: model_spec("root-model", Some("medium")),
590                    ..lash_core::SessionPolicy::default()
591                },
592                ..RuntimeSessionState::default()
593            },
594            requests: Mutex::new(Vec::new()),
595            response_text: r#"{"kind":"value","value":"done","error":null}"#.to_string(),
596        });
597        let provider = llm_query_provider(Some("gpt-5.5".to_string()), Some("low".to_string()));
598        let context = direct_completion_context(manager.clone());
599
600        let args = json!({ "task": "answer directly" });
601        let result = provider
602            .execute(ToolCall {
603                name: "llm_query",
604                args: &args,
605                context: &context,
606                progress: None,
607            })
608            .await;
609
610        assert!(result.is_success(), "{:?}", result.value_for_projection());
611        let requests = manager.requests.lock().expect("requests");
612        assert_eq!(requests.len(), 1);
613        let (request, usage_source) = &requests[0];
614        assert_eq!(usage_source, "llm_query");
615        assert_eq!(request.model, "gpt-5.5");
616        assert_eq!(request.model_variant.as_deref(), Some("low"));
617    }
618
619    #[tokio::test]
620    async fn llm_query_error_result_fails_tool_call() {
621        let manager = Arc::new(DirectCompletionManager {
622            snapshot: RuntimeSessionState {
623                policy: lash_core::SessionPolicy::default(),
624                ..RuntimeSessionState::default()
625            },
626            requests: Mutex::new(Vec::new()),
627            response_text: r#"{"kind":"error","value":null,"error":"missing required evidence"}"#
628                .to_string(),
629        });
630        let provider = llm_query_provider(None, None);
631        let context = direct_completion_context(manager);
632
633        let args = json!({ "task": "answer from missing evidence" });
634        let result = provider
635            .execute(ToolCall {
636                name: "llm_query",
637                args: &args,
638                context: &context,
639                progress: None,
640            })
641            .await;
642
643        assert!(!result.is_success());
644        assert_eq!(
645            result.value_for_projection(),
646            json!("missing required evidence")
647        );
648    }
649}