Skip to main content

lash_core/plugin/
session_obj.rs

1use std::collections::{BTreeMap, BTreeSet};
2use std::sync::Arc;
3
4use futures_util::stream::{FuturesUnordered, StreamExt};
5use sha2::{Digest, Sha256};
6
7use super::*;
8
9mod directives;
10mod tools;
11
12async fn collect_owned_async<C, O, H, F>(
13    hooks: &[RegisteredHook<H>],
14    ctx: C,
15    hook_kind: &'static str,
16    phase_probe: Option<&Arc<dyn crate::runtime::RuntimeTurnPhaseProbe>>,
17    invoke: F,
18) -> Result<Vec<PluginOwned<O>>, PluginError>
19where
20    C: Clone,
21    F: Fn(&H, C) -> PluginFuture<Vec<O>>,
22{
23    let mut out = Vec::new();
24    for registered in hooks {
25        let phase_name = plugin_hook_phase_name(hook_kind, &registered.plugin_id);
26        if let Some(probe) = phase_probe {
27            probe.begin_named(&phase_name);
28        }
29        let result = invoke(&registered.hook, ctx.clone()).await;
30        if let Some(probe) = phase_probe {
31            probe.end_named(&phase_name);
32        }
33        for value in result? {
34            out.push(PluginOwned {
35                plugin_id: registered.plugin_id.clone(),
36                value,
37            });
38        }
39    }
40    Ok(out)
41}
42
43fn plugin_hook_phase_name(hook_kind: &str, plugin_id: &str) -> String {
44    format!("plugin_hook.{hook_kind}.{plugin_id}")
45}
46
47fn lifecycle_event_hook_kind(event: &PluginLifecycleEvent<'_>) -> &'static str {
48    match event {
49        PluginLifecycleEvent::TurnFinalized(_) => "turn_finalized",
50        PluginLifecycleEvent::TurnPersisted(_) => "turn_persisted",
51        PluginLifecycleEvent::SessionRestored(_) => "session_restored",
52        PluginLifecycleEvent::SessionConfigChanged(_) => "session_config_changed",
53    }
54}
55
56fn collect_owned_sync<C, O, H, F>(
57    hooks: &[RegisteredHook<H>],
58    ctx: C,
59    invoke: F,
60) -> Result<Vec<PluginOwned<O>>, PluginError>
61where
62    C: Clone,
63    F: Fn(&H, C) -> Result<O, PluginError>,
64{
65    let mut out = Vec::new();
66    for registered in hooks {
67        out.push(PluginOwned {
68            plugin_id: registered.plugin_id.clone(),
69            value: invoke(&registered.hook, ctx.clone())?,
70        });
71    }
72    Ok(out)
73}
74
75struct EmptySnapshotReader;
76
77impl SnapshotReader for EmptySnapshotReader {
78    fn read_blob(&self, _name: &str) -> Option<&[u8]> {
79        None
80    }
81}
82
83pub struct PluginSession {
84    pub(super) host: PluginHost,
85    pub(super) session_id: String,
86    pub(super) plugins: Vec<Arc<dyn SessionPlugin>>,
87    pub(super) tools: Arc<dyn ToolProvider>,
88    pub(super) tool_registry: Arc<crate::ToolRegistry>,
89    pub(super) tool_surface_overlay: ToolSurfaceContribution,
90    pub(super) tool_access: SessionToolAccess,
91    pub(super) subagent: Option<SubagentSessionContext>,
92    pub(super) lashlang_abilities: lashlang::LashlangAbilities,
93    pub(super) lashlang_language_features: lashlang::LashlangLanguageFeatures,
94    pub(super) lashlang_resources: lashlang::ResourceCatalog,
95    pub(super) host_events: crate::HostEventCatalog,
96    pub(super) contributions: PluginContributions,
97}
98impl PluginSession {
99    pub fn session_id(&self) -> &str {
100        &self.session_id
101    }
102
103    pub fn tool_access(&self) -> &SessionToolAccess {
104        &self.tool_access
105    }
106
107    pub fn subagent_context(&self) -> Option<&SubagentSessionContext> {
108        self.subagent.as_ref()
109    }
110
111    pub fn lashlang_abilities(&self) -> lashlang::LashlangAbilities {
112        self.lashlang_abilities
113    }
114
115    pub fn lashlang_language_features(&self) -> lashlang::LashlangLanguageFeatures {
116        self.lashlang_language_features
117    }
118
119    pub fn lashlang_resources(&self) -> lashlang::ResourceCatalog {
120        self.lashlang_resources.clone()
121    }
122
123    pub fn host_events(&self) -> &crate::HostEventCatalog {
124        &self.host_events
125    }
126
127    pub fn host(&self) -> &PluginHost {
128        &self.host
129    }
130
131    pub fn tools(&self) -> Arc<dyn ToolProvider> {
132        Arc::clone(&self.tools)
133    }
134
135    pub fn tool_registry(&self) -> Arc<crate::ToolRegistry> {
136        Arc::clone(&self.tool_registry)
137    }
138
139    pub(crate) fn protocol_session(&self) -> &Arc<dyn ProtocolSessionPlugin> {
140        &self
141            .contributions
142            .protocol_session
143            .as_ref()
144            .expect("plugin session must have a protocol session")
145            .hook
146    }
147
148    pub(crate) fn code_executor(&self) -> Option<Arc<dyn CodeExecutorPlugin>> {
149        self.contributions
150            .code_executor
151            .as_ref()
152            .map(|entry| Arc::clone(&entry.hook))
153    }
154
155    pub(crate) fn assistant_prose_projector(
156        &self,
157    ) -> Option<Arc<dyn AssistantProseProjectorPlugin>> {
158        self.contributions
159            .assistant_prose_projector
160            .as_ref()
161            .map(|entry| Arc::clone(&entry.hook))
162    }
163
164    pub fn protocol_driver(&self) -> Arc<dyn ProtocolDriverPlugin> {
165        self.contributions
166            .protocol_driver
167            .as_ref()
168            .map(|entry| Arc::clone(&entry.hook))
169            .expect("plugin session must have a protocol driver")
170    }
171
172    pub fn plugin_actions(&self) -> Vec<PluginActionDef> {
173        self.contributions
174            .plugin_actions
175            .values()
176            .map(|op| op.def.clone())
177            .collect()
178    }
179
180    pub fn has_assistant_stream_hooks(&self) -> bool {
181        !self.contributions.assistant_stream_hooks.is_empty()
182    }
183
184    /// Chain registered turn-context transforms, piping each one's output
185    /// into the next in priority order.
186    pub async fn prepare_turn_context(
187        &self,
188        ctx: &TurnTransformContext<'_>,
189        input: crate::session_model::context::PreparedContext,
190        phase_probe: Option<Arc<dyn crate::runtime::RuntimeTurnPhaseProbe>>,
191    ) -> Result<crate::session_model::context::PreparedContext, HistoryError> {
192        let mut current = input;
193        for (_, registered) in &self.contributions.turn_context_transforms {
194            let phase_name =
195                plugin_hook_phase_name("context_transform", registered.plugin_id.as_str());
196            if let Some(probe) = phase_probe.as_ref() {
197                probe.begin_named(&phase_name);
198            }
199            let result = registered.hook.transform(ctx, current).await;
200            if let Some(probe) = phase_probe.as_ref() {
201                probe.end_named(&phase_name);
202            }
203            current = result?;
204        }
205        Ok(current)
206    }
207
208    /// Chain registered history rewriters, skipping any that opt out of
209    /// the current trigger via `accepts()`.
210    pub async fn rewrite_history(
211        &self,
212        ctx: &RewriteContext<'_>,
213        input: HistoryState,
214    ) -> Result<HistoryState, HistoryError> {
215        let mut current = input;
216        for (_, registered) in &self.contributions.history_rewriters {
217            if !registered.hook.accepts(&ctx.trigger) {
218                continue;
219            }
220            current = registered.hook.rewrite(ctx, current).await?;
221        }
222        Ok(current)
223    }
224
225    pub async fn collect_prompt_contributions(
226        &self,
227        ctx: PromptHookContext,
228    ) -> Result<Vec<PromptContribution>, PluginError> {
229        let mut out = collect_owned_async(
230            &self.contributions.prompt_contributors,
231            ctx,
232            "prompt_contributor",
233            None,
234            |hook, ctx| hook(ctx),
235        )
236        .await?
237        .into_iter()
238        .map(|owned| owned.value)
239        .collect::<Vec<_>>();
240        let mut seen = BTreeSet::new();
241        out.retain(|contribution| {
242            seen.insert((
243                format!("{:?}", contribution.slot),
244                contribution.priority,
245                contribution.content.trim().to_string(),
246            ))
247        });
248        out.sort_by(|a, b| {
249            format!("{:?}", a.slot)
250                .cmp(&format!("{:?}", b.slot))
251                .then(a.priority.cmp(&b.priority))
252        });
253        Ok(out)
254    }
255
256    pub async fn before_turn(
257        &self,
258        ctx: TurnHookContext,
259    ) -> Result<Vec<PluginOwned<PluginDirective>>, PluginError> {
260        self.before_turn_with_phase_probe(ctx, None).await
261    }
262
263    async fn before_turn_with_phase_probe(
264        &self,
265        ctx: TurnHookContext,
266        phase_probe: Option<&Arc<dyn crate::runtime::RuntimeTurnPhaseProbe>>,
267    ) -> Result<Vec<PluginOwned<PluginDirective>>, PluginError> {
268        collect_owned_async(
269            &self.contributions.before_turn_hooks,
270            ctx,
271            "before_turn",
272            phase_probe,
273            |hook, ctx| hook(ctx),
274        )
275        .await
276    }
277
278    pub async fn before_tool_call(
279        &self,
280        ctx: ToolCallHookContext,
281    ) -> Result<Vec<PluginOwned<PluginDirective>>, PluginError> {
282        collect_owned_async(
283            &self.contributions.before_tool_call_hooks,
284            ctx,
285            "before_tool_call",
286            None,
287            |hook, ctx| hook(ctx),
288        )
289        .await
290    }
291
292    pub async fn after_tool_call(
293        &self,
294        ctx: ToolResultHookContext,
295    ) -> Result<Vec<PluginOwned<PluginDirective>>, PluginError> {
296        collect_owned_async(
297            &self.contributions.after_tool_call_hooks,
298            ctx,
299            "after_tool_call",
300            None,
301            |hook, ctx| hook(ctx),
302        )
303        .await
304    }
305
306    pub async fn after_turn(
307        &self,
308        ctx: TurnResultHookContext,
309    ) -> Result<Vec<PluginOwned<PluginDirective>>, PluginError> {
310        self.after_turn_with_phase_probe(ctx, None).await
311    }
312
313    async fn after_turn_with_phase_probe(
314        &self,
315        ctx: TurnResultHookContext,
316        phase_probe: Option<&Arc<dyn crate::runtime::RuntimeTurnPhaseProbe>>,
317    ) -> Result<Vec<PluginOwned<PluginDirective>>, PluginError> {
318        collect_owned_async(
319            &self.contributions.after_turn_hooks,
320            ctx,
321            "after_turn",
322            phase_probe,
323            |hook, ctx| hook(ctx),
324        )
325        .await
326    }
327
328    pub async fn at_checkpoint(
329        &self,
330        ctx: CheckpointHookContext,
331    ) -> Result<Vec<PluginOwned<PluginDirective>>, PluginError> {
332        collect_owned_async(
333            &self.contributions.checkpoint_hooks,
334            ctx,
335            "checkpoint",
336            None,
337            |hook, ctx| hook(ctx),
338        )
339        .await
340    }
341
342    pub async fn transform_assistant_stream(
343        &self,
344        session_id: &str,
345        chunk: String,
346    ) -> Result<Vec<PluginOwned<AssistantStreamTransform>>, PluginError> {
347        let mut current = chunk;
348        let mut transforms = Vec::new();
349        for registered in &self.contributions.assistant_stream_hooks {
350            let transform = (registered.hook)(AssistantStreamHookContext {
351                session_id: session_id.to_string(),
352                chunk: current.clone(),
353            })
354            .await?;
355            current = transform.chunk.clone();
356            transforms.push(PluginOwned {
357                plugin_id: registered.plugin_id.clone(),
358                value: transform,
359            });
360        }
361        Ok(transforms)
362    }
363
364    pub async fn transform_assistant_response(
365        &self,
366        session_id: &str,
367        response: crate::llm::types::LlmResponse,
368    ) -> Result<Vec<PluginOwned<AssistantResponseTransform>>, PluginError> {
369        let mut current = response;
370        let mut transforms = Vec::new();
371        for registered in &self.contributions.assistant_response_hooks {
372            let transform = (registered.hook)(AssistantResponseHookContext {
373                session_id: session_id.to_string(),
374                response: current.clone(),
375            })
376            .await?;
377            current = transform.response.clone();
378            transforms.push(PluginOwned {
379                plugin_id: registered.plugin_id.clone(),
380                value: transform,
381            });
382        }
383        Ok(transforms)
384    }
385
386    pub async fn project_tool_result(
387        &self,
388        ctx: ToolResultProjectionContext,
389    ) -> Result<crate::ModelToolReturn, PluginError> {
390        let Some(projector) = &self.contributions.tool_result_projector else {
391            return Ok(crate::ModelToolReturn::from_output(
392                ctx.call_id.clone(),
393                ctx.tool_name.clone(),
394                &ctx.output,
395            ));
396        };
397        (projector.hook)(ctx).await
398    }
399
400    pub async fn emit_runtime_event(&self, event: PluginLifecycleEvent<'_>) {
401        self.emit_runtime_event_with_phase_probe(event, None).await;
402    }
403
404    pub async fn emit_runtime_event_with_phase_probe(
405        &self,
406        event: PluginLifecycleEvent<'_>,
407        phase_probe: Option<Arc<dyn crate::runtime::RuntimeTurnPhaseProbe>>,
408    ) {
409        let hook_kind = lifecycle_event_hook_kind(&event);
410        let mut pending = FuturesUnordered::new();
411        for registered in &self.contributions.runtime_event_hooks {
412            let hook = Arc::clone(&registered.hook);
413            let plugin_id = registered.plugin_id.clone();
414            let phase_name = plugin_hook_phase_name(hook_kind, registered.plugin_id.as_str());
415            let event = event.clone();
416            let phase_probe = phase_probe.clone();
417            pending.push(async move {
418                if let Some(probe) = phase_probe.as_ref() {
419                    probe.begin_named(&phase_name);
420                }
421                let result = hook(event).await;
422                if let Some(probe) = phase_probe.as_ref() {
423                    probe.end_named(&phase_name);
424                }
425                (plugin_id, result)
426            });
427        }
428        while let Some((plugin_id, result)) = pending.next().await {
429            if let Err(err) = result {
430                tracing::warn!(plugin_id, "plugin runtime event hook failed: {err}");
431            }
432        }
433    }
434
435    pub fn has_runtime_event_hooks(&self) -> bool {
436        !self.contributions.runtime_event_hooks.is_empty()
437    }
438
439    pub async fn mutate_session_config(
440        &self,
441        ctx: SessionConfigChangedContext,
442        mut policy: SessionPolicy,
443    ) -> SessionPolicy {
444        for hook in &self.contributions.session_config_mutators {
445            match hook(ctx.clone(), policy.clone()).await {
446                Ok(next_policy) => policy = next_policy,
447                Err(err) => tracing::warn!("plugin config mutator failed: {err}"),
448            }
449        }
450        policy
451    }
452
453    pub fn snapshot(&self) -> Result<PluginSessionSnapshot, PluginError> {
454        let mut plugins = BTreeMap::new();
455        for plugin in &self.plugins {
456            let mut writer = InMemorySnapshotWriter::default();
457            let meta = plugin.snapshot(&mut writer)?;
458            plugins.insert(
459                plugin.id().to_string(),
460                PluginSnapshotEntry {
461                    meta,
462                    artifacts: writer.finish(),
463                },
464            );
465        }
466        Ok(PluginSessionSnapshot { plugins })
467    }
468
469    pub fn snapshot_is_current(&self, previous: Option<&PluginSessionSnapshot>) -> bool {
470        let Some(previous) = previous else {
471            return false;
472        };
473        if previous.plugins.len() != self.plugins.len() {
474            return false;
475        }
476        for plugin in &self.plugins {
477            let Some(entry) = previous.plugins.get(plugin.id()) else {
478                return false;
479            };
480            if entry.meta.plugin_version != plugin.version()
481                || entry.meta.revision != plugin.snapshot_revision()
482            {
483                return false;
484            }
485        }
486        true
487    }
488
489    pub fn snapshot_revision_fingerprint(&self) -> u64 {
490        let mut hasher = Sha256::new();
491        for plugin in &self.plugins {
492            hasher.update(plugin.id().as_bytes());
493            hasher.update([0]);
494            hasher.update(plugin.version().as_bytes());
495            hasher.update([0]);
496            hasher.update(plugin.snapshot_revision().to_le_bytes());
497            hasher.update([0xff]);
498        }
499        let digest = hasher.finalize();
500        u64::from_le_bytes(digest[..8].try_into().expect("digest prefix"))
501    }
502
503    pub fn restore(&self, snapshot: &PluginSessionSnapshot) -> Result<(), PluginError> {
504        for plugin in &self.plugins {
505            if let Some(entry) = snapshot.plugins.get(plugin.id()) {
506                let reader = InMemorySnapshotReader { entry };
507                plugin.restore(&entry.meta, &reader)?;
508            } else {
509                plugin.restore(
510                    &PluginSnapshotMeta {
511                        plugin_id: plugin.id().to_string(),
512                        plugin_version: plugin.version().to_string(),
513                        revision: plugin.snapshot_revision(),
514                        state: None,
515                    },
516                    &EmptySnapshotReader,
517                )?;
518            }
519        }
520        Ok(())
521    }
522
523    pub fn fork_for_session(
524        &self,
525        session_id: impl Into<String>,
526    ) -> Result<Arc<PluginSession>, PluginError> {
527        let snapshot = self.snapshot()?;
528        self.host.build_session_with_surface(
529            session_id,
530            Some(&snapshot),
531            self.tool_surface_overlay.clone(),
532            Some(self.tool_registry.export_state()),
533        )
534    }
535
536    pub fn fork_for_child_session(
537        &self,
538        session_id: impl Into<String>,
539        parent_session_id: Option<String>,
540        authority: super::SessionAuthorityContext,
541    ) -> Result<Arc<PluginSession>, PluginError> {
542        let snapshot = self.snapshot()?;
543        self.host.build_session_with_parent_and_surface(
544            session_id,
545            parent_session_id,
546            Some(&snapshot),
547            self.tool_surface_overlay.clone(),
548            Some(self.tool_registry.export_state()),
549            authority,
550        )
551    }
552
553    pub fn fork_for_session_with_tool_surface(
554        &self,
555        session_id: impl Into<String>,
556        tool_surface_overlay: ToolSurfaceContribution,
557    ) -> Result<Arc<PluginSession>, PluginError> {
558        let snapshot = self.snapshot()?;
559        self.host.build_session_with_surface(
560            session_id,
561            Some(&snapshot),
562            tool_surface_overlay,
563            Some(self.tool_registry.export_state()),
564        )
565    }
566
567    #[expect(
568        clippy::too_many_arguments,
569        reason = "plugin action invocation carries the explicit host services exposed to actions"
570    )]
571    pub async fn invoke_plugin_action(
572        &self,
573        name: &str,
574        args: serde_json::Value,
575        session_id: Option<String>,
576        default_to_current_session: bool,
577        sessions: Arc<dyn SessionStateService>,
578        session_lifecycle: Arc<dyn SessionLifecycleService>,
579        session_graph: Arc<dyn SessionGraphService>,
580        processes: Arc<dyn crate::ProcessService>,
581    ) -> Result<ToolResult, PluginActionInvokeError> {
582        let Some(op) = self.contributions.plugin_actions.get(name).cloned() else {
583            return Err(PluginActionInvokeError::Unknown(name.to_string()));
584        };
585
586        let effective_session = session_id.or_else(|| {
587            if default_to_current_session && !self.session_id.is_empty() {
588                Some(self.session_id.clone())
589            } else {
590                None
591            }
592        });
593
594        match (op.def.session_param, effective_session.as_ref()) {
595            (SessionParam::Required, None) => {
596                return Err(PluginActionInvokeError::MissingSession(name.to_string()));
597            }
598            (SessionParam::Forbidden, Some(_)) => {
599                return Err(PluginActionInvokeError::UnexpectedSession(name.to_string()));
600            }
601            _ => {}
602        }
603
604        Ok((op.handler)(
605            PluginActionContext {
606                session_id: effective_session,
607                sessions,
608                session_lifecycle,
609                session_graph,
610                processes,
611            },
612            args,
613        )
614        .await)
615    }
616
617    #[expect(
618        clippy::too_many_arguments,
619        reason = "typed action invocation mirrors the raw plugin action host service boundary"
620    )]
621    pub async fn call_plugin_action<Op: PluginAction>(
622        &self,
623        args: Op::Args,
624        session_id: Option<String>,
625        default_to_current_session: bool,
626        sessions: Arc<dyn SessionStateService>,
627        session_lifecycle: Arc<dyn SessionLifecycleService>,
628        session_graph: Arc<dyn SessionGraphService>,
629        processes: Arc<dyn crate::ProcessService>,
630    ) -> Result<Op::Output, PluginError> {
631        let args = serde_json::to_value(args)
632            .map_err(|err| PluginError::Invoke(format!("invalid {} args: {err}", Op::NAME)))?;
633        let result = self
634            .invoke_plugin_action(
635                Op::NAME,
636                args,
637                session_id,
638                default_to_current_session,
639                sessions,
640                session_lifecycle,
641                session_graph,
642                processes,
643            )
644            .await
645            .map_err(|err| PluginError::Invoke(err.to_string()))?;
646        if !result.is_success() {
647            return Err(PluginError::Invoke(format!(
648                "{} failed: {}",
649                Op::NAME,
650                result.value_for_projection()
651            )));
652        }
653        serde_json::from_value(result.into_output().value_for_projection())
654            .map_err(|err| PluginError::Invoke(format!("invalid {} output: {err}", Op::NAME)))
655    }
656}