Skip to main content

lash_core/plugin/
registrar.rs

1use std::collections::{BTreeMap, BTreeSet};
2use std::sync::Arc;
3
4use super::*;
5
6#[derive(Clone)]
7pub(crate) struct RegisteredHook<T> {
8    pub(crate) plugin_id: String,
9    pub(crate) hook: T,
10}
11
12#[derive(Clone)]
13pub(crate) struct RegisteredExclusiveHook<T> {
14    pub(crate) plugin_id: String,
15    pub(crate) hook: T,
16}
17
18pub(crate) fn current_registration_owner(registering_plugin_id: &Option<String>) -> String {
19    registering_plugin_id
20        .clone()
21        .unwrap_or_else(|| "__unknown__".to_string())
22}
23
24fn push_registered_hook<T>(
25    hooks: &mut Vec<RegisteredHook<T>>,
26    registering_plugin_id: &Option<String>,
27    hook: T,
28) {
29    hooks.push(RegisteredHook {
30        plugin_id: current_registration_owner(registering_plugin_id),
31        hook,
32    });
33}
34
35fn exclusive_hook_owner(
36    existing_owner: Option<&str>,
37    registering_plugin_id: &Option<String>,
38    hook_kind: &str,
39    hook_name: &str,
40) -> Result<String, PluginError> {
41    let plugin_id = registering_plugin_id
42        .clone()
43        .ok_or_else(|| PluginError::Registration("missing registering plugin id".to_string()))?;
44    if let Some(existing) = existing_owner {
45        return Err(PluginError::Registration(format!(
46            "duplicate {hook_kind} for `{hook_name}`: `{plugin_id}` conflicts with `{existing}`"
47        )));
48    }
49    Ok(plugin_id)
50}
51
52fn register_singleton_hook<H>(
53    slot: &mut Option<RegisteredExclusiveHook<H>>,
54    registering_plugin_id: &Option<String>,
55    hook_kind: &str,
56    hook_name: &str,
57    hook: H,
58) -> Result<(), PluginError> {
59    let plugin_id = exclusive_hook_owner(
60        slot.as_ref()
61            .map(|registered| registered.plugin_id.as_str()),
62        registering_plugin_id,
63        hook_kind,
64        hook_name,
65    )?;
66    *slot = Some(RegisteredExclusiveHook { plugin_id, hook });
67    Ok(())
68}
69
70#[derive(Clone, Default)]
71pub(crate) struct PluginContributions {
72    pub(crate) tool_providers: Vec<Arc<dyn ToolProvider>>,
73    pub(crate) host_events: Vec<crate::HostEvent>,
74    pub(crate) trigger_registry: Option<Arc<SessionTriggerRegistry>>,
75    pub(crate) prompt_contributors: Vec<RegisteredHook<PromptContributor>>,
76    pub(crate) tool_surface_contributors: Vec<RegisteredHook<ToolSurfaceContributor>>,
77    pub(crate) tool_discovery_contributors: Vec<RegisteredHook<ToolDiscoveryContributor>>,
78    pub(crate) before_turn_hooks: Vec<RegisteredHook<BeforeTurnHook>>,
79    pub(crate) before_tool_call_hooks: Vec<RegisteredHook<BeforeToolCallHook>>,
80    pub(crate) after_tool_call_hooks: Vec<RegisteredHook<AfterToolCallHook>>,
81    pub(crate) after_turn_hooks: Vec<RegisteredHook<AfterTurnHook>>,
82    pub(crate) checkpoint_hooks: Vec<RegisteredHook<CheckpointHook>>,
83    pub(crate) assistant_stream_hooks: Vec<RegisteredHook<AssistantStreamHook>>,
84    pub(crate) assistant_response_hooks: Vec<RegisteredHook<AssistantResponseHook>>,
85    pub(crate) tool_result_projector: Option<RegisteredExclusiveHook<ToolResultProjector>>,
86    pub(crate) runtime_event_hooks: Vec<RegisteredHook<PluginLifecycleEventHook>>,
87    pub(crate) session_config_mutators: Vec<SessionConfigMutator>,
88    pub(crate) plugin_actions: BTreeMap<String, RegisteredPluginAction>,
89    pub(crate) turn_context_transforms: Vec<(i32, RegisteredHook<Arc<dyn TurnContextTransform>>)>,
90    pub(crate) history_rewriters: Vec<(i32, RegisteredHook<Arc<dyn HistoryRewriter>>)>,
91    pub(crate) protocol_session: Option<RegisteredExclusiveHook<Arc<dyn ProtocolSessionPlugin>>>,
92    pub(crate) protocol_driver: Option<RegisteredExclusiveHook<Arc<dyn ProtocolDriverPlugin>>>,
93    pub(crate) code_executor: Option<RegisteredExclusiveHook<Arc<dyn CodeExecutorPlugin>>>,
94    pub(crate) assistant_prose_projector:
95        Option<RegisteredExclusiveHook<Arc<dyn AssistantProseProjectorPlugin>>>,
96}
97
98pub struct PluginRegistrar {
99    pub(crate) tool_names: BTreeSet<String>,
100    pub(crate) contributions: PluginContributions,
101    pub(crate) registering_plugin_id: Option<String>,
102}
103
104pub struct ToolRegistrations<'a> {
105    reg: &'a mut PluginRegistrar,
106}
107
108impl ToolRegistrations<'_> {
109    pub fn provider(self, provider: Arc<dyn ToolProvider>) -> Result<(), PluginError> {
110        self.reg.add_tool_provider(provider)
111    }
112}
113
114pub struct HostEventRegistrations<'a> {
115    reg: &'a mut PluginRegistrar,
116}
117
118impl HostEventRegistrations<'_> {
119    pub fn declare(self, event: crate::HostEvent) -> Result<(), PluginError> {
120        self.reg.add_host_event(event)
121    }
122}
123
124pub(crate) struct TriggerRegistrations<'a> {
125    reg: &'a mut PluginRegistrar,
126}
127
128impl TriggerRegistrations<'_> {
129    pub(crate) fn registry(self, registry: Arc<SessionTriggerRegistry>) -> Result<(), PluginError> {
130        if self.reg.contributions.trigger_registry.is_some() {
131            return Err(PluginError::Registration(
132                "duplicate session trigger registry".to_string(),
133            ));
134        }
135        self.reg.contributions.trigger_registry = Some(registry);
136        Ok(())
137    }
138}
139
140pub struct PromptRegistrations<'a> {
141    reg: &'a mut PluginRegistrar,
142}
143
144impl PromptRegistrations<'_> {
145    pub fn contribute(self, contributor: PromptContributor) {
146        self.reg.add_prompt_contributor(contributor);
147    }
148}
149
150pub struct SurfaceRegistrations<'a> {
151    reg: &'a mut PluginRegistrar,
152}
153
154impl SurfaceRegistrations<'_> {
155    pub fn contribute(self, contributor: ToolSurfaceContributor) {
156        self.reg.add_tool_surface_contributor(contributor);
157    }
158}
159
160pub struct DiscoveryRegistrations<'a> {
161    reg: &'a mut PluginRegistrar,
162}
163
164impl DiscoveryRegistrations<'_> {
165    pub fn contribute(self, contributor: ToolDiscoveryContributor) {
166        self.reg.add_tool_discovery_contributor(contributor);
167    }
168}
169
170pub struct TurnRegistrations<'a> {
171    reg: &'a mut PluginRegistrar,
172}
173
174impl TurnRegistrations<'_> {
175    pub fn before(self, hook: BeforeTurnHook) {
176        self.reg.add_before_turn_hook(hook);
177    }
178
179    pub fn after(self, hook: AfterTurnHook) {
180        self.reg.add_after_turn_hook(hook);
181    }
182
183    pub fn checkpoint(self, hook: CheckpointHook) {
184        self.reg.add_checkpoint_hook(hook);
185    }
186}
187
188pub struct ToolCallRegistrations<'a> {
189    reg: &'a mut PluginRegistrar,
190}
191
192impl ToolCallRegistrations<'_> {
193    pub fn before(self, hook: BeforeToolCallHook) {
194        self.reg.add_before_tool_call_hook(hook);
195    }
196
197    pub fn after(self, hook: AfterToolCallHook) {
198        self.reg.add_after_tool_call_hook(hook);
199    }
200}
201
202pub struct OutputRegistrations<'a> {
203    reg: &'a mut PluginRegistrar,
204}
205
206impl OutputRegistrations<'_> {
207    pub fn stream(self, hook: AssistantStreamHook) {
208        self.reg.add_assistant_stream_hook(hook);
209    }
210
211    pub fn response(self, hook: AssistantResponseHook) {
212        self.reg.add_assistant_response_hook(hook);
213    }
214
215    pub fn assistant_prose_projector(
216        self,
217        provider: Arc<dyn AssistantProseProjectorPlugin>,
218    ) -> Result<(), PluginError> {
219        self.reg.add_assistant_prose_projector(provider)
220    }
221}
222
223pub struct ToolResultRegistrations<'a> {
224    reg: &'a mut PluginRegistrar,
225}
226
227impl ToolResultRegistrations<'_> {
228    pub fn projector(self, hook: ToolResultProjector) -> Result<(), PluginError> {
229        self.reg.add_tool_result_projector(hook)
230    }
231}
232
233pub struct SessionRegistrations<'a> {
234    reg: &'a mut PluginRegistrar,
235}
236
237impl SessionRegistrations<'_> {
238    pub fn on_event(self, hook: PluginLifecycleEventHook) {
239        push_registered_hook(
240            &mut self.reg.contributions.runtime_event_hooks,
241            &self.reg.registering_plugin_id,
242            hook,
243        );
244    }
245
246    pub fn config_mutator(self, hook: SessionConfigMutator) {
247        self.reg.contributions.session_config_mutators.push(hook);
248    }
249}
250
251pub struct PluginActionRegistrations<'a> {
252    reg: &'a mut PluginRegistrar,
253}
254
255impl PluginActionRegistrations<'_> {
256    pub fn op(self, def: PluginActionDef, handler: PluginActionHandler) -> Result<(), PluginError> {
257        self.reg.add_plugin_action(def, handler)
258    }
259
260    pub fn typed<Op, F, Fut>(self, handler: F) -> Result<(), PluginError>
261    where
262        Op: PluginAction,
263        F: Fn(PluginActionContext, Op::Args) -> Fut + Send + Sync + 'static,
264        Fut: Future<Output = Result<Op::Output, PluginActionFailure>> + Send + 'static,
265    {
266        self.op(
267            plugin_action_def::<Op>(),
268            Arc::new(move |ctx, args| {
269                let parsed = serde_json::from_value::<Op::Args>(args);
270                match parsed {
271                    Ok(args) => {
272                        let fut = handler(ctx, args);
273                        Box::pin(async move {
274                            match fut.await {
275                                Ok(output) => match serde_json::to_value(output) {
276                                    Ok(value) => ToolResult::ok(value),
277                                    Err(err) => ToolResult::err(serde_json::json!(format!(
278                                        "failed to serialize {} output: {err}",
279                                        Op::NAME
280                                    ))),
281                                },
282                                Err(err) => ToolResult::err(serde_json::json!(err.to_string())),
283                            }
284                        })
285                    }
286                    Err(err) => Box::pin(async move {
287                        ToolResult::err(serde_json::json!(format!(
288                            "invalid {} args: {err}",
289                            Op::NAME
290                        )))
291                    }),
292                }
293            }),
294        )
295    }
296}
297
298pub struct HistoryRegistrations<'a> {
299    reg: &'a mut PluginRegistrar,
300}
301
302impl HistoryRegistrations<'_> {
303    /// Register a per-turn context transform. Higher priority runs first.
304    pub fn prepare_turn(self, priority: i32, transform: Arc<dyn TurnContextTransform>) {
305        self.reg.contributions.turn_context_transforms.push((
306            priority,
307            RegisteredHook {
308                plugin_id: current_registration_owner(&self.reg.registering_plugin_id),
309                hook: transform,
310            },
311        ));
312    }
313
314    /// Register a permanent history rewriter. Higher priority runs first.
315    pub fn rewrite(self, priority: i32, rewriter: Arc<dyn HistoryRewriter>) {
316        self.reg.contributions.history_rewriters.push((
317            priority,
318            RegisteredHook {
319                plugin_id: current_registration_owner(&self.reg.registering_plugin_id),
320                hook: rewriter,
321            },
322        ));
323    }
324}
325
326pub struct ProtocolRegistrations<'a> {
327    reg: &'a mut PluginRegistrar,
328}
329
330impl ProtocolRegistrations<'_> {
331    pub fn session(self, provider: Arc<dyn ProtocolSessionPlugin>) -> Result<(), PluginError> {
332        self.reg.add_protocol_session(provider)
333    }
334
335    /// Claim the session-wide singleton protocol-driver slot. The
336    /// plugin provides a `ProtocolDriverHandle` via `build_preamble`.
337    /// The active plugin stack must install exactly one protocol driver.
338    pub fn protocol_driver(
339        self,
340        provider: Arc<dyn ProtocolDriverPlugin>,
341    ) -> Result<(), PluginError> {
342        self.reg.add_protocol_driver(provider)
343    }
344}
345
346pub struct ExecutionRegistrations<'a> {
347    reg: &'a mut PluginRegistrar,
348}
349
350impl ExecutionRegistrations<'_> {
351    pub fn code_executor(self, provider: Arc<dyn CodeExecutorPlugin>) -> Result<(), PluginError> {
352        self.reg.add_code_executor(provider)
353    }
354}
355
356impl PluginRegistrar {
357    pub(crate) fn new() -> Self {
358        Self {
359            tool_names: BTreeSet::new(),
360            contributions: PluginContributions::default(),
361            registering_plugin_id: None,
362        }
363    }
364
365    pub fn tools(&mut self) -> ToolRegistrations<'_> {
366        ToolRegistrations { reg: self }
367    }
368
369    pub fn host_events(&mut self) -> HostEventRegistrations<'_> {
370        HostEventRegistrations { reg: self }
371    }
372
373    pub(crate) fn triggers(&mut self) -> TriggerRegistrations<'_> {
374        TriggerRegistrations { reg: self }
375    }
376
377    pub fn prompt(&mut self) -> PromptRegistrations<'_> {
378        PromptRegistrations { reg: self }
379    }
380
381    pub fn surface(&mut self) -> SurfaceRegistrations<'_> {
382        SurfaceRegistrations { reg: self }
383    }
384
385    pub fn discovery(&mut self) -> DiscoveryRegistrations<'_> {
386        DiscoveryRegistrations { reg: self }
387    }
388
389    pub fn turn(&mut self) -> TurnRegistrations<'_> {
390        TurnRegistrations { reg: self }
391    }
392
393    pub fn tool_calls(&mut self) -> ToolCallRegistrations<'_> {
394        ToolCallRegistrations { reg: self }
395    }
396
397    pub fn output(&mut self) -> OutputRegistrations<'_> {
398        OutputRegistrations { reg: self }
399    }
400
401    pub fn tool_results(&mut self) -> ToolResultRegistrations<'_> {
402        ToolResultRegistrations { reg: self }
403    }
404
405    pub fn session(&mut self) -> SessionRegistrations<'_> {
406        SessionRegistrations { reg: self }
407    }
408
409    pub fn actions(&mut self) -> PluginActionRegistrations<'_> {
410        PluginActionRegistrations { reg: self }
411    }
412
413    pub fn history(&mut self) -> HistoryRegistrations<'_> {
414        HistoryRegistrations { reg: self }
415    }
416
417    pub fn protocol(&mut self) -> ProtocolRegistrations<'_> {
418        ProtocolRegistrations { reg: self }
419    }
420
421    pub fn execution(&mut self) -> ExecutionRegistrations<'_> {
422        ExecutionRegistrations { reg: self }
423    }
424
425    fn add_tool_provider(&mut self, provider: Arc<dyn ToolProvider>) -> Result<(), PluginError> {
426        for manifest in provider.tool_manifests() {
427            if !self.tool_names.insert(manifest.name.clone()) {
428                return Err(PluginError::Registration(format!(
429                    "duplicate plugin tool name `{}`",
430                    manifest.name
431                )));
432            }
433        }
434        self.contributions.tool_providers.push(provider);
435        Ok(())
436    }
437
438    fn add_host_event(&mut self, event: crate::HostEvent) -> Result<(), PluginError> {
439        if self
440            .contributions
441            .host_events
442            .iter()
443            .any(|existing| existing.key() == event.key())
444        {
445            return Err(PluginError::Registration(format!(
446                "duplicate host event `{}.{}.{}`",
447                event.resource_type, event.alias, event.event
448            )));
449        }
450        self.contributions.host_events.push(event);
451        Ok(())
452    }
453
454    fn add_prompt_contributor(&mut self, contributor: PromptContributor) {
455        push_registered_hook(
456            &mut self.contributions.prompt_contributors,
457            &self.registering_plugin_id,
458            contributor,
459        );
460    }
461
462    fn add_tool_surface_contributor(&mut self, contributor: ToolSurfaceContributor) {
463        push_registered_hook(
464            &mut self.contributions.tool_surface_contributors,
465            &self.registering_plugin_id,
466            contributor,
467        );
468    }
469
470    fn add_tool_discovery_contributor(&mut self, contributor: ToolDiscoveryContributor) {
471        push_registered_hook(
472            &mut self.contributions.tool_discovery_contributors,
473            &self.registering_plugin_id,
474            contributor,
475        );
476    }
477
478    fn add_before_turn_hook(&mut self, hook: BeforeTurnHook) {
479        push_registered_hook(
480            &mut self.contributions.before_turn_hooks,
481            &self.registering_plugin_id,
482            hook,
483        );
484    }
485
486    fn add_before_tool_call_hook(&mut self, hook: BeforeToolCallHook) {
487        push_registered_hook(
488            &mut self.contributions.before_tool_call_hooks,
489            &self.registering_plugin_id,
490            hook,
491        );
492    }
493
494    fn add_after_tool_call_hook(&mut self, hook: AfterToolCallHook) {
495        push_registered_hook(
496            &mut self.contributions.after_tool_call_hooks,
497            &self.registering_plugin_id,
498            hook,
499        );
500    }
501
502    fn add_after_turn_hook(&mut self, hook: AfterTurnHook) {
503        push_registered_hook(
504            &mut self.contributions.after_turn_hooks,
505            &self.registering_plugin_id,
506            hook,
507        );
508    }
509
510    fn add_checkpoint_hook(&mut self, hook: CheckpointHook) {
511        push_registered_hook(
512            &mut self.contributions.checkpoint_hooks,
513            &self.registering_plugin_id,
514            hook,
515        );
516    }
517
518    fn add_assistant_stream_hook(&mut self, hook: AssistantStreamHook) {
519        push_registered_hook(
520            &mut self.contributions.assistant_stream_hooks,
521            &self.registering_plugin_id,
522            hook,
523        );
524    }
525
526    fn add_assistant_response_hook(&mut self, hook: AssistantResponseHook) {
527        push_registered_hook(
528            &mut self.contributions.assistant_response_hooks,
529            &self.registering_plugin_id,
530            hook,
531        );
532    }
533
534    fn add_assistant_prose_projector(
535        &mut self,
536        provider: Arc<dyn AssistantProseProjectorPlugin>,
537    ) -> Result<(), PluginError> {
538        register_singleton_hook(
539            &mut self.contributions.assistant_prose_projector,
540            &self.registering_plugin_id,
541            "assistant prose projector",
542            "assistant_prose_projector",
543            provider,
544        )
545    }
546
547    fn add_tool_result_projector(&mut self, hook: ToolResultProjector) -> Result<(), PluginError> {
548        register_singleton_hook(
549            &mut self.contributions.tool_result_projector,
550            &self.registering_plugin_id,
551            "tool result projector",
552            "model_observation",
553            hook,
554        )
555    }
556
557    fn add_plugin_action(
558        &mut self,
559        def: PluginActionDef,
560        handler: PluginActionHandler,
561    ) -> Result<(), PluginError> {
562        if self.contributions.plugin_actions.contains_key(&def.name) {
563            return Err(PluginError::Registration(format!(
564                "duplicate plugin action name `{}`",
565                def.name
566            )));
567        }
568        self.contributions
569            .plugin_actions
570            .insert(def.name.clone(), RegisteredPluginAction { def, handler });
571        Ok(())
572    }
573
574    fn add_protocol_session(
575        &mut self,
576        provider: Arc<dyn ProtocolSessionPlugin>,
577    ) -> Result<(), PluginError> {
578        register_singleton_hook(
579            &mut self.contributions.protocol_session,
580            &self.registering_plugin_id,
581            "protocol session capability",
582            "protocol_session",
583            provider,
584        )
585    }
586
587    fn add_code_executor(
588        &mut self,
589        provider: Arc<dyn CodeExecutorPlugin>,
590    ) -> Result<(), PluginError> {
591        register_singleton_hook(
592            &mut self.contributions.code_executor,
593            &self.registering_plugin_id,
594            "code executor capability",
595            "code_executor",
596            provider,
597        )
598    }
599
600    fn add_protocol_driver(
601        &mut self,
602        provider: Arc<dyn ProtocolDriverPlugin>,
603    ) -> Result<(), PluginError> {
604        register_singleton_hook(
605            &mut self.contributions.protocol_driver,
606            &self.registering_plugin_id,
607            "protocol driver capability",
608            "protocol_driver",
609            provider,
610        )
611    }
612}