Skip to main content

lash_core/
session.rs

1use std::collections::VecDeque;
2use std::sync::{Arc, OnceLock};
3
4use tokio::sync::mpsc::UnboundedSender;
5
6use crate::PluginMessage;
7use crate::tool_dispatch::ToolDispatchContext;
8use crate::{PromptContribution, RuntimeServices, SandboxMessage, SessionEvent, ToolProvider};
9
10pub(crate) mod async_handles;
11mod execution_context;
12mod monitor_handles;
13mod tool_execution;
14
15pub(crate) use async_handles::AsyncToolHandleMap;
16pub use execution_context::ModeExecutionContext;
17pub use tool_execution::{ModeToolBatchItem, ModeToolReply};
18
19#[derive(Clone, Debug, PartialEq, Eq)]
20struct ToolSurfaceCacheKey {
21    mode: crate::ExecutionMode,
22    include_base_tools: bool,
23    context_surface_revision: u64,
24    tool_generation: u64,
25    plugin_revision: u64,
26}
27
28#[derive(Debug, Default)]
29struct ToolSurfaceDerived {
30    catalog: OnceLock<Arc<Vec<serde_json::Value>>>,
31}
32
33struct ToolSurfaceArtifact {
34    surface: Arc<crate::ToolSurface>,
35    preamble: Arc<crate::ModePreamble>,
36    derived: ToolSurfaceDerived,
37}
38
39#[derive(Clone)]
40pub(crate) struct ToolSurfaceHandle(Arc<ToolSurfaceArtifact>);
41
42impl ToolSurfaceHandle {
43    fn surface(&self) -> Arc<crate::ToolSurface> {
44        Arc::clone(&self.0.surface)
45    }
46
47    fn preamble(&self) -> Arc<crate::ModePreamble> {
48        Arc::clone(&self.0.preamble)
49    }
50
51    fn catalog(&self) -> Arc<Vec<serde_json::Value>> {
52        Arc::clone(self.0.derived.catalog.get_or_init(|| {
53            Arc::new(crate::tool_registry::project_tool_catalog(
54                self.0.surface.searchable_tools_iter().cloned(),
55            ))
56        }))
57    }
58}
59
60#[derive(Clone, Default)]
61pub struct TurnInjectionBridge {
62    queue: std::sync::Arc<std::sync::Mutex<VecDeque<PluginMessage>>>,
63}
64
65#[derive(Clone, Debug)]
66pub struct InjectedTurnInput {
67    pub id: Option<String>,
68    pub message: PluginMessage,
69}
70
71#[derive(Clone, Default)]
72pub struct TurnInputInjectionBridge {
73    queue: std::sync::Arc<std::sync::Mutex<VecDeque<InjectedTurnInput>>>,
74}
75
76impl TurnInjectionBridge {
77    pub fn new() -> Self {
78        Self::default()
79    }
80
81    pub fn enqueue(&self, messages: Vec<PluginMessage>) -> Result<(), String> {
82        let mut queue = self
83            .queue
84            .lock()
85            .map_err(|_| "turn injection bridge poisoned".to_string())?;
86        queue.extend(messages);
87        Ok(())
88    }
89
90    pub fn drain(&self) -> Result<Vec<PluginMessage>, String> {
91        let mut queue = self
92            .queue
93            .lock()
94            .map_err(|_| "turn injection bridge poisoned".to_string())?;
95        Ok(queue.drain(..).collect())
96    }
97}
98
99impl TurnInputInjectionBridge {
100    pub fn new() -> Self {
101        Self::default()
102    }
103
104    pub fn enqueue(&self, messages: Vec<InjectedTurnInput>) -> Result<(), String> {
105        let mut queue = self
106            .queue
107            .lock()
108            .map_err(|_| "turn injection bridge poisoned".to_string())?;
109        queue.extend(messages);
110        Ok(())
111    }
112
113    pub fn drain(&self) -> Result<Vec<InjectedTurnInput>, String> {
114        let mut queue = self
115            .queue
116            .lock()
117            .map_err(|_| "turn input injection bridge poisoned".to_string())?;
118        Ok(queue.drain(..).collect())
119    }
120}
121
122#[derive(Debug, thiserror::Error)]
123pub enum SessionError {
124    #[error("I/O error: {0}")]
125    Io(#[from] std::io::Error),
126    #[error("JSON error: {0}")]
127    Json(#[from] serde_json::Error),
128    #[error("rlm execution mode is not available in this build or session")]
129    RlmUnavailable,
130    #[error("rlm runtime exited unexpectedly")]
131    RuntimeExited,
132    #[error("protocol error: {0}")]
133    Protocol(String),
134}
135
136#[derive(Clone, Debug)]
137pub struct ExecRequest {
138    pub code: String,
139    pub accept_finish: bool,
140}
141
142pub struct Session {
143    session_id: String,
144    execution_mode: crate::ExecutionMode,
145    services: RuntimeServices,
146    include_base_tools: bool,
147    context_surface_revision: u64,
148    context_tools: Vec<Arc<dyn ToolProvider>>,
149    context_prompt_contributions: Vec<PromptContribution>,
150    message_tx: Option<UnboundedSender<SandboxMessage>>,
151    tool_surface_cache: std::sync::Mutex<Vec<(ToolSurfaceCacheKey, ToolSurfaceHandle)>>,
152    /// Memoizes the rendered system prompt across turns. Most consecutive
153    /// turns reuse the same template + context surface, so the cache hits
154    /// and we skip the section/Vec-join work in
155    /// `lash_sansio::PromptTemplate::render`.
156    prompt_cache: Arc<lash_sansio::PromptCache>,
157    async_tool_handles: AsyncToolHandleMap,
158}
159
160impl Session {
161    pub async fn new(
162        services: RuntimeServices,
163        session_id: &str,
164        execution_mode: crate::ExecutionMode,
165    ) -> Result<Self, SessionError> {
166        let mut session = Self {
167            session_id: session_id.to_string(),
168            execution_mode,
169            services,
170            include_base_tools: true,
171            context_surface_revision: 0,
172            context_tools: Vec::new(),
173            context_prompt_contributions: Vec::new(),
174            message_tx: None,
175            tool_surface_cache: std::sync::Mutex::new(Vec::new()),
176            prompt_cache: Arc::new(lash_sansio::PromptCache::new()),
177            async_tool_handles: Default::default(),
178        };
179
180        let mode_session = Arc::clone(session.plugins().mode_session());
181        mode_session
182            .initialize_session(crate::plugin::ModeSessionContext::new(
183                &mut session,
184                session_id,
185            ))
186            .await?;
187
188        Ok(session)
189    }
190
191    pub fn session_id(&self) -> &str {
192        &self.session_id
193    }
194
195    pub(crate) fn mode_extra_prompt_contributions(
196        &self,
197        _mode: &crate::ExecutionMode,
198    ) -> Vec<PromptContribution> {
199        // Mode-specific prompt contributions are owned by the mode
200        // plugins (`lash-mode-standard`, `lash-mode-rlm`) via their
201        // `reg.prompt().contribute(...)` hooks. Nothing to add here.
202        Vec::new()
203    }
204
205    pub fn tools(&self) -> Arc<dyn ToolProvider> {
206        if self.include_base_tools && self.context_tools.is_empty() {
207            return self.services.plugins.tools();
208        }
209
210        let mut providers = Vec::new();
211        if self.include_base_tools {
212            providers.push(self.services.plugins.tools());
213        }
214        providers.extend(self.context_tools.iter().cloned());
215        Arc::new(crate::tool_provider::CompositeToolProvider::from_providers(
216            providers,
217        ))
218    }
219
220    pub fn plugins(&self) -> &Arc<crate::PluginSession> {
221        &self.services.plugins
222    }
223
224    pub fn set_context_surface(
225        &mut self,
226        tool_providers: Vec<Arc<dyn ToolProvider>>,
227        prompt_contributions: Vec<PromptContribution>,
228        include_base_tools: bool,
229    ) {
230        let tool_providers_unchanged = self.context_tools.len() == tool_providers.len()
231            && self
232                .context_tools
233                .iter()
234                .zip(&tool_providers)
235                .all(|(current, next)| Arc::ptr_eq(current, next));
236        if self.include_base_tools == include_base_tools
237            && self.context_prompt_contributions == prompt_contributions
238            && tool_providers_unchanged
239        {
240            return;
241        }
242        self.include_base_tools = include_base_tools;
243        self.context_surface_revision = self.context_surface_revision.wrapping_add(1);
244        self.context_tools = tool_providers;
245        self.context_prompt_contributions = prompt_contributions;
246        self.tool_surface_cache
247            .lock()
248            .expect("tool surface cache lock")
249            .clear();
250    }
251
252    pub fn prompt_cache(&self) -> Arc<lash_sansio::PromptCache> {
253        Arc::clone(&self.prompt_cache)
254    }
255
256    pub fn context_prompt_contributions(&self) -> &[PromptContribution] {
257        &self.context_prompt_contributions
258    }
259
260    pub fn history_store(&self) -> Option<Arc<dyn crate::store::RuntimePersistence>> {
261        self.services.store.clone()
262    }
263
264    fn tool_surface_cache_key(&self, mode: &crate::ExecutionMode) -> ToolSurfaceCacheKey {
265        ToolSurfaceCacheKey {
266            mode: mode.clone(),
267            include_base_tools: self.include_base_tools,
268            context_surface_revision: self.context_surface_revision,
269            tool_generation: self.plugins().tool_registry().generation(),
270            plugin_revision: self.plugins().snapshot_revision_fingerprint(),
271        }
272    }
273
274    fn build_tool_surface_entry(
275        &self,
276        session_id: &str,
277        mode: crate::ExecutionMode,
278    ) -> ToolSurfaceHandle {
279        let provider = self.tools();
280        let mut tools = provider.tool_manifests();
281        let contract_provider = Arc::clone(&provider);
282        let plugins = self.plugins();
283        let native_contract_providers = plugins.mode_native_tools().to_vec();
284        let resolve_contract: lash_sansio::ToolContractResolver = Arc::new(move |name: &str| {
285            contract_provider.resolve_contract(name).or_else(|| {
286                native_contract_providers
287                    .iter()
288                    .find_map(|provider| provider.resolve_contract(name))
289            })
290        });
291        if self.include_base_tools && mode == self.plugins().execution_mode() {
292            let native_tools = self.plugins().mode_native_tool_manifests();
293            tools.extend(native_tools);
294        }
295        let surface = match self
296            .plugins()
297            .resolve_tool_surface(crate::plugin::ToolSurfaceContext {
298                session_id: session_id.to_string(),
299                mode: mode.clone(),
300                tools,
301                resolve_contract: Some(Arc::clone(&resolve_contract)),
302                tool_access: self.plugins().tool_access().clone(),
303                subagent: self.plugins().subagent_authority().cloned(),
304            }) {
305            Ok(surface) => Arc::new(surface),
306            Err(err) => {
307                tracing::warn!("failed to resolve tool surface: {err}");
308                let provider = self.tools();
309                let mut fallback_tools = provider.tool_manifests();
310                if self.include_base_tools && mode == self.plugins().execution_mode() {
311                    let native_tools = self.plugins().mode_native_tool_manifests();
312                    fallback_tools.extend(native_tools);
313                }
314                Arc::new(crate::build_tool_surface(crate::ToolSurfaceBuildInput {
315                    tools: fallback_tools,
316                    mode: mode.clone(),
317                    resolve_contract: Some(resolve_contract),
318                    contributions: Vec::new(),
319                }))
320            }
321        };
322        let input = crate::ModeBuildInput {
323            mode: mode.clone(),
324            tool_surface: Arc::clone(&surface),
325            extra_prompt_contributions: self.mode_extra_prompt_contributions(&mode),
326        };
327        let driver = self.plugins().mode_protocol_driver().unwrap_or_else(|| {
328            panic!(
329                "no protocol driver registered for execution mode `{}` — \
330                 did you forget to register the mode plugin (e.g. \
331                 `lash_mode_standard::BuiltinStandardModePluginFactory` or \
332                 `lash_mode_rlm::BuiltinRlmModePluginFactory`)?",
333                mode.plugin_id()
334            )
335        });
336        assert_eq!(
337            driver.mode_id(),
338            mode.plugin_id(),
339            "protocol driver `{}` does not match session mode `{}`",
340            driver.mode_id(),
341            mode.plugin_id(),
342        );
343        let preamble = driver.build_preamble(input);
344        ToolSurfaceHandle(Arc::new(ToolSurfaceArtifact {
345            surface,
346            preamble: Arc::new(preamble),
347            derived: ToolSurfaceDerived::default(),
348        }))
349    }
350
351    fn tool_surface_cache_entry(
352        &self,
353        session_id: &str,
354        mode: crate::ExecutionMode,
355    ) -> ToolSurfaceHandle {
356        let key = self.tool_surface_cache_key(&mode);
357        let mut cache = self
358            .tool_surface_cache
359            .lock()
360            .expect("tool surface cache lock");
361        if let Some((_, entry)) = cache.iter().find(|(entry_key, _)| *entry_key == key) {
362            return entry.clone();
363        }
364        let entry = self.build_tool_surface_entry(session_id, mode);
365        cache.push((key, entry.clone()));
366        entry
367    }
368
369    pub fn tool_surface(
370        &self,
371        session_id: &str,
372        mode: crate::ExecutionMode,
373    ) -> Arc<crate::ToolSurface> {
374        self.tool_surface_cache_entry(session_id, mode).surface()
375    }
376
377    pub(crate) fn mode_preamble(
378        &self,
379        session_id: &str,
380        mode: crate::ExecutionMode,
381    ) -> Arc<crate::ModePreamble> {
382        self.tool_surface_cache_entry(session_id, mode).preamble()
383    }
384
385    pub(crate) fn shared_tool_catalog(
386        &self,
387        session_id: &str,
388        mode: crate::ExecutionMode,
389    ) -> Arc<Vec<serde_json::Value>> {
390        self.tool_surface_cache_entry(session_id, mode).catalog()
391    }
392
393    pub fn tool_catalog(
394        &self,
395        session_id: &str,
396        mode: crate::ExecutionMode,
397    ) -> Vec<serde_json::Value> {
398        self.shared_tool_catalog(session_id, mode).as_ref().clone()
399    }
400
401    #[allow(
402        clippy::too_many_arguments,
403        reason = "mode execution bridge carries explicit per-turn runtime dependencies"
404    )]
405    pub(crate) fn mode_execution_context(
406        &self,
407        session_id: &str,
408        host: Arc<dyn crate::plugin::ToolHookHost>,
409        event_tx: tokio::sync::mpsc::Sender<SessionEvent>,
410        chronological_projection: Arc<crate::ChronologicalProjection>,
411        mode_extension: Option<crate::ModeTurnExtensionHandle>,
412        turn_context: crate::TurnContext,
413    ) -> ModeExecutionContext {
414        let dispatch = Arc::new(ToolDispatchContext {
415            plugins: Arc::clone(self.plugins()),
416            tools: self.tools(),
417            surface: self.tool_surface(session_id, self.execution_mode.clone()),
418            host,
419            session_id: session_id.to_string(),
420            event_tx,
421            turn_injection_bridge: self.turn_injection_bridge().clone(),
422            attachment_store: Arc::clone(&self.services.attachment_store),
423            turn_context: turn_context.clone(),
424        });
425        ModeExecutionContext::new(
426            session_id.to_string(),
427            self.execution_mode.clone(),
428            dispatch,
429            Arc::clone(&self.async_tool_handles),
430            self.message_tx.clone(),
431            Arc::clone(&self.services.attachment_store),
432            chronological_projection,
433            mode_extension,
434            turn_context,
435        )
436    }
437
438    pub fn turn_injection_bridge(&self) -> &TurnInjectionBridge {
439        &self.services.turn_injection_bridge
440    }
441
442    pub fn turn_input_injection_bridge(&self) -> &TurnInputInjectionBridge {
443        &self.services.turn_input_injection_bridge
444    }
445
446    /// Set the message sender for streaming messages during execution.
447    pub fn set_message_sender(&mut self, tx: UnboundedSender<SandboxMessage>) {
448        self.message_tx = Some(tx);
449    }
450
451    /// Clear the message sender (drops the sender, causing receivers to terminate).
452    pub fn clear_message_sender(&mut self) {
453        self.message_tx = None;
454    }
455
456    pub async fn reset(&mut self) -> Result<(), SessionError> {
457        self.async_tool_handles
458            .lock()
459            .expect("async tool handle map lock")
460            .clear();
461        self.tool_surface_cache
462            .lock()
463            .expect("tool surface cache lock")
464            .clear();
465        Ok(())
466    }
467
468    pub async fn refresh_tool_surface(&mut self) -> Result<(), SessionError> {
469        self.tool_surface_cache
470            .lock()
471            .expect("tool surface cache lock")
472            .clear();
473        Ok(())
474    }
475}