Skip to main content

lash_core/plugin/
registrar.rs

1use std::collections::{BTreeMap, BTreeSet};
2use std::sync::Arc;
3
4use super::*;
5
6#[derive(Clone)]
7pub(crate) struct RegisteredHook<T> {
8    pub(crate) plugin_id: String,
9    pub(crate) hook: T,
10}
11
12#[derive(Clone)]
13pub(crate) struct RegisteredExclusiveHook<T> {
14    pub(crate) plugin_id: String,
15    pub(crate) hook: T,
16}
17
18pub(crate) fn current_plugin_id(registering_plugin_id: &Option<String>) -> String {
19    registering_plugin_id
20        .clone()
21        .unwrap_or_else(|| "__unknown__".to_string())
22}
23
24fn push_registered_hook<T>(
25    hooks: &mut Vec<RegisteredHook<T>>,
26    registering_plugin_id: &Option<String>,
27    hook: T,
28) {
29    hooks.push(RegisteredHook {
30        plugin_id: current_plugin_id(registering_plugin_id),
31        hook,
32    });
33}
34
35fn exclusive_hook_owner(
36    existing_owner: Option<&str>,
37    registering_plugin_id: &Option<String>,
38    hook_kind: &str,
39    hook_name: &str,
40) -> Result<String, PluginError> {
41    let plugin_id = registering_plugin_id
42        .clone()
43        .ok_or_else(|| PluginError::Registration("missing registering plugin id".to_string()))?;
44    if let Some(existing) = existing_owner {
45        return Err(PluginError::Registration(format!(
46            "duplicate {hook_kind} for `{hook_name}`: `{plugin_id}` conflicts with `{existing}`"
47        )));
48    }
49    Ok(plugin_id)
50}
51
52fn register_singleton_hook<H>(
53    slot: &mut Option<RegisteredExclusiveHook<H>>,
54    registering_plugin_id: &Option<String>,
55    hook_kind: &str,
56    hook_name: &str,
57    hook: H,
58) -> Result<(), PluginError> {
59    let plugin_id = exclusive_hook_owner(
60        slot.as_ref()
61            .map(|registered| registered.plugin_id.as_str()),
62        registering_plugin_id,
63        hook_kind,
64        hook_name,
65    )?;
66    *slot = Some(RegisteredExclusiveHook { plugin_id, hook });
67    Ok(())
68}
69
70pub struct PluginRegistrar {
71    pub(crate) tool_names: BTreeSet<String>,
72    pub(crate) tool_providers: Vec<Arc<dyn ToolProvider>>,
73    pub(crate) prompt_contributors: Vec<RegisteredHook<PromptContributor>>,
74    pub(crate) tool_surface_contributors: Vec<RegisteredHook<ToolSurfaceContributor>>,
75    pub(crate) tool_discovery_contributors: Vec<RegisteredHook<ToolDiscoveryContributor>>,
76    pub(crate) before_turn_hooks: Vec<RegisteredHook<BeforeTurnHook>>,
77    pub(crate) before_tool_call_hooks: Vec<RegisteredHook<BeforeToolCallHook>>,
78    pub(crate) after_tool_call_hooks: Vec<RegisteredHook<AfterToolCallHook>>,
79    pub(crate) after_turn_hooks: Vec<RegisteredHook<AfterTurnHook>>,
80    pub(crate) checkpoint_hooks: Vec<RegisteredHook<CheckpointHook>>,
81    pub(crate) assistant_stream_hooks: Vec<RegisteredHook<AssistantStreamHook>>,
82    pub(crate) assistant_response_hooks: Vec<RegisteredHook<AssistantResponseHook>>,
83    pub(crate) tool_result_projector: Option<RegisteredExclusiveHook<ToolResultProjector>>,
84    pub(crate) runtime_event_hooks: Vec<PluginRuntimeEventHook>,
85    pub(crate) session_config_mutators: Vec<SessionConfigMutator>,
86    pub(crate) plugin_actions: BTreeMap<String, RegisteredPluginAction>,
87    pub(crate) monitor_specs: Vec<PluginOwned<crate::MonitorSpec>>,
88    pub(crate) turn_context_transforms: Vec<(i32, Arc<dyn TurnContextTransform>)>,
89    pub(crate) history_rewriters: Vec<(i32, Arc<dyn HistoryRewriter>)>,
90    pub(crate) mode_session: Option<RegisteredExclusiveHook<Arc<dyn ModeSessionPlugin>>>,
91    pub(crate) mode_native_tools: Vec<RegisteredHook<Arc<dyn ModeNativeToolsPlugin>>>,
92    pub(crate) mode_protocol_driver:
93        Option<RegisteredExclusiveHook<Arc<dyn ModeProtocolDriverPlugin>>>,
94    pub(crate) registering_plugin_id: Option<String>,
95}
96
97pub struct ToolRegistrations<'a> {
98    reg: &'a mut PluginRegistrar,
99}
100
101impl ToolRegistrations<'_> {
102    pub fn provider(self, provider: Arc<dyn ToolProvider>) -> Result<(), PluginError> {
103        self.reg.add_tool_provider(provider)
104    }
105}
106
107pub struct PromptRegistrations<'a> {
108    reg: &'a mut PluginRegistrar,
109}
110
111impl PromptRegistrations<'_> {
112    pub fn contribute(self, contributor: PromptContributor) {
113        self.reg.add_prompt_contributor(contributor);
114    }
115}
116
117pub struct SurfaceRegistrations<'a> {
118    reg: &'a mut PluginRegistrar,
119}
120
121impl SurfaceRegistrations<'_> {
122    pub fn contribute(self, contributor: ToolSurfaceContributor) {
123        self.reg.add_tool_surface_contributor(contributor);
124    }
125}
126
127pub struct DiscoveryRegistrations<'a> {
128    reg: &'a mut PluginRegistrar,
129}
130
131impl DiscoveryRegistrations<'_> {
132    pub fn contribute(self, contributor: ToolDiscoveryContributor) {
133        self.reg.add_tool_discovery_contributor(contributor);
134    }
135}
136
137pub struct TurnRegistrations<'a> {
138    reg: &'a mut PluginRegistrar,
139}
140
141impl TurnRegistrations<'_> {
142    pub fn before(self, hook: BeforeTurnHook) {
143        self.reg.add_before_turn_hook(hook);
144    }
145
146    pub fn after(self, hook: AfterTurnHook) {
147        self.reg.add_after_turn_hook(hook);
148    }
149
150    pub fn checkpoint(self, hook: CheckpointHook) {
151        self.reg.add_checkpoint_hook(hook);
152    }
153}
154
155pub struct ToolCallRegistrations<'a> {
156    reg: &'a mut PluginRegistrar,
157}
158
159impl ToolCallRegistrations<'_> {
160    pub fn before(self, hook: BeforeToolCallHook) {
161        self.reg.add_before_tool_call_hook(hook);
162    }
163
164    pub fn after(self, hook: AfterToolCallHook) {
165        self.reg.add_after_tool_call_hook(hook);
166    }
167}
168
169pub struct OutputRegistrations<'a> {
170    reg: &'a mut PluginRegistrar,
171}
172
173impl OutputRegistrations<'_> {
174    pub fn stream(self, hook: AssistantStreamHook) {
175        self.reg.add_assistant_stream_hook(hook);
176    }
177
178    pub fn response(self, hook: AssistantResponseHook) {
179        self.reg.add_assistant_response_hook(hook);
180    }
181}
182
183pub struct ToolResultRegistrations<'a> {
184    reg: &'a mut PluginRegistrar,
185}
186
187impl ToolResultRegistrations<'_> {
188    pub fn projector(self, hook: ToolResultProjector) -> Result<(), PluginError> {
189        self.reg.add_tool_result_projector(hook)
190    }
191}
192
193pub struct SessionRegistrations<'a> {
194    reg: &'a mut PluginRegistrar,
195}
196
197impl SessionRegistrations<'_> {
198    pub fn on_event(self, hook: PluginRuntimeEventHook) {
199        self.reg.runtime_event_hooks.push(hook);
200    }
201
202    pub fn config_mutator(self, hook: SessionConfigMutator) {
203        self.reg.session_config_mutators.push(hook);
204    }
205}
206
207pub struct MonitorRegistrations<'a> {
208    reg: &'a mut PluginRegistrar,
209}
210
211impl MonitorRegistrations<'_> {
212    pub fn register(self, spec: crate::MonitorSpec) {
213        self.reg.add_monitor_spec(spec);
214    }
215}
216
217pub struct PluginActionRegistrations<'a> {
218    reg: &'a mut PluginRegistrar,
219}
220
221impl PluginActionRegistrations<'_> {
222    pub fn op(self, def: PluginActionDef, handler: PluginActionHandler) -> Result<(), PluginError> {
223        self.reg.add_plugin_action(def, handler)
224    }
225
226    pub fn typed<Op, F, Fut>(self, handler: F) -> Result<(), PluginError>
227    where
228        Op: PluginAction,
229        F: Fn(PluginActionContext, Op::Args) -> Fut + Send + Sync + 'static,
230        Fut: Future<Output = Result<Op::Output, PluginActionFailure>> + Send + 'static,
231    {
232        self.op(
233            plugin_action_def::<Op>(),
234            Arc::new(move |ctx, args| {
235                let parsed = serde_json::from_value::<Op::Args>(args);
236                match parsed {
237                    Ok(args) => {
238                        let fut = handler(ctx, args);
239                        Box::pin(async move {
240                            match fut.await {
241                                Ok(output) => match serde_json::to_value(output) {
242                                    Ok(value) => ToolResult::ok(value),
243                                    Err(err) => ToolResult::err(serde_json::json!(format!(
244                                        "failed to serialize {} output: {err}",
245                                        Op::NAME
246                                    ))),
247                                },
248                                Err(err) => ToolResult::err(serde_json::json!(err.to_string())),
249                            }
250                        })
251                    }
252                    Err(err) => Box::pin(async move {
253                        ToolResult::err(serde_json::json!(format!(
254                            "invalid {} args: {err}",
255                            Op::NAME
256                        )))
257                    }),
258                }
259            }),
260        )
261    }
262}
263
264pub struct HistoryRegistrations<'a> {
265    reg: &'a mut PluginRegistrar,
266}
267
268impl HistoryRegistrations<'_> {
269    /// Register a per-turn context transform. Higher priority runs first.
270    pub fn prepare_turn(self, priority: i32, transform: Arc<dyn TurnContextTransform>) {
271        self.reg.turn_context_transforms.push((priority, transform));
272    }
273
274    /// Register a permanent history rewriter. Higher priority runs first.
275    pub fn rewrite(self, priority: i32, rewriter: Arc<dyn HistoryRewriter>) {
276        self.reg.history_rewriters.push((priority, rewriter));
277    }
278}
279
280pub struct ModeRegistrations<'a> {
281    reg: &'a mut PluginRegistrar,
282}
283
284impl ModeRegistrations<'_> {
285    pub fn session(self, provider: Arc<dyn ModeSessionPlugin>) -> Result<(), PluginError> {
286        self.reg.add_mode_session(provider)
287    }
288
289    pub fn native_tools(self, provider: Arc<dyn ModeNativeToolsPlugin>) -> Result<(), PluginError> {
290        self.reg.add_mode_native_tools(provider)
291    }
292
293    /// Claim the session-wide singleton protocol-driver slot. The
294    /// plugin provides a `ProtocolDriverHandle` via `build_preamble`
295    /// and identifies itself with a `mode_id` that the session's
296    /// `ExecutionMode` must match for the driver to be selected.
297    pub fn protocol_driver(
298        self,
299        provider: Arc<dyn ModeProtocolDriverPlugin>,
300    ) -> Result<(), PluginError> {
301        self.reg.add_mode_protocol_driver(provider)
302    }
303}
304
305impl PluginRegistrar {
306    pub(crate) fn new() -> Self {
307        Self {
308            tool_names: BTreeSet::new(),
309            tool_providers: Vec::new(),
310            prompt_contributors: Vec::new(),
311            tool_surface_contributors: Vec::new(),
312            tool_discovery_contributors: Vec::new(),
313            before_turn_hooks: Vec::new(),
314            before_tool_call_hooks: Vec::new(),
315            after_tool_call_hooks: Vec::new(),
316            after_turn_hooks: Vec::new(),
317            checkpoint_hooks: Vec::new(),
318            assistant_stream_hooks: Vec::new(),
319            assistant_response_hooks: Vec::new(),
320            tool_result_projector: None,
321            runtime_event_hooks: Vec::new(),
322            session_config_mutators: Vec::new(),
323            plugin_actions: BTreeMap::new(),
324            monitor_specs: Vec::new(),
325            turn_context_transforms: Vec::new(),
326            history_rewriters: Vec::new(),
327            mode_session: None,
328            mode_native_tools: Vec::new(),
329            mode_protocol_driver: None,
330            registering_plugin_id: None,
331        }
332    }
333
334    pub fn tools(&mut self) -> ToolRegistrations<'_> {
335        ToolRegistrations { reg: self }
336    }
337
338    pub fn prompt(&mut self) -> PromptRegistrations<'_> {
339        PromptRegistrations { reg: self }
340    }
341
342    pub fn surface(&mut self) -> SurfaceRegistrations<'_> {
343        SurfaceRegistrations { reg: self }
344    }
345
346    pub fn discovery(&mut self) -> DiscoveryRegistrations<'_> {
347        DiscoveryRegistrations { reg: self }
348    }
349
350    pub fn turn(&mut self) -> TurnRegistrations<'_> {
351        TurnRegistrations { reg: self }
352    }
353
354    pub fn tool_calls(&mut self) -> ToolCallRegistrations<'_> {
355        ToolCallRegistrations { reg: self }
356    }
357
358    pub fn output(&mut self) -> OutputRegistrations<'_> {
359        OutputRegistrations { reg: self }
360    }
361
362    pub fn tool_results(&mut self) -> ToolResultRegistrations<'_> {
363        ToolResultRegistrations { reg: self }
364    }
365
366    pub fn session(&mut self) -> SessionRegistrations<'_> {
367        SessionRegistrations { reg: self }
368    }
369
370    pub fn actions(&mut self) -> PluginActionRegistrations<'_> {
371        PluginActionRegistrations { reg: self }
372    }
373
374    pub fn monitors(&mut self) -> MonitorRegistrations<'_> {
375        MonitorRegistrations { reg: self }
376    }
377
378    pub fn history(&mut self) -> HistoryRegistrations<'_> {
379        HistoryRegistrations { reg: self }
380    }
381
382    pub fn mode(&mut self) -> ModeRegistrations<'_> {
383        ModeRegistrations { reg: self }
384    }
385
386    fn add_tool_provider(&mut self, provider: Arc<dyn ToolProvider>) -> Result<(), PluginError> {
387        for manifest in provider.tool_manifests() {
388            if !self.tool_names.insert(manifest.name.clone()) {
389                return Err(PluginError::Registration(format!(
390                    "duplicate plugin tool name `{}`",
391                    manifest.name
392                )));
393            }
394        }
395        self.tool_providers.push(provider);
396        Ok(())
397    }
398
399    fn add_prompt_contributor(&mut self, contributor: PromptContributor) {
400        push_registered_hook(
401            &mut self.prompt_contributors,
402            &self.registering_plugin_id,
403            contributor,
404        );
405    }
406
407    fn add_tool_surface_contributor(&mut self, contributor: ToolSurfaceContributor) {
408        push_registered_hook(
409            &mut self.tool_surface_contributors,
410            &self.registering_plugin_id,
411            contributor,
412        );
413    }
414
415    fn add_tool_discovery_contributor(&mut self, contributor: ToolDiscoveryContributor) {
416        push_registered_hook(
417            &mut self.tool_discovery_contributors,
418            &self.registering_plugin_id,
419            contributor,
420        );
421    }
422
423    fn add_before_turn_hook(&mut self, hook: BeforeTurnHook) {
424        push_registered_hook(
425            &mut self.before_turn_hooks,
426            &self.registering_plugin_id,
427            hook,
428        );
429    }
430
431    fn add_before_tool_call_hook(&mut self, hook: BeforeToolCallHook) {
432        push_registered_hook(
433            &mut self.before_tool_call_hooks,
434            &self.registering_plugin_id,
435            hook,
436        );
437    }
438
439    fn add_after_tool_call_hook(&mut self, hook: AfterToolCallHook) {
440        push_registered_hook(
441            &mut self.after_tool_call_hooks,
442            &self.registering_plugin_id,
443            hook,
444        );
445    }
446
447    fn add_after_turn_hook(&mut self, hook: AfterTurnHook) {
448        push_registered_hook(
449            &mut self.after_turn_hooks,
450            &self.registering_plugin_id,
451            hook,
452        );
453    }
454
455    fn add_checkpoint_hook(&mut self, hook: CheckpointHook) {
456        push_registered_hook(
457            &mut self.checkpoint_hooks,
458            &self.registering_plugin_id,
459            hook,
460        );
461    }
462
463    fn add_assistant_stream_hook(&mut self, hook: AssistantStreamHook) {
464        push_registered_hook(
465            &mut self.assistant_stream_hooks,
466            &self.registering_plugin_id,
467            hook,
468        );
469    }
470
471    fn add_assistant_response_hook(&mut self, hook: AssistantResponseHook) {
472        push_registered_hook(
473            &mut self.assistant_response_hooks,
474            &self.registering_plugin_id,
475            hook,
476        );
477    }
478
479    fn add_tool_result_projector(&mut self, hook: ToolResultProjector) -> Result<(), PluginError> {
480        register_singleton_hook(
481            &mut self.tool_result_projector,
482            &self.registering_plugin_id,
483            "tool result projector",
484            "model_observation",
485            hook,
486        )
487    }
488
489    fn add_plugin_action(
490        &mut self,
491        def: PluginActionDef,
492        handler: PluginActionHandler,
493    ) -> Result<(), PluginError> {
494        if self.plugin_actions.contains_key(&def.name) {
495            return Err(PluginError::Registration(format!(
496                "duplicate plugin action name `{}`",
497                def.name
498            )));
499        }
500        self.plugin_actions
501            .insert(def.name.clone(), RegisteredPluginAction { def, handler });
502        Ok(())
503    }
504
505    fn add_monitor_spec(&mut self, spec: crate::MonitorSpec) {
506        self.monitor_specs.push(PluginOwned {
507            plugin_id: current_plugin_id(&self.registering_plugin_id),
508            value: spec,
509        });
510    }
511
512    fn add_mode_session(
513        &mut self,
514        provider: Arc<dyn ModeSessionPlugin>,
515    ) -> Result<(), PluginError> {
516        register_singleton_hook(
517            &mut self.mode_session,
518            &self.registering_plugin_id,
519            "mode session capability",
520            "mode_session",
521            provider,
522        )
523    }
524
525    fn add_mode_native_tools(
526        &mut self,
527        provider: Arc<dyn ModeNativeToolsPlugin>,
528    ) -> Result<(), PluginError> {
529        push_registered_hook(
530            &mut self.mode_native_tools,
531            &self.registering_plugin_id,
532            provider,
533        );
534        Ok(())
535    }
536
537    fn add_mode_protocol_driver(
538        &mut self,
539        provider: Arc<dyn ModeProtocolDriverPlugin>,
540    ) -> Result<(), PluginError> {
541        register_singleton_hook(
542            &mut self.mode_protocol_driver,
543            &self.registering_plugin_id,
544            "mode protocol driver capability",
545            "mode_protocol_driver",
546            provider,
547        )
548    }
549}