Skip to main content

lash_core/plugin/
registrar.rs

1use std::collections::{BTreeMap, BTreeSet};
2use std::sync::Arc;
3
4use super::*;
5
6#[derive(Clone)]
7pub(crate) struct RegisteredHook<T> {
8    pub(crate) plugin_id: String,
9    pub(crate) hook: T,
10}
11
12#[derive(Clone)]
13pub(crate) struct RegisteredExclusiveHook<T> {
14    pub(crate) plugin_id: String,
15    pub(crate) hook: T,
16}
17
18pub(crate) fn current_registration_owner(registering_plugin_id: &Option<String>) -> String {
19    registering_plugin_id
20        .clone()
21        .unwrap_or_else(|| "__unknown__".to_string())
22}
23
24fn push_registered_hook<T>(
25    hooks: &mut Vec<RegisteredHook<T>>,
26    registering_plugin_id: &Option<String>,
27    hook: T,
28) {
29    hooks.push(RegisteredHook {
30        plugin_id: current_registration_owner(registering_plugin_id),
31        hook,
32    });
33}
34
35fn push_prioritized_registered_hook<T>(
36    hooks: &mut Vec<(i32, RegisteredHook<T>)>,
37    registering_plugin_id: &Option<String>,
38    priority: i32,
39    hook: T,
40) {
41    hooks.push((
42        priority,
43        RegisteredHook {
44            plugin_id: current_registration_owner(registering_plugin_id),
45            hook,
46        },
47    ));
48}
49
50fn exclusive_hook_owner(
51    existing_owner: Option<&str>,
52    registering_plugin_id: &Option<String>,
53    hook_kind: &str,
54    hook_name: &str,
55) -> Result<String, PluginError> {
56    let plugin_id = registering_plugin_id
57        .clone()
58        .ok_or_else(|| PluginError::Registration("missing registering plugin id".to_string()))?;
59    if let Some(existing) = existing_owner {
60        return Err(PluginError::Registration(format!(
61            "duplicate {hook_kind} for `{hook_name}`: `{plugin_id}` conflicts with `{existing}`"
62        )));
63    }
64    Ok(plugin_id)
65}
66
67fn register_singleton_hook<H>(
68    slot: &mut Option<RegisteredExclusiveHook<H>>,
69    registering_plugin_id: &Option<String>,
70    hook_kind: &str,
71    hook_name: &str,
72    hook: H,
73) -> Result<(), PluginError> {
74    let plugin_id = exclusive_hook_owner(
75        slot.as_ref()
76            .map(|registered| registered.plugin_id.as_str()),
77        registering_plugin_id,
78        hook_kind,
79        hook_name,
80    )?;
81    *slot = Some(RegisteredExclusiveHook { plugin_id, hook });
82    Ok(())
83}
84
85#[derive(Clone, Default)]
86pub(crate) struct PluginContributions {
87    pub(crate) tool_providers: Vec<Arc<dyn ToolProvider>>,
88    pub(crate) host_events: Vec<crate::HostEvent>,
89    pub(crate) prompt_contributors: Vec<RegisteredHook<PromptContributor>>,
90    pub(crate) tool_surface_contributors: Vec<RegisteredHook<ToolSurfaceContributor>>,
91    pub(crate) tool_discovery_contributors: Vec<RegisteredHook<ToolDiscoveryContributor>>,
92    pub(crate) before_turn_hooks: Vec<RegisteredHook<BeforeTurnHook>>,
93    pub(crate) before_tool_call_hooks: Vec<RegisteredHook<BeforeToolCallHook>>,
94    pub(crate) after_tool_call_hooks: Vec<RegisteredHook<AfterToolCallHook>>,
95    pub(crate) after_turn_hooks: Vec<RegisteredHook<AfterTurnHook>>,
96    pub(crate) checkpoint_hooks: Vec<RegisteredHook<CheckpointHook>>,
97    pub(crate) assistant_stream_hooks: Vec<RegisteredHook<AssistantStreamHook>>,
98    pub(crate) assistant_response_hooks: Vec<RegisteredHook<AssistantResponseHook>>,
99    pub(crate) tool_result_projector: Option<RegisteredExclusiveHook<ToolResultProjector>>,
100    pub(crate) runtime_event_hooks: Vec<RegisteredHook<PluginLifecycleEventHook>>,
101    pub(crate) session_config_mutators: Vec<SessionConfigMutator>,
102    pub(crate) plugin_actions: BTreeMap<String, RegisteredPluginAction>,
103    pub(crate) turn_context_transforms: Vec<(i32, RegisteredHook<Arc<dyn TurnContextTransform>>)>,
104    pub(crate) context_compactors: Vec<(i32, RegisteredHook<Arc<dyn ContextCompactor>>)>,
105    pub(crate) protocol_session: Option<RegisteredExclusiveHook<Arc<dyn ProtocolSessionPlugin>>>,
106    pub(crate) protocol_driver: Option<RegisteredExclusiveHook<Arc<dyn ProtocolDriverPlugin>>>,
107    pub(crate) code_executor: Option<RegisteredExclusiveHook<Arc<dyn CodeExecutorPlugin>>>,
108    pub(crate) assistant_prose_projector:
109        Option<RegisteredExclusiveHook<Arc<dyn AssistantProseProjectorPlugin>>>,
110}
111
112pub struct PluginRegistrar {
113    pub(crate) tool_names: BTreeSet<String>,
114    pub(crate) contributions: PluginContributions,
115    pub(crate) registering_plugin_id: Option<String>,
116}
117
118pub struct ToolRegistrations<'a> {
119    reg: &'a mut PluginRegistrar,
120}
121
122impl ToolRegistrations<'_> {
123    pub fn provider(self, provider: Arc<dyn ToolProvider>) -> Result<(), PluginError> {
124        self.reg.add_tool_provider(provider)
125    }
126}
127
128pub struct HostEventRegistrations<'a> {
129    reg: &'a mut PluginRegistrar,
130}
131
132impl HostEventRegistrations<'_> {
133    pub fn declare(self, event: crate::HostEvent) -> Result<(), PluginError> {
134        self.reg.add_host_event(event)
135    }
136}
137
138pub struct PromptRegistrations<'a> {
139    reg: &'a mut PluginRegistrar,
140}
141
142impl PromptRegistrations<'_> {
143    pub fn contribute(self, contributor: PromptContributor) {
144        self.reg.add_prompt_contributor(contributor);
145    }
146}
147
148pub struct SurfaceRegistrations<'a> {
149    reg: &'a mut PluginRegistrar,
150}
151
152impl SurfaceRegistrations<'_> {
153    pub fn contribute(self, contributor: ToolSurfaceContributor) {
154        self.reg.add_tool_surface_contributor(contributor);
155    }
156}
157
158pub struct DiscoveryRegistrations<'a> {
159    reg: &'a mut PluginRegistrar,
160}
161
162impl DiscoveryRegistrations<'_> {
163    pub fn contribute(self, contributor: ToolDiscoveryContributor) {
164        self.reg.add_tool_discovery_contributor(contributor);
165    }
166}
167
168pub struct TurnRegistrations<'a> {
169    reg: &'a mut PluginRegistrar,
170}
171
172impl TurnRegistrations<'_> {
173    pub fn before(self, hook: BeforeTurnHook) {
174        self.reg.add_before_turn_hook(hook);
175    }
176
177    pub fn after(self, hook: AfterTurnHook) {
178        self.reg.add_after_turn_hook(hook);
179    }
180
181    pub fn checkpoint(self, hook: CheckpointHook) {
182        self.reg.add_checkpoint_hook(hook);
183    }
184}
185
186pub struct ToolCallRegistrations<'a> {
187    reg: &'a mut PluginRegistrar,
188}
189
190impl ToolCallRegistrations<'_> {
191    pub fn before(self, hook: BeforeToolCallHook) {
192        self.reg.add_before_tool_call_hook(hook);
193    }
194
195    pub fn after(self, hook: AfterToolCallHook) {
196        self.reg.add_after_tool_call_hook(hook);
197    }
198}
199
200pub struct OutputRegistrations<'a> {
201    reg: &'a mut PluginRegistrar,
202}
203
204impl OutputRegistrations<'_> {
205    pub fn stream(self, hook: AssistantStreamHook) {
206        self.reg.add_assistant_stream_hook(hook);
207    }
208
209    pub fn response(self, hook: AssistantResponseHook) {
210        self.reg.add_assistant_response_hook(hook);
211    }
212
213    pub fn assistant_prose_projector(
214        self,
215        provider: Arc<dyn AssistantProseProjectorPlugin>,
216    ) -> Result<(), PluginError> {
217        self.reg.add_assistant_prose_projector(provider)
218    }
219}
220
221pub struct ToolResultRegistrations<'a> {
222    reg: &'a mut PluginRegistrar,
223}
224
225impl ToolResultRegistrations<'_> {
226    pub fn projector(self, hook: ToolResultProjector) -> Result<(), PluginError> {
227        self.reg.add_tool_result_projector(hook)
228    }
229}
230
231pub struct SessionRegistrations<'a> {
232    reg: &'a mut PluginRegistrar,
233}
234
235impl SessionRegistrations<'_> {
236    pub fn on_event(self, hook: PluginLifecycleEventHook) {
237        push_registered_hook(
238            &mut self.reg.contributions.runtime_event_hooks,
239            &self.reg.registering_plugin_id,
240            hook,
241        );
242    }
243
244    pub fn config_mutator(self, hook: SessionConfigMutator) {
245        self.reg.contributions.session_config_mutators.push(hook);
246    }
247}
248
249pub struct PluginActionRegistrations<'a> {
250    reg: &'a mut PluginRegistrar,
251}
252
253impl PluginActionRegistrations<'_> {
254    pub fn op(self, def: PluginActionDef, handler: PluginActionHandler) -> Result<(), PluginError> {
255        self.reg.add_plugin_action(def, handler)
256    }
257
258    pub fn typed<Op, F, Fut>(self, handler: F) -> Result<(), PluginError>
259    where
260        Op: PluginAction,
261        F: Fn(PluginActionContext, Op::Args) -> Fut + Send + Sync + 'static,
262        Fut: Future<Output = Result<Op::Output, PluginActionFailure>> + Send + 'static,
263    {
264        self.op(
265            plugin_action_def::<Op>(),
266            Arc::new(move |ctx, args| {
267                let parsed = serde_json::from_value::<Op::Args>(args);
268                match parsed {
269                    Ok(args) => {
270                        let fut = handler(ctx, args);
271                        Box::pin(async move {
272                            match fut.await {
273                                Ok(output) => match serde_json::to_value(output) {
274                                    Ok(value) => ToolResult::ok(value),
275                                    Err(err) => ToolResult::err(serde_json::json!(format!(
276                                        "failed to serialize {} output: {err}",
277                                        Op::NAME
278                                    ))),
279                                },
280                                Err(err) => ToolResult::err(serde_json::json!(err.to_string())),
281                            }
282                        })
283                    }
284                    Err(err) => Box::pin(async move {
285                        ToolResult::err(serde_json::json!(format!(
286                            "invalid {} args: {err}",
287                            Op::NAME
288                        )))
289                    }),
290                }
291            }),
292        )
293    }
294}
295
296pub struct ContextRegistrations<'a> {
297    reg: &'a mut PluginRegistrar,
298}
299
300impl ContextRegistrations<'_> {
301    /// Register a per-turn context transform. Higher priority runs first.
302    pub fn prepare_turn(self, priority: i32, transform: Arc<dyn TurnContextTransform>) {
303        push_prioritized_registered_hook(
304            &mut self.reg.contributions.turn_context_transforms,
305            &self.reg.registering_plugin_id,
306            priority,
307            transform,
308        );
309    }
310
311    /// Register an explicit compaction provider. Higher priority runs first.
312    pub fn compact(self, priority: i32, compactor: Arc<dyn ContextCompactor>) {
313        push_prioritized_registered_hook(
314            &mut self.reg.contributions.context_compactors,
315            &self.reg.registering_plugin_id,
316            priority,
317            compactor,
318        );
319    }
320}
321
322pub struct ProtocolRegistrations<'a> {
323    reg: &'a mut PluginRegistrar,
324}
325
326impl ProtocolRegistrations<'_> {
327    pub fn session(self, provider: Arc<dyn ProtocolSessionPlugin>) -> Result<(), PluginError> {
328        self.reg.add_protocol_session(provider)
329    }
330
331    /// Claim the session-wide singleton protocol-driver slot. The
332    /// plugin provides a `ProtocolDriverHandle` via `build_preamble`.
333    /// The active plugin stack must install exactly one protocol driver.
334    pub fn protocol_driver(
335        self,
336        provider: Arc<dyn ProtocolDriverPlugin>,
337    ) -> Result<(), PluginError> {
338        self.reg.add_protocol_driver(provider)
339    }
340}
341
342pub struct ExecutionRegistrations<'a> {
343    reg: &'a mut PluginRegistrar,
344}
345
346impl ExecutionRegistrations<'_> {
347    pub fn code_executor(self, provider: Arc<dyn CodeExecutorPlugin>) -> Result<(), PluginError> {
348        self.reg.add_code_executor(provider)
349    }
350}
351
352impl PluginRegistrar {
353    pub(crate) fn new() -> Self {
354        Self {
355            tool_names: BTreeSet::new(),
356            contributions: PluginContributions::default(),
357            registering_plugin_id: None,
358        }
359    }
360
361    pub fn tools(&mut self) -> ToolRegistrations<'_> {
362        ToolRegistrations { reg: self }
363    }
364
365    pub fn host_events(&mut self) -> HostEventRegistrations<'_> {
366        HostEventRegistrations { reg: self }
367    }
368
369    pub fn prompt(&mut self) -> PromptRegistrations<'_> {
370        PromptRegistrations { reg: self }
371    }
372
373    pub fn surface(&mut self) -> SurfaceRegistrations<'_> {
374        SurfaceRegistrations { reg: self }
375    }
376
377    pub fn discovery(&mut self) -> DiscoveryRegistrations<'_> {
378        DiscoveryRegistrations { reg: self }
379    }
380
381    pub fn turn(&mut self) -> TurnRegistrations<'_> {
382        TurnRegistrations { reg: self }
383    }
384
385    pub fn tool_calls(&mut self) -> ToolCallRegistrations<'_> {
386        ToolCallRegistrations { reg: self }
387    }
388
389    pub fn output(&mut self) -> OutputRegistrations<'_> {
390        OutputRegistrations { reg: self }
391    }
392
393    pub fn tool_results(&mut self) -> ToolResultRegistrations<'_> {
394        ToolResultRegistrations { reg: self }
395    }
396
397    pub fn session(&mut self) -> SessionRegistrations<'_> {
398        SessionRegistrations { reg: self }
399    }
400
401    pub fn actions(&mut self) -> PluginActionRegistrations<'_> {
402        PluginActionRegistrations { reg: self }
403    }
404
405    pub fn context(&mut self) -> ContextRegistrations<'_> {
406        ContextRegistrations { reg: self }
407    }
408
409    pub fn protocol(&mut self) -> ProtocolRegistrations<'_> {
410        ProtocolRegistrations { reg: self }
411    }
412
413    pub fn execution(&mut self) -> ExecutionRegistrations<'_> {
414        ExecutionRegistrations { reg: self }
415    }
416
417    fn add_tool_provider(&mut self, provider: Arc<dyn ToolProvider>) -> Result<(), PluginError> {
418        for manifest in provider.tool_manifests() {
419            if !self.tool_names.insert(manifest.name.clone()) {
420                return Err(PluginError::Registration(format!(
421                    "duplicate plugin tool name `{}`",
422                    manifest.name
423                )));
424            }
425        }
426        self.contributions.tool_providers.push(provider);
427        Ok(())
428    }
429
430    fn add_host_event(&mut self, event: crate::HostEvent) -> Result<(), PluginError> {
431        if self
432            .contributions
433            .host_events
434            .iter()
435            .any(|existing| existing.key() == event.key())
436        {
437            return Err(PluginError::Registration(format!(
438                "duplicate host event `{}.{}.{}`",
439                event.resource_type, event.alias, event.event
440            )));
441        }
442        self.contributions.host_events.push(event);
443        Ok(())
444    }
445
446    fn add_prompt_contributor(&mut self, contributor: PromptContributor) {
447        push_registered_hook(
448            &mut self.contributions.prompt_contributors,
449            &self.registering_plugin_id,
450            contributor,
451        );
452    }
453
454    fn add_tool_surface_contributor(&mut self, contributor: ToolSurfaceContributor) {
455        push_registered_hook(
456            &mut self.contributions.tool_surface_contributors,
457            &self.registering_plugin_id,
458            contributor,
459        );
460    }
461
462    fn add_tool_discovery_contributor(&mut self, contributor: ToolDiscoveryContributor) {
463        push_registered_hook(
464            &mut self.contributions.tool_discovery_contributors,
465            &self.registering_plugin_id,
466            contributor,
467        );
468    }
469
470    fn add_before_turn_hook(&mut self, hook: BeforeTurnHook) {
471        push_registered_hook(
472            &mut self.contributions.before_turn_hooks,
473            &self.registering_plugin_id,
474            hook,
475        );
476    }
477
478    fn add_before_tool_call_hook(&mut self, hook: BeforeToolCallHook) {
479        push_registered_hook(
480            &mut self.contributions.before_tool_call_hooks,
481            &self.registering_plugin_id,
482            hook,
483        );
484    }
485
486    fn add_after_tool_call_hook(&mut self, hook: AfterToolCallHook) {
487        push_registered_hook(
488            &mut self.contributions.after_tool_call_hooks,
489            &self.registering_plugin_id,
490            hook,
491        );
492    }
493
494    fn add_after_turn_hook(&mut self, hook: AfterTurnHook) {
495        push_registered_hook(
496            &mut self.contributions.after_turn_hooks,
497            &self.registering_plugin_id,
498            hook,
499        );
500    }
501
502    fn add_checkpoint_hook(&mut self, hook: CheckpointHook) {
503        push_registered_hook(
504            &mut self.contributions.checkpoint_hooks,
505            &self.registering_plugin_id,
506            hook,
507        );
508    }
509
510    fn add_assistant_stream_hook(&mut self, hook: AssistantStreamHook) {
511        push_registered_hook(
512            &mut self.contributions.assistant_stream_hooks,
513            &self.registering_plugin_id,
514            hook,
515        );
516    }
517
518    fn add_assistant_response_hook(&mut self, hook: AssistantResponseHook) {
519        push_registered_hook(
520            &mut self.contributions.assistant_response_hooks,
521            &self.registering_plugin_id,
522            hook,
523        );
524    }
525
526    fn add_assistant_prose_projector(
527        &mut self,
528        provider: Arc<dyn AssistantProseProjectorPlugin>,
529    ) -> Result<(), PluginError> {
530        register_singleton_hook(
531            &mut self.contributions.assistant_prose_projector,
532            &self.registering_plugin_id,
533            "assistant prose projector",
534            "assistant_prose_projector",
535            provider,
536        )
537    }
538
539    fn add_tool_result_projector(&mut self, hook: ToolResultProjector) -> Result<(), PluginError> {
540        register_singleton_hook(
541            &mut self.contributions.tool_result_projector,
542            &self.registering_plugin_id,
543            "tool result projector",
544            "model_observation",
545            hook,
546        )
547    }
548
549    fn add_plugin_action(
550        &mut self,
551        def: PluginActionDef,
552        handler: PluginActionHandler,
553    ) -> Result<(), PluginError> {
554        if self.contributions.plugin_actions.contains_key(&def.name) {
555            return Err(PluginError::Registration(format!(
556                "duplicate plugin action name `{}`",
557                def.name
558            )));
559        }
560        self.contributions
561            .plugin_actions
562            .insert(def.name.clone(), RegisteredPluginAction { def, handler });
563        Ok(())
564    }
565
566    fn add_protocol_session(
567        &mut self,
568        provider: Arc<dyn ProtocolSessionPlugin>,
569    ) -> Result<(), PluginError> {
570        register_singleton_hook(
571            &mut self.contributions.protocol_session,
572            &self.registering_plugin_id,
573            "protocol session capability",
574            "protocol_session",
575            provider,
576        )
577    }
578
579    fn add_code_executor(
580        &mut self,
581        provider: Arc<dyn CodeExecutorPlugin>,
582    ) -> Result<(), PluginError> {
583        register_singleton_hook(
584            &mut self.contributions.code_executor,
585            &self.registering_plugin_id,
586            "code executor capability",
587            "code_executor",
588            provider,
589        )
590    }
591
592    fn add_protocol_driver(
593        &mut self,
594        provider: Arc<dyn ProtocolDriverPlugin>,
595    ) -> Result<(), PluginError> {
596        register_singleton_hook(
597            &mut self.contributions.protocol_driver,
598            &self.registering_plugin_id,
599            "protocol driver capability",
600            "protocol_driver",
601            provider,
602        )
603    }
604}