Skip to main content

lash_core/plugin/
registrar.rs

1use std::collections::{BTreeMap, BTreeSet};
2use std::sync::Arc;
3
4use super::*;
5
6#[derive(Clone)]
7pub(crate) struct RegisteredHook<T> {
8    pub(crate) plugin_id: String,
9    pub(crate) hook: T,
10}
11
12#[derive(Clone)]
13pub(crate) struct RegisteredExclusiveHook<T> {
14    pub(crate) plugin_id: String,
15    pub(crate) hook: T,
16}
17
18pub(crate) fn current_registration_owner(registering_plugin_id: &Option<String>) -> String {
19    registering_plugin_id
20        .clone()
21        .unwrap_or_else(|| "__unknown__".to_string())
22}
23
24fn push_registered_hook<T>(
25    hooks: &mut Vec<RegisteredHook<T>>,
26    registering_plugin_id: &Option<String>,
27    hook: T,
28) {
29    hooks.push(RegisteredHook {
30        plugin_id: current_registration_owner(registering_plugin_id),
31        hook,
32    });
33}
34
35fn exclusive_hook_owner(
36    existing_owner: Option<&str>,
37    registering_plugin_id: &Option<String>,
38    hook_kind: &str,
39    hook_name: &str,
40) -> Result<String, PluginError> {
41    let plugin_id = registering_plugin_id
42        .clone()
43        .ok_or_else(|| PluginError::Registration("missing registering plugin id".to_string()))?;
44    if let Some(existing) = existing_owner {
45        return Err(PluginError::Registration(format!(
46            "duplicate {hook_kind} for `{hook_name}`: `{plugin_id}` conflicts with `{existing}`"
47        )));
48    }
49    Ok(plugin_id)
50}
51
52fn register_singleton_hook<H>(
53    slot: &mut Option<RegisteredExclusiveHook<H>>,
54    registering_plugin_id: &Option<String>,
55    hook_kind: &str,
56    hook_name: &str,
57    hook: H,
58) -> Result<(), PluginError> {
59    let plugin_id = exclusive_hook_owner(
60        slot.as_ref()
61            .map(|registered| registered.plugin_id.as_str()),
62        registering_plugin_id,
63        hook_kind,
64        hook_name,
65    )?;
66    *slot = Some(RegisteredExclusiveHook { plugin_id, hook });
67    Ok(())
68}
69
70#[derive(Clone, Default)]
71pub(crate) struct PluginContributions {
72    pub(crate) tool_providers: Vec<Arc<dyn ToolProvider>>,
73    pub(crate) host_events: Vec<crate::HostEvent>,
74    pub(crate) prompt_contributors: Vec<RegisteredHook<PromptContributor>>,
75    pub(crate) tool_surface_contributors: Vec<RegisteredHook<ToolSurfaceContributor>>,
76    pub(crate) tool_discovery_contributors: Vec<RegisteredHook<ToolDiscoveryContributor>>,
77    pub(crate) before_turn_hooks: Vec<RegisteredHook<BeforeTurnHook>>,
78    pub(crate) before_tool_call_hooks: Vec<RegisteredHook<BeforeToolCallHook>>,
79    pub(crate) after_tool_call_hooks: Vec<RegisteredHook<AfterToolCallHook>>,
80    pub(crate) after_turn_hooks: Vec<RegisteredHook<AfterTurnHook>>,
81    pub(crate) checkpoint_hooks: Vec<RegisteredHook<CheckpointHook>>,
82    pub(crate) assistant_stream_hooks: Vec<RegisteredHook<AssistantStreamHook>>,
83    pub(crate) assistant_response_hooks: Vec<RegisteredHook<AssistantResponseHook>>,
84    pub(crate) tool_result_projector: Option<RegisteredExclusiveHook<ToolResultProjector>>,
85    pub(crate) runtime_event_hooks: Vec<RegisteredHook<PluginLifecycleEventHook>>,
86    pub(crate) session_config_mutators: Vec<SessionConfigMutator>,
87    pub(crate) plugin_actions: BTreeMap<String, RegisteredPluginAction>,
88    pub(crate) turn_context_transforms: Vec<(i32, RegisteredHook<Arc<dyn TurnContextTransform>>)>,
89    pub(crate) history_rewriters: Vec<(i32, RegisteredHook<Arc<dyn HistoryRewriter>>)>,
90    pub(crate) protocol_session: Option<RegisteredExclusiveHook<Arc<dyn ProtocolSessionPlugin>>>,
91    pub(crate) protocol_driver: Option<RegisteredExclusiveHook<Arc<dyn ProtocolDriverPlugin>>>,
92    pub(crate) code_executor: Option<RegisteredExclusiveHook<Arc<dyn CodeExecutorPlugin>>>,
93    pub(crate) assistant_prose_projector:
94        Option<RegisteredExclusiveHook<Arc<dyn AssistantProseProjectorPlugin>>>,
95}
96
97pub struct PluginRegistrar {
98    pub(crate) tool_names: BTreeSet<String>,
99    pub(crate) contributions: PluginContributions,
100    pub(crate) registering_plugin_id: Option<String>,
101}
102
103pub struct ToolRegistrations<'a> {
104    reg: &'a mut PluginRegistrar,
105}
106
107impl ToolRegistrations<'_> {
108    pub fn provider(self, provider: Arc<dyn ToolProvider>) -> Result<(), PluginError> {
109        self.reg.add_tool_provider(provider)
110    }
111}
112
113pub struct HostEventRegistrations<'a> {
114    reg: &'a mut PluginRegistrar,
115}
116
117impl HostEventRegistrations<'_> {
118    pub fn declare(self, event: crate::HostEvent) -> Result<(), PluginError> {
119        self.reg.add_host_event(event)
120    }
121}
122
123pub struct PromptRegistrations<'a> {
124    reg: &'a mut PluginRegistrar,
125}
126
127impl PromptRegistrations<'_> {
128    pub fn contribute(self, contributor: PromptContributor) {
129        self.reg.add_prompt_contributor(contributor);
130    }
131}
132
133pub struct SurfaceRegistrations<'a> {
134    reg: &'a mut PluginRegistrar,
135}
136
137impl SurfaceRegistrations<'_> {
138    pub fn contribute(self, contributor: ToolSurfaceContributor) {
139        self.reg.add_tool_surface_contributor(contributor);
140    }
141}
142
143pub struct DiscoveryRegistrations<'a> {
144    reg: &'a mut PluginRegistrar,
145}
146
147impl DiscoveryRegistrations<'_> {
148    pub fn contribute(self, contributor: ToolDiscoveryContributor) {
149        self.reg.add_tool_discovery_contributor(contributor);
150    }
151}
152
153pub struct TurnRegistrations<'a> {
154    reg: &'a mut PluginRegistrar,
155}
156
157impl TurnRegistrations<'_> {
158    pub fn before(self, hook: BeforeTurnHook) {
159        self.reg.add_before_turn_hook(hook);
160    }
161
162    pub fn after(self, hook: AfterTurnHook) {
163        self.reg.add_after_turn_hook(hook);
164    }
165
166    pub fn checkpoint(self, hook: CheckpointHook) {
167        self.reg.add_checkpoint_hook(hook);
168    }
169}
170
171pub struct ToolCallRegistrations<'a> {
172    reg: &'a mut PluginRegistrar,
173}
174
175impl ToolCallRegistrations<'_> {
176    pub fn before(self, hook: BeforeToolCallHook) {
177        self.reg.add_before_tool_call_hook(hook);
178    }
179
180    pub fn after(self, hook: AfterToolCallHook) {
181        self.reg.add_after_tool_call_hook(hook);
182    }
183}
184
185pub struct OutputRegistrations<'a> {
186    reg: &'a mut PluginRegistrar,
187}
188
189impl OutputRegistrations<'_> {
190    pub fn stream(self, hook: AssistantStreamHook) {
191        self.reg.add_assistant_stream_hook(hook);
192    }
193
194    pub fn response(self, hook: AssistantResponseHook) {
195        self.reg.add_assistant_response_hook(hook);
196    }
197
198    pub fn assistant_prose_projector(
199        self,
200        provider: Arc<dyn AssistantProseProjectorPlugin>,
201    ) -> Result<(), PluginError> {
202        self.reg.add_assistant_prose_projector(provider)
203    }
204}
205
206pub struct ToolResultRegistrations<'a> {
207    reg: &'a mut PluginRegistrar,
208}
209
210impl ToolResultRegistrations<'_> {
211    pub fn projector(self, hook: ToolResultProjector) -> Result<(), PluginError> {
212        self.reg.add_tool_result_projector(hook)
213    }
214}
215
216pub struct SessionRegistrations<'a> {
217    reg: &'a mut PluginRegistrar,
218}
219
220impl SessionRegistrations<'_> {
221    pub fn on_event(self, hook: PluginLifecycleEventHook) {
222        push_registered_hook(
223            &mut self.reg.contributions.runtime_event_hooks,
224            &self.reg.registering_plugin_id,
225            hook,
226        );
227    }
228
229    pub fn config_mutator(self, hook: SessionConfigMutator) {
230        self.reg.contributions.session_config_mutators.push(hook);
231    }
232}
233
234pub struct PluginActionRegistrations<'a> {
235    reg: &'a mut PluginRegistrar,
236}
237
238impl PluginActionRegistrations<'_> {
239    pub fn op(self, def: PluginActionDef, handler: PluginActionHandler) -> Result<(), PluginError> {
240        self.reg.add_plugin_action(def, handler)
241    }
242
243    pub fn typed<Op, F, Fut>(self, handler: F) -> Result<(), PluginError>
244    where
245        Op: PluginAction,
246        F: Fn(PluginActionContext, Op::Args) -> Fut + Send + Sync + 'static,
247        Fut: Future<Output = Result<Op::Output, PluginActionFailure>> + Send + 'static,
248    {
249        self.op(
250            plugin_action_def::<Op>(),
251            Arc::new(move |ctx, args| {
252                let parsed = serde_json::from_value::<Op::Args>(args);
253                match parsed {
254                    Ok(args) => {
255                        let fut = handler(ctx, args);
256                        Box::pin(async move {
257                            match fut.await {
258                                Ok(output) => match serde_json::to_value(output) {
259                                    Ok(value) => ToolResult::ok(value),
260                                    Err(err) => ToolResult::err(serde_json::json!(format!(
261                                        "failed to serialize {} output: {err}",
262                                        Op::NAME
263                                    ))),
264                                },
265                                Err(err) => ToolResult::err(serde_json::json!(err.to_string())),
266                            }
267                        })
268                    }
269                    Err(err) => Box::pin(async move {
270                        ToolResult::err(serde_json::json!(format!(
271                            "invalid {} args: {err}",
272                            Op::NAME
273                        )))
274                    }),
275                }
276            }),
277        )
278    }
279}
280
281pub struct HistoryRegistrations<'a> {
282    reg: &'a mut PluginRegistrar,
283}
284
285impl HistoryRegistrations<'_> {
286    /// Register a per-turn context transform. Higher priority runs first.
287    pub fn prepare_turn(self, priority: i32, transform: Arc<dyn TurnContextTransform>) {
288        self.reg.contributions.turn_context_transforms.push((
289            priority,
290            RegisteredHook {
291                plugin_id: current_registration_owner(&self.reg.registering_plugin_id),
292                hook: transform,
293            },
294        ));
295    }
296
297    /// Register a permanent history rewriter. Higher priority runs first.
298    pub fn rewrite(self, priority: i32, rewriter: Arc<dyn HistoryRewriter>) {
299        self.reg.contributions.history_rewriters.push((
300            priority,
301            RegisteredHook {
302                plugin_id: current_registration_owner(&self.reg.registering_plugin_id),
303                hook: rewriter,
304            },
305        ));
306    }
307}
308
309pub struct ProtocolRegistrations<'a> {
310    reg: &'a mut PluginRegistrar,
311}
312
313impl ProtocolRegistrations<'_> {
314    pub fn session(self, provider: Arc<dyn ProtocolSessionPlugin>) -> Result<(), PluginError> {
315        self.reg.add_protocol_session(provider)
316    }
317
318    /// Claim the session-wide singleton protocol-driver slot. The
319    /// plugin provides a `ProtocolDriverHandle` via `build_preamble`.
320    /// The active plugin stack must install exactly one protocol driver.
321    pub fn protocol_driver(
322        self,
323        provider: Arc<dyn ProtocolDriverPlugin>,
324    ) -> Result<(), PluginError> {
325        self.reg.add_protocol_driver(provider)
326    }
327}
328
329pub struct ExecutionRegistrations<'a> {
330    reg: &'a mut PluginRegistrar,
331}
332
333impl ExecutionRegistrations<'_> {
334    pub fn code_executor(self, provider: Arc<dyn CodeExecutorPlugin>) -> Result<(), PluginError> {
335        self.reg.add_code_executor(provider)
336    }
337}
338
339impl PluginRegistrar {
340    pub(crate) fn new() -> Self {
341        Self {
342            tool_names: BTreeSet::new(),
343            contributions: PluginContributions::default(),
344            registering_plugin_id: None,
345        }
346    }
347
348    pub fn tools(&mut self) -> ToolRegistrations<'_> {
349        ToolRegistrations { reg: self }
350    }
351
352    pub fn host_events(&mut self) -> HostEventRegistrations<'_> {
353        HostEventRegistrations { reg: self }
354    }
355
356    pub fn prompt(&mut self) -> PromptRegistrations<'_> {
357        PromptRegistrations { reg: self }
358    }
359
360    pub fn surface(&mut self) -> SurfaceRegistrations<'_> {
361        SurfaceRegistrations { reg: self }
362    }
363
364    pub fn discovery(&mut self) -> DiscoveryRegistrations<'_> {
365        DiscoveryRegistrations { reg: self }
366    }
367
368    pub fn turn(&mut self) -> TurnRegistrations<'_> {
369        TurnRegistrations { reg: self }
370    }
371
372    pub fn tool_calls(&mut self) -> ToolCallRegistrations<'_> {
373        ToolCallRegistrations { reg: self }
374    }
375
376    pub fn output(&mut self) -> OutputRegistrations<'_> {
377        OutputRegistrations { reg: self }
378    }
379
380    pub fn tool_results(&mut self) -> ToolResultRegistrations<'_> {
381        ToolResultRegistrations { reg: self }
382    }
383
384    pub fn session(&mut self) -> SessionRegistrations<'_> {
385        SessionRegistrations { reg: self }
386    }
387
388    pub fn actions(&mut self) -> PluginActionRegistrations<'_> {
389        PluginActionRegistrations { reg: self }
390    }
391
392    pub fn history(&mut self) -> HistoryRegistrations<'_> {
393        HistoryRegistrations { reg: self }
394    }
395
396    pub fn protocol(&mut self) -> ProtocolRegistrations<'_> {
397        ProtocolRegistrations { reg: self }
398    }
399
400    pub fn execution(&mut self) -> ExecutionRegistrations<'_> {
401        ExecutionRegistrations { reg: self }
402    }
403
404    fn add_tool_provider(&mut self, provider: Arc<dyn ToolProvider>) -> Result<(), PluginError> {
405        for manifest in provider.tool_manifests() {
406            if !self.tool_names.insert(manifest.name.clone()) {
407                return Err(PluginError::Registration(format!(
408                    "duplicate plugin tool name `{}`",
409                    manifest.name
410                )));
411            }
412        }
413        self.contributions.tool_providers.push(provider);
414        Ok(())
415    }
416
417    fn add_host_event(&mut self, event: crate::HostEvent) -> Result<(), PluginError> {
418        if self
419            .contributions
420            .host_events
421            .iter()
422            .any(|existing| existing.key() == event.key())
423        {
424            return Err(PluginError::Registration(format!(
425                "duplicate host event `{}.{}.{}`",
426                event.resource_type, event.alias, event.event
427            )));
428        }
429        self.contributions.host_events.push(event);
430        Ok(())
431    }
432
433    fn add_prompt_contributor(&mut self, contributor: PromptContributor) {
434        push_registered_hook(
435            &mut self.contributions.prompt_contributors,
436            &self.registering_plugin_id,
437            contributor,
438        );
439    }
440
441    fn add_tool_surface_contributor(&mut self, contributor: ToolSurfaceContributor) {
442        push_registered_hook(
443            &mut self.contributions.tool_surface_contributors,
444            &self.registering_plugin_id,
445            contributor,
446        );
447    }
448
449    fn add_tool_discovery_contributor(&mut self, contributor: ToolDiscoveryContributor) {
450        push_registered_hook(
451            &mut self.contributions.tool_discovery_contributors,
452            &self.registering_plugin_id,
453            contributor,
454        );
455    }
456
457    fn add_before_turn_hook(&mut self, hook: BeforeTurnHook) {
458        push_registered_hook(
459            &mut self.contributions.before_turn_hooks,
460            &self.registering_plugin_id,
461            hook,
462        );
463    }
464
465    fn add_before_tool_call_hook(&mut self, hook: BeforeToolCallHook) {
466        push_registered_hook(
467            &mut self.contributions.before_tool_call_hooks,
468            &self.registering_plugin_id,
469            hook,
470        );
471    }
472
473    fn add_after_tool_call_hook(&mut self, hook: AfterToolCallHook) {
474        push_registered_hook(
475            &mut self.contributions.after_tool_call_hooks,
476            &self.registering_plugin_id,
477            hook,
478        );
479    }
480
481    fn add_after_turn_hook(&mut self, hook: AfterTurnHook) {
482        push_registered_hook(
483            &mut self.contributions.after_turn_hooks,
484            &self.registering_plugin_id,
485            hook,
486        );
487    }
488
489    fn add_checkpoint_hook(&mut self, hook: CheckpointHook) {
490        push_registered_hook(
491            &mut self.contributions.checkpoint_hooks,
492            &self.registering_plugin_id,
493            hook,
494        );
495    }
496
497    fn add_assistant_stream_hook(&mut self, hook: AssistantStreamHook) {
498        push_registered_hook(
499            &mut self.contributions.assistant_stream_hooks,
500            &self.registering_plugin_id,
501            hook,
502        );
503    }
504
505    fn add_assistant_response_hook(&mut self, hook: AssistantResponseHook) {
506        push_registered_hook(
507            &mut self.contributions.assistant_response_hooks,
508            &self.registering_plugin_id,
509            hook,
510        );
511    }
512
513    fn add_assistant_prose_projector(
514        &mut self,
515        provider: Arc<dyn AssistantProseProjectorPlugin>,
516    ) -> Result<(), PluginError> {
517        register_singleton_hook(
518            &mut self.contributions.assistant_prose_projector,
519            &self.registering_plugin_id,
520            "assistant prose projector",
521            "assistant_prose_projector",
522            provider,
523        )
524    }
525
526    fn add_tool_result_projector(&mut self, hook: ToolResultProjector) -> Result<(), PluginError> {
527        register_singleton_hook(
528            &mut self.contributions.tool_result_projector,
529            &self.registering_plugin_id,
530            "tool result projector",
531            "model_observation",
532            hook,
533        )
534    }
535
536    fn add_plugin_action(
537        &mut self,
538        def: PluginActionDef,
539        handler: PluginActionHandler,
540    ) -> Result<(), PluginError> {
541        if self.contributions.plugin_actions.contains_key(&def.name) {
542            return Err(PluginError::Registration(format!(
543                "duplicate plugin action name `{}`",
544                def.name
545            )));
546        }
547        self.contributions
548            .plugin_actions
549            .insert(def.name.clone(), RegisteredPluginAction { def, handler });
550        Ok(())
551    }
552
553    fn add_protocol_session(
554        &mut self,
555        provider: Arc<dyn ProtocolSessionPlugin>,
556    ) -> Result<(), PluginError> {
557        register_singleton_hook(
558            &mut self.contributions.protocol_session,
559            &self.registering_plugin_id,
560            "protocol session capability",
561            "protocol_session",
562            provider,
563        )
564    }
565
566    fn add_code_executor(
567        &mut self,
568        provider: Arc<dyn CodeExecutorPlugin>,
569    ) -> Result<(), PluginError> {
570        register_singleton_hook(
571            &mut self.contributions.code_executor,
572            &self.registering_plugin_id,
573            "code executor capability",
574            "code_executor",
575            provider,
576        )
577    }
578
579    fn add_protocol_driver(
580        &mut self,
581        provider: Arc<dyn ProtocolDriverPlugin>,
582    ) -> Result<(), PluginError> {
583        register_singleton_hook(
584            &mut self.contributions.protocol_driver,
585            &self.registering_plugin_id,
586            "protocol driver capability",
587            "protocol_driver",
588            provider,
589        )
590    }
591}