Skip to main content

lash_core/plugin/
session_obj.rs

1use std::collections::{BTreeMap, BTreeSet};
2use std::sync::Arc;
3
4use futures_util::stream::{FuturesUnordered, StreamExt};
5use sha2::{Digest, Sha256};
6
7use super::*;
8
9mod directives;
10mod tools;
11
12async fn collect_owned_async<C, O, H, F>(
13    hooks: &[RegisteredHook<H>],
14    ctx: C,
15    hook_kind: &'static str,
16    phase_probe: Option<&Arc<dyn crate::runtime::RuntimeTurnPhaseProbe>>,
17    invoke: F,
18) -> Result<Vec<PluginOwned<O>>, PluginError>
19where
20    C: Clone,
21    F: Fn(&H, C) -> PluginFuture<Vec<O>>,
22{
23    let mut out = Vec::new();
24    for registered in hooks {
25        let phase_name = plugin_hook_phase_name(hook_kind, &registered.plugin_id);
26        if let Some(probe) = phase_probe {
27            probe.begin_named(&phase_name);
28        }
29        let result = invoke(&registered.hook, ctx.clone()).await;
30        if let Some(probe) = phase_probe {
31            probe.end_named(&phase_name);
32        }
33        for value in result? {
34            out.push(PluginOwned {
35                plugin_id: registered.plugin_id.clone(),
36                value,
37            });
38        }
39    }
40    Ok(out)
41}
42
43fn plugin_hook_phase_name(hook_kind: &str, plugin_id: &str) -> String {
44    format!("plugin_hook.{hook_kind}.{plugin_id}")
45}
46
47fn lifecycle_event_hook_kind(event: &PluginLifecycleEvent<'_>) -> &'static str {
48    match event {
49        PluginLifecycleEvent::TurnFinalized(_) => "turn_finalized",
50        PluginLifecycleEvent::TurnPersisted(_) => "turn_persisted",
51        PluginLifecycleEvent::SessionRestored(_) => "session_restored",
52        PluginLifecycleEvent::SessionConfigChanged(_) => "session_config_changed",
53    }
54}
55
56fn collect_owned_sync<C, O, H, F>(
57    hooks: &[RegisteredHook<H>],
58    ctx: C,
59    invoke: F,
60) -> Result<Vec<PluginOwned<O>>, PluginError>
61where
62    C: Clone,
63    F: Fn(&H, C) -> Result<O, PluginError>,
64{
65    let mut out = Vec::new();
66    for registered in hooks {
67        out.push(PluginOwned {
68            plugin_id: registered.plugin_id.clone(),
69            value: invoke(&registered.hook, ctx.clone())?,
70        });
71    }
72    Ok(out)
73}
74
75struct EmptySnapshotReader;
76
77impl SnapshotReader for EmptySnapshotReader {
78    fn read_blob(&self, _name: &str) -> Option<&[u8]> {
79        None
80    }
81}
82
83pub struct PluginSession {
84    pub(super) host: PluginHost,
85    pub(super) session_id: String,
86    pub(super) plugins: Vec<Arc<dyn SessionPlugin>>,
87    pub(super) tools: Arc<dyn ToolProvider>,
88    pub(super) tool_registry: Arc<crate::ToolRegistry>,
89    pub(super) tool_catalog_overlay: ToolCatalogContribution,
90    pub(super) tool_access: SessionToolAccess,
91    pub(super) subagent: Option<SubagentSessionContext>,
92    pub(super) extensions: PluginExtensions,
93    pub(super) triggers: crate::TriggerEventCatalog,
94    pub(super) contributions: PluginContributions,
95}
96impl PluginSession {
97    pub fn session_id(&self) -> &str {
98        &self.session_id
99    }
100
101    pub fn tool_access(&self) -> &SessionToolAccess {
102        &self.tool_access
103    }
104
105    pub fn subagent_context(&self) -> Option<&SubagentSessionContext> {
106        self.subagent.as_ref()
107    }
108
109    pub fn extensions(&self) -> &PluginExtensions {
110        &self.extensions
111    }
112
113    pub fn triggers(&self) -> &crate::TriggerEventCatalog {
114        &self.triggers
115    }
116
117    pub fn host(&self) -> &PluginHost {
118        &self.host
119    }
120
121    pub fn tools(&self) -> Arc<dyn ToolProvider> {
122        Arc::clone(&self.tools)
123    }
124
125    pub fn tool_registry(&self) -> Arc<crate::ToolRegistry> {
126        Arc::clone(&self.tool_registry)
127    }
128
129    pub(crate) fn protocol_session(&self) -> &Arc<dyn ProtocolSessionPlugin> {
130        &self
131            .contributions
132            .protocol_session
133            .as_ref()
134            .expect("plugin session must have a protocol session")
135            .hook
136    }
137
138    pub(crate) fn code_executor(&self) -> Option<Arc<dyn CodeExecutorPlugin>> {
139        self.contributions
140            .code_executor
141            .as_ref()
142            .map(|entry| Arc::clone(&entry.hook))
143    }
144
145    pub(crate) fn assistant_prose_projector(
146        &self,
147    ) -> Option<Arc<dyn AssistantProseProjectorPlugin>> {
148        self.contributions
149            .assistant_prose_projector
150            .as_ref()
151            .map(|entry| Arc::clone(&entry.hook))
152    }
153
154    pub fn protocol_driver(&self) -> Arc<dyn ProtocolDriverPlugin> {
155        self.contributions
156            .protocol_driver
157            .as_ref()
158            .map(|entry| Arc::clone(&entry.hook))
159            .expect("plugin session must have a protocol driver")
160    }
161
162    pub fn plugin_actions(&self) -> Vec<PluginActionDef> {
163        self.contributions
164            .plugin_actions
165            .values()
166            .map(|op| op.def.clone())
167            .collect()
168    }
169
170    pub fn has_assistant_stream_hooks(&self) -> bool {
171        !self.contributions.assistant_stream_hooks.is_empty()
172    }
173
174    /// Chain registered turn-context transforms, piping each one's output
175    /// into the next in priority order.
176    pub async fn prepare_turn_context(
177        &self,
178        ctx: &TurnTransformContext<'_>,
179        input: crate::session_model::context::PreparedContext,
180        phase_probe: Option<Arc<dyn crate::runtime::RuntimeTurnPhaseProbe>>,
181    ) -> Result<crate::session_model::context::PreparedContext, ContextError> {
182        let mut current = input;
183        for (_, registered) in &self.contributions.turn_context_transforms {
184            let phase_name =
185                plugin_hook_phase_name("context_transform", registered.plugin_id.as_str());
186            if let Some(probe) = phase_probe.as_ref() {
187                probe.begin_named(&phase_name);
188            }
189            let result = registered.hook.transform(ctx, current).await;
190            if let Some(probe) = phase_probe.as_ref() {
191                probe.end_named(&phase_name);
192            }
193            current = result?;
194        }
195        Ok(current)
196    }
197
198    /// Ask registered compactors for seed nodes for a new compaction frame.
199    pub async fn compact_context(
200        &self,
201        ctx: &CompactionContext<'_>,
202    ) -> Result<Option<ContextCompaction>, ContextError> {
203        for (_, registered) in &self.contributions.context_compactors {
204            if let Some(compaction) = registered.hook.compact(ctx).await?
205                && !compaction.is_empty()
206            {
207                return Ok(Some(compaction));
208            }
209        }
210        Ok(None)
211    }
212
213    pub async fn collect_prompt_contributions(
214        &self,
215        ctx: PromptHookContext,
216    ) -> Result<Vec<PromptContribution>, PluginError> {
217        let mut out = collect_owned_async(
218            &self.contributions.prompt_contributors,
219            ctx,
220            "prompt_contributor",
221            None,
222            |hook, ctx| hook(ctx),
223        )
224        .await?
225        .into_iter()
226        .map(|owned| owned.value)
227        .collect::<Vec<_>>();
228        let mut seen = BTreeSet::new();
229        out.retain(|contribution| {
230            seen.insert((
231                format!("{:?}", contribution.slot),
232                contribution.priority,
233                contribution.content.trim().to_string(),
234            ))
235        });
236        out.sort_by(|a, b| {
237            format!("{:?}", a.slot)
238                .cmp(&format!("{:?}", b.slot))
239                .then(a.priority.cmp(&b.priority))
240        });
241        Ok(out)
242    }
243
244    pub async fn before_turn(
245        &self,
246        ctx: TurnHookContext,
247    ) -> Result<Vec<PluginOwned<PluginDirective>>, PluginError> {
248        self.before_turn_with_phase_probe(ctx, None).await
249    }
250
251    async fn before_turn_with_phase_probe(
252        &self,
253        ctx: TurnHookContext,
254        phase_probe: Option<&Arc<dyn crate::runtime::RuntimeTurnPhaseProbe>>,
255    ) -> Result<Vec<PluginOwned<PluginDirective>>, PluginError> {
256        collect_owned_async(
257            &self.contributions.before_turn_hooks,
258            ctx,
259            "before_turn",
260            phase_probe,
261            |hook, ctx| hook(ctx),
262        )
263        .await
264    }
265
266    pub async fn before_tool_call(
267        &self,
268        ctx: ToolCallHookContext,
269    ) -> Result<Vec<PluginOwned<PluginDirective>>, PluginError> {
270        collect_owned_async(
271            &self.contributions.before_tool_call_hooks,
272            ctx,
273            "before_tool_call",
274            None,
275            |hook, ctx| hook(ctx),
276        )
277        .await
278    }
279
280    pub async fn after_tool_call(
281        &self,
282        ctx: ToolResultHookContext,
283    ) -> Result<Vec<PluginOwned<PluginDirective>>, PluginError> {
284        collect_owned_async(
285            &self.contributions.after_tool_call_hooks,
286            ctx,
287            "after_tool_call",
288            None,
289            |hook, ctx| hook(ctx),
290        )
291        .await
292    }
293
294    pub async fn after_turn(
295        &self,
296        ctx: TurnResultHookContext,
297    ) -> Result<Vec<PluginOwned<PluginDirective>>, PluginError> {
298        self.after_turn_with_phase_probe(ctx, None).await
299    }
300
301    async fn after_turn_with_phase_probe(
302        &self,
303        ctx: TurnResultHookContext,
304        phase_probe: Option<&Arc<dyn crate::runtime::RuntimeTurnPhaseProbe>>,
305    ) -> Result<Vec<PluginOwned<PluginDirective>>, PluginError> {
306        collect_owned_async(
307            &self.contributions.after_turn_hooks,
308            ctx,
309            "after_turn",
310            phase_probe,
311            |hook, ctx| hook(ctx),
312        )
313        .await
314    }
315
316    pub async fn at_checkpoint(
317        &self,
318        ctx: CheckpointHookContext,
319    ) -> Result<Vec<PluginOwned<PluginDirective>>, PluginError> {
320        collect_owned_async(
321            &self.contributions.checkpoint_hooks,
322            ctx,
323            "checkpoint",
324            None,
325            |hook, ctx| hook(ctx),
326        )
327        .await
328    }
329
330    pub async fn transform_assistant_stream(
331        &self,
332        session_id: &str,
333        chunk: String,
334    ) -> Result<Vec<PluginOwned<AssistantStreamTransform>>, PluginError> {
335        let mut current = chunk;
336        let mut transforms = Vec::new();
337        for registered in &self.contributions.assistant_stream_hooks {
338            let transform = (registered.hook)(AssistantStreamHookContext {
339                session_id: session_id.to_string(),
340                chunk: current.clone(),
341            })
342            .await?;
343            current = transform.chunk.clone();
344            transforms.push(PluginOwned {
345                plugin_id: registered.plugin_id.clone(),
346                value: transform,
347            });
348        }
349        Ok(transforms)
350    }
351
352    pub async fn transform_assistant_response(
353        &self,
354        session_id: &str,
355        response: crate::llm::types::LlmResponse,
356    ) -> Result<Vec<PluginOwned<AssistantResponseTransform>>, PluginError> {
357        let mut current = response;
358        let mut transforms = Vec::new();
359        for registered in &self.contributions.assistant_response_hooks {
360            let transform = (registered.hook)(AssistantResponseHookContext {
361                session_id: session_id.to_string(),
362                response: current.clone(),
363            })
364            .await?;
365            current = transform.response.clone();
366            transforms.push(PluginOwned {
367                plugin_id: registered.plugin_id.clone(),
368                value: transform,
369            });
370        }
371        Ok(transforms)
372    }
373
374    pub async fn project_tool_result(
375        &self,
376        ctx: ToolResultProjectionContext,
377    ) -> Result<crate::ModelToolReturn, PluginError> {
378        let Some(projector) = &self.contributions.tool_result_projector else {
379            return Ok(crate::ModelToolReturn::from_output(
380                ctx.call_id.clone(),
381                ctx.tool_name.clone(),
382                &ctx.output,
383            ));
384        };
385        (projector.hook)(ctx).await
386    }
387
388    pub async fn emit_runtime_event(&self, event: PluginLifecycleEvent<'_>) {
389        self.emit_runtime_event_with_phase_probe(event, None).await;
390    }
391
392    pub async fn emit_runtime_event_with_phase_probe(
393        &self,
394        event: PluginLifecycleEvent<'_>,
395        phase_probe: Option<Arc<dyn crate::runtime::RuntimeTurnPhaseProbe>>,
396    ) {
397        let hook_kind = lifecycle_event_hook_kind(&event);
398        let mut pending = FuturesUnordered::new();
399        for registered in &self.contributions.runtime_event_hooks {
400            let hook = Arc::clone(&registered.hook);
401            let plugin_id = registered.plugin_id.clone();
402            let phase_name = plugin_hook_phase_name(hook_kind, registered.plugin_id.as_str());
403            let event = event.clone();
404            let phase_probe = phase_probe.clone();
405            pending.push(async move {
406                if let Some(probe) = phase_probe.as_ref() {
407                    probe.begin_named(&phase_name);
408                }
409                let result = hook(event).await;
410                if let Some(probe) = phase_probe.as_ref() {
411                    probe.end_named(&phase_name);
412                }
413                (plugin_id, result)
414            });
415        }
416        while let Some((plugin_id, result)) = pending.next().await {
417            if let Err(err) = result {
418                tracing::warn!(plugin_id, "plugin runtime event hook failed: {err}");
419            }
420        }
421    }
422
423    pub fn has_runtime_event_hooks(&self) -> bool {
424        !self.contributions.runtime_event_hooks.is_empty()
425    }
426
427    pub async fn mutate_session_config(
428        &self,
429        ctx: SessionConfigChangedContext,
430        mut policy: SessionPolicy,
431    ) -> SessionPolicy {
432        for hook in &self.contributions.session_config_mutators {
433            match hook(ctx.clone(), policy.clone()).await {
434                Ok(next_policy) => policy = next_policy,
435                Err(err) => tracing::warn!("plugin config mutator failed: {err}"),
436            }
437        }
438        policy
439    }
440
441    pub fn snapshot(&self) -> Result<PluginSessionSnapshot, PluginError> {
442        let mut plugins = BTreeMap::new();
443        for plugin in &self.plugins {
444            let mut writer = InMemorySnapshotWriter::default();
445            let meta = plugin.snapshot(&mut writer)?;
446            plugins.insert(
447                plugin.id().to_string(),
448                PluginSnapshotEntry {
449                    meta,
450                    artifacts: writer.finish(),
451                },
452            );
453        }
454        Ok(PluginSessionSnapshot { plugins })
455    }
456
457    pub fn snapshot_is_current(&self, previous: Option<&PluginSessionSnapshot>) -> bool {
458        let Some(previous) = previous else {
459            return false;
460        };
461        if previous.plugins.len() != self.plugins.len() {
462            return false;
463        }
464        for plugin in &self.plugins {
465            let Some(entry) = previous.plugins.get(plugin.id()) else {
466                return false;
467            };
468            if entry.meta.plugin_version != plugin.version()
469                || entry.meta.revision != plugin.snapshot_revision()
470            {
471                return false;
472            }
473        }
474        true
475    }
476
477    pub fn snapshot_revision_fingerprint(&self) -> u64 {
478        let mut hasher = Sha256::new();
479        for plugin in &self.plugins {
480            hasher.update(plugin.id().as_bytes());
481            hasher.update([0]);
482            hasher.update(plugin.version().as_bytes());
483            hasher.update([0]);
484            hasher.update(plugin.snapshot_revision().to_le_bytes());
485            hasher.update([0xff]);
486        }
487        let digest = hasher.finalize();
488        u64::from_le_bytes(digest[..8].try_into().expect("digest prefix"))
489    }
490
491    pub fn restore(&self, snapshot: &PluginSessionSnapshot) -> Result<(), PluginError> {
492        for plugin in &self.plugins {
493            if let Some(entry) = snapshot.plugins.get(plugin.id()) {
494                let reader = InMemorySnapshotReader { entry };
495                plugin.restore(&entry.meta, &reader)?;
496            } else {
497                plugin.restore(
498                    &PluginSnapshotMeta {
499                        plugin_id: plugin.id().to_string(),
500                        plugin_version: plugin.version().to_string(),
501                        revision: plugin.snapshot_revision(),
502                        state: None,
503                    },
504                    &EmptySnapshotReader,
505                )?;
506            }
507        }
508        Ok(())
509    }
510
511    pub fn fork_for_session(
512        &self,
513        session_id: impl Into<String>,
514    ) -> Result<Arc<PluginSession>, PluginError> {
515        let snapshot = self.snapshot()?;
516        self.host.build_session_with_overlay(
517            session_id,
518            Some(&snapshot),
519            self.tool_catalog_overlay.clone(),
520            Some(self.tool_registry.export_state()),
521        )
522    }
523
524    pub fn fork_for_child_session(
525        &self,
526        session_id: impl Into<String>,
527        parent_session_id: Option<String>,
528        authority: super::SessionAuthorityContext,
529    ) -> Result<Arc<PluginSession>, PluginError> {
530        let snapshot = self.snapshot()?;
531        self.host.build_session_with_parent_and_overlay(
532            session_id,
533            parent_session_id,
534            Some(&snapshot),
535            self.tool_catalog_overlay.clone(),
536            Some(self.tool_registry.export_state()),
537            authority,
538        )
539    }
540
541    pub fn fork_for_session_with_tool_catalog(
542        &self,
543        session_id: impl Into<String>,
544        tool_catalog_overlay: ToolCatalogContribution,
545    ) -> Result<Arc<PluginSession>, PluginError> {
546        let snapshot = self.snapshot()?;
547        self.host.build_session_with_overlay(
548            session_id,
549            Some(&snapshot),
550            tool_catalog_overlay,
551            Some(self.tool_registry.export_state()),
552        )
553    }
554
555    #[expect(
556        clippy::too_many_arguments,
557        reason = "plugin action invocation carries the explicit host services exposed to actions"
558    )]
559    pub async fn invoke_plugin_action(
560        &self,
561        name: &str,
562        args: serde_json::Value,
563        session_id: Option<String>,
564        default_to_current_session: bool,
565        sessions: Arc<dyn SessionStateService>,
566        session_lifecycle: Arc<dyn SessionLifecycleService>,
567        session_graph: Arc<dyn SessionGraphService>,
568        processes: Arc<dyn crate::ProcessService>,
569    ) -> Result<ToolResult, PluginActionInvokeError> {
570        let Some(op) = self.contributions.plugin_actions.get(name).cloned() else {
571            return Err(PluginActionInvokeError::Unknown(name.to_string()));
572        };
573
574        let effective_session = session_id.or_else(|| {
575            if default_to_current_session && !self.session_id.is_empty() {
576                Some(self.session_id.clone())
577            } else {
578                None
579            }
580        });
581
582        match (op.def.session_param, effective_session.as_ref()) {
583            (SessionParam::Required, None) => {
584                return Err(PluginActionInvokeError::MissingSession(name.to_string()));
585            }
586            (SessionParam::Forbidden, Some(_)) => {
587                return Err(PluginActionInvokeError::UnexpectedSession(name.to_string()));
588            }
589            _ => {}
590        }
591
592        Ok((op.handler)(
593            PluginActionContext {
594                session_id: effective_session,
595                sessions,
596                session_lifecycle,
597                session_graph,
598                processes,
599            },
600            args,
601        )
602        .await)
603    }
604
605    #[expect(
606        clippy::too_many_arguments,
607        reason = "typed action invocation mirrors the raw plugin action host service boundary"
608    )]
609    pub async fn call_plugin_action<Op: PluginAction>(
610        &self,
611        args: Op::Args,
612        session_id: Option<String>,
613        default_to_current_session: bool,
614        sessions: Arc<dyn SessionStateService>,
615        session_lifecycle: Arc<dyn SessionLifecycleService>,
616        session_graph: Arc<dyn SessionGraphService>,
617        processes: Arc<dyn crate::ProcessService>,
618    ) -> Result<Op::Output, PluginError> {
619        let args = serde_json::to_value(args)
620            .map_err(|err| PluginError::Invoke(format!("invalid {} args: {err}", Op::NAME)))?;
621        let result = self
622            .invoke_plugin_action(
623                Op::NAME,
624                args,
625                session_id,
626                default_to_current_session,
627                sessions,
628                session_lifecycle,
629                session_graph,
630                processes,
631            )
632            .await
633            .map_err(|err| PluginError::Invoke(err.to_string()))?;
634        let Some(output) = result.as_done_output() else {
635            return Err(PluginError::Invoke(format!(
636                "{} returned a pending result where completed output is required",
637                Op::NAME
638            )));
639        };
640        if !output.is_success() {
641            return Err(PluginError::Invoke(format!(
642                "{} failed: {}",
643                Op::NAME,
644                output.value_for_projection()
645            )));
646        }
647        serde_json::from_value(output.value_for_projection())
648            .map_err(|err| PluginError::Invoke(format!("invalid {} output: {err}", Op::NAME)))
649    }
650}