lash-runtime 0.1.0-alpha.44

Durable agent runtime for Rust: sessions, turns, tools, plugins. Embeddable facade over lash-core.
Documentation
use super::super::*;
use super::contracts::{
    GraphContract, assert_all_processes_terminal, assert_completed_process_graph,
    assert_labeled_node, assert_labeled_resource_operation,
    assert_min_completed_child_session_exec_graphs, assert_min_completed_process_graphs,
    assert_no_duplicate_label_step, assert_session_turn_child_graph, assert_successful_turn_case,
};
use std::collections::VecDeque;

#[derive(Default)]
pub(super) struct ExpectedContracts {
    pub(super) labeled_resource_titles: Vec<&'static str>,
    pub(super) labeled_node_titles: Vec<&'static str>,
    pub(super) completed_process_entries: Vec<&'static str>,
    pub(super) min_completed_child_session_exec_graphs: usize,
    pub(super) min_completed_process_graphs: usize,
}

pub(super) struct LashE2eCase {
    pub(super) name: &'static str,
    pub(super) session_id: &'static str,
    pub(super) scripted_provider_responses: Vec<&'static str>,
    pub(super) root_prompt: &'static str,
    pub(super) expected_submitted_value: Option<serde_json::Value>,
    pub(super) tool_provider: Option<Arc<dyn ToolProvider>>,
    pub(super) install_subagents: bool,
    pub(super) max_turns: Option<usize>,
    pub(super) expected_contracts: ExpectedContracts,
}

pub(super) struct LashE2eRun {
    pub(super) turn_output: Option<TurnResult>,
    pub(super) streamed_events: Vec<TurnActivity>,
    pub(super) graph_snapshots: Vec<crate::tracing::TraceLashlangGraph>,
    pub(super) prompt_captures: Vec<LlmRequest>,
    pub(super) final_process_list: Vec<lash_core::ProcessHandleSummary>,
}

pub(super) async fn run_turn_case(case: LashE2eCase) -> Result<LashE2eRun> {
    let run = run_turn_case_without_success_assertions(case).await?;
    assert_successful_turn_case(&run);
    Ok(run)
}

pub(super) async fn run_turn_case_without_success_assertions(
    case: LashE2eCase,
) -> Result<LashE2eRun> {
    let graph_store = Arc::new(crate::tracing::TraceLashlangGraphStore::default());
    let process_registry = Arc::new(TestLocalProcessRegistry::default());
    let prompt_captures = Arc::new(StdMutex::new(Vec::new()));
    let provider = scripted_provider(
        case.scripted_provider_responses.clone(),
        Arc::clone(&prompt_captures),
    );
    let mut builder = explicit_ephemeral_facets(LashCore::rlm())
        .provider(provider)
        .model(mock_model_spec())
        .store_factory(Arc::new(lash_core::InMemorySessionStoreFactory::new()))
        .process_registry(Arc::clone(&process_registry) as Arc<dyn ProcessRegistry>)
        .lashlang_execution_sink(Arc::clone(&graph_store) as Arc<dyn crate::tracing::TraceSink>);
    if let Some(tools) = case.tool_provider.clone() {
        builder = builder.tools(tools);
    }
    if case.install_subagents {
        builder = builder.plugin(subagents_plugin());
    }
    if let Some(max_turns) = case.max_turns {
        builder = builder.max_turns(max_turns);
    }
    let core = builder.build()?;
    let session = core.session(case.session_id).open().await?;
    let events = Arc::new(RecordingEvents::default());

    let turn_output = session
        .turn(TurnInput::text(case.root_prompt))
        .stream_to(events.as_ref())
        .await?;
    session.process_control().await_all().await?;
    let prompt_captures_snapshot = prompt_captures.lock().expect("prompt captures").clone();
    let final_process_list = session.process_control().list_all().await?;
    let run = LashE2eRun {
        turn_output: Some(turn_output),
        streamed_events: events.snapshot().await,
        graph_snapshots: graph_store.graphs(),
        prompt_captures: prompt_captures_snapshot,
        final_process_list,
    };

    if let Some(expected) = &case.expected_submitted_value {
        let Some(output) = run.turn_output.as_ref() else {
            panic!("{} did not run a turn", case.name);
        };
        assert_eq!(
            output.submitted_value(),
            Some(expected),
            "{} submitted value mismatch",
            case.name
        );
    }

    let contract = GraphContract::from_graphs(&run.graph_snapshots);
    for title in case.expected_contracts.labeled_resource_titles {
        assert_labeled_resource_operation(
            &contract,
            title,
            crate::tracing::TraceLashlangNodeStatus::Completed,
        );
        assert_no_duplicate_label_step(&contract, title);
    }
    for title in case.expected_contracts.labeled_node_titles {
        assert_labeled_node(
            &contract,
            title,
            crate::tracing::TraceLashlangNodeStatus::Completed,
        );
        assert_no_duplicate_label_step(&contract, title);
    }
    for entry_name in case.expected_contracts.completed_process_entries {
        assert_completed_process_graph(&contract, entry_name);
    }
    assert_min_completed_process_graphs(
        &contract,
        case.expected_contracts.min_completed_process_graphs,
    );
    assert_min_completed_child_session_exec_graphs(
        &run,
        case.session_id,
        case.expected_contracts
            .min_completed_child_session_exec_graphs,
    );
    Ok(run)
}

