Skip to main content

lash_core/plugin/
session_obj.rs

1use std::collections::{BTreeMap, BTreeSet};
2use std::sync::Arc;
3
4use futures_util::stream::{FuturesUnordered, StreamExt};
5use sha2::{Digest, Sha256};
6
7use super::*;
8
9mod directives;
10mod tools;
11
12async fn collect_owned_async<C, O, H, F>(
13    hooks: &[RegisteredHook<H>],
14    ctx: C,
15    hook_kind: &'static str,
16    phase_probe: Option<&Arc<dyn crate::runtime::RuntimeTurnPhaseProbe>>,
17    invoke: F,
18) -> Result<Vec<PluginOwned<O>>, PluginError>
19where
20    C: Clone,
21    F: Fn(&H, C) -> PluginFuture<Vec<O>>,
22{
23    let mut out = Vec::new();
24    for registered in hooks {
25        let phase_name = plugin_hook_phase_name(hook_kind, &registered.plugin_id);
26        if let Some(probe) = phase_probe {
27            probe.begin_named(&phase_name);
28        }
29        let result = invoke(&registered.hook, ctx.clone()).await;
30        if let Some(probe) = phase_probe {
31            probe.end_named(&phase_name);
32        }
33        for value in result? {
34            out.push(PluginOwned {
35                plugin_id: registered.plugin_id.clone(),
36                value,
37            });
38        }
39    }
40    Ok(out)
41}
42
43fn plugin_hook_phase_name(hook_kind: &str, plugin_id: &str) -> String {
44    format!("plugin_hook.{hook_kind}.{plugin_id}")
45}
46
47fn lifecycle_event_hook_kind(event: &PluginLifecycleEvent<'_>) -> &'static str {
48    match event {
49        PluginLifecycleEvent::TurnFinalized(_) => "turn_finalized",
50        PluginLifecycleEvent::TurnPersisted(_) => "turn_persisted",
51        PluginLifecycleEvent::SessionRestored(_) => "session_restored",
52        PluginLifecycleEvent::SessionConfigChanged(_) => "session_config_changed",
53    }
54}
55
56fn collect_owned_sync<C, O, H, F>(
57    hooks: &[RegisteredHook<H>],
58    ctx: C,
59    invoke: F,
60) -> Result<Vec<PluginOwned<O>>, PluginError>
61where
62    C: Clone,
63    F: Fn(&H, C) -> Result<O, PluginError>,
64{
65    let mut out = Vec::new();
66    for registered in hooks {
67        out.push(PluginOwned {
68            plugin_id: registered.plugin_id.clone(),
69            value: invoke(&registered.hook, ctx.clone())?,
70        });
71    }
72    Ok(out)
73}
74
75struct EmptySnapshotReader;
76
77impl SnapshotReader for EmptySnapshotReader {
78    fn read_blob(&self, _name: &str) -> Option<&[u8]> {
79        None
80    }
81}
82
83pub struct PluginSession {
84    pub(super) host: PluginHost,
85    pub(super) session_id: String,
86    pub(super) plugins: Vec<Arc<dyn SessionPlugin>>,
87    pub(super) tools: Arc<dyn ToolProvider>,
88    pub(super) tool_registry: Arc<crate::ToolRegistry>,
89    pub(super) tool_surface_overlay: ToolSurfaceContribution,
90    pub(super) tool_access: SessionToolAccess,
91    pub(super) subagent: Option<SubagentSessionContext>,
92    pub(super) lashlang_abilities: lashlang::LashlangAbilities,
93    pub(super) lashlang_language_features: lashlang::LashlangLanguageFeatures,
94    pub(super) lashlang_resources: lashlang::ResourceCatalog,
95    pub(super) host_events: crate::HostEventCatalog,
96    pub(super) contributions: PluginContributions,
97}
98impl PluginSession {
99    pub fn session_id(&self) -> &str {
100        &self.session_id
101    }
102
103    pub fn tool_access(&self) -> &SessionToolAccess {
104        &self.tool_access
105    }
106
107    pub fn subagent_context(&self) -> Option<&SubagentSessionContext> {
108        self.subagent.as_ref()
109    }
110
111    pub fn lashlang_abilities(&self) -> lashlang::LashlangAbilities {
112        self.lashlang_abilities
113    }
114
115    pub fn lashlang_language_features(&self) -> lashlang::LashlangLanguageFeatures {
116        self.lashlang_language_features
117    }
118
119    pub fn lashlang_resources(&self) -> lashlang::ResourceCatalog {
120        self.lashlang_resources.clone()
121    }
122
123    pub fn host_events(&self) -> &crate::HostEventCatalog {
124        &self.host_events
125    }
126
127    pub fn host(&self) -> &PluginHost {
128        &self.host
129    }
130
131    pub fn tools(&self) -> Arc<dyn ToolProvider> {
132        Arc::clone(&self.tools)
133    }
134
135    pub fn tool_registry(&self) -> Arc<crate::ToolRegistry> {
136        Arc::clone(&self.tool_registry)
137    }
138
139    pub(crate) fn protocol_session(&self) -> &Arc<dyn ProtocolSessionPlugin> {
140        &self
141            .contributions
142            .protocol_session
143            .as_ref()
144            .expect("plugin session must have a protocol session")
145            .hook
146    }
147
148    pub(crate) fn code_executor(&self) -> Option<Arc<dyn CodeExecutorPlugin>> {
149        self.contributions
150            .code_executor
151            .as_ref()
152            .map(|entry| Arc::clone(&entry.hook))
153    }
154
155    pub(crate) fn assistant_prose_projector(
156        &self,
157    ) -> Option<Arc<dyn AssistantProseProjectorPlugin>> {
158        self.contributions
159            .assistant_prose_projector
160            .as_ref()
161            .map(|entry| Arc::clone(&entry.hook))
162    }
163
164    pub fn protocol_driver(&self) -> Arc<dyn ProtocolDriverPlugin> {
165        self.contributions
166            .protocol_driver
167            .as_ref()
168            .map(|entry| Arc::clone(&entry.hook))
169            .expect("plugin session must have a protocol driver")
170    }
171
172    pub fn plugin_actions(&self) -> Vec<PluginActionDef> {
173        self.contributions
174            .plugin_actions
175            .values()
176            .map(|op| op.def.clone())
177            .collect()
178    }
179
180    pub fn has_assistant_stream_hooks(&self) -> bool {
181        !self.contributions.assistant_stream_hooks.is_empty()
182    }
183
184    /// Chain registered turn-context transforms, piping each one's output
185    /// into the next in priority order.
186    pub async fn prepare_turn_context(
187        &self,
188        ctx: &TurnTransformContext<'_>,
189        input: crate::session_model::context::PreparedContext,
190        phase_probe: Option<Arc<dyn crate::runtime::RuntimeTurnPhaseProbe>>,
191    ) -> Result<crate::session_model::context::PreparedContext, ContextError> {
192        let mut current = input;
193        for (_, registered) in &self.contributions.turn_context_transforms {
194            let phase_name =
195                plugin_hook_phase_name("context_transform", registered.plugin_id.as_str());
196            if let Some(probe) = phase_probe.as_ref() {
197                probe.begin_named(&phase_name);
198            }
199            let result = registered.hook.transform(ctx, current).await;
200            if let Some(probe) = phase_probe.as_ref() {
201                probe.end_named(&phase_name);
202            }
203            current = result?;
204        }
205        Ok(current)
206    }
207
208    /// Ask registered compactors for seed nodes for a new compaction frame.
209    pub async fn compact_context(
210        &self,
211        ctx: &CompactionContext<'_>,
212    ) -> Result<Option<ContextCompaction>, ContextError> {
213        for (_, registered) in &self.contributions.context_compactors {
214            if let Some(compaction) = registered.hook.compact(ctx).await?
215                && !compaction.is_empty()
216            {
217                return Ok(Some(compaction));
218            }
219        }
220        Ok(None)
221    }
222
223    pub async fn collect_prompt_contributions(
224        &self,
225        ctx: PromptHookContext,
226    ) -> Result<Vec<PromptContribution>, PluginError> {
227        let mut out = collect_owned_async(
228            &self.contributions.prompt_contributors,
229            ctx,
230            "prompt_contributor",
231            None,
232            |hook, ctx| hook(ctx),
233        )
234        .await?
235        .into_iter()
236        .map(|owned| owned.value)
237        .collect::<Vec<_>>();
238        let mut seen = BTreeSet::new();
239        out.retain(|contribution| {
240            seen.insert((
241                format!("{:?}", contribution.slot),
242                contribution.priority,
243                contribution.content.trim().to_string(),
244            ))
245        });
246        out.sort_by(|a, b| {
247            format!("{:?}", a.slot)
248                .cmp(&format!("{:?}", b.slot))
249                .then(a.priority.cmp(&b.priority))
250        });
251        Ok(out)
252    }
253
254    pub async fn before_turn(
255        &self,
256        ctx: TurnHookContext,
257    ) -> Result<Vec<PluginOwned<PluginDirective>>, PluginError> {
258        self.before_turn_with_phase_probe(ctx, None).await
259    }
260
261    async fn before_turn_with_phase_probe(
262        &self,
263        ctx: TurnHookContext,
264        phase_probe: Option<&Arc<dyn crate::runtime::RuntimeTurnPhaseProbe>>,
265    ) -> Result<Vec<PluginOwned<PluginDirective>>, PluginError> {
266        collect_owned_async(
267            &self.contributions.before_turn_hooks,
268            ctx,
269            "before_turn",
270            phase_probe,
271            |hook, ctx| hook(ctx),
272        )
273        .await
274    }
275
276    pub async fn before_tool_call(
277        &self,
278        ctx: ToolCallHookContext,
279    ) -> Result<Vec<PluginOwned<PluginDirective>>, PluginError> {
280        collect_owned_async(
281            &self.contributions.before_tool_call_hooks,
282            ctx,
283            "before_tool_call",
284            None,
285            |hook, ctx| hook(ctx),
286        )
287        .await
288    }
289
290    pub async fn after_tool_call(
291        &self,
292        ctx: ToolResultHookContext,
293    ) -> Result<Vec<PluginOwned<PluginDirective>>, PluginError> {
294        collect_owned_async(
295            &self.contributions.after_tool_call_hooks,
296            ctx,
297            "after_tool_call",
298            None,
299            |hook, ctx| hook(ctx),
300        )
301        .await
302    }
303
304    pub async fn after_turn(
305        &self,
306        ctx: TurnResultHookContext,
307    ) -> Result<Vec<PluginOwned<PluginDirective>>, PluginError> {
308        self.after_turn_with_phase_probe(ctx, None).await
309    }
310
311    async fn after_turn_with_phase_probe(
312        &self,
313        ctx: TurnResultHookContext,
314        phase_probe: Option<&Arc<dyn crate::runtime::RuntimeTurnPhaseProbe>>,
315    ) -> Result<Vec<PluginOwned<PluginDirective>>, PluginError> {
316        collect_owned_async(
317            &self.contributions.after_turn_hooks,
318            ctx,
319            "after_turn",
320            phase_probe,
321            |hook, ctx| hook(ctx),
322        )
323        .await
324    }
325
326    pub async fn at_checkpoint(
327        &self,
328        ctx: CheckpointHookContext,
329    ) -> Result<Vec<PluginOwned<PluginDirective>>, PluginError> {
330        collect_owned_async(
331            &self.contributions.checkpoint_hooks,
332            ctx,
333            "checkpoint",
334            None,
335            |hook, ctx| hook(ctx),
336        )
337        .await
338    }
339
340    pub async fn transform_assistant_stream(
341        &self,
342        session_id: &str,
343        chunk: String,
344    ) -> Result<Vec<PluginOwned<AssistantStreamTransform>>, PluginError> {
345        let mut current = chunk;
346        let mut transforms = Vec::new();
347        for registered in &self.contributions.assistant_stream_hooks {
348            let transform = (registered.hook)(AssistantStreamHookContext {
349                session_id: session_id.to_string(),
350                chunk: current.clone(),
351            })
352            .await?;
353            current = transform.chunk.clone();
354            transforms.push(PluginOwned {
355                plugin_id: registered.plugin_id.clone(),
356                value: transform,
357            });
358        }
359        Ok(transforms)
360    }
361
362    pub async fn transform_assistant_response(
363        &self,
364        session_id: &str,
365        response: crate::llm::types::LlmResponse,
366    ) -> Result<Vec<PluginOwned<AssistantResponseTransform>>, PluginError> {
367        let mut current = response;
368        let mut transforms = Vec::new();
369        for registered in &self.contributions.assistant_response_hooks {
370            let transform = (registered.hook)(AssistantResponseHookContext {
371                session_id: session_id.to_string(),
372                response: current.clone(),
373            })
374            .await?;
375            current = transform.response.clone();
376            transforms.push(PluginOwned {
377                plugin_id: registered.plugin_id.clone(),
378                value: transform,
379            });
380        }
381        Ok(transforms)
382    }
383
384    pub async fn project_tool_result(
385        &self,
386        ctx: ToolResultProjectionContext,
387    ) -> Result<crate::ModelToolReturn, PluginError> {
388        let Some(projector) = &self.contributions.tool_result_projector else {
389            return Ok(crate::ModelToolReturn::from_output(
390                ctx.call_id.clone(),
391                ctx.tool_name.clone(),
392                &ctx.output,
393            ));
394        };
395        (projector.hook)(ctx).await
396    }
397
398    pub async fn emit_runtime_event(&self, event: PluginLifecycleEvent<'_>) {
399        self.emit_runtime_event_with_phase_probe(event, None).await;
400    }
401
402    pub async fn emit_runtime_event_with_phase_probe(
403        &self,
404        event: PluginLifecycleEvent<'_>,
405        phase_probe: Option<Arc<dyn crate::runtime::RuntimeTurnPhaseProbe>>,
406    ) {
407        let hook_kind = lifecycle_event_hook_kind(&event);
408        let mut pending = FuturesUnordered::new();
409        for registered in &self.contributions.runtime_event_hooks {
410            let hook = Arc::clone(&registered.hook);
411            let plugin_id = registered.plugin_id.clone();
412            let phase_name = plugin_hook_phase_name(hook_kind, registered.plugin_id.as_str());
413            let event = event.clone();
414            let phase_probe = phase_probe.clone();
415            pending.push(async move {
416                if let Some(probe) = phase_probe.as_ref() {
417                    probe.begin_named(&phase_name);
418                }
419                let result = hook(event).await;
420                if let Some(probe) = phase_probe.as_ref() {
421                    probe.end_named(&phase_name);
422                }
423                (plugin_id, result)
424            });
425        }
426        while let Some((plugin_id, result)) = pending.next().await {
427            if let Err(err) = result {
428                tracing::warn!(plugin_id, "plugin runtime event hook failed: {err}");
429            }
430        }
431    }
432
433    pub fn has_runtime_event_hooks(&self) -> bool {
434        !self.contributions.runtime_event_hooks.is_empty()
435    }
436
437    pub async fn mutate_session_config(
438        &self,
439        ctx: SessionConfigChangedContext,
440        mut policy: SessionPolicy,
441    ) -> SessionPolicy {
442        for hook in &self.contributions.session_config_mutators {
443            match hook(ctx.clone(), policy.clone()).await {
444                Ok(next_policy) => policy = next_policy,
445                Err(err) => tracing::warn!("plugin config mutator failed: {err}"),
446            }
447        }
448        policy
449    }
450
451    pub fn snapshot(&self) -> Result<PluginSessionSnapshot, PluginError> {
452        let mut plugins = BTreeMap::new();
453        for plugin in &self.plugins {
454            let mut writer = InMemorySnapshotWriter::default();
455            let meta = plugin.snapshot(&mut writer)?;
456            plugins.insert(
457                plugin.id().to_string(),
458                PluginSnapshotEntry {
459                    meta,
460                    artifacts: writer.finish(),
461                },
462            );
463        }
464        Ok(PluginSessionSnapshot { plugins })
465    }
466
467    pub fn snapshot_is_current(&self, previous: Option<&PluginSessionSnapshot>) -> bool {
468        let Some(previous) = previous else {
469            return false;
470        };
471        if previous.plugins.len() != self.plugins.len() {
472            return false;
473        }
474        for plugin in &self.plugins {
475            let Some(entry) = previous.plugins.get(plugin.id()) else {
476                return false;
477            };
478            if entry.meta.plugin_version != plugin.version()
479                || entry.meta.revision != plugin.snapshot_revision()
480            {
481                return false;
482            }
483        }
484        true
485    }
486
487    pub fn snapshot_revision_fingerprint(&self) -> u64 {
488        let mut hasher = Sha256::new();
489        for plugin in &self.plugins {
490            hasher.update(plugin.id().as_bytes());
491            hasher.update([0]);
492            hasher.update(plugin.version().as_bytes());
493            hasher.update([0]);
494            hasher.update(plugin.snapshot_revision().to_le_bytes());
495            hasher.update([0xff]);
496        }
497        let digest = hasher.finalize();
498        u64::from_le_bytes(digest[..8].try_into().expect("digest prefix"))
499    }
500
501    pub fn restore(&self, snapshot: &PluginSessionSnapshot) -> Result<(), PluginError> {
502        for plugin in &self.plugins {
503            if let Some(entry) = snapshot.plugins.get(plugin.id()) {
504                let reader = InMemorySnapshotReader { entry };
505                plugin.restore(&entry.meta, &reader)?;
506            } else {
507                plugin.restore(
508                    &PluginSnapshotMeta {
509                        plugin_id: plugin.id().to_string(),
510                        plugin_version: plugin.version().to_string(),
511                        revision: plugin.snapshot_revision(),
512                        state: None,
513                    },
514                    &EmptySnapshotReader,
515                )?;
516            }
517        }
518        Ok(())
519    }
520
521    pub fn fork_for_session(
522        &self,
523        session_id: impl Into<String>,
524    ) -> Result<Arc<PluginSession>, PluginError> {
525        let snapshot = self.snapshot()?;
526        self.host.build_session_with_surface(
527            session_id,
528            Some(&snapshot),
529            self.tool_surface_overlay.clone(),
530            Some(self.tool_registry.export_state()),
531        )
532    }
533
534    pub fn fork_for_child_session(
535        &self,
536        session_id: impl Into<String>,
537        parent_session_id: Option<String>,
538        authority: super::SessionAuthorityContext,
539    ) -> Result<Arc<PluginSession>, PluginError> {
540        let snapshot = self.snapshot()?;
541        self.host.build_session_with_parent_and_surface(
542            session_id,
543            parent_session_id,
544            Some(&snapshot),
545            self.tool_surface_overlay.clone(),
546            Some(self.tool_registry.export_state()),
547            authority,
548        )
549    }
550
551    pub fn fork_for_session_with_tool_surface(
552        &self,
553        session_id: impl Into<String>,
554        tool_surface_overlay: ToolSurfaceContribution,
555    ) -> Result<Arc<PluginSession>, PluginError> {
556        let snapshot = self.snapshot()?;
557        self.host.build_session_with_surface(
558            session_id,
559            Some(&snapshot),
560            tool_surface_overlay,
561            Some(self.tool_registry.export_state()),
562        )
563    }
564
565    #[expect(
566        clippy::too_many_arguments,
567        reason = "plugin action invocation carries the explicit host services exposed to actions"
568    )]
569    pub async fn invoke_plugin_action(
570        &self,
571        name: &str,
572        args: serde_json::Value,
573        session_id: Option<String>,
574        default_to_current_session: bool,
575        sessions: Arc<dyn SessionStateService>,
576        session_lifecycle: Arc<dyn SessionLifecycleService>,
577        session_graph: Arc<dyn SessionGraphService>,
578        processes: Arc<dyn crate::ProcessService>,
579    ) -> Result<ToolResult, PluginActionInvokeError> {
580        let Some(op) = self.contributions.plugin_actions.get(name).cloned() else {
581            return Err(PluginActionInvokeError::Unknown(name.to_string()));
582        };
583
584        let effective_session = session_id.or_else(|| {
585            if default_to_current_session && !self.session_id.is_empty() {
586                Some(self.session_id.clone())
587            } else {
588                None
589            }
590        });
591
592        match (op.def.session_param, effective_session.as_ref()) {
593            (SessionParam::Required, None) => {
594                return Err(PluginActionInvokeError::MissingSession(name.to_string()));
595            }
596            (SessionParam::Forbidden, Some(_)) => {
597                return Err(PluginActionInvokeError::UnexpectedSession(name.to_string()));
598            }
599            _ => {}
600        }
601
602        Ok((op.handler)(
603            PluginActionContext {
604                session_id: effective_session,
605                sessions,
606                session_lifecycle,
607                session_graph,
608                processes,
609            },
610            args,
611        )
612        .await)
613    }
614
615    #[expect(
616        clippy::too_many_arguments,
617        reason = "typed action invocation mirrors the raw plugin action host service boundary"
618    )]
619    pub async fn call_plugin_action<Op: PluginAction>(
620        &self,
621        args: Op::Args,
622        session_id: Option<String>,
623        default_to_current_session: bool,
624        sessions: Arc<dyn SessionStateService>,
625        session_lifecycle: Arc<dyn SessionLifecycleService>,
626        session_graph: Arc<dyn SessionGraphService>,
627        processes: Arc<dyn crate::ProcessService>,
628    ) -> Result<Op::Output, PluginError> {
629        let args = serde_json::to_value(args)
630            .map_err(|err| PluginError::Invoke(format!("invalid {} args: {err}", Op::NAME)))?;
631        let result = self
632            .invoke_plugin_action(
633                Op::NAME,
634                args,
635                session_id,
636                default_to_current_session,
637                sessions,
638                session_lifecycle,
639                session_graph,
640                processes,
641            )
642            .await
643            .map_err(|err| PluginError::Invoke(err.to_string()))?;
644        if !result.is_success() {
645            return Err(PluginError::Invoke(format!(
646                "{} failed: {}",
647                Op::NAME,
648                result.value_for_projection()
649            )));
650        }
651        serde_json::from_value(result.into_output().value_for_projection())
652            .map_err(|err| PluginError::Invoke(format!("invalid {} output: {err}", Op::NAME)))
653    }
654}