1use std::sync::Arc;
16
17use async_trait::async_trait;
18use lash_core::llm::types::{ProviderReasoningReplay, ProviderReplayMeta, ResponseTextMeta};
19use lash_core::plugin::{
20 PluginError, PluginFactory, PluginRegistrar, PluginSessionContext, ProtocolDriverPlugin,
21 ProtocolSessionContext, ProtocolSessionPlugin, SessionPlugin,
22};
23use lash_core::sansio::{
24 CheckpointResumeAction, CompletedToolCall, PendingToolCall, ProtocolDriverHandle,
25 WaitingExecState, WaitingLlmState,
26};
27use lash_core::session_model::message::PartAttachment;
28use lash_core::session_model::{
29 ConversationRecord, Message, MessageRole, Part, PartKind, PruneState, SessionEvent,
30 SessionEventRecord, fresh_message_id, make_error_event, reassign_part_ids, shared_parts,
31};
32
33mod batch;
34use batch::batch_tool_definition;
35use lash_core::{
36 CheckpointKind, DriverAction, DriverContextView, LlmOutputPart, LlmResponse,
37 ProtocolBuildInput, SessionError, ToolCall, ToolContract, ToolInvocation, ToolManifest,
38 ToolProvider, ToolResult, TurnDriverConfig, TurnDriverPreamble, TurnFinish, TurnOutcome,
39 TurnStop, append_assistant_text_part, normalized_response_parts, reasoning_part,
40};
41use serde_json::Value;
42
43const STANDARD_EXECUTION_SECTION: &str = r#"Use direct tool calls.
44
45- Use `batch` (up to 25 calls) for two or more independent tool calls. Serialize calls when later arguments depend on earlier results.
46- For direct conversational requests that need no tools, respond in prose only.
47
48Example — two independent reads in one `batch` call:
49
50```json
51{
52 "tool_calls": [
53 { "tool": "read_file", "parameters": { "path": "src/main.rs" } },
54 { "tool": "grep", "parameters": { "query": "ToolProvider", "path": "crates/lash/src/" } }
55 ]
56}
57```"#;
58
59const BATCH_MAX_TOOL_CALLS: usize = 25;
60
61#[derive(Default)]
64pub struct StandardProtocolPluginFactory;
65
66impl StandardProtocolPluginFactory {
67 pub fn new() -> Self {
68 Self
69 }
70}
71
72impl PluginFactory for StandardProtocolPluginFactory {
73 fn id(&self) -> &'static str {
74 "standard_protocol"
75 }
76
77 fn build(&self, _ctx: &PluginSessionContext) -> Result<Arc<dyn SessionPlugin>, PluginError> {
78 Ok(Arc::new(StandardProtocolPlugin))
79 }
80}
81
82struct StandardProtocolPlugin;
83
84impl SessionPlugin for StandardProtocolPlugin {
85 fn id(&self) -> &'static str {
86 "standard_protocol"
87 }
88
89 fn register(&self, reg: &mut PluginRegistrar) -> Result<(), PluginError> {
90 reg.protocol().session(Arc::new(StandardProtocolSession))?;
91 reg.protocol()
92 .protocol_driver(Arc::new(StandardProtocolDriver))?;
93 reg.tools().provider(Arc::new(StandardProtocolTools))?;
94 Ok(())
95 }
96}
97
98struct StandardProtocolSession;
99
100#[async_trait]
101impl ProtocolSessionPlugin for StandardProtocolSession {
102 async fn initialize_session(
103 &self,
104 _ctx: ProtocolSessionContext<'_>,
105 ) -> Result<(), SessionError> {
106 Ok(())
107 }
108}
109
110struct StandardProtocolDriver;
111
112impl ProtocolDriverPlugin for StandardProtocolDriver {
113 fn build_preamble(&self, input: ProtocolBuildInput) -> TurnDriverPreamble {
114 let tool_names = input.tool_catalog.tool_names();
115 let tool_names_fingerprint = input.tool_catalog.tool_names_fingerprint();
116 TurnDriverPreamble {
117 config: TurnDriverConfig::chat(
118 Arc::new(StandardDriver),
119 true,
120 Arc::new(turn_limit_exhausted_message),
121 ),
122 tool_specs: input.tool_catalog.model_tool_specs(),
123 tool_names,
124 tool_names_fingerprint,
125 execution_prompt: Arc::from(STANDARD_EXECUTION_SECTION),
126 prompt_contributions: input.extra_prompt_contributions,
127 }
128 }
129}
130
131fn turn_limit_exhausted_message(message_id: String, max_turns: usize) -> Message {
132 Message {
133 id: message_id.clone(),
134 role: MessageRole::System,
135 parts: shared_parts(vec![Part {
136 id: format!("{message_id}.p0"),
137 kind: PartKind::Error,
138 content: format!("Turn limit reached ({max_turns}) before a final assistant response."),
139 attachment: None,
140 tool_call_id: None,
141 tool_name: None,
142 tool_replay: None,
143 prune_state: PruneState::Intact,
144 reasoning_meta: None,
145 response_meta: None,
146 }]),
147 origin: None,
148 }
149}
150
151struct StandardProtocolTools;
152
153#[async_trait]
154impl ToolProvider for StandardProtocolTools {
155 fn tool_manifests(&self) -> Vec<ToolManifest> {
156 vec![batch_tool_definition().manifest()]
157 }
158
159 fn resolve_contract(&self, name: &str) -> Option<Arc<ToolContract>> {
160 (name == "batch").then(|| Arc::new(batch_tool_definition().contract()))
161 }
162
163 async fn execute(&self, call: ToolCall<'_>) -> ToolResult {
164 match call.name {
165 "batch" => execute_batch_tool_call(call).await,
166 _ => ToolResult::err_fmt(format_args!("Unknown tool: {}", call.name)),
167 }
168 }
169}
170
171#[derive(Debug)]
172struct BatchCallSpec {
173 index: usize,
174 tool: String,
175 parameters: Value,
176}
177
178async fn execute_batch_tool_call(call: ToolCall<'_>) -> ToolResult {
179 let args = call.args;
180 let specs = match parse_batch_specs(args) {
181 Ok(specs) => specs,
182 Err(err) => return err,
183 };
184
185 let mut immediate_outcomes = Vec::new();
186 let mut parallel_specs = Vec::new();
187 let dispatch = call.context.dispatch();
188
189 for spec in specs.into_iter().take(BATCH_MAX_TOOL_CALLS) {
190 if spec.tool == "batch" {
191 immediate_outcomes.push(serde_json::json!({
192 "index": spec.index,
193 "tool": spec.tool,
194 "success": false,
195 "duration_ms": 0,
196 "error": "Tool 'batch' is not allowed inside batch",
197 }));
198 continue;
199 }
200 let Some(manifest) = dispatch.callable_tool_manifest(&spec.tool) else {
201 let error = format!("Tool '{}' is unavailable in this session", spec.tool);
202 immediate_outcomes.push(serde_json::json!({
203 "index": spec.index,
204 "tool": spec.tool,
205 "success": false,
206 "duration_ms": 0,
207 "error": error,
208 }));
209 continue;
210 };
211 parallel_specs.push((
212 spec.index,
213 ToolInvocation::new(
214 format!(
215 "{}:{:02}",
216 call.context.tool_call_id().unwrap_or("batch"),
217 spec.index
218 ),
219 manifest.id,
220 spec.parameters,
221 ),
222 ));
223 }
224
225 let mut parallel_outcomes = dispatch
226 .batch(
227 parallel_specs
228 .iter()
229 .map(|(_, invocation)| invocation.clone())
230 .collect(),
231 )
232 .await;
233 for ((index, invocation), outcome) in
234 parallel_specs.into_iter().zip(parallel_outcomes.drain(..))
235 {
236 let tool_label = invocation.label();
237 let tool_record = outcome.record.unwrap_or(lash_core::ToolCallRecord {
238 call_id: Some(invocation.id),
239 tool: tool_label,
240 args: invocation.args,
241 output: outcome.output,
242 duration_ms: 0,
243 });
244 let mut result_record = serde_json::Map::new();
245 result_record.insert("index".to_string(), serde_json::json!(index));
246 result_record.insert("tool".to_string(), serde_json::json!(tool_record.tool));
247 result_record.insert(
248 "success".to_string(),
249 serde_json::json!(tool_record.output.is_success()),
250 );
251 result_record.insert(
252 "duration_ms".to_string(),
253 serde_json::json!(tool_record.duration_ms),
254 );
255 result_record.insert(
256 if tool_record.output.is_success() {
257 "result".to_string()
258 } else {
259 "error".to_string()
260 },
261 tool_record.output.value_for_projection(),
262 );
263 immediate_outcomes.push(Value::Object(result_record));
264 }
265
266 for overflow_index in BATCH_MAX_TOOL_CALLS
267 ..args
268 .get("tool_calls")
269 .and_then(|value| value.as_array())
270 .map(|value| value.len())
271 .unwrap_or_default()
272 {
273 immediate_outcomes.push(serde_json::json!({
274 "index": overflow_index,
275 "tool": args
276 .get("tool_calls")
277 .and_then(|value| value.as_array())
278 .and_then(|items| items.get(overflow_index))
279 .and_then(|item| item.get("tool"))
280 .and_then(|value| value.as_str())
281 .unwrap_or("unknown"),
282 "success": false,
283 "duration_ms": 0,
284 "error": "Maximum of 25 tool calls allowed in batch",
285 }));
286 }
287
288 immediate_outcomes.sort_by_key(|outcome| {
289 outcome
290 .get("index")
291 .and_then(|value| value.as_u64())
292 .unwrap_or(u64::MAX)
293 });
294 ToolResult::ok(serde_json::json!({
295 "results": immediate_outcomes,
296 }))
297}
298
299#[allow(clippy::result_large_err)]
300fn parse_batch_specs(args: &Value) -> Result<Vec<BatchCallSpec>, ToolResult> {
301 let Some(raw_calls) = args.get("tool_calls").and_then(|value| value.as_array()) else {
302 return Err(ToolResult::err_fmt(
303 "Missing required parameter: tool_calls",
304 ));
305 };
306 if raw_calls.is_empty() {
307 return Err(ToolResult::err_fmt(
308 "Invalid tool_calls: expected at least one call",
309 ));
310 }
311
312 let mut specs = Vec::with_capacity(raw_calls.len());
313 for (index, item) in raw_calls.iter().enumerate() {
314 let Some(object) = item.as_object() else {
315 return Err(ToolResult::err_fmt(format_args!(
316 "Invalid tool_calls[{index}]: expected object with tool and parameters"
317 )));
318 };
319 let Some(tool) = object
320 .get("tool")
321 .and_then(|value| value.as_str())
322 .map(str::trim)
323 .filter(|tool| !tool.is_empty())
324 else {
325 return Err(ToolResult::err_fmt(format_args!(
326 "Invalid tool_calls[{index}].tool: expected non-empty string"
327 )));
328 };
329 let parameters = object
330 .get("parameters")
331 .cloned()
332 .unwrap_or_else(|| serde_json::json!({}));
333 specs.push(BatchCallSpec {
334 index,
335 tool: tool.to_string(),
336 parameters,
337 });
338 }
339
340 Ok(specs)
341}
342
343pub struct StandardDriver;
353
354struct StandardToolCall {
355 call_id: String,
356 tool_name: String,
357 input_json: String,
358 replay: Option<ProviderReplayMeta>,
359}
360
361fn last_message_has_tool_result(ctx: &DriverContextView<'_>) -> bool {
362 ctx.messages().last().is_some_and(|message| {
363 matches!(message.role, MessageRole::User)
364 && message
365 .parts
366 .iter()
367 .any(|part| matches!(part.kind, PartKind::ToolResult))
368 })
369}
370
371impl ProtocolDriverHandle<lash_core::HostTurnProtocol> for StandardDriver {
372 fn prepare_protocol_iteration(&self, ctx: DriverContextView<'_>) -> Vec<DriverAction> {
373 vec![DriverAction::StartLlm {
374 request: ctx.project_llm_request(true),
375 driver_state: None,
376 }]
377 }
378
379 fn handle_llm_success(
380 &self,
381 ctx: DriverContextView<'_>,
382 _waiting: WaitingLlmState<lash_core::HostTurnProtocol>,
383 llm_response: LlmResponse,
384 text_streamed: bool,
385 ) -> Vec<DriverAction> {
386 let response_parts = normalized_response_parts(&llm_response);
387 let mut assistant_text = String::new();
388 let mut assistant_text_parts: Vec<(String, Option<ResponseTextMeta>)> = Vec::new();
389 let mut tool_calls: Vec<StandardToolCall> = Vec::new();
390 let mut reasoning_items: Vec<(usize, Option<ProviderReasoningReplay>, String)> = Vec::new();
398 let mut actions = Vec::new();
399
400 for part in response_parts {
401 match part {
402 LlmOutputPart::Text {
403 text,
404 response_meta,
405 } => {
406 if !text.is_empty() {
407 let previous_len = assistant_text.len();
408 append_assistant_text_part(&mut assistant_text, &text);
409 assistant_text_parts
410 .push((assistant_text[previous_len..].to_string(), response_meta));
411 if !text_streamed {
412 actions.push(DriverAction::Emit(SessionEvent::TextDelta {
413 content: assistant_text[previous_len..].to_string(),
414 }));
415 }
416 }
417 }
418 LlmOutputPart::Reasoning { text, replay } => {
419 let trimmed = text.trim().to_string();
420 if trimmed.is_empty() && replay.as_ref().is_none_or(|meta| meta.is_empty()) {
423 continue;
424 }
425 reasoning_items.push((tool_calls.len(), replay, trimmed));
426 }
427 LlmOutputPart::ToolCall {
428 call_id,
429 tool_name,
430 input_json,
431 replay,
432 } => {
433 tool_calls.push(StandardToolCall {
434 call_id,
435 tool_name,
436 input_json,
437 replay,
438 });
439 }
440 }
441 }
442
443 actions.push(DriverAction::Emit(SessionEvent::LlmResponse {
444 protocol_iteration: ctx.protocol_iteration(),
445 content: assistant_text.clone(),
446 duration_ms: 0,
447 }));
448
449 if tool_calls.is_empty() {
450 if assistant_text.trim().is_empty() && reasoning_items.is_empty() {
451 if last_message_has_tool_result(&ctx) {
452 actions.push(DriverAction::StartCheckpoint {
456 checkpoint: CheckpointKind::BeforeCompletion,
457 on_empty: CheckpointResumeAction::Finish(TurnOutcome::Finished(
458 TurnFinish::AssistantMessage {
459 text: String::new(),
460 },
461 )),
462 });
463 return actions;
464 }
465 actions.push(DriverAction::Emit(make_error_event(
466 "llm_provider",
467 Some("empty_response"),
468 "Model returned no assistant text or tool calls.",
469 None,
470 )));
471 actions.push(DriverAction::Finish(TurnOutcome::Stopped(
472 TurnStop::ProviderError,
473 )));
474 return actions;
475 }
476
477 let asst_id = fresh_message_id();
478 let mut parts_out = Vec::new();
479 for (_, meta, text) in reasoning_items {
480 parts_out.push(reasoning_part(&asst_id, parts_out.len(), text, meta));
481 }
482 for (content, response_meta) in assistant_text_parts {
483 if content.trim().is_empty() {
484 continue;
485 }
486 parts_out.push(Part {
487 id: format!("{}.p{}", asst_id, parts_out.len()),
488 kind: PartKind::Prose,
489 content,
490 attachment: None,
491 tool_call_id: None,
492 tool_name: None,
493 tool_replay: None,
494 prune_state: PruneState::Intact,
495 reasoning_meta: None,
496 response_meta,
497 });
498 }
499 if parts_out.is_empty() {
500 actions.push(DriverAction::Emit(make_error_event(
501 "llm_provider",
502 Some("empty_response"),
503 "Model returned no assistant text or tool calls.",
504 None,
505 )));
506 actions.push(DriverAction::Finish(TurnOutcome::Stopped(
507 TurnStop::ProviderError,
508 )));
509 return actions;
510 }
511 actions.push(DriverAction::StartCheckpoint {
512 checkpoint: CheckpointKind::BeforeCompletion,
513 on_empty: CheckpointResumeAction::Finish(TurnOutcome::Finished(
514 TurnFinish::AssistantMessage {
515 text: assistant_text.clone(),
516 },
517 )),
518 });
519 return actions;
520 }
521
522 let asst_id = fresh_message_id();
523 let mut assistant_parts = Vec::new();
524 for (content, response_meta) in assistant_text_parts {
525 if content.trim().is_empty() {
526 continue;
527 }
528 assistant_parts.push(Part {
529 id: format!("{}.p{}", asst_id, assistant_parts.len()),
530 kind: PartKind::Prose,
531 content,
532 attachment: None,
533 tool_call_id: None,
534 tool_name: None,
535 tool_replay: None,
536 prune_state: PruneState::Intact,
537 reasoning_meta: None,
538 response_meta,
539 });
540 }
541
542 let mut calls = Vec::new();
543 let mut reasoning_iter = reasoning_items.into_iter().peekable();
548 for (tool_index, tool_call) in tool_calls.into_iter().enumerate() {
549 while let Some((insert_index, _, _)) = reasoning_iter.peek() {
550 if *insert_index > tool_index {
551 break;
552 }
553 let (_, meta, text) = reasoning_iter.next().expect("peek ok");
554 assistant_parts.push(reasoning_part(&asst_id, assistant_parts.len(), text, meta));
555 }
556 assistant_parts.push(Part {
557 id: format!("{}.p{}", asst_id, assistant_parts.len()),
558 kind: PartKind::ToolCall,
559 content: tool_call.input_json.clone(),
560 attachment: None,
561 tool_call_id: Some(tool_call.call_id.clone()),
562 tool_name: Some(tool_call.tool_name.clone()),
563 tool_replay: tool_call.replay.clone(),
564 prune_state: PruneState::Intact,
565 reasoning_meta: None,
566 response_meta: None,
567 });
568
569 let args = serde_json::from_str::<Value>(&tool_call.input_json)
570 .unwrap_or_else(|_| serde_json::json!({}));
571 calls.push(PendingToolCall {
572 call_id: tool_call.call_id,
573 tool_name: tool_call.tool_name,
574 args,
575 replay: tool_call.replay,
576 });
577 }
578 for (_, meta, text) in reasoning_iter {
579 assistant_parts.push(reasoning_part(&asst_id, assistant_parts.len(), text, meta));
580 }
581
582 if !assistant_parts.is_empty() {
583 actions.push(DriverAction::AppendEvents(vec![conversation_event(
584 Message {
585 id: asst_id,
586 role: MessageRole::Assistant,
587 parts: shared_parts(assistant_parts),
588 origin: None,
589 },
590 )]));
591 }
592
593 actions.push(DriverAction::StartTools { calls });
594 actions
595 }
596
597 fn handle_tool_results(
598 &self,
599 ctx: DriverContextView<'_>,
600 completed: Vec<CompletedToolCall>,
601 ) -> Vec<DriverAction> {
602 let mut actions = Vec::new();
603 let mut result_parts = Vec::new();
604 let mut terminal_outcome = None;
605
606 for outcome in completed {
607 if terminal_outcome.is_none() && outcome.output.is_success() {
608 terminal_outcome = match outcome.output.control.as_ref() {
609 Some(lash_core::ToolControl::SwitchAgentFrame {
610 frame_id,
611 task: Some(task),
612 ..
613 }) if !frame_id.trim().is_empty() && !task.trim().is_empty() => {
614 Some(TurnOutcome::AgentFrameSwitch {
615 frame_id: frame_id.clone(),
616 task: task.clone(),
617 })
618 }
619 Some(lash_core::ToolControl::Finish { value }) => {
620 Some(TurnOutcome::Finished(TurnFinish::ToolValue {
621 tool_name: outcome.tool_name.clone(),
622 value: value.to_json_value(),
623 }))
624 }
625 Some(lash_core::ToolControl::Fail { failure }) => {
626 Some(TurnOutcome::Stopped(TurnStop::ToolError {
627 tool_name: outcome.tool_name.clone(),
628 value: failure.to_json_value(),
629 }))
630 }
631 _ => None,
632 };
633 }
634
635 append_model_return_parts(&mut result_parts, outcome.model_return);
636 }
637
638 if !result_parts.is_empty() {
639 let user_id = fresh_message_id();
640 reassign_part_ids(&user_id, &mut result_parts);
641 actions.push(DriverAction::AppendEvents(vec![conversation_event(
642 Message {
643 id: user_id,
644 role: MessageRole::User,
645 parts: shared_parts(result_parts),
646 origin: None,
647 },
648 )]));
649 }
650
651 if let Some(outcome) = terminal_outcome {
652 actions.push(DriverAction::Finish(outcome));
653 return actions;
654 }
655
656 actions.push(DriverAction::AdvanceProtocolIteration);
657 let next_protocol_iteration = ctx.protocol_iteration() + 1;
658 if let Some(max_turns) = ctx.max_turns()
659 && next_protocol_iteration >= ctx.protocol_run_offset() + max_turns
660 {
661 let message_id = fresh_message_id();
662 actions.push(DriverAction::AppendEvents(vec![conversation_event(
663 turn_limit_exhausted_message(message_id, max_turns),
664 )]));
665 actions.push(DriverAction::Finish(TurnOutcome::Stopped(
666 TurnStop::MaxTurns,
667 )));
668 return actions;
669 }
670
671 actions.push(DriverAction::StartCheckpoint {
672 checkpoint: CheckpointKind::AfterWork,
673 on_empty: CheckpointResumeAction::PrepareIteration,
674 });
675 actions
676 }
677
678 fn handle_exec_result(
679 &self,
680 _ctx: DriverContextView<'_>,
681 _waiting: WaitingExecState<lash_core::HostTurnProtocol>,
682 _result: Result<lash_core::ExecResponse, String>,
683 ) -> Vec<DriverAction> {
684 Vec::new()
685 }
686}
687
688fn append_model_return_parts(parts: &mut Vec<Part>, model_return: lash_core::ModelToolReturn) {
689 for part in model_return.parts {
690 match part {
691 lash_core::ModelToolReturnPart::Text { text } => {
692 if text.is_empty() {
693 continue;
694 }
695 parts.push(Part {
696 id: String::new(),
697 kind: PartKind::ToolResult,
698 content: text,
699 attachment: None,
700 tool_call_id: Some(model_return.call_id.clone()),
701 tool_name: Some(model_return.tool_name.clone()),
702 tool_replay: None,
703 prune_state: PruneState::Intact,
704 reasoning_meta: None,
705 response_meta: None,
706 });
707 }
708 lash_core::ModelToolReturnPart::Attachment(reference) => {
709 parts.push(Part {
710 id: String::new(),
711 kind: PartKind::Image,
712 content: String::new(),
713 attachment: Some(PartAttachment { reference }),
714 tool_call_id: Some(model_return.call_id.clone()),
715 tool_name: Some(model_return.tool_name.clone()),
716 tool_replay: None,
717 prune_state: PruneState::Intact,
718 reasoning_meta: None,
719 response_meta: None,
720 });
721 }
722 }
723 }
724}
725
726fn conversation_event(message: Message) -> SessionEventRecord {
727 SessionEventRecord::Conversation(ConversationRecord::from_message(message))
728}
729
730#[cfg(test)]
731mod tests {
732 use super::*;
733 use lash_core::{
734 AttachmentId, AttachmentMeta, ImageMediaType, MediaType, ModelToolReturn, ToolCallOutput,
735 ToolValue,
736 };
737 use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
738 use tokio::sync::Barrier;
739 use tokio::time::{Duration, timeout};
740
741 fn image_ref(id: &str) -> lash_core::AttachmentRef {
742 AttachmentMeta::new(
743 AttachmentId::new(id),
744 MediaType::Image(ImageMediaType::Png),
745 4,
746 Some(1),
747 Some(1),
748 Some("tiny".to_string()),
749 )
750 .as_ref()
751 }
752
753 #[derive(Clone, Debug)]
754 struct BatchRuntimeProvider {
755 calls: Arc<AtomicUsize>,
756 saw_batch_result: Arc<AtomicBool>,
757 }
758
759 #[async_trait::async_trait]
760 impl lash_core::Provider for BatchRuntimeProvider {
761 fn kind(&self) -> &'static str {
762 "stub"
763 }
764
765 fn options(&self) -> lash_core::ProviderOptions {
766 lash_core::ProviderOptions::default()
767 }
768
769 fn set_options(&mut self, _options: lash_core::ProviderOptions) {}
770
771 fn serialize_config(&self) -> serde_json::Value {
772 serde_json::json!({})
773 }
774
775 async fn complete(
776 &mut self,
777 request: lash_core::LlmRequest,
778 ) -> Result<lash_core::LlmResponse, lash_core::LlmTransportError> {
779 let call_index = self.calls.fetch_add(1, Ordering::SeqCst);
780 if call_index == 0 {
781 return Ok(lash_core::LlmResponse {
782 parts: vec![lash_core::LlmOutputPart::ToolCall {
783 call_id: "batch-call".to_string(),
784 tool_name: "batch".to_string(),
785 input_json: serde_json::json!({
786 "tool_calls": [
787 {"tool": "alpha", "parameters": {}},
788 {"tool": "beta", "parameters": {"value": "fail"}}
789 ]
790 })
791 .to_string(),
792 replay: None,
793 }],
794 ..lash_core::LlmResponse::default()
795 });
796 }
797
798 let projected_messages = format!("{:?}", request.messages);
799 if projected_messages.contains("alpha") && projected_messages.contains("beta failed") {
800 self.saw_batch_result.store(true, Ordering::SeqCst);
801 }
802 Ok(lash_core::LlmResponse {
803 full_text: "done".to_string(),
804 parts: vec![lash_core::LlmOutputPart::Text {
805 text: "done".to_string(),
806 response_meta: None,
807 }],
808 ..lash_core::LlmResponse::default()
809 })
810 }
811
812 fn clone_boxed(&self) -> Box<dyn lash_core::Provider> {
813 Box::new(self.clone())
814 }
815 }
816
817 #[derive(Debug)]
818 struct BatchRuntimeTools {
819 barrier: Arc<Barrier>,
820 started: Arc<AtomicUsize>,
821 }
822
823 fn runtime_test_tool(name: &str) -> lash_core::ToolDefinition {
824 lash_core::ToolDefinition::raw(
825 format!("tool:{name}"),
826 name,
827 "",
828 serde_json::json!({
829 "type": "object",
830 "properties": {
831 "value": { "type": "string" }
832 },
833 "additionalProperties": true
834 }),
835 serde_json::json!({ "type": "string" }),
836 )
837 .with_scheduling(lash_core::ToolScheduling::Parallel)
838 }
839
840 #[async_trait::async_trait]
841 impl ToolProvider for BatchRuntimeTools {
842 fn tool_manifests(&self) -> Vec<ToolManifest> {
843 vec![
844 runtime_test_tool("alpha").manifest(),
845 runtime_test_tool("beta").manifest(),
846 ]
847 }
848
849 fn resolve_contract(&self, name: &str) -> Option<Arc<ToolContract>> {
850 match name {
851 "alpha" | "beta" => Some(Arc::new(runtime_test_tool(name).contract())),
852 _ => None,
853 }
854 }
855
856 async fn execute(&self, call: ToolCall<'_>) -> ToolResult {
857 self.started.fetch_add(1, Ordering::SeqCst);
858 if timeout(Duration::from_millis(100), self.barrier.wait())
859 .await
860 .is_err()
861 {
862 return ToolResult::err_fmt("batch child tools did not run concurrently");
863 }
864 if call.name == "beta"
865 && call.args.get("value").and_then(|value| value.as_str()) == Some("fail")
866 {
867 return ToolResult::err_fmt("beta failed");
868 }
869 ToolResult::ok(serde_json::json!(call.name))
870 }
871 }
872
873 #[derive(Clone, Default)]
874 struct CountingEffectController {
875 kinds: Arc<std::sync::Mutex<Vec<lash_core::RuntimeEffectKind>>>,
876 }
877
878 impl CountingEffectController {
879 fn count(&self, kind: lash_core::RuntimeEffectKind) -> usize {
880 self.kinds
881 .lock()
882 .expect("effect kinds")
883 .iter()
884 .filter(|candidate| **candidate == kind)
885 .count()
886 }
887 }
888
889 #[derive(Default)]
890 struct DurableMemoryAttachmentStore {
891 inner: lash_core::InMemoryAttachmentStore,
892 }
893
894 #[async_trait::async_trait]
895 impl lash_core::AttachmentStore for DurableMemoryAttachmentStore {
896 fn persistence(&self) -> lash_core::AttachmentStorePersistence {
897 lash_core::AttachmentStorePersistence::Durable
898 }
899
900 async fn put(
901 &self,
902 bytes: Vec<u8>,
903 meta: lash_core::AttachmentCreateMeta,
904 ) -> Result<lash_core::AttachmentRef, lash_core::AttachmentStoreError> {
905 self.inner.put(bytes, meta).await
906 }
907
908 async fn get(
909 &self,
910 id: &lash_core::AttachmentId,
911 ) -> Result<lash_core::StoredAttachment, lash_core::AttachmentStoreError> {
912 self.inner.get(id).await
913 }
914 }
915
916 #[derive(Default)]
917 struct DurableMemoryProcessEnvStore {
918 inner: lash_core::InMemoryProcessExecutionEnvStore,
919 }
920
921 #[async_trait::async_trait]
922 impl lash_core::ProcessExecutionEnvStore for DurableMemoryProcessEnvStore {
923 fn durability_tier(&self) -> lash_core::DurabilityTier {
924 lash_core::DurabilityTier::Durable
925 }
926
927 async fn put_process_execution_env(
928 &self,
929 env_ref: &lash_core::ProcessExecutionEnvRef,
930 bytes: &[u8],
931 ) -> Result<(), lash_core::PluginError> {
932 self.inner.put_process_execution_env(env_ref, bytes).await
933 }
934
935 async fn get_process_execution_env(
936 &self,
937 env_ref: &lash_core::ProcessExecutionEnvRef,
938 ) -> Result<Option<Vec<u8>>, lash_core::PluginError> {
939 self.inner.get_process_execution_env(env_ref).await
940 }
941 }
942
943 #[async_trait::async_trait]
944 impl lash_core::RuntimeEffectController for CountingEffectController {
945 fn durability_tier(&self) -> lash_core::DurabilityTier {
946 lash_core::DurabilityTier::Durable
947 }
948
949 async fn execute_effect(
950 &self,
951 envelope: lash_core::RuntimeEffectEnvelope,
952 local_executor: lash_core::RuntimeEffectLocalExecutor<'_>,
953 ) -> Result<lash_core::RuntimeEffectOutcome, lash_core::RuntimeEffectControllerError>
954 {
955 self.kinds
956 .lock()
957 .expect("effect kinds")
958 .push(envelope.command.kind());
959 local_executor.execute(envelope).await
960 }
961 }
962
963 #[tokio::test]
964 async fn standard_batch_tool_rejects_nested_batch_inside_durable_attempt() {
965 let provider_calls = Arc::new(AtomicUsize::new(0));
966 let saw_batch_result = Arc::new(AtomicBool::new(false));
967 let provider = BatchRuntimeProvider {
968 calls: Arc::clone(&provider_calls),
969 saw_batch_result: Arc::clone(&saw_batch_result),
970 };
971 let provider_handle = lash_core::ProviderHandle::new(lash_core::ProviderComponents::new(
972 Box::new(provider),
973 Arc::new(lash_core::StaticModelPolicy::new()),
974 ));
975 let mut host = lash_core::RuntimeHostConfig::in_memory();
976 host.providers.provider_resolver =
977 Arc::new(lash_core::SingleProviderResolver::new(provider_handle));
978 host.durability.attachment_store = Arc::new(DurableMemoryAttachmentStore::default());
979 host.durability.process_env_store = Arc::new(DurableMemoryProcessEnvStore::default());
980 let started = Arc::new(AtomicUsize::new(0));
981 let factories: Vec<Arc<dyn lash_core::PluginFactory>> = vec![
982 Arc::new(StandardProtocolPluginFactory::new()),
983 Arc::new(lash_core::plugin::StaticPluginFactory::new(
984 "standard-batch-test-tools",
985 lash_core::PluginSpec::new().with_tool_provider(Arc::new(BatchRuntimeTools {
986 barrier: Arc::new(Barrier::new(2)),
987 started: Arc::clone(&started),
988 })),
989 )),
990 ];
991 let policy = lash_core::SessionPolicy {
992 provider_id: "stub".to_string(),
993 model: lash_core::ModelSpec::from_token_limits("mock-model", None, 200_000, None)
994 .expect("valid model"),
995 ..lash_core::SessionPolicy::default()
996 };
997 let controller = CountingEffectController::default();
998 let scoped_controller = lash_core::ScopedEffectController::shared(
999 Arc::new(controller.clone()),
1000 lash_core::ExecutionScope::turn("standard-batch-session", "turn-1"),
1001 )
1002 .expect("scoped controller");
1003 let mut runtime = lash_core::LashRuntime::builder()
1004 .with_session_id("standard-batch-session")
1005 .with_policy(policy)
1006 .with_runtime_host(host)
1007 .with_plugin_factories(factories)
1008 .build()
1009 .await
1010 .expect("runtime");
1011
1012 let turn = runtime
1013 .stream_turn(
1014 lash_core::TurnInput::text("run the batch"),
1015 lash_core::TurnOptions::new(
1016 tokio_util::sync::CancellationToken::new(),
1017 scoped_controller,
1018 ),
1019 )
1020 .await
1021 .expect("turn");
1022
1023 assert!(matches!(turn.outcome, lash_core::TurnOutcome::Finished(_)));
1024 assert_eq!(provider_calls.load(Ordering::SeqCst), 2);
1025 assert_eq!(started.load(Ordering::SeqCst), 0);
1026 assert!(!saw_batch_result.load(Ordering::SeqCst));
1027 assert_eq!(controller.count(lash_core::RuntimeEffectKind::ToolBatch), 1);
1028 assert_eq!(
1029 controller.count(lash_core::RuntimeEffectKind::ToolAttempt),
1030 1
1031 );
1032 }
1033
1034 #[test]
1035 fn tool_attachment_round_trips_to_part_kind_image() {
1036 let attachment = image_ref("att-1");
1037 let output = ToolCallOutput::success(ToolValue::Attachment(attachment.clone()));
1038 let model_return =
1039 ModelToolReturn::from_output("call-9".to_string(), "screenshot".to_string(), &output);
1040
1041 let mut parts: Vec<Part> = Vec::new();
1042 append_model_return_parts(&mut parts, model_return);
1043
1044 assert_eq!(parts.len(), 1, "single attachment yields single part");
1045 let part = &parts[0];
1046 assert!(matches!(part.kind, PartKind::Image));
1047 assert_eq!(part.content, "");
1048 assert_eq!(part.tool_call_id.as_deref(), Some("call-9"));
1049 assert_eq!(part.tool_name.as_deref(), Some("screenshot"));
1050 let part_attachment = part.attachment.as_ref().expect("attachment present");
1051 assert_eq!(part_attachment.reference.id, attachment.id);
1052 }
1053
1054 #[test]
1055 fn tool_text_and_attachment_round_trip_preserves_order() {
1056 let attachment = image_ref("att-2");
1057 let output = ToolCallOutput::success(ToolValue::Array(vec![
1058 ToolValue::String("before".into()),
1059 ToolValue::Attachment(attachment.clone()),
1060 ToolValue::String("after".into()),
1061 ]));
1062 let model_return =
1063 ModelToolReturn::from_output("call-10".to_string(), "snap".to_string(), &output);
1064
1065 let mut parts: Vec<Part> = Vec::new();
1066 append_model_return_parts(&mut parts, model_return);
1067
1068 assert_eq!(parts.len(), 3, "text + image + text yields three parts");
1071 assert!(matches!(parts[0].kind, PartKind::ToolResult));
1072 assert!(parts[0].content.starts_with("[\"before\""));
1073 assert!(matches!(parts[1].kind, PartKind::Image));
1074 assert_eq!(
1075 parts[1]
1076 .attachment
1077 .as_ref()
1078 .expect("attachment")
1079 .reference
1080 .id,
1081 attachment.id
1082 );
1083 assert!(matches!(parts[2].kind, PartKind::ToolResult));
1084 assert!(parts[2].content.ends_with("\"after\"]"));
1085 }
1086}