Skip to main content

lash_protocol_standard/
lib.rs

1//! Standard protocol stack: the model drives tools via the native
2//! function-calling envelope of its LLM transport.
3//!
4//! This crate owns:
5//!
6//! - [`StandardDriver`] — the [`ProtocolDriverHandle`] that dispatches
7//!   native tool calls and weaves reasoning parts into the assistant
8//!   message timeline.
9//! - The [`StandardProtocolPluginFactory`] plugin that claims the
10//!   protocol-driver slot so the runtime can run standard-protocol
11//!   sessions.
12//! - The `batch` tool that composes parallel native tool calls (only
13//!   exposed when this protocol stack is installed).
14
15use std::sync::Arc;
16
17use async_trait::async_trait;
18use lash_core::llm::types::{ProviderReasoningReplay, ProviderReplayMeta, ResponseTextMeta};
19use lash_core::plugin::{
20    PluginError, PluginFactory, PluginRegistrar, PluginSessionContext, ProtocolDriverPlugin,
21    ProtocolSessionContext, ProtocolSessionPlugin, SessionPlugin,
22};
23use lash_core::sansio::{
24    CheckpointResumeAction, CompletedToolCall, PendingToolCall, ProtocolDriverHandle,
25    WaitingExecState, WaitingLlmState,
26};
27use lash_core::session_model::message::PartAttachment;
28use lash_core::session_model::{
29    ConversationRecord, Message, MessageRole, Part, PartKind, PruneState, SessionEvent,
30    SessionEventRecord, fresh_message_id, make_error_event, reassign_part_ids, shared_parts,
31};
32
33mod batch;
34use batch::batch_tool_definition;
35use lash_core::{
36    CheckpointKind, DriverAction, DriverContextView, LlmOutputPart, LlmResponse,
37    ProtocolBuildInput, SessionError, ToolCall, ToolContract, ToolInvocation, ToolManifest,
38    ToolProvider, ToolResult, TurnDriverConfig, TurnDriverPreamble, TurnFinish, TurnOutcome,
39    TurnStop, append_assistant_text_part, normalized_response_parts, reasoning_part,
40};
41use serde_json::Value;
42
43const STANDARD_EXECUTION_SECTION: &str = r#"Use direct tool calls.
44
45- Use `batch` (up to 25 calls) for two or more independent tool calls. Serialize calls when later arguments depend on earlier results.
46- For direct conversational requests that need no tools, respond in prose only.
47
48Example — two independent reads in one `batch` call:
49
50```json
51{
52  "tool_calls": [
53    { "tool": "read_file", "parameters": { "path": "src/main.rs" } },
54    { "tool": "grep", "parameters": { "query": "ToolProvider", "path": "crates/lash/src/" } }
55  ]
56}
57```"#;
58
59const BATCH_MAX_TOOL_CALLS: usize = 25;
60
61/// Plugin factory that installs the standard-protocol driver,
62/// session plugin, and native tool catalog.
63#[derive(Default)]
64pub struct StandardProtocolPluginFactory;
65
66impl StandardProtocolPluginFactory {
67    pub fn new() -> Self {
68        Self
69    }
70}
71
72impl PluginFactory for StandardProtocolPluginFactory {
73    fn id(&self) -> &'static str {
74        "standard_protocol"
75    }
76
77    fn build(&self, _ctx: &PluginSessionContext) -> Result<Arc<dyn SessionPlugin>, PluginError> {
78        Ok(Arc::new(StandardProtocolPlugin))
79    }
80}
81
82struct StandardProtocolPlugin;
83
84impl SessionPlugin for StandardProtocolPlugin {
85    fn id(&self) -> &'static str {
86        "standard_protocol"
87    }
88
89    fn register(&self, reg: &mut PluginRegistrar) -> Result<(), PluginError> {
90        reg.protocol().session(Arc::new(StandardProtocolSession))?;
91        reg.protocol()
92            .protocol_driver(Arc::new(StandardProtocolDriver))?;
93        reg.tools().provider(Arc::new(StandardProtocolTools))?;
94        Ok(())
95    }
96}
97
98struct StandardProtocolSession;
99
100#[async_trait]
101impl ProtocolSessionPlugin for StandardProtocolSession {
102    async fn initialize_session(
103        &self,
104        _ctx: ProtocolSessionContext<'_>,
105    ) -> Result<(), SessionError> {
106        Ok(())
107    }
108}
109
110struct StandardProtocolDriver;
111
112impl ProtocolDriverPlugin for StandardProtocolDriver {
113    fn build_preamble(&self, input: ProtocolBuildInput) -> TurnDriverPreamble {
114        let tool_names = input.tool_catalog.tool_names();
115        let tool_names_fingerprint = input.tool_catalog.tool_names_fingerprint();
116        TurnDriverPreamble {
117            config: TurnDriverConfig::chat(
118                Arc::new(StandardDriver),
119                true,
120                Arc::new(turn_limit_exhausted_message),
121            ),
122            tool_specs: input.tool_catalog.model_tool_specs(),
123            tool_names,
124            tool_names_fingerprint,
125            omitted_tool_count: 0,
126            execution_prompt: Arc::from(STANDARD_EXECUTION_SECTION),
127            prompt_contributions: input.extra_prompt_contributions,
128        }
129    }
130}
131
132fn turn_limit_exhausted_message(message_id: String, max_turns: usize) -> Message {
133    Message {
134        id: message_id.clone(),
135        role: MessageRole::System,
136        parts: shared_parts(vec![Part {
137            id: format!("{message_id}.p0"),
138            kind: PartKind::Error,
139            content: format!("Turn limit reached ({max_turns}) before a final assistant response."),
140            attachment: None,
141            tool_call_id: None,
142            tool_name: None,
143            tool_replay: None,
144            prune_state: PruneState::Intact,
145            reasoning_meta: None,
146            response_meta: None,
147        }]),
148        origin: None,
149    }
150}
151
152struct StandardProtocolTools;
153
154#[async_trait]
155impl ToolProvider for StandardProtocolTools {
156    fn tool_manifests(&self) -> Vec<ToolManifest> {
157        vec![batch_tool_definition().manifest()]
158    }
159
160    fn resolve_contract(&self, name: &str) -> Option<Arc<ToolContract>> {
161        (name == "batch").then(|| Arc::new(batch_tool_definition().contract()))
162    }
163
164    async fn execute(&self, call: ToolCall<'_>) -> ToolResult {
165        match call.name {
166            "batch" => execute_batch_tool_call(call).await,
167            _ => ToolResult::err_fmt(format_args!("Unknown tool: {}", call.name)),
168        }
169    }
170}
171
172#[derive(Debug)]
173struct BatchCallSpec {
174    index: usize,
175    tool: String,
176    parameters: Value,
177}
178
179async fn execute_batch_tool_call(call: ToolCall<'_>) -> ToolResult {
180    let args = call.args;
181    let specs = match parse_batch_specs(args) {
182        Ok(specs) => specs,
183        Err(err) => return err,
184    };
185
186    let mut immediate_outcomes = Vec::new();
187    let mut parallel_specs = Vec::new();
188    let dispatch = call.context.dispatch();
189
190    for spec in specs.into_iter().take(BATCH_MAX_TOOL_CALLS) {
191        if spec.tool == "batch" {
192            immediate_outcomes.push(serde_json::json!({
193                "index": spec.index,
194                "tool": spec.tool,
195                "success": false,
196                "duration_ms": 0,
197                "error": "Tool 'batch' is not allowed inside batch",
198            }));
199            continue;
200        }
201        let Some(manifest) = dispatch.callable_tool_manifest(&spec.tool) else {
202            let error = format!("Tool '{}' is unavailable in this session", spec.tool);
203            immediate_outcomes.push(serde_json::json!({
204                "index": spec.index,
205                "tool": spec.tool,
206                "success": false,
207                "duration_ms": 0,
208                "error": error,
209            }));
210            continue;
211        };
212        parallel_specs.push((
213            spec.index,
214            ToolInvocation::new(
215                format!(
216                    "{}:{:02}",
217                    call.context.tool_call_id().unwrap_or("batch"),
218                    spec.index
219                ),
220                manifest.id,
221                spec.parameters,
222            ),
223        ));
224    }
225
226    let mut parallel_outcomes = dispatch
227        .batch(
228            parallel_specs
229                .iter()
230                .map(|(_, invocation)| invocation.clone())
231                .collect(),
232        )
233        .await;
234    for ((index, invocation), outcome) in
235        parallel_specs.into_iter().zip(parallel_outcomes.drain(..))
236    {
237        let tool_label = invocation.label();
238        let tool_record = outcome.record.unwrap_or(lash_core::ToolCallRecord {
239            call_id: Some(invocation.id),
240            tool: tool_label,
241            args: invocation.args,
242            output: outcome.output,
243            duration_ms: 0,
244        });
245        let mut result_record = serde_json::Map::new();
246        result_record.insert("index".to_string(), serde_json::json!(index));
247        result_record.insert("tool".to_string(), serde_json::json!(tool_record.tool));
248        result_record.insert(
249            "success".to_string(),
250            serde_json::json!(tool_record.output.is_success()),
251        );
252        result_record.insert(
253            "duration_ms".to_string(),
254            serde_json::json!(tool_record.duration_ms),
255        );
256        result_record.insert(
257            if tool_record.output.is_success() {
258                "result".to_string()
259            } else {
260                "error".to_string()
261            },
262            tool_record.output.value_for_projection(),
263        );
264        immediate_outcomes.push(Value::Object(result_record));
265    }
266
267    for overflow_index in BATCH_MAX_TOOL_CALLS
268        ..args
269            .get("tool_calls")
270            .and_then(|value| value.as_array())
271            .map(|value| value.len())
272            .unwrap_or_default()
273    {
274        immediate_outcomes.push(serde_json::json!({
275            "index": overflow_index,
276            "tool": args
277                .get("tool_calls")
278                .and_then(|value| value.as_array())
279                .and_then(|items| items.get(overflow_index))
280                .and_then(|item| item.get("tool"))
281                .and_then(|value| value.as_str())
282                .unwrap_or("unknown"),
283            "success": false,
284            "duration_ms": 0,
285            "error": "Maximum of 25 tool calls allowed in batch",
286        }));
287    }
288
289    immediate_outcomes.sort_by_key(|outcome| {
290        outcome
291            .get("index")
292            .and_then(|value| value.as_u64())
293            .unwrap_or(u64::MAX)
294    });
295    ToolResult::ok(serde_json::json!({
296        "results": immediate_outcomes,
297    }))
298}
299
300#[allow(clippy::result_large_err)]
301fn parse_batch_specs(args: &Value) -> Result<Vec<BatchCallSpec>, ToolResult> {
302    let Some(raw_calls) = args.get("tool_calls").and_then(|value| value.as_array()) else {
303        return Err(ToolResult::err_fmt(
304            "Missing required parameter: tool_calls",
305        ));
306    };
307    if raw_calls.is_empty() {
308        return Err(ToolResult::err_fmt(
309            "Invalid tool_calls: expected at least one call",
310        ));
311    }
312
313    let mut specs = Vec::with_capacity(raw_calls.len());
314    for (index, item) in raw_calls.iter().enumerate() {
315        let Some(object) = item.as_object() else {
316            return Err(ToolResult::err_fmt(format_args!(
317                "Invalid tool_calls[{index}]: expected object with tool and parameters"
318            )));
319        };
320        let Some(tool) = object
321            .get("tool")
322            .and_then(|value| value.as_str())
323            .map(str::trim)
324            .filter(|tool| !tool.is_empty())
325        else {
326            return Err(ToolResult::err_fmt(format_args!(
327                "Invalid tool_calls[{index}].tool: expected non-empty string"
328            )));
329        };
330        let parameters = object
331            .get("parameters")
332            .cloned()
333            .unwrap_or_else(|| serde_json::json!({}));
334        specs.push(BatchCallSpec {
335            index,
336            tool: tool.to_string(),
337            parameters,
338        });
339    }
340
341    Ok(specs)
342}
343
344// ─────────────────────────────────────────────────────────────────────
345// Standard protocol driver
346// ─────────────────────────────────────────────────────────────────────
347
348/// Protocol driver for the Standard protocol. Consumes native
349/// tool-call envelopes from the LLM, dispatches them via
350/// `DriverAction::StartTools`, and splices reasoning parts into the
351/// assistant message so provider replay metadata preserves
352/// chain-of-thought ordering.
353pub struct StandardDriver;
354
355struct StandardToolCall {
356    call_id: String,
357    tool_name: String,
358    input_json: String,
359    replay: Option<ProviderReplayMeta>,
360}
361
362fn last_message_has_tool_result(ctx: &DriverContextView<'_>) -> bool {
363    ctx.messages().last().is_some_and(|message| {
364        matches!(message.role, MessageRole::User)
365            && message
366                .parts
367                .iter()
368                .any(|part| matches!(part.kind, PartKind::ToolResult))
369    })
370}
371
372impl ProtocolDriverHandle<lash_core::HostTurnProtocol> for StandardDriver {
373    fn prepare_protocol_iteration(&self, ctx: DriverContextView<'_>) -> Vec<DriverAction> {
374        vec![DriverAction::StartLlm {
375            request: ctx.project_llm_request(true),
376            driver_state: None,
377        }]
378    }
379
380    fn handle_llm_success(
381        &self,
382        ctx: DriverContextView<'_>,
383        _waiting: WaitingLlmState<lash_core::HostTurnProtocol>,
384        llm_response: LlmResponse,
385        text_streamed: bool,
386    ) -> Vec<DriverAction> {
387        let response_parts = normalized_response_parts(&llm_response);
388        let mut assistant_text = String::new();
389        let mut assistant_text_parts: Vec<(String, Option<ResponseTextMeta>)> = Vec::new();
390        let mut tool_calls: Vec<StandardToolCall> = Vec::new();
391        // Reasoning items captured with their position in the original
392        // response. The `usize` is the index in `tool_calls` that this
393        // reasoning item originally preceded, so we can interleave
394        // reasoning → tool_call in the provider's original emission order.
395        // `Option<ProviderReasoningReplay>` carries roundtrip payload
396        // when present (fix 1.3b); when None, the item is display-only
397        // (fix 1.3a) — still rendered in the UI but never re-fed.
398        let mut reasoning_items: Vec<(usize, Option<ProviderReasoningReplay>, String)> = Vec::new();
399        let mut actions = Vec::new();
400
401        for part in response_parts {
402            match part {
403                LlmOutputPart::Text {
404                    text,
405                    response_meta,
406                } => {
407                    if !text.is_empty() {
408                        let previous_len = assistant_text.len();
409                        append_assistant_text_part(&mut assistant_text, &text);
410                        assistant_text_parts
411                            .push((assistant_text[previous_len..].to_string(), response_meta));
412                        if !text_streamed {
413                            actions.push(DriverAction::Emit(SessionEvent::TextDelta {
414                                content: assistant_text[previous_len..].to_string(),
415                            }));
416                        }
417                    }
418                }
419                LlmOutputPart::Reasoning { text, replay } => {
420                    let trimmed = text.trim().to_string();
421                    // Skip fully-empty reasoning items (no display text and
422                    // no roundtrip payload).
423                    if trimmed.is_empty() && replay.as_ref().is_none_or(|meta| meta.is_empty()) {
424                        continue;
425                    }
426                    reasoning_items.push((tool_calls.len(), replay, trimmed));
427                }
428                LlmOutputPart::ToolCall {
429                    call_id,
430                    tool_name,
431                    input_json,
432                    replay,
433                } => {
434                    tool_calls.push(StandardToolCall {
435                        call_id,
436                        tool_name,
437                        input_json,
438                        replay,
439                    });
440                }
441            }
442        }
443
444        actions.push(DriverAction::Emit(SessionEvent::LlmResponse {
445            protocol_iteration: ctx.protocol_iteration(),
446            content: assistant_text.clone(),
447            duration_ms: 0,
448        }));
449
450        if tool_calls.is_empty() {
451            if assistant_text.trim().is_empty() && reasoning_items.is_empty() {
452                if last_message_has_tool_result(&ctx) {
453                    // A model can intentionally complete a tool-only request
454                    // with an empty final answer, e.g. when the user says
455                    // "do nothing else" after the tool action.
456                    actions.push(DriverAction::StartCheckpoint {
457                        checkpoint: CheckpointKind::BeforeCompletion,
458                        on_empty: CheckpointResumeAction::Finish(TurnOutcome::Finished(
459                            TurnFinish::AssistantMessage {
460                                text: String::new(),
461                            },
462                        )),
463                    });
464                    return actions;
465                }
466                actions.push(DriverAction::Emit(make_error_event(
467                    "llm_provider",
468                    Some("empty_response"),
469                    "Model returned no assistant text or tool calls.",
470                    None,
471                )));
472                actions.push(DriverAction::Finish(TurnOutcome::Stopped(
473                    TurnStop::ProviderError,
474                )));
475                return actions;
476            }
477
478            let asst_id = fresh_message_id();
479            let mut parts_out = Vec::new();
480            for (_, meta, text) in reasoning_items {
481                parts_out.push(reasoning_part(&asst_id, parts_out.len(), text, meta));
482            }
483            for (content, response_meta) in assistant_text_parts {
484                if content.trim().is_empty() {
485                    continue;
486                }
487                parts_out.push(Part {
488                    id: format!("{}.p{}", asst_id, parts_out.len()),
489                    kind: PartKind::Prose,
490                    content,
491                    attachment: None,
492                    tool_call_id: None,
493                    tool_name: None,
494                    tool_replay: None,
495                    prune_state: PruneState::Intact,
496                    reasoning_meta: None,
497                    response_meta,
498                });
499            }
500            if parts_out.is_empty() {
501                actions.push(DriverAction::Emit(make_error_event(
502                    "llm_provider",
503                    Some("empty_response"),
504                    "Model returned no assistant text or tool calls.",
505                    None,
506                )));
507                actions.push(DriverAction::Finish(TurnOutcome::Stopped(
508                    TurnStop::ProviderError,
509                )));
510                return actions;
511            }
512            actions.push(DriverAction::StartCheckpoint {
513                checkpoint: CheckpointKind::BeforeCompletion,
514                on_empty: CheckpointResumeAction::Finish(TurnOutcome::Finished(
515                    TurnFinish::AssistantMessage {
516                        text: assistant_text.clone(),
517                    },
518                )),
519            });
520            return actions;
521        }
522
523        let asst_id = fresh_message_id();
524        let mut assistant_parts = Vec::new();
525        for (content, response_meta) in assistant_text_parts {
526            if content.trim().is_empty() {
527                continue;
528            }
529            assistant_parts.push(Part {
530                id: format!("{}.p{}", asst_id, assistant_parts.len()),
531                kind: PartKind::Prose,
532                content,
533                attachment: None,
534                tool_call_id: None,
535                tool_name: None,
536                tool_replay: None,
537                prune_state: PruneState::Intact,
538                reasoning_meta: None,
539                response_meta,
540            });
541        }
542
543        let mut calls = Vec::new();
544        // Interleave reasoning items with tool calls to preserve the
545        // original emission order. Some provider replays expect the
546        // sequence `reasoning → function_call` from the turn in which both
547        // were produced; swapping them can drop the reasoning/tool pairing.
548        let mut reasoning_iter = reasoning_items.into_iter().peekable();
549        for (tool_index, tool_call) in tool_calls.into_iter().enumerate() {
550            while let Some((insert_index, _, _)) = reasoning_iter.peek() {
551                if *insert_index > tool_index {
552                    break;
553                }
554                let (_, meta, text) = reasoning_iter.next().expect("peek ok");
555                assistant_parts.push(reasoning_part(&asst_id, assistant_parts.len(), text, meta));
556            }
557            assistant_parts.push(Part {
558                id: format!("{}.p{}", asst_id, assistant_parts.len()),
559                kind: PartKind::ToolCall,
560                content: tool_call.input_json.clone(),
561                attachment: None,
562                tool_call_id: Some(tool_call.call_id.clone()),
563                tool_name: Some(tool_call.tool_name.clone()),
564                tool_replay: tool_call.replay.clone(),
565                prune_state: PruneState::Intact,
566                reasoning_meta: None,
567                response_meta: None,
568            });
569
570            let args = serde_json::from_str::<Value>(&tool_call.input_json)
571                .unwrap_or_else(|_| serde_json::json!({}));
572            calls.push(PendingToolCall {
573                call_id: tool_call.call_id,
574                tool_name: tool_call.tool_name,
575                args,
576                replay: tool_call.replay,
577            });
578        }
579        for (_, meta, text) in reasoning_iter {
580            assistant_parts.push(reasoning_part(&asst_id, assistant_parts.len(), text, meta));
581        }
582
583        if !assistant_parts.is_empty() {
584            actions.push(DriverAction::AppendEvents(vec![conversation_event(
585                Message {
586                    id: asst_id,
587                    role: MessageRole::Assistant,
588                    parts: shared_parts(assistant_parts),
589                    origin: None,
590                },
591            )]));
592        }
593
594        actions.push(DriverAction::StartTools { calls });
595        actions
596    }
597
598    fn handle_tool_results(
599        &self,
600        ctx: DriverContextView<'_>,
601        completed: Vec<CompletedToolCall>,
602    ) -> Vec<DriverAction> {
603        let mut actions = Vec::new();
604        let mut result_parts = Vec::new();
605        let mut terminal_outcome = None;
606
607        for outcome in completed {
608            if terminal_outcome.is_none() && outcome.output.is_success() {
609                terminal_outcome = match outcome.output.control.as_ref() {
610                    Some(lash_core::ToolControl::SwitchAgentFrame {
611                        frame_id,
612                        task: Some(task),
613                        ..
614                    }) if !frame_id.trim().is_empty() && !task.trim().is_empty() => {
615                        Some(TurnOutcome::AgentFrameSwitch {
616                            frame_id: frame_id.clone(),
617                            task: task.clone(),
618                        })
619                    }
620                    Some(lash_core::ToolControl::Finish { value }) => {
621                        Some(TurnOutcome::Finished(TurnFinish::ToolValue {
622                            tool_name: outcome.tool_name.clone(),
623                            value: value.to_json_value(),
624                        }))
625                    }
626                    Some(lash_core::ToolControl::Fail { failure }) => {
627                        Some(TurnOutcome::Stopped(TurnStop::ToolError {
628                            tool_name: outcome.tool_name.clone(),
629                            value: failure.to_json_value(),
630                        }))
631                    }
632                    _ => None,
633                };
634            }
635
636            append_model_return_parts(&mut result_parts, outcome.model_return);
637        }
638
639        if !result_parts.is_empty() {
640            let user_id = fresh_message_id();
641            reassign_part_ids(&user_id, &mut result_parts);
642            actions.push(DriverAction::AppendEvents(vec![conversation_event(
643                Message {
644                    id: user_id,
645                    role: MessageRole::User,
646                    parts: shared_parts(result_parts),
647                    origin: None,
648                },
649            )]));
650        }
651
652        if let Some(outcome) = terminal_outcome {
653            actions.push(DriverAction::Finish(outcome));
654            return actions;
655        }
656
657        actions.push(DriverAction::AdvanceProtocolIteration);
658        let next_protocol_iteration = ctx.protocol_iteration() + 1;
659        if let Some(max_turns) = ctx.max_turns()
660            && next_protocol_iteration >= ctx.protocol_run_offset() + max_turns
661        {
662            let message_id = fresh_message_id();
663            actions.push(DriverAction::AppendEvents(vec![conversation_event(
664                turn_limit_exhausted_message(message_id, max_turns),
665            )]));
666            actions.push(DriverAction::Finish(TurnOutcome::Stopped(
667                TurnStop::MaxTurns,
668            )));
669            return actions;
670        }
671
672        actions.push(DriverAction::StartCheckpoint {
673            checkpoint: CheckpointKind::AfterWork,
674            on_empty: CheckpointResumeAction::PrepareIteration,
675        });
676        actions
677    }
678
679    fn handle_exec_result(
680        &self,
681        _ctx: DriverContextView<'_>,
682        _waiting: WaitingExecState<lash_core::HostTurnProtocol>,
683        _result: Result<lash_core::ExecResponse, String>,
684    ) -> Vec<DriverAction> {
685        Vec::new()
686    }
687}
688
689fn append_model_return_parts(parts: &mut Vec<Part>, model_return: lash_core::ModelToolReturn) {
690    for part in model_return.parts {
691        match part {
692            lash_core::ModelToolReturnPart::Text { text } => {
693                if text.is_empty() {
694                    continue;
695                }
696                parts.push(Part {
697                    id: String::new(),
698                    kind: PartKind::ToolResult,
699                    content: text,
700                    attachment: None,
701                    tool_call_id: Some(model_return.call_id.clone()),
702                    tool_name: Some(model_return.tool_name.clone()),
703                    tool_replay: None,
704                    prune_state: PruneState::Intact,
705                    reasoning_meta: None,
706                    response_meta: None,
707                });
708            }
709            lash_core::ModelToolReturnPart::Attachment(reference) => {
710                parts.push(Part {
711                    id: String::new(),
712                    kind: PartKind::Image,
713                    content: String::new(),
714                    attachment: Some(PartAttachment { reference }),
715                    tool_call_id: Some(model_return.call_id.clone()),
716                    tool_name: Some(model_return.tool_name.clone()),
717                    tool_replay: None,
718                    prune_state: PruneState::Intact,
719                    reasoning_meta: None,
720                    response_meta: None,
721                });
722            }
723        }
724    }
725}
726
727fn conversation_event(message: Message) -> SessionEventRecord {
728    SessionEventRecord::Conversation(ConversationRecord::from_message(message))
729}
730
731#[cfg(test)]
732mod tests {
733    use super::*;
734    use lash_core::{
735        AttachmentId, AttachmentMeta, ImageMediaType, MediaType, ModelToolReturn, ToolCallOutput,
736        ToolValue,
737    };
738    use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
739    use tokio::sync::Barrier;
740    use tokio::time::{Duration, timeout};
741
742    fn image_ref(id: &str) -> lash_core::AttachmentRef {
743        AttachmentMeta::new(
744            AttachmentId::new(id),
745            MediaType::Image(ImageMediaType::Png),
746            4,
747            Some(1),
748            Some(1),
749            Some("tiny".to_string()),
750        )
751        .as_ref()
752    }
753
754    #[derive(Clone, Debug)]
755    struct BatchRuntimeProvider {
756        calls: Arc<AtomicUsize>,
757        saw_batch_result: Arc<AtomicBool>,
758    }
759
760    #[async_trait::async_trait]
761    impl lash_core::Provider for BatchRuntimeProvider {
762        fn kind(&self) -> &'static str {
763            "stub"
764        }
765
766        fn options(&self) -> lash_core::ProviderOptions {
767            lash_core::ProviderOptions::default()
768        }
769
770        fn set_options(&mut self, _options: lash_core::ProviderOptions) {}
771
772        fn serialize_config(&self) -> serde_json::Value {
773            serde_json::json!({})
774        }
775
776        async fn complete(
777            &mut self,
778            request: lash_core::LlmRequest,
779        ) -> Result<lash_core::LlmResponse, lash_core::LlmTransportError> {
780            let call_index = self.calls.fetch_add(1, Ordering::SeqCst);
781            if call_index == 0 {
782                return Ok(lash_core::LlmResponse {
783                    parts: vec![lash_core::LlmOutputPart::ToolCall {
784                        call_id: "batch-call".to_string(),
785                        tool_name: "batch".to_string(),
786                        input_json: serde_json::json!({
787                            "tool_calls": [
788                                {"tool": "alpha", "parameters": {}},
789                                {"tool": "beta", "parameters": {"value": "fail"}}
790                            ]
791                        })
792                        .to_string(),
793                        replay: None,
794                    }],
795                    ..lash_core::LlmResponse::default()
796                });
797            }
798
799            let projected_messages = format!("{:?}", request.messages);
800            if projected_messages.contains("alpha") && projected_messages.contains("beta failed") {
801                self.saw_batch_result.store(true, Ordering::SeqCst);
802            }
803            Ok(lash_core::LlmResponse {
804                full_text: "done".to_string(),
805                parts: vec![lash_core::LlmOutputPart::Text {
806                    text: "done".to_string(),
807                    response_meta: None,
808                }],
809                ..lash_core::LlmResponse::default()
810            })
811        }
812
813        fn clone_boxed(&self) -> Box<dyn lash_core::Provider> {
814            Box::new(self.clone())
815        }
816    }
817
818    #[derive(Debug)]
819    struct BatchRuntimeTools {
820        barrier: Arc<Barrier>,
821        started: Arc<AtomicUsize>,
822    }
823
824    fn runtime_test_tool(name: &str) -> lash_core::ToolDefinition {
825        lash_core::ToolDefinition::raw(
826            format!("tool:{name}"),
827            name,
828            "",
829            serde_json::json!({
830                "type": "object",
831                "properties": {
832                    "value": { "type": "string" }
833                },
834                "additionalProperties": true
835            }),
836            serde_json::json!({ "type": "string" }),
837        )
838        .with_scheduling(lash_core::ToolScheduling::Parallel)
839    }
840
841    #[async_trait::async_trait]
842    impl ToolProvider for BatchRuntimeTools {
843        fn tool_manifests(&self) -> Vec<ToolManifest> {
844            vec![
845                runtime_test_tool("alpha").manifest(),
846                runtime_test_tool("beta").manifest(),
847            ]
848        }
849
850        fn resolve_contract(&self, name: &str) -> Option<Arc<ToolContract>> {
851            match name {
852                "alpha" | "beta" => Some(Arc::new(runtime_test_tool(name).contract())),
853                _ => None,
854            }
855        }
856
857        async fn execute(&self, call: ToolCall<'_>) -> ToolResult {
858            self.started.fetch_add(1, Ordering::SeqCst);
859            if timeout(Duration::from_millis(100), self.barrier.wait())
860                .await
861                .is_err()
862            {
863                return ToolResult::err_fmt("batch child tools did not run concurrently");
864            }
865            if call.name == "beta"
866                && call.args.get("value").and_then(|value| value.as_str()) == Some("fail")
867            {
868                return ToolResult::err_fmt("beta failed");
869            }
870            ToolResult::ok(serde_json::json!(call.name))
871        }
872    }
873
874    #[derive(Clone, Default)]
875    struct CountingEffectController {
876        kinds: Arc<std::sync::Mutex<Vec<lash_core::RuntimeEffectKind>>>,
877    }
878
879    impl CountingEffectController {
880        fn count(&self, kind: lash_core::RuntimeEffectKind) -> usize {
881            self.kinds
882                .lock()
883                .expect("effect kinds")
884                .iter()
885                .filter(|candidate| **candidate == kind)
886                .count()
887        }
888    }
889
890    #[derive(Default)]
891    struct DurableMemoryAttachmentStore {
892        inner: lash_core::InMemoryAttachmentStore,
893    }
894
895    #[async_trait::async_trait]
896    impl lash_core::AttachmentStore for DurableMemoryAttachmentStore {
897        fn persistence(&self) -> lash_core::AttachmentStorePersistence {
898            lash_core::AttachmentStorePersistence::Durable
899        }
900
901        async fn put(
902            &self,
903            bytes: Vec<u8>,
904            meta: lash_core::AttachmentCreateMeta,
905        ) -> Result<lash_core::AttachmentRef, lash_core::AttachmentStoreError> {
906            self.inner.put(bytes, meta).await
907        }
908
909        async fn get(
910            &self,
911            id: &lash_core::AttachmentId,
912        ) -> Result<lash_core::StoredAttachment, lash_core::AttachmentStoreError> {
913            self.inner.get(id).await
914        }
915    }
916
917    #[derive(Default)]
918    struct DurableMemoryProcessEnvStore {
919        inner: lash_core::InMemoryProcessExecutionEnvStore,
920    }
921
922    #[async_trait::async_trait]
923    impl lash_core::ProcessExecutionEnvStore for DurableMemoryProcessEnvStore {
924        fn durability_tier(&self) -> lash_core::DurabilityTier {
925            lash_core::DurabilityTier::Durable
926        }
927
928        async fn put_process_execution_env(
929            &self,
930            env_ref: &lash_core::ProcessExecutionEnvRef,
931            bytes: &[u8],
932        ) -> Result<(), lash_core::PluginError> {
933            self.inner.put_process_execution_env(env_ref, bytes).await
934        }
935
936        async fn get_process_execution_env(
937            &self,
938            env_ref: &lash_core::ProcessExecutionEnvRef,
939        ) -> Result<Option<Vec<u8>>, lash_core::PluginError> {
940            self.inner.get_process_execution_env(env_ref).await
941        }
942    }
943
944    #[async_trait::async_trait]
945    impl lash_core::RuntimeEffectController for CountingEffectController {
946        fn durability_tier(&self) -> lash_core::DurabilityTier {
947            lash_core::DurabilityTier::Durable
948        }
949
950        async fn execute_effect(
951            &self,
952            envelope: lash_core::RuntimeEffectEnvelope,
953            local_executor: lash_core::RuntimeEffectLocalExecutor<'_>,
954        ) -> Result<lash_core::RuntimeEffectOutcome, lash_core::RuntimeEffectControllerError>
955        {
956            self.kinds
957                .lock()
958                .expect("effect kinds")
959                .push(envelope.command.kind());
960            local_executor.execute(envelope).await
961        }
962    }
963
964    #[tokio::test]
965    async fn standard_batch_tool_uses_core_tool_batch_under_durable_boundary() {
966        let provider_calls = Arc::new(AtomicUsize::new(0));
967        let saw_batch_result = Arc::new(AtomicBool::new(false));
968        let provider = BatchRuntimeProvider {
969            calls: Arc::clone(&provider_calls),
970            saw_batch_result: Arc::clone(&saw_batch_result),
971        };
972        let provider_handle = lash_core::ProviderHandle::new(lash_core::ProviderComponents::new(
973            Box::new(provider),
974            Arc::new(lash_core::StaticModelPolicy::new()),
975        ));
976        let mut host = lash_core::RuntimeHostConfig::in_memory();
977        host.providers.provider_resolver =
978            Arc::new(lash_core::SingleProviderResolver::new(provider_handle));
979        host.durability.attachment_store = Arc::new(DurableMemoryAttachmentStore::default());
980        host.durability.process_env_store = Arc::new(DurableMemoryProcessEnvStore::default());
981        let started = Arc::new(AtomicUsize::new(0));
982        let factories: Vec<Arc<dyn lash_core::PluginFactory>> = vec![
983            Arc::new(StandardProtocolPluginFactory::new()),
984            Arc::new(lash_core::plugin::StaticPluginFactory::new(
985                "standard-batch-test-tools",
986                lash_core::PluginSpec::new().with_tool_provider(Arc::new(BatchRuntimeTools {
987                    barrier: Arc::new(Barrier::new(2)),
988                    started: Arc::clone(&started),
989                })),
990            )),
991        ];
992        let policy = lash_core::SessionPolicy {
993            provider_id: "stub".to_string(),
994            model: lash_core::ModelSpec::from_token_limits("mock-model", None, 200_000, None)
995                .expect("valid model"),
996            ..lash_core::SessionPolicy::default()
997        };
998        let controller = CountingEffectController::default();
999        let scoped_controller = lash_core::ScopedEffectController::shared(
1000            Arc::new(controller.clone()),
1001            lash_core::ExecutionScope::turn("standard-batch-session", "turn-1"),
1002        )
1003        .expect("scoped controller");
1004        let mut runtime = lash_core::LashRuntime::builder()
1005            .with_session_id("standard-batch-session")
1006            .with_policy(policy)
1007            .with_runtime_host(host)
1008            .with_plugin_factories(factories)
1009            .build()
1010            .await
1011            .expect("runtime");
1012
1013        let turn = runtime
1014            .stream_turn(
1015                lash_core::TurnInput::text("run the batch"),
1016                lash_core::TurnOptions::new(
1017                    tokio_util::sync::CancellationToken::new(),
1018                    scoped_controller,
1019                ),
1020            )
1021            .await
1022            .expect("turn");
1023
1024        assert!(matches!(turn.outcome, lash_core::TurnOutcome::Finished(_)));
1025        assert_eq!(provider_calls.load(Ordering::SeqCst), 2);
1026        assert_eq!(started.load(Ordering::SeqCst), 2);
1027        assert!(saw_batch_result.load(Ordering::SeqCst));
1028        assert_eq!(controller.count(lash_core::RuntimeEffectKind::ToolBatch), 1);
1029    }
1030
1031    #[test]
1032    fn tool_attachment_round_trips_to_part_kind_image() {
1033        let attachment = image_ref("att-1");
1034        let output = ToolCallOutput::success(ToolValue::Attachment(attachment.clone()));
1035        let model_return =
1036            ModelToolReturn::from_output("call-9".to_string(), "screenshot".to_string(), &output);
1037
1038        let mut parts: Vec<Part> = Vec::new();
1039        append_model_return_parts(&mut parts, model_return);
1040
1041        assert_eq!(parts.len(), 1, "single attachment yields single part");
1042        let part = &parts[0];
1043        assert!(matches!(part.kind, PartKind::Image));
1044        assert_eq!(part.content, "");
1045        assert_eq!(part.tool_call_id.as_deref(), Some("call-9"));
1046        assert_eq!(part.tool_name.as_deref(), Some("screenshot"));
1047        let part_attachment = part.attachment.as_ref().expect("attachment present");
1048        assert_eq!(part_attachment.reference.id, attachment.id);
1049    }
1050
1051    #[test]
1052    fn tool_text_and_attachment_round_trip_preserves_order() {
1053        let attachment = image_ref("att-2");
1054        let output = ToolCallOutput::success(ToolValue::Array(vec![
1055            ToolValue::String("before".into()),
1056            ToolValue::Attachment(attachment.clone()),
1057            ToolValue::String("after".into()),
1058        ]));
1059        let model_return =
1060            ModelToolReturn::from_output("call-10".to_string(), "snap".to_string(), &output);
1061
1062        let mut parts: Vec<Part> = Vec::new();
1063        append_model_return_parts(&mut parts, model_return);
1064
1065        // The array projection emits compact JSON text fragments around the
1066        // attachment, preserving in-order position.
1067        assert_eq!(parts.len(), 3, "text + image + text yields three parts");
1068        assert!(matches!(parts[0].kind, PartKind::ToolResult));
1069        assert!(parts[0].content.starts_with("[\"before\""));
1070        assert!(matches!(parts[1].kind, PartKind::Image));
1071        assert_eq!(
1072            parts[1]
1073                .attachment
1074                .as_ref()
1075                .expect("attachment")
1076                .reference
1077                .id,
1078            attachment.id
1079        );
1080        assert!(matches!(parts[2].kind, PartKind::ToolResult));
1081        assert!(parts[2].content.ends_with("\"after\"]"));
1082    }
1083}