Skip to main content

lash_protocol_standard/
lib.rs

1//! Standard protocol stack: the model drives tools via the native
2//! function-calling envelope of its LLM transport.
3//!
4//! This crate owns:
5//!
6//! - [`StandardDriver`] — the [`ProtocolDriverHandle`] that dispatches
7//!   native tool calls and weaves reasoning parts into the assistant
8//!   message timeline.
9//! - The [`StandardProtocolPluginFactory`] plugin that claims the
10//!   protocol-driver slot so the runtime can run standard-protocol
11//!   sessions.
12//! - The `batch` tool that composes parallel native tool calls (only
13//!   exposed when this protocol stack is installed).
14
15use std::sync::Arc;
16
17use async_trait::async_trait;
18use lash_core::llm::types::{ProviderReasoningReplay, ProviderReplayMeta, ResponseTextMeta};
19use lash_core::plugin::{
20    PluginError, PluginFactory, PluginRegistrar, PluginSessionContext, ProtocolDriverPlugin,
21    ProtocolSessionContext, ProtocolSessionPlugin, SessionPlugin,
22};
23use lash_core::sansio::{
24    CheckpointResumeAction, CompletedToolCall, PendingToolCall, ProtocolDriverHandle,
25    WaitingExecState, WaitingLlmState,
26};
27use lash_core::session_model::message::PartAttachment;
28use lash_core::session_model::{
29    ConversationRecord, Message, MessageRole, Part, PartKind, PruneState, SessionEvent,
30    SessionEventRecord, fresh_message_id, make_error_event, reassign_part_ids, shared_parts,
31};
32
33mod batch;
34use batch::batch_tool_definition;
35use lash_core::{
36    CheckpointKind, DriverAction, DriverContextView, LlmOutputPart, LlmResponse,
37    ProtocolBuildInput, SessionError, ToolCall, ToolContract, ToolInvocation, ToolManifest,
38    ToolProvider, ToolResult, TurnDriverConfig, TurnDriverPreamble, TurnFinish, TurnOutcome,
39    TurnStop, append_assistant_text_part, normalized_response_parts, reasoning_part,
40};
41use serde_json::Value;
42
43const STANDARD_EXECUTION_SECTION: &str = r#"Use direct tool calls.
44
45- Use `batch` (up to 25 calls) for two or more independent tool calls. Serialize calls when later arguments depend on earlier results.
46- For direct conversational requests that need no tools, respond in prose only.
47
48Example — two independent reads in one `batch` call:
49
50```json
51{
52  "tool_calls": [
53    { "tool": "read_file", "parameters": { "path": "src/main.rs" } },
54    { "tool": "grep", "parameters": { "query": "ToolProvider", "path": "crates/lash/src/" } }
55  ]
56}
57```"#;
58
59const BATCH_MAX_TOOL_CALLS: usize = 25;
60
61/// Plugin factory that installs the standard-protocol driver,
62/// session plugin, and native tool catalog.
63#[derive(Default)]
64pub struct StandardProtocolPluginFactory;
65
66impl StandardProtocolPluginFactory {
67    pub fn new() -> Self {
68        Self
69    }
70}
71
72impl PluginFactory for StandardProtocolPluginFactory {
73    fn id(&self) -> &'static str {
74        "standard_protocol"
75    }
76
77    fn build(&self, _ctx: &PluginSessionContext) -> Result<Arc<dyn SessionPlugin>, PluginError> {
78        Ok(Arc::new(StandardProtocolPlugin))
79    }
80}
81
82struct StandardProtocolPlugin;
83
84impl SessionPlugin for StandardProtocolPlugin {
85    fn id(&self) -> &'static str {
86        "standard_protocol"
87    }
88
89    fn register(&self, reg: &mut PluginRegistrar) -> Result<(), PluginError> {
90        reg.protocol().session(Arc::new(StandardProtocolSession))?;
91        reg.protocol()
92            .protocol_driver(Arc::new(StandardProtocolDriver))?;
93        reg.tools().provider(Arc::new(StandardProtocolTools))?;
94        Ok(())
95    }
96}
97
98struct StandardProtocolSession;
99
100#[async_trait]
101impl ProtocolSessionPlugin for StandardProtocolSession {
102    async fn initialize_session(
103        &self,
104        _ctx: ProtocolSessionContext<'_>,
105    ) -> Result<(), SessionError> {
106        Ok(())
107    }
108}
109
110struct StandardProtocolDriver;
111
112impl ProtocolDriverPlugin for StandardProtocolDriver {
113    fn build_preamble(&self, input: ProtocolBuildInput) -> TurnDriverPreamble {
114        let tool_names = input.tool_catalog.tool_names();
115        let tool_names_fingerprint = input.tool_catalog.tool_names_fingerprint();
116        TurnDriverPreamble {
117            config: TurnDriverConfig::chat(
118                Arc::new(StandardDriver),
119                true,
120                Arc::new(turn_limit_exhausted_message),
121            ),
122            tool_specs: input.tool_catalog.model_tool_specs(),
123            tool_names,
124            tool_names_fingerprint,
125            execution_prompt: Arc::from(STANDARD_EXECUTION_SECTION),
126            prompt_contributions: input.extra_prompt_contributions,
127        }
128    }
129}
130
131fn turn_limit_exhausted_message(message_id: String, max_turns: usize) -> Message {
132    Message {
133        id: message_id.clone(),
134        role: MessageRole::System,
135        parts: shared_parts(vec![Part {
136            id: format!("{message_id}.p0"),
137            kind: PartKind::Error,
138            content: format!("Turn limit reached ({max_turns}) before a final assistant response."),
139            attachment: None,
140            tool_call_id: None,
141            tool_name: None,
142            tool_replay: None,
143            prune_state: PruneState::Intact,
144            reasoning_meta: None,
145            response_meta: None,
146        }]),
147        origin: None,
148    }
149}
150
151struct StandardProtocolTools;
152
153#[async_trait]
154impl ToolProvider for StandardProtocolTools {
155    fn tool_manifests(&self) -> Vec<ToolManifest> {
156        vec![batch_tool_definition().manifest()]
157    }
158
159    fn resolve_contract(&self, name: &str) -> Option<Arc<ToolContract>> {
160        (name == "batch").then(|| Arc::new(batch_tool_definition().contract()))
161    }
162
163    async fn execute(&self, call: ToolCall<'_>) -> ToolResult {
164        match call.name {
165            "batch" => execute_batch_tool_call(call).await,
166            _ => ToolResult::err_fmt(format_args!("Unknown tool: {}", call.name)),
167        }
168    }
169}
170
171#[derive(Debug)]
172struct BatchCallSpec {
173    index: usize,
174    tool: String,
175    parameters: Value,
176}
177
178async fn execute_batch_tool_call(call: ToolCall<'_>) -> ToolResult {
179    let args = call.args;
180    let specs = match parse_batch_specs(args) {
181        Ok(specs) => specs,
182        Err(err) => return err,
183    };
184
185    let mut immediate_outcomes = Vec::new();
186    let mut parallel_specs = Vec::new();
187    let dispatch = call.context.dispatch();
188
189    for spec in specs.into_iter().take(BATCH_MAX_TOOL_CALLS) {
190        if spec.tool == "batch" {
191            immediate_outcomes.push(serde_json::json!({
192                "index": spec.index,
193                "tool": spec.tool,
194                "success": false,
195                "duration_ms": 0,
196                "error": "Tool 'batch' is not allowed inside batch",
197            }));
198            continue;
199        }
200        let Some(manifest) = dispatch.callable_tool_manifest(&spec.tool) else {
201            let error = format!("Tool '{}' is unavailable in this session", spec.tool);
202            immediate_outcomes.push(serde_json::json!({
203                "index": spec.index,
204                "tool": spec.tool,
205                "success": false,
206                "duration_ms": 0,
207                "error": error,
208            }));
209            continue;
210        };
211        parallel_specs.push((
212            spec.index,
213            ToolInvocation::new(
214                format!(
215                    "{}:{:02}",
216                    call.context.tool_call_id().unwrap_or("batch"),
217                    spec.index
218                ),
219                manifest.id,
220                spec.parameters,
221            ),
222        ));
223    }
224
225    let mut parallel_outcomes = dispatch
226        .batch(
227            parallel_specs
228                .iter()
229                .map(|(_, invocation)| invocation.clone())
230                .collect(),
231        )
232        .await;
233    for ((index, invocation), outcome) in
234        parallel_specs.into_iter().zip(parallel_outcomes.drain(..))
235    {
236        let tool_label = invocation.label();
237        let tool_record = outcome.record.unwrap_or(lash_core::ToolCallRecord {
238            call_id: Some(invocation.id),
239            tool: tool_label,
240            args: invocation.args,
241            output: outcome.output,
242            duration_ms: 0,
243        });
244        let mut result_record = serde_json::Map::new();
245        result_record.insert("index".to_string(), serde_json::json!(index));
246        result_record.insert("tool".to_string(), serde_json::json!(tool_record.tool));
247        result_record.insert(
248            "success".to_string(),
249            serde_json::json!(tool_record.output.is_success()),
250        );
251        result_record.insert(
252            "duration_ms".to_string(),
253            serde_json::json!(tool_record.duration_ms),
254        );
255        result_record.insert(
256            if tool_record.output.is_success() {
257                "result".to_string()
258            } else {
259                "error".to_string()
260            },
261            tool_record.output.value_for_projection(),
262        );
263        immediate_outcomes.push(Value::Object(result_record));
264    }
265
266    for overflow_index in BATCH_MAX_TOOL_CALLS
267        ..args
268            .get("tool_calls")
269            .and_then(|value| value.as_array())
270            .map(|value| value.len())
271            .unwrap_or_default()
272    {
273        immediate_outcomes.push(serde_json::json!({
274            "index": overflow_index,
275            "tool": args
276                .get("tool_calls")
277                .and_then(|value| value.as_array())
278                .and_then(|items| items.get(overflow_index))
279                .and_then(|item| item.get("tool"))
280                .and_then(|value| value.as_str())
281                .unwrap_or("unknown"),
282            "success": false,
283            "duration_ms": 0,
284            "error": "Maximum of 25 tool calls allowed in batch",
285        }));
286    }
287
288    immediate_outcomes.sort_by_key(|outcome| {
289        outcome
290            .get("index")
291            .and_then(|value| value.as_u64())
292            .unwrap_or(u64::MAX)
293    });
294    ToolResult::ok(serde_json::json!({
295        "results": immediate_outcomes,
296    }))
297}
298
299#[allow(clippy::result_large_err)]
300fn parse_batch_specs(args: &Value) -> Result<Vec<BatchCallSpec>, ToolResult> {
301    let Some(raw_calls) = args.get("tool_calls").and_then(|value| value.as_array()) else {
302        return Err(ToolResult::err_fmt(
303            "Missing required parameter: tool_calls",
304        ));
305    };
306    if raw_calls.is_empty() {
307        return Err(ToolResult::err_fmt(
308            "Invalid tool_calls: expected at least one call",
309        ));
310    }
311
312    let mut specs = Vec::with_capacity(raw_calls.len());
313    for (index, item) in raw_calls.iter().enumerate() {
314        let Some(object) = item.as_object() else {
315            return Err(ToolResult::err_fmt(format_args!(
316                "Invalid tool_calls[{index}]: expected object with tool and parameters"
317            )));
318        };
319        let Some(tool) = object
320            .get("tool")
321            .and_then(|value| value.as_str())
322            .map(str::trim)
323            .filter(|tool| !tool.is_empty())
324        else {
325            return Err(ToolResult::err_fmt(format_args!(
326                "Invalid tool_calls[{index}].tool: expected non-empty string"
327            )));
328        };
329        let parameters = object
330            .get("parameters")
331            .cloned()
332            .unwrap_or_else(|| serde_json::json!({}));
333        specs.push(BatchCallSpec {
334            index,
335            tool: tool.to_string(),
336            parameters,
337        });
338    }
339
340    Ok(specs)
341}
342
343// ─────────────────────────────────────────────────────────────────────
344// Standard protocol driver
345// ─────────────────────────────────────────────────────────────────────
346
347/// Protocol driver for the Standard protocol. Consumes native
348/// tool-call envelopes from the LLM, dispatches them via
349/// `DriverAction::StartTools`, and splices reasoning parts into the
350/// assistant message so provider replay metadata preserves
351/// chain-of-thought ordering.
352pub struct StandardDriver;
353
354struct StandardToolCall {
355    call_id: String,
356    tool_name: String,
357    input_json: String,
358    replay: Option<ProviderReplayMeta>,
359}
360
361fn last_message_has_tool_result(ctx: &DriverContextView<'_>) -> bool {
362    ctx.messages().last().is_some_and(|message| {
363        matches!(message.role, MessageRole::User)
364            && message
365                .parts
366                .iter()
367                .any(|part| matches!(part.kind, PartKind::ToolResult))
368    })
369}
370
371impl ProtocolDriverHandle<lash_core::HostTurnProtocol> for StandardDriver {
372    fn prepare_protocol_iteration(&self, ctx: DriverContextView<'_>) -> Vec<DriverAction> {
373        vec![DriverAction::StartLlm {
374            request: ctx.project_llm_request(true),
375            driver_state: None,
376        }]
377    }
378
379    fn handle_llm_success(
380        &self,
381        ctx: DriverContextView<'_>,
382        _waiting: WaitingLlmState<lash_core::HostTurnProtocol>,
383        llm_response: LlmResponse,
384        text_streamed: bool,
385    ) -> Vec<DriverAction> {
386        let response_parts = normalized_response_parts(&llm_response);
387        let mut assistant_text = String::new();
388        let mut assistant_text_parts: Vec<(String, Option<ResponseTextMeta>)> = Vec::new();
389        let mut tool_calls: Vec<StandardToolCall> = Vec::new();
390        // Reasoning items captured with their position in the original
391        // response. The `usize` is the index in `tool_calls` that this
392        // reasoning item originally preceded, so we can interleave
393        // reasoning → tool_call in the provider's original emission order.
394        // `Option<ProviderReasoningReplay>` carries roundtrip payload
395        // when present (fix 1.3b); when None, the item is display-only
396        // (fix 1.3a) — still rendered in the UI but never re-fed.
397        let mut reasoning_items: Vec<(usize, Option<ProviderReasoningReplay>, String)> = Vec::new();
398        let mut actions = Vec::new();
399
400        for part in response_parts {
401            match part {
402                LlmOutputPart::Text {
403                    text,
404                    response_meta,
405                } => {
406                    if !text.is_empty() {
407                        let previous_len = assistant_text.len();
408                        append_assistant_text_part(&mut assistant_text, &text);
409                        assistant_text_parts
410                            .push((assistant_text[previous_len..].to_string(), response_meta));
411                        if !text_streamed {
412                            actions.push(DriverAction::Emit(SessionEvent::TextDelta {
413                                content: assistant_text[previous_len..].to_string(),
414                            }));
415                        }
416                    }
417                }
418                LlmOutputPart::Reasoning { text, replay } => {
419                    let trimmed = text.trim().to_string();
420                    // Skip fully-empty reasoning items (no display text and
421                    // no roundtrip payload).
422                    if trimmed.is_empty() && replay.as_ref().is_none_or(|meta| meta.is_empty()) {
423                        continue;
424                    }
425                    reasoning_items.push((tool_calls.len(), replay, trimmed));
426                }
427                LlmOutputPart::ToolCall {
428                    call_id,
429                    tool_name,
430                    input_json,
431                    replay,
432                } => {
433                    tool_calls.push(StandardToolCall {
434                        call_id,
435                        tool_name,
436                        input_json,
437                        replay,
438                    });
439                }
440            }
441        }
442
443        actions.push(DriverAction::Emit(SessionEvent::LlmResponse {
444            protocol_iteration: ctx.protocol_iteration(),
445            content: assistant_text.clone(),
446            duration_ms: 0,
447        }));
448
449        if tool_calls.is_empty() {
450            if assistant_text.trim().is_empty() && reasoning_items.is_empty() {
451                if last_message_has_tool_result(&ctx) {
452                    // A model can intentionally complete a tool-only request
453                    // with an empty final answer, e.g. when the user says
454                    // "do nothing else" after the tool action.
455                    actions.push(DriverAction::StartCheckpoint {
456                        checkpoint: CheckpointKind::BeforeCompletion,
457                        on_empty: CheckpointResumeAction::Finish(TurnOutcome::Finished(
458                            TurnFinish::AssistantMessage {
459                                text: String::new(),
460                            },
461                        )),
462                    });
463                    return actions;
464                }
465                actions.push(DriverAction::Emit(make_error_event(
466                    "llm_provider",
467                    Some("empty_response"),
468                    "Model returned no assistant text or tool calls.",
469                    None,
470                )));
471                actions.push(DriverAction::Finish(TurnOutcome::Stopped(
472                    TurnStop::ProviderError,
473                )));
474                return actions;
475            }
476
477            let asst_id = fresh_message_id();
478            let mut parts_out = Vec::new();
479            for (_, meta, text) in reasoning_items {
480                parts_out.push(reasoning_part(&asst_id, parts_out.len(), text, meta));
481            }
482            for (content, response_meta) in assistant_text_parts {
483                if content.trim().is_empty() {
484                    continue;
485                }
486                parts_out.push(Part {
487                    id: format!("{}.p{}", asst_id, parts_out.len()),
488                    kind: PartKind::Prose,
489                    content,
490                    attachment: None,
491                    tool_call_id: None,
492                    tool_name: None,
493                    tool_replay: None,
494                    prune_state: PruneState::Intact,
495                    reasoning_meta: None,
496                    response_meta,
497                });
498            }
499            if parts_out.is_empty() {
500                actions.push(DriverAction::Emit(make_error_event(
501                    "llm_provider",
502                    Some("empty_response"),
503                    "Model returned no assistant text or tool calls.",
504                    None,
505                )));
506                actions.push(DriverAction::Finish(TurnOutcome::Stopped(
507                    TurnStop::ProviderError,
508                )));
509                return actions;
510            }
511            actions.push(DriverAction::StartCheckpoint {
512                checkpoint: CheckpointKind::BeforeCompletion,
513                on_empty: CheckpointResumeAction::Finish(TurnOutcome::Finished(
514                    TurnFinish::AssistantMessage {
515                        text: assistant_text.clone(),
516                    },
517                )),
518            });
519            return actions;
520        }
521
522        let asst_id = fresh_message_id();
523        let mut assistant_parts = Vec::new();
524        for (content, response_meta) in assistant_text_parts {
525            if content.trim().is_empty() {
526                continue;
527            }
528            assistant_parts.push(Part {
529                id: format!("{}.p{}", asst_id, assistant_parts.len()),
530                kind: PartKind::Prose,
531                content,
532                attachment: None,
533                tool_call_id: None,
534                tool_name: None,
535                tool_replay: None,
536                prune_state: PruneState::Intact,
537                reasoning_meta: None,
538                response_meta,
539            });
540        }
541
542        let mut calls = Vec::new();
543        // Interleave reasoning items with tool calls to preserve the
544        // original emission order. Some provider replays expect the
545        // sequence `reasoning → function_call` from the turn in which both
546        // were produced; swapping them can drop the reasoning/tool pairing.
547        let mut reasoning_iter = reasoning_items.into_iter().peekable();
548        for (tool_index, tool_call) in tool_calls.into_iter().enumerate() {
549            while let Some((insert_index, _, _)) = reasoning_iter.peek() {
550                if *insert_index > tool_index {
551                    break;
552                }
553                let (_, meta, text) = reasoning_iter.next().expect("peek ok");
554                assistant_parts.push(reasoning_part(&asst_id, assistant_parts.len(), text, meta));
555            }
556            assistant_parts.push(Part {
557                id: format!("{}.p{}", asst_id, assistant_parts.len()),
558                kind: PartKind::ToolCall,
559                content: tool_call.input_json.clone(),
560                attachment: None,
561                tool_call_id: Some(tool_call.call_id.clone()),
562                tool_name: Some(tool_call.tool_name.clone()),
563                tool_replay: tool_call.replay.clone(),
564                prune_state: PruneState::Intact,
565                reasoning_meta: None,
566                response_meta: None,
567            });
568
569            let args = serde_json::from_str::<Value>(&tool_call.input_json)
570                .unwrap_or_else(|_| serde_json::json!({}));
571            calls.push(PendingToolCall {
572                call_id: tool_call.call_id,
573                tool_name: tool_call.tool_name,
574                args,
575                replay: tool_call.replay,
576            });
577        }
578        for (_, meta, text) in reasoning_iter {
579            assistant_parts.push(reasoning_part(&asst_id, assistant_parts.len(), text, meta));
580        }
581
582        if !assistant_parts.is_empty() {
583            actions.push(DriverAction::AppendEvents(vec![conversation_event(
584                Message {
585                    id: asst_id,
586                    role: MessageRole::Assistant,
587                    parts: shared_parts(assistant_parts),
588                    origin: None,
589                },
590            )]));
591        }
592
593        actions.push(DriverAction::StartTools { calls });
594        actions
595    }
596
597    fn handle_tool_results(
598        &self,
599        ctx: DriverContextView<'_>,
600        completed: Vec<CompletedToolCall>,
601    ) -> Vec<DriverAction> {
602        let mut actions = Vec::new();
603        let mut result_parts = Vec::new();
604        let mut terminal_outcome = None;
605
606        for outcome in completed {
607            if terminal_outcome.is_none() && outcome.output.is_success() {
608                terminal_outcome = match outcome.output.control.as_ref() {
609                    Some(lash_core::ToolControl::SwitchAgentFrame {
610                        frame_id,
611                        task: Some(task),
612                        ..
613                    }) if !frame_id.trim().is_empty() && !task.trim().is_empty() => {
614                        Some(TurnOutcome::AgentFrameSwitch {
615                            frame_id: frame_id.clone(),
616                            task: task.clone(),
617                        })
618                    }
619                    Some(lash_core::ToolControl::Finish { value }) => {
620                        Some(TurnOutcome::Finished(TurnFinish::ToolValue {
621                            tool_name: outcome.tool_name.clone(),
622                            value: value.to_json_value(),
623                        }))
624                    }
625                    Some(lash_core::ToolControl::Fail { failure }) => {
626                        Some(TurnOutcome::Stopped(TurnStop::ToolError {
627                            tool_name: outcome.tool_name.clone(),
628                            value: failure.to_json_value(),
629                        }))
630                    }
631                    _ => None,
632                };
633            }
634
635            append_model_return_parts(&mut result_parts, outcome.model_return);
636        }
637
638        if !result_parts.is_empty() {
639            let user_id = fresh_message_id();
640            reassign_part_ids(&user_id, &mut result_parts);
641            actions.push(DriverAction::AppendEvents(vec![conversation_event(
642                Message {
643                    id: user_id,
644                    role: MessageRole::User,
645                    parts: shared_parts(result_parts),
646                    origin: None,
647                },
648            )]));
649        }
650
651        if let Some(outcome) = terminal_outcome {
652            actions.push(DriverAction::Finish(outcome));
653            return actions;
654        }
655
656        actions.push(DriverAction::AdvanceProtocolIteration);
657        let next_protocol_iteration = ctx.protocol_iteration() + 1;
658        if let Some(max_turns) = ctx.max_turns()
659            && next_protocol_iteration >= ctx.protocol_run_offset() + max_turns
660        {
661            let message_id = fresh_message_id();
662            actions.push(DriverAction::AppendEvents(vec![conversation_event(
663                turn_limit_exhausted_message(message_id, max_turns),
664            )]));
665            actions.push(DriverAction::Finish(TurnOutcome::Stopped(
666                TurnStop::MaxTurns,
667            )));
668            return actions;
669        }
670
671        actions.push(DriverAction::StartCheckpoint {
672            checkpoint: CheckpointKind::AfterWork,
673            on_empty: CheckpointResumeAction::PrepareIteration,
674        });
675        actions
676    }
677
678    fn handle_exec_result(
679        &self,
680        _ctx: DriverContextView<'_>,
681        _waiting: WaitingExecState<lash_core::HostTurnProtocol>,
682        _result: Result<lash_core::ExecResponse, String>,
683    ) -> Vec<DriverAction> {
684        Vec::new()
685    }
686}
687
688fn append_model_return_parts(parts: &mut Vec<Part>, model_return: lash_core::ModelToolReturn) {
689    for part in model_return.parts {
690        match part {
691            lash_core::ModelToolReturnPart::Text { text } => {
692                if text.is_empty() {
693                    continue;
694                }
695                parts.push(Part {
696                    id: String::new(),
697                    kind: PartKind::ToolResult,
698                    content: text,
699                    attachment: None,
700                    tool_call_id: Some(model_return.call_id.clone()),
701                    tool_name: Some(model_return.tool_name.clone()),
702                    tool_replay: None,
703                    prune_state: PruneState::Intact,
704                    reasoning_meta: None,
705                    response_meta: None,
706                });
707            }
708            lash_core::ModelToolReturnPart::Attachment(reference) => {
709                parts.push(Part {
710                    id: String::new(),
711                    kind: PartKind::Image,
712                    content: String::new(),
713                    attachment: Some(PartAttachment { reference }),
714                    tool_call_id: Some(model_return.call_id.clone()),
715                    tool_name: Some(model_return.tool_name.clone()),
716                    tool_replay: None,
717                    prune_state: PruneState::Intact,
718                    reasoning_meta: None,
719                    response_meta: None,
720                });
721            }
722        }
723    }
724}
725
726fn conversation_event(message: Message) -> SessionEventRecord {
727    SessionEventRecord::Conversation(ConversationRecord::from_message(message))
728}
729
730#[cfg(test)]
731mod tests {
732    use super::*;
733    use lash_core::{
734        AttachmentId, AttachmentMeta, ImageMediaType, MediaType, ModelToolReturn, ToolCallOutput,
735        ToolValue,
736    };
737    use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
738    use tokio::sync::Barrier;
739    use tokio::time::{Duration, timeout};
740
741    fn image_ref(id: &str) -> lash_core::AttachmentRef {
742        AttachmentMeta::new(
743            AttachmentId::new(id),
744            MediaType::Image(ImageMediaType::Png),
745            4,
746            Some(1),
747            Some(1),
748            Some("tiny".to_string()),
749        )
750        .as_ref()
751    }
752
753    #[derive(Clone, Debug)]
754    struct BatchRuntimeProvider {
755        calls: Arc<AtomicUsize>,
756        saw_batch_result: Arc<AtomicBool>,
757    }
758
759    #[async_trait::async_trait]
760    impl lash_core::Provider for BatchRuntimeProvider {
761        fn kind(&self) -> &'static str {
762            "stub"
763        }
764
765        fn options(&self) -> lash_core::ProviderOptions {
766            lash_core::ProviderOptions::default()
767        }
768
769        fn set_options(&mut self, _options: lash_core::ProviderOptions) {}
770
771        fn serialize_config(&self) -> serde_json::Value {
772            serde_json::json!({})
773        }
774
775        async fn complete(
776            &mut self,
777            request: lash_core::LlmRequest,
778        ) -> Result<lash_core::LlmResponse, lash_core::LlmTransportError> {
779            let call_index = self.calls.fetch_add(1, Ordering::SeqCst);
780            if call_index == 0 {
781                return Ok(lash_core::LlmResponse {
782                    parts: vec![lash_core::LlmOutputPart::ToolCall {
783                        call_id: "batch-call".to_string(),
784                        tool_name: "batch".to_string(),
785                        input_json: serde_json::json!({
786                            "tool_calls": [
787                                {"tool": "alpha", "parameters": {}},
788                                {"tool": "beta", "parameters": {"value": "fail"}}
789                            ]
790                        })
791                        .to_string(),
792                        replay: None,
793                    }],
794                    ..lash_core::LlmResponse::default()
795                });
796            }
797
798            let projected_messages = format!("{:?}", request.messages);
799            if projected_messages.contains("alpha") && projected_messages.contains("beta failed") {
800                self.saw_batch_result.store(true, Ordering::SeqCst);
801            }
802            Ok(lash_core::LlmResponse {
803                full_text: "done".to_string(),
804                parts: vec![lash_core::LlmOutputPart::Text {
805                    text: "done".to_string(),
806                    response_meta: None,
807                }],
808                ..lash_core::LlmResponse::default()
809            })
810        }
811
812        fn clone_boxed(&self) -> Box<dyn lash_core::Provider> {
813            Box::new(self.clone())
814        }
815    }
816
817    #[derive(Debug)]
818    struct BatchRuntimeTools {
819        barrier: Arc<Barrier>,
820        started: Arc<AtomicUsize>,
821    }
822
823    fn runtime_test_tool(name: &str) -> lash_core::ToolDefinition {
824        lash_core::ToolDefinition::raw(
825            format!("tool:{name}"),
826            name,
827            "",
828            serde_json::json!({
829                "type": "object",
830                "properties": {
831                    "value": { "type": "string" }
832                },
833                "additionalProperties": true
834            }),
835            serde_json::json!({ "type": "string" }),
836        )
837        .with_scheduling(lash_core::ToolScheduling::Parallel)
838    }
839
840    #[async_trait::async_trait]
841    impl ToolProvider for BatchRuntimeTools {
842        fn tool_manifests(&self) -> Vec<ToolManifest> {
843            vec![
844                runtime_test_tool("alpha").manifest(),
845                runtime_test_tool("beta").manifest(),
846            ]
847        }
848
849        fn resolve_contract(&self, name: &str) -> Option<Arc<ToolContract>> {
850            match name {
851                "alpha" | "beta" => Some(Arc::new(runtime_test_tool(name).contract())),
852                _ => None,
853            }
854        }
855
856        async fn execute(&self, call: ToolCall<'_>) -> ToolResult {
857            self.started.fetch_add(1, Ordering::SeqCst);
858            if timeout(Duration::from_millis(100), self.barrier.wait())
859                .await
860                .is_err()
861            {
862                return ToolResult::err_fmt("batch child tools did not run concurrently");
863            }
864            if call.name == "beta"
865                && call.args.get("value").and_then(|value| value.as_str()) == Some("fail")
866            {
867                return ToolResult::err_fmt("beta failed");
868            }
869            ToolResult::ok(serde_json::json!(call.name))
870        }
871    }
872
873    #[derive(Clone, Default)]
874    struct CountingEffectController {
875        kinds: Arc<std::sync::Mutex<Vec<lash_core::RuntimeEffectKind>>>,
876    }
877
878    impl CountingEffectController {
879        fn count(&self, kind: lash_core::RuntimeEffectKind) -> usize {
880            self.kinds
881                .lock()
882                .expect("effect kinds")
883                .iter()
884                .filter(|candidate| **candidate == kind)
885                .count()
886        }
887    }
888
889    #[derive(Default)]
890    struct DurableMemoryAttachmentStore {
891        inner: lash_core::InMemoryAttachmentStore,
892    }
893
894    #[async_trait::async_trait]
895    impl lash_core::AttachmentStore for DurableMemoryAttachmentStore {
896        fn persistence(&self) -> lash_core::AttachmentStorePersistence {
897            lash_core::AttachmentStorePersistence::Durable
898        }
899
900        async fn put(
901            &self,
902            bytes: Vec<u8>,
903            meta: lash_core::AttachmentCreateMeta,
904        ) -> Result<lash_core::AttachmentRef, lash_core::AttachmentStoreError> {
905            self.inner.put(bytes, meta).await
906        }
907
908        async fn get(
909            &self,
910            id: &lash_core::AttachmentId,
911        ) -> Result<lash_core::StoredAttachment, lash_core::AttachmentStoreError> {
912            self.inner.get(id).await
913        }
914    }
915
916    #[derive(Default)]
917    struct DurableMemoryProcessEnvStore {
918        inner: lash_core::InMemoryProcessExecutionEnvStore,
919    }
920
921    #[async_trait::async_trait]
922    impl lash_core::ProcessExecutionEnvStore for DurableMemoryProcessEnvStore {
923        fn durability_tier(&self) -> lash_core::DurabilityTier {
924            lash_core::DurabilityTier::Durable
925        }
926
927        async fn put_process_execution_env(
928            &self,
929            env_ref: &lash_core::ProcessExecutionEnvRef,
930            bytes: &[u8],
931        ) -> Result<(), lash_core::PluginError> {
932            self.inner.put_process_execution_env(env_ref, bytes).await
933        }
934
935        async fn get_process_execution_env(
936            &self,
937            env_ref: &lash_core::ProcessExecutionEnvRef,
938        ) -> Result<Option<Vec<u8>>, lash_core::PluginError> {
939            self.inner.get_process_execution_env(env_ref).await
940        }
941    }
942
943    #[async_trait::async_trait]
944    impl lash_core::RuntimeEffectController for CountingEffectController {
945        fn durability_tier(&self) -> lash_core::DurabilityTier {
946            lash_core::DurabilityTier::Durable
947        }
948
949        async fn execute_effect(
950            &self,
951            envelope: lash_core::RuntimeEffectEnvelope,
952            local_executor: lash_core::RuntimeEffectLocalExecutor<'_>,
953        ) -> Result<lash_core::RuntimeEffectOutcome, lash_core::RuntimeEffectControllerError>
954        {
955            self.kinds
956                .lock()
957                .expect("effect kinds")
958                .push(envelope.command.kind());
959            local_executor.execute(envelope).await
960        }
961    }
962
963    #[tokio::test]
964    async fn standard_batch_tool_rejects_nested_batch_inside_durable_attempt() {
965        let provider_calls = Arc::new(AtomicUsize::new(0));
966        let saw_batch_result = Arc::new(AtomicBool::new(false));
967        let provider = BatchRuntimeProvider {
968            calls: Arc::clone(&provider_calls),
969            saw_batch_result: Arc::clone(&saw_batch_result),
970        };
971        let provider_handle = lash_core::ProviderHandle::new(lash_core::ProviderComponents::new(
972            Box::new(provider),
973            Arc::new(lash_core::StaticModelPolicy::new()),
974        ));
975        let mut host = lash_core::RuntimeHostConfig::in_memory();
976        host.providers.provider_resolver =
977            Arc::new(lash_core::SingleProviderResolver::new(provider_handle));
978        host.durability.attachment_store = Arc::new(DurableMemoryAttachmentStore::default());
979        host.durability.process_env_store = Arc::new(DurableMemoryProcessEnvStore::default());
980        let started = Arc::new(AtomicUsize::new(0));
981        let factories: Vec<Arc<dyn lash_core::PluginFactory>> = vec![
982            Arc::new(StandardProtocolPluginFactory::new()),
983            Arc::new(lash_core::plugin::StaticPluginFactory::new(
984                "standard-batch-test-tools",
985                lash_core::PluginSpec::new().with_tool_provider(Arc::new(BatchRuntimeTools {
986                    barrier: Arc::new(Barrier::new(2)),
987                    started: Arc::clone(&started),
988                })),
989            )),
990        ];
991        let policy = lash_core::SessionPolicy {
992            provider_id: "stub".to_string(),
993            model: lash_core::ModelSpec::from_token_limits("mock-model", None, 200_000, None)
994                .expect("valid model"),
995            ..lash_core::SessionPolicy::default()
996        };
997        let controller = CountingEffectController::default();
998        let scoped_controller = lash_core::ScopedEffectController::shared(
999            Arc::new(controller.clone()),
1000            lash_core::ExecutionScope::turn("standard-batch-session", "turn-1"),
1001        )
1002        .expect("scoped controller");
1003        let mut runtime = lash_core::LashRuntime::builder()
1004            .with_session_id("standard-batch-session")
1005            .with_policy(policy)
1006            .with_runtime_host(host)
1007            .with_plugin_factories(factories)
1008            .build()
1009            .await
1010            .expect("runtime");
1011
1012        let turn = runtime
1013            .stream_turn(
1014                lash_core::TurnInput::text("run the batch"),
1015                lash_core::TurnOptions::new(
1016                    tokio_util::sync::CancellationToken::new(),
1017                    scoped_controller,
1018                ),
1019            )
1020            .await
1021            .expect("turn");
1022
1023        assert!(matches!(turn.outcome, lash_core::TurnOutcome::Finished(_)));
1024        assert_eq!(provider_calls.load(Ordering::SeqCst), 2);
1025        assert_eq!(started.load(Ordering::SeqCst), 0);
1026        assert!(!saw_batch_result.load(Ordering::SeqCst));
1027        assert_eq!(controller.count(lash_core::RuntimeEffectKind::ToolBatch), 1);
1028        assert_eq!(
1029            controller.count(lash_core::RuntimeEffectKind::ToolAttempt),
1030            1
1031        );
1032    }
1033
1034    #[test]
1035    fn tool_attachment_round_trips_to_part_kind_image() {
1036        let attachment = image_ref("att-1");
1037        let output = ToolCallOutput::success(ToolValue::Attachment(attachment.clone()));
1038        let model_return =
1039            ModelToolReturn::from_output("call-9".to_string(), "screenshot".to_string(), &output);
1040
1041        let mut parts: Vec<Part> = Vec::new();
1042        append_model_return_parts(&mut parts, model_return);
1043
1044        assert_eq!(parts.len(), 1, "single attachment yields single part");
1045        let part = &parts[0];
1046        assert!(matches!(part.kind, PartKind::Image));
1047        assert_eq!(part.content, "");
1048        assert_eq!(part.tool_call_id.as_deref(), Some("call-9"));
1049        assert_eq!(part.tool_name.as_deref(), Some("screenshot"));
1050        let part_attachment = part.attachment.as_ref().expect("attachment present");
1051        assert_eq!(part_attachment.reference.id, attachment.id);
1052    }
1053
1054    #[test]
1055    fn tool_text_and_attachment_round_trip_preserves_order() {
1056        let attachment = image_ref("att-2");
1057        let output = ToolCallOutput::success(ToolValue::Array(vec![
1058            ToolValue::String("before".into()),
1059            ToolValue::Attachment(attachment.clone()),
1060            ToolValue::String("after".into()),
1061        ]));
1062        let model_return =
1063            ModelToolReturn::from_output("call-10".to_string(), "snap".to_string(), &output);
1064
1065        let mut parts: Vec<Part> = Vec::new();
1066        append_model_return_parts(&mut parts, model_return);
1067
1068        // The array projection emits compact JSON text fragments around the
1069        // attachment, preserving in-order position.
1070        assert_eq!(parts.len(), 3, "text + image + text yields three parts");
1071        assert!(matches!(parts[0].kind, PartKind::ToolResult));
1072        assert!(parts[0].content.starts_with("[\"before\""));
1073        assert!(matches!(parts[1].kind, PartKind::Image));
1074        assert_eq!(
1075            parts[1]
1076                .attachment
1077                .as_ref()
1078                .expect("attachment")
1079                .reference
1080                .id,
1081            attachment.id
1082        );
1083        assert!(matches!(parts[2].kind, PartKind::ToolResult));
1084        assert!(parts[2].content.ends_with("\"after\"]"));
1085    }
1086}