Skip to main content

lash_core/
session.rs

1use std::sync::{Arc, OnceLock};
2
3use tokio::sync::mpsc::UnboundedSender;
4
5use crate::PluginMessage;
6use crate::tool_dispatch::ToolDispatchContext;
7use crate::{PromptContribution, RuntimeServices, SandboxMessage, SessionEvent, ToolProvider};
8
9mod execution_context;
10pub(crate) mod process_handles;
11mod tool_execution;
12pub(crate) mod triggers;
13
14pub use execution_context::RuntimeExecutionContext;
15pub(crate) use execution_context::lashlang_surface_from_tool_surface;
16pub use tool_execution::{ToolInvocation, ToolInvocationReply};
17
18#[derive(Clone, Debug, PartialEq, Eq)]
19struct ToolSurfaceCacheKey {
20    include_base_tools: bool,
21    context_surface_revision: u64,
22    tool_generation: u64,
23    plugin_revision: u64,
24    lashlang_language_features: lashlang::LashlangLanguageFeatures,
25}
26
27#[derive(Debug, Default)]
28struct ToolSurfaceDerived {
29    catalog: OnceLock<Arc<Vec<serde_json::Value>>>,
30}
31
32struct ToolSurfaceArtifact {
33    surface: Arc<crate::ToolSurface>,
34    preamble: Arc<crate::TurnDriverPreamble>,
35    derived: ToolSurfaceDerived,
36}
37
38#[derive(Clone)]
39pub(crate) struct ToolSurfaceHandle(Arc<ToolSurfaceArtifact>);
40
41impl ToolSurfaceHandle {
42    fn surface(&self) -> Arc<crate::ToolSurface> {
43        Arc::clone(&self.0.surface)
44    }
45
46    fn preamble(&self) -> Arc<crate::TurnDriverPreamble> {
47        Arc::clone(&self.0.preamble)
48    }
49
50    fn catalog(&self) -> Arc<Vec<serde_json::Value>> {
51        Arc::clone(self.0.derived.catalog.get_or_init(|| {
52            Arc::new(crate::tool_registry::project_tool_catalog(
53                self.0.surface.searchable_tools_iter().cloned(),
54            ))
55        }))
56    }
57}
58
59#[derive(Clone, Debug)]
60pub struct InjectedTurnInput {
61    pub id: Option<String>,
62    pub message: PluginMessage,
63}
64
65#[derive(Debug, thiserror::Error)]
66pub enum SessionError {
67    #[error("I/O error: {0}")]
68    Io(#[from] std::io::Error),
69    #[error("JSON error: {0}")]
70    Json(#[from] serde_json::Error),
71    #[error("code execution is not available in this session")]
72    CodeExecutionUnavailable,
73    #[error("code execution runtime exited unexpectedly")]
74    CodeExecutionRuntimeStopped,
75    #[error(
76        "provider mismatch for session `{session_id}`: persisted provider `{expected}` does not match live provider `{actual}`"
77    )]
78    ProviderMismatch {
79        expected: String,
80        actual: String,
81        session_id: String,
82    },
83    #[error("provider is not configured for session `{session_id}`")]
84    ProviderUnconfigured { session_id: String },
85    #[error("provider `{provider_id}` is not registered for session `{session_id}`")]
86    ProviderUnavailable {
87        provider_id: String,
88        session_id: String,
89    },
90    #[error("protocol error: {0}")]
91    Protocol(String),
92}
93
94#[derive(Clone, Debug)]
95pub struct ExecRequest {
96    pub code: String,
97    pub accept_finish: bool,
98}
99
100pub struct Session {
101    session_id: String,
102    services: RuntimeServices,
103    include_base_tools: bool,
104    context_surface_revision: u64,
105    context_tools: Vec<Arc<dyn ToolProvider>>,
106    tool_registry: Arc<crate::ToolRegistry>,
107    context_prompt_contributions: Vec<PromptContribution>,
108    message_tx: Option<UnboundedSender<SandboxMessage>>,
109    tool_surface_cache: std::sync::Mutex<Vec<(ToolSurfaceCacheKey, ToolSurfaceHandle)>>,
110    /// Memoizes the rendered system prompt across turns. Most consecutive
111    /// turns reuse the same template + context surface, so the cache hits
112    /// and we skip the section/Vec-join work in
113    /// `lash_sansio::PromptTemplate::render`.
114    prompt_cache: Arc<lash_sansio::PromptCache>,
115}
116
117impl Session {
118    pub async fn new(services: RuntimeServices, session_id: &str) -> Result<Self, SessionError> {
119        let tool_registry = services.plugins.tool_registry();
120        let mut session = Self {
121            session_id: session_id.to_string(),
122            services,
123            include_base_tools: true,
124            context_surface_revision: 0,
125            context_tools: Vec::new(),
126            tool_registry,
127            context_prompt_contributions: Vec::new(),
128            message_tx: None,
129            tool_surface_cache: std::sync::Mutex::new(Vec::new()),
130            prompt_cache: Arc::new(lash_sansio::PromptCache::new()),
131        };
132
133        let protocol_session = Arc::clone(session.plugins().protocol_session());
134        protocol_session
135            .initialize_session(crate::plugin::ProtocolSessionContext::new(
136                &mut session,
137                session_id,
138            ))
139            .await?;
140
141        Ok(session)
142    }
143
144    pub fn session_id(&self) -> &str {
145        &self.session_id
146    }
147
148    pub(crate) fn protocol_extra_prompt_contributions(&self) -> Vec<PromptContribution> {
149        // Protocol-specific prompt contributions are owned by the protocol
150        // plugins via their
151        // `reg.prompt().contribute(...)` hooks. Nothing to add here.
152        Vec::new()
153    }
154
155    pub fn tools(&self) -> Arc<dyn ToolProvider> {
156        Arc::clone(&self.tool_registry) as Arc<dyn ToolProvider>
157    }
158
159    pub(crate) fn tool_registry(&self) -> Arc<crate::ToolRegistry> {
160        Arc::clone(&self.tool_registry)
161    }
162
163    pub fn plugins(&self) -> &Arc<crate::PluginSession> {
164        &self.services.plugins
165    }
166
167    pub fn set_context_surface(
168        &mut self,
169        tool_providers: Vec<Arc<dyn ToolProvider>>,
170        prompt_contributions: Vec<PromptContribution>,
171        include_base_tools: bool,
172    ) -> Result<(), crate::PluginError> {
173        let tool_providers_unchanged = self.context_tools.len() == tool_providers.len()
174            && self
175                .context_tools
176                .iter()
177                .zip(&tool_providers)
178                .all(|(current, next)| Arc::ptr_eq(current, next));
179        if self.include_base_tools == include_base_tools
180            && self.context_prompt_contributions == prompt_contributions
181            && tool_providers_unchanged
182        {
183            return Ok(());
184        }
185        let registry = self
186            .services
187            .plugins
188            .tool_registry()
189            .compose_session_surface(include_base_tools, tool_providers.clone())
190            .map(Arc::new)
191            .map_err(|err| {
192                crate::PluginError::Session(format!("failed to build session tool registry: {err}"))
193            })?;
194        self.include_base_tools = include_base_tools;
195        self.context_surface_revision = self.context_surface_revision.wrapping_add(1);
196        self.context_tools = tool_providers;
197        self.tool_registry = registry;
198        self.context_prompt_contributions = prompt_contributions;
199        self.tool_surface_cache
200            .lock()
201            .expect("tool surface cache lock")
202            .clear();
203        Ok(())
204    }
205
206    pub fn prompt_cache(&self) -> Arc<lash_sansio::PromptCache> {
207        Arc::clone(&self.prompt_cache)
208    }
209
210    pub fn context_prompt_contributions(&self) -> &[PromptContribution] {
211        &self.context_prompt_contributions
212    }
213
214    pub fn history_store(&self) -> Option<Arc<dyn crate::store::RuntimePersistence>> {
215        self.services.store.clone()
216    }
217
218    fn tool_surface_cache_key(&self) -> ToolSurfaceCacheKey {
219        ToolSurfaceCacheKey {
220            include_base_tools: self.include_base_tools,
221            context_surface_revision: self.context_surface_revision,
222            tool_generation: self.tool_registry.generation(),
223            plugin_revision: self.plugins().snapshot_revision_fingerprint(),
224            lashlang_language_features: self.plugins().lashlang_language_features(),
225        }
226    }
227
228    fn build_tool_surface_entry(
229        &self,
230        session_id: &str,
231    ) -> Result<ToolSurfaceHandle, crate::PluginError> {
232        let provider = self.tools();
233        let tools = provider.tool_manifests();
234        let contract_provider = Arc::clone(&provider);
235        let resolve_contract: lash_sansio::ToolContractResolver =
236            Arc::new(move |name: &str| contract_provider.resolve_contract(name));
237        let surface = Arc::new(self.plugins().resolve_tool_surface(
238            crate::plugin::ToolSurfaceContext {
239                session_id: session_id.to_string(),
240                tools,
241                resolve_contract: Some(Arc::clone(&resolve_contract)),
242                tool_access: self.plugins().tool_access().clone(),
243                subagent: self.plugins().subagent_context().cloned(),
244                lashlang_abilities: self.plugins().lashlang_abilities(),
245            },
246        )?);
247        let input = crate::ProtocolBuildInput {
248            tool_surface: Arc::clone(&surface),
249            lashlang_surface: execution_context::lashlang_surface_from_tool_surface(
250                &surface,
251                self.plugins().lashlang_abilities(),
252                self.plugins().lashlang_language_features(),
253                self.plugins().lashlang_resources(),
254            ),
255            extra_prompt_contributions: self.protocol_extra_prompt_contributions(),
256        };
257        let driver = self.plugins().protocol_driver();
258        let preamble = driver.build_preamble(input);
259        Ok(ToolSurfaceHandle(Arc::new(ToolSurfaceArtifact {
260            surface,
261            preamble: Arc::new(preamble),
262            derived: ToolSurfaceDerived::default(),
263        })))
264    }
265
266    fn tool_surface_cache_entry(
267        &self,
268        session_id: &str,
269    ) -> Result<ToolSurfaceHandle, crate::PluginError> {
270        let key = self.tool_surface_cache_key();
271        let mut cache = self
272            .tool_surface_cache
273            .lock()
274            .expect("tool surface cache lock");
275        if let Some((_, entry)) = cache.iter().find(|(entry_key, _)| *entry_key == key) {
276            return Ok(entry.clone());
277        }
278        let entry = self.build_tool_surface_entry(session_id)?;
279        cache.push((key, entry.clone()));
280        Ok(entry)
281    }
282
283    pub fn tool_surface(
284        &self,
285        session_id: &str,
286    ) -> Result<Arc<crate::ToolSurface>, crate::PluginError> {
287        Ok(self.tool_surface_cache_entry(session_id)?.surface())
288    }
289
290    pub(crate) fn turn_driver_preamble(
291        &self,
292        session_id: &str,
293    ) -> Result<Arc<crate::TurnDriverPreamble>, crate::PluginError> {
294        Ok(self.tool_surface_cache_entry(session_id)?.preamble())
295    }
296
297    pub(crate) fn shared_tool_catalog(
298        &self,
299        session_id: &str,
300    ) -> Result<Arc<Vec<serde_json::Value>>, crate::PluginError> {
301        Ok(self.tool_surface_cache_entry(session_id)?.catalog())
302    }
303
304    pub fn tool_catalog(
305        &self,
306        session_id: &str,
307    ) -> Result<Vec<serde_json::Value>, crate::PluginError> {
308        Ok(self.shared_tool_catalog(session_id)?.as_ref().clone())
309    }
310
311    #[allow(
312        clippy::too_many_arguments,
313        reason = "code execution bridge carries explicit per-turn runtime dependencies"
314    )]
315    pub(crate) fn code_execution_context<'run>(
316        &self,
317        session_id: &str,
318        agent_frame_id: &str,
319        sessions: Arc<dyn crate::plugin::SessionStateService>,
320        session_lifecycle: Arc<dyn crate::plugin::SessionLifecycleService>,
321        session_graph: Arc<dyn crate::plugin::SessionGraphService>,
322        processes: Arc<dyn crate::ProcessService>,
323        process_cancel_ability: Arc<dyn crate::ProcessCancelAbility>,
324        effect_controller: crate::runtime::RuntimeEffectControllerHandle<'run>,
325        direct_completions: crate::DirectCompletionClient<'run>,
326        host_event_router: Option<crate::HostEventRouter>,
327        event_tx: tokio::sync::mpsc::Sender<SessionEvent>,
328        chronological_projection: Arc<crate::ChronologicalProjection>,
329        protocol_extension: Option<crate::ProtocolTurnExtensionHandle>,
330        turn_context: crate::TurnContext,
331        checkpoint_messages: crate::tool_dispatch::CheckpointMessageBuffer,
332    ) -> Result<RuntimeExecutionContext<'run>, crate::PluginError> {
333        let dispatch = Arc::new(ToolDispatchContext {
334            plugins: Arc::clone(self.plugins()),
335            tools: self.tools(),
336            surface: self.tool_surface(session_id)?,
337            sessions,
338            session_lifecycle,
339            session_graph,
340            processes,
341            process_cancel_ability,
342            host_event_router,
343            effect_controller,
344            direct_completions: direct_completions.clone(),
345            parent_invocation: None,
346            session_id: session_id.to_string(),
347            agent_frame_id: agent_frame_id.to_string(),
348            event_tx,
349            checkpoint_messages,
350            host_event_outcomes: crate::tool_dispatch::ToolHostEventOutcomeBuffer::default(),
351            attachment_store: Arc::clone(&self.services.attachment_store),
352            turn_context: turn_context.clone(),
353        });
354        Ok(RuntimeExecutionContext::new(
355            session_id.to_string(),
356            dispatch,
357            self.plugins().lashlang_abilities(),
358            self.plugins().lashlang_language_features(),
359            Arc::clone(&self.services.lashlang_artifact_store),
360            Arc::clone(&self.services.attachment_store),
361            chronological_projection,
362            protocol_extension,
363            turn_context,
364        ))
365    }
366
367    /// Set the message sender for streaming messages during execution.
368    pub fn set_message_sender(&mut self, tx: UnboundedSender<SandboxMessage>) {
369        self.message_tx = Some(tx);
370    }
371
372    /// Clear the message sender (drops the sender, causing receivers to terminate).
373    pub fn clear_message_sender(&mut self) {
374        self.message_tx = None;
375    }
376
377    pub fn invalidate_runtime_caches(&self) {
378        self.tool_surface_cache
379            .lock()
380            .expect("tool surface cache lock")
381            .clear();
382        self.prompt_cache.clear();
383    }
384
385    pub async fn refresh_tool_surface(&mut self) -> Result<(), SessionError> {
386        self.tool_registry = self
387            .services
388            .plugins
389            .tool_registry()
390            .compose_session_surface(self.include_base_tools, self.context_tools.clone())
391            .map(Arc::new)
392            .map_err(|err| SessionError::Protocol(format!("tool reconfigure failed: {err}")))?;
393        self.tool_surface_cache
394            .lock()
395            .expect("tool surface cache lock")
396            .clear();
397        Ok(())
398    }
399}