Skip to main content

lash_protocol_standard/
lib.rs

1//! Standard protocol stack: the model drives tools via the native
2//! function-calling envelope of its LLM transport.
3//!
4//! This crate owns:
5//!
6//! - [`StandardDriver`] — the [`ProtocolDriverHandle`] that dispatches
7//!   native tool calls and weaves reasoning parts into the assistant
8//!   message timeline.
9//! - The [`StandardProtocolPluginFactory`] plugin that claims the
10//!   protocol-driver slot so the runtime can run standard-protocol
11//!   sessions.
12//! - The `batch` tool that composes parallel native tool calls (only
13//!   exposed when this protocol stack is installed).
14
15use std::sync::Arc;
16
17use async_trait::async_trait;
18use lash_core::llm::types::{ProviderReasoningReplay, ProviderReplayMeta, ResponseTextMeta};
19use lash_core::plugin::{
20    PluginError, PluginFactory, PluginRegistrar, PluginSessionContext, ProtocolDriverPlugin,
21    ProtocolSessionContext, ProtocolSessionPlugin, SessionPlugin,
22};
23use lash_core::sansio::{
24    CheckpointResumeAction, CompletedToolCall, PendingToolCall, ProtocolDriverHandle,
25    WaitingExecState, WaitingLlmState,
26};
27use lash_core::session_model::message::PartAttachment;
28use lash_core::session_model::{
29    ConversationRecord, Message, MessageRole, Part, PartKind, PruneState, SessionEvent,
30    SessionEventRecord, fresh_message_id, make_error_event, reassign_part_ids, shared_parts,
31};
32
33mod batch;
34use batch::batch_tool_definition;
35use lash_core::{
36    CheckpointKind, DriverAction, DriverContextView, LlmOutputPart, LlmResponse,
37    ProtocolBuildInput, SessionError, ToolCall, ToolContract, ToolInvocation, ToolManifest,
38    ToolProvider, ToolResult, TurnDriverConfig, TurnDriverPreamble, TurnFinish, TurnOutcome,
39    TurnStop, append_assistant_text_part, normalized_response_parts, reasoning_part,
40};
41use serde_json::Value;
42
43const STANDARD_EXECUTION_SECTION: &str = r#"Use direct tool calls.
44
45- Use `batch` (up to 25 calls) for two or more independent tool calls. Serialize calls when later arguments depend on earlier results.
46- For direct conversational requests that need no tools, respond in prose only.
47
48Example — two independent reads in one `batch` call:
49
50```json
51{
52  "tool_calls": [
53    { "tool": "read_file", "parameters": { "path": "src/main.rs" } },
54    { "tool": "grep", "parameters": { "query": "ToolProvider", "path": "crates/lash/src/" } }
55  ]
56}
57```"#;
58
59const BATCH_MAX_TOOL_CALLS: usize = 25;
60
61/// Plugin factory that installs the standard-protocol driver,
62/// session plugin, and native tool catalog.
63#[derive(Default)]
64pub struct StandardProtocolPluginFactory;
65
66impl StandardProtocolPluginFactory {
67    pub fn new() -> Self {
68        Self
69    }
70}
71
72impl PluginFactory for StandardProtocolPluginFactory {
73    fn id(&self) -> &'static str {
74        "standard_protocol"
75    }
76
77    fn build(&self, _ctx: &PluginSessionContext) -> Result<Arc<dyn SessionPlugin>, PluginError> {
78        Ok(Arc::new(StandardProtocolPlugin))
79    }
80}
81
82struct StandardProtocolPlugin;
83
84impl SessionPlugin for StandardProtocolPlugin {
85    fn id(&self) -> &'static str {
86        "standard_protocol"
87    }
88
89    fn register(&self, reg: &mut PluginRegistrar) -> Result<(), PluginError> {
90        reg.protocol().session(Arc::new(StandardProtocolSession))?;
91        reg.protocol()
92            .protocol_driver(Arc::new(StandardProtocolDriver))?;
93        reg.tools().provider(Arc::new(StandardProtocolTools))?;
94        Ok(())
95    }
96}
97
98struct StandardProtocolSession;
99
100#[async_trait]
101impl ProtocolSessionPlugin for StandardProtocolSession {
102    async fn initialize_session(
103        &self,
104        _ctx: ProtocolSessionContext<'_>,
105    ) -> Result<(), SessionError> {
106        Ok(())
107    }
108}
109
110struct StandardProtocolDriver;
111
112impl ProtocolDriverPlugin for StandardProtocolDriver {
113    fn build_preamble(&self, input: ProtocolBuildInput) -> TurnDriverPreamble {
114        let tool_names = input.tool_catalog.tool_names();
115        let tool_names_fingerprint = input.tool_catalog.tool_names_fingerprint();
116        TurnDriverPreamble {
117            config: TurnDriverConfig::chat(
118                Arc::new(StandardDriver),
119                true,
120                Arc::new(turn_limit_exhausted_message),
121            ),
122            tool_specs: input.tool_catalog.model_tool_specs(),
123            tool_names,
124            tool_names_fingerprint,
125            omitted_tool_count: 0,
126            execution_prompt: Arc::from(STANDARD_EXECUTION_SECTION),
127            prompt_contributions: input.extra_prompt_contributions,
128        }
129    }
130}
131
132fn turn_limit_exhausted_message(message_id: String, max_turns: usize) -> Message {
133    Message {
134        id: message_id.clone(),
135        role: MessageRole::System,
136        parts: shared_parts(vec![Part {
137            id: format!("{message_id}.p0"),
138            kind: PartKind::Error,
139            content: format!("Turn limit reached ({max_turns}) before a final assistant response."),
140            attachment: None,
141            tool_call_id: None,
142            tool_name: None,
143            tool_replay: None,
144            prune_state: PruneState::Intact,
145            reasoning_meta: None,
146            response_meta: None,
147        }]),
148        origin: None,
149    }
150}
151
152struct StandardProtocolTools;
153
154#[async_trait]
155impl ToolProvider for StandardProtocolTools {
156    fn tool_manifests(&self) -> Vec<ToolManifest> {
157        vec![batch_tool_definition().manifest()]
158    }
159
160    fn resolve_contract(&self, name: &str) -> Option<Arc<ToolContract>> {
161        (name == "batch").then(|| Arc::new(batch_tool_definition().contract()))
162    }
163
164    async fn execute(&self, call: ToolCall<'_>) -> ToolResult {
165        match call.name {
166            "batch" => execute_batch_tool_call(call).await,
167            _ => ToolResult::err_fmt(format_args!("Unknown tool: {}", call.name)),
168        }
169    }
170}
171
172#[derive(Debug)]
173struct BatchCallSpec {
174    index: usize,
175    tool: String,
176    parameters: Value,
177}
178
179async fn execute_batch_tool_call(call: ToolCall<'_>) -> ToolResult {
180    let args = call.args;
181    let specs = match parse_batch_specs(args) {
182        Ok(specs) => specs,
183        Err(err) => return err,
184    };
185
186    let mut immediate_outcomes = Vec::new();
187    let mut parallel_specs = Vec::new();
188
189    for spec in specs.into_iter().take(BATCH_MAX_TOOL_CALLS) {
190        if spec.tool == "batch" {
191            immediate_outcomes.push(serde_json::json!({
192                "index": spec.index,
193                "tool": spec.tool,
194                "success": false,
195                "duration_ms": 0,
196                "error": "Tool 'batch' is not allowed inside batch",
197            }));
198            continue;
199        }
200        parallel_specs.push((
201            spec.index,
202            ToolInvocation::new(
203                format!(
204                    "{}:{:02}",
205                    call.context.tool_call_id().unwrap_or("batch"),
206                    spec.index
207                ),
208                spec.tool,
209                spec.parameters,
210            ),
211        ));
212    }
213
214    let mut parallel_outcomes = call
215        .context
216        .dispatch()
217        .batch(
218            parallel_specs
219                .iter()
220                .map(|(_, invocation)| invocation.clone())
221                .collect(),
222        )
223        .await;
224    for ((index, invocation), outcome) in
225        parallel_specs.into_iter().zip(parallel_outcomes.drain(..))
226    {
227        let tool_record = outcome.record.unwrap_or(lash_core::ToolCallRecord {
228            call_id: Some(invocation.id),
229            tool: invocation.name,
230            args: invocation.args,
231            output: outcome.output,
232            duration_ms: 0,
233        });
234        let mut result_record = serde_json::Map::new();
235        result_record.insert("index".to_string(), serde_json::json!(index));
236        result_record.insert("tool".to_string(), serde_json::json!(tool_record.tool));
237        result_record.insert(
238            "success".to_string(),
239            serde_json::json!(tool_record.output.is_success()),
240        );
241        result_record.insert(
242            "duration_ms".to_string(),
243            serde_json::json!(tool_record.duration_ms),
244        );
245        result_record.insert(
246            if tool_record.output.is_success() {
247                "result".to_string()
248            } else {
249                "error".to_string()
250            },
251            tool_record.output.value_for_projection(),
252        );
253        immediate_outcomes.push(Value::Object(result_record));
254    }
255
256    for overflow_index in BATCH_MAX_TOOL_CALLS
257        ..args
258            .get("tool_calls")
259            .and_then(|value| value.as_array())
260            .map(|value| value.len())
261            .unwrap_or_default()
262    {
263        immediate_outcomes.push(serde_json::json!({
264            "index": overflow_index,
265            "tool": args
266                .get("tool_calls")
267                .and_then(|value| value.as_array())
268                .and_then(|items| items.get(overflow_index))
269                .and_then(|item| item.get("tool"))
270                .and_then(|value| value.as_str())
271                .unwrap_or("unknown"),
272            "success": false,
273            "duration_ms": 0,
274            "error": "Maximum of 25 tool calls allowed in batch",
275        }));
276    }
277
278    immediate_outcomes.sort_by_key(|outcome| {
279        outcome
280            .get("index")
281            .and_then(|value| value.as_u64())
282            .unwrap_or(u64::MAX)
283    });
284    ToolResult::ok(serde_json::json!({
285        "results": immediate_outcomes,
286    }))
287}
288
289#[allow(clippy::result_large_err)]
290fn parse_batch_specs(args: &Value) -> Result<Vec<BatchCallSpec>, ToolResult> {
291    let Some(raw_calls) = args.get("tool_calls").and_then(|value| value.as_array()) else {
292        return Err(ToolResult::err_fmt(
293            "Missing required parameter: tool_calls",
294        ));
295    };
296    if raw_calls.is_empty() {
297        return Err(ToolResult::err_fmt(
298            "Invalid tool_calls: expected at least one call",
299        ));
300    }
301
302    let mut specs = Vec::with_capacity(raw_calls.len());
303    for (index, item) in raw_calls.iter().enumerate() {
304        let Some(object) = item.as_object() else {
305            return Err(ToolResult::err_fmt(format_args!(
306                "Invalid tool_calls[{index}]: expected object with tool and parameters"
307            )));
308        };
309        let Some(tool) = object
310            .get("tool")
311            .and_then(|value| value.as_str())
312            .map(str::trim)
313            .filter(|tool| !tool.is_empty())
314        else {
315            return Err(ToolResult::err_fmt(format_args!(
316                "Invalid tool_calls[{index}].tool: expected non-empty string"
317            )));
318        };
319        let parameters = object
320            .get("parameters")
321            .cloned()
322            .unwrap_or_else(|| serde_json::json!({}));
323        specs.push(BatchCallSpec {
324            index,
325            tool: tool.to_string(),
326            parameters,
327        });
328    }
329
330    Ok(specs)
331}
332
333// ─────────────────────────────────────────────────────────────────────
334// Standard protocol driver
335// ─────────────────────────────────────────────────────────────────────
336
337/// Protocol driver for the Standard protocol. Consumes native
338/// tool-call envelopes from the LLM, dispatches them via
339/// `DriverAction::StartTools`, and splices reasoning parts into the
340/// assistant message so provider replay metadata preserves
341/// chain-of-thought ordering.
342pub struct StandardDriver;
343
344struct StandardToolCall {
345    call_id: String,
346    tool_name: String,
347    input_json: String,
348    replay: Option<ProviderReplayMeta>,
349}
350
351fn last_message_has_tool_result(ctx: &DriverContextView<'_>) -> bool {
352    ctx.messages().last().is_some_and(|message| {
353        matches!(message.role, MessageRole::User)
354            && message
355                .parts
356                .iter()
357                .any(|part| matches!(part.kind, PartKind::ToolResult))
358    })
359}
360
361impl ProtocolDriverHandle<lash_core::HostTurnProtocol> for StandardDriver {
362    fn prepare_protocol_iteration(&self, ctx: DriverContextView<'_>) -> Vec<DriverAction> {
363        vec![DriverAction::StartLlm {
364            request: ctx.project_llm_request(true),
365            driver_state: None,
366        }]
367    }
368
369    fn handle_llm_success(
370        &self,
371        ctx: DriverContextView<'_>,
372        _waiting: WaitingLlmState<lash_core::HostTurnProtocol>,
373        llm_response: LlmResponse,
374        text_streamed: bool,
375    ) -> Vec<DriverAction> {
376        let response_parts = normalized_response_parts(&llm_response);
377        let mut assistant_text = String::new();
378        let mut assistant_text_parts: Vec<(String, Option<ResponseTextMeta>)> = Vec::new();
379        let mut tool_calls: Vec<StandardToolCall> = Vec::new();
380        // Reasoning items captured with their position in the original
381        // response. The `usize` is the index in `tool_calls` that this
382        // reasoning item originally preceded, so we can interleave
383        // reasoning → tool_call in the provider's original emission order.
384        // `Option<ProviderReasoningReplay>` carries roundtrip payload
385        // when present (fix 1.3b); when None, the item is display-only
386        // (fix 1.3a) — still rendered in the UI but never re-fed.
387        let mut reasoning_items: Vec<(usize, Option<ProviderReasoningReplay>, String)> = Vec::new();
388        let mut actions = Vec::new();
389
390        for part in response_parts {
391            match part {
392                LlmOutputPart::Text {
393                    text,
394                    response_meta,
395                } => {
396                    if !text.is_empty() {
397                        let previous_len = assistant_text.len();
398                        append_assistant_text_part(&mut assistant_text, &text);
399                        assistant_text_parts
400                            .push((assistant_text[previous_len..].to_string(), response_meta));
401                        if !text_streamed {
402                            actions.push(DriverAction::Emit(SessionEvent::TextDelta {
403                                content: assistant_text[previous_len..].to_string(),
404                            }));
405                        }
406                    }
407                }
408                LlmOutputPart::Reasoning { text, replay } => {
409                    let trimmed = text.trim().to_string();
410                    // Skip fully-empty reasoning items (no display text and
411                    // no roundtrip payload).
412                    if trimmed.is_empty() && replay.as_ref().is_none_or(|meta| meta.is_empty()) {
413                        continue;
414                    }
415                    reasoning_items.push((tool_calls.len(), replay, trimmed));
416                }
417                LlmOutputPart::ToolCall {
418                    call_id,
419                    tool_name,
420                    input_json,
421                    replay,
422                } => {
423                    tool_calls.push(StandardToolCall {
424                        call_id,
425                        tool_name,
426                        input_json,
427                        replay,
428                    });
429                }
430            }
431        }
432
433        actions.push(DriverAction::Emit(SessionEvent::LlmResponse {
434            protocol_iteration: ctx.protocol_iteration(),
435            content: assistant_text.clone(),
436            duration_ms: 0,
437        }));
438
439        if tool_calls.is_empty() {
440            if assistant_text.trim().is_empty() && reasoning_items.is_empty() {
441                if last_message_has_tool_result(&ctx) {
442                    // A model can intentionally complete a tool-only request
443                    // with an empty final answer, e.g. when the user says
444                    // "do nothing else" after the tool action.
445                    actions.push(DriverAction::StartCheckpoint {
446                        checkpoint: CheckpointKind::BeforeCompletion,
447                        on_empty: CheckpointResumeAction::Finish(TurnOutcome::Finished(
448                            TurnFinish::AssistantMessage {
449                                text: String::new(),
450                            },
451                        )),
452                    });
453                    return actions;
454                }
455                actions.push(DriverAction::Emit(make_error_event(
456                    "llm_provider",
457                    Some("empty_response"),
458                    "Model returned no assistant text or tool calls.",
459                    None,
460                )));
461                actions.push(DriverAction::Finish(TurnOutcome::Stopped(
462                    TurnStop::ProviderError,
463                )));
464                return actions;
465            }
466
467            let asst_id = fresh_message_id();
468            let mut parts_out = Vec::new();
469            for (_, meta, text) in reasoning_items {
470                parts_out.push(reasoning_part(&asst_id, parts_out.len(), text, meta));
471            }
472            for (content, response_meta) in assistant_text_parts {
473                if content.trim().is_empty() {
474                    continue;
475                }
476                parts_out.push(Part {
477                    id: format!("{}.p{}", asst_id, parts_out.len()),
478                    kind: PartKind::Prose,
479                    content,
480                    attachment: None,
481                    tool_call_id: None,
482                    tool_name: None,
483                    tool_replay: None,
484                    prune_state: PruneState::Intact,
485                    reasoning_meta: None,
486                    response_meta,
487                });
488            }
489            if parts_out.is_empty() {
490                actions.push(DriverAction::Emit(make_error_event(
491                    "llm_provider",
492                    Some("empty_response"),
493                    "Model returned no assistant text or tool calls.",
494                    None,
495                )));
496                actions.push(DriverAction::Finish(TurnOutcome::Stopped(
497                    TurnStop::ProviderError,
498                )));
499                return actions;
500            }
501            actions.push(DriverAction::StartCheckpoint {
502                checkpoint: CheckpointKind::BeforeCompletion,
503                on_empty: CheckpointResumeAction::Finish(TurnOutcome::Finished(
504                    TurnFinish::AssistantMessage {
505                        text: assistant_text.clone(),
506                    },
507                )),
508            });
509            return actions;
510        }
511
512        let asst_id = fresh_message_id();
513        let mut assistant_parts = Vec::new();
514        for (content, response_meta) in assistant_text_parts {
515            if content.trim().is_empty() {
516                continue;
517            }
518            assistant_parts.push(Part {
519                id: format!("{}.p{}", asst_id, assistant_parts.len()),
520                kind: PartKind::Prose,
521                content,
522                attachment: None,
523                tool_call_id: None,
524                tool_name: None,
525                tool_replay: None,
526                prune_state: PruneState::Intact,
527                reasoning_meta: None,
528                response_meta,
529            });
530        }
531
532        let mut calls = Vec::new();
533        // Interleave reasoning items with tool calls to preserve the
534        // original emission order. Some provider replays expect the
535        // sequence `reasoning → function_call` from the turn in which both
536        // were produced; swapping them can drop the reasoning/tool pairing.
537        let mut reasoning_iter = reasoning_items.into_iter().peekable();
538        for (tool_index, tool_call) in tool_calls.into_iter().enumerate() {
539            while let Some((insert_index, _, _)) = reasoning_iter.peek() {
540                if *insert_index > tool_index {
541                    break;
542                }
543                let (_, meta, text) = reasoning_iter.next().expect("peek ok");
544                assistant_parts.push(reasoning_part(&asst_id, assistant_parts.len(), text, meta));
545            }
546            assistant_parts.push(Part {
547                id: format!("{}.p{}", asst_id, assistant_parts.len()),
548                kind: PartKind::ToolCall,
549                content: tool_call.input_json.clone(),
550                attachment: None,
551                tool_call_id: Some(tool_call.call_id.clone()),
552                tool_name: Some(tool_call.tool_name.clone()),
553                tool_replay: tool_call.replay.clone(),
554                prune_state: PruneState::Intact,
555                reasoning_meta: None,
556                response_meta: None,
557            });
558
559            let args = serde_json::from_str::<Value>(&tool_call.input_json)
560                .unwrap_or_else(|_| serde_json::json!({}));
561            calls.push(PendingToolCall {
562                call_id: tool_call.call_id,
563                tool_name: tool_call.tool_name,
564                args,
565                replay: tool_call.replay,
566            });
567        }
568        for (_, meta, text) in reasoning_iter {
569            assistant_parts.push(reasoning_part(&asst_id, assistant_parts.len(), text, meta));
570        }
571
572        if !assistant_parts.is_empty() {
573            actions.push(DriverAction::AppendEvents(vec![conversation_event(
574                Message {
575                    id: asst_id,
576                    role: MessageRole::Assistant,
577                    parts: shared_parts(assistant_parts),
578                    origin: None,
579                },
580            )]));
581        }
582
583        actions.push(DriverAction::StartTools { calls });
584        actions
585    }
586
587    fn handle_tool_results(
588        &self,
589        ctx: DriverContextView<'_>,
590        completed: Vec<CompletedToolCall>,
591    ) -> Vec<DriverAction> {
592        let mut actions = Vec::new();
593        let mut result_parts = Vec::new();
594        let mut terminal_outcome = None;
595
596        for outcome in completed {
597            if terminal_outcome.is_none() && outcome.output.is_success() {
598                terminal_outcome = match outcome.output.control.as_ref() {
599                    Some(lash_core::ToolControl::SwitchAgentFrame {
600                        frame_id,
601                        task: Some(task),
602                        ..
603                    }) if !frame_id.trim().is_empty() && !task.trim().is_empty() => {
604                        Some(TurnOutcome::AgentFrameSwitch {
605                            frame_id: frame_id.clone(),
606                            task: task.clone(),
607                        })
608                    }
609                    Some(lash_core::ToolControl::Finish { value }) => {
610                        Some(TurnOutcome::Finished(TurnFinish::ToolValue {
611                            tool_name: outcome.tool_name.clone(),
612                            value: value.to_json_value(),
613                        }))
614                    }
615                    Some(lash_core::ToolControl::Fail { failure }) => {
616                        Some(TurnOutcome::Stopped(TurnStop::ToolError {
617                            tool_name: outcome.tool_name.clone(),
618                            value: failure.to_json_value(),
619                        }))
620                    }
621                    _ => None,
622                };
623            }
624
625            append_model_return_parts(&mut result_parts, outcome.model_return);
626        }
627
628        if !result_parts.is_empty() {
629            let user_id = fresh_message_id();
630            reassign_part_ids(&user_id, &mut result_parts);
631            actions.push(DriverAction::AppendEvents(vec![conversation_event(
632                Message {
633                    id: user_id,
634                    role: MessageRole::User,
635                    parts: shared_parts(result_parts),
636                    origin: None,
637                },
638            )]));
639        }
640
641        if let Some(outcome) = terminal_outcome {
642            actions.push(DriverAction::Finish(outcome));
643            return actions;
644        }
645
646        actions.push(DriverAction::AdvanceProtocolIteration);
647        let next_protocol_iteration = ctx.protocol_iteration() + 1;
648        if let Some(max_turns) = ctx.max_turns()
649            && next_protocol_iteration >= ctx.protocol_run_offset() + max_turns
650        {
651            let message_id = fresh_message_id();
652            actions.push(DriverAction::AppendEvents(vec![conversation_event(
653                turn_limit_exhausted_message(message_id, max_turns),
654            )]));
655            actions.push(DriverAction::Finish(TurnOutcome::Stopped(
656                TurnStop::MaxTurns,
657            )));
658            return actions;
659        }
660
661        actions.push(DriverAction::StartCheckpoint {
662            checkpoint: CheckpointKind::AfterWork,
663            on_empty: CheckpointResumeAction::PrepareIteration,
664        });
665        actions
666    }
667
668    fn handle_exec_result(
669        &self,
670        _ctx: DriverContextView<'_>,
671        _waiting: WaitingExecState<lash_core::HostTurnProtocol>,
672        _result: Result<lash_core::ExecResponse, String>,
673    ) -> Vec<DriverAction> {
674        Vec::new()
675    }
676}
677
678fn append_model_return_parts(parts: &mut Vec<Part>, model_return: lash_core::ModelToolReturn) {
679    for part in model_return.parts {
680        match part {
681            lash_core::ModelToolReturnPart::Text { text } => {
682                if text.is_empty() {
683                    continue;
684                }
685                parts.push(Part {
686                    id: String::new(),
687                    kind: PartKind::ToolResult,
688                    content: text,
689                    attachment: None,
690                    tool_call_id: Some(model_return.call_id.clone()),
691                    tool_name: Some(model_return.tool_name.clone()),
692                    tool_replay: None,
693                    prune_state: PruneState::Intact,
694                    reasoning_meta: None,
695                    response_meta: None,
696                });
697            }
698            lash_core::ModelToolReturnPart::Attachment(reference) => {
699                parts.push(Part {
700                    id: String::new(),
701                    kind: PartKind::Image,
702                    content: String::new(),
703                    attachment: Some(PartAttachment { reference }),
704                    tool_call_id: Some(model_return.call_id.clone()),
705                    tool_name: Some(model_return.tool_name.clone()),
706                    tool_replay: None,
707                    prune_state: PruneState::Intact,
708                    reasoning_meta: None,
709                    response_meta: None,
710                });
711            }
712        }
713    }
714}
715
716fn conversation_event(message: Message) -> SessionEventRecord {
717    SessionEventRecord::Conversation(ConversationRecord::from_message(message))
718}
719
720#[cfg(test)]
721mod tests {
722    use super::*;
723    use lash_core::{
724        AttachmentId, AttachmentMeta, ImageMediaType, MediaType, ModelToolReturn, ToolCallOutput,
725        ToolValue,
726    };
727
728    fn image_ref(id: &str) -> lash_core::AttachmentRef {
729        AttachmentMeta::new(
730            AttachmentId::new(id),
731            MediaType::Image(ImageMediaType::Png),
732            4,
733            Some(1),
734            Some(1),
735            Some("tiny".to_string()),
736        )
737        .as_ref()
738    }
739
740    #[test]
741    fn tool_attachment_round_trips_to_part_kind_image() {
742        let attachment = image_ref("att-1");
743        let output = ToolCallOutput::success(ToolValue::Attachment(attachment.clone()));
744        let model_return =
745            ModelToolReturn::from_output("call-9".to_string(), "screenshot".to_string(), &output);
746
747        let mut parts: Vec<Part> = Vec::new();
748        append_model_return_parts(&mut parts, model_return);
749
750        assert_eq!(parts.len(), 1, "single attachment yields single part");
751        let part = &parts[0];
752        assert!(matches!(part.kind, PartKind::Image));
753        assert_eq!(part.content, "");
754        assert_eq!(part.tool_call_id.as_deref(), Some("call-9"));
755        assert_eq!(part.tool_name.as_deref(), Some("screenshot"));
756        let part_attachment = part.attachment.as_ref().expect("attachment present");
757        assert_eq!(part_attachment.reference.id, attachment.id);
758    }
759
760    #[test]
761    fn tool_text_and_attachment_round_trip_preserves_order() {
762        let attachment = image_ref("att-2");
763        let output = ToolCallOutput::success(ToolValue::Array(vec![
764            ToolValue::String("before".into()),
765            ToolValue::Attachment(attachment.clone()),
766            ToolValue::String("after".into()),
767        ]));
768        let model_return =
769            ModelToolReturn::from_output("call-10".to_string(), "snap".to_string(), &output);
770
771        let mut parts: Vec<Part> = Vec::new();
772        append_model_return_parts(&mut parts, model_return);
773
774        // The array projection emits compact JSON text fragments around the
775        // attachment, preserving in-order position.
776        assert_eq!(parts.len(), 3, "text + image + text yields three parts");
777        assert!(matches!(parts[0].kind, PartKind::ToolResult));
778        assert!(parts[0].content.starts_with("[\"before\""));
779        assert!(matches!(parts[1].kind, PartKind::Image));
780        assert_eq!(
781            parts[1]
782                .attachment
783                .as_ref()
784                .expect("attachment")
785                .reference
786                .id,
787            attachment.id
788        );
789        assert!(matches!(parts[2].kind, PartKind::ToolResult));
790        assert!(parts[2].content.ends_with("\"after\"]"));
791    }
792}