1use std::sync::Arc;
2
3use async_trait::async_trait;
4use lash_core::plugin::{PluginError, PluginFactory, PluginSessionContext};
5use lash_core::{
6 DirectJsonSchema, DirectMessage, DirectOutputSpec, DirectPart, DirectRequest, DirectRole,
7 PluginSpec, PluginSpecFactory, ToolCall, ToolContext, ToolDefinition, ToolProvider, ToolResult,
8 ToolScheduling,
9};
10use lash_tool_support::{StaticToolExecute, StaticToolProvider};
11use serde_json::{Value, json};
12
13#[derive(Clone, Debug, Default)]
14pub struct LlmToolsPluginFactory {
15 model: Option<String>,
16 model_variant: Option<String>,
17}
18
19impl LlmToolsPluginFactory {
20 pub fn with_model(mut self, model: impl Into<String>, model_variant: Option<String>) -> Self {
21 self.model = Some(model.into());
22 self.model_variant = model_variant;
23 self
24 }
25
26 pub fn with_model_variant(mut self, model_variant: impl Into<String>) -> Self {
27 self.model_variant = Some(model_variant.into());
28 self
29 }
30}
31
32impl PluginFactory for LlmToolsPluginFactory {
33 fn id(&self) -> &'static str {
34 "llm_tools"
35 }
36
37 fn build(
38 &self,
39 ctx: &PluginSessionContext,
40 ) -> Result<Arc<dyn lash_core::SessionPlugin>, PluginError> {
41 let provider: Arc<dyn ToolProvider> = Arc::new(llm_query_provider(
42 self.model.clone(),
43 self.model_variant.clone(),
44 ));
45
46 PluginSpecFactory::new(
47 "llm_tools",
48 Arc::new(move |_ctx| Ok(PluginSpec::new().with_tool_provider(Arc::clone(&provider)))),
49 )
50 .build(ctx)
51 }
52}
53
54pub struct LlmToolsProvider {
55 model: Option<String>,
56 model_variant: Option<String>,
57}
58
59pub fn llm_query_provider(
61 model: Option<String>,
62 model_variant: Option<String>,
63) -> StaticToolProvider<LlmToolsProvider> {
64 StaticToolProvider::new(
65 vec![llm_query_tool_definition()],
66 LlmToolsProvider {
67 model,
68 model_variant,
69 },
70 )
71}
72
73impl LlmToolsProvider {
74 async fn llm_query(&self, args: &Value, context: &ToolContext<'_>) -> Result<Value, String> {
75 let task = required_string(args, "task")?;
76 let inputs = args.get("inputs").cloned().unwrap_or(Value::Null);
77 let output_schema = parse_output_schema(args.get("output"))?;
78 let session_model = context
79 .sessions()
80 .model()
81 .await
82 .map_err(|err| format!("failed to read current session model: {err}"))?;
83 let model = self.model.clone().unwrap_or(session_model.model);
84 let model_variant = self.model_variant.clone().or(session_model.model_variant);
85 let response_schema = llm_query_response_schema(output_schema.as_ref());
86 let prompt = llm_query_prompt(&task, &inputs, output_schema.as_ref());
87
88 let output = DirectOutputSpec::JsonSchema(DirectJsonSchema {
89 name: "llm_query_result".to_string(),
90 schema: response_schema.clone(),
91 strict: true,
92 });
93
94 let completion = context
95 .direct_completions()
96 .complete(
97 DirectRequest {
98 model,
99 model_variant,
100 messages: vec![
101 DirectMessage {
102 role: DirectRole::System,
103 parts: vec![DirectPart::Text(
104 "Answer the focused sub-question using only the supplied task and inputs. Return only JSON matching the requested result wrapper. Use kind=\"error\" with a concise error only when the task cannot be answered from the supplied inputs."
105 .to_string(),
106 )],
107 },
108 DirectMessage {
109 role: DirectRole::User,
110 parts: vec![DirectPart::Text(prompt)],
111 },
112 ],
113 attachments: Vec::new(),
114 output,
115 stream_events: None,
116 generation: lash_core::GenerationOptions::default(),
117 session_id: Some(format!("{}-llm-query", context.session_id())),
118 caused_by: None,
119 replay: None,
120 },
121 "llm_query",
122 )
123 .await
124 .map_err(|err| format!("llm_query failed: {err}"))?;
125
126 parse_llm_query_result(&completion.text, &response_schema)
127 }
128}
129
130#[async_trait]
131impl StaticToolExecute for LlmToolsProvider {
132 async fn execute(&self, call: ToolCall<'_>) -> ToolResult {
133 let result = match call.name {
134 "llm_query" => self.llm_query(call.args, call.context).await,
135 _ => Err(format!("Unknown tool: {}", call.name)),
136 };
137 finalise_tool_result(result)
138 }
139}
140
141pub fn llm_query_tool_definition() -> ToolDefinition {
142 tool_definition(
143 "llm_query",
144 "Run a one-shot LLM prompt over supplied data and return its result. The `task` plus everything in `inputs` is rendered into that single prompt; the call cannot use tools, inspect files, or gather more context beyond what you pass it. Use this for extracting information, classification, summarization, judging, or transformation over data already in your variables. `inputs` can be any structured value. `output` is optional and defaults to a string; when present, it requests structured output using record descriptors or `Type { ... }` literals.",
145 llm_query_input_schema(),
146 vec![
147 r#"summary = await llm.query({ task: "Summarize the supplied notes in three bullets", inputs: { notes: notes } })?"#.into(),
148 r#"claims = await llm.query({ task: "Extract the key claim from each supplied chunk", inputs: { chunks: chunks }, output: { claims: "list[str]" } })?"#.into(),
149 ],
150 ToolScheduling::Parallel,
151 )
152 .with_lashlang_binding(lash_core::LashlangToolBinding::new(["llm"], "query"))
153 .with_output_from_input_schema("output", Some(json!({ "type": "string" })))
154}
155
156pub fn parse_output_schema(value: Option<&Value>) -> Result<Option<Value>, String> {
157 let Some(value) = value else {
158 return Ok(None);
159 };
160 if value.is_null() {
161 return Ok(None);
162 }
163 let output = value.as_object().ok_or_else(|| {
164 "invalid `output`: expected a record describing the typed shape".to_string()
165 })?;
166 if output.is_empty() {
167 return Err("at least one output field is required".to_string());
168 }
169
170 if output.len() == 1
171 && let Some(schema) = output.get(lashlang::LASH_TYPE_KEY)
172 {
173 validate_schema(schema)?;
174 return Ok(Some(schema.clone()));
175 }
176
177 let mut properties = serde_json::Map::new();
178 let mut required = Vec::new();
179 for (name, descriptor) in output {
180 let type_str = descriptor
181 .as_str()
182 .ok_or_else(|| format!("field `{name}`: type descriptor must be a string"))?;
183 properties.insert(name.clone(), type_descriptor_to_json_schema(type_str)?);
184 required.push(Value::String(name.clone()));
185 }
186 Ok(Some(json!({
187 "type": "object",
188 "properties": properties,
189 "required": required,
190 "additionalProperties": false,
191 })))
192}
193
194fn llm_query_input_schema() -> Value {
195 json!({
196 "type": "object",
197 "properties": {
198 "task": { "type": "string" },
199 "inputs": {},
200 "output": { "type": "object", "additionalProperties": true }
201 },
202 "required": ["task"],
203 "additionalProperties": false
204 })
205}
206
207fn llm_query_prompt(task: &str, inputs: &Value, output_schema: Option<&Value>) -> String {
208 let mut sections = Vec::new();
209 sections.push(format!("Task:\n{task}"));
210 sections.push(format!(
211 "Inputs:\n```json\n{}\n```",
212 serde_json::to_string_pretty(inputs).unwrap_or_else(|_| inputs.to_string())
213 ));
214 if let Some(schema) = output_schema {
215 sections.push(format!(
216 "Return `kind=\"value\"` with `value` matching this JSON Schema, or `kind=\"error\"` with a concise error if the task cannot be answered from the supplied inputs:\n```json\n{}\n```",
217 serde_json::to_string_pretty(schema).unwrap_or_else(|_| schema.to_string())
218 ));
219 } else {
220 sections.push("Return `kind=\"value\"` with a concise string `value`, or `kind=\"error\"` with a concise error if the task cannot be answered from the supplied inputs.".to_string());
221 }
222 sections.join("\n\n")
223}
224
225fn llm_query_response_schema(output_schema: Option<&Value>) -> Value {
226 let value_schema = output_schema
227 .cloned()
228 .unwrap_or_else(|| json!({"type": "string"}));
229 json!({
230 "type": "object",
231 "additionalProperties": false,
232 "required": ["kind", "value", "error"],
233 "properties": {
234 "kind": { "type": "string", "enum": ["value", "error"] },
235 "value": {
236 "anyOf": [
237 value_schema,
238 { "type": "null" }
239 ]
240 },
241 "error": {
242 "anyOf": [
243 { "type": "string" },
244 { "type": "null" }
245 ]
246 }
247 }
248 })
249}
250
251fn parse_llm_query_result(text: &str, schema: &Value) -> Result<Value, String> {
252 let trimmed = text.trim();
253 if trimmed.is_empty() {
254 return Err("llm_query returned empty output".to_string());
255 }
256 let value = serde_json::from_str::<Value>(trimmed).or_else(|err| {
257 let Some(start) = trimmed.find(['{', '[', '"']) else {
258 return Err(format!("llm_query returned non-JSON output: {err}"));
259 };
260 let end = trimmed
261 .rfind(['}', ']', '"'])
262 .ok_or_else(|| format!("llm_query returned malformed JSON output: {err}"))?;
263 if end < start {
264 return Err(format!("llm_query returned malformed JSON output: {err}"));
265 }
266 serde_json::from_str::<Value>(&trimmed[start..=end])
267 .map_err(|parse_err| format!("llm_query returned malformed JSON output: {parse_err}"))
268 })?;
269 let compiled = jsonschema::JSONSchema::compile(schema)
270 .map_err(|err| format!("llm_query output schema is invalid: {err}"))?;
271 if let Err(errors) = compiled.validate(&value) {
272 let message = errors
273 .map(|err| err.to_string())
274 .collect::<Vec<_>>()
275 .join("; ");
276 return Err(format!("llm_query output did not match schema: {message}"));
277 }
278 match value.get("kind").and_then(Value::as_str) {
279 Some("value") => value
280 .get("value")
281 .cloned()
282 .filter(|value| !value.is_null())
283 .ok_or_else(|| "llm_query returned value result without value".to_string()),
284 Some("error") => Err(value
285 .get("error")
286 .and_then(Value::as_str)
287 .map(str::trim)
288 .filter(|message| !message.is_empty())
289 .unwrap_or("llm_query returned an error")
290 .to_string()),
291 Some(other) => Err(format!("llm_query returned unknown result kind `{other}`")),
292 None => Err("llm_query returned result without kind field".to_string()),
293 }
294}
295
296fn tool_definition(
297 name: &str,
298 description: impl Into<String>,
299 input_schema: Value,
300 examples: Vec<String>,
301 execution_mode: ToolScheduling,
302) -> ToolDefinition {
303 ToolDefinition::raw(
304 format!("tool:{name}"),
305 name,
306 description,
307 input_schema,
308 json!({ "type": "object", "additionalProperties": true }),
309 )
310 .with_examples(examples)
311 .with_scheduling(execution_mode)
312}
313
314fn required_string(args: &Value, key: &str) -> Result<String, String> {
315 args.get(key)
316 .and_then(Value::as_str)
317 .map(str::trim)
318 .filter(|value| !value.is_empty())
319 .map(ToOwned::to_owned)
320 .ok_or_else(|| format!("missing required parameter: {key}"))
321}
322
323fn validate_schema(schema: &Value) -> Result<(), String> {
324 let object = schema
325 .as_object()
326 .ok_or_else(|| "Type schema must be a JSON object".to_string())?;
327 let kind = object
328 .get("type")
329 .and_then(Value::as_str)
330 .ok_or_else(|| "Type schema missing `type` field".to_string())?;
331 match kind {
332 "object" | "array" | "string" | "integer" | "number" | "boolean" => Ok(()),
333 other => Err(format!("unsupported Type schema kind `{other}`")),
334 }
335}
336
337fn type_descriptor_to_json_schema(descriptor: &str) -> Result<Value, String> {
338 let scalar = |ty: &str| -> Result<Value, String> {
339 match ty {
340 "str" | "string" => Ok(json!({"type": "string"})),
341 "int" | "integer" => Ok(json!({"type": "integer"})),
342 "float" | "number" => Ok(json!({"type": "number"})),
343 "bool" | "boolean" => Ok(json!({"type": "boolean"})),
344 "record" | "dict" | "object" => {
345 Ok(json!({"type": "object", "additionalProperties": true}))
346 }
347 other => Err(format!("unknown scalar type `{other}`")),
348 }
349 };
350 let trimmed = descriptor.trim();
351 if let Some(inner) = trimmed
352 .strip_prefix("list[")
353 .and_then(|rest| rest.strip_suffix(']'))
354 {
355 return Ok(json!({
356 "type": "array",
357 "items": scalar(inner.trim())?,
358 }));
359 }
360 scalar(trimmed)
361}
362
363fn finalise_tool_result(result: Result<Value, String>) -> ToolResult {
364 match result {
365 Ok(value) => ToolResult::ok(value),
366 Err(err) => ToolResult::err(json!(err)),
367 }
368}
369
370#[cfg(test)]
371mod tests {
372 use super::*;
373 use std::sync::Mutex;
374
375 use async_trait::async_trait;
376 use lash_core::plugin::runtime_host::{
377 SessionGraphService, SessionLifecycleService, SessionStateService,
378 };
379 use lash_core::plugin::{PluginError, SessionHandle};
380 use lash_core::runtime::RuntimeSessionState;
381 use lash_core::{SessionCreateRequest, SessionSnapshot, ToolCall};
382
383 fn model_spec(model: &str, variant: Option<&str>) -> lash_core::ModelSpec {
384 lash_core::ModelSpec::from_token_limits(model, variant.map(str::to_string), 200_000, None)
385 .expect("valid test model spec")
386 }
387
388 #[derive(Default)]
389 struct DirectCompletionManager {
390 snapshot: RuntimeSessionState,
391 requests: Mutex<Vec<(lash_core::DirectRequest, String)>>,
392 response_text: String,
393 }
394
395 #[async_trait]
396 impl SessionStateService for DirectCompletionManager {
397 async fn snapshot_current(&self) -> Result<SessionSnapshot, PluginError> {
398 Ok(self.snapshot.to_snapshot())
399 }
400
401 async fn snapshot_session(
402 &self,
403 _session_id: &str,
404 ) -> Result<SessionSnapshot, PluginError> {
405 Ok(self.snapshot.to_snapshot())
406 }
407 async fn tool_catalog(
408 &self,
409 _session_id: &str,
410 ) -> Result<Vec<serde_json::Value>, PluginError> {
411 Ok(Vec::new())
412 }
413 }
414
415 #[async_trait]
416 impl SessionLifecycleService for DirectCompletionManager {
417 async fn create_session(
418 &self,
419 _request: SessionCreateRequest,
420 ) -> Result<SessionHandle, PluginError> {
421 Err(PluginError::Session("not used".to_string()))
422 }
423
424 async fn close_session(&self, _session_id: &str) -> Result<(), PluginError> {
425 Ok(())
426 }
427 }
428
429 #[async_trait]
430 impl SessionGraphService for DirectCompletionManager {}
431
432 fn direct_completion_context(
433 manager: Arc<DirectCompletionManager>,
434 ) -> lash_core::ToolContext<'static> {
435 let completions = lash_core::DirectCompletionClient::from_fn({
436 let manager = Arc::clone(&manager);
437 move |request, usage_source| {
438 manager
439 .requests
440 .lock()
441 .expect("requests")
442 .push((request, usage_source));
443 Ok(lash_core::DirectCompletion {
444 text: manager.response_text.clone(),
445 usage: lash_core::TokenUsage::default(),
446 })
447 }
448 });
449 lash_core::testing::mock_tool_context_with_host_and_direct_completions(manager, completions)
450 }
451
452 #[test]
453 fn llm_definitions_include_llm_query_only() {
454 let provider = llm_query_provider(None, None);
455 let manifests = provider.tool_manifests();
456 let names = manifests
457 .iter()
458 .map(|tool| tool.name.clone())
459 .collect::<Vec<_>>();
460 assert_eq!(names, vec!["llm_query"]);
461 assert_eq!(
462 manifests[0].effective_availability(),
463 lash_core::ToolAvailability::Showcased
464 );
465 }
466
467 #[test]
468 fn output_schema_supports_scalars_and_lists() {
469 let schema = parse_output_schema(Some(&json!({
470 "answer": "str",
471 "count": "int",
472 "items": "list[str]"
473 })))
474 .expect("schema")
475 .expect("present");
476 assert_eq!(schema["properties"]["answer"]["type"], json!("string"));
477 assert_eq!(schema["properties"]["count"]["type"], json!("integer"));
478 assert_eq!(schema["properties"]["items"]["type"], json!("array"));
479 }
480
481 #[test]
482 fn output_schema_passes_through_lash_type_wrapper() {
483 let inner_schema = json!({
484 "type": "object",
485 "properties": {
486 "name": { "type": "string" },
487 "tags": { "type": "array", "items": { "type": "string" } },
488 "status": { "type": "string", "enum": ["ok", "err"] }
489 },
490 "required": ["name", "tags", "status"],
491 "additionalProperties": false
492 });
493 let wrapped = json!({ lashlang::LASH_TYPE_KEY: inner_schema.clone() });
494 let schema = parse_output_schema(Some(&wrapped))
495 .expect("schema")
496 .expect("present");
497 assert_eq!(schema, inner_schema);
498 }
499
500 #[test]
501 fn output_schema_rejects_lash_type_without_type_field() {
502 let wrapped = json!({ lashlang::LASH_TYPE_KEY: {"properties": {}} });
503 let err = parse_output_schema(Some(&wrapped)).expect_err("missing type");
504 assert!(err.contains("type"), "error: {err}");
505 }
506
507 #[test]
508 fn output_schema_accepts_array_top_level_type() {
509 let wrapped = json!({
510 lashlang::LASH_TYPE_KEY: {
511 "type": "array",
512 "items": {"type": "string"}
513 }
514 });
515 let schema = parse_output_schema(Some(&wrapped))
516 .expect("schema")
517 .expect("present");
518 assert_eq!(schema["type"], json!("array"));
519 }
520
521 #[tokio::test]
522 async fn llm_query_uses_current_policy_and_direct_completion() {
523 let manager = Arc::new(DirectCompletionManager {
524 snapshot: RuntimeSessionState {
525 policy: lash_core::SessionPolicy {
526 model: model_spec("root-model", Some("fast")),
527 ..lash_core::SessionPolicy::default()
528 },
529 ..RuntimeSessionState::default()
530 },
531 requests: Mutex::new(Vec::new()),
532 response_text:
533 r#"{"kind":"value","value":{"root_cause":"missing config","confidence":0.8},"error":null}"#
534 .to_string(),
535 });
536 let provider = llm_query_provider(None, None);
537 let context = direct_completion_context(manager.clone());
538
539 let args = json!({
540 "task": "extract root cause",
541 "inputs": { "log": "failed" },
542 "output": { "root_cause": "str", "confidence": "float" }
543 });
544 let result = provider
545 .execute(ToolCall {
546 name: "llm_query",
547 args: &args,
548 context: &context,
549 progress: None,
550 })
551 .await;
552
553 assert!(result.is_success(), "{:?}", result.value_for_projection());
554 assert_eq!(
555 result.value_for_projection()["root_cause"],
556 json!("missing config")
557 );
558 assert_eq!(result.value_for_projection()["confidence"], json!(0.8));
559
560 let requests = manager.requests.lock().expect("requests");
561 assert_eq!(requests.len(), 1);
562 let (request, usage_source) = &requests[0];
563 assert_eq!(usage_source, "llm_query");
564 assert_eq!(request.model, "root-model");
565 assert_eq!(request.model_variant.as_deref(), Some("fast"));
566 assert!(matches!(
567 request.output,
568 lash_core::DirectOutputSpec::JsonSchema(_)
569 ));
570 let prompt = request
571 .messages
572 .iter()
573 .flat_map(|message| message.parts.iter())
574 .filter_map(|part| match part {
575 lash_core::DirectPart::Text(text) => Some(text.as_str()),
576 lash_core::DirectPart::Image(_) => None,
577 })
578 .collect::<Vec<_>>()
579 .join("\n");
580 assert!(prompt.contains("extract root cause"));
581 assert!(prompt.contains("\"log\": \"failed\""));
582 }
583
584 #[tokio::test]
585 async fn llm_query_uses_configured_model_override() {
586 let manager = Arc::new(DirectCompletionManager {
587 snapshot: RuntimeSessionState {
588 policy: lash_core::SessionPolicy {
589 model: model_spec("root-model", Some("medium")),
590 ..lash_core::SessionPolicy::default()
591 },
592 ..RuntimeSessionState::default()
593 },
594 requests: Mutex::new(Vec::new()),
595 response_text: r#"{"kind":"value","value":"done","error":null}"#.to_string(),
596 });
597 let provider = llm_query_provider(Some("gpt-5.5".to_string()), Some("low".to_string()));
598 let context = direct_completion_context(manager.clone());
599
600 let args = json!({ "task": "answer directly" });
601 let result = provider
602 .execute(ToolCall {
603 name: "llm_query",
604 args: &args,
605 context: &context,
606 progress: None,
607 })
608 .await;
609
610 assert!(result.is_success(), "{:?}", result.value_for_projection());
611 let requests = manager.requests.lock().expect("requests");
612 assert_eq!(requests.len(), 1);
613 let (request, usage_source) = &requests[0];
614 assert_eq!(usage_source, "llm_query");
615 assert_eq!(request.model, "gpt-5.5");
616 assert_eq!(request.model_variant.as_deref(), Some("low"));
617 }
618
619 #[tokio::test]
620 async fn llm_query_error_result_fails_tool_call() {
621 let manager = Arc::new(DirectCompletionManager {
622 snapshot: RuntimeSessionState {
623 policy: lash_core::SessionPolicy::default(),
624 ..RuntimeSessionState::default()
625 },
626 requests: Mutex::new(Vec::new()),
627 response_text: r#"{"kind":"error","value":null,"error":"missing required evidence"}"#
628 .to_string(),
629 });
630 let provider = llm_query_provider(None, None);
631 let context = direct_completion_context(manager);
632
633 let args = json!({ "task": "answer from missing evidence" });
634 let result = provider
635 .execute(ToolCall {
636 name: "llm_query",
637 args: &args,
638 context: &context,
639 progress: None,
640 })
641 .await;
642
643 assert!(!result.is_success());
644 assert_eq!(
645 result.value_for_projection(),
646 json!("missing required evidence")
647 );
648 }
649}