Skip to main content

lash_core/
session.rs

1use std::sync::{Arc, OnceLock};
2
3use tokio::sync::mpsc::UnboundedSender;
4
5use crate::PluginMessage;
6use crate::tool_dispatch::ToolDispatchContext;
7use crate::{PromptContribution, RuntimeServices, SandboxMessage, SessionEvent, ToolProvider};
8
9mod execution_context;
10pub(crate) mod process_handles;
11mod tool_execution;
12
13pub use execution_context::RuntimeExecutionContext;
14pub(crate) use execution_context::lashlang_host_environment_from_tool_catalog;
15pub use tool_execution::{ToolInvocation, ToolInvocationReply};
16
17#[derive(Clone, Debug, PartialEq, Eq)]
18struct ToolCatalogCacheKey {
19    include_base_tools: bool,
20    context_overlay_revision: u64,
21    tool_generation: u64,
22    plugin_revision: u64,
23    lashlang_language_features: lashlang::LashlangLanguageFeatures,
24}
25
26#[derive(Debug, Default)]
27struct ToolCatalogDerived {
28    catalog: OnceLock<Arc<Vec<serde_json::Value>>>,
29}
30
31struct ToolCatalogArtifact {
32    tool_catalog: Arc<crate::ToolCatalog>,
33    preamble: Arc<crate::TurnDriverPreamble>,
34    derived: ToolCatalogDerived,
35}
36
37#[derive(Clone)]
38pub(crate) struct ToolCatalogHandle(Arc<ToolCatalogArtifact>);
39
40impl ToolCatalogHandle {
41    fn tool_catalog(&self) -> Arc<crate::ToolCatalog> {
42        Arc::clone(&self.0.tool_catalog)
43    }
44
45    fn preamble(&self) -> Arc<crate::TurnDriverPreamble> {
46        Arc::clone(&self.0.preamble)
47    }
48
49    fn catalog(&self) -> Arc<Vec<serde_json::Value>> {
50        Arc::clone(self.0.derived.catalog.get_or_init(|| {
51            Arc::new(crate::tool_registry::project_tool_catalog(
52                self.0.tool_catalog.searchable_tools_iter().cloned(),
53            ))
54        }))
55    }
56}
57
58#[derive(Clone, Debug)]
59pub struct InjectedTurnInput {
60    pub id: Option<String>,
61    pub message: PluginMessage,
62}
63
64#[derive(Debug, thiserror::Error)]
65pub enum SessionError {
66    #[error("I/O error: {0}")]
67    Io(#[from] std::io::Error),
68    #[error("JSON error: {0}")]
69    Json(#[from] serde_json::Error),
70    #[error("code execution is not available in this session")]
71    CodeExecutionUnavailable,
72    #[error("code execution runtime exited unexpectedly")]
73    CodeExecutionRuntimeStopped,
74    #[error(
75        "provider mismatch for session `{session_id}`: persisted provider `{expected}` does not match live provider `{actual}`"
76    )]
77    ProviderMismatch {
78        expected: String,
79        actual: String,
80        session_id: String,
81    },
82    #[error("provider is not configured for session `{session_id}`")]
83    ProviderUnconfigured { session_id: String },
84    #[error("provider `{provider_id}` is not registered for session `{session_id}`")]
85    ProviderUnavailable {
86        provider_id: String,
87        session_id: String,
88    },
89    #[error("protocol error: {0}")]
90    Protocol(String),
91}
92
93#[derive(Clone, Debug)]
94pub struct ExecRequest {
95    pub code: String,
96    pub accept_finish: bool,
97}
98
99pub struct Session {
100    session_id: String,
101    services: RuntimeServices,
102    include_base_tools: bool,
103    context_overlay_revision: u64,
104    context_tools: Vec<Arc<dyn ToolProvider>>,
105    tool_registry: Arc<crate::ToolRegistry>,
106    context_prompt_contributions: Vec<PromptContribution>,
107    message_tx: Option<UnboundedSender<SandboxMessage>>,
108    tool_catalog_cache: std::sync::Mutex<Vec<(ToolCatalogCacheKey, ToolCatalogHandle)>>,
109    /// Memoizes the rendered system prompt across turns. Most consecutive
110    /// turns reuse the same template + context overlay, so the cache hits
111    /// and we skip the section/Vec-join work in
112    /// `lash_sansio::PromptTemplate::render`.
113    prompt_cache: Arc<lash_sansio::PromptCache>,
114}
115
116impl Session {
117    pub async fn new(services: RuntimeServices, session_id: &str) -> Result<Self, SessionError> {
118        let tool_registry = services.plugins.tool_registry();
119        let mut session = Self {
120            session_id: session_id.to_string(),
121            services,
122            include_base_tools: true,
123            context_overlay_revision: 0,
124            context_tools: Vec::new(),
125            tool_registry,
126            context_prompt_contributions: Vec::new(),
127            message_tx: None,
128            tool_catalog_cache: std::sync::Mutex::new(Vec::new()),
129            prompt_cache: Arc::new(lash_sansio::PromptCache::new()),
130        };
131
132        let protocol_session = Arc::clone(session.plugins().protocol_session());
133        protocol_session
134            .initialize_session(crate::plugin::ProtocolSessionContext::new(
135                &mut session,
136                session_id,
137            ))
138            .await?;
139
140        Ok(session)
141    }
142
143    pub fn session_id(&self) -> &str {
144        &self.session_id
145    }
146
147    pub(crate) fn protocol_extra_prompt_contributions(&self) -> Vec<PromptContribution> {
148        // Protocol-specific prompt contributions are owned by the protocol
149        // plugins via their
150        // `reg.prompt().contribute(...)` hooks. Nothing to add here.
151        Vec::new()
152    }
153
154    pub fn tools(&self) -> Arc<dyn ToolProvider> {
155        Arc::clone(&self.tool_registry) as Arc<dyn ToolProvider>
156    }
157
158    pub fn plugins(&self) -> &Arc<crate::PluginSession> {
159        &self.services.plugins
160    }
161
162    pub fn set_context_overlay(
163        &mut self,
164        tool_providers: Vec<Arc<dyn ToolProvider>>,
165        prompt_contributions: Vec<PromptContribution>,
166        include_base_tools: bool,
167    ) -> Result<(), crate::PluginError> {
168        let tool_providers_unchanged = self.context_tools.len() == tool_providers.len()
169            && self
170                .context_tools
171                .iter()
172                .zip(&tool_providers)
173                .all(|(current, next)| Arc::ptr_eq(current, next));
174        if self.include_base_tools == include_base_tools
175            && self.context_prompt_contributions == prompt_contributions
176            && tool_providers_unchanged
177        {
178            return Ok(());
179        }
180        let registry = self
181            .services
182            .plugins
183            .tool_registry()
184            .compose_session_catalog(include_base_tools, tool_providers.clone())
185            .map(Arc::new)
186            .map_err(|err| {
187                crate::PluginError::Session(format!("failed to build session tool registry: {err}"))
188            })?;
189        self.include_base_tools = include_base_tools;
190        self.context_overlay_revision = self.context_overlay_revision.wrapping_add(1);
191        self.context_tools = tool_providers;
192        self.tool_registry = registry;
193        self.context_prompt_contributions = prompt_contributions;
194        self.tool_catalog_cache
195            .lock()
196            .expect("tool catalog cache lock")
197            .clear();
198        Ok(())
199    }
200
201    pub fn prompt_cache(&self) -> Arc<lash_sansio::PromptCache> {
202        Arc::clone(&self.prompt_cache)
203    }
204
205    pub fn context_prompt_contributions(&self) -> &[PromptContribution] {
206        &self.context_prompt_contributions
207    }
208
209    pub fn history_store(&self) -> Option<Arc<dyn crate::store::RuntimePersistence>> {
210        self.services.store.clone()
211    }
212
213    fn tool_catalog_cache_key(&self) -> ToolCatalogCacheKey {
214        ToolCatalogCacheKey {
215            include_base_tools: self.include_base_tools,
216            context_overlay_revision: self.context_overlay_revision,
217            tool_generation: self.tool_registry.generation(),
218            plugin_revision: self.plugins().snapshot_revision_fingerprint(),
219            lashlang_language_features: self.plugins().lashlang_language_features(),
220        }
221    }
222
223    fn build_tool_catalog_entry(
224        &self,
225        session_id: &str,
226    ) -> Result<ToolCatalogHandle, crate::PluginError> {
227        let provider = self.tools();
228        let tools = provider.tool_manifests();
229        let contract_provider = Arc::clone(&provider);
230        let resolve_contract: lash_sansio::ToolContractResolver =
231            Arc::new(move |name: &str| contract_provider.resolve_contract(name));
232        let tool_catalog = Arc::new(self.plugins().resolve_tool_catalog(
233            crate::plugin::ToolCatalogContext {
234                session_id: session_id.to_string(),
235                tools,
236                resolve_contract: Some(Arc::clone(&resolve_contract)),
237                tool_access: self.plugins().tool_access().clone(),
238                subagent: self.plugins().subagent_context().cloned(),
239                lashlang_abilities: self.plugins().lashlang_abilities(),
240            },
241        )?);
242        let input = crate::ProtocolBuildInput {
243            tool_catalog: Arc::clone(&tool_catalog),
244            lashlang_host_environment:
245                execution_context::lashlang_host_environment_from_tool_catalog(
246                    &tool_catalog,
247                    self.plugins().lashlang_abilities(),
248                    self.plugins().lashlang_language_features(),
249                    self.plugins().lashlang_resources(),
250                ),
251            extra_prompt_contributions: self.protocol_extra_prompt_contributions(),
252        };
253        let driver = self.plugins().protocol_driver();
254        let preamble = driver.build_preamble(input);
255        Ok(ToolCatalogHandle(Arc::new(ToolCatalogArtifact {
256            tool_catalog,
257            preamble: Arc::new(preamble),
258            derived: ToolCatalogDerived::default(),
259        })))
260    }
261
262    fn tool_catalog_cache_entry(
263        &self,
264        session_id: &str,
265    ) -> Result<ToolCatalogHandle, crate::PluginError> {
266        let key = self.tool_catalog_cache_key();
267        let mut cache = self
268            .tool_catalog_cache
269            .lock()
270            .expect("tool catalog cache lock");
271        if let Some((_, entry)) = cache.iter().find(|(entry_key, _)| *entry_key == key) {
272            return Ok(entry.clone());
273        }
274        let entry = self.build_tool_catalog_entry(session_id)?;
275        cache.push((key, entry.clone()));
276        Ok(entry)
277    }
278
279    pub fn resolved_tool_catalog(
280        &self,
281        session_id: &str,
282    ) -> Result<Arc<crate::ToolCatalog>, crate::PluginError> {
283        Ok(self.tool_catalog_cache_entry(session_id)?.tool_catalog())
284    }
285
286    pub(crate) fn turn_driver_preamble(
287        &self,
288        session_id: &str,
289    ) -> Result<Arc<crate::TurnDriverPreamble>, crate::PluginError> {
290        Ok(self.tool_catalog_cache_entry(session_id)?.preamble())
291    }
292
293    pub(crate) fn shared_tool_catalog(
294        &self,
295        session_id: &str,
296    ) -> Result<Arc<Vec<serde_json::Value>>, crate::PluginError> {
297        Ok(self.tool_catalog_cache_entry(session_id)?.catalog())
298    }
299
300    pub fn tool_catalog(
301        &self,
302        session_id: &str,
303    ) -> Result<Vec<serde_json::Value>, crate::PluginError> {
304        Ok(self.shared_tool_catalog(session_id)?.as_ref().clone())
305    }
306
307    #[allow(
308        clippy::too_many_arguments,
309        reason = "code execution bridge carries explicit per-turn runtime dependencies"
310    )]
311    pub(crate) fn code_execution_context<'run>(
312        &self,
313        session_id: &str,
314        agent_frame_id: &str,
315        sessions: Arc<dyn crate::plugin::SessionStateService>,
316        session_lifecycle: Arc<dyn crate::plugin::SessionLifecycleService>,
317        session_graph: Arc<dyn crate::plugin::SessionGraphService>,
318        processes: Arc<dyn crate::ProcessService>,
319        process_cancel_ability: Arc<dyn crate::ProcessCancelAbility>,
320        effect_controller: crate::runtime::RuntimeEffectControllerHandle<'run>,
321        direct_completions: crate::DirectCompletionClient<'run>,
322        trigger_router: Option<crate::TriggerRouter>,
323        event_tx: tokio::sync::mpsc::Sender<SessionEvent>,
324        chronological_projection: Arc<crate::ChronologicalProjection>,
325        protocol_extension: Option<crate::ProtocolTurnExtensionHandle>,
326        turn_context: crate::TurnContext,
327        execution_env_spec: crate::ProcessExecutionEnvSpec,
328        checkpoint_messages: crate::tool_dispatch::CheckpointMessageBuffer,
329    ) -> Result<RuntimeExecutionContext<'run>, crate::PluginError> {
330        let dispatch = Arc::new(ToolDispatchContext {
331            plugins: Arc::clone(self.plugins()),
332            tools: self.tools(),
333            tool_catalog: self.resolved_tool_catalog(session_id)?,
334            sessions,
335            session_lifecycle,
336            session_graph,
337            processes,
338            process_cancel_ability,
339            trigger_router,
340            effect_controller,
341            direct_completions: direct_completions.clone(),
342            parent_invocation: None,
343            execution_env_spec: execution_env_spec.clone(),
344            session_id: session_id.to_string(),
345            agent_frame_id: agent_frame_id.to_string(),
346            event_tx,
347            checkpoint_messages,
348            trigger_outcomes: crate::tool_dispatch::ToolTriggerOutcomeBuffer::default(),
349            attachment_store: Arc::clone(&self.services.attachment_store),
350            turn_context: turn_context.clone(),
351        });
352        Ok(RuntimeExecutionContext::new(
353            session_id.to_string(),
354            dispatch,
355            self.plugins().lashlang_abilities(),
356            self.plugins().lashlang_language_features(),
357            Arc::clone(&self.services.lashlang_artifact_store),
358            Arc::clone(&self.services.attachment_store),
359            chronological_projection,
360            protocol_extension,
361            turn_context,
362        ))
363        .map(|context| context.with_execution_env_spec(execution_env_spec))
364    }
365
366    /// Set the message sender for streaming messages during execution.
367    pub fn set_message_sender(&mut self, tx: UnboundedSender<SandboxMessage>) {
368        self.message_tx = Some(tx);
369    }
370
371    /// Clear the message sender (drops the sender, causing receivers to terminate).
372    pub fn clear_message_sender(&mut self) {
373        self.message_tx = None;
374    }
375
376    pub fn invalidate_runtime_caches(&self) {
377        self.tool_catalog_cache
378            .lock()
379            .expect("tool catalog cache lock")
380            .clear();
381        self.prompt_cache.clear();
382    }
383
384    pub async fn refresh_tool_catalog(&mut self) -> Result<(), SessionError> {
385        self.tool_registry = self
386            .services
387            .plugins
388            .tool_registry()
389            .compose_session_catalog(self.include_base_tools, self.context_tools.clone())
390            .map(Arc::new)
391            .map_err(|err| SessionError::Protocol(format!("tool reconfigure failed: {err}")))?;
392        self.tool_catalog_cache
393            .lock()
394            .expect("tool catalog cache lock")
395            .clear();
396        Ok(())
397    }
398}