Skip to main content

lash_core/plugin/
registry.rs

1//! Plugin registration: `PluginSpec` (the declarative bundle of all a
2//! plugin's hooks), the `PluginFactory` / `SessionPlugin` traits
3//! plugin crates implement, and the two convenience factories
4//! (`StaticPluginFactory`, `PluginSpecFactory`) + the `SpecPlugin`
5//! glue that walks a spec and wires each field into the registrar.
6//!
7//! Split out of `plugin/mod.rs` for file size; outer path preserved by
8//! `pub use` in `plugin/mod.rs`.
9
10use std::sync::Arc;
11
12use super::{
13    AfterToolCallHook, AfterTurnHook, AssistantResponseHook, AssistantStreamHook,
14    BeforeToolCallHook, BeforeTurnHook, CheckpointHook, ContextCompactor, PluginCommand,
15    PluginCommandHandler, PluginCommandInvokeFuture, PluginCommandOutcome, PluginError, PluginHost,
16    PluginLifecycleEventHook, PluginOperationDef, PluginOperationFailure, PluginOperationKind,
17    PluginQuery, PluginQueryHandler, PluginQueryInvokeFuture, PluginRegistrar, PluginSnapshotMeta,
18    PluginTask, PluginTaskHandler, PluginTaskInvokeFuture, PluginTaskOutcome, PromptContributor,
19    SessionConfigMutator, SessionToolAccess, SnapshotReader, SnapshotWriter,
20    SubagentSessionContext, ToolCatalogContributor, ToolDiscoveryContributor, ToolResultProjector,
21    TurnContextTransform,
22};
23use crate::{PluginOptions, ToolProvider};
24
25#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
26pub struct PluginExtensionContribution {
27    pub extension_id: String,
28    #[serde(default)]
29    pub payload: serde_json::Value,
30}
31
32impl PluginExtensionContribution {
33    pub fn new(
34        extension_id: impl Into<String>,
35        payload: impl serde::Serialize,
36    ) -> Result<Self, serde_json::Error> {
37        Ok(Self {
38            extension_id: extension_id.into(),
39            payload: serde_json::to_value(payload)?,
40        })
41    }
42
43    pub fn from_value(extension_id: impl Into<String>, payload: serde_json::Value) -> Self {
44        Self {
45            extension_id: extension_id.into(),
46            payload,
47        }
48    }
49}
50
51#[derive(Clone, Debug, Default, PartialEq, Eq)]
52pub struct PluginExtensions {
53    contributions: std::collections::BTreeMap<String, Vec<serde_json::Value>>,
54}
55
56impl PluginExtensions {
57    pub fn from_contributions(
58        contributions: impl IntoIterator<Item = PluginExtensionContribution>,
59    ) -> Self {
60        let mut extensions = Self::default();
61        for contribution in contributions {
62            extensions.insert(contribution);
63        }
64        extensions
65    }
66
67    pub fn insert(&mut self, contribution: PluginExtensionContribution) {
68        self.contributions
69            .entry(contribution.extension_id)
70            .or_default()
71            .push(contribution.payload);
72    }
73
74    pub fn payloads(&self, extension_id: &str) -> &[serde_json::Value] {
75        self.contributions
76            .get(extension_id)
77            .map(Vec::as_slice)
78            .unwrap_or(&[])
79    }
80
81    pub fn is_empty(&self) -> bool {
82        self.contributions.is_empty()
83    }
84}
85
86#[derive(Clone, Default)]
87pub struct PluginSpec {
88    pub tool_providers: Vec<Arc<dyn ToolProvider>>,
89    pub triggers: Vec<crate::TriggerEvent>,
90    pub prompt_contributors: Vec<PromptContributor>,
91    pub tool_catalog_contributors: Vec<ToolCatalogContributor>,
92    pub tool_discovery_contributors: Vec<ToolDiscoveryContributor>,
93    pub before_turn_hooks: Vec<BeforeTurnHook>,
94    pub before_tool_call_hooks: Vec<BeforeToolCallHook>,
95    pub after_tool_call_hooks: Vec<AfterToolCallHook>,
96    pub after_turn_hooks: Vec<AfterTurnHook>,
97    pub checkpoint_hooks: Vec<CheckpointHook>,
98    pub assistant_stream_hooks: Vec<AssistantStreamHook>,
99    pub assistant_response_hooks: Vec<AssistantResponseHook>,
100    pub tool_result_projector: Option<ToolResultProjector>,
101    pub runtime_event_hooks: Vec<PluginLifecycleEventHook>,
102    pub session_config_mutators: Vec<SessionConfigMutator>,
103    pub(crate) plugin_queries: Vec<(PluginOperationDef, PluginQueryHandler)>,
104    pub(crate) plugin_commands: Vec<(PluginOperationDef, PluginCommandHandler)>,
105    pub(crate) plugin_tasks: Vec<(PluginOperationDef, PluginTaskHandler)>,
106    pub turn_context_transforms: Vec<(i32, Arc<dyn TurnContextTransform>)>,
107    pub context_compactors: Vec<(i32, Arc<dyn ContextCompactor>)>,
108}
109
110impl PluginSpec {
111    pub fn new() -> Self {
112        Self::default()
113    }
114
115    pub fn with_tool_provider(mut self, provider: Arc<dyn ToolProvider>) -> Self {
116        self.tool_providers.push(provider);
117        self
118    }
119
120    pub fn with_trigger_event(mut self, event: crate::TriggerEvent) -> Self {
121        self.triggers.push(event);
122        self
123    }
124
125    pub fn with_prompt_contributor(mut self, contributor: PromptContributor) -> Self {
126        self.prompt_contributors.push(contributor);
127        self
128    }
129
130    pub fn with_tool_catalog_contributor(mut self, contributor: ToolCatalogContributor) -> Self {
131        self.tool_catalog_contributors.push(contributor);
132        self
133    }
134
135    pub fn with_tool_discovery_contributor(
136        mut self,
137        contributor: ToolDiscoveryContributor,
138    ) -> Self {
139        self.tool_discovery_contributors.push(contributor);
140        self
141    }
142
143    pub fn with_before_turn(mut self, hook: BeforeTurnHook) -> Self {
144        self.before_turn_hooks.push(hook);
145        self
146    }
147
148    pub fn with_before_tool_call(mut self, hook: BeforeToolCallHook) -> Self {
149        self.before_tool_call_hooks.push(hook);
150        self
151    }
152
153    pub fn with_after_tool_call(mut self, hook: AfterToolCallHook) -> Self {
154        self.after_tool_call_hooks.push(hook);
155        self
156    }
157
158    pub fn with_after_turn(mut self, hook: AfterTurnHook) -> Self {
159        self.after_turn_hooks.push(hook);
160        self
161    }
162
163    pub fn with_checkpoint(mut self, hook: CheckpointHook) -> Self {
164        self.checkpoint_hooks.push(hook);
165        self
166    }
167
168    pub fn with_assistant_stream(mut self, hook: AssistantStreamHook) -> Self {
169        self.assistant_stream_hooks.push(hook);
170        self
171    }
172
173    pub fn with_assistant_response(mut self, hook: AssistantResponseHook) -> Self {
174        self.assistant_response_hooks.push(hook);
175        self
176    }
177
178    pub fn with_tool_result_projector(mut self, projector: ToolResultProjector) -> Self {
179        self.tool_result_projector = Some(projector);
180        self
181    }
182
183    pub fn with_runtime_event(mut self, hook: PluginLifecycleEventHook) -> Self {
184        self.runtime_event_hooks.push(hook);
185        self
186    }
187
188    pub fn with_session_config_mutator(mut self, hook: SessionConfigMutator) -> Self {
189        self.session_config_mutators.push(hook);
190        self
191    }
192
193    pub(crate) fn with_plugin_query(
194        mut self,
195        def: PluginOperationDef,
196        handler: PluginQueryHandler,
197    ) -> Self {
198        self.plugin_queries.push((def, handler));
199        self
200    }
201
202    pub fn with_plugin_query_typed<Op, F, Fut>(self, handler: F) -> Self
203    where
204        Op: PluginQuery,
205        F: Fn(super::PluginQueryContext, Op::Args) -> Fut + Send + Sync + 'static,
206        Fut: std::future::Future<Output = Result<Op::Output, PluginOperationFailure>>
207            + Send
208            + 'static,
209    {
210        self.with_plugin_query(
211            super::plugin_operation_def::<Op>(PluginOperationKind::Query),
212            Arc::new(move |ctx, args| {
213                let parsed = serde_json::from_value::<Op::Args>(args);
214                match parsed {
215                    Ok(args) => {
216                        let fut = handler(ctx, args);
217                        Box::pin(async move {
218                            let output = fut.await?;
219                            serde_json::to_value(output).map_err(|err| {
220                                PluginOperationFailure::new(format!(
221                                    "failed to serialize {} output: {err}",
222                                    Op::NAME
223                                ))
224                            })
225                        }) as PluginQueryInvokeFuture
226                    }
227                    Err(err) => Box::pin(async move {
228                        Err(PluginOperationFailure::new(format!(
229                            "invalid {} args: {err}",
230                            Op::NAME
231                        )))
232                    }) as PluginQueryInvokeFuture,
233                }
234            }),
235        )
236    }
237
238    pub(crate) fn with_plugin_command(
239        mut self,
240        def: PluginOperationDef,
241        handler: PluginCommandHandler,
242    ) -> Self {
243        self.plugin_commands.push((def, handler));
244        self
245    }
246
247    pub fn with_plugin_command_typed<Op, F, Fut>(self, handler: F) -> Self
248    where
249        Op: PluginCommand,
250        F: Fn(super::PluginCommandContext, Op::Args) -> Fut + Send + Sync + 'static,
251        Fut: std::future::Future<
252                Output = Result<PluginCommandOutcome<Op::Output>, PluginOperationFailure>,
253            > + Send
254            + 'static,
255    {
256        self.with_plugin_command(
257            super::plugin_operation_def::<Op>(PluginOperationKind::Command),
258            Arc::new(move |ctx, args| {
259                let parsed = serde_json::from_value::<Op::Args>(args);
260                match parsed {
261                    Ok(args) => {
262                        let fut = handler(ctx, args);
263                        Box::pin(async move {
264                            let outcome = fut.await?;
265                            let output = serde_json::to_value(outcome.output).map_err(|err| {
266                                PluginOperationFailure::new(format!(
267                                    "failed to serialize {} output: {err}",
268                                    Op::NAME
269                                ))
270                            })?;
271                            Ok(super::actions::ErasedPluginCommandOutcome {
272                                output,
273                                events: outcome.events,
274                                directives: outcome.directives,
275                            })
276                        }) as PluginCommandInvokeFuture
277                    }
278                    Err(err) => Box::pin(async move {
279                        Err(PluginOperationFailure::new(format!(
280                            "invalid {} args: {err}",
281                            Op::NAME
282                        )))
283                    }) as PluginCommandInvokeFuture,
284                }
285            }),
286        )
287    }
288
289    pub fn with_plugin_command_value<Op, F, Fut>(self, handler: F) -> Self
290    where
291        Op: PluginCommand,
292        F: Fn(super::PluginCommandContext, Op::Args) -> Fut + Send + Sync + 'static,
293        Fut: std::future::Future<Output = Result<Op::Output, PluginOperationFailure>>
294            + Send
295            + 'static,
296    {
297        self.with_plugin_command_typed::<Op, _, _>(move |ctx, args| {
298            let fut = handler(ctx, args);
299            async move { fut.await.map(PluginCommandOutcome::new) }
300        })
301    }
302
303    pub(crate) fn with_plugin_task(
304        mut self,
305        def: PluginOperationDef,
306        handler: PluginTaskHandler,
307    ) -> Self {
308        self.plugin_tasks.push((def, handler));
309        self
310    }
311
312    pub fn with_plugin_task_typed<Op, F, Fut>(self, handler: F) -> Self
313    where
314        Op: PluginTask,
315        F: Fn(super::PluginTaskContext, Op::Args) -> Fut + Send + Sync + 'static,
316        Fut: std::future::Future<
317                Output = Result<PluginTaskOutcome<Op::Output>, PluginOperationFailure>,
318            > + Send
319            + 'static,
320    {
321        self.with_plugin_task(
322            super::plugin_operation_def::<Op>(PluginOperationKind::Task),
323            Arc::new(move |ctx, args| {
324                let parsed = serde_json::from_value::<Op::Args>(args);
325                match parsed {
326                    Ok(args) => {
327                        let fut = handler(ctx, args);
328                        Box::pin(async move {
329                            let outcome = fut.await?;
330                            let output = serde_json::to_value(outcome.output).map_err(|err| {
331                                PluginOperationFailure::new(format!(
332                                    "failed to serialize {} output: {err}",
333                                    Op::NAME
334                                ))
335                            })?;
336                            Ok(super::actions::ErasedPluginTaskOutcome {
337                                output,
338                                events: outcome.events,
339                                directives: outcome.directives,
340                            })
341                        }) as PluginTaskInvokeFuture
342                    }
343                    Err(err) => Box::pin(async move {
344                        Err(PluginOperationFailure::new(format!(
345                            "invalid {} args: {err}",
346                            Op::NAME
347                        )))
348                    }) as PluginTaskInvokeFuture,
349                }
350            }),
351        )
352    }
353
354    pub fn with_plugin_task_value<Op, F, Fut>(self, handler: F) -> Self
355    where
356        Op: PluginTask,
357        F: Fn(super::PluginTaskContext, Op::Args) -> Fut + Send + Sync + 'static,
358        Fut: std::future::Future<Output = Result<Op::Output, PluginOperationFailure>>
359            + Send
360            + 'static,
361    {
362        self.with_plugin_task_typed::<Op, _, _>(move |ctx, args| {
363            let fut = handler(ctx, args);
364            async move { fut.await.map(PluginTaskOutcome::new) }
365        })
366    }
367
368    pub fn with_turn_context_transform(
369        mut self,
370        priority: i32,
371        transform: Arc<dyn TurnContextTransform>,
372    ) -> Self {
373        self.turn_context_transforms.push((priority, transform));
374        self
375    }
376
377    pub fn with_context_compactor(
378        mut self,
379        priority: i32,
380        compactor: Arc<dyn ContextCompactor>,
381    ) -> Self {
382        self.context_compactors.push((priority, compactor));
383        self
384    }
385}
386
387#[derive(Clone, Debug)]
388pub struct PluginSessionContext {
389    pub session_id: String,
390    pub tool_access: SessionToolAccess,
391    pub subagent: Option<SubagentSessionContext>,
392    pub plugin_options: PluginOptions,
393    pub extensions: PluginExtensions,
394    /// Session id of the caller that created this one. `None` identifies
395    /// a root session; any subagent / compaction / forked-child session
396    /// carries the parent here so plugin factories can gate themselves
397    /// on root-only behavior (e.g. `update_plan`'s sticky plan dock).
398    pub parent_session_id: Option<String>,
399}
400
401impl PluginSessionContext {
402    /// Returns `true` when this context represents a root session, not a
403    /// subagent or internal child. Plugins that should only surface in
404    /// user-facing top-level turns check this in their `build`.
405    pub fn is_root_session(&self) -> bool {
406        self.parent_session_id.is_none()
407    }
408}
409
410#[derive(Clone)]
411pub struct SessionReadyContext {
412    pub session_id: String,
413    pub host: PluginHost,
414}
415
416pub trait SessionPlugin: Send + Sync {
417    fn id(&self) -> &'static str;
418
419    fn version(&self) -> &'static str {
420        "1"
421    }
422
423    fn register(&self, reg: &mut PluginRegistrar) -> Result<(), PluginError>;
424
425    fn snapshot(
426        &self,
427        _writer: &mut dyn SnapshotWriter,
428    ) -> Result<PluginSnapshotMeta, PluginError> {
429        Ok(PluginSnapshotMeta {
430            plugin_id: self.id().to_string(),
431            plugin_version: self.version().to_string(),
432            revision: self.snapshot_revision(),
433            state: None,
434        })
435    }
436
437    fn snapshot_revision(&self) -> u64 {
438        0
439    }
440
441    fn restore(
442        &self,
443        _meta: &PluginSnapshotMeta,
444        _reader: &dyn SnapshotReader,
445    ) -> Result<(), PluginError> {
446        Ok(())
447    }
448
449    fn session_ready(&self, _ctx: SessionReadyContext) -> Result<(), PluginError> {
450        Ok(())
451    }
452}
453
454/// Registers a plugin with the runtime and produces a per-session
455/// `SessionPlugin` instance for each new session.
456///
457/// # Cheap-build / stateful-factory contract
458///
459/// `build(ctx)` **must be cheap**. It runs on the hot path every time
460/// a new session is created (including subagents, forked children,
461/// and compaction children) and any latency here is paid per session.
462///
463/// Specifically, `build` must **not**:
464/// - perform any I/O (disk reads, HTTP calls, DB queries),
465/// - compile regexes, templates, or schemas,
466/// - open network connections or initialize connection pools,
467/// - load models, parse large config files, or allocate large buffers,
468/// - block the current thread for non-trivial work.
469///
470/// Expensive state belongs on the `PluginFactory` struct itself,
471/// wrapped in `Arc` so it can be cheaply cloned into per-session
472/// closures. The `PluginFactory` is constructed once by the embedder
473/// and held in the `RuntimeEnvironment`; its fields outlive every
474/// session. Hooks captured into a `PluginSpec` are closures that
475/// clone the `Arc`s off `self` and reference the shared state
476/// directly, so every session sees the same pool / cache / compiled
477/// artifact without rebuilding it.
478///
479/// The typical shape is:
480/// ```ignore
481/// pub struct MyFactory {
482///     pool: Arc<ConnectionPool>,          // expensive, built once
483///     compiled: Arc<Regex>,               // expensive, built once
484/// }
485///
486/// impl PluginFactory for MyFactory {
487///     fn id(&self) -> &'static str { "my_plugin" }
488///
489///     fn build(&self, _ctx: &PluginSessionContext)
490///         -> Result<Arc<dyn SessionPlugin>, PluginError>
491///     {
492///         // Cheap: clone Arcs, assemble spec, wrap in SpecPlugin.
493///         let pool = Arc::clone(&self.pool);
494///         let spec = PluginSpec::new().with_before_turn(Arc::new(move |_ctx| {
495///             let pool = Arc::clone(&pool);
496///             Box::pin(async move { /* use pool */ Ok(vec![]) })
497///         }));
498///         Ok(Arc::new(SpecPluginFromSpec::new("my_plugin", spec)))
499///     }
500/// }
501/// ```
502pub trait PluginFactory: Send + Sync {
503    fn id(&self) -> &'static str;
504
505    fn extension_contributions(&self) -> Vec<PluginExtensionContribution> {
506        Vec::new()
507    }
508
509    /// Produce a session-scoped plugin. **Must be cheap** — see the
510    /// trait-level docs for the full contract.
511    fn build(&self, ctx: &PluginSessionContext) -> Result<Arc<dyn SessionPlugin>, PluginError>;
512}
513
514pub type PluginSpecBuilder =
515    Arc<dyn Fn(&PluginSessionContext) -> Result<PluginSpec, PluginError> + Send + Sync>;
516
517pub struct PluginSpecFactory {
518    id: &'static str,
519    builder: PluginSpecBuilder,
520}
521
522impl PluginSpecFactory {
523    pub fn new(id: &'static str, builder: PluginSpecBuilder) -> Self {
524        Self { id, builder }
525    }
526}
527
528pub struct StaticPluginFactory {
529    id: &'static str,
530    spec: PluginSpec,
531}
532
533impl StaticPluginFactory {
534    pub fn new(id: &'static str, spec: PluginSpec) -> Self {
535        Self { id, spec }
536    }
537}
538
539struct SpecPlugin {
540    id: &'static str,
541    spec: PluginSpec,
542}
543
544impl PluginFactory for PluginSpecFactory {
545    fn id(&self) -> &'static str {
546        self.id
547    }
548
549    fn build(&self, ctx: &PluginSessionContext) -> Result<Arc<dyn SessionPlugin>, PluginError> {
550        Ok(Arc::new(SpecPlugin {
551            id: self.id,
552            spec: (self.builder)(ctx)?,
553        }))
554    }
555}
556
557impl PluginFactory for StaticPluginFactory {
558    fn id(&self) -> &'static str {
559        self.id
560    }
561
562    fn build(&self, _ctx: &PluginSessionContext) -> Result<Arc<dyn SessionPlugin>, PluginError> {
563        Ok(Arc::new(SpecPlugin {
564            id: self.id,
565            spec: self.spec.clone(),
566        }))
567    }
568}
569
570impl SessionPlugin for SpecPlugin {
571    fn id(&self) -> &'static str {
572        self.id
573    }
574
575    fn register(&self, reg: &mut PluginRegistrar) -> Result<(), PluginError> {
576        for provider in &self.spec.tool_providers {
577            reg.tools().provider(Arc::clone(provider))?;
578        }
579        for event in &self.spec.triggers {
580            reg.triggers().declare(event.clone())?;
581        }
582        for contributor in &self.spec.prompt_contributors {
583            reg.prompt().contribute(Arc::clone(contributor));
584        }
585        for contributor in &self.spec.tool_catalog_contributors {
586            reg.tool_catalog().contribute(Arc::clone(contributor));
587        }
588        for contributor in &self.spec.tool_discovery_contributors {
589            reg.discovery().contribute(Arc::clone(contributor));
590        }
591        for hook in &self.spec.before_turn_hooks {
592            reg.turn().before(Arc::clone(hook));
593        }
594        for hook in &self.spec.before_tool_call_hooks {
595            reg.tool_calls().before(Arc::clone(hook));
596        }
597        for hook in &self.spec.after_tool_call_hooks {
598            reg.tool_calls().after(Arc::clone(hook));
599        }
600        for hook in &self.spec.after_turn_hooks {
601            reg.turn().after(Arc::clone(hook));
602        }
603        for hook in &self.spec.checkpoint_hooks {
604            reg.turn().checkpoint(Arc::clone(hook));
605        }
606        for hook in &self.spec.assistant_stream_hooks {
607            reg.output().stream(Arc::clone(hook));
608        }
609        for hook in &self.spec.assistant_response_hooks {
610            reg.output().response(Arc::clone(hook));
611        }
612        if let Some(projector) = &self.spec.tool_result_projector {
613            reg.tool_results().projector(Arc::clone(projector))?;
614        }
615        for hook in &self.spec.runtime_event_hooks {
616            reg.session().on_event(Arc::clone(hook));
617        }
618        for hook in &self.spec.session_config_mutators {
619            reg.session().config_mutator(Arc::clone(hook));
620        }
621        for (def, handler) in &self.spec.plugin_queries {
622            reg.operations().query(def.clone(), Arc::clone(handler))?;
623        }
624        for (def, handler) in &self.spec.plugin_commands {
625            reg.operations().command(def.clone(), Arc::clone(handler))?;
626        }
627        for (def, handler) in &self.spec.plugin_tasks {
628            reg.operations().task(def.clone(), Arc::clone(handler))?;
629        }
630        for (priority, transform) in &self.spec.turn_context_transforms {
631            reg.context().prepare_turn(*priority, Arc::clone(transform));
632        }
633        for (priority, compactor) in &self.spec.context_compactors {
634            reg.context().compact(*priority, Arc::clone(compactor));
635        }
636        Ok(())
637    }
638}