Skip to main content

lash_core/plugin/
registry.rs

1//! Plugin registration: `PluginSpec` (the declarative bundle of all a
2//! plugin's hooks), the `PluginFactory` / `SessionPlugin` traits
3//! plugin crates implement, and the two convenience factories
4//! (`StaticPluginFactory`, `PluginSpecFactory`) + the `SpecPlugin`
5//! glue that walks a spec and wires each field into the registrar.
6//!
7//! Split out of `plugin/mod.rs` for file size; outer path preserved by
8//! `pub use` in `plugin/mod.rs`.
9
10use std::sync::Arc;
11
12use super::{
13    AfterToolCallHook, AfterTurnHook, AssistantResponseHook, AssistantStreamHook,
14    BeforeToolCallHook, BeforeTurnHook, CheckpointHook, ContextCompactor, PluginCommand,
15    PluginCommandHandler, PluginCommandInvokeFuture, PluginCommandOutcome, PluginError, PluginHost,
16    PluginLifecycleEventHook, PluginOperationDef, PluginOperationFailure, PluginOperationKind,
17    PluginQuery, PluginQueryHandler, PluginQueryInvokeFuture, PluginRegistrar, PluginSnapshotMeta,
18    PluginTask, PluginTaskHandler, PluginTaskInvokeFuture, PluginTaskOutcome, PromptContributor,
19    SessionConfigMutator, SessionToolAccess, SnapshotReader, SnapshotWriter,
20    SubagentSessionContext, ToolCatalogContributor, ToolResultProjector, TurnContextTransform,
21};
22use crate::{PluginOptions, ToolProvider};
23
24#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
25pub struct PluginExtensionContribution {
26    pub extension_id: String,
27    #[serde(default)]
28    pub payload: serde_json::Value,
29}
30
31impl PluginExtensionContribution {
32    pub fn new(
33        extension_id: impl Into<String>,
34        payload: impl serde::Serialize,
35    ) -> Result<Self, serde_json::Error> {
36        Ok(Self {
37            extension_id: extension_id.into(),
38            payload: serde_json::to_value(payload)?,
39        })
40    }
41
42    pub fn from_value(extension_id: impl Into<String>, payload: serde_json::Value) -> Self {
43        Self {
44            extension_id: extension_id.into(),
45            payload,
46        }
47    }
48}
49
50#[derive(Clone, Debug, Default, PartialEq, Eq)]
51pub struct PluginExtensions {
52    contributions: std::collections::BTreeMap<String, Vec<serde_json::Value>>,
53}
54
55impl PluginExtensions {
56    pub fn from_contributions(
57        contributions: impl IntoIterator<Item = PluginExtensionContribution>,
58    ) -> Self {
59        let mut extensions = Self::default();
60        for contribution in contributions {
61            extensions.insert(contribution);
62        }
63        extensions
64    }
65
66    pub fn insert(&mut self, contribution: PluginExtensionContribution) {
67        self.contributions
68            .entry(contribution.extension_id)
69            .or_default()
70            .push(contribution.payload);
71    }
72
73    pub fn payloads(&self, extension_id: &str) -> &[serde_json::Value] {
74        self.contributions
75            .get(extension_id)
76            .map(Vec::as_slice)
77            .unwrap_or(&[])
78    }
79
80    pub fn is_empty(&self) -> bool {
81        self.contributions.is_empty()
82    }
83}
84
85#[derive(Clone, Default)]
86pub struct PluginSpec {
87    pub tool_providers: Vec<Arc<dyn ToolProvider>>,
88    pub triggers: Vec<crate::TriggerEvent>,
89    pub prompt_contributors: Vec<PromptContributor>,
90    pub tool_catalog_contributors: Vec<ToolCatalogContributor>,
91    pub before_turn_hooks: Vec<BeforeTurnHook>,
92    pub before_tool_call_hooks: Vec<BeforeToolCallHook>,
93    pub after_tool_call_hooks: Vec<AfterToolCallHook>,
94    pub after_turn_hooks: Vec<AfterTurnHook>,
95    pub checkpoint_hooks: Vec<CheckpointHook>,
96    pub assistant_stream_hooks: Vec<AssistantStreamHook>,
97    pub assistant_response_hooks: Vec<AssistantResponseHook>,
98    pub tool_result_projector: Option<ToolResultProjector>,
99    pub runtime_event_hooks: Vec<PluginLifecycleEventHook>,
100    pub session_config_mutators: Vec<SessionConfigMutator>,
101    pub(crate) plugin_queries: Vec<(PluginOperationDef, PluginQueryHandler)>,
102    pub(crate) plugin_commands: Vec<(PluginOperationDef, PluginCommandHandler)>,
103    pub(crate) plugin_tasks: Vec<(PluginOperationDef, PluginTaskHandler)>,
104    pub turn_context_transforms: Vec<(i32, Arc<dyn TurnContextTransform>)>,
105    pub context_compactors: Vec<(i32, Arc<dyn ContextCompactor>)>,
106}
107
108impl PluginSpec {
109    pub fn new() -> Self {
110        Self::default()
111    }
112
113    pub fn with_tool_provider(mut self, provider: Arc<dyn ToolProvider>) -> Self {
114        self.tool_providers.push(provider);
115        self
116    }
117
118    pub fn with_trigger_event(mut self, event: crate::TriggerEvent) -> Self {
119        self.triggers.push(event);
120        self
121    }
122
123    pub fn with_prompt_contributor(mut self, contributor: PromptContributor) -> Self {
124        self.prompt_contributors.push(contributor);
125        self
126    }
127
128    pub fn with_tool_catalog_contributor(mut self, contributor: ToolCatalogContributor) -> Self {
129        self.tool_catalog_contributors.push(contributor);
130        self
131    }
132
133    pub fn with_before_turn(mut self, hook: BeforeTurnHook) -> Self {
134        self.before_turn_hooks.push(hook);
135        self
136    }
137
138    pub fn with_before_tool_call(mut self, hook: BeforeToolCallHook) -> Self {
139        self.before_tool_call_hooks.push(hook);
140        self
141    }
142
143    pub fn with_after_tool_call(mut self, hook: AfterToolCallHook) -> Self {
144        self.after_tool_call_hooks.push(hook);
145        self
146    }
147
148    pub fn with_after_turn(mut self, hook: AfterTurnHook) -> Self {
149        self.after_turn_hooks.push(hook);
150        self
151    }
152
153    pub fn with_checkpoint(mut self, hook: CheckpointHook) -> Self {
154        self.checkpoint_hooks.push(hook);
155        self
156    }
157
158    pub fn with_assistant_stream(mut self, hook: AssistantStreamHook) -> Self {
159        self.assistant_stream_hooks.push(hook);
160        self
161    }
162
163    pub fn with_assistant_response(mut self, hook: AssistantResponseHook) -> Self {
164        self.assistant_response_hooks.push(hook);
165        self
166    }
167
168    pub fn with_tool_result_projector(mut self, projector: ToolResultProjector) -> Self {
169        self.tool_result_projector = Some(projector);
170        self
171    }
172
173    pub fn with_runtime_event(mut self, hook: PluginLifecycleEventHook) -> Self {
174        self.runtime_event_hooks.push(hook);
175        self
176    }
177
178    pub fn with_session_config_mutator(mut self, hook: SessionConfigMutator) -> Self {
179        self.session_config_mutators.push(hook);
180        self
181    }
182
183    pub(crate) fn with_plugin_query(
184        mut self,
185        def: PluginOperationDef,
186        handler: PluginQueryHandler,
187    ) -> Self {
188        self.plugin_queries.push((def, handler));
189        self
190    }
191
192    pub fn with_plugin_query_typed<Op, F, Fut>(self, handler: F) -> Self
193    where
194        Op: PluginQuery,
195        F: Fn(super::PluginQueryContext, Op::Args) -> Fut + Send + Sync + 'static,
196        Fut: std::future::Future<Output = Result<Op::Output, PluginOperationFailure>>
197            + Send
198            + 'static,
199    {
200        self.with_plugin_query(
201            super::plugin_operation_def::<Op>(PluginOperationKind::Query),
202            Arc::new(move |ctx, args| {
203                let parsed = serde_json::from_value::<Op::Args>(args);
204                match parsed {
205                    Ok(args) => {
206                        let fut = handler(ctx, args);
207                        Box::pin(async move {
208                            let output = fut.await?;
209                            serde_json::to_value(output).map_err(|err| {
210                                PluginOperationFailure::new(format!(
211                                    "failed to serialize {} output: {err}",
212                                    Op::NAME
213                                ))
214                            })
215                        }) as PluginQueryInvokeFuture
216                    }
217                    Err(err) => Box::pin(async move {
218                        Err(PluginOperationFailure::new(format!(
219                            "invalid {} args: {err}",
220                            Op::NAME
221                        )))
222                    }) as PluginQueryInvokeFuture,
223                }
224            }),
225        )
226    }
227
228    pub(crate) fn with_plugin_command(
229        mut self,
230        def: PluginOperationDef,
231        handler: PluginCommandHandler,
232    ) -> Self {
233        self.plugin_commands.push((def, handler));
234        self
235    }
236
237    pub fn with_plugin_command_typed<Op, F, Fut>(self, handler: F) -> Self
238    where
239        Op: PluginCommand,
240        F: Fn(super::PluginCommandContext, Op::Args) -> Fut + Send + Sync + 'static,
241        Fut: std::future::Future<
242                Output = Result<PluginCommandOutcome<Op::Output>, PluginOperationFailure>,
243            > + Send
244            + 'static,
245    {
246        self.with_plugin_command(
247            super::plugin_operation_def::<Op>(PluginOperationKind::Command),
248            Arc::new(move |ctx, args| {
249                let parsed = serde_json::from_value::<Op::Args>(args);
250                match parsed {
251                    Ok(args) => {
252                        let fut = handler(ctx, args);
253                        Box::pin(async move {
254                            let outcome = fut.await?;
255                            let output = serde_json::to_value(outcome.output).map_err(|err| {
256                                PluginOperationFailure::new(format!(
257                                    "failed to serialize {} output: {err}",
258                                    Op::NAME
259                                ))
260                            })?;
261                            Ok(super::actions::ErasedPluginCommandOutcome {
262                                output,
263                                events: outcome.events,
264                                directives: outcome.directives,
265                            })
266                        }) as PluginCommandInvokeFuture
267                    }
268                    Err(err) => Box::pin(async move {
269                        Err(PluginOperationFailure::new(format!(
270                            "invalid {} args: {err}",
271                            Op::NAME
272                        )))
273                    }) as PluginCommandInvokeFuture,
274                }
275            }),
276        )
277    }
278
279    pub fn with_plugin_command_value<Op, F, Fut>(self, handler: F) -> Self
280    where
281        Op: PluginCommand,
282        F: Fn(super::PluginCommandContext, Op::Args) -> Fut + Send + Sync + 'static,
283        Fut: std::future::Future<Output = Result<Op::Output, PluginOperationFailure>>
284            + Send
285            + 'static,
286    {
287        self.with_plugin_command_typed::<Op, _, _>(move |ctx, args| {
288            let fut = handler(ctx, args);
289            async move { fut.await.map(PluginCommandOutcome::new) }
290        })
291    }
292
293    pub(crate) fn with_plugin_task(
294        mut self,
295        def: PluginOperationDef,
296        handler: PluginTaskHandler,
297    ) -> Self {
298        self.plugin_tasks.push((def, handler));
299        self
300    }
301
302    pub fn with_plugin_task_typed<Op, F, Fut>(self, handler: F) -> Self
303    where
304        Op: PluginTask,
305        F: Fn(super::PluginTaskContext, Op::Args) -> Fut + Send + Sync + 'static,
306        Fut: std::future::Future<
307                Output = Result<PluginTaskOutcome<Op::Output>, PluginOperationFailure>,
308            > + Send
309            + 'static,
310    {
311        self.with_plugin_task(
312            super::plugin_operation_def::<Op>(PluginOperationKind::Task),
313            Arc::new(move |ctx, args| {
314                let parsed = serde_json::from_value::<Op::Args>(args);
315                match parsed {
316                    Ok(args) => {
317                        let fut = handler(ctx, args);
318                        Box::pin(async move {
319                            let outcome = fut.await?;
320                            let output = serde_json::to_value(outcome.output).map_err(|err| {
321                                PluginOperationFailure::new(format!(
322                                    "failed to serialize {} output: {err}",
323                                    Op::NAME
324                                ))
325                            })?;
326                            Ok(super::actions::ErasedPluginTaskOutcome {
327                                output,
328                                events: outcome.events,
329                                directives: outcome.directives,
330                            })
331                        }) as PluginTaskInvokeFuture
332                    }
333                    Err(err) => Box::pin(async move {
334                        Err(PluginOperationFailure::new(format!(
335                            "invalid {} args: {err}",
336                            Op::NAME
337                        )))
338                    }) as PluginTaskInvokeFuture,
339                }
340            }),
341        )
342    }
343
344    pub fn with_plugin_task_value<Op, F, Fut>(self, handler: F) -> Self
345    where
346        Op: PluginTask,
347        F: Fn(super::PluginTaskContext, Op::Args) -> Fut + Send + Sync + 'static,
348        Fut: std::future::Future<Output = Result<Op::Output, PluginOperationFailure>>
349            + Send
350            + 'static,
351    {
352        self.with_plugin_task_typed::<Op, _, _>(move |ctx, args| {
353            let fut = handler(ctx, args);
354            async move { fut.await.map(PluginTaskOutcome::new) }
355        })
356    }
357
358    pub fn with_turn_context_transform(
359        mut self,
360        priority: i32,
361        transform: Arc<dyn TurnContextTransform>,
362    ) -> Self {
363        self.turn_context_transforms.push((priority, transform));
364        self
365    }
366
367    pub fn with_context_compactor(
368        mut self,
369        priority: i32,
370        compactor: Arc<dyn ContextCompactor>,
371    ) -> Self {
372        self.context_compactors.push((priority, compactor));
373        self
374    }
375}
376
377#[derive(Clone, Debug)]
378pub struct PluginSessionContext {
379    pub session_id: String,
380    pub tool_access: SessionToolAccess,
381    pub subagent: Option<SubagentSessionContext>,
382    pub plugin_options: PluginOptions,
383    pub extensions: PluginExtensions,
384    /// Session id of the caller that created this one. `None` identifies
385    /// a root session; any subagent / compaction / forked-child session
386    /// carries the parent here so plugin factories can gate themselves
387    /// on root-only behavior (e.g. `update_plan`'s sticky plan dock).
388    pub parent_session_id: Option<String>,
389}
390
391impl PluginSessionContext {
392    /// Returns `true` when this context represents a root session, not a
393    /// subagent or internal child. Plugins that should only surface in
394    /// user-facing top-level turns check this in their `build`.
395    pub fn is_root_session(&self) -> bool {
396        self.parent_session_id.is_none()
397    }
398}
399
400#[derive(Clone)]
401pub struct SessionReadyContext {
402    pub session_id: String,
403    pub host: PluginHost,
404}
405
406pub trait SessionPlugin: Send + Sync {
407    fn id(&self) -> &'static str;
408
409    fn version(&self) -> &'static str {
410        "1"
411    }
412
413    fn register(&self, reg: &mut PluginRegistrar) -> Result<(), PluginError>;
414
415    fn snapshot(
416        &self,
417        _writer: &mut dyn SnapshotWriter,
418    ) -> Result<PluginSnapshotMeta, PluginError> {
419        Ok(PluginSnapshotMeta {
420            plugin_id: self.id().to_string(),
421            plugin_version: self.version().to_string(),
422            revision: self.snapshot_revision(),
423            state: None,
424        })
425    }
426
427    fn snapshot_revision(&self) -> u64 {
428        0
429    }
430
431    fn restore(
432        &self,
433        _meta: &PluginSnapshotMeta,
434        _reader: &dyn SnapshotReader,
435    ) -> Result<(), PluginError> {
436        Ok(())
437    }
438
439    fn session_ready(&self, _ctx: SessionReadyContext) -> Result<(), PluginError> {
440        Ok(())
441    }
442}
443
444/// Registers a plugin with the runtime and produces a per-session
445/// `SessionPlugin` instance for each new session.
446///
447/// # Cheap-build / stateful-factory contract
448///
449/// `build(ctx)` **must be cheap**. It runs on the hot path every time
450/// a new session is created (including subagents, forked children,
451/// and compaction children) and any latency here is paid per session.
452///
453/// Specifically, `build` must **not**:
454/// - perform any I/O (disk reads, HTTP calls, DB queries),
455/// - compile regexes, templates, or schemas,
456/// - open network connections or initialize connection pools,
457/// - load models, parse large config files, or allocate large buffers,
458/// - block the current thread for non-trivial work.
459///
460/// Expensive state belongs on the `PluginFactory` struct itself,
461/// wrapped in `Arc` so it can be cheaply cloned into per-session
462/// closures. The `PluginFactory` is constructed once by the embedder
463/// and held in the `RuntimeEnvironment`; its fields outlive every
464/// session. Hooks captured into a `PluginSpec` are closures that
465/// clone the `Arc`s off `self` and reference the shared state
466/// directly, so every session sees the same pool / cache / compiled
467/// artifact without rebuilding it.
468///
469/// The typical shape is:
470/// ```ignore
471/// pub struct MyFactory {
472///     pool: Arc<ConnectionPool>,          // expensive, built once
473///     compiled: Arc<Regex>,               // expensive, built once
474/// }
475///
476/// impl PluginFactory for MyFactory {
477///     fn id(&self) -> &'static str { "my_plugin" }
478///
479///     fn build(&self, _ctx: &PluginSessionContext)
480///         -> Result<Arc<dyn SessionPlugin>, PluginError>
481///     {
482///         // Cheap: clone Arcs, assemble spec, wrap in SpecPlugin.
483///         let pool = Arc::clone(&self.pool);
484///         let spec = PluginSpec::new().with_before_turn(Arc::new(move |_ctx| {
485///             let pool = Arc::clone(&pool);
486///             Box::pin(async move { /* use pool */ Ok(vec![]) })
487///         }));
488///         Ok(Arc::new(SpecPluginFromSpec::new("my_plugin", spec)))
489///     }
490/// }
491/// ```
492pub trait PluginFactory: Send + Sync {
493    fn id(&self) -> &'static str;
494
495    fn extension_contributions(&self) -> Vec<PluginExtensionContribution> {
496        Vec::new()
497    }
498
499    /// Produce a session-scoped plugin. **Must be cheap** — see the
500    /// trait-level docs for the full contract.
501    fn build(&self, ctx: &PluginSessionContext) -> Result<Arc<dyn SessionPlugin>, PluginError>;
502}
503
504pub type PluginSpecBuilder =
505    Arc<dyn Fn(&PluginSessionContext) -> Result<PluginSpec, PluginError> + Send + Sync>;
506
507pub struct PluginSpecFactory {
508    id: &'static str,
509    builder: PluginSpecBuilder,
510}
511
512impl PluginSpecFactory {
513    pub fn new(id: &'static str, builder: PluginSpecBuilder) -> Self {
514        Self { id, builder }
515    }
516}
517
518pub struct StaticPluginFactory {
519    id: &'static str,
520    spec: PluginSpec,
521}
522
523impl StaticPluginFactory {
524    pub fn new(id: &'static str, spec: PluginSpec) -> Self {
525        Self { id, spec }
526    }
527}
528
529struct SpecPlugin {
530    id: &'static str,
531    spec: PluginSpec,
532}
533
534impl PluginFactory for PluginSpecFactory {
535    fn id(&self) -> &'static str {
536        self.id
537    }
538
539    fn build(&self, ctx: &PluginSessionContext) -> Result<Arc<dyn SessionPlugin>, PluginError> {
540        Ok(Arc::new(SpecPlugin {
541            id: self.id,
542            spec: (self.builder)(ctx)?,
543        }))
544    }
545}
546
547impl PluginFactory for StaticPluginFactory {
548    fn id(&self) -> &'static str {
549        self.id
550    }
551
552    fn build(&self, _ctx: &PluginSessionContext) -> Result<Arc<dyn SessionPlugin>, PluginError> {
553        Ok(Arc::new(SpecPlugin {
554            id: self.id,
555            spec: self.spec.clone(),
556        }))
557    }
558}
559
560impl SessionPlugin for SpecPlugin {
561    fn id(&self) -> &'static str {
562        self.id
563    }
564
565    fn register(&self, reg: &mut PluginRegistrar) -> Result<(), PluginError> {
566        for provider in &self.spec.tool_providers {
567            reg.tools().provider(Arc::clone(provider))?;
568        }
569        for event in &self.spec.triggers {
570            reg.triggers().declare(event.clone())?;
571        }
572        for contributor in &self.spec.prompt_contributors {
573            reg.prompt().contribute(Arc::clone(contributor));
574        }
575        for contributor in &self.spec.tool_catalog_contributors {
576            reg.tool_catalog().contribute(Arc::clone(contributor));
577        }
578        for hook in &self.spec.before_turn_hooks {
579            reg.turn().before(Arc::clone(hook));
580        }
581        for hook in &self.spec.before_tool_call_hooks {
582            reg.tool_calls().before(Arc::clone(hook));
583        }
584        for hook in &self.spec.after_tool_call_hooks {
585            reg.tool_calls().after(Arc::clone(hook));
586        }
587        for hook in &self.spec.after_turn_hooks {
588            reg.turn().after(Arc::clone(hook));
589        }
590        for hook in &self.spec.checkpoint_hooks {
591            reg.turn().checkpoint(Arc::clone(hook));
592        }
593        for hook in &self.spec.assistant_stream_hooks {
594            reg.output().stream(Arc::clone(hook));
595        }
596        for hook in &self.spec.assistant_response_hooks {
597            reg.output().response(Arc::clone(hook));
598        }
599        if let Some(projector) = &self.spec.tool_result_projector {
600            reg.tool_results().projector(Arc::clone(projector))?;
601        }
602        for hook in &self.spec.runtime_event_hooks {
603            reg.session().on_event(Arc::clone(hook));
604        }
605        for hook in &self.spec.session_config_mutators {
606            reg.session().config_mutator(Arc::clone(hook));
607        }
608        for (def, handler) in &self.spec.plugin_queries {
609            reg.operations().query(def.clone(), Arc::clone(handler))?;
610        }
611        for (def, handler) in &self.spec.plugin_commands {
612            reg.operations().command(def.clone(), Arc::clone(handler))?;
613        }
614        for (def, handler) in &self.spec.plugin_tasks {
615            reg.operations().task(def.clone(), Arc::clone(handler))?;
616        }
617        for (priority, transform) in &self.spec.turn_context_transforms {
618            reg.context().prepare_turn(*priority, Arc::clone(transform));
619        }
620        for (priority, compactor) in &self.spec.context_compactors {
621            reg.context().compact(*priority, Arc::clone(compactor));
622        }
623        Ok(())
624    }
625}