Skip to main content

lash_core/plugin/
session_obj.rs

1use std::collections::{BTreeMap, BTreeSet};
2use std::sync::Arc;
3
4use futures_util::stream::{FuturesUnordered, StreamExt};
5use sha2::{Digest, Sha256};
6
7use super::*;
8
9mod directives;
10mod tools;
11
12async fn collect_owned_async<C, O, H, F>(
13    hooks: &[RegisteredHook<H>],
14    ctx: C,
15    hook_kind: &'static str,
16    phase_probe: Option<&Arc<dyn crate::runtime::RuntimeTurnPhaseProbe>>,
17    invoke: F,
18) -> Result<Vec<PluginOwned<O>>, PluginError>
19where
20    C: Clone,
21    F: Fn(&H, C) -> PluginFuture<Vec<O>>,
22{
23    let mut out = Vec::new();
24    for registered in hooks {
25        let phase_name = plugin_hook_phase_name(hook_kind, &registered.plugin_id);
26        if let Some(probe) = phase_probe {
27            probe.begin_named(&phase_name);
28        }
29        let result = invoke(&registered.hook, ctx.clone()).await;
30        if let Some(probe) = phase_probe {
31            probe.end_named(&phase_name);
32        }
33        for value in result? {
34            out.push(PluginOwned {
35                plugin_id: registered.plugin_id.clone(),
36                value,
37            });
38        }
39    }
40    Ok(out)
41}
42
43fn plugin_hook_phase_name(hook_kind: &str, plugin_id: &str) -> String {
44    format!("plugin_hook.{hook_kind}.{plugin_id}")
45}
46
47fn lifecycle_event_hook_kind(event: &PluginLifecycleEvent<'_>) -> &'static str {
48    match event {
49        PluginLifecycleEvent::TurnFinalized(_) => "turn_finalized",
50        PluginLifecycleEvent::TurnPersisted(_) => "turn_persisted",
51        PluginLifecycleEvent::SessionRestored(_) => "session_restored",
52        PluginLifecycleEvent::SessionConfigChanged(_) => "session_config_changed",
53    }
54}
55
56fn collect_owned_sync<C, O, H, F>(
57    hooks: &[RegisteredHook<H>],
58    ctx: C,
59    invoke: F,
60) -> Result<Vec<PluginOwned<O>>, PluginError>
61where
62    C: Clone,
63    F: Fn(&H, C) -> Result<O, PluginError>,
64{
65    let mut out = Vec::new();
66    for registered in hooks {
67        out.push(PluginOwned {
68            plugin_id: registered.plugin_id.clone(),
69            value: invoke(&registered.hook, ctx.clone())?,
70        });
71    }
72    Ok(out)
73}
74
75struct EmptySnapshotReader;
76
77impl SnapshotReader for EmptySnapshotReader {
78    fn read_blob(&self, _name: &str) -> Option<&[u8]> {
79        None
80    }
81}
82
83pub struct PluginSession {
84    pub(super) host: PluginHost,
85    pub(super) session_id: String,
86    pub(super) plugins: Vec<Arc<dyn SessionPlugin>>,
87    pub(super) tools: Arc<dyn ToolProvider>,
88    pub(super) tool_registry: Arc<crate::ToolRegistry>,
89    pub(super) tool_surface_overlay: ToolSurfaceContribution,
90    pub(super) tool_access: SessionToolAccess,
91    pub(super) subagent: Option<SubagentSessionContext>,
92    pub(super) lashlang_abilities: lashlang::LashlangAbilities,
93    pub(super) lashlang_language_features: lashlang::LashlangLanguageFeatures,
94    pub(super) lashlang_resources: lashlang::ResourceCatalog,
95    pub(super) host_events: crate::HostEventCatalog,
96    pub(super) trigger_registry: Arc<SessionTriggerRegistry>,
97    pub(super) contributions: PluginContributions,
98}
99impl PluginSession {
100    pub fn session_id(&self) -> &str {
101        &self.session_id
102    }
103
104    pub fn tool_access(&self) -> &SessionToolAccess {
105        &self.tool_access
106    }
107
108    pub fn subagent_context(&self) -> Option<&SubagentSessionContext> {
109        self.subagent.as_ref()
110    }
111
112    pub fn lashlang_abilities(&self) -> lashlang::LashlangAbilities {
113        self.lashlang_abilities
114    }
115
116    pub fn lashlang_language_features(&self) -> lashlang::LashlangLanguageFeatures {
117        self.lashlang_language_features
118    }
119
120    pub fn lashlang_resources(&self) -> lashlang::ResourceCatalog {
121        self.lashlang_resources.clone()
122    }
123
124    pub fn host_events(&self) -> &crate::HostEventCatalog {
125        &self.host_events
126    }
127
128    pub async fn register_lashlang_trigger(
129        &self,
130        request: serde_json::Value,
131        artifact_store: Arc<dyn lashlang::LashlangArtifactStore>,
132    ) -> Result<serde_json::Value, PluginError> {
133        let route = self
134            .trigger_registry
135            .register_route(request, &self.lashlang_resources, artifact_store.as_ref())
136            .await?;
137        Ok(super::trigger_registry::trigger_handle_json(&route.handle))
138    }
139
140    pub fn list_lashlang_triggers(
141        &self,
142        request: serde_json::Value,
143    ) -> Result<serde_json::Value, PluginError> {
144        serde_json::to_value(self.trigger_registry.list(request)?).map_err(|err| {
145            PluginError::Session(format!("failed to encode trigger registrations: {err}"))
146        })
147    }
148
149    pub fn list_all_lashlang_triggers(&self) -> Result<Vec<TriggerRegistration>, PluginError> {
150        self.trigger_registry.list_all()
151    }
152
153    pub fn lashlang_trigger_registrations_by_source_type(
154        &self,
155        source_type: TriggerSourceType,
156    ) -> Result<Vec<TriggerRegistration>, PluginError> {
157        self.trigger_registry.routes_by_source_type(&source_type)
158    }
159
160    pub fn cancel_lashlang_trigger(
161        &self,
162        request: serde_json::Value,
163    ) -> Result<serde_json::Value, PluginError> {
164        let changed = self.trigger_registry.cancel(request)?;
165        Ok(serde_json::json!(changed))
166    }
167
168    pub(crate) fn trigger_activation_service<'a>(
169        &'a self,
170        processes: Arc<dyn crate::ProcessService>,
171        scoped_effect_controller: crate::ScopedEffectController<'a>,
172    ) -> crate::TriggerActivationService<'a> {
173        crate::TriggerActivationService::new(
174            self.session_id.clone(),
175            Arc::clone(&self.trigger_registry),
176            processes,
177            scoped_effect_controller,
178        )
179    }
180
181    pub fn host(&self) -> &PluginHost {
182        &self.host
183    }
184
185    pub fn tools(&self) -> Arc<dyn ToolProvider> {
186        Arc::clone(&self.tools)
187    }
188
189    pub fn tool_registry(&self) -> Arc<crate::ToolRegistry> {
190        Arc::clone(&self.tool_registry)
191    }
192
193    pub(crate) fn protocol_session(&self) -> &Arc<dyn ProtocolSessionPlugin> {
194        &self
195            .contributions
196            .protocol_session
197            .as_ref()
198            .expect("plugin session must have a protocol session")
199            .hook
200    }
201
202    pub(crate) fn code_executor(&self) -> Option<Arc<dyn CodeExecutorPlugin>> {
203        self.contributions
204            .code_executor
205            .as_ref()
206            .map(|entry| Arc::clone(&entry.hook))
207    }
208
209    pub(crate) fn assistant_prose_projector(
210        &self,
211    ) -> Option<Arc<dyn AssistantProseProjectorPlugin>> {
212        self.contributions
213            .assistant_prose_projector
214            .as_ref()
215            .map(|entry| Arc::clone(&entry.hook))
216    }
217
218    pub fn protocol_driver(&self) -> Arc<dyn ProtocolDriverPlugin> {
219        self.contributions
220            .protocol_driver
221            .as_ref()
222            .map(|entry| Arc::clone(&entry.hook))
223            .expect("plugin session must have a protocol driver")
224    }
225
226    pub fn plugin_actions(&self) -> Vec<PluginActionDef> {
227        self.contributions
228            .plugin_actions
229            .values()
230            .map(|op| op.def.clone())
231            .collect()
232    }
233
234    pub fn has_assistant_stream_hooks(&self) -> bool {
235        !self.contributions.assistant_stream_hooks.is_empty()
236    }
237
238    /// Chain registered turn-context transforms, piping each one's output
239    /// into the next in priority order.
240    pub async fn prepare_turn_context(
241        &self,
242        ctx: &TurnTransformContext<'_>,
243        input: crate::session_model::context::PreparedContext,
244        phase_probe: Option<Arc<dyn crate::runtime::RuntimeTurnPhaseProbe>>,
245    ) -> Result<crate::session_model::context::PreparedContext, HistoryError> {
246        let mut current = input;
247        for (_, registered) in &self.contributions.turn_context_transforms {
248            let phase_name =
249                plugin_hook_phase_name("context_transform", registered.plugin_id.as_str());
250            if let Some(probe) = phase_probe.as_ref() {
251                probe.begin_named(&phase_name);
252            }
253            let result = registered.hook.transform(ctx, current).await;
254            if let Some(probe) = phase_probe.as_ref() {
255                probe.end_named(&phase_name);
256            }
257            current = result?;
258        }
259        Ok(current)
260    }
261
262    /// Chain registered history rewriters, skipping any that opt out of
263    /// the current trigger via `accepts()`.
264    pub async fn rewrite_history(
265        &self,
266        ctx: &RewriteContext<'_>,
267        input: HistoryState,
268    ) -> Result<HistoryState, HistoryError> {
269        let mut current = input;
270        for (_, registered) in &self.contributions.history_rewriters {
271            if !registered.hook.accepts(&ctx.trigger) {
272                continue;
273            }
274            current = registered.hook.rewrite(ctx, current).await?;
275        }
276        Ok(current)
277    }
278
279    pub async fn collect_prompt_contributions(
280        &self,
281        ctx: PromptHookContext,
282    ) -> Result<Vec<PromptContribution>, PluginError> {
283        let mut out = collect_owned_async(
284            &self.contributions.prompt_contributors,
285            ctx,
286            "prompt_contributor",
287            None,
288            |hook, ctx| hook(ctx),
289        )
290        .await?
291        .into_iter()
292        .map(|owned| owned.value)
293        .collect::<Vec<_>>();
294        let mut seen = BTreeSet::new();
295        out.retain(|contribution| {
296            seen.insert((
297                format!("{:?}", contribution.slot),
298                contribution.priority,
299                contribution.content.trim().to_string(),
300            ))
301        });
302        out.sort_by(|a, b| {
303            format!("{:?}", a.slot)
304                .cmp(&format!("{:?}", b.slot))
305                .then(a.priority.cmp(&b.priority))
306        });
307        Ok(out)
308    }
309
310    pub async fn before_turn(
311        &self,
312        ctx: TurnHookContext,
313    ) -> Result<Vec<PluginOwned<PluginDirective>>, PluginError> {
314        self.before_turn_with_phase_probe(ctx, None).await
315    }
316
317    async fn before_turn_with_phase_probe(
318        &self,
319        ctx: TurnHookContext,
320        phase_probe: Option<&Arc<dyn crate::runtime::RuntimeTurnPhaseProbe>>,
321    ) -> Result<Vec<PluginOwned<PluginDirective>>, PluginError> {
322        collect_owned_async(
323            &self.contributions.before_turn_hooks,
324            ctx,
325            "before_turn",
326            phase_probe,
327            |hook, ctx| hook(ctx),
328        )
329        .await
330    }
331
332    pub async fn before_tool_call(
333        &self,
334        ctx: ToolCallHookContext,
335    ) -> Result<Vec<PluginOwned<PluginDirective>>, PluginError> {
336        collect_owned_async(
337            &self.contributions.before_tool_call_hooks,
338            ctx,
339            "before_tool_call",
340            None,
341            |hook, ctx| hook(ctx),
342        )
343        .await
344    }
345
346    pub async fn after_tool_call(
347        &self,
348        ctx: ToolResultHookContext,
349    ) -> Result<Vec<PluginOwned<PluginDirective>>, PluginError> {
350        collect_owned_async(
351            &self.contributions.after_tool_call_hooks,
352            ctx,
353            "after_tool_call",
354            None,
355            |hook, ctx| hook(ctx),
356        )
357        .await
358    }
359
360    pub async fn after_turn(
361        &self,
362        ctx: TurnResultHookContext,
363    ) -> Result<Vec<PluginOwned<PluginDirective>>, PluginError> {
364        self.after_turn_with_phase_probe(ctx, None).await
365    }
366
367    async fn after_turn_with_phase_probe(
368        &self,
369        ctx: TurnResultHookContext,
370        phase_probe: Option<&Arc<dyn crate::runtime::RuntimeTurnPhaseProbe>>,
371    ) -> Result<Vec<PluginOwned<PluginDirective>>, PluginError> {
372        collect_owned_async(
373            &self.contributions.after_turn_hooks,
374            ctx,
375            "after_turn",
376            phase_probe,
377            |hook, ctx| hook(ctx),
378        )
379        .await
380    }
381
382    pub async fn at_checkpoint(
383        &self,
384        ctx: CheckpointHookContext,
385    ) -> Result<Vec<PluginOwned<PluginDirective>>, PluginError> {
386        collect_owned_async(
387            &self.contributions.checkpoint_hooks,
388            ctx,
389            "checkpoint",
390            None,
391            |hook, ctx| hook(ctx),
392        )
393        .await
394    }
395
396    pub async fn transform_assistant_stream(
397        &self,
398        session_id: &str,
399        chunk: String,
400    ) -> Result<Vec<PluginOwned<AssistantStreamTransform>>, PluginError> {
401        let mut current = chunk;
402        let mut transforms = Vec::new();
403        for registered in &self.contributions.assistant_stream_hooks {
404            let transform = (registered.hook)(AssistantStreamHookContext {
405                session_id: session_id.to_string(),
406                chunk: current.clone(),
407            })
408            .await?;
409            current = transform.chunk.clone();
410            transforms.push(PluginOwned {
411                plugin_id: registered.plugin_id.clone(),
412                value: transform,
413            });
414        }
415        Ok(transforms)
416    }
417
418    pub async fn transform_assistant_response(
419        &self,
420        session_id: &str,
421        response: crate::llm::types::LlmResponse,
422    ) -> Result<Vec<PluginOwned<AssistantResponseTransform>>, PluginError> {
423        let mut current = response;
424        let mut transforms = Vec::new();
425        for registered in &self.contributions.assistant_response_hooks {
426            let transform = (registered.hook)(AssistantResponseHookContext {
427                session_id: session_id.to_string(),
428                response: current.clone(),
429            })
430            .await?;
431            current = transform.response.clone();
432            transforms.push(PluginOwned {
433                plugin_id: registered.plugin_id.clone(),
434                value: transform,
435            });
436        }
437        Ok(transforms)
438    }
439
440    pub async fn project_tool_result(
441        &self,
442        ctx: ToolResultProjectionContext,
443    ) -> Result<crate::ModelToolReturn, PluginError> {
444        let Some(projector) = &self.contributions.tool_result_projector else {
445            return Ok(crate::ModelToolReturn::from_output(
446                ctx.call_id.clone(),
447                ctx.tool_name.clone(),
448                &ctx.output,
449            ));
450        };
451        (projector.hook)(ctx).await
452    }
453
454    pub async fn emit_runtime_event(&self, event: PluginLifecycleEvent<'_>) {
455        self.emit_runtime_event_with_phase_probe(event, None).await;
456    }
457
458    pub async fn emit_runtime_event_with_phase_probe(
459        &self,
460        event: PluginLifecycleEvent<'_>,
461        phase_probe: Option<Arc<dyn crate::runtime::RuntimeTurnPhaseProbe>>,
462    ) {
463        let hook_kind = lifecycle_event_hook_kind(&event);
464        let mut pending = FuturesUnordered::new();
465        for registered in &self.contributions.runtime_event_hooks {
466            let hook = Arc::clone(&registered.hook);
467            let plugin_id = registered.plugin_id.clone();
468            let phase_name = plugin_hook_phase_name(hook_kind, registered.plugin_id.as_str());
469            let event = event.clone();
470            let phase_probe = phase_probe.clone();
471            pending.push(async move {
472                if let Some(probe) = phase_probe.as_ref() {
473                    probe.begin_named(&phase_name);
474                }
475                let result = hook(event).await;
476                if let Some(probe) = phase_probe.as_ref() {
477                    probe.end_named(&phase_name);
478                }
479                (plugin_id, result)
480            });
481        }
482        while let Some((plugin_id, result)) = pending.next().await {
483            if let Err(err) = result {
484                tracing::warn!(plugin_id, "plugin runtime event hook failed: {err}");
485            }
486        }
487    }
488
489    pub fn has_runtime_event_hooks(&self) -> bool {
490        !self.contributions.runtime_event_hooks.is_empty()
491    }
492
493    pub async fn mutate_session_config(
494        &self,
495        ctx: SessionConfigChangedContext,
496        mut policy: SessionPolicy,
497    ) -> SessionPolicy {
498        for hook in &self.contributions.session_config_mutators {
499            match hook(ctx.clone(), policy.clone()).await {
500                Ok(next_policy) => policy = next_policy,
501                Err(err) => tracing::warn!("plugin config mutator failed: {err}"),
502            }
503        }
504        policy
505    }
506
507    pub fn snapshot(&self) -> Result<PluginSessionSnapshot, PluginError> {
508        let mut plugins = BTreeMap::new();
509        for plugin in &self.plugins {
510            let mut writer = InMemorySnapshotWriter::default();
511            let meta = plugin.snapshot(&mut writer)?;
512            plugins.insert(
513                plugin.id().to_string(),
514                PluginSnapshotEntry {
515                    meta,
516                    artifacts: writer.finish(),
517                },
518            );
519        }
520        Ok(PluginSessionSnapshot { plugins })
521    }
522
523    pub fn snapshot_is_current(&self, previous: Option<&PluginSessionSnapshot>) -> bool {
524        let Some(previous) = previous else {
525            return false;
526        };
527        if previous.plugins.len() != self.plugins.len() {
528            return false;
529        }
530        for plugin in &self.plugins {
531            let Some(entry) = previous.plugins.get(plugin.id()) else {
532                return false;
533            };
534            if entry.meta.plugin_version != plugin.version()
535                || entry.meta.revision != plugin.snapshot_revision()
536            {
537                return false;
538            }
539        }
540        true
541    }
542
543    pub fn snapshot_revision_fingerprint(&self) -> u64 {
544        let mut hasher = Sha256::new();
545        for plugin in &self.plugins {
546            hasher.update(plugin.id().as_bytes());
547            hasher.update([0]);
548            hasher.update(plugin.version().as_bytes());
549            hasher.update([0]);
550            hasher.update(plugin.snapshot_revision().to_le_bytes());
551            hasher.update([0xff]);
552        }
553        let digest = hasher.finalize();
554        u64::from_le_bytes(digest[..8].try_into().expect("digest prefix"))
555    }
556
557    pub fn restore(&self, snapshot: &PluginSessionSnapshot) -> Result<(), PluginError> {
558        for plugin in &self.plugins {
559            if let Some(entry) = snapshot.plugins.get(plugin.id()) {
560                let reader = InMemorySnapshotReader { entry };
561                plugin.restore(&entry.meta, &reader)?;
562            } else {
563                plugin.restore(
564                    &PluginSnapshotMeta {
565                        plugin_id: plugin.id().to_string(),
566                        plugin_version: plugin.version().to_string(),
567                        revision: plugin.snapshot_revision(),
568                        state: None,
569                    },
570                    &EmptySnapshotReader,
571                )?;
572            }
573        }
574        Ok(())
575    }
576
577    pub fn fork_for_session(
578        &self,
579        session_id: impl Into<String>,
580    ) -> Result<Arc<PluginSession>, PluginError> {
581        let snapshot = self.snapshot()?;
582        self.host.build_session_with_surface(
583            session_id,
584            Some(&snapshot),
585            self.tool_surface_overlay.clone(),
586            Some(self.tool_registry.export_state()),
587        )
588    }
589
590    pub fn fork_for_child_session(
591        &self,
592        session_id: impl Into<String>,
593        parent_session_id: Option<String>,
594        authority: super::SessionAuthorityContext,
595    ) -> Result<Arc<PluginSession>, PluginError> {
596        let snapshot = self.snapshot()?;
597        self.host.build_session_with_parent_and_surface(
598            session_id,
599            parent_session_id,
600            Some(&snapshot),
601            self.tool_surface_overlay.clone(),
602            Some(self.tool_registry.export_state()),
603            authority,
604        )
605    }
606
607    pub fn fork_for_session_with_tool_surface(
608        &self,
609        session_id: impl Into<String>,
610        tool_surface_overlay: ToolSurfaceContribution,
611    ) -> Result<Arc<PluginSession>, PluginError> {
612        let snapshot = self.snapshot()?;
613        self.host.build_session_with_surface(
614            session_id,
615            Some(&snapshot),
616            tool_surface_overlay,
617            Some(self.tool_registry.export_state()),
618        )
619    }
620
621    #[expect(
622        clippy::too_many_arguments,
623        reason = "plugin action invocation carries the explicit host services exposed to actions"
624    )]
625    pub async fn invoke_plugin_action(
626        &self,
627        name: &str,
628        args: serde_json::Value,
629        session_id: Option<String>,
630        default_to_current_session: bool,
631        sessions: Arc<dyn SessionStateService>,
632        session_lifecycle: Arc<dyn SessionLifecycleService>,
633        session_graph: Arc<dyn SessionGraphService>,
634        processes: Arc<dyn crate::ProcessService>,
635    ) -> Result<ToolResult, PluginActionInvokeError> {
636        let Some(op) = self.contributions.plugin_actions.get(name).cloned() else {
637            return Err(PluginActionInvokeError::Unknown(name.to_string()));
638        };
639
640        let effective_session = session_id.or_else(|| {
641            if default_to_current_session && !self.session_id.is_empty() {
642                Some(self.session_id.clone())
643            } else {
644                None
645            }
646        });
647
648        match (op.def.session_param, effective_session.as_ref()) {
649            (SessionParam::Required, None) => {
650                return Err(PluginActionInvokeError::MissingSession(name.to_string()));
651            }
652            (SessionParam::Forbidden, Some(_)) => {
653                return Err(PluginActionInvokeError::UnexpectedSession(name.to_string()));
654            }
655            _ => {}
656        }
657
658        Ok((op.handler)(
659            PluginActionContext {
660                session_id: effective_session,
661                sessions,
662                session_lifecycle,
663                session_graph,
664                processes,
665            },
666            args,
667        )
668        .await)
669    }
670
671    #[expect(
672        clippy::too_many_arguments,
673        reason = "typed action invocation mirrors the raw plugin action host service boundary"
674    )]
675    pub async fn call_plugin_action<Op: PluginAction>(
676        &self,
677        args: Op::Args,
678        session_id: Option<String>,
679        default_to_current_session: bool,
680        sessions: Arc<dyn SessionStateService>,
681        session_lifecycle: Arc<dyn SessionLifecycleService>,
682        session_graph: Arc<dyn SessionGraphService>,
683        processes: Arc<dyn crate::ProcessService>,
684    ) -> Result<Op::Output, PluginError> {
685        let args = serde_json::to_value(args)
686            .map_err(|err| PluginError::Invoke(format!("invalid {} args: {err}", Op::NAME)))?;
687        let result = self
688            .invoke_plugin_action(
689                Op::NAME,
690                args,
691                session_id,
692                default_to_current_session,
693                sessions,
694                session_lifecycle,
695                session_graph,
696                processes,
697            )
698            .await
699            .map_err(|err| PluginError::Invoke(err.to_string()))?;
700        if !result.is_success() {
701            return Err(PluginError::Invoke(format!(
702                "{} failed: {}",
703                Op::NAME,
704                result.value_for_projection()
705            )));
706        }
707        serde_json::from_value(result.into_output().value_for_projection())
708            .map_err(|err| PluginError::Invoke(format!("invalid {} output: {err}", Op::NAME)))
709    }
710}