1use std::sync::Arc;
16
17use async_trait::async_trait;
18use lash_core::llm::types::{ProviderReasoningReplay, ProviderReplayMeta, ResponseTextMeta};
19use lash_core::plugin::{
20 PluginError, PluginFactory, PluginRegistrar, PluginSessionContext, ProtocolDriverPlugin,
21 ProtocolSessionContext, ProtocolSessionPlugin, SessionPlugin,
22};
23use lash_core::sansio::{
24 CheckpointResumeAction, CompletedToolCall, PendingToolCall, ProtocolDriverHandle,
25 WaitingExecState, WaitingLlmState,
26};
27use lash_core::session_model::message::PartAttachment;
28use lash_core::session_model::{
29 ConversationRecord, Message, MessageRole, Part, PartKind, PruneState, SessionEvent,
30 SessionEventRecord, fresh_message_id, make_error_event, reassign_part_ids, shared_parts,
31};
32
33mod batch;
34pub mod scenario_contracts;
35use batch::batch_tool_definition;
36use lash_core::{
37 CheckpointKind, DriverAction, DriverContextView, LlmOutputPart, LlmResponse,
38 ProtocolBuildInput, SessionError, ToolCall, ToolContract, ToolInvocation, ToolManifest,
39 ToolProvider, ToolResult, TurnDriverConfig, TurnDriverPreamble, TurnFinish, TurnOutcome,
40 TurnStop, append_assistant_text_part, normalized_response_parts, reasoning_part,
41};
42use serde_json::Value;
43
44const STANDARD_EXECUTION_SECTION: &str = r#"Use direct tool calls.
45
46- Use `batch` (up to 25 calls) for two or more independent tool calls. Serialize calls when later arguments depend on earlier results.
47- For direct conversational requests that need no tools, respond in prose only.
48
49Example — two independent reads in one `batch` call:
50
51```json
52{
53 "tool_calls": [
54 { "tool": "read_file", "parameters": { "path": "src/main.rs" } },
55 { "tool": "grep", "parameters": { "query": "ToolProvider", "path": "crates/lash/src/" } }
56 ]
57}
58```"#;
59
60const BATCH_MAX_TOOL_CALLS: usize = 25;
61const STANDARD_PROTOCOL_PLUGIN_ID: &str = "standard_protocol";
62
63#[derive(Default)]
66pub struct StandardProtocolPluginFactory;
67
68impl StandardProtocolPluginFactory {
69 pub fn new() -> Self {
70 Self
71 }
72}
73
74impl PluginFactory for StandardProtocolPluginFactory {
75 fn id(&self) -> &'static str {
76 STANDARD_PROTOCOL_PLUGIN_ID
77 }
78
79 fn build(&self, _ctx: &PluginSessionContext) -> Result<Arc<dyn SessionPlugin>, PluginError> {
80 Ok(Arc::new(StandardProtocolPlugin))
81 }
82}
83
84struct StandardProtocolPlugin;
85
86impl SessionPlugin for StandardProtocolPlugin {
87 fn id(&self) -> &'static str {
88 STANDARD_PROTOCOL_PLUGIN_ID
89 }
90
91 fn register(&self, reg: &mut PluginRegistrar) -> Result<(), PluginError> {
92 reg.protocol().session(Arc::new(StandardProtocolSession))?;
93 reg.protocol()
94 .protocol_driver(Arc::new(StandardProtocolDriver))?;
95 reg.tools().provider(Arc::new(StandardProtocolTools))?;
96 Ok(())
97 }
98}
99
100struct StandardProtocolSession;
101
102#[async_trait]
103impl ProtocolSessionPlugin for StandardProtocolSession {
104 async fn initialize_session(
105 &self,
106 _ctx: ProtocolSessionContext<'_>,
107 ) -> Result<(), SessionError> {
108 Ok(())
109 }
110}
111
112struct StandardProtocolDriver;
113
114impl ProtocolDriverPlugin for StandardProtocolDriver {
115 fn build_preamble(&self, input: ProtocolBuildInput) -> TurnDriverPreamble {
116 let tool_names = input.tool_catalog.tool_names();
117 let tool_names_fingerprint = input.tool_catalog.tool_names_fingerprint();
118 TurnDriverPreamble {
119 config: TurnDriverConfig::chat(
120 Arc::new(StandardDriver),
121 true,
122 Arc::new(turn_limit_exhausted_message),
123 ),
124 tool_specs: input.tool_catalog.model_tool_specs(),
125 tool_names,
126 tool_names_fingerprint,
127 execution_prompt: Arc::from(STANDARD_EXECUTION_SECTION),
128 prompt_contributions: input.extra_prompt_contributions,
129 }
130 }
131}
132
133fn turn_limit_exhausted_message(message_id: String, max_turns: usize) -> Message {
134 Message {
135 id: message_id.clone(),
136 role: MessageRole::System,
137 parts: shared_parts(vec![Part {
138 id: format!("{message_id}.p0"),
139 kind: PartKind::Error,
140 content: format!("Turn limit reached ({max_turns}) before a final assistant response."),
141 attachment: None,
142 tool_call_id: None,
143 tool_name: None,
144 tool_replay: None,
145 prune_state: PruneState::Intact,
146 reasoning_meta: None,
147 response_meta: None,
148 }]),
149 origin: None,
150 }
151}
152
153struct StandardProtocolTools;
154
155#[async_trait]
156impl ToolProvider for StandardProtocolTools {
157 fn tool_manifests(&self) -> Vec<ToolManifest> {
158 vec![batch_tool_definition().manifest()]
159 }
160
161 fn resolve_contract(&self, name: &str) -> Option<Arc<ToolContract>> {
162 (name == "batch").then(|| Arc::new(batch_tool_definition().contract()))
163 }
164
165 async fn execute(&self, call: ToolCall<'_>) -> ToolResult {
166 match call.name {
167 "batch" => execute_batch_tool_call(call).await,
168 _ => ToolResult::err_fmt(format_args!("Unknown tool: {}", call.name)),
169 }
170 }
171}
172
173#[derive(Debug)]
174struct BatchCallSpec {
175 index: usize,
176 tool: String,
177 parameters: Value,
178}
179
180async fn execute_batch_tool_call(call: ToolCall<'_>) -> ToolResult {
181 let args = call.args;
182 let specs = match parse_batch_specs(args) {
183 Ok(specs) => specs,
184 Err(err) => return err,
185 };
186
187 let mut immediate_outcomes = Vec::new();
188 let mut parallel_specs = Vec::new();
189 let dispatch = call.context.dispatch();
190
191 for spec in specs.into_iter().take(BATCH_MAX_TOOL_CALLS) {
192 if spec.tool == "batch" {
193 immediate_outcomes.push(serde_json::json!({
194 "index": spec.index,
195 "tool": spec.tool,
196 "success": false,
197 "duration_ms": 0,
198 "error": "Tool 'batch' is not allowed inside batch",
199 }));
200 continue;
201 }
202 let Some(manifest) = dispatch.callable_tool_manifest(&spec.tool) else {
203 let error = format!("Tool '{}' is unavailable in this session", spec.tool);
204 immediate_outcomes.push(serde_json::json!({
205 "index": spec.index,
206 "tool": spec.tool,
207 "success": false,
208 "duration_ms": 0,
209 "error": error,
210 }));
211 continue;
212 };
213 parallel_specs.push((
214 spec.index,
215 ToolInvocation::new(
216 format!(
217 "{}:{:02}",
218 call.context.tool_call_id().unwrap_or("batch"),
219 spec.index
220 ),
221 manifest.id,
222 spec.parameters,
223 ),
224 ));
225 }
226
227 let mut parallel_outcomes = dispatch
228 .batch(
229 parallel_specs
230 .iter()
231 .map(|(_, invocation)| invocation.clone())
232 .collect(),
233 )
234 .await;
235 for ((index, invocation), outcome) in
236 parallel_specs.into_iter().zip(parallel_outcomes.drain(..))
237 {
238 let tool_label = invocation.label();
239 let tool_record = outcome.record.unwrap_or(lash_core::ToolCallRecord {
240 call_id: Some(invocation.id),
241 tool: tool_label,
242 args: invocation.args,
243 output: outcome.output,
244 duration_ms: 0,
245 });
246 let mut result_record = serde_json::Map::new();
247 result_record.insert("index".to_string(), serde_json::json!(index));
248 result_record.insert("tool".to_string(), serde_json::json!(tool_record.tool));
249 result_record.insert(
250 "success".to_string(),
251 serde_json::json!(tool_record.output.is_success()),
252 );
253 result_record.insert(
254 "duration_ms".to_string(),
255 serde_json::json!(tool_record.duration_ms),
256 );
257 result_record.insert(
258 if tool_record.output.is_success() {
259 "result".to_string()
260 } else {
261 "error".to_string()
262 },
263 tool_record.output.value_for_projection(),
264 );
265 immediate_outcomes.push(Value::Object(result_record));
266 }
267
268 for overflow_index in BATCH_MAX_TOOL_CALLS
269 ..args
270 .get("tool_calls")
271 .and_then(|value| value.as_array())
272 .map(|value| value.len())
273 .unwrap_or_default()
274 {
275 immediate_outcomes.push(serde_json::json!({
276 "index": overflow_index,
277 "tool": args
278 .get("tool_calls")
279 .and_then(|value| value.as_array())
280 .and_then(|items| items.get(overflow_index))
281 .and_then(|item| item.get("tool"))
282 .and_then(|value| value.as_str())
283 .unwrap_or("unknown"),
284 "success": false,
285 "duration_ms": 0,
286 "error": "Maximum of 25 tool calls allowed in batch",
287 }));
288 }
289
290 immediate_outcomes.sort_by_key(|outcome| {
291 outcome
292 .get("index")
293 .and_then(|value| value.as_u64())
294 .unwrap_or(u64::MAX)
295 });
296 ToolResult::ok(serde_json::json!({
297 "results": immediate_outcomes,
298 }))
299}
300
301#[allow(clippy::result_large_err)]
302fn parse_batch_specs(args: &Value) -> Result<Vec<BatchCallSpec>, ToolResult> {
303 let Some(raw_calls) = args.get("tool_calls").and_then(|value| value.as_array()) else {
304 return Err(ToolResult::err_fmt(
305 "Missing required parameter: tool_calls",
306 ));
307 };
308 if raw_calls.is_empty() {
309 return Err(ToolResult::err_fmt(
310 "Invalid tool_calls: expected at least one call",
311 ));
312 }
313
314 let mut specs = Vec::with_capacity(raw_calls.len());
315 for (index, item) in raw_calls.iter().enumerate() {
316 let Some(object) = item.as_object() else {
317 return Err(ToolResult::err_fmt(format_args!(
318 "Invalid tool_calls[{index}]: expected object with tool and parameters"
319 )));
320 };
321 let Some(tool) = object
322 .get("tool")
323 .and_then(|value| value.as_str())
324 .map(str::trim)
325 .filter(|tool| !tool.is_empty())
326 else {
327 return Err(ToolResult::err_fmt(format_args!(
328 "Invalid tool_calls[{index}].tool: expected non-empty string"
329 )));
330 };
331 let parameters = object
332 .get("parameters")
333 .cloned()
334 .unwrap_or_else(|| serde_json::json!({}));
335 specs.push(BatchCallSpec {
336 index,
337 tool: tool.to_string(),
338 parameters,
339 });
340 }
341
342 Ok(specs)
343}
344
345pub struct StandardDriver;
355
356struct StandardToolCall {
357 call_id: String,
358 tool_name: String,
359 input_json: String,
360 replay: Option<ProviderReplayMeta>,
361}
362
363fn last_message_has_tool_result(ctx: &DriverContextView<'_>) -> bool {
364 ctx.messages().last().is_some_and(|message| {
365 matches!(message.role, MessageRole::User)
366 && message
367 .parts
368 .iter()
369 .any(|part| matches!(part.kind, PartKind::ToolResult))
370 })
371}
372
373impl ProtocolDriverHandle<lash_core::HostTurnProtocol> for StandardDriver {
374 fn prepare_protocol_iteration(&self, ctx: DriverContextView<'_>) -> Vec<DriverAction> {
375 vec![DriverAction::StartLlm {
376 request: ctx.project_llm_request(true),
377 driver_state: None,
378 }]
379 }
380
381 fn handle_llm_success(
382 &self,
383 ctx: DriverContextView<'_>,
384 _waiting: WaitingLlmState<lash_core::HostTurnProtocol>,
385 llm_response: LlmResponse,
386 text_streamed: bool,
387 ) -> Vec<DriverAction> {
388 let response_parts = normalized_response_parts(&llm_response);
389 let mut assistant_text = String::new();
390 let mut assistant_text_parts: Vec<(String, Option<ResponseTextMeta>)> = Vec::new();
391 let mut tool_calls: Vec<StandardToolCall> = Vec::new();
392 let mut reasoning_items: Vec<(usize, Option<ProviderReasoningReplay>, String)> = Vec::new();
400 let mut actions = Vec::new();
401
402 for part in response_parts {
403 match part {
404 LlmOutputPart::Text {
405 text,
406 response_meta,
407 } => {
408 if !text.is_empty() {
409 let previous_len = assistant_text.len();
410 append_assistant_text_part(&mut assistant_text, &text);
411 assistant_text_parts
412 .push((assistant_text[previous_len..].to_string(), response_meta));
413 if !text_streamed {
414 actions.push(DriverAction::Emit(SessionEvent::TextDelta {
415 content: assistant_text[previous_len..].to_string(),
416 }));
417 }
418 }
419 }
420 LlmOutputPart::Reasoning { text, replay } => {
421 let trimmed = text.trim().to_string();
422 if trimmed.is_empty() && replay.as_ref().is_none_or(|meta| meta.is_empty()) {
425 continue;
426 }
427 reasoning_items.push((tool_calls.len(), replay, trimmed));
428 }
429 LlmOutputPart::ToolCall {
430 call_id,
431 tool_name,
432 input_json,
433 replay,
434 } => {
435 tool_calls.push(StandardToolCall {
436 call_id,
437 tool_name,
438 input_json,
439 replay,
440 });
441 }
442 }
443 }
444
445 actions.push(DriverAction::Emit(SessionEvent::LlmResponse {
446 protocol_iteration: ctx.protocol_iteration(),
447 content: assistant_text.clone(),
448 duration_ms: 0,
449 }));
450
451 if tool_calls.is_empty() {
452 if assistant_text.trim().is_empty() && reasoning_items.is_empty() {
453 if last_message_has_tool_result(&ctx) {
454 actions.push(DriverAction::StartCheckpoint {
458 checkpoint: CheckpointKind::BeforeCompletion,
459 on_empty: CheckpointResumeAction::Finish(TurnOutcome::Finished(
460 TurnFinish::AssistantMessage {
461 text: String::new(),
462 },
463 )),
464 });
465 return actions;
466 }
467 actions.push(DriverAction::Emit(make_error_event(
468 "llm_provider",
469 Some("empty_response"),
470 "Model returned no assistant text or tool calls.",
471 None,
472 )));
473 actions.push(DriverAction::Finish(TurnOutcome::Stopped(
474 TurnStop::ProviderError,
475 )));
476 return actions;
477 }
478
479 let asst_id = fresh_message_id();
480 let mut parts_out = Vec::new();
481 for (_, meta, text) in reasoning_items {
482 parts_out.push(reasoning_part(&asst_id, parts_out.len(), text, meta));
483 }
484 for (content, response_meta) in assistant_text_parts {
485 if content.trim().is_empty() {
486 continue;
487 }
488 parts_out.push(Part {
489 id: format!("{}.p{}", asst_id, parts_out.len()),
490 kind: PartKind::Prose,
491 content,
492 attachment: None,
493 tool_call_id: None,
494 tool_name: None,
495 tool_replay: None,
496 prune_state: PruneState::Intact,
497 reasoning_meta: None,
498 response_meta,
499 });
500 }
501 if parts_out.is_empty() {
502 actions.push(DriverAction::Emit(make_error_event(
503 "llm_provider",
504 Some("empty_response"),
505 "Model returned no assistant text or tool calls.",
506 None,
507 )));
508 actions.push(DriverAction::Finish(TurnOutcome::Stopped(
509 TurnStop::ProviderError,
510 )));
511 return actions;
512 }
513 actions.push(DriverAction::StartCheckpoint {
514 checkpoint: CheckpointKind::BeforeCompletion,
515 on_empty: CheckpointResumeAction::Finish(TurnOutcome::Finished(
516 TurnFinish::AssistantMessage {
517 text: assistant_text.clone(),
518 },
519 )),
520 });
521 return actions;
522 }
523
524 let asst_id = fresh_message_id();
525 let mut assistant_parts = Vec::new();
526 for (content, response_meta) in assistant_text_parts {
527 if content.trim().is_empty() {
528 continue;
529 }
530 assistant_parts.push(Part {
531 id: format!("{}.p{}", asst_id, assistant_parts.len()),
532 kind: PartKind::Prose,
533 content,
534 attachment: None,
535 tool_call_id: None,
536 tool_name: None,
537 tool_replay: None,
538 prune_state: PruneState::Intact,
539 reasoning_meta: None,
540 response_meta,
541 });
542 }
543
544 let mut calls = Vec::new();
545 let mut reasoning_iter = reasoning_items.into_iter().peekable();
550 for (tool_index, tool_call) in tool_calls.into_iter().enumerate() {
551 while let Some((insert_index, _, _)) = reasoning_iter.peek() {
552 if *insert_index > tool_index {
553 break;
554 }
555 let (_, meta, text) = reasoning_iter.next().expect("peek ok");
556 assistant_parts.push(reasoning_part(&asst_id, assistant_parts.len(), text, meta));
557 }
558 assistant_parts.push(Part {
559 id: format!("{}.p{}", asst_id, assistant_parts.len()),
560 kind: PartKind::ToolCall,
561 content: tool_call.input_json.clone(),
562 attachment: None,
563 tool_call_id: Some(tool_call.call_id.clone()),
564 tool_name: Some(tool_call.tool_name.clone()),
565 tool_replay: tool_call.replay.clone(),
566 prune_state: PruneState::Intact,
567 reasoning_meta: None,
568 response_meta: None,
569 });
570
571 let args = serde_json::from_str::<Value>(&tool_call.input_json)
572 .unwrap_or_else(|_| serde_json::json!({}));
573 calls.push(PendingToolCall {
574 call_id: tool_call.call_id,
575 tool_name: tool_call.tool_name,
576 args,
577 replay: tool_call.replay,
578 });
579 }
580 for (_, meta, text) in reasoning_iter {
581 assistant_parts.push(reasoning_part(&asst_id, assistant_parts.len(), text, meta));
582 }
583
584 if !assistant_parts.is_empty() {
585 actions.push(DriverAction::AppendEvents(vec![conversation_event(
586 Message {
587 id: asst_id,
588 role: MessageRole::Assistant,
589 parts: shared_parts(assistant_parts),
590 origin: None,
591 },
592 )]));
593 }
594
595 actions.push(DriverAction::StartTools { calls });
596 actions
597 }
598
599 fn handle_tool_results(
600 &self,
601 ctx: DriverContextView<'_>,
602 completed: Vec<CompletedToolCall>,
603 ) -> Vec<DriverAction> {
604 let mut actions = Vec::new();
605 let mut result_parts = Vec::new();
606 let mut terminal_outcome = None;
607
608 for outcome in completed {
609 if terminal_outcome.is_none() && outcome.output.is_success() {
610 terminal_outcome = match outcome.output.control.as_ref() {
611 Some(lash_core::ToolControl::SwitchAgentFrame {
612 frame_id,
613 task: Some(task),
614 ..
615 }) if !frame_id.trim().is_empty() && !task.trim().is_empty() => {
616 Some(TurnOutcome::AgentFrameSwitch {
617 frame_id: frame_id.clone(),
618 task: task.clone(),
619 })
620 }
621 Some(lash_core::ToolControl::Finish { value }) => {
622 Some(TurnOutcome::Finished(TurnFinish::ToolValue {
623 tool_name: outcome.tool_name.clone(),
624 value: value.to_json_value(),
625 }))
626 }
627 Some(lash_core::ToolControl::Fail { failure }) => {
628 Some(TurnOutcome::Stopped(TurnStop::ToolError {
629 tool_name: outcome.tool_name.clone(),
630 value: failure.to_json_value(),
631 }))
632 }
633 _ => None,
634 };
635 }
636
637 append_model_return_parts(&mut result_parts, outcome.model_return);
638 }
639
640 if !result_parts.is_empty() {
641 let user_id = fresh_message_id();
642 reassign_part_ids(&user_id, &mut result_parts);
643 actions.push(DriverAction::AppendEvents(vec![conversation_event(
644 Message {
645 id: user_id,
646 role: MessageRole::User,
647 parts: shared_parts(result_parts),
648 origin: None,
649 },
650 )]));
651 }
652
653 if let Some(outcome) = terminal_outcome {
654 actions.push(DriverAction::Finish(outcome));
655 return actions;
656 }
657
658 actions.push(DriverAction::AdvanceProtocolIteration);
659 let next_protocol_iteration = ctx.protocol_iteration() + 1;
660 if let Some(max_turns) = ctx.max_turns()
661 && next_protocol_iteration >= ctx.protocol_run_offset() + max_turns
662 {
663 let message_id = fresh_message_id();
664 actions.push(DriverAction::AppendEvents(vec![conversation_event(
665 turn_limit_exhausted_message(message_id, max_turns),
666 )]));
667 actions.push(DriverAction::Finish(TurnOutcome::Stopped(
668 TurnStop::MaxTurns,
669 )));
670 return actions;
671 }
672
673 actions.push(DriverAction::StartCheckpoint {
674 checkpoint: CheckpointKind::AfterWork,
675 on_empty: CheckpointResumeAction::PrepareIteration,
676 });
677 actions
678 }
679
680 fn handle_exec_result(
681 &self,
682 _ctx: DriverContextView<'_>,
683 _waiting: WaitingExecState<lash_core::HostTurnProtocol>,
684 _result: Result<lash_core::ExecResponse, String>,
685 ) -> Vec<DriverAction> {
686 Vec::new()
687 }
688}
689
690fn append_model_return_parts(parts: &mut Vec<Part>, model_return: lash_core::ModelToolReturn) {
691 for part in model_return.parts {
692 match part {
693 lash_core::ModelToolReturnPart::Text { text } => {
694 if text.is_empty() {
695 continue;
696 }
697 parts.push(Part {
698 id: String::new(),
699 kind: PartKind::ToolResult,
700 content: text,
701 attachment: None,
702 tool_call_id: Some(model_return.call_id.clone()),
703 tool_name: Some(model_return.tool_name.clone()),
704 tool_replay: None,
705 prune_state: PruneState::Intact,
706 reasoning_meta: None,
707 response_meta: None,
708 });
709 }
710 lash_core::ModelToolReturnPart::Attachment(reference) => {
711 parts.push(Part {
712 id: String::new(),
713 kind: PartKind::Image,
714 content: String::new(),
715 attachment: Some(PartAttachment { reference }),
716 tool_call_id: Some(model_return.call_id.clone()),
717 tool_name: Some(model_return.tool_name.clone()),
718 tool_replay: None,
719 prune_state: PruneState::Intact,
720 reasoning_meta: None,
721 response_meta: None,
722 });
723 }
724 }
725 }
726}
727
728fn conversation_event(message: Message) -> SessionEventRecord {
729 SessionEventRecord::Conversation(ConversationRecord::from_message(message))
730}
731
732#[cfg(test)]
733mod tests {
734 use super::*;
735 use lash_core::{
736 AttachmentId, AttachmentMeta, ImageMediaType, MediaType, ModelToolReturn, ToolCallOutput,
737 ToolValue,
738 };
739 use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
740 use tokio::sync::Barrier;
741 use tokio::time::{Duration, timeout};
742
743 fn image_ref(id: &str) -> lash_core::AttachmentRef {
744 AttachmentMeta::new(
745 AttachmentId::new(id),
746 MediaType::Image(ImageMediaType::Png),
747 4,
748 Some(1),
749 Some(1),
750 Some("tiny".to_string()),
751 )
752 .as_ref()
753 }
754
755 #[test]
756 fn standard_protocol_factory_id_is_stable_plugin_contract() {
757 let factory = StandardProtocolPluginFactory::new();
758
759 assert_eq!(factory.id(), STANDARD_PROTOCOL_PLUGIN_ID);
760 assert_eq!(factory.id(), "standard_protocol");
761 }
762
763 #[derive(Clone, Debug)]
764 struct BatchRuntimeProvider {
765 calls: Arc<AtomicUsize>,
766 saw_batch_result: Arc<AtomicBool>,
767 }
768
769 #[async_trait::async_trait]
770 impl lash_core::Provider for BatchRuntimeProvider {
771 fn kind(&self) -> &'static str {
772 "stub"
773 }
774
775 fn options(&self) -> lash_core::ProviderOptions {
776 lash_core::ProviderOptions::default()
777 }
778
779 fn set_options(&mut self, _options: lash_core::ProviderOptions) {}
780
781 fn serialize_config(&self) -> serde_json::Value {
782 serde_json::json!({})
783 }
784
785 async fn complete(
786 &mut self,
787 request: lash_core::LlmRequest,
788 ) -> Result<lash_core::LlmResponse, lash_core::LlmTransportError> {
789 let call_index = self.calls.fetch_add(1, Ordering::SeqCst);
790 if call_index == 0 {
791 return Ok(lash_core::LlmResponse {
792 parts: vec![lash_core::LlmOutputPart::ToolCall {
793 call_id: "batch-call".to_string(),
794 tool_name: "batch".to_string(),
795 input_json: serde_json::json!({
796 "tool_calls": [
797 {"tool": "alpha", "parameters": {}},
798 {"tool": "beta", "parameters": {"value": "fail"}}
799 ]
800 })
801 .to_string(),
802 replay: None,
803 }],
804 ..lash_core::LlmResponse::default()
805 });
806 }
807
808 let projected_messages = format!("{:?}", request.messages);
809 if projected_messages.contains("alpha") && projected_messages.contains("beta failed") {
810 self.saw_batch_result.store(true, Ordering::SeqCst);
811 }
812 Ok(lash_core::LlmResponse {
813 full_text: "done".to_string(),
814 parts: vec![lash_core::LlmOutputPart::Text {
815 text: "done".to_string(),
816 response_meta: None,
817 }],
818 ..lash_core::LlmResponse::default()
819 })
820 }
821
822 fn clone_boxed(&self) -> Box<dyn lash_core::Provider> {
823 Box::new(self.clone())
824 }
825 }
826
827 #[derive(Debug)]
828 struct BatchRuntimeTools {
829 barrier: Arc<Barrier>,
830 started: Arc<AtomicUsize>,
831 }
832
833 fn runtime_test_tool(name: &str) -> lash_core::ToolDefinition {
834 lash_core::ToolDefinition::raw(
835 format!("tool:{name}"),
836 name,
837 "",
838 serde_json::json!({
839 "type": "object",
840 "properties": {
841 "value": { "type": "string" }
842 },
843 "additionalProperties": true
844 }),
845 serde_json::json!({ "type": "string" }),
846 )
847 .with_scheduling(lash_core::ToolScheduling::Parallel)
848 }
849
850 #[async_trait::async_trait]
851 impl ToolProvider for BatchRuntimeTools {
852 fn tool_manifests(&self) -> Vec<ToolManifest> {
853 vec![
854 runtime_test_tool("alpha").manifest(),
855 runtime_test_tool("beta").manifest(),
856 ]
857 }
858
859 fn resolve_contract(&self, name: &str) -> Option<Arc<ToolContract>> {
860 match name {
861 "alpha" | "beta" => Some(Arc::new(runtime_test_tool(name).contract())),
862 _ => None,
863 }
864 }
865
866 async fn execute(&self, call: ToolCall<'_>) -> ToolResult {
867 self.started.fetch_add(1, Ordering::SeqCst);
868 if timeout(Duration::from_millis(100), self.barrier.wait())
869 .await
870 .is_err()
871 {
872 return ToolResult::err_fmt("batch child tools did not run concurrently");
873 }
874 if call.name == "beta"
875 && call.args.get("value").and_then(|value| value.as_str()) == Some("fail")
876 {
877 return ToolResult::err_fmt("beta failed");
878 }
879 ToolResult::ok(serde_json::json!(call.name))
880 }
881 }
882
883 #[derive(Clone, Default)]
884 struct CountingEffectController {
885 kinds: Arc<std::sync::Mutex<Vec<lash_core::RuntimeEffectKind>>>,
886 }
887
888 impl CountingEffectController {
889 fn count(&self, kind: lash_core::RuntimeEffectKind) -> usize {
890 self.kinds
891 .lock()
892 .expect("effect kinds")
893 .iter()
894 .filter(|candidate| **candidate == kind)
895 .count()
896 }
897 }
898
899 #[derive(Default)]
900 struct DurableMemoryAttachmentStore {
901 inner: lash_core::InMemoryAttachmentStore,
902 }
903
904 #[async_trait::async_trait]
905 impl lash_core::AttachmentStore for DurableMemoryAttachmentStore {
906 fn persistence(&self) -> lash_core::AttachmentStorePersistence {
907 lash_core::AttachmentStorePersistence::Durable
908 }
909
910 async fn put(
911 &self,
912 bytes: Vec<u8>,
913 meta: lash_core::AttachmentCreateMeta,
914 ) -> Result<lash_core::AttachmentRef, lash_core::AttachmentStoreError> {
915 self.inner.put(bytes, meta).await
916 }
917
918 async fn get(
919 &self,
920 id: &lash_core::AttachmentId,
921 ) -> Result<lash_core::StoredAttachment, lash_core::AttachmentStoreError> {
922 self.inner.get(id).await
923 }
924
925 async fn delete(
926 &self,
927 id: &lash_core::AttachmentId,
928 ) -> Result<(), lash_core::AttachmentStoreError> {
929 self.inner.delete(id).await
930 }
931 }
932
933 #[derive(Default)]
934 struct DurableMemoryProcessEnvStore {
935 inner: lash_core::InMemoryProcessExecutionEnvStore,
936 }
937
938 #[async_trait::async_trait]
939 impl lash_core::ProcessExecutionEnvStore for DurableMemoryProcessEnvStore {
940 fn durability_tier(&self) -> lash_core::DurabilityTier {
941 lash_core::DurabilityTier::Durable
942 }
943
944 async fn put_process_execution_env(
945 &self,
946 env_ref: &lash_core::ProcessExecutionEnvRef,
947 bytes: &[u8],
948 ) -> Result<(), lash_core::PluginError> {
949 self.inner.put_process_execution_env(env_ref, bytes).await
950 }
951
952 async fn get_process_execution_env(
953 &self,
954 env_ref: &lash_core::ProcessExecutionEnvRef,
955 ) -> Result<Option<Vec<u8>>, lash_core::PluginError> {
956 self.inner.get_process_execution_env(env_ref).await
957 }
958 }
959
960 impl lash_core::AwaitEventResolver for CountingEffectController {
961 fn durability_tier(&self) -> lash_core::DurabilityTier {
962 lash_core::DurabilityTier::Durable
963 }
964 }
965
966 #[async_trait::async_trait]
967 impl lash_core::RuntimeEffectController for CountingEffectController {
968 async fn execute_effect(
969 &self,
970 envelope: lash_core::RuntimeEffectEnvelope,
971 local_executor: lash_core::RuntimeEffectLocalExecutor<'_>,
972 ) -> Result<lash_core::RuntimeEffectOutcome, lash_core::RuntimeEffectControllerError>
973 {
974 self.kinds
975 .lock()
976 .expect("effect kinds")
977 .push(envelope.command.kind());
978 local_executor.execute(envelope).await
979 }
980 }
981
982 #[tokio::test]
983 async fn standard_batch_tool_rejects_nested_batch_inside_durable_attempt() {
984 let provider_calls = Arc::new(AtomicUsize::new(0));
985 let saw_batch_result = Arc::new(AtomicBool::new(false));
986 let provider = BatchRuntimeProvider {
987 calls: Arc::clone(&provider_calls),
988 saw_batch_result: Arc::clone(&saw_batch_result),
989 };
990 let provider_handle = lash_core::ProviderHandle::new(lash_core::ProviderComponents::new(
991 Box::new(provider),
992 Arc::new(lash_core::StaticModelPolicy::new()),
993 ));
994 let mut host = lash_core::RuntimeHostConfig::in_memory();
995 host.providers.provider_resolver =
996 Arc::new(lash_core::SingleProviderResolver::new(provider_handle));
997 host.durability.attachment_store = Arc::new(DurableMemoryAttachmentStore::default());
998 host.durability.process_env_store = Arc::new(DurableMemoryProcessEnvStore::default());
999 let started = Arc::new(AtomicUsize::new(0));
1000 let factories: Vec<Arc<dyn lash_core::PluginFactory>> = vec![
1001 Arc::new(StandardProtocolPluginFactory::new()),
1002 Arc::new(lash_core::plugin::StaticPluginFactory::new(
1003 "standard-batch-test-tools",
1004 lash_core::PluginSpec::new().with_tool_provider(Arc::new(BatchRuntimeTools {
1005 barrier: Arc::new(Barrier::new(2)),
1006 started: Arc::clone(&started),
1007 })),
1008 )),
1009 ];
1010 let policy = lash_core::SessionPolicy {
1011 provider_id: "stub".to_string(),
1012 model: lash_core::ModelSpec::from_token_limits("mock-model", None, 200_000, None)
1013 .expect("valid model"),
1014 ..lash_core::SessionPolicy::default()
1015 };
1016 let controller = CountingEffectController::default();
1017 let scoped_controller = lash_core::ScopedEffectController::shared(
1018 Arc::new(controller.clone()),
1019 lash_core::ExecutionScope::turn("standard-batch-session", "turn-1"),
1020 )
1021 .expect("scoped controller");
1022 let mut runtime = lash_core::LashRuntime::builder()
1023 .with_session_id("standard-batch-session")
1024 .with_policy(policy)
1025 .with_runtime_host(host)
1026 .with_plugin_factories(factories)
1027 .build()
1028 .await
1029 .expect("runtime");
1030
1031 let turn = runtime
1032 .stream_turn(
1033 lash_core::TurnInput::text("run the batch"),
1034 lash_core::TurnOptions::new(
1035 tokio_util::sync::CancellationToken::new(),
1036 scoped_controller,
1037 ),
1038 )
1039 .await
1040 .expect("turn");
1041
1042 assert!(matches!(turn.outcome, lash_core::TurnOutcome::Finished(_)));
1043 assert_eq!(provider_calls.load(Ordering::SeqCst), 2);
1044 assert_eq!(started.load(Ordering::SeqCst), 0);
1045 assert!(!saw_batch_result.load(Ordering::SeqCst));
1046 assert_eq!(controller.count(lash_core::RuntimeEffectKind::ToolBatch), 1);
1047 assert_eq!(
1048 controller.count(lash_core::RuntimeEffectKind::ToolAttempt),
1049 1
1050 );
1051 }
1052
1053 #[test]
1054 fn tool_attachment_round_trips_to_part_kind_image() {
1055 let attachment = image_ref("att-1");
1056 let output = ToolCallOutput::success(ToolValue::Attachment(attachment.clone()));
1057 let model_return =
1058 ModelToolReturn::from_output("call-9".to_string(), "screenshot".to_string(), &output);
1059
1060 let mut parts: Vec<Part> = Vec::new();
1061 append_model_return_parts(&mut parts, model_return);
1062
1063 assert_eq!(parts.len(), 1, "single attachment yields single part");
1064 let part = &parts[0];
1065 assert!(matches!(part.kind, PartKind::Image));
1066 assert_eq!(part.content, "");
1067 assert_eq!(part.tool_call_id.as_deref(), Some("call-9"));
1068 assert_eq!(part.tool_name.as_deref(), Some("screenshot"));
1069 let part_attachment = part.attachment.as_ref().expect("attachment present");
1070 assert_eq!(part_attachment.reference.id, attachment.id);
1071 }
1072
1073 #[test]
1074 fn tool_text_and_attachment_round_trip_preserves_order() {
1075 let attachment = image_ref("att-2");
1076 let output = ToolCallOutput::success(ToolValue::Array(vec![
1077 ToolValue::String("before".into()),
1078 ToolValue::Attachment(attachment.clone()),
1079 ToolValue::String("after".into()),
1080 ]));
1081 let model_return =
1082 ModelToolReturn::from_output("call-10".to_string(), "snap".to_string(), &output);
1083
1084 let mut parts: Vec<Part> = Vec::new();
1085 append_model_return_parts(&mut parts, model_return);
1086
1087 assert_eq!(parts.len(), 3, "text + image + text yields three parts");
1090 assert!(matches!(parts[0].kind, PartKind::ToolResult));
1091 assert!(parts[0].content.starts_with("[\"before\""));
1092 assert!(matches!(parts[1].kind, PartKind::Image));
1093 assert_eq!(
1094 parts[1]
1095 .attachment
1096 .as_ref()
1097 .expect("attachment")
1098 .reference
1099 .id,
1100 attachment.id
1101 );
1102 assert!(matches!(parts[2].kind, PartKind::ToolResult));
1103 assert!(parts[2].content.ends_with("\"after\"]"));
1104 }
1105}