1use std::sync::Arc;
16
17use async_trait::async_trait;
18use lash_core::llm::types::{ProviderReasoningReplay, ProviderReplayMeta, ResponseTextMeta};
19use lash_core::plugin::{
20 PluginError, PluginFactory, PluginRegistrar, PluginSessionContext, ProtocolDriverPlugin,
21 ProtocolSessionContext, ProtocolSessionPlugin, SessionPlugin,
22};
23use lash_core::sansio::{
24 CheckpointResumeAction, CompletedToolCall, PendingToolCall, ProtocolDriverHandle,
25 WaitingExecState, WaitingLlmState,
26};
27use lash_core::session_model::message::PartAttachment;
28use lash_core::session_model::{
29 ConversationRecord, Message, MessageRole, Part, PartKind, PruneState, SessionEvent,
30 SessionEventRecord, fresh_message_id, make_error_event, reassign_part_ids, shared_parts,
31};
32
33mod batch;
34use batch::batch_tool_definition;
35use lash_core::{
36 CheckpointKind, DriverAction, DriverContextView, LlmOutputPart, LlmResponse,
37 ProtocolBuildInput, SessionError, ToolCall, ToolContract, ToolInvocation, ToolManifest,
38 ToolProvider, ToolResult, TurnDriverConfig, TurnDriverPreamble, TurnFinish, TurnOutcome,
39 TurnStop, append_assistant_text_part, normalized_response_parts, reasoning_part,
40};
41use serde_json::Value;
42
43const STANDARD_EXECUTION_SECTION: &str = r#"Use direct tool calls.
44
45- Use `batch` (up to 25 calls) for two or more independent tool calls. Serialize calls when later arguments depend on earlier results.
46- For direct conversational requests that need no tools, respond in prose only.
47
48Example — two independent reads in one `batch` call:
49
50```json
51{
52 "tool_calls": [
53 { "tool": "read_file", "parameters": { "path": "src/main.rs" } },
54 { "tool": "grep", "parameters": { "query": "ToolProvider", "path": "crates/lash/src/" } }
55 ]
56}
57```"#;
58
59const BATCH_MAX_TOOL_CALLS: usize = 25;
60
61#[derive(Default)]
64pub struct StandardProtocolPluginFactory;
65
66impl StandardProtocolPluginFactory {
67 pub fn new() -> Self {
68 Self
69 }
70}
71
72impl PluginFactory for StandardProtocolPluginFactory {
73 fn id(&self) -> &'static str {
74 "standard_protocol"
75 }
76
77 fn build(&self, _ctx: &PluginSessionContext) -> Result<Arc<dyn SessionPlugin>, PluginError> {
78 Ok(Arc::new(StandardProtocolPlugin))
79 }
80}
81
82struct StandardProtocolPlugin;
83
84impl SessionPlugin for StandardProtocolPlugin {
85 fn id(&self) -> &'static str {
86 "standard_protocol"
87 }
88
89 fn register(&self, reg: &mut PluginRegistrar) -> Result<(), PluginError> {
90 reg.protocol().session(Arc::new(StandardProtocolSession))?;
91 reg.protocol()
92 .protocol_driver(Arc::new(StandardProtocolDriver))?;
93 reg.tools().provider(Arc::new(StandardProtocolTools))?;
94 Ok(())
95 }
96}
97
98struct StandardProtocolSession;
99
100#[async_trait]
101impl ProtocolSessionPlugin for StandardProtocolSession {
102 async fn initialize_session(
103 &self,
104 _ctx: ProtocolSessionContext<'_>,
105 ) -> Result<(), SessionError> {
106 Ok(())
107 }
108}
109
110struct StandardProtocolDriver;
111
112impl ProtocolDriverPlugin for StandardProtocolDriver {
113 fn build_preamble(&self, input: ProtocolBuildInput) -> TurnDriverPreamble {
114 let tool_names = input.tool_catalog.tool_names();
115 let tool_names_fingerprint = input.tool_catalog.tool_names_fingerprint();
116 TurnDriverPreamble {
117 config: TurnDriverConfig::chat(
118 Arc::new(StandardDriver),
119 true,
120 Arc::new(turn_limit_exhausted_message),
121 ),
122 tool_specs: input.tool_catalog.model_tool_specs(),
123 tool_names,
124 tool_names_fingerprint,
125 omitted_tool_count: 0,
126 execution_prompt: Arc::from(STANDARD_EXECUTION_SECTION),
127 prompt_contributions: input.extra_prompt_contributions,
128 }
129 }
130}
131
132fn turn_limit_exhausted_message(message_id: String, max_turns: usize) -> Message {
133 Message {
134 id: message_id.clone(),
135 role: MessageRole::System,
136 parts: shared_parts(vec![Part {
137 id: format!("{message_id}.p0"),
138 kind: PartKind::Error,
139 content: format!("Turn limit reached ({max_turns}) before a final assistant response."),
140 attachment: None,
141 tool_call_id: None,
142 tool_name: None,
143 tool_replay: None,
144 prune_state: PruneState::Intact,
145 reasoning_meta: None,
146 response_meta: None,
147 }]),
148 origin: None,
149 }
150}
151
152struct StandardProtocolTools;
153
154#[async_trait]
155impl ToolProvider for StandardProtocolTools {
156 fn tool_manifests(&self) -> Vec<ToolManifest> {
157 vec![batch_tool_definition().manifest()]
158 }
159
160 fn resolve_contract(&self, name: &str) -> Option<Arc<ToolContract>> {
161 (name == "batch").then(|| Arc::new(batch_tool_definition().contract()))
162 }
163
164 async fn execute(&self, call: ToolCall<'_>) -> ToolResult {
165 match call.name {
166 "batch" => execute_batch_tool_call(call).await,
167 _ => ToolResult::err_fmt(format_args!("Unknown tool: {}", call.name)),
168 }
169 }
170}
171
172#[derive(Debug)]
173struct BatchCallSpec {
174 index: usize,
175 tool: String,
176 parameters: Value,
177}
178
179async fn execute_batch_tool_call(call: ToolCall<'_>) -> ToolResult {
180 let args = call.args;
181 let specs = match parse_batch_specs(args) {
182 Ok(specs) => specs,
183 Err(err) => return err,
184 };
185
186 let mut immediate_outcomes = Vec::new();
187 let mut parallel_specs = Vec::new();
188 let dispatch = call.context.dispatch();
189
190 for spec in specs.into_iter().take(BATCH_MAX_TOOL_CALLS) {
191 if spec.tool == "batch" {
192 immediate_outcomes.push(serde_json::json!({
193 "index": spec.index,
194 "tool": spec.tool,
195 "success": false,
196 "duration_ms": 0,
197 "error": "Tool 'batch' is not allowed inside batch",
198 }));
199 continue;
200 }
201 let Some(manifest) = dispatch.callable_tool_manifest(&spec.tool) else {
202 let error = format!("Tool '{}' is unavailable in this session", spec.tool);
203 immediate_outcomes.push(serde_json::json!({
204 "index": spec.index,
205 "tool": spec.tool,
206 "success": false,
207 "duration_ms": 0,
208 "error": error,
209 }));
210 continue;
211 };
212 parallel_specs.push((
213 spec.index,
214 ToolInvocation::new(
215 format!(
216 "{}:{:02}",
217 call.context.tool_call_id().unwrap_or("batch"),
218 spec.index
219 ),
220 manifest.id,
221 spec.parameters,
222 ),
223 ));
224 }
225
226 let mut parallel_outcomes = dispatch
227 .batch(
228 parallel_specs
229 .iter()
230 .map(|(_, invocation)| invocation.clone())
231 .collect(),
232 )
233 .await;
234 for ((index, invocation), outcome) in
235 parallel_specs.into_iter().zip(parallel_outcomes.drain(..))
236 {
237 let tool_label = invocation.label();
238 let tool_record = outcome.record.unwrap_or(lash_core::ToolCallRecord {
239 call_id: Some(invocation.id),
240 tool: tool_label,
241 args: invocation.args,
242 output: outcome.output,
243 duration_ms: 0,
244 });
245 let mut result_record = serde_json::Map::new();
246 result_record.insert("index".to_string(), serde_json::json!(index));
247 result_record.insert("tool".to_string(), serde_json::json!(tool_record.tool));
248 result_record.insert(
249 "success".to_string(),
250 serde_json::json!(tool_record.output.is_success()),
251 );
252 result_record.insert(
253 "duration_ms".to_string(),
254 serde_json::json!(tool_record.duration_ms),
255 );
256 result_record.insert(
257 if tool_record.output.is_success() {
258 "result".to_string()
259 } else {
260 "error".to_string()
261 },
262 tool_record.output.value_for_projection(),
263 );
264 immediate_outcomes.push(Value::Object(result_record));
265 }
266
267 for overflow_index in BATCH_MAX_TOOL_CALLS
268 ..args
269 .get("tool_calls")
270 .and_then(|value| value.as_array())
271 .map(|value| value.len())
272 .unwrap_or_default()
273 {
274 immediate_outcomes.push(serde_json::json!({
275 "index": overflow_index,
276 "tool": args
277 .get("tool_calls")
278 .and_then(|value| value.as_array())
279 .and_then(|items| items.get(overflow_index))
280 .and_then(|item| item.get("tool"))
281 .and_then(|value| value.as_str())
282 .unwrap_or("unknown"),
283 "success": false,
284 "duration_ms": 0,
285 "error": "Maximum of 25 tool calls allowed in batch",
286 }));
287 }
288
289 immediate_outcomes.sort_by_key(|outcome| {
290 outcome
291 .get("index")
292 .and_then(|value| value.as_u64())
293 .unwrap_or(u64::MAX)
294 });
295 ToolResult::ok(serde_json::json!({
296 "results": immediate_outcomes,
297 }))
298}
299
300#[allow(clippy::result_large_err)]
301fn parse_batch_specs(args: &Value) -> Result<Vec<BatchCallSpec>, ToolResult> {
302 let Some(raw_calls) = args.get("tool_calls").and_then(|value| value.as_array()) else {
303 return Err(ToolResult::err_fmt(
304 "Missing required parameter: tool_calls",
305 ));
306 };
307 if raw_calls.is_empty() {
308 return Err(ToolResult::err_fmt(
309 "Invalid tool_calls: expected at least one call",
310 ));
311 }
312
313 let mut specs = Vec::with_capacity(raw_calls.len());
314 for (index, item) in raw_calls.iter().enumerate() {
315 let Some(object) = item.as_object() else {
316 return Err(ToolResult::err_fmt(format_args!(
317 "Invalid tool_calls[{index}]: expected object with tool and parameters"
318 )));
319 };
320 let Some(tool) = object
321 .get("tool")
322 .and_then(|value| value.as_str())
323 .map(str::trim)
324 .filter(|tool| !tool.is_empty())
325 else {
326 return Err(ToolResult::err_fmt(format_args!(
327 "Invalid tool_calls[{index}].tool: expected non-empty string"
328 )));
329 };
330 let parameters = object
331 .get("parameters")
332 .cloned()
333 .unwrap_or_else(|| serde_json::json!({}));
334 specs.push(BatchCallSpec {
335 index,
336 tool: tool.to_string(),
337 parameters,
338 });
339 }
340
341 Ok(specs)
342}
343
344pub struct StandardDriver;
354
355struct StandardToolCall {
356 call_id: String,
357 tool_name: String,
358 input_json: String,
359 replay: Option<ProviderReplayMeta>,
360}
361
362fn last_message_has_tool_result(ctx: &DriverContextView<'_>) -> bool {
363 ctx.messages().last().is_some_and(|message| {
364 matches!(message.role, MessageRole::User)
365 && message
366 .parts
367 .iter()
368 .any(|part| matches!(part.kind, PartKind::ToolResult))
369 })
370}
371
372impl ProtocolDriverHandle<lash_core::HostTurnProtocol> for StandardDriver {
373 fn prepare_protocol_iteration(&self, ctx: DriverContextView<'_>) -> Vec<DriverAction> {
374 vec![DriverAction::StartLlm {
375 request: ctx.project_llm_request(true),
376 driver_state: None,
377 }]
378 }
379
380 fn handle_llm_success(
381 &self,
382 ctx: DriverContextView<'_>,
383 _waiting: WaitingLlmState<lash_core::HostTurnProtocol>,
384 llm_response: LlmResponse,
385 text_streamed: bool,
386 ) -> Vec<DriverAction> {
387 let response_parts = normalized_response_parts(&llm_response);
388 let mut assistant_text = String::new();
389 let mut assistant_text_parts: Vec<(String, Option<ResponseTextMeta>)> = Vec::new();
390 let mut tool_calls: Vec<StandardToolCall> = Vec::new();
391 let mut reasoning_items: Vec<(usize, Option<ProviderReasoningReplay>, String)> = Vec::new();
399 let mut actions = Vec::new();
400
401 for part in response_parts {
402 match part {
403 LlmOutputPart::Text {
404 text,
405 response_meta,
406 } => {
407 if !text.is_empty() {
408 let previous_len = assistant_text.len();
409 append_assistant_text_part(&mut assistant_text, &text);
410 assistant_text_parts
411 .push((assistant_text[previous_len..].to_string(), response_meta));
412 if !text_streamed {
413 actions.push(DriverAction::Emit(SessionEvent::TextDelta {
414 content: assistant_text[previous_len..].to_string(),
415 }));
416 }
417 }
418 }
419 LlmOutputPart::Reasoning { text, replay } => {
420 let trimmed = text.trim().to_string();
421 if trimmed.is_empty() && replay.as_ref().is_none_or(|meta| meta.is_empty()) {
424 continue;
425 }
426 reasoning_items.push((tool_calls.len(), replay, trimmed));
427 }
428 LlmOutputPart::ToolCall {
429 call_id,
430 tool_name,
431 input_json,
432 replay,
433 } => {
434 tool_calls.push(StandardToolCall {
435 call_id,
436 tool_name,
437 input_json,
438 replay,
439 });
440 }
441 }
442 }
443
444 actions.push(DriverAction::Emit(SessionEvent::LlmResponse {
445 protocol_iteration: ctx.protocol_iteration(),
446 content: assistant_text.clone(),
447 duration_ms: 0,
448 }));
449
450 if tool_calls.is_empty() {
451 if assistant_text.trim().is_empty() && reasoning_items.is_empty() {
452 if last_message_has_tool_result(&ctx) {
453 actions.push(DriverAction::StartCheckpoint {
457 checkpoint: CheckpointKind::BeforeCompletion,
458 on_empty: CheckpointResumeAction::Finish(TurnOutcome::Finished(
459 TurnFinish::AssistantMessage {
460 text: String::new(),
461 },
462 )),
463 });
464 return actions;
465 }
466 actions.push(DriverAction::Emit(make_error_event(
467 "llm_provider",
468 Some("empty_response"),
469 "Model returned no assistant text or tool calls.",
470 None,
471 )));
472 actions.push(DriverAction::Finish(TurnOutcome::Stopped(
473 TurnStop::ProviderError,
474 )));
475 return actions;
476 }
477
478 let asst_id = fresh_message_id();
479 let mut parts_out = Vec::new();
480 for (_, meta, text) in reasoning_items {
481 parts_out.push(reasoning_part(&asst_id, parts_out.len(), text, meta));
482 }
483 for (content, response_meta) in assistant_text_parts {
484 if content.trim().is_empty() {
485 continue;
486 }
487 parts_out.push(Part {
488 id: format!("{}.p{}", asst_id, parts_out.len()),
489 kind: PartKind::Prose,
490 content,
491 attachment: None,
492 tool_call_id: None,
493 tool_name: None,
494 tool_replay: None,
495 prune_state: PruneState::Intact,
496 reasoning_meta: None,
497 response_meta,
498 });
499 }
500 if parts_out.is_empty() {
501 actions.push(DriverAction::Emit(make_error_event(
502 "llm_provider",
503 Some("empty_response"),
504 "Model returned no assistant text or tool calls.",
505 None,
506 )));
507 actions.push(DriverAction::Finish(TurnOutcome::Stopped(
508 TurnStop::ProviderError,
509 )));
510 return actions;
511 }
512 actions.push(DriverAction::StartCheckpoint {
513 checkpoint: CheckpointKind::BeforeCompletion,
514 on_empty: CheckpointResumeAction::Finish(TurnOutcome::Finished(
515 TurnFinish::AssistantMessage {
516 text: assistant_text.clone(),
517 },
518 )),
519 });
520 return actions;
521 }
522
523 let asst_id = fresh_message_id();
524 let mut assistant_parts = Vec::new();
525 for (content, response_meta) in assistant_text_parts {
526 if content.trim().is_empty() {
527 continue;
528 }
529 assistant_parts.push(Part {
530 id: format!("{}.p{}", asst_id, assistant_parts.len()),
531 kind: PartKind::Prose,
532 content,
533 attachment: None,
534 tool_call_id: None,
535 tool_name: None,
536 tool_replay: None,
537 prune_state: PruneState::Intact,
538 reasoning_meta: None,
539 response_meta,
540 });
541 }
542
543 let mut calls = Vec::new();
544 let mut reasoning_iter = reasoning_items.into_iter().peekable();
549 for (tool_index, tool_call) in tool_calls.into_iter().enumerate() {
550 while let Some((insert_index, _, _)) = reasoning_iter.peek() {
551 if *insert_index > tool_index {
552 break;
553 }
554 let (_, meta, text) = reasoning_iter.next().expect("peek ok");
555 assistant_parts.push(reasoning_part(&asst_id, assistant_parts.len(), text, meta));
556 }
557 assistant_parts.push(Part {
558 id: format!("{}.p{}", asst_id, assistant_parts.len()),
559 kind: PartKind::ToolCall,
560 content: tool_call.input_json.clone(),
561 attachment: None,
562 tool_call_id: Some(tool_call.call_id.clone()),
563 tool_name: Some(tool_call.tool_name.clone()),
564 tool_replay: tool_call.replay.clone(),
565 prune_state: PruneState::Intact,
566 reasoning_meta: None,
567 response_meta: None,
568 });
569
570 let args = serde_json::from_str::<Value>(&tool_call.input_json)
571 .unwrap_or_else(|_| serde_json::json!({}));
572 calls.push(PendingToolCall {
573 call_id: tool_call.call_id,
574 tool_name: tool_call.tool_name,
575 args,
576 replay: tool_call.replay,
577 });
578 }
579 for (_, meta, text) in reasoning_iter {
580 assistant_parts.push(reasoning_part(&asst_id, assistant_parts.len(), text, meta));
581 }
582
583 if !assistant_parts.is_empty() {
584 actions.push(DriverAction::AppendEvents(vec![conversation_event(
585 Message {
586 id: asst_id,
587 role: MessageRole::Assistant,
588 parts: shared_parts(assistant_parts),
589 origin: None,
590 },
591 )]));
592 }
593
594 actions.push(DriverAction::StartTools { calls });
595 actions
596 }
597
598 fn handle_tool_results(
599 &self,
600 ctx: DriverContextView<'_>,
601 completed: Vec<CompletedToolCall>,
602 ) -> Vec<DriverAction> {
603 let mut actions = Vec::new();
604 let mut result_parts = Vec::new();
605 let mut terminal_outcome = None;
606
607 for outcome in completed {
608 if terminal_outcome.is_none() && outcome.output.is_success() {
609 terminal_outcome = match outcome.output.control.as_ref() {
610 Some(lash_core::ToolControl::SwitchAgentFrame {
611 frame_id,
612 task: Some(task),
613 ..
614 }) if !frame_id.trim().is_empty() && !task.trim().is_empty() => {
615 Some(TurnOutcome::AgentFrameSwitch {
616 frame_id: frame_id.clone(),
617 task: task.clone(),
618 })
619 }
620 Some(lash_core::ToolControl::Finish { value }) => {
621 Some(TurnOutcome::Finished(TurnFinish::ToolValue {
622 tool_name: outcome.tool_name.clone(),
623 value: value.to_json_value(),
624 }))
625 }
626 Some(lash_core::ToolControl::Fail { failure }) => {
627 Some(TurnOutcome::Stopped(TurnStop::ToolError {
628 tool_name: outcome.tool_name.clone(),
629 value: failure.to_json_value(),
630 }))
631 }
632 _ => None,
633 };
634 }
635
636 append_model_return_parts(&mut result_parts, outcome.model_return);
637 }
638
639 if !result_parts.is_empty() {
640 let user_id = fresh_message_id();
641 reassign_part_ids(&user_id, &mut result_parts);
642 actions.push(DriverAction::AppendEvents(vec![conversation_event(
643 Message {
644 id: user_id,
645 role: MessageRole::User,
646 parts: shared_parts(result_parts),
647 origin: None,
648 },
649 )]));
650 }
651
652 if let Some(outcome) = terminal_outcome {
653 actions.push(DriverAction::Finish(outcome));
654 return actions;
655 }
656
657 actions.push(DriverAction::AdvanceProtocolIteration);
658 let next_protocol_iteration = ctx.protocol_iteration() + 1;
659 if let Some(max_turns) = ctx.max_turns()
660 && next_protocol_iteration >= ctx.protocol_run_offset() + max_turns
661 {
662 let message_id = fresh_message_id();
663 actions.push(DriverAction::AppendEvents(vec![conversation_event(
664 turn_limit_exhausted_message(message_id, max_turns),
665 )]));
666 actions.push(DriverAction::Finish(TurnOutcome::Stopped(
667 TurnStop::MaxTurns,
668 )));
669 return actions;
670 }
671
672 actions.push(DriverAction::StartCheckpoint {
673 checkpoint: CheckpointKind::AfterWork,
674 on_empty: CheckpointResumeAction::PrepareIteration,
675 });
676 actions
677 }
678
679 fn handle_exec_result(
680 &self,
681 _ctx: DriverContextView<'_>,
682 _waiting: WaitingExecState<lash_core::HostTurnProtocol>,
683 _result: Result<lash_core::ExecResponse, String>,
684 ) -> Vec<DriverAction> {
685 Vec::new()
686 }
687}
688
689fn append_model_return_parts(parts: &mut Vec<Part>, model_return: lash_core::ModelToolReturn) {
690 for part in model_return.parts {
691 match part {
692 lash_core::ModelToolReturnPart::Text { text } => {
693 if text.is_empty() {
694 continue;
695 }
696 parts.push(Part {
697 id: String::new(),
698 kind: PartKind::ToolResult,
699 content: text,
700 attachment: None,
701 tool_call_id: Some(model_return.call_id.clone()),
702 tool_name: Some(model_return.tool_name.clone()),
703 tool_replay: None,
704 prune_state: PruneState::Intact,
705 reasoning_meta: None,
706 response_meta: None,
707 });
708 }
709 lash_core::ModelToolReturnPart::Attachment(reference) => {
710 parts.push(Part {
711 id: String::new(),
712 kind: PartKind::Image,
713 content: String::new(),
714 attachment: Some(PartAttachment { reference }),
715 tool_call_id: Some(model_return.call_id.clone()),
716 tool_name: Some(model_return.tool_name.clone()),
717 tool_replay: None,
718 prune_state: PruneState::Intact,
719 reasoning_meta: None,
720 response_meta: None,
721 });
722 }
723 }
724 }
725}
726
727fn conversation_event(message: Message) -> SessionEventRecord {
728 SessionEventRecord::Conversation(ConversationRecord::from_message(message))
729}
730
731#[cfg(test)]
732mod tests {
733 use super::*;
734 use lash_core::{
735 AttachmentId, AttachmentMeta, ImageMediaType, MediaType, ModelToolReturn, ToolCallOutput,
736 ToolValue,
737 };
738 use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
739 use tokio::sync::Barrier;
740 use tokio::time::{Duration, timeout};
741
742 fn image_ref(id: &str) -> lash_core::AttachmentRef {
743 AttachmentMeta::new(
744 AttachmentId::new(id),
745 MediaType::Image(ImageMediaType::Png),
746 4,
747 Some(1),
748 Some(1),
749 Some("tiny".to_string()),
750 )
751 .as_ref()
752 }
753
754 #[derive(Clone, Debug)]
755 struct BatchRuntimeProvider {
756 calls: Arc<AtomicUsize>,
757 saw_batch_result: Arc<AtomicBool>,
758 }
759
760 #[async_trait::async_trait]
761 impl lash_core::Provider for BatchRuntimeProvider {
762 fn kind(&self) -> &'static str {
763 "stub"
764 }
765
766 fn options(&self) -> lash_core::ProviderOptions {
767 lash_core::ProviderOptions::default()
768 }
769
770 fn set_options(&mut self, _options: lash_core::ProviderOptions) {}
771
772 fn serialize_config(&self) -> serde_json::Value {
773 serde_json::json!({})
774 }
775
776 async fn complete(
777 &mut self,
778 request: lash_core::LlmRequest,
779 ) -> Result<lash_core::LlmResponse, lash_core::LlmTransportError> {
780 let call_index = self.calls.fetch_add(1, Ordering::SeqCst);
781 if call_index == 0 {
782 return Ok(lash_core::LlmResponse {
783 parts: vec![lash_core::LlmOutputPart::ToolCall {
784 call_id: "batch-call".to_string(),
785 tool_name: "batch".to_string(),
786 input_json: serde_json::json!({
787 "tool_calls": [
788 {"tool": "alpha", "parameters": {}},
789 {"tool": "beta", "parameters": {"value": "fail"}}
790 ]
791 })
792 .to_string(),
793 replay: None,
794 }],
795 ..lash_core::LlmResponse::default()
796 });
797 }
798
799 let projected_messages = format!("{:?}", request.messages);
800 if projected_messages.contains("alpha") && projected_messages.contains("beta failed") {
801 self.saw_batch_result.store(true, Ordering::SeqCst);
802 }
803 Ok(lash_core::LlmResponse {
804 full_text: "done".to_string(),
805 parts: vec![lash_core::LlmOutputPart::Text {
806 text: "done".to_string(),
807 response_meta: None,
808 }],
809 ..lash_core::LlmResponse::default()
810 })
811 }
812
813 fn clone_boxed(&self) -> Box<dyn lash_core::Provider> {
814 Box::new(self.clone())
815 }
816 }
817
818 #[derive(Debug)]
819 struct BatchRuntimeTools {
820 barrier: Arc<Barrier>,
821 started: Arc<AtomicUsize>,
822 }
823
824 fn runtime_test_tool(name: &str) -> lash_core::ToolDefinition {
825 lash_core::ToolDefinition::raw(
826 format!("tool:{name}"),
827 name,
828 "",
829 serde_json::json!({
830 "type": "object",
831 "properties": {
832 "value": { "type": "string" }
833 },
834 "additionalProperties": true
835 }),
836 serde_json::json!({ "type": "string" }),
837 )
838 .with_scheduling(lash_core::ToolScheduling::Parallel)
839 }
840
841 #[async_trait::async_trait]
842 impl ToolProvider for BatchRuntimeTools {
843 fn tool_manifests(&self) -> Vec<ToolManifest> {
844 vec![
845 runtime_test_tool("alpha").manifest(),
846 runtime_test_tool("beta").manifest(),
847 ]
848 }
849
850 fn resolve_contract(&self, name: &str) -> Option<Arc<ToolContract>> {
851 match name {
852 "alpha" | "beta" => Some(Arc::new(runtime_test_tool(name).contract())),
853 _ => None,
854 }
855 }
856
857 async fn execute(&self, call: ToolCall<'_>) -> ToolResult {
858 self.started.fetch_add(1, Ordering::SeqCst);
859 if timeout(Duration::from_millis(100), self.barrier.wait())
860 .await
861 .is_err()
862 {
863 return ToolResult::err_fmt("batch child tools did not run concurrently");
864 }
865 if call.name == "beta"
866 && call.args.get("value").and_then(|value| value.as_str()) == Some("fail")
867 {
868 return ToolResult::err_fmt("beta failed");
869 }
870 ToolResult::ok(serde_json::json!(call.name))
871 }
872 }
873
874 #[derive(Clone, Default)]
875 struct CountingEffectController {
876 kinds: Arc<std::sync::Mutex<Vec<lash_core::RuntimeEffectKind>>>,
877 }
878
879 impl CountingEffectController {
880 fn count(&self, kind: lash_core::RuntimeEffectKind) -> usize {
881 self.kinds
882 .lock()
883 .expect("effect kinds")
884 .iter()
885 .filter(|candidate| **candidate == kind)
886 .count()
887 }
888 }
889
890 #[derive(Default)]
891 struct DurableMemoryAttachmentStore {
892 inner: lash_core::InMemoryAttachmentStore,
893 }
894
895 #[async_trait::async_trait]
896 impl lash_core::AttachmentStore for DurableMemoryAttachmentStore {
897 fn persistence(&self) -> lash_core::AttachmentStorePersistence {
898 lash_core::AttachmentStorePersistence::Durable
899 }
900
901 async fn put(
902 &self,
903 bytes: Vec<u8>,
904 meta: lash_core::AttachmentCreateMeta,
905 ) -> Result<lash_core::AttachmentRef, lash_core::AttachmentStoreError> {
906 self.inner.put(bytes, meta).await
907 }
908
909 async fn get(
910 &self,
911 id: &lash_core::AttachmentId,
912 ) -> Result<lash_core::StoredAttachment, lash_core::AttachmentStoreError> {
913 self.inner.get(id).await
914 }
915 }
916
917 #[derive(Default)]
918 struct DurableMemoryProcessEnvStore {
919 inner: lash_core::InMemoryProcessExecutionEnvStore,
920 }
921
922 #[async_trait::async_trait]
923 impl lash_core::ProcessExecutionEnvStore for DurableMemoryProcessEnvStore {
924 fn durability_tier(&self) -> lash_core::DurabilityTier {
925 lash_core::DurabilityTier::Durable
926 }
927
928 async fn put_process_execution_env(
929 &self,
930 env_ref: &lash_core::ProcessExecutionEnvRef,
931 bytes: &[u8],
932 ) -> Result<(), lash_core::PluginError> {
933 self.inner.put_process_execution_env(env_ref, bytes).await
934 }
935
936 async fn get_process_execution_env(
937 &self,
938 env_ref: &lash_core::ProcessExecutionEnvRef,
939 ) -> Result<Option<Vec<u8>>, lash_core::PluginError> {
940 self.inner.get_process_execution_env(env_ref).await
941 }
942 }
943
944 #[async_trait::async_trait]
945 impl lash_core::RuntimeEffectController for CountingEffectController {
946 fn durability_tier(&self) -> lash_core::DurabilityTier {
947 lash_core::DurabilityTier::Durable
948 }
949
950 async fn execute_effect(
951 &self,
952 envelope: lash_core::RuntimeEffectEnvelope,
953 local_executor: lash_core::RuntimeEffectLocalExecutor<'_>,
954 ) -> Result<lash_core::RuntimeEffectOutcome, lash_core::RuntimeEffectControllerError>
955 {
956 self.kinds
957 .lock()
958 .expect("effect kinds")
959 .push(envelope.command.kind());
960 local_executor.execute(envelope).await
961 }
962 }
963
964 #[tokio::test]
965 async fn standard_batch_tool_uses_core_tool_batch_under_durable_boundary() {
966 let provider_calls = Arc::new(AtomicUsize::new(0));
967 let saw_batch_result = Arc::new(AtomicBool::new(false));
968 let provider = BatchRuntimeProvider {
969 calls: Arc::clone(&provider_calls),
970 saw_batch_result: Arc::clone(&saw_batch_result),
971 };
972 let provider_handle = lash_core::ProviderHandle::new(lash_core::ProviderComponents::new(
973 Box::new(provider),
974 Arc::new(lash_core::StaticModelPolicy::new()),
975 ));
976 let mut host = lash_core::RuntimeHostConfig::in_memory();
977 host.providers.provider_resolver =
978 Arc::new(lash_core::SingleProviderResolver::new(provider_handle));
979 host.durability.attachment_store = Arc::new(DurableMemoryAttachmentStore::default());
980 host.durability.process_env_store = Arc::new(DurableMemoryProcessEnvStore::default());
981 let started = Arc::new(AtomicUsize::new(0));
982 let factories: Vec<Arc<dyn lash_core::PluginFactory>> = vec![
983 Arc::new(StandardProtocolPluginFactory::new()),
984 Arc::new(lash_core::plugin::StaticPluginFactory::new(
985 "standard-batch-test-tools",
986 lash_core::PluginSpec::new().with_tool_provider(Arc::new(BatchRuntimeTools {
987 barrier: Arc::new(Barrier::new(2)),
988 started: Arc::clone(&started),
989 })),
990 )),
991 ];
992 let policy = lash_core::SessionPolicy {
993 provider_id: "stub".to_string(),
994 model: lash_core::ModelSpec::from_token_limits("mock-model", None, 200_000, None)
995 .expect("valid model"),
996 ..lash_core::SessionPolicy::default()
997 };
998 let controller = CountingEffectController::default();
999 let scoped_controller = lash_core::ScopedEffectController::shared(
1000 Arc::new(controller.clone()),
1001 lash_core::ExecutionScope::turn("standard-batch-session", "turn-1"),
1002 )
1003 .expect("scoped controller");
1004 let mut runtime = lash_core::LashRuntime::builder()
1005 .with_session_id("standard-batch-session")
1006 .with_policy(policy)
1007 .with_runtime_host(host)
1008 .with_plugin_factories(factories)
1009 .build()
1010 .await
1011 .expect("runtime");
1012
1013 let turn = runtime
1014 .stream_turn(
1015 lash_core::TurnInput::text("run the batch"),
1016 lash_core::TurnOptions::new(
1017 tokio_util::sync::CancellationToken::new(),
1018 scoped_controller,
1019 ),
1020 )
1021 .await
1022 .expect("turn");
1023
1024 assert!(matches!(turn.outcome, lash_core::TurnOutcome::Finished(_)));
1025 assert_eq!(provider_calls.load(Ordering::SeqCst), 2);
1026 assert_eq!(started.load(Ordering::SeqCst), 2);
1027 assert!(saw_batch_result.load(Ordering::SeqCst));
1028 assert_eq!(controller.count(lash_core::RuntimeEffectKind::ToolBatch), 1);
1029 }
1030
1031 #[test]
1032 fn tool_attachment_round_trips_to_part_kind_image() {
1033 let attachment = image_ref("att-1");
1034 let output = ToolCallOutput::success(ToolValue::Attachment(attachment.clone()));
1035 let model_return =
1036 ModelToolReturn::from_output("call-9".to_string(), "screenshot".to_string(), &output);
1037
1038 let mut parts: Vec<Part> = Vec::new();
1039 append_model_return_parts(&mut parts, model_return);
1040
1041 assert_eq!(parts.len(), 1, "single attachment yields single part");
1042 let part = &parts[0];
1043 assert!(matches!(part.kind, PartKind::Image));
1044 assert_eq!(part.content, "");
1045 assert_eq!(part.tool_call_id.as_deref(), Some("call-9"));
1046 assert_eq!(part.tool_name.as_deref(), Some("screenshot"));
1047 let part_attachment = part.attachment.as_ref().expect("attachment present");
1048 assert_eq!(part_attachment.reference.id, attachment.id);
1049 }
1050
1051 #[test]
1052 fn tool_text_and_attachment_round_trip_preserves_order() {
1053 let attachment = image_ref("att-2");
1054 let output = ToolCallOutput::success(ToolValue::Array(vec![
1055 ToolValue::String("before".into()),
1056 ToolValue::Attachment(attachment.clone()),
1057 ToolValue::String("after".into()),
1058 ]));
1059 let model_return =
1060 ModelToolReturn::from_output("call-10".to_string(), "snap".to_string(), &output);
1061
1062 let mut parts: Vec<Part> = Vec::new();
1063 append_model_return_parts(&mut parts, model_return);
1064
1065 assert_eq!(parts.len(), 3, "text + image + text yields three parts");
1068 assert!(matches!(parts[0].kind, PartKind::ToolResult));
1069 assert!(parts[0].content.starts_with("[\"before\""));
1070 assert!(matches!(parts[1].kind, PartKind::Image));
1071 assert_eq!(
1072 parts[1]
1073 .attachment
1074 .as_ref()
1075 .expect("attachment")
1076 .reference
1077 .id,
1078 attachment.id
1079 );
1080 assert!(matches!(parts[2].kind, PartKind::ToolResult));
1081 assert!(parts[2].content.ends_with("\"after\"]"));
1082 }
1083}