Skip to main content

lash_protocol_standard/
lib.rs

1//! Standard protocol stack: the model drives tools via the native
2//! function-calling envelope of its LLM transport.
3//!
4//! This crate owns:
5//!
6//! - [`StandardDriver`] — the [`ProtocolDriverHandle`] that dispatches
7//!   native tool calls and weaves reasoning parts into the assistant
8//!   message timeline.
9//! - The [`StandardProtocolPluginFactory`] plugin that claims the
10//!   protocol-driver slot so the runtime can run standard-protocol
11//!   sessions.
12//! - The `batch` tool that composes parallel native tool calls (only
13//!   exposed when this protocol stack is installed).
14
15use std::sync::Arc;
16
17use async_trait::async_trait;
18use lash_core::llm::types::{ProviderReasoningReplay, ProviderReplayMeta, ResponseTextMeta};
19use lash_core::plugin::{
20    PluginError, PluginFactory, PluginRegistrar, PluginSessionContext, ProtocolDriverPlugin,
21    ProtocolSessionContext, ProtocolSessionPlugin, SessionPlugin,
22};
23use lash_core::sansio::{
24    CheckpointResumeAction, CompletedToolCall, PendingToolCall, ProtocolDriverHandle,
25    WaitingExecState, WaitingLlmState,
26};
27use lash_core::session_model::message::PartAttachment;
28use lash_core::session_model::{
29    ConversationRecord, Message, MessageRole, Part, PartKind, PruneState, SessionEvent,
30    SessionEventRecord, fresh_message_id, make_error_event, reassign_part_ids, shared_parts,
31};
32
33mod batch;
34pub mod scenario_contracts;
35use batch::batch_tool_definition;
36use lash_core::{
37    CheckpointKind, DriverAction, DriverContextView, LlmOutputPart, LlmResponse,
38    ProtocolBuildInput, SessionError, ToolCall, ToolContract, ToolInvocation, ToolManifest,
39    ToolProvider, ToolResult, TurnDriverConfig, TurnDriverPreamble, TurnFinish, TurnOutcome,
40    TurnStop, append_assistant_text_part, normalized_response_parts, reasoning_part,
41};
42use serde_json::Value;
43
44const STANDARD_EXECUTION_SECTION: &str = r#"Use direct tool calls.
45
46- Use `batch` (up to 25 calls) for two or more independent tool calls. Serialize calls when later arguments depend on earlier results.
47- For direct conversational requests that need no tools, respond in prose only.
48
49Example — two independent reads in one `batch` call:
50
51```json
52{
53  "tool_calls": [
54    { "tool": "read_file", "parameters": { "path": "src/main.rs" } },
55    { "tool": "grep", "parameters": { "query": "ToolProvider", "path": "crates/lash/src/" } }
56  ]
57}
58```"#;
59
60const BATCH_MAX_TOOL_CALLS: usize = 25;
61const STANDARD_PROTOCOL_PLUGIN_ID: &str = "standard_protocol";
62
63/// Plugin factory that installs the standard-protocol driver,
64/// session plugin, and native tool catalog.
65#[derive(Default)]
66pub struct StandardProtocolPluginFactory;
67
68impl StandardProtocolPluginFactory {
69    pub fn new() -> Self {
70        Self
71    }
72}
73
74impl PluginFactory for StandardProtocolPluginFactory {
75    fn id(&self) -> &'static str {
76        STANDARD_PROTOCOL_PLUGIN_ID
77    }
78
79    fn build(&self, _ctx: &PluginSessionContext) -> Result<Arc<dyn SessionPlugin>, PluginError> {
80        Ok(Arc::new(StandardProtocolPlugin))
81    }
82}
83
84struct StandardProtocolPlugin;
85
86impl SessionPlugin for StandardProtocolPlugin {
87    fn id(&self) -> &'static str {
88        STANDARD_PROTOCOL_PLUGIN_ID
89    }
90
91    fn register(&self, reg: &mut PluginRegistrar) -> Result<(), PluginError> {
92        reg.protocol().session(Arc::new(StandardProtocolSession))?;
93        reg.protocol()
94            .protocol_driver(Arc::new(StandardProtocolDriver))?;
95        reg.tools().provider(Arc::new(StandardProtocolTools))?;
96        Ok(())
97    }
98}
99
100struct StandardProtocolSession;
101
102#[async_trait]
103impl ProtocolSessionPlugin for StandardProtocolSession {
104    async fn initialize_session(
105        &self,
106        _ctx: ProtocolSessionContext<'_>,
107    ) -> Result<(), SessionError> {
108        Ok(())
109    }
110}
111
112struct StandardProtocolDriver;
113
114impl ProtocolDriverPlugin for StandardProtocolDriver {
115    fn build_preamble(&self, input: ProtocolBuildInput) -> TurnDriverPreamble {
116        let tool_names = input.tool_catalog.tool_names();
117        let tool_names_fingerprint = input.tool_catalog.tool_names_fingerprint();
118        TurnDriverPreamble {
119            config: TurnDriverConfig::chat(
120                Arc::new(StandardDriver),
121                true,
122                Arc::new(turn_limit_exhausted_message),
123            ),
124            tool_specs: input.tool_catalog.model_tool_specs(),
125            tool_names,
126            tool_names_fingerprint,
127            execution_prompt: Arc::from(STANDARD_EXECUTION_SECTION),
128            prompt_contributions: input.extra_prompt_contributions,
129        }
130    }
131}
132
133fn turn_limit_exhausted_message(message_id: String, max_turns: usize) -> Message {
134    Message {
135        id: message_id.clone(),
136        role: MessageRole::System,
137        parts: shared_parts(vec![Part {
138            id: format!("{message_id}.p0"),
139            kind: PartKind::Error,
140            content: format!("Turn limit reached ({max_turns}) before a final assistant response."),
141            attachment: None,
142            tool_call_id: None,
143            tool_name: None,
144            tool_replay: None,
145            prune_state: PruneState::Intact,
146            reasoning_meta: None,
147            response_meta: None,
148        }]),
149        origin: None,
150    }
151}
152
153struct StandardProtocolTools;
154
155#[async_trait]
156impl ToolProvider for StandardProtocolTools {
157    fn tool_manifests(&self) -> Vec<ToolManifest> {
158        vec![batch_tool_definition().manifest()]
159    }
160
161    fn resolve_contract(&self, name: &str) -> Option<Arc<ToolContract>> {
162        (name == "batch").then(|| Arc::new(batch_tool_definition().contract()))
163    }
164
165    async fn execute(&self, call: ToolCall<'_>) -> ToolResult {
166        match call.name {
167            "batch" => execute_batch_tool_call(call).await,
168            _ => ToolResult::err_fmt(format_args!("Unknown tool: {}", call.name)),
169        }
170    }
171}
172
173#[derive(Debug)]
174struct BatchCallSpec {
175    index: usize,
176    tool: String,
177    parameters: Value,
178}
179
180async fn execute_batch_tool_call(call: ToolCall<'_>) -> ToolResult {
181    let args = call.args;
182    let specs = match parse_batch_specs(args) {
183        Ok(specs) => specs,
184        Err(err) => return err,
185    };
186
187    let mut immediate_outcomes = Vec::new();
188    let mut parallel_specs = Vec::new();
189    let dispatch = call.context.dispatch();
190
191    for spec in specs.into_iter().take(BATCH_MAX_TOOL_CALLS) {
192        if spec.tool == "batch" {
193            immediate_outcomes.push(serde_json::json!({
194                "index": spec.index,
195                "tool": spec.tool,
196                "success": false,
197                "duration_ms": 0,
198                "error": "Tool 'batch' is not allowed inside batch",
199            }));
200            continue;
201        }
202        let Some(manifest) = dispatch.callable_tool_manifest(&spec.tool) else {
203            let error = format!("Tool '{}' is unavailable in this session", spec.tool);
204            immediate_outcomes.push(serde_json::json!({
205                "index": spec.index,
206                "tool": spec.tool,
207                "success": false,
208                "duration_ms": 0,
209                "error": error,
210            }));
211            continue;
212        };
213        parallel_specs.push((
214            spec.index,
215            ToolInvocation::new(
216                format!(
217                    "{}:{:02}",
218                    call.context.tool_call_id().unwrap_or("batch"),
219                    spec.index
220                ),
221                manifest.id,
222                spec.parameters,
223            ),
224        ));
225    }
226
227    let mut parallel_outcomes = dispatch
228        .batch(
229            parallel_specs
230                .iter()
231                .map(|(_, invocation)| invocation.clone())
232                .collect(),
233        )
234        .await;
235    for ((index, invocation), outcome) in
236        parallel_specs.into_iter().zip(parallel_outcomes.drain(..))
237    {
238        let tool_label = invocation.label();
239        let tool_record = outcome.record.unwrap_or(lash_core::ToolCallRecord {
240            call_id: Some(invocation.id),
241            tool: tool_label,
242            args: invocation.args,
243            output: outcome.output,
244            duration_ms: 0,
245        });
246        let mut result_record = serde_json::Map::new();
247        result_record.insert("index".to_string(), serde_json::json!(index));
248        result_record.insert("tool".to_string(), serde_json::json!(tool_record.tool));
249        result_record.insert(
250            "success".to_string(),
251            serde_json::json!(tool_record.output.is_success()),
252        );
253        result_record.insert(
254            "duration_ms".to_string(),
255            serde_json::json!(tool_record.duration_ms),
256        );
257        result_record.insert(
258            if tool_record.output.is_success() {
259                "result".to_string()
260            } else {
261                "error".to_string()
262            },
263            tool_record.output.value_for_projection(),
264        );
265        immediate_outcomes.push(Value::Object(result_record));
266    }
267
268    for overflow_index in BATCH_MAX_TOOL_CALLS
269        ..args
270            .get("tool_calls")
271            .and_then(|value| value.as_array())
272            .map(|value| value.len())
273            .unwrap_or_default()
274    {
275        immediate_outcomes.push(serde_json::json!({
276            "index": overflow_index,
277            "tool": args
278                .get("tool_calls")
279                .and_then(|value| value.as_array())
280                .and_then(|items| items.get(overflow_index))
281                .and_then(|item| item.get("tool"))
282                .and_then(|value| value.as_str())
283                .unwrap_or("unknown"),
284            "success": false,
285            "duration_ms": 0,
286            "error": "Maximum of 25 tool calls allowed in batch",
287        }));
288    }
289
290    immediate_outcomes.sort_by_key(|outcome| {
291        outcome
292            .get("index")
293            .and_then(|value| value.as_u64())
294            .unwrap_or(u64::MAX)
295    });
296    ToolResult::ok(serde_json::json!({
297        "results": immediate_outcomes,
298    }))
299}
300
301#[allow(clippy::result_large_err)]
302fn parse_batch_specs(args: &Value) -> Result<Vec<BatchCallSpec>, ToolResult> {
303    let Some(raw_calls) = args.get("tool_calls").and_then(|value| value.as_array()) else {
304        return Err(ToolResult::err_fmt(
305            "Missing required parameter: tool_calls",
306        ));
307    };
308    if raw_calls.is_empty() {
309        return Err(ToolResult::err_fmt(
310            "Invalid tool_calls: expected at least one call",
311        ));
312    }
313
314    let mut specs = Vec::with_capacity(raw_calls.len());
315    for (index, item) in raw_calls.iter().enumerate() {
316        let Some(object) = item.as_object() else {
317            return Err(ToolResult::err_fmt(format_args!(
318                "Invalid tool_calls[{index}]: expected object with tool and parameters"
319            )));
320        };
321        let Some(tool) = object
322            .get("tool")
323            .and_then(|value| value.as_str())
324            .map(str::trim)
325            .filter(|tool| !tool.is_empty())
326        else {
327            return Err(ToolResult::err_fmt(format_args!(
328                "Invalid tool_calls[{index}].tool: expected non-empty string"
329            )));
330        };
331        let parameters = object
332            .get("parameters")
333            .cloned()
334            .unwrap_or_else(|| serde_json::json!({}));
335        specs.push(BatchCallSpec {
336            index,
337            tool: tool.to_string(),
338            parameters,
339        });
340    }
341
342    Ok(specs)
343}
344
345// ─────────────────────────────────────────────────────────────────────
346// Standard protocol driver
347// ─────────────────────────────────────────────────────────────────────
348
349/// Protocol driver for the Standard protocol. Consumes native
350/// tool-call envelopes from the LLM, dispatches them via
351/// `DriverAction::StartTools`, and splices reasoning parts into the
352/// assistant message so provider replay metadata preserves
353/// chain-of-thought ordering.
354pub struct StandardDriver;
355
356struct StandardToolCall {
357    call_id: String,
358    tool_name: String,
359    input_json: String,
360    replay: Option<ProviderReplayMeta>,
361}
362
363fn last_message_has_tool_result(ctx: &DriverContextView<'_>) -> bool {
364    ctx.messages().last().is_some_and(|message| {
365        matches!(message.role, MessageRole::User)
366            && message
367                .parts
368                .iter()
369                .any(|part| matches!(part.kind, PartKind::ToolResult))
370    })
371}
372
373impl ProtocolDriverHandle<lash_core::HostTurnProtocol> for StandardDriver {
374    fn prepare_protocol_iteration(&self, ctx: DriverContextView<'_>) -> Vec<DriverAction> {
375        vec![DriverAction::StartLlm {
376            request: ctx.project_llm_request(true),
377            driver_state: None,
378        }]
379    }
380
381    fn handle_llm_success(
382        &self,
383        ctx: DriverContextView<'_>,
384        _waiting: WaitingLlmState<lash_core::HostTurnProtocol>,
385        llm_response: LlmResponse,
386        text_streamed: bool,
387    ) -> Vec<DriverAction> {
388        let response_parts = normalized_response_parts(&llm_response);
389        let mut assistant_text = String::new();
390        let mut assistant_text_parts: Vec<(String, Option<ResponseTextMeta>)> = Vec::new();
391        let mut tool_calls: Vec<StandardToolCall> = Vec::new();
392        // Reasoning items captured with their position in the original
393        // response. The `usize` is the index in `tool_calls` that this
394        // reasoning item originally preceded, so we can interleave
395        // reasoning → tool_call in the provider's original emission order.
396        // `Option<ProviderReasoningReplay>` carries roundtrip payload
397        // when present (fix 1.3b); when None, the item is display-only
398        // (fix 1.3a) — still rendered in the UI but never re-fed.
399        let mut reasoning_items: Vec<(usize, Option<ProviderReasoningReplay>, String)> = Vec::new();
400        let mut actions = Vec::new();
401
402        for part in response_parts {
403            match part {
404                LlmOutputPart::Text {
405                    text,
406                    response_meta,
407                } => {
408                    if !text.is_empty() {
409                        let previous_len = assistant_text.len();
410                        append_assistant_text_part(&mut assistant_text, &text);
411                        assistant_text_parts
412                            .push((assistant_text[previous_len..].to_string(), response_meta));
413                        if !text_streamed {
414                            actions.push(DriverAction::Emit(SessionEvent::TextDelta {
415                                content: assistant_text[previous_len..].to_string(),
416                            }));
417                        }
418                    }
419                }
420                LlmOutputPart::Reasoning { text, replay } => {
421                    let trimmed = text.trim().to_string();
422                    // Skip fully-empty reasoning items (no display text and
423                    // no roundtrip payload).
424                    if trimmed.is_empty() && replay.as_ref().is_none_or(|meta| meta.is_empty()) {
425                        continue;
426                    }
427                    reasoning_items.push((tool_calls.len(), replay, trimmed));
428                }
429                LlmOutputPart::ToolCall {
430                    call_id,
431                    tool_name,
432                    input_json,
433                    replay,
434                } => {
435                    tool_calls.push(StandardToolCall {
436                        call_id,
437                        tool_name,
438                        input_json,
439                        replay,
440                    });
441                }
442            }
443        }
444
445        actions.push(DriverAction::Emit(SessionEvent::LlmResponse {
446            protocol_iteration: ctx.protocol_iteration(),
447            content: assistant_text.clone(),
448            duration_ms: 0,
449        }));
450
451        if tool_calls.is_empty() {
452            if assistant_text.trim().is_empty() && reasoning_items.is_empty() {
453                if last_message_has_tool_result(&ctx) {
454                    // A model can intentionally complete a tool-only request
455                    // with an empty final answer, e.g. when the user says
456                    // "do nothing else" after the tool action.
457                    actions.push(DriverAction::StartCheckpoint {
458                        checkpoint: CheckpointKind::BeforeCompletion,
459                        on_empty: CheckpointResumeAction::Finish(TurnOutcome::Finished(
460                            TurnFinish::AssistantMessage {
461                                text: String::new(),
462                            },
463                        )),
464                    });
465                    return actions;
466                }
467                actions.push(DriverAction::Emit(make_error_event(
468                    "llm_provider",
469                    Some("empty_response"),
470                    "Model returned no assistant text or tool calls.",
471                    None,
472                )));
473                actions.push(DriverAction::Finish(TurnOutcome::Stopped(
474                    TurnStop::ProviderError,
475                )));
476                return actions;
477            }
478
479            let asst_id = fresh_message_id();
480            let mut parts_out = Vec::new();
481            for (_, meta, text) in reasoning_items {
482                parts_out.push(reasoning_part(&asst_id, parts_out.len(), text, meta));
483            }
484            for (content, response_meta) in assistant_text_parts {
485                if content.trim().is_empty() {
486                    continue;
487                }
488                parts_out.push(Part {
489                    id: format!("{}.p{}", asst_id, parts_out.len()),
490                    kind: PartKind::Prose,
491                    content,
492                    attachment: None,
493                    tool_call_id: None,
494                    tool_name: None,
495                    tool_replay: None,
496                    prune_state: PruneState::Intact,
497                    reasoning_meta: None,
498                    response_meta,
499                });
500            }
501            if parts_out.is_empty() {
502                actions.push(DriverAction::Emit(make_error_event(
503                    "llm_provider",
504                    Some("empty_response"),
505                    "Model returned no assistant text or tool calls.",
506                    None,
507                )));
508                actions.push(DriverAction::Finish(TurnOutcome::Stopped(
509                    TurnStop::ProviderError,
510                )));
511                return actions;
512            }
513            actions.push(DriverAction::StartCheckpoint {
514                checkpoint: CheckpointKind::BeforeCompletion,
515                on_empty: CheckpointResumeAction::Finish(TurnOutcome::Finished(
516                    TurnFinish::AssistantMessage {
517                        text: assistant_text.clone(),
518                    },
519                )),
520            });
521            return actions;
522        }
523
524        let asst_id = fresh_message_id();
525        let mut assistant_parts = Vec::new();
526        for (content, response_meta) in assistant_text_parts {
527            if content.trim().is_empty() {
528                continue;
529            }
530            assistant_parts.push(Part {
531                id: format!("{}.p{}", asst_id, assistant_parts.len()),
532                kind: PartKind::Prose,
533                content,
534                attachment: None,
535                tool_call_id: None,
536                tool_name: None,
537                tool_replay: None,
538                prune_state: PruneState::Intact,
539                reasoning_meta: None,
540                response_meta,
541            });
542        }
543
544        let mut calls = Vec::new();
545        // Interleave reasoning items with tool calls to preserve the
546        // original emission order. Some provider replays expect the
547        // sequence `reasoning → function_call` from the turn in which both
548        // were produced; swapping them can drop the reasoning/tool pairing.
549        let mut reasoning_iter = reasoning_items.into_iter().peekable();
550        for (tool_index, tool_call) in tool_calls.into_iter().enumerate() {
551            while let Some((insert_index, _, _)) = reasoning_iter.peek() {
552                if *insert_index > tool_index {
553                    break;
554                }
555                let (_, meta, text) = reasoning_iter.next().expect("peek ok");
556                assistant_parts.push(reasoning_part(&asst_id, assistant_parts.len(), text, meta));
557            }
558            assistant_parts.push(Part {
559                id: format!("{}.p{}", asst_id, assistant_parts.len()),
560                kind: PartKind::ToolCall,
561                content: tool_call.input_json.clone(),
562                attachment: None,
563                tool_call_id: Some(tool_call.call_id.clone()),
564                tool_name: Some(tool_call.tool_name.clone()),
565                tool_replay: tool_call.replay.clone(),
566                prune_state: PruneState::Intact,
567                reasoning_meta: None,
568                response_meta: None,
569            });
570
571            let args = serde_json::from_str::<Value>(&tool_call.input_json)
572                .unwrap_or_else(|_| serde_json::json!({}));
573            calls.push(PendingToolCall {
574                call_id: tool_call.call_id,
575                tool_name: tool_call.tool_name,
576                args,
577                replay: tool_call.replay,
578            });
579        }
580        for (_, meta, text) in reasoning_iter {
581            assistant_parts.push(reasoning_part(&asst_id, assistant_parts.len(), text, meta));
582        }
583
584        if !assistant_parts.is_empty() {
585            actions.push(DriverAction::AppendEvents(vec![conversation_event(
586                Message {
587                    id: asst_id,
588                    role: MessageRole::Assistant,
589                    parts: shared_parts(assistant_parts),
590                    origin: None,
591                },
592            )]));
593        }
594
595        actions.push(DriverAction::StartTools { calls });
596        actions
597    }
598
599    fn handle_tool_results(
600        &self,
601        ctx: DriverContextView<'_>,
602        completed: Vec<CompletedToolCall>,
603    ) -> Vec<DriverAction> {
604        let mut actions = Vec::new();
605        let mut result_parts = Vec::new();
606        let mut terminal_outcome = None;
607
608        for outcome in completed {
609            if terminal_outcome.is_none() && outcome.output.is_success() {
610                terminal_outcome = match outcome.output.control.as_ref() {
611                    Some(lash_core::ToolControl::SwitchAgentFrame {
612                        frame_id,
613                        task: Some(task),
614                        ..
615                    }) if !frame_id.trim().is_empty() && !task.trim().is_empty() => {
616                        Some(TurnOutcome::AgentFrameSwitch {
617                            frame_id: frame_id.clone(),
618                            task: task.clone(),
619                        })
620                    }
621                    Some(lash_core::ToolControl::Finish { value }) => {
622                        Some(TurnOutcome::Finished(TurnFinish::ToolValue {
623                            tool_name: outcome.tool_name.clone(),
624                            value: value.to_json_value(),
625                        }))
626                    }
627                    Some(lash_core::ToolControl::Fail { failure }) => {
628                        Some(TurnOutcome::Stopped(TurnStop::ToolError {
629                            tool_name: outcome.tool_name.clone(),
630                            value: failure.to_json_value(),
631                        }))
632                    }
633                    _ => None,
634                };
635            }
636
637            append_model_return_parts(&mut result_parts, outcome.model_return);
638        }
639
640        if !result_parts.is_empty() {
641            let user_id = fresh_message_id();
642            reassign_part_ids(&user_id, &mut result_parts);
643            actions.push(DriverAction::AppendEvents(vec![conversation_event(
644                Message {
645                    id: user_id,
646                    role: MessageRole::User,
647                    parts: shared_parts(result_parts),
648                    origin: None,
649                },
650            )]));
651        }
652
653        if let Some(outcome) = terminal_outcome {
654            actions.push(DriverAction::Finish(outcome));
655            return actions;
656        }
657
658        actions.push(DriverAction::AdvanceProtocolIteration);
659        let next_protocol_iteration = ctx.protocol_iteration() + 1;
660        if let Some(max_turns) = ctx.max_turns()
661            && next_protocol_iteration >= ctx.protocol_run_offset() + max_turns
662        {
663            let message_id = fresh_message_id();
664            actions.push(DriverAction::AppendEvents(vec![conversation_event(
665                turn_limit_exhausted_message(message_id, max_turns),
666            )]));
667            actions.push(DriverAction::Finish(TurnOutcome::Stopped(
668                TurnStop::MaxTurns,
669            )));
670            return actions;
671        }
672
673        actions.push(DriverAction::StartCheckpoint {
674            checkpoint: CheckpointKind::AfterWork,
675            on_empty: CheckpointResumeAction::PrepareIteration,
676        });
677        actions
678    }
679
680    fn handle_exec_result(
681        &self,
682        _ctx: DriverContextView<'_>,
683        _waiting: WaitingExecState<lash_core::HostTurnProtocol>,
684        _result: Result<lash_core::ExecResponse, String>,
685    ) -> Vec<DriverAction> {
686        Vec::new()
687    }
688}
689
690fn append_model_return_parts(parts: &mut Vec<Part>, model_return: lash_core::ModelToolReturn) {
691    for part in model_return.parts {
692        match part {
693            lash_core::ModelToolReturnPart::Text { text } => {
694                if text.is_empty() {
695                    continue;
696                }
697                parts.push(Part {
698                    id: String::new(),
699                    kind: PartKind::ToolResult,
700                    content: text,
701                    attachment: None,
702                    tool_call_id: Some(model_return.call_id.clone()),
703                    tool_name: Some(model_return.tool_name.clone()),
704                    tool_replay: None,
705                    prune_state: PruneState::Intact,
706                    reasoning_meta: None,
707                    response_meta: None,
708                });
709            }
710            lash_core::ModelToolReturnPart::Attachment(reference) => {
711                parts.push(Part {
712                    id: String::new(),
713                    kind: PartKind::Image,
714                    content: String::new(),
715                    attachment: Some(PartAttachment { reference }),
716                    tool_call_id: Some(model_return.call_id.clone()),
717                    tool_name: Some(model_return.tool_name.clone()),
718                    tool_replay: None,
719                    prune_state: PruneState::Intact,
720                    reasoning_meta: None,
721                    response_meta: None,
722                });
723            }
724        }
725    }
726}
727
728fn conversation_event(message: Message) -> SessionEventRecord {
729    SessionEventRecord::Conversation(ConversationRecord::from_message(message))
730}
731
732#[cfg(test)]
733mod tests {
734    use super::*;
735    use lash_core::{
736        AttachmentId, AttachmentMeta, ImageMediaType, MediaType, ModelToolReturn, ToolCallOutput,
737        ToolValue,
738    };
739    use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
740    use tokio::sync::Barrier;
741    use tokio::time::{Duration, timeout};
742
743    fn image_ref(id: &str) -> lash_core::AttachmentRef {
744        AttachmentMeta::new(
745            AttachmentId::new(id),
746            MediaType::Image(ImageMediaType::Png),
747            4,
748            Some(1),
749            Some(1),
750            Some("tiny".to_string()),
751        )
752        .as_ref()
753    }
754
755    #[test]
756    fn standard_protocol_factory_id_is_stable_plugin_contract() {
757        let factory = StandardProtocolPluginFactory::new();
758
759        assert_eq!(factory.id(), STANDARD_PROTOCOL_PLUGIN_ID);
760        assert_eq!(factory.id(), "standard_protocol");
761    }
762
763    #[derive(Clone, Debug)]
764    struct BatchRuntimeProvider {
765        calls: Arc<AtomicUsize>,
766        saw_batch_result: Arc<AtomicBool>,
767    }
768
769    #[async_trait::async_trait]
770    impl lash_core::Provider for BatchRuntimeProvider {
771        fn kind(&self) -> &'static str {
772            "stub"
773        }
774
775        fn options(&self) -> lash_core::ProviderOptions {
776            lash_core::ProviderOptions::default()
777        }
778
779        fn set_options(&mut self, _options: lash_core::ProviderOptions) {}
780
781        fn serialize_config(&self) -> serde_json::Value {
782            serde_json::json!({})
783        }
784
785        async fn complete(
786            &mut self,
787            request: lash_core::LlmRequest,
788        ) -> Result<lash_core::LlmResponse, lash_core::LlmTransportError> {
789            let call_index = self.calls.fetch_add(1, Ordering::SeqCst);
790            if call_index == 0 {
791                return Ok(lash_core::LlmResponse {
792                    parts: vec![lash_core::LlmOutputPart::ToolCall {
793                        call_id: "batch-call".to_string(),
794                        tool_name: "batch".to_string(),
795                        input_json: serde_json::json!({
796                            "tool_calls": [
797                                {"tool": "alpha", "parameters": {}},
798                                {"tool": "beta", "parameters": {"value": "fail"}}
799                            ]
800                        })
801                        .to_string(),
802                        replay: None,
803                    }],
804                    ..lash_core::LlmResponse::default()
805                });
806            }
807
808            let projected_messages = format!("{:?}", request.messages);
809            if projected_messages.contains("alpha") && projected_messages.contains("beta failed") {
810                self.saw_batch_result.store(true, Ordering::SeqCst);
811            }
812            Ok(lash_core::LlmResponse {
813                full_text: "done".to_string(),
814                parts: vec![lash_core::LlmOutputPart::Text {
815                    text: "done".to_string(),
816                    response_meta: None,
817                }],
818                ..lash_core::LlmResponse::default()
819            })
820        }
821
822        fn clone_boxed(&self) -> Box<dyn lash_core::Provider> {
823            Box::new(self.clone())
824        }
825    }
826
827    #[derive(Debug)]
828    struct BatchRuntimeTools {
829        barrier: Arc<Barrier>,
830        started: Arc<AtomicUsize>,
831    }
832
833    fn runtime_test_tool(name: &str) -> lash_core::ToolDefinition {
834        lash_core::ToolDefinition::raw(
835            format!("tool:{name}"),
836            name,
837            "",
838            serde_json::json!({
839                "type": "object",
840                "properties": {
841                    "value": { "type": "string" }
842                },
843                "additionalProperties": true
844            }),
845            serde_json::json!({ "type": "string" }),
846        )
847        .with_scheduling(lash_core::ToolScheduling::Parallel)
848    }
849
850    #[async_trait::async_trait]
851    impl ToolProvider for BatchRuntimeTools {
852        fn tool_manifests(&self) -> Vec<ToolManifest> {
853            vec![
854                runtime_test_tool("alpha").manifest(),
855                runtime_test_tool("beta").manifest(),
856            ]
857        }
858
859        fn resolve_contract(&self, name: &str) -> Option<Arc<ToolContract>> {
860            match name {
861                "alpha" | "beta" => Some(Arc::new(runtime_test_tool(name).contract())),
862                _ => None,
863            }
864        }
865
866        async fn execute(&self, call: ToolCall<'_>) -> ToolResult {
867            self.started.fetch_add(1, Ordering::SeqCst);
868            if timeout(Duration::from_millis(100), self.barrier.wait())
869                .await
870                .is_err()
871            {
872                return ToolResult::err_fmt("batch child tools did not run concurrently");
873            }
874            if call.name == "beta"
875                && call.args.get("value").and_then(|value| value.as_str()) == Some("fail")
876            {
877                return ToolResult::err_fmt("beta failed");
878            }
879            ToolResult::ok(serde_json::json!(call.name))
880        }
881    }
882
883    #[derive(Clone, Default)]
884    struct CountingEffectController {
885        kinds: Arc<std::sync::Mutex<Vec<lash_core::RuntimeEffectKind>>>,
886    }
887
888    impl CountingEffectController {
889        fn count(&self, kind: lash_core::RuntimeEffectKind) -> usize {
890            self.kinds
891                .lock()
892                .expect("effect kinds")
893                .iter()
894                .filter(|candidate| **candidate == kind)
895                .count()
896        }
897    }
898
899    #[derive(Default)]
900    struct DurableMemoryAttachmentStore {
901        inner: lash_core::InMemoryAttachmentStore,
902    }
903
904    #[async_trait::async_trait]
905    impl lash_core::AttachmentStore for DurableMemoryAttachmentStore {
906        fn persistence(&self) -> lash_core::AttachmentStorePersistence {
907            lash_core::AttachmentStorePersistence::Durable
908        }
909
910        async fn put(
911            &self,
912            bytes: Vec<u8>,
913            meta: lash_core::AttachmentCreateMeta,
914        ) -> Result<lash_core::AttachmentRef, lash_core::AttachmentStoreError> {
915            self.inner.put(bytes, meta).await
916        }
917
918        async fn get(
919            &self,
920            id: &lash_core::AttachmentId,
921        ) -> Result<lash_core::StoredAttachment, lash_core::AttachmentStoreError> {
922            self.inner.get(id).await
923        }
924    }
925
926    #[derive(Default)]
927    struct DurableMemoryProcessEnvStore {
928        inner: lash_core::InMemoryProcessExecutionEnvStore,
929    }
930
931    #[async_trait::async_trait]
932    impl lash_core::ProcessExecutionEnvStore for DurableMemoryProcessEnvStore {
933        fn durability_tier(&self) -> lash_core::DurabilityTier {
934            lash_core::DurabilityTier::Durable
935        }
936
937        async fn put_process_execution_env(
938            &self,
939            env_ref: &lash_core::ProcessExecutionEnvRef,
940            bytes: &[u8],
941        ) -> Result<(), lash_core::PluginError> {
942            self.inner.put_process_execution_env(env_ref, bytes).await
943        }
944
945        async fn get_process_execution_env(
946            &self,
947            env_ref: &lash_core::ProcessExecutionEnvRef,
948        ) -> Result<Option<Vec<u8>>, lash_core::PluginError> {
949            self.inner.get_process_execution_env(env_ref).await
950        }
951    }
952
953    #[async_trait::async_trait]
954    impl lash_core::RuntimeEffectController for CountingEffectController {
955        fn durability_tier(&self) -> lash_core::DurabilityTier {
956            lash_core::DurabilityTier::Durable
957        }
958
959        async fn execute_effect(
960            &self,
961            envelope: lash_core::RuntimeEffectEnvelope,
962            local_executor: lash_core::RuntimeEffectLocalExecutor<'_>,
963        ) -> Result<lash_core::RuntimeEffectOutcome, lash_core::RuntimeEffectControllerError>
964        {
965            self.kinds
966                .lock()
967                .expect("effect kinds")
968                .push(envelope.command.kind());
969            local_executor.execute(envelope).await
970        }
971    }
972
973    #[tokio::test]
974    async fn standard_batch_tool_rejects_nested_batch_inside_durable_attempt() {
975        let provider_calls = Arc::new(AtomicUsize::new(0));
976        let saw_batch_result = Arc::new(AtomicBool::new(false));
977        let provider = BatchRuntimeProvider {
978            calls: Arc::clone(&provider_calls),
979            saw_batch_result: Arc::clone(&saw_batch_result),
980        };
981        let provider_handle = lash_core::ProviderHandle::new(lash_core::ProviderComponents::new(
982            Box::new(provider),
983            Arc::new(lash_core::StaticModelPolicy::new()),
984        ));
985        let mut host = lash_core::RuntimeHostConfig::in_memory();
986        host.providers.provider_resolver =
987            Arc::new(lash_core::SingleProviderResolver::new(provider_handle));
988        host.durability.attachment_store = Arc::new(DurableMemoryAttachmentStore::default());
989        host.durability.process_env_store = Arc::new(DurableMemoryProcessEnvStore::default());
990        let started = Arc::new(AtomicUsize::new(0));
991        let factories: Vec<Arc<dyn lash_core::PluginFactory>> = vec![
992            Arc::new(StandardProtocolPluginFactory::new()),
993            Arc::new(lash_core::plugin::StaticPluginFactory::new(
994                "standard-batch-test-tools",
995                lash_core::PluginSpec::new().with_tool_provider(Arc::new(BatchRuntimeTools {
996                    barrier: Arc::new(Barrier::new(2)),
997                    started: Arc::clone(&started),
998                })),
999            )),
1000        ];
1001        let policy = lash_core::SessionPolicy {
1002            provider_id: "stub".to_string(),
1003            model: lash_core::ModelSpec::from_token_limits("mock-model", None, 200_000, None)
1004                .expect("valid model"),
1005            ..lash_core::SessionPolicy::default()
1006        };
1007        let controller = CountingEffectController::default();
1008        let scoped_controller = lash_core::ScopedEffectController::shared(
1009            Arc::new(controller.clone()),
1010            lash_core::ExecutionScope::turn("standard-batch-session", "turn-1"),
1011        )
1012        .expect("scoped controller");
1013        let mut runtime = lash_core::LashRuntime::builder()
1014            .with_session_id("standard-batch-session")
1015            .with_policy(policy)
1016            .with_runtime_host(host)
1017            .with_plugin_factories(factories)
1018            .build()
1019            .await
1020            .expect("runtime");
1021
1022        let turn = runtime
1023            .stream_turn(
1024                lash_core::TurnInput::text("run the batch"),
1025                lash_core::TurnOptions::new(
1026                    tokio_util::sync::CancellationToken::new(),
1027                    scoped_controller,
1028                ),
1029            )
1030            .await
1031            .expect("turn");
1032
1033        assert!(matches!(turn.outcome, lash_core::TurnOutcome::Finished(_)));
1034        assert_eq!(provider_calls.load(Ordering::SeqCst), 2);
1035        assert_eq!(started.load(Ordering::SeqCst), 0);
1036        assert!(!saw_batch_result.load(Ordering::SeqCst));
1037        assert_eq!(controller.count(lash_core::RuntimeEffectKind::ToolBatch), 1);
1038        assert_eq!(
1039            controller.count(lash_core::RuntimeEffectKind::ToolAttempt),
1040            1
1041        );
1042    }
1043
1044    #[test]
1045    fn tool_attachment_round_trips_to_part_kind_image() {
1046        let attachment = image_ref("att-1");
1047        let output = ToolCallOutput::success(ToolValue::Attachment(attachment.clone()));
1048        let model_return =
1049            ModelToolReturn::from_output("call-9".to_string(), "screenshot".to_string(), &output);
1050
1051        let mut parts: Vec<Part> = Vec::new();
1052        append_model_return_parts(&mut parts, model_return);
1053
1054        assert_eq!(parts.len(), 1, "single attachment yields single part");
1055        let part = &parts[0];
1056        assert!(matches!(part.kind, PartKind::Image));
1057        assert_eq!(part.content, "");
1058        assert_eq!(part.tool_call_id.as_deref(), Some("call-9"));
1059        assert_eq!(part.tool_name.as_deref(), Some("screenshot"));
1060        let part_attachment = part.attachment.as_ref().expect("attachment present");
1061        assert_eq!(part_attachment.reference.id, attachment.id);
1062    }
1063
1064    #[test]
1065    fn tool_text_and_attachment_round_trip_preserves_order() {
1066        let attachment = image_ref("att-2");
1067        let output = ToolCallOutput::success(ToolValue::Array(vec![
1068            ToolValue::String("before".into()),
1069            ToolValue::Attachment(attachment.clone()),
1070            ToolValue::String("after".into()),
1071        ]));
1072        let model_return =
1073            ModelToolReturn::from_output("call-10".to_string(), "snap".to_string(), &output);
1074
1075        let mut parts: Vec<Part> = Vec::new();
1076        append_model_return_parts(&mut parts, model_return);
1077
1078        // The array projection emits compact JSON text fragments around the
1079        // attachment, preserving in-order position.
1080        assert_eq!(parts.len(), 3, "text + image + text yields three parts");
1081        assert!(matches!(parts[0].kind, PartKind::ToolResult));
1082        assert!(parts[0].content.starts_with("[\"before\""));
1083        assert!(matches!(parts[1].kind, PartKind::Image));
1084        assert_eq!(
1085            parts[1]
1086                .attachment
1087                .as_ref()
1088                .expect("attachment")
1089                .reference
1090                .id,
1091            attachment.id
1092        );
1093        assert!(matches!(parts[2].kind, PartKind::ToolResult));
1094        assert!(parts[2].content.ends_with("\"after\"]"));
1095    }
1096}