pub(super) async fn run_session_turn_process_case() -> Result<()> {
    let session_id = "lash-e2e-session-turn-root";
    let child_session_id = "lash-e2e-session-turn-child";
    let process_id = "lash-e2e-session-turn-process";
    let graph_store = Arc::new(crate::tracing::TraceLashlangGraphStore::default());
    let process_registry = Arc::new(TestLocalProcessRegistry::default());
    let prompt_captures = Arc::new(StdMutex::new(Vec::new()));
    let provider = scripted_provider(
        vec!["```lashlang\nsubmit { child: \"done\", scoped: true }\n```"],
        Arc::clone(&prompt_captures),
    );
    let core = explicit_ephemeral_facets(LashCore::rlm())
        .provider(provider)
        .model(mock_model_spec())
        .plugin(subagents_plugin())
        .store_factory(Arc::new(lash_core::InMemorySessionStoreFactory::new()))
        .process_registry(Arc::clone(&process_registry) as Arc<dyn ProcessRegistry>)
        .lashlang_execution_sink(Arc::clone(&graph_store) as Arc<dyn crate::tracing::TraceSink>)
        .build()?;
    let session = core.session(session_id).open().await?;
    let child_policy = lash_core::SessionPolicy {
        model: mock_model_spec(),
        max_turns: Some(2),
        ..lash_core::SessionPolicy::default()
    };
    let create_request = lash_core::SessionCreateRequest::child(
        session_id,
        lash_core::SessionStartPoint::Empty,
        child_policy,
        lash_core::PluginOptions::default(),
        "e2e-session-turn",
    )
    .with_session_id(child_session_id);

    let handle = session
        .process_control()
        .start(
            lash_core::ProcessStartRequest::new(
                process_id,
                lash_core::ProcessInput::SessionTurn {
                    create_request: Box::new(create_request),
                    turn_input: Box::new(TurnInput::text("run child session turn")),
                    output_contract: lash_core::ToolOutputContract::Static,
                },
                lash_core::ProcessHandleDescriptor::new(Some("session_turn"), Some("child turn")),
            ),
            inline_scope(lash_core::EffectScope::process(process_id)),
        )
        .await?;
    assert_eq!(handle.process_id, process_id);
    session.process_control().await_all().await?;

    let await_output = process_registry.await_process(process_id).await?;
    let lash_core::ProcessAwaitOutput::Success { value, .. } = await_output else {
        panic!("session-turn process did not succeed");
    };
    assert_eq!(
        value.get("child_session_id"),
        Some(&serde_json::json!(child_session_id))
    );
    let turn: lash_core::AssembledTurn = value
        .get("turn")
        .cloned()
        .map(serde_json::from_value)
        .transpose()
        .expect("session-turn output should decode")
        .expect("session-turn output should contain a turn");
    assert_eq!(
        turn.outcome,
        TurnOutcome::Finished(lash_core::TurnFinish::SubmittedValue {
            value: serde_json::json!({ "child": "done", "scoped": true })
        })
    );

    let prompt_captures_snapshot = prompt_captures.lock().expect("prompt captures").clone();
    let final_process_list = session.process_control().list_all().await?;
    let run = LashE2eRun {
        turn_output: None,
        streamed_events: Vec::new(),
        graph_snapshots: graph_store.graphs(),
        prompt_captures: prompt_captures_snapshot,
        final_process_list,
    };
    assert_eq!(run.prompt_captures.len(), 1);
    assert_all_processes_terminal(&run.final_process_list);
    assert_session_turn_child_graph(&run, child_session_id, process_id);
    Ok(())
}

fn scripted_provider(
    responses: Vec<&'static str>,
    prompt_captures: Arc<StdMutex<Vec<LlmRequest>>>,
) -> ProviderHandle {
    let responses = Arc::new(TokioMutex::new(VecDeque::from(
        responses
            .into_iter()
            .map(ToOwned::to_owned)
            .collect::<Vec<_>>(),
    )));
    crate::testing::TestProvider::builder()
        .kind("lash-e2e")
        .complete(move |request| {
            let responses = Arc::clone(&responses);
            let prompt_captures = Arc::clone(&prompt_captures);
            async move {
                prompt_captures
                    .lock()
                    .expect("prompt captures")
                    .push(request.clone());
                let text = responses
                    .lock()
                    .await
                    .pop_front()
                    .unwrap_or_else(|| panic!("no scripted e2e provider response left"));
                Ok(LlmResponse {
                    full_text: text.clone(),
                    parts: vec![LlmOutputPart::Text {
                        text,
                        response_meta: None,
                    }],
                    ..LlmResponse::default()
                })
            }
        })
        .build()
        .into_handle()
}

fn subagents_plugin() -> Arc<dyn PluginFactory> {
    Arc::new(lash_subagents::SubagentsPluginFactory::new(Arc::new(
        lash_subagents::CapabilityRegistry::new().with(Arc::new(
            lash_subagents::StaticCapability::new("default", SessionSpec::inherit()),
        )),
    )))
}