Skip to main content

lash_core/runtime/
session_ops.rs

1//! `LashRuntime` session-graph and execution-state operations.
2//!
3//! Extracted from `runtime/mod.rs`. This file re-opens `impl LashRuntime`;
4//! no types live here and no public API is changed.
5
6use std::sync::Arc;
7
8use crate::{PluginActionInvokeError, SessionError};
9
10use super::LashRuntime;
11use super::state::{RuntimeSessionState, append_session_nodes_to_state, normalize_session_graph};
12
13impl LashRuntime {
14    /// Replace the host-owned state envelope.
15    pub fn set_persisted_state(&mut self, state: RuntimeSessionState) -> Result<(), SessionError> {
16        let mut state = state;
17        normalize_session_graph(&mut state);
18        if let Some(session) = self.session.as_ref() {
19            session.invalidate_runtime_caches();
20            // Restore the persisted tool surface so the live registry matches the
21            // state being installed (mirrors `from_host_state`). Without this the
22            // registry keeps its prior generation/tools and silently diverges from
23            // `state`. `restore_state` adopts the snapshot's generation, so a
24            // surface that reached generation >= 2 restores cleanly.
25            if let Some(tool_state) = state.tool_state_snapshot.clone() {
26                session
27                    .plugins()
28                    .tool_registry()
29                    .restore_state(tool_state)
30                    .map_err(|err| SessionError::Protocol(err.to_string()))?;
31            }
32            let snapshot = state.plugin_snapshot.clone().unwrap_or_default();
33            session
34                .plugins()
35                .restore(&snapshot)
36                .map_err(|err| SessionError::Protocol(err.to_string()))?;
37            state.plugin_snapshot_revision =
38                Some(session.plugins().snapshot_revision_fingerprint());
39        }
40        self.policy = state.policy.clone();
41        self.protocol_turn_options = state.protocol_turn_options.clone();
42        self.state = state;
43        Ok(())
44    }
45
46    pub async fn append_session_nodes(
47        &mut self,
48        request: crate::AppendSessionNodesRequest,
49    ) -> Result<crate::AppendSessionNodesResult, SessionError> {
50        self.refresh_session_graph_from_store().await?;
51        if let Some(required) = request.requires_ancestor_node_id.as_deref()
52            && !self.state.session_graph.active_path_contains(required)
53        {
54            return Ok(crate::AppendSessionNodesResult::StaleBranch {
55                current_leaf_node_id: self.state.session_graph.leaf_node_id.clone(),
56            });
57        }
58        let node_ids = append_session_nodes_to_state(&mut self.state, &request.nodes);
59        if let Some(session) = self.session.as_mut() {
60            let protocol_session = Arc::clone(session.plugins().protocol_session());
61            let session_id = self.state.session_id.clone();
62            protocol_session
63                .append_session_nodes(
64                    crate::plugin::ProtocolSessionContext::new(session, &session_id),
65                    &request.nodes,
66                )
67                .await?;
68        }
69        self.stamp_live_plugin_state();
70        if let Some(store) = self
71            .session
72            .as_ref()
73            .and_then(|session| session.history_store())
74        {
75            let graph = crate::store::GraphCommitDelta::Append {
76                nodes: node_ids
77                    .iter()
78                    .filter_map(|id| self.state.session_graph.find_node(id).cloned())
79                    .collect(),
80                leaf_node_id: self.state.session_graph.leaf_node_id.clone(),
81            };
82            let commit = crate::store::RuntimeCommit::persisted_state_with_graph_commit(
83                &self.state,
84                graph,
85                &[],
86            );
87            match store.commit_runtime_state(commit).await {
88                Ok(result) => self.state.apply_persisted_commit_result(result),
89                Err(err) => tracing::warn!("failed to persist runtime state: {err}"),
90            }
91        }
92        Ok(crate::AppendSessionNodesResult::Appended {
93            node_ids,
94            leaf_node_id: self
95                .state
96                .session_graph
97                .leaf_node_id
98                .clone()
99                .unwrap_or_default(),
100        })
101    }
102
103    pub async fn apply_protocol_session_extension(
104        &mut self,
105        extension: crate::ProtocolSessionExtensionHandle,
106    ) -> Result<(), SessionError> {
107        let Some(session) = self.session.as_ref() else {
108            return Err(SessionError::Protocol(
109                "runtime session is not available".to_string(),
110            ));
111        };
112        let protocol_session = Arc::clone(session.plugins().protocol_session());
113        protocol_session.apply_session_extension(extension).await
114    }
115
116    pub async fn validate_protocol_turn_extension(
117        &mut self,
118        extension: &crate::ProtocolTurnExtensionHandle,
119    ) -> Result<(), SessionError> {
120        let Some(session) = self.session.as_ref() else {
121            return Err(SessionError::Protocol(
122                "runtime session is not available".to_string(),
123            ));
124        };
125        let protocol_session = Arc::clone(session.plugins().protocol_session());
126        protocol_session.validate_turn_extension(extension).await
127    }
128
129    pub async fn branch_to_node(
130        &mut self,
131        node_id: Option<String>,
132    ) -> Result<crate::SessionSnapshot, SessionError> {
133        let mut state = self.export_state();
134        state.session_graph.branch_to(node_id);
135        let mut persisted_state = RuntimeSessionState::from_snapshot(state);
136        normalize_session_graph(&mut persisted_state);
137
138        let policy = persisted_state.policy.clone();
139        let host = self.host.clone();
140        let services = self.services.clone();
141        let managed_sessions = Arc::clone(&self.managed_sessions);
142        let managed_turns = Arc::clone(&self.managed_turns);
143        let process_sync_needed = Arc::clone(&self.process_sync_needed);
144        let runtime_scope_id = Arc::clone(&self.runtime_scope_id);
145        let turn_phase_probe = self.turn_phase_probe.clone();
146
147        let mut rebuilt = Self::from_host_state(policy, host, services, persisted_state).await?;
148        rebuilt.managed_sessions = managed_sessions;
149        rebuilt.managed_turns = managed_turns;
150        rebuilt.process_sync_needed = process_sync_needed;
151        rebuilt.runtime_scope_id = runtime_scope_id;
152        rebuilt.turn_phase_probe = turn_phase_probe;
153
154        let exported = rebuilt.export_state();
155        *self = rebuilt;
156        Ok(exported)
157    }
158
159    /// Promote a managed child session into the foreground runtime.
160    ///
161    /// Child sessions created through `SessionLifecycleService::create_session` are real
162    /// runtimes, not serialized placeholders. Foreground activation must therefore
163    /// claim that runtime instead of reconstructing a new empty state in the UI.
164    pub async fn activate_managed_session(&mut self, session_id: &str) -> Result<(), SessionError> {
165        let child = {
166            let mut registry = self.managed_sessions.lock().await;
167            registry.remove(session_id).ok_or_else(|| {
168                SessionError::Protocol(format!("unknown managed session `{session_id}`"))
169            })?
170        };
171        let child = child.try_into_runtime().map_err(|_| {
172            SessionError::Protocol(format!("managed session `{session_id}` is still in use"))
173        })?;
174        *self = child;
175        Ok(())
176    }
177
178    /// Explicitly snapshot protocol-local execution state, if any.
179    pub async fn snapshot_execution_state(&mut self) -> Result<Option<Vec<u8>>, SessionError> {
180        let Some(session) = self.session.as_mut() else {
181            return Err(SessionError::Protocol(
182                "runtime session not available".to_string(),
183            ));
184        };
185        let code_executor = session
186            .plugins()
187            .code_executor()
188            .ok_or(SessionError::CodeExecutionUnavailable)?;
189        let session_id = self.state.session_id.clone();
190        let blob = code_executor
191            .snapshot_execution_state(crate::plugin::ProtocolSessionContext::new(
192                session,
193                &session_id,
194            ))
195            .await?;
196        self.state.execution_state_snapshot = blob.clone();
197        Ok(blob)
198    }
199
200    /// Explicitly restore protocol-local execution state from an opaque snapshot blob.
201    pub async fn restore_execution_state(&mut self, snapshot: &[u8]) -> Result<(), SessionError> {
202        let Some(session) = self.session.as_mut() else {
203            return Err(SessionError::Protocol(
204                "runtime session not available".to_string(),
205            ));
206        };
207        let code_executor = session
208            .plugins()
209            .code_executor()
210            .ok_or(SessionError::CodeExecutionUnavailable)?;
211        let session_id = self.state.session_id.clone();
212        code_executor
213            .restore_execution_state(
214                crate::plugin::ProtocolSessionContext::new(session, &session_id),
215                snapshot,
216            )
217            .await?;
218        self.state.execution_state_snapshot = Some(snapshot.to_vec());
219        Ok(())
220    }
221
222    pub async fn list_lashlang_trigger_registrations(
223        &self,
224    ) -> Result<Vec<crate::TriggerRegistration>, SessionError> {
225        let store = self.host.host_event_store.as_ref().ok_or_else(|| {
226            SessionError::Protocol("host event store is unavailable in this runtime".to_string())
227        })?;
228        let records = store
229            .list_subscriptions(crate::TriggerSubscriptionFilter::for_session(
230                self.state.session_id.clone(),
231            ))
232            .await
233            .map_err(|err| SessionError::Protocol(err.to_string()))?;
234        Ok(records
235            .iter()
236            .map(crate::TriggerRegistration::from)
237            .collect())
238    }
239
240    pub async fn lashlang_trigger_registrations_by_source_type(
241        &self,
242        source_type: impl Into<crate::TriggerSourceType>,
243    ) -> Result<Vec<crate::TriggerRegistration>, SessionError> {
244        let store = self.host.host_event_store.as_ref().ok_or_else(|| {
245            SessionError::Protocol("host event store is unavailable in this runtime".to_string())
246        })?;
247        let mut filter =
248            crate::TriggerSubscriptionFilter::for_session(self.state.session_id.clone());
249        filter.source_type = Some(source_type.into().to_string());
250        let records = store
251            .list_subscriptions(filter)
252            .await
253            .map_err(|err| SessionError::Protocol(err.to_string()))?;
254        Ok(records
255            .iter()
256            .map(crate::TriggerRegistration::from)
257            .collect())
258    }
259
260    pub async fn invoke_plugin_action(
261        &self,
262        name: &str,
263        args: serde_json::Value,
264        session_id: Option<String>,
265    ) -> Result<crate::ToolResult, PluginActionInvokeError> {
266        let manager = self.runtime_session_services()?;
267        let Some(session) = self.session.as_ref() else {
268            return Err(PluginActionInvokeError::Unknown(name.to_string()));
269        };
270        session
271            .plugins()
272            .invoke_plugin_action(
273                name,
274                args,
275                session_id,
276                true,
277                manager.state_service(),
278                manager.lifecycle_service(),
279                manager.graph_service(),
280                manager.process_service(),
281            )
282            .await
283    }
284}