use std::collections::BTreeSet;
use std::sync::Arc;
use lash_core::plugin::{CodeExecutorPlugin, ProtocolSessionContext};
use lash_core::{PromptContribution, SessionError, SessionEventRecord};
use lash_rlm_types::{RlmGlobalsPatchPluginBody, RlmProtocolEvent};
use super::RlmProtocolPluginConfig;
use crate::executor::{RlmExecutionState, execute_code};
use crate::projection::{
ProjectionResolver, RlmProjectedBindings, RlmProjectionExtension, decode_rlm_protocol_event,
};
use crate::rlm_support::{BoundVariableRenderCache, render_bound_variables};
pub(super) struct RlmRuntimeState {
config: RlmProtocolPluginConfig,
projection_resolver: Arc<dyn ProjectionResolver>,
session_projected_bindings: tokio::sync::Mutex<RlmProjectedBindings>,
execution: tokio::sync::Mutex<Option<RlmExecutionState>>,
active_agent_frame_id: tokio::sync::Mutex<Option<String>>,
bound_variable_render_cache: tokio::sync::Mutex<BoundVariableRenderCache>,
}
impl RlmRuntimeState {
pub(super) fn new(
config: RlmProtocolPluginConfig,
projection_resolver: Arc<dyn ProjectionResolver>,
) -> Result<Self, SessionError> {
Ok(Self {
execution: tokio::sync::Mutex::new(Some(RlmExecutionState::new(
config.observe_projection.clone(),
)?)),
config,
projection_resolver,
session_projected_bindings: tokio::sync::Mutex::new(RlmProjectedBindings::new()),
active_agent_frame_id: tokio::sync::Mutex::new(None),
bound_variable_render_cache: tokio::sync::Mutex::new(
BoundVariableRenderCache::default(),
),
})
}
pub(super) async fn projected_binding_prompt_contributions(
&self,
) -> Vec<lash_core::PromptContribution> {
let bindings = self.session_projected_bindings.lock().await;
RlmProjectionExtension::prompt_contributions_for(&bindings)
}
pub(super) async fn bound_variables_prompt_contribution(
&self,
history_len: usize,
) -> PromptContribution {
let globals = self.bound_variable_values().await;
let mut cache = self.bound_variable_render_cache.lock().await;
render_bound_variables(
&mut cache,
&globals,
history_len,
self.config.bound_variables_inline_char_limit,
)
}
async fn bound_variable_values(&self) -> serde_json::Map<String, serde_json::Value> {
let exclude = self.protected_projected_binding_names().await;
self.execution
.lock()
.await
.as_ref()
.map(|execution| execution.bound_variable_values(&exclude))
.unwrap_or_default()
}
async fn protected_projected_binding_names(&self) -> BTreeSet<String> {
self.session_projected_bindings
.lock()
.await
.names()
.collect()
}
pub(super) async fn apply_session_extension(
&self,
extension: lash_core::ProtocolSessionExtensionHandle,
) -> Result<(), SessionError> {
let extension = extension
.as_any()
.downcast_ref::<RlmProjectionExtension>()
.ok_or_else(|| {
SessionError::Protocol(
"RLM protocol received an unsupported session extension".to_string(),
)
})?;
reject_reserved_projected_binding_names(&extension.bindings)?;
let mut guard = self.session_projected_bindings.lock().await;
let merged = guard
.clone()
.merge(extension.bindings.clone())
.map_err(|err| SessionError::Protocol(err.to_string()))?;
*guard = merged;
Ok(())
}
pub(super) async fn validate_turn_extension(
&self,
extension: &lash_core::ProtocolTurnExtensionHandle,
) -> Result<(), SessionError> {
let extension = extension
.as_any()
.downcast_ref::<RlmProjectionExtension>()
.ok_or_else(|| {
SessionError::Protocol(
"RLM protocol received an unsupported turn extension".to_string(),
)
})?;
reject_reserved_projected_binding_names(&extension.bindings)?;
self.session_projected_bindings
.lock()
.await
.clone()
.merge(extension.bindings.clone())
.map(|_| ())
.map_err(|err| SessionError::Protocol(err.to_string()))
}
pub(super) async fn restore_runtime_session_state(
&self,
state: &lash_core::runtime::RuntimeSessionState,
) -> Result<(), SessionError> {
let mut active_agent_frame_id = self.active_agent_frame_id.lock().await;
let mut execution = self.execution.lock().await;
let execution = execution
.as_mut()
.ok_or_else(|| SessionError::Protocol("RLM execution state is busy".to_string()))?;
if active_agent_frame_id.as_deref() != Some(state.current_agent_frame_id.as_str()) {
*execution = RlmExecutionState::new(self.config.observe_projection.clone())?;
*self.session_projected_bindings.lock().await = RlmProjectedBindings::new();
*self.bound_variable_render_cache.lock().await = BoundVariableRenderCache::default();
*active_agent_frame_id = Some(state.current_agent_frame_id.clone());
}
let protected_names = self.protected_projected_binding_names().await;
if let Some(snapshot) = state.execution_state_snapshot().map(|bytes| bytes.to_vec()) {
execution.restore_execution_state(&snapshot)?;
execution.prune_protected_globals(&protected_names);
}
for event in state.read_view().active_events() {
if let SessionEventRecord::Protocol(event) = event
&& let Some(event) = decode_rlm_protocol_event(event)
{
self.apply_seed_or_globals_event(execution, event, &protected_names)
.await?;
}
}
Ok(())
}
pub(super) async fn append_session_nodes(
&self,
nodes: &[lash_core::SessionAppendNode],
) -> Result<(), SessionError> {
let mut execution = self.execution.lock().await;
let execution = execution
.as_mut()
.ok_or_else(|| SessionError::Protocol("RLM execution state is busy".to_string()))?;
let protected_names = self.protected_projected_binding_names().await;
execution.prune_protected_globals(&protected_names);
for node in nodes {
if let lash_core::SessionAppendNode::ProtocolEvent { event, .. } = node
&& let Some(event) = decode_rlm_protocol_event(event)
{
self.apply_seed_or_globals_event(execution, event, &protected_names)
.await?;
}
}
Ok(())
}
pub(super) async fn execute_code(
&self,
ctx: lash_core::RuntimeExecutionContext<'_>,
request: lash_core::ExecRequest,
) -> Result<lash_core::ExecResponse, SessionError> {
let session_projected_bindings = self.session_projected_bindings.lock().await.clone();
let mut guard = self.execution.lock().await;
let state = guard
.take()
.ok_or_else(|| SessionError::Protocol("RLM execution state is busy".to_string()))?;
let result = execute_code(
state,
ctx,
request,
session_projected_bindings,
Arc::clone(&self.projection_resolver),
)
.await;
match result {
Ok((state, response)) => {
*guard = Some(state);
Ok(response)
}
Err(err) => {
*guard = Some(RlmExecutionState::new(
self.config.observe_projection.clone(),
)?);
Err(err)
}
}
}
pub(super) fn execution_state_dirty(&self) -> bool {
self.execution
.try_lock()
.map(|execution| {
execution
.as_ref()
.map(|execution| execution.execution_state_dirty())
.unwrap_or(true)
})
.unwrap_or(true)
}
pub(super) async fn snapshot_execution_state(&self) -> Result<Option<Vec<u8>>, SessionError> {
self.execution
.lock()
.await
.as_mut()
.ok_or_else(|| SessionError::Protocol("RLM execution state is busy".to_string()))?
.snapshot_execution_state()
}
pub(super) async fn restore_execution_state(&self, data: &[u8]) -> Result<(), SessionError> {
self.execution
.lock()
.await
.as_mut()
.ok_or_else(|| SessionError::Protocol("RLM execution state is busy".to_string()))?
.restore_execution_state(data)
}
async fn apply_seed_or_globals_event(
&self,
execution: &mut RlmExecutionState,
event: RlmProtocolEvent,
protected_names: &BTreeSet<String>,
) -> Result<(), SessionError> {
match event {
RlmProtocolEvent::RlmGlobalsPatch(patch) => {
execution.patch_globals(&patch, protected_names)?;
}
RlmProtocolEvent::RlmSeed(seed) => {
let mut protected_names = protected_names.clone();
if !seed.projected.is_empty() {
self.install_initial_projected_seed(seed.projected)?;
protected_names = self.protected_projected_binding_names().await;
}
if !seed.globals.is_empty() {
execution.patch_globals(
&RlmGlobalsPatchPluginBody {
set_default: seed.globals,
},
&protected_names,
)?;
}
}
RlmProtocolEvent::RlmTrajectoryEntry(_) | RlmProtocolEvent::RlmDiagnostic(_) => {}
}
Ok(())
}
fn install_initial_projected_seed(
&self,
snapshot: lash_rlm_types::RlmProjectedSeedSnapshot,
) -> Result<(), SessionError> {
let bindings = match RlmProjectedBindings::from_snapshot(&snapshot) {
Ok(bindings) => bindings,
Err(err) => {
return Err(SessionError::Protocol(format!(
"rlm projected seed snapshot rejected: {err}"
)));
}
};
reject_reserved_projected_binding_names(&bindings)?;
let mut guard = match self.session_projected_bindings.try_lock() {
Ok(guard) => guard,
Err(_) => return Err(SessionError::Protocol(
"rlm projected seed snapshot could not be installed because session bindings were contended".to_string(),
)),
};
let merged = guard
.clone()
.merge(bindings)
.map_err(|err| SessionError::Protocol(err.to_string()))?;
*guard = merged;
Ok(())
}
}
pub(super) struct RlmCodeExecutor {
state: Arc<RlmRuntimeState>,
}
impl RlmCodeExecutor {
pub(super) fn new(state: Arc<RlmRuntimeState>) -> Self {
Self { state }
}
}
#[async_trait::async_trait]
impl CodeExecutorPlugin for RlmCodeExecutor {
async fn execute_code(
&self,
ctx: lash_core::RuntimeExecutionContext<'_>,
request: lash_core::ExecRequest,
) -> Result<lash_core::ExecResponse, SessionError> {
self.state.execute_code(ctx, request).await
}
fn execution_state_dirty(&self) -> bool {
self.state.execution_state_dirty()
}
async fn snapshot_execution_state(
&self,
_ctx: ProtocolSessionContext<'_>,
) -> Result<Option<Vec<u8>>, SessionError> {
self.state.snapshot_execution_state().await
}
async fn restore_execution_state(
&self,
_ctx: ProtocolSessionContext<'_>,
data: &[u8],
) -> Result<(), SessionError> {
self.state.restore_execution_state(data).await
}
}
pub(super) fn reject_reserved_projected_binding_names(
bindings: &RlmProjectedBindings,
) -> Result<(), SessionError> {
if bindings.names().any(|name| name == "history") {
return Err(SessionError::Protocol(
"`history` is reserved as an RLM built-in binding".to_string(),
));
}
Ok(())
}