Skip to main content

lash_core/
session.rs

1use std::sync::{Arc, OnceLock};
2
3use tokio::sync::mpsc::UnboundedSender;
4
5use crate::PluginMessage;
6use crate::tool_dispatch::ToolDispatchContext;
7use crate::{PromptContribution, RuntimeServices, SandboxMessage, SessionEvent, ToolProvider};
8
9mod execution_context;
10pub(crate) mod process_handles;
11mod tool_execution;
12
13pub use execution_context::RuntimeExecutionContext;
14pub use tool_execution::{ToolInvocation, ToolInvocationReply};
15
16#[derive(Clone, Debug, PartialEq, Eq)]
17struct ToolCatalogCacheKey {
18    include_base_tools: bool,
19    context_overlay_revision: u64,
20    tool_generation: u64,
21    plugin_revision: u64,
22}
23
24#[derive(Debug, Default)]
25struct ToolCatalogDerived {
26    catalog: OnceLock<Arc<Vec<serde_json::Value>>>,
27}
28
29struct ToolCatalogArtifact {
30    tool_catalog: Arc<crate::ToolCatalog>,
31    preamble: Arc<crate::TurnDriverPreamble>,
32    derived: ToolCatalogDerived,
33}
34
35#[derive(Clone)]
36pub(crate) struct ToolCatalogHandle(Arc<ToolCatalogArtifact>);
37
38impl ToolCatalogHandle {
39    fn tool_catalog(&self) -> Arc<crate::ToolCatalog> {
40        Arc::clone(&self.0.tool_catalog)
41    }
42
43    fn preamble(&self) -> Arc<crate::TurnDriverPreamble> {
44        Arc::clone(&self.0.preamble)
45    }
46
47    fn catalog(&self) -> Arc<Vec<serde_json::Value>> {
48        Arc::clone(self.0.derived.catalog.get_or_init(|| {
49            Arc::new(crate::tool_registry::project_tool_catalog(
50                self.0.tool_catalog.searchable_tools_iter().cloned(),
51            ))
52        }))
53    }
54}
55
56#[derive(Clone, Debug)]
57pub struct InjectedTurnInput {
58    pub id: Option<String>,
59    pub message: PluginMessage,
60}
61
62#[derive(Debug, thiserror::Error)]
63pub enum SessionError {
64    #[error("I/O error: {0}")]
65    Io(#[from] std::io::Error),
66    #[error("JSON error: {0}")]
67    Json(#[from] serde_json::Error),
68    #[error("code execution is not available in this session")]
69    CodeExecutionUnavailable,
70    #[error("code execution runtime exited unexpectedly")]
71    CodeExecutionRuntimeStopped,
72    #[error(
73        "provider mismatch for session `{session_id}`: persisted provider `{expected}` does not match live provider `{actual}`"
74    )]
75    ProviderMismatch {
76        expected: String,
77        actual: String,
78        session_id: String,
79    },
80    #[error("provider is not configured for session `{session_id}`")]
81    ProviderUnconfigured { session_id: String },
82    #[error("provider `{provider_id}` is not registered for session `{session_id}`")]
83    ProviderUnavailable {
84        provider_id: String,
85        session_id: String,
86    },
87    #[error("protocol error: {0}")]
88    Protocol(String),
89}
90
91#[derive(Clone, Debug)]
92pub struct ExecRequest {
93    pub language: String,
94    pub code: String,
95    pub accept_finish: bool,
96}
97
98pub struct Session {
99    session_id: String,
100    services: RuntimeServices,
101    include_base_tools: bool,
102    context_overlay_revision: u64,
103    context_tools: Vec<Arc<dyn ToolProvider>>,
104    tool_registry: Arc<crate::ToolRegistry>,
105    context_prompt_contributions: Vec<PromptContribution>,
106    message_tx: Option<UnboundedSender<SandboxMessage>>,
107    tool_catalog_cache: std::sync::Mutex<Vec<(ToolCatalogCacheKey, ToolCatalogHandle)>>,
108    /// Memoizes the rendered system prompt across turns. Most consecutive
109    /// turns reuse the same template + context overlay, so the cache hits
110    /// and we skip the section/Vec-join work in
111    /// `lash_sansio::PromptTemplate::render`.
112    prompt_cache: Arc<lash_sansio::PromptCache>,
113}
114
115impl Session {
116    pub async fn new(services: RuntimeServices, session_id: &str) -> Result<Self, SessionError> {
117        let tool_registry = services.plugins.tool_registry();
118        let mut session = Self {
119            session_id: session_id.to_string(),
120            services,
121            include_base_tools: true,
122            context_overlay_revision: 0,
123            context_tools: Vec::new(),
124            tool_registry,
125            context_prompt_contributions: Vec::new(),
126            message_tx: None,
127            tool_catalog_cache: std::sync::Mutex::new(Vec::new()),
128            prompt_cache: Arc::new(lash_sansio::PromptCache::new()),
129        };
130
131        let protocol_session = Arc::clone(session.plugins().protocol_session());
132        protocol_session
133            .initialize_session(crate::plugin::ProtocolSessionContext::new(
134                &mut session,
135                session_id,
136            ))
137            .await?;
138
139        Ok(session)
140    }
141
142    pub fn session_id(&self) -> &str {
143        &self.session_id
144    }
145
146    pub(crate) fn protocol_extra_prompt_contributions(&self) -> Vec<PromptContribution> {
147        // Protocol-specific prompt contributions are owned by the protocol
148        // plugins via their
149        // `reg.prompt().contribute(...)` hooks. Nothing to add here.
150        Vec::new()
151    }
152
153    pub fn tools(&self) -> Arc<dyn ToolProvider> {
154        Arc::clone(&self.tool_registry) as Arc<dyn ToolProvider>
155    }
156
157    pub fn plugins(&self) -> &Arc<crate::PluginSession> {
158        &self.services.plugins
159    }
160
161    pub fn set_context_overlay(
162        &mut self,
163        tool_providers: Vec<Arc<dyn ToolProvider>>,
164        prompt_contributions: Vec<PromptContribution>,
165        include_base_tools: bool,
166    ) -> Result<(), crate::PluginError> {
167        let tool_providers_unchanged = self.context_tools.len() == tool_providers.len()
168            && self
169                .context_tools
170                .iter()
171                .zip(&tool_providers)
172                .all(|(current, next)| Arc::ptr_eq(current, next));
173        if self.include_base_tools == include_base_tools
174            && self.context_prompt_contributions == prompt_contributions
175            && tool_providers_unchanged
176        {
177            return Ok(());
178        }
179        let registry = self
180            .services
181            .plugins
182            .tool_registry()
183            .compose_session_catalog(include_base_tools, tool_providers.clone())
184            .map(Arc::new)
185            .map_err(|err| {
186                crate::PluginError::Session(format!("failed to build session tool registry: {err}"))
187            })?;
188        self.include_base_tools = include_base_tools;
189        self.context_overlay_revision = self.context_overlay_revision.wrapping_add(1);
190        self.context_tools = tool_providers;
191        self.tool_registry = registry;
192        self.context_prompt_contributions = prompt_contributions;
193        self.tool_catalog_cache
194            .lock()
195            .expect("tool catalog cache lock")
196            .clear();
197        Ok(())
198    }
199
200    pub fn prompt_cache(&self) -> Arc<lash_sansio::PromptCache> {
201        Arc::clone(&self.prompt_cache)
202    }
203
204    pub fn context_prompt_contributions(&self) -> &[PromptContribution] {
205        &self.context_prompt_contributions
206    }
207
208    pub fn history_store(&self) -> Option<Arc<dyn crate::store::RuntimePersistence>> {
209        self.services.store.clone()
210    }
211
212    fn tool_catalog_cache_key(&self) -> ToolCatalogCacheKey {
213        ToolCatalogCacheKey {
214            include_base_tools: self.include_base_tools,
215            context_overlay_revision: self.context_overlay_revision,
216            tool_generation: self.tool_registry.generation(),
217            plugin_revision: self.plugins().snapshot_revision_fingerprint(),
218        }
219    }
220
221    fn build_tool_catalog_entry(
222        &self,
223        session_id: &str,
224    ) -> Result<ToolCatalogHandle, crate::PluginError> {
225        let provider = self.tools();
226        let tools = provider.tool_manifests();
227        let contract_provider = Arc::clone(&provider);
228        let resolve_contract: lash_sansio::ToolContractResolver =
229            Arc::new(move |name: &str| contract_provider.resolve_contract(name));
230        let tool_catalog = Arc::new(self.plugins().resolve_tool_catalog(
231            crate::plugin::ToolCatalogContext {
232                session_id: session_id.to_string(),
233                tools,
234                resolve_contract: Some(Arc::clone(&resolve_contract)),
235                tool_access: self.plugins().tool_access().clone(),
236                subagent: self.plugins().subagent_context().cloned(),
237                extensions: self.plugins().extensions().clone(),
238            },
239        )?);
240        let input = crate::ProtocolBuildInput {
241            tool_catalog: Arc::clone(&tool_catalog),
242            plugin_extensions: self.plugins().extensions().clone(),
243            trigger_events: self.plugins().triggers().clone(),
244            extra_prompt_contributions: self.protocol_extra_prompt_contributions(),
245        };
246        let driver = self.plugins().protocol_driver();
247        let preamble = driver.build_preamble(input);
248        Ok(ToolCatalogHandle(Arc::new(ToolCatalogArtifact {
249            tool_catalog,
250            preamble: Arc::new(preamble),
251            derived: ToolCatalogDerived::default(),
252        })))
253    }
254
255    fn tool_catalog_cache_entry(
256        &self,
257        session_id: &str,
258    ) -> Result<ToolCatalogHandle, crate::PluginError> {
259        let key = self.tool_catalog_cache_key();
260        let mut cache = self
261            .tool_catalog_cache
262            .lock()
263            .expect("tool catalog cache lock");
264        if let Some((_, entry)) = cache.iter().find(|(entry_key, _)| *entry_key == key) {
265            return Ok(entry.clone());
266        }
267        let entry = self.build_tool_catalog_entry(session_id)?;
268        cache.push((key, entry.clone()));
269        Ok(entry)
270    }
271
272    pub fn resolved_tool_catalog(
273        &self,
274        session_id: &str,
275    ) -> Result<Arc<crate::ToolCatalog>, crate::PluginError> {
276        Ok(self.tool_catalog_cache_entry(session_id)?.tool_catalog())
277    }
278
279    pub(crate) fn turn_driver_preamble(
280        &self,
281        session_id: &str,
282    ) -> Result<Arc<crate::TurnDriverPreamble>, crate::PluginError> {
283        Ok(self.tool_catalog_cache_entry(session_id)?.preamble())
284    }
285
286    pub(crate) fn shared_tool_catalog(
287        &self,
288        session_id: &str,
289    ) -> Result<Arc<Vec<serde_json::Value>>, crate::PluginError> {
290        Ok(self.tool_catalog_cache_entry(session_id)?.catalog())
291    }
292
293    pub fn tool_catalog(
294        &self,
295        session_id: &str,
296    ) -> Result<Vec<serde_json::Value>, crate::PluginError> {
297        Ok(self.shared_tool_catalog(session_id)?.as_ref().clone())
298    }
299
300    #[allow(
301        clippy::too_many_arguments,
302        reason = "code execution bridge carries explicit per-turn runtime dependencies"
303    )]
304    pub(crate) fn code_execution_context<'run>(
305        &self,
306        session_id: &str,
307        agent_frame_id: &str,
308        sessions: Arc<dyn crate::plugin::SessionStateService>,
309        session_lifecycle: Arc<dyn crate::plugin::SessionLifecycleService>,
310        session_graph: Arc<dyn crate::plugin::SessionGraphService>,
311        processes: Arc<dyn crate::ProcessService>,
312        process_cancel_ability: Arc<dyn crate::ProcessCancelAbility>,
313        effect_controller: crate::runtime::RuntimeEffectControllerHandle<'run>,
314        direct_completions: crate::DirectCompletionClient<'run>,
315        trigger_router: Option<crate::TriggerRouter>,
316        event_tx: tokio::sync::mpsc::Sender<SessionEvent>,
317        chronological_projection: Arc<crate::ChronologicalProjection>,
318        protocol_extension: Option<crate::ProtocolTurnExtensionHandle>,
319        turn_context: crate::TurnContext,
320        execution_env_spec: crate::ProcessExecutionEnvSpec,
321        checkpoint_messages: crate::tool_dispatch::CheckpointMessageBuffer,
322    ) -> Result<RuntimeExecutionContext<'run>, crate::PluginError> {
323        let dispatch = Arc::new(ToolDispatchContext {
324            plugins: Arc::clone(self.plugins()),
325            tools: self.tools(),
326            tool_catalog: self.resolved_tool_catalog(session_id)?,
327            sessions,
328            session_lifecycle,
329            session_graph,
330            processes,
331            process_cancel_ability,
332            trigger_router,
333            effect_controller,
334            direct_completions: direct_completions.clone(),
335            parent_invocation: None,
336            execution_env_spec: execution_env_spec.clone(),
337            session_id: session_id.to_string(),
338            agent_frame_id: agent_frame_id.to_string(),
339            event_tx,
340            checkpoint_messages,
341            trigger_outcomes: crate::tool_dispatch::ToolTriggerOutcomeBuffer::default(),
342            attachment_store: Arc::clone(&self.services.attachment_store),
343            turn_context: turn_context.clone(),
344        });
345        Ok(RuntimeExecutionContext::new(
346            session_id.to_string(),
347            dispatch,
348            Arc::clone(&self.services.process_env_store),
349            Arc::clone(&self.services.attachment_store),
350            chronological_projection,
351            protocol_extension,
352            turn_context,
353        ))
354        .map(|context| context.with_execution_env_spec(execution_env_spec))
355    }
356
357    /// Set the message sender for streaming messages during execution.
358    pub fn set_message_sender(&mut self, tx: UnboundedSender<SandboxMessage>) {
359        self.message_tx = Some(tx);
360    }
361
362    /// Clear the message sender (drops the sender, causing receivers to terminate).
363    pub fn clear_message_sender(&mut self) {
364        self.message_tx = None;
365    }
366
367    pub fn invalidate_runtime_caches(&self) {
368        self.tool_catalog_cache
369            .lock()
370            .expect("tool catalog cache lock")
371            .clear();
372        self.prompt_cache.clear();
373    }
374
375    pub async fn refresh_tool_catalog(&mut self) -> Result<(), SessionError> {
376        self.tool_registry = self
377            .services
378            .plugins
379            .tool_registry()
380            .compose_session_catalog(self.include_base_tools, self.context_tools.clone())
381            .map(Arc::new)
382            .map_err(|err| SessionError::Protocol(format!("tool reconfigure failed: {err}")))?;
383        self.tool_catalog_cache
384            .lock()
385            .expect("tool catalog cache lock")
386            .clear();
387        Ok(())
388    }
389}