use std::sync::Arc;
use async_trait::async_trait;
use lash_core::llm::types::{ProviderReasoningReplay, ProviderReplayMeta, ResponseTextMeta};
use lash_core::plugin::{
PluginError, PluginFactory, PluginRegistrar, PluginSessionContext, ProtocolDriverPlugin,
ProtocolSessionContext, ProtocolSessionPlugin, SessionPlugin,
};
use lash_core::sansio::{
CheckpointResumeAction, CompletedToolCall, PendingToolCall, ProtocolDriverHandle,
WaitingExecState, WaitingLlmState,
};
use lash_core::session_model::message::PartAttachment;
use lash_core::session_model::{
ConversationRecord, Message, MessageRole, Part, PartKind, PruneState, SessionEvent,
SessionEventRecord, fresh_message_id, make_error_event, reassign_part_ids, shared_parts,
};
mod batch;
use batch::batch_tool_definition;
use lash_core::{
CheckpointKind, DriverAction, DriverContextView, LlmOutputPart, LlmResponse,
ProtocolBuildInput, SessionError, ToolCall, ToolContract, ToolInvocation, ToolManifest,
ToolProvider, ToolResult, TurnDriverConfig, TurnDriverPreamble, TurnFinish, TurnOutcome,
TurnStop, append_assistant_text_part, normalized_response_parts, reasoning_part,
};
use serde_json::Value;
const STANDARD_EXECUTION_SECTION: &str = r#"Use direct tool calls.
- Use `batch` (up to 25 calls) for two or more independent tool calls. Serialize calls when later arguments depend on earlier results.
- For direct conversational requests that need no tools, respond in prose only.
Example — two independent reads in one `batch` call:
```json
{
"tool_calls": [
{ "tool": "read_file", "parameters": { "path": "src/main.rs" } },
{ "tool": "grep", "parameters": { "query": "ToolProvider", "path": "crates/lash/src/" } }
]
}
```"#;
const BATCH_MAX_TOOL_CALLS: usize = 25;
#[derive(Default)]
pub struct StandardProtocolPluginFactory;
impl StandardProtocolPluginFactory {
pub fn new() -> Self {
Self
}
}
impl PluginFactory for StandardProtocolPluginFactory {
fn id(&self) -> &'static str {
"standard_protocol"
}
fn build(&self, _ctx: &PluginSessionContext) -> Result<Arc<dyn SessionPlugin>, PluginError> {
Ok(Arc::new(StandardProtocolPlugin))
}
}
struct StandardProtocolPlugin;
impl SessionPlugin for StandardProtocolPlugin {
fn id(&self) -> &'static str {
"standard_protocol"
}
fn register(&self, reg: &mut PluginRegistrar) -> Result<(), PluginError> {
reg.protocol().session(Arc::new(StandardProtocolSession))?;
reg.protocol()
.protocol_driver(Arc::new(StandardProtocolDriver))?;
reg.tools().provider(Arc::new(StandardProtocolTools))?;
Ok(())
}
}
struct StandardProtocolSession;
#[async_trait]
impl ProtocolSessionPlugin for StandardProtocolSession {
async fn initialize_session(
&self,
_ctx: ProtocolSessionContext<'_>,
) -> Result<(), SessionError> {
Ok(())
}
}
struct StandardProtocolDriver;
impl ProtocolDriverPlugin for StandardProtocolDriver {
fn build_preamble(&self, input: ProtocolBuildInput) -> TurnDriverPreamble {
let tool_names = input.tool_catalog.tool_names();
let tool_names_fingerprint = input.tool_catalog.tool_names_fingerprint();
TurnDriverPreamble {
config: TurnDriverConfig::chat(
Arc::new(StandardDriver),
true,
Arc::new(turn_limit_exhausted_message),
),
tool_specs: input.tool_catalog.model_tool_specs(),
tool_names,
tool_names_fingerprint,
omitted_tool_count: 0,
execution_prompt: Arc::from(STANDARD_EXECUTION_SECTION),
prompt_contributions: input.extra_prompt_contributions,
}
}
}
fn turn_limit_exhausted_message(message_id: String, max_turns: usize) -> Message {
Message {
id: message_id.clone(),
role: MessageRole::System,
parts: shared_parts(vec![Part {
id: format!("{message_id}.p0"),
kind: PartKind::Error,
content: format!("Turn limit reached ({max_turns}) before a final assistant response."),
attachment: None,
tool_call_id: None,
tool_name: None,
tool_replay: None,
prune_state: PruneState::Intact,
reasoning_meta: None,
response_meta: None,
}]),
origin: None,
}
}
struct StandardProtocolTools;
#[async_trait]
impl ToolProvider for StandardProtocolTools {
fn tool_manifests(&self) -> Vec<ToolManifest> {
vec![batch_tool_definition().manifest()]
}
fn resolve_contract(&self, name: &str) -> Option<Arc<ToolContract>> {
(name == "batch").then(|| Arc::new(batch_tool_definition().contract()))
}
async fn execute(&self, call: ToolCall<'_>) -> ToolResult {
match call.name {
"batch" => execute_batch_tool_call(call).await,
_ => ToolResult::err_fmt(format_args!("Unknown tool: {}", call.name)),
}
}
}
#[derive(Debug)]
struct BatchCallSpec {
index: usize,
tool: String,
parameters: Value,
}
async fn execute_batch_tool_call(call: ToolCall<'_>) -> ToolResult {
let args = call.args;
let specs = match parse_batch_specs(args) {
Ok(specs) => specs,
Err(err) => return err,
};
let mut immediate_outcomes = Vec::new();
let mut parallel_specs = Vec::new();
let dispatch = call.context.dispatch();
for spec in specs.into_iter().take(BATCH_MAX_TOOL_CALLS) {
if spec.tool == "batch" {
immediate_outcomes.push(serde_json::json!({
"index": spec.index,
"tool": spec.tool,
"success": false,
"duration_ms": 0,
"error": "Tool 'batch' is not allowed inside batch",
}));
continue;
}
let Some(manifest) = dispatch.callable_tool_manifest(&spec.tool) else {
let error = format!("Tool '{}' is unavailable in this session", spec.tool);
immediate_outcomes.push(serde_json::json!({
"index": spec.index,
"tool": spec.tool,
"success": false,
"duration_ms": 0,
"error": error,
}));
continue;
};
parallel_specs.push((
spec.index,
ToolInvocation::new(
format!(
"{}:{:02}",
call.context.tool_call_id().unwrap_or("batch"),
spec.index
),
manifest.id,
spec.parameters,
),
));
}
let mut parallel_outcomes = dispatch
.batch(
parallel_specs
.iter()
.map(|(_, invocation)| invocation.clone())
.collect(),
)
.await;
for ((index, invocation), outcome) in
parallel_specs.into_iter().zip(parallel_outcomes.drain(..))
{
let tool_label = invocation.label();
let tool_record = outcome.record.unwrap_or(lash_core::ToolCallRecord {
call_id: Some(invocation.id),
tool: tool_label,
args: invocation.args,
output: outcome.output,
duration_ms: 0,
});
let mut result_record = serde_json::Map::new();
result_record.insert("index".to_string(), serde_json::json!(index));
result_record.insert("tool".to_string(), serde_json::json!(tool_record.tool));
result_record.insert(
"success".to_string(),
serde_json::json!(tool_record.output.is_success()),
);
result_record.insert(
"duration_ms".to_string(),
serde_json::json!(tool_record.duration_ms),
);
result_record.insert(
if tool_record.output.is_success() {
"result".to_string()
} else {
"error".to_string()
},
tool_record.output.value_for_projection(),
);
immediate_outcomes.push(Value::Object(result_record));
}
for overflow_index in BATCH_MAX_TOOL_CALLS
..args
.get("tool_calls")
.and_then(|value| value.as_array())
.map(|value| value.len())
.unwrap_or_default()
{
immediate_outcomes.push(serde_json::json!({
"index": overflow_index,
"tool": args
.get("tool_calls")
.and_then(|value| value.as_array())
.and_then(|items| items.get(overflow_index))
.and_then(|item| item.get("tool"))
.and_then(|value| value.as_str())
.unwrap_or("unknown"),
"success": false,
"duration_ms": 0,
"error": "Maximum of 25 tool calls allowed in batch",
}));
}
immediate_outcomes.sort_by_key(|outcome| {
outcome
.get("index")
.and_then(|value| value.as_u64())
.unwrap_or(u64::MAX)
});
ToolResult::ok(serde_json::json!({
"results": immediate_outcomes,
}))
}
#[allow(clippy::result_large_err)]
fn parse_batch_specs(args: &Value) -> Result<Vec<BatchCallSpec>, ToolResult> {
let Some(raw_calls) = args.get("tool_calls").and_then(|value| value.as_array()) else {
return Err(ToolResult::err_fmt(
"Missing required parameter: tool_calls",
));
};
if raw_calls.is_empty() {
return Err(ToolResult::err_fmt(
"Invalid tool_calls: expected at least one call",
));
}
let mut specs = Vec::with_capacity(raw_calls.len());
for (index, item) in raw_calls.iter().enumerate() {
let Some(object) = item.as_object() else {
return Err(ToolResult::err_fmt(format_args!(
"Invalid tool_calls[{index}]: expected object with tool and parameters"
)));
};
let Some(tool) = object
.get("tool")
.and_then(|value| value.as_str())
.map(str::trim)
.filter(|tool| !tool.is_empty())
else {
return Err(ToolResult::err_fmt(format_args!(
"Invalid tool_calls[{index}].tool: expected non-empty string"
)));
};
let parameters = object
.get("parameters")
.cloned()
.unwrap_or_else(|| serde_json::json!({}));
specs.push(BatchCallSpec {
index,
tool: tool.to_string(),
parameters,
});
}
Ok(specs)
}
pub struct StandardDriver;
struct StandardToolCall {
call_id: String,
tool_name: String,
input_json: String,
replay: Option<ProviderReplayMeta>,
}
fn last_message_has_tool_result(ctx: &DriverContextView<'_>) -> bool {
ctx.messages().last().is_some_and(|message| {
matches!(message.role, MessageRole::User)
&& message
.parts
.iter()
.any(|part| matches!(part.kind, PartKind::ToolResult))
})
}
impl ProtocolDriverHandle<lash_core::HostTurnProtocol> for StandardDriver {
fn prepare_protocol_iteration(&self, ctx: DriverContextView<'_>) -> Vec<DriverAction> {
vec![DriverAction::StartLlm {
request: ctx.project_llm_request(true),
driver_state: None,
}]
}
fn handle_llm_success(
&self,
ctx: DriverContextView<'_>,
_waiting: WaitingLlmState<lash_core::HostTurnProtocol>,
llm_response: LlmResponse,
text_streamed: bool,
) -> Vec<DriverAction> {
let response_parts = normalized_response_parts(&llm_response);
let mut assistant_text = String::new();
let mut assistant_text_parts: Vec<(String, Option<ResponseTextMeta>)> = Vec::new();
let mut tool_calls: Vec<StandardToolCall> = Vec::new();
let mut reasoning_items: Vec<(usize, Option<ProviderReasoningReplay>, String)> = Vec::new();
let mut actions = Vec::new();
for part in response_parts {
match part {
LlmOutputPart::Text {
text,
response_meta,
} => {
if !text.is_empty() {
let previous_len = assistant_text.len();
append_assistant_text_part(&mut assistant_text, &text);
assistant_text_parts
.push((assistant_text[previous_len..].to_string(), response_meta));
if !text_streamed {
actions.push(DriverAction::Emit(SessionEvent::TextDelta {
content: assistant_text[previous_len..].to_string(),
}));
}
}
}
LlmOutputPart::Reasoning { text, replay } => {
let trimmed = text.trim().to_string();
if trimmed.is_empty() && replay.as_ref().is_none_or(|meta| meta.is_empty()) {
continue;
}
reasoning_items.push((tool_calls.len(), replay, trimmed));
}
LlmOutputPart::ToolCall {
call_id,
tool_name,
input_json,
replay,
} => {
tool_calls.push(StandardToolCall {
call_id,
tool_name,
input_json,
replay,
});
}
}
}
actions.push(DriverAction::Emit(SessionEvent::LlmResponse {
protocol_iteration: ctx.protocol_iteration(),
content: assistant_text.clone(),
duration_ms: 0,
}));
if tool_calls.is_empty() {
if assistant_text.trim().is_empty() && reasoning_items.is_empty() {
if last_message_has_tool_result(&ctx) {
actions.push(DriverAction::StartCheckpoint {
checkpoint: CheckpointKind::BeforeCompletion,
on_empty: CheckpointResumeAction::Finish(TurnOutcome::Finished(
TurnFinish::AssistantMessage {
text: String::new(),
},
)),
});
return actions;
}
actions.push(DriverAction::Emit(make_error_event(
"llm_provider",
Some("empty_response"),
"Model returned no assistant text or tool calls.",
None,
)));
actions.push(DriverAction::Finish(TurnOutcome::Stopped(
TurnStop::ProviderError,
)));
return actions;
}
let asst_id = fresh_message_id();
let mut parts_out = Vec::new();
for (_, meta, text) in reasoning_items {
parts_out.push(reasoning_part(&asst_id, parts_out.len(), text, meta));
}
for (content, response_meta) in assistant_text_parts {
if content.trim().is_empty() {
continue;
}
parts_out.push(Part {
id: format!("{}.p{}", asst_id, parts_out.len()),
kind: PartKind::Prose,
content,
attachment: None,
tool_call_id: None,
tool_name: None,
tool_replay: None,
prune_state: PruneState::Intact,
reasoning_meta: None,
response_meta,
});
}
if parts_out.is_empty() {
actions.push(DriverAction::Emit(make_error_event(
"llm_provider",
Some("empty_response"),
"Model returned no assistant text or tool calls.",
None,
)));
actions.push(DriverAction::Finish(TurnOutcome::Stopped(
TurnStop::ProviderError,
)));
return actions;
}
actions.push(DriverAction::StartCheckpoint {
checkpoint: CheckpointKind::BeforeCompletion,
on_empty: CheckpointResumeAction::Finish(TurnOutcome::Finished(
TurnFinish::AssistantMessage {
text: assistant_text.clone(),
},
)),
});
return actions;
}
let asst_id = fresh_message_id();
let mut assistant_parts = Vec::new();
for (content, response_meta) in assistant_text_parts {
if content.trim().is_empty() {
continue;
}
assistant_parts.push(Part {
id: format!("{}.p{}", asst_id, assistant_parts.len()),
kind: PartKind::Prose,
content,
attachment: None,
tool_call_id: None,
tool_name: None,
tool_replay: None,
prune_state: PruneState::Intact,
reasoning_meta: None,
response_meta,
});
}
let mut calls = Vec::new();
let mut reasoning_iter = reasoning_items.into_iter().peekable();
for (tool_index, tool_call) in tool_calls.into_iter().enumerate() {
while let Some((insert_index, _, _)) = reasoning_iter.peek() {
if *insert_index > tool_index {
break;
}
let (_, meta, text) = reasoning_iter.next().expect("peek ok");
assistant_parts.push(reasoning_part(&asst_id, assistant_parts.len(), text, meta));
}
assistant_parts.push(Part {
id: format!("{}.p{}", asst_id, assistant_parts.len()),
kind: PartKind::ToolCall,
content: tool_call.input_json.clone(),
attachment: None,
tool_call_id: Some(tool_call.call_id.clone()),
tool_name: Some(tool_call.tool_name.clone()),
tool_replay: tool_call.replay.clone(),
prune_state: PruneState::Intact,
reasoning_meta: None,
response_meta: None,
});
let args = serde_json::from_str::<Value>(&tool_call.input_json)
.unwrap_or_else(|_| serde_json::json!({}));
calls.push(PendingToolCall {
call_id: tool_call.call_id,
tool_name: tool_call.tool_name,
args,
replay: tool_call.replay,
});
}
for (_, meta, text) in reasoning_iter {
assistant_parts.push(reasoning_part(&asst_id, assistant_parts.len(), text, meta));
}
if !assistant_parts.is_empty() {
actions.push(DriverAction::AppendEvents(vec![conversation_event(
Message {
id: asst_id,
role: MessageRole::Assistant,
parts: shared_parts(assistant_parts),
origin: None,
},
)]));
}
actions.push(DriverAction::StartTools { calls });
actions
}
fn handle_tool_results(
&self,
ctx: DriverContextView<'_>,
completed: Vec<CompletedToolCall>,
) -> Vec<DriverAction> {
let mut actions = Vec::new();
let mut result_parts = Vec::new();
let mut terminal_outcome = None;
for outcome in completed {
if terminal_outcome.is_none() && outcome.output.is_success() {
terminal_outcome = match outcome.output.control.as_ref() {
Some(lash_core::ToolControl::SwitchAgentFrame {
frame_id,
task: Some(task),
..
}) if !frame_id.trim().is_empty() && !task.trim().is_empty() => {
Some(TurnOutcome::AgentFrameSwitch {
frame_id: frame_id.clone(),
task: task.clone(),
})
}
Some(lash_core::ToolControl::Finish { value }) => {
Some(TurnOutcome::Finished(TurnFinish::ToolValue {
tool_name: outcome.tool_name.clone(),
value: value.to_json_value(),
}))
}
Some(lash_core::ToolControl::Fail { failure }) => {
Some(TurnOutcome::Stopped(TurnStop::ToolError {
tool_name: outcome.tool_name.clone(),
value: failure.to_json_value(),
}))
}
_ => None,
};
}
append_model_return_parts(&mut result_parts, outcome.model_return);
}
if !result_parts.is_empty() {
let user_id = fresh_message_id();
reassign_part_ids(&user_id, &mut result_parts);
actions.push(DriverAction::AppendEvents(vec![conversation_event(
Message {
id: user_id,
role: MessageRole::User,
parts: shared_parts(result_parts),
origin: None,
},
)]));
}
if let Some(outcome) = terminal_outcome {
actions.push(DriverAction::Finish(outcome));
return actions;
}
actions.push(DriverAction::AdvanceProtocolIteration);
let next_protocol_iteration = ctx.protocol_iteration() + 1;
if let Some(max_turns) = ctx.max_turns()
&& next_protocol_iteration >= ctx.protocol_run_offset() + max_turns
{
let message_id = fresh_message_id();
actions.push(DriverAction::AppendEvents(vec![conversation_event(
turn_limit_exhausted_message(message_id, max_turns),
)]));
actions.push(DriverAction::Finish(TurnOutcome::Stopped(
TurnStop::MaxTurns,
)));
return actions;
}
actions.push(DriverAction::StartCheckpoint {
checkpoint: CheckpointKind::AfterWork,
on_empty: CheckpointResumeAction::PrepareIteration,
});
actions
}
fn handle_exec_result(
&self,
_ctx: DriverContextView<'_>,
_waiting: WaitingExecState<lash_core::HostTurnProtocol>,
_result: Result<lash_core::ExecResponse, String>,
) -> Vec<DriverAction> {
Vec::new()
}
}
fn append_model_return_parts(parts: &mut Vec<Part>, model_return: lash_core::ModelToolReturn) {
for part in model_return.parts {
match part {
lash_core::ModelToolReturnPart::Text { text } => {
if text.is_empty() {
continue;
}
parts.push(Part {
id: String::new(),
kind: PartKind::ToolResult,
content: text,
attachment: None,
tool_call_id: Some(model_return.call_id.clone()),
tool_name: Some(model_return.tool_name.clone()),
tool_replay: None,
prune_state: PruneState::Intact,
reasoning_meta: None,
response_meta: None,
});
}
lash_core::ModelToolReturnPart::Attachment(reference) => {
parts.push(Part {
id: String::new(),
kind: PartKind::Image,
content: String::new(),
attachment: Some(PartAttachment { reference }),
tool_call_id: Some(model_return.call_id.clone()),
tool_name: Some(model_return.tool_name.clone()),
tool_replay: None,
prune_state: PruneState::Intact,
reasoning_meta: None,
response_meta: None,
});
}
}
}
}
fn conversation_event(message: Message) -> SessionEventRecord {
SessionEventRecord::Conversation(ConversationRecord::from_message(message))
}
#[cfg(test)]
mod tests {
use super::*;
use lash_core::{
AttachmentId, AttachmentMeta, ImageMediaType, MediaType, ModelToolReturn, ToolCallOutput,
ToolValue,
};
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use tokio::sync::Barrier;
use tokio::time::{Duration, timeout};
fn image_ref(id: &str) -> lash_core::AttachmentRef {
AttachmentMeta::new(
AttachmentId::new(id),
MediaType::Image(ImageMediaType::Png),
4,
Some(1),
Some(1),
Some("tiny".to_string()),
)
.as_ref()
}
#[derive(Clone, Debug)]
struct BatchRuntimeProvider {
calls: Arc<AtomicUsize>,
saw_batch_result: Arc<AtomicBool>,
}
#[async_trait::async_trait]
impl lash_core::Provider for BatchRuntimeProvider {
fn kind(&self) -> &'static str {
"stub"
}
fn options(&self) -> lash_core::ProviderOptions {
lash_core::ProviderOptions::default()
}
fn set_options(&mut self, _options: lash_core::ProviderOptions) {}
fn serialize_config(&self) -> serde_json::Value {
serde_json::json!({})
}
async fn complete(
&mut self,
request: lash_core::LlmRequest,
) -> Result<lash_core::LlmResponse, lash_core::LlmTransportError> {
let call_index = self.calls.fetch_add(1, Ordering::SeqCst);
if call_index == 0 {
return Ok(lash_core::LlmResponse {
parts: vec![lash_core::LlmOutputPart::ToolCall {
call_id: "batch-call".to_string(),
tool_name: "batch".to_string(),
input_json: serde_json::json!({
"tool_calls": [
{"tool": "alpha", "parameters": {}},
{"tool": "beta", "parameters": {"value": "fail"}}
]
})
.to_string(),
replay: None,
}],
..lash_core::LlmResponse::default()
});
}
let projected_messages = format!("{:?}", request.messages);
if projected_messages.contains("alpha") && projected_messages.contains("beta failed") {
self.saw_batch_result.store(true, Ordering::SeqCst);
}
Ok(lash_core::LlmResponse {
full_text: "done".to_string(),
parts: vec![lash_core::LlmOutputPart::Text {
text: "done".to_string(),
response_meta: None,
}],
..lash_core::LlmResponse::default()
})
}
fn clone_boxed(&self) -> Box<dyn lash_core::Provider> {
Box::new(self.clone())
}
}
#[derive(Debug)]
struct BatchRuntimeTools {
barrier: Arc<Barrier>,
started: Arc<AtomicUsize>,
}
fn runtime_test_tool(name: &str) -> lash_core::ToolDefinition {
lash_core::ToolDefinition::raw(
format!("tool:{name}"),
name,
"",
serde_json::json!({
"type": "object",
"properties": {
"value": { "type": "string" }
},
"additionalProperties": true
}),
serde_json::json!({ "type": "string" }),
)
.with_scheduling(lash_core::ToolScheduling::Parallel)
}
#[async_trait::async_trait]
impl ToolProvider for BatchRuntimeTools {
fn tool_manifests(&self) -> Vec<ToolManifest> {
vec![
runtime_test_tool("alpha").manifest(),
runtime_test_tool("beta").manifest(),
]
}
fn resolve_contract(&self, name: &str) -> Option<Arc<ToolContract>> {
match name {
"alpha" | "beta" => Some(Arc::new(runtime_test_tool(name).contract())),
_ => None,
}
}
async fn execute(&self, call: ToolCall<'_>) -> ToolResult {
self.started.fetch_add(1, Ordering::SeqCst);
if timeout(Duration::from_millis(100), self.barrier.wait())
.await
.is_err()
{
return ToolResult::err_fmt("batch child tools did not run concurrently");
}
if call.name == "beta"
&& call.args.get("value").and_then(|value| value.as_str()) == Some("fail")
{
return ToolResult::err_fmt("beta failed");
}
ToolResult::ok(serde_json::json!(call.name))
}
}
#[derive(Clone, Default)]
struct CountingEffectController {
kinds: Arc<std::sync::Mutex<Vec<lash_core::RuntimeEffectKind>>>,
}
impl CountingEffectController {
fn count(&self, kind: lash_core::RuntimeEffectKind) -> usize {
self.kinds
.lock()
.expect("effect kinds")
.iter()
.filter(|candidate| **candidate == kind)
.count()
}
}
#[derive(Default)]
struct DurableMemoryAttachmentStore {
inner: lash_core::InMemoryAttachmentStore,
}
#[async_trait::async_trait]
impl lash_core::AttachmentStore for DurableMemoryAttachmentStore {
fn persistence(&self) -> lash_core::AttachmentStorePersistence {
lash_core::AttachmentStorePersistence::Durable
}
async fn put(
&self,
bytes: Vec<u8>,
meta: lash_core::AttachmentCreateMeta,
) -> Result<lash_core::AttachmentRef, lash_core::AttachmentStoreError> {
self.inner.put(bytes, meta).await
}
async fn get(
&self,
id: &lash_core::AttachmentId,
) -> Result<lash_core::StoredAttachment, lash_core::AttachmentStoreError> {
self.inner.get(id).await
}
}
#[derive(Default)]
struct DurableMemoryProcessEnvStore {
inner: lash_core::InMemoryProcessExecutionEnvStore,
}
#[async_trait::async_trait]
impl lash_core::ProcessExecutionEnvStore for DurableMemoryProcessEnvStore {
fn durability_tier(&self) -> lash_core::DurabilityTier {
lash_core::DurabilityTier::Durable
}
async fn put_process_execution_env(
&self,
env_ref: &lash_core::ProcessExecutionEnvRef,
bytes: &[u8],
) -> Result<(), lash_core::PluginError> {
self.inner.put_process_execution_env(env_ref, bytes).await
}
async fn get_process_execution_env(
&self,
env_ref: &lash_core::ProcessExecutionEnvRef,
) -> Result<Option<Vec<u8>>, lash_core::PluginError> {
self.inner.get_process_execution_env(env_ref).await
}
}
#[async_trait::async_trait]
impl lash_core::RuntimeEffectController for CountingEffectController {
fn durability_tier(&self) -> lash_core::DurabilityTier {
lash_core::DurabilityTier::Durable
}
async fn execute_effect(
&self,
envelope: lash_core::RuntimeEffectEnvelope,
local_executor: lash_core::RuntimeEffectLocalExecutor<'_>,
) -> Result<lash_core::RuntimeEffectOutcome, lash_core::RuntimeEffectControllerError>
{
self.kinds
.lock()
.expect("effect kinds")
.push(envelope.command.kind());
local_executor.execute(envelope).await
}
}
#[tokio::test]
async fn standard_batch_tool_uses_core_tool_batch_under_durable_boundary() {
let provider_calls = Arc::new(AtomicUsize::new(0));
let saw_batch_result = Arc::new(AtomicBool::new(false));
let provider = BatchRuntimeProvider {
calls: Arc::clone(&provider_calls),
saw_batch_result: Arc::clone(&saw_batch_result),
};
let provider_handle = lash_core::ProviderHandle::new(lash_core::ProviderComponents::new(
Box::new(provider),
Arc::new(lash_core::StaticModelPolicy::new()),
));
let mut host = lash_core::RuntimeHostConfig::in_memory();
host.providers.provider_resolver =
Arc::new(lash_core::SingleProviderResolver::new(provider_handle));
host.durability.attachment_store = Arc::new(DurableMemoryAttachmentStore::default());
host.durability.process_env_store = Arc::new(DurableMemoryProcessEnvStore::default());
let started = Arc::new(AtomicUsize::new(0));
let factories: Vec<Arc<dyn lash_core::PluginFactory>> = vec![
Arc::new(StandardProtocolPluginFactory::new()),
Arc::new(lash_core::plugin::StaticPluginFactory::new(
"standard-batch-test-tools",
lash_core::PluginSpec::new().with_tool_provider(Arc::new(BatchRuntimeTools {
barrier: Arc::new(Barrier::new(2)),
started: Arc::clone(&started),
})),
)),
];
let policy = lash_core::SessionPolicy {
provider_id: "stub".to_string(),
model: lash_core::ModelSpec::from_token_limits("mock-model", None, 200_000, None)
.expect("valid model"),
..lash_core::SessionPolicy::default()
};
let controller = CountingEffectController::default();
let scoped_controller = lash_core::ScopedEffectController::shared(
Arc::new(controller.clone()),
lash_core::ExecutionScope::turn("standard-batch-session", "turn-1"),
)
.expect("scoped controller");
let mut runtime = lash_core::LashRuntime::builder()
.with_session_id("standard-batch-session")
.with_policy(policy)
.with_runtime_host(host)
.with_plugin_factories(factories)
.build()
.await
.expect("runtime");
let turn = runtime
.stream_turn(
lash_core::TurnInput::text("run the batch"),
lash_core::TurnOptions::new(
tokio_util::sync::CancellationToken::new(),
scoped_controller,
),
)
.await
.expect("turn");
assert!(matches!(turn.outcome, lash_core::TurnOutcome::Finished(_)));
assert_eq!(provider_calls.load(Ordering::SeqCst), 2);
assert_eq!(started.load(Ordering::SeqCst), 2);
assert!(saw_batch_result.load(Ordering::SeqCst));
assert_eq!(controller.count(lash_core::RuntimeEffectKind::ToolBatch), 1);
}
#[test]
fn tool_attachment_round_trips_to_part_kind_image() {
let attachment = image_ref("att-1");
let output = ToolCallOutput::success(ToolValue::Attachment(attachment.clone()));
let model_return =
ModelToolReturn::from_output("call-9".to_string(), "screenshot".to_string(), &output);
let mut parts: Vec<Part> = Vec::new();
append_model_return_parts(&mut parts, model_return);
assert_eq!(parts.len(), 1, "single attachment yields single part");
let part = &parts[0];
assert!(matches!(part.kind, PartKind::Image));
assert_eq!(part.content, "");
assert_eq!(part.tool_call_id.as_deref(), Some("call-9"));
assert_eq!(part.tool_name.as_deref(), Some("screenshot"));
let part_attachment = part.attachment.as_ref().expect("attachment present");
assert_eq!(part_attachment.reference.id, attachment.id);
}
#[test]
fn tool_text_and_attachment_round_trip_preserves_order() {
let attachment = image_ref("att-2");
let output = ToolCallOutput::success(ToolValue::Array(vec![
ToolValue::String("before".into()),
ToolValue::Attachment(attachment.clone()),
ToolValue::String("after".into()),
]));
let model_return =
ModelToolReturn::from_output("call-10".to_string(), "snap".to_string(), &output);
let mut parts: Vec<Part> = Vec::new();
append_model_return_parts(&mut parts, model_return);
assert_eq!(parts.len(), 3, "text + image + text yields three parts");
assert!(matches!(parts[0].kind, PartKind::ToolResult));
assert!(parts[0].content.starts_with("[\"before\""));
assert!(matches!(parts[1].kind, PartKind::Image));
assert_eq!(
parts[1]
.attachment
.as_ref()
.expect("attachment")
.reference
.id,
attachment.id
);
assert!(matches!(parts[2].kind, PartKind::ToolResult));
assert!(parts[2].content.ends_with("\"after\"]"));
}
}