1use std::sync::Arc;
11
12use super::{
13 AfterToolCallHook, AfterTurnHook, AssistantResponseHook, AssistantStreamHook,
14 BeforeToolCallHook, BeforeTurnHook, CheckpointHook, ContextCompactor, PluginCommand,
15 PluginCommandHandler, PluginCommandInvokeFuture, PluginCommandOutcome, PluginError, PluginHost,
16 PluginLifecycleEventHook, PluginOperationDef, PluginOperationFailure, PluginOperationKind,
17 PluginQuery, PluginQueryHandler, PluginQueryInvokeFuture, PluginRegistrar, PluginSnapshotMeta,
18 PluginTask, PluginTaskHandler, PluginTaskInvokeFuture, PluginTaskOutcome, PromptContributor,
19 SessionConfigMutator, SessionToolAccess, SnapshotReader, SnapshotWriter,
20 SubagentSessionContext, ToolCatalogContributor, ToolResultProjector, TurnContextTransform,
21};
22use crate::{PluginOptions, ToolProvider};
23
24#[derive(Clone, Debug, PartialEq, Eq, serde::Serialize, serde::Deserialize)]
25pub struct PluginExtensionContribution {
26 pub extension_id: String,
27 #[serde(default)]
28 pub payload: serde_json::Value,
29}
30
31impl PluginExtensionContribution {
32 pub fn new(
33 extension_id: impl Into<String>,
34 payload: impl serde::Serialize,
35 ) -> Result<Self, serde_json::Error> {
36 Ok(Self {
37 extension_id: extension_id.into(),
38 payload: serde_json::to_value(payload)?,
39 })
40 }
41
42 pub fn from_value(extension_id: impl Into<String>, payload: serde_json::Value) -> Self {
43 Self {
44 extension_id: extension_id.into(),
45 payload,
46 }
47 }
48}
49
50#[derive(Clone, Debug, Default, PartialEq, Eq)]
51pub struct PluginExtensions {
52 contributions: std::collections::BTreeMap<String, Vec<serde_json::Value>>,
53}
54
55impl PluginExtensions {
56 pub fn from_contributions(
57 contributions: impl IntoIterator<Item = PluginExtensionContribution>,
58 ) -> Self {
59 let mut extensions = Self::default();
60 for contribution in contributions {
61 extensions.insert(contribution);
62 }
63 extensions
64 }
65
66 pub fn insert(&mut self, contribution: PluginExtensionContribution) {
67 self.contributions
68 .entry(contribution.extension_id)
69 .or_default()
70 .push(contribution.payload);
71 }
72
73 pub fn payloads(&self, extension_id: &str) -> &[serde_json::Value] {
74 self.contributions
75 .get(extension_id)
76 .map(Vec::as_slice)
77 .unwrap_or(&[])
78 }
79
80 pub fn is_empty(&self) -> bool {
81 self.contributions.is_empty()
82 }
83}
84
85#[derive(Clone, Default)]
86pub struct PluginSpec {
87 pub tool_providers: Vec<Arc<dyn ToolProvider>>,
88 pub triggers: Vec<crate::TriggerEvent>,
89 pub prompt_contributors: Vec<PromptContributor>,
90 pub tool_catalog_contributors: Vec<ToolCatalogContributor>,
91 pub before_turn_hooks: Vec<BeforeTurnHook>,
92 pub before_tool_call_hooks: Vec<BeforeToolCallHook>,
93 pub after_tool_call_hooks: Vec<AfterToolCallHook>,
94 pub after_turn_hooks: Vec<AfterTurnHook>,
95 pub checkpoint_hooks: Vec<CheckpointHook>,
96 pub assistant_stream_hooks: Vec<AssistantStreamHook>,
97 pub assistant_response_hooks: Vec<AssistantResponseHook>,
98 pub tool_result_projector: Option<ToolResultProjector>,
99 pub runtime_event_hooks: Vec<PluginLifecycleEventHook>,
100 pub session_config_mutators: Vec<SessionConfigMutator>,
101 pub(crate) plugin_queries: Vec<(PluginOperationDef, PluginQueryHandler)>,
102 pub(crate) plugin_commands: Vec<(PluginOperationDef, PluginCommandHandler)>,
103 pub(crate) plugin_tasks: Vec<(PluginOperationDef, PluginTaskHandler)>,
104 pub turn_context_transforms: Vec<(i32, Arc<dyn TurnContextTransform>)>,
105 pub context_compactors: Vec<(i32, Arc<dyn ContextCompactor>)>,
106}
107
108impl PluginSpec {
109 pub fn new() -> Self {
110 Self::default()
111 }
112
113 pub fn with_tool_provider(mut self, provider: Arc<dyn ToolProvider>) -> Self {
114 self.tool_providers.push(provider);
115 self
116 }
117
118 pub fn with_trigger_event(mut self, event: crate::TriggerEvent) -> Self {
119 self.triggers.push(event);
120 self
121 }
122
123 pub fn with_prompt_contributor(mut self, contributor: PromptContributor) -> Self {
124 self.prompt_contributors.push(contributor);
125 self
126 }
127
128 pub fn with_tool_catalog_contributor(mut self, contributor: ToolCatalogContributor) -> Self {
129 self.tool_catalog_contributors.push(contributor);
130 self
131 }
132
133 pub fn with_before_turn(mut self, hook: BeforeTurnHook) -> Self {
134 self.before_turn_hooks.push(hook);
135 self
136 }
137
138 pub fn with_before_tool_call(mut self, hook: BeforeToolCallHook) -> Self {
139 self.before_tool_call_hooks.push(hook);
140 self
141 }
142
143 pub fn with_after_tool_call(mut self, hook: AfterToolCallHook) -> Self {
144 self.after_tool_call_hooks.push(hook);
145 self
146 }
147
148 pub fn with_after_turn(mut self, hook: AfterTurnHook) -> Self {
149 self.after_turn_hooks.push(hook);
150 self
151 }
152
153 pub fn with_checkpoint(mut self, hook: CheckpointHook) -> Self {
154 self.checkpoint_hooks.push(hook);
155 self
156 }
157
158 pub fn with_assistant_stream(mut self, hook: AssistantStreamHook) -> Self {
159 self.assistant_stream_hooks.push(hook);
160 self
161 }
162
163 pub fn with_assistant_response(mut self, hook: AssistantResponseHook) -> Self {
164 self.assistant_response_hooks.push(hook);
165 self
166 }
167
168 pub fn with_tool_result_projector(mut self, projector: ToolResultProjector) -> Self {
169 self.tool_result_projector = Some(projector);
170 self
171 }
172
173 pub fn with_runtime_event(mut self, hook: PluginLifecycleEventHook) -> Self {
174 self.runtime_event_hooks.push(hook);
175 self
176 }
177
178 pub fn with_session_config_mutator(mut self, hook: SessionConfigMutator) -> Self {
179 self.session_config_mutators.push(hook);
180 self
181 }
182
183 pub(crate) fn with_plugin_query(
184 mut self,
185 def: PluginOperationDef,
186 handler: PluginQueryHandler,
187 ) -> Self {
188 self.plugin_queries.push((def, handler));
189 self
190 }
191
192 pub fn with_plugin_query_typed<Op, F, Fut>(self, handler: F) -> Self
193 where
194 Op: PluginQuery,
195 F: Fn(super::PluginQueryContext, Op::Args) -> Fut + Send + Sync + 'static,
196 Fut: std::future::Future<Output = Result<Op::Output, PluginOperationFailure>>
197 + Send
198 + 'static,
199 {
200 self.with_plugin_query(
201 super::plugin_operation_def::<Op>(PluginOperationKind::Query),
202 Arc::new(move |ctx, args| {
203 let parsed = serde_json::from_value::<Op::Args>(args);
204 match parsed {
205 Ok(args) => {
206 let fut = handler(ctx, args);
207 Box::pin(async move {
208 let output = fut.await?;
209 serde_json::to_value(output).map_err(|err| {
210 PluginOperationFailure::new(format!(
211 "failed to serialize {} output: {err}",
212 Op::NAME
213 ))
214 })
215 }) as PluginQueryInvokeFuture
216 }
217 Err(err) => Box::pin(async move {
218 Err(PluginOperationFailure::new(format!(
219 "invalid {} args: {err}",
220 Op::NAME
221 )))
222 }) as PluginQueryInvokeFuture,
223 }
224 }),
225 )
226 }
227
228 pub(crate) fn with_plugin_command(
229 mut self,
230 def: PluginOperationDef,
231 handler: PluginCommandHandler,
232 ) -> Self {
233 self.plugin_commands.push((def, handler));
234 self
235 }
236
237 pub fn with_plugin_command_typed<Op, F, Fut>(self, handler: F) -> Self
238 where
239 Op: PluginCommand,
240 F: Fn(super::PluginCommandContext, Op::Args) -> Fut + Send + Sync + 'static,
241 Fut: std::future::Future<
242 Output = Result<PluginCommandOutcome<Op::Output>, PluginOperationFailure>,
243 > + Send
244 + 'static,
245 {
246 self.with_plugin_command(
247 super::plugin_operation_def::<Op>(PluginOperationKind::Command),
248 Arc::new(move |ctx, args| {
249 let parsed = serde_json::from_value::<Op::Args>(args);
250 match parsed {
251 Ok(args) => {
252 let fut = handler(ctx, args);
253 Box::pin(async move {
254 let outcome = fut.await?;
255 let output = serde_json::to_value(outcome.output).map_err(|err| {
256 PluginOperationFailure::new(format!(
257 "failed to serialize {} output: {err}",
258 Op::NAME
259 ))
260 })?;
261 Ok(super::actions::ErasedPluginCommandOutcome {
262 output,
263 events: outcome.events,
264 directives: outcome.directives,
265 })
266 }) as PluginCommandInvokeFuture
267 }
268 Err(err) => Box::pin(async move {
269 Err(PluginOperationFailure::new(format!(
270 "invalid {} args: {err}",
271 Op::NAME
272 )))
273 }) as PluginCommandInvokeFuture,
274 }
275 }),
276 )
277 }
278
279 pub fn with_plugin_command_value<Op, F, Fut>(self, handler: F) -> Self
280 where
281 Op: PluginCommand,
282 F: Fn(super::PluginCommandContext, Op::Args) -> Fut + Send + Sync + 'static,
283 Fut: std::future::Future<Output = Result<Op::Output, PluginOperationFailure>>
284 + Send
285 + 'static,
286 {
287 self.with_plugin_command_typed::<Op, _, _>(move |ctx, args| {
288 let fut = handler(ctx, args);
289 async move { fut.await.map(PluginCommandOutcome::new) }
290 })
291 }
292
293 pub(crate) fn with_plugin_task(
294 mut self,
295 def: PluginOperationDef,
296 handler: PluginTaskHandler,
297 ) -> Self {
298 self.plugin_tasks.push((def, handler));
299 self
300 }
301
302 pub fn with_plugin_task_typed<Op, F, Fut>(self, handler: F) -> Self
303 where
304 Op: PluginTask,
305 F: Fn(super::PluginTaskContext, Op::Args) -> Fut + Send + Sync + 'static,
306 Fut: std::future::Future<
307 Output = Result<PluginTaskOutcome<Op::Output>, PluginOperationFailure>,
308 > + Send
309 + 'static,
310 {
311 self.with_plugin_task(
312 super::plugin_operation_def::<Op>(PluginOperationKind::Task),
313 Arc::new(move |ctx, args| {
314 let parsed = serde_json::from_value::<Op::Args>(args);
315 match parsed {
316 Ok(args) => {
317 let fut = handler(ctx, args);
318 Box::pin(async move {
319 let outcome = fut.await?;
320 let output = serde_json::to_value(outcome.output).map_err(|err| {
321 PluginOperationFailure::new(format!(
322 "failed to serialize {} output: {err}",
323 Op::NAME
324 ))
325 })?;
326 Ok(super::actions::ErasedPluginTaskOutcome {
327 output,
328 events: outcome.events,
329 directives: outcome.directives,
330 })
331 }) as PluginTaskInvokeFuture
332 }
333 Err(err) => Box::pin(async move {
334 Err(PluginOperationFailure::new(format!(
335 "invalid {} args: {err}",
336 Op::NAME
337 )))
338 }) as PluginTaskInvokeFuture,
339 }
340 }),
341 )
342 }
343
344 pub fn with_plugin_task_value<Op, F, Fut>(self, handler: F) -> Self
345 where
346 Op: PluginTask,
347 F: Fn(super::PluginTaskContext, Op::Args) -> Fut + Send + Sync + 'static,
348 Fut: std::future::Future<Output = Result<Op::Output, PluginOperationFailure>>
349 + Send
350 + 'static,
351 {
352 self.with_plugin_task_typed::<Op, _, _>(move |ctx, args| {
353 let fut = handler(ctx, args);
354 async move { fut.await.map(PluginTaskOutcome::new) }
355 })
356 }
357
358 pub fn with_turn_context_transform(
359 mut self,
360 priority: i32,
361 transform: Arc<dyn TurnContextTransform>,
362 ) -> Self {
363 self.turn_context_transforms.push((priority, transform));
364 self
365 }
366
367 pub fn with_context_compactor(
368 mut self,
369 priority: i32,
370 compactor: Arc<dyn ContextCompactor>,
371 ) -> Self {
372 self.context_compactors.push((priority, compactor));
373 self
374 }
375}
376
377#[derive(Clone, Debug)]
378pub struct PluginSessionContext {
379 pub session_id: String,
380 pub tool_access: SessionToolAccess,
381 pub subagent: Option<SubagentSessionContext>,
382 pub plugin_options: PluginOptions,
383 pub extensions: PluginExtensions,
384 pub parent_session_id: Option<String>,
389}
390
391impl PluginSessionContext {
392 pub fn is_root_session(&self) -> bool {
396 self.parent_session_id.is_none()
397 }
398}
399
400#[derive(Clone)]
401pub struct SessionReadyContext {
402 pub session_id: String,
403 pub host: PluginHost,
404}
405
406pub trait SessionPlugin: Send + Sync {
407 fn id(&self) -> &'static str;
408
409 fn version(&self) -> &'static str {
410 "1"
411 }
412
413 fn register(&self, reg: &mut PluginRegistrar) -> Result<(), PluginError>;
414
415 fn snapshot(
416 &self,
417 _writer: &mut dyn SnapshotWriter,
418 ) -> Result<PluginSnapshotMeta, PluginError> {
419 Ok(PluginSnapshotMeta {
420 plugin_id: self.id().to_string(),
421 plugin_version: self.version().to_string(),
422 revision: self.snapshot_revision(),
423 state: None,
424 })
425 }
426
427 fn snapshot_revision(&self) -> u64 {
428 0
429 }
430
431 fn restore(
432 &self,
433 _meta: &PluginSnapshotMeta,
434 _reader: &dyn SnapshotReader,
435 ) -> Result<(), PluginError> {
436 Ok(())
437 }
438
439 fn session_ready(&self, _ctx: SessionReadyContext) -> Result<(), PluginError> {
440 Ok(())
441 }
442}
443
444pub trait PluginFactory: Send + Sync {
493 fn id(&self) -> &'static str;
494
495 fn extension_contributions(&self) -> Vec<PluginExtensionContribution> {
496 Vec::new()
497 }
498
499 fn build(&self, ctx: &PluginSessionContext) -> Result<Arc<dyn SessionPlugin>, PluginError>;
502}
503
504pub type PluginSpecBuilder =
505 Arc<dyn Fn(&PluginSessionContext) -> Result<PluginSpec, PluginError> + Send + Sync>;
506
507pub struct PluginSpecFactory {
508 id: &'static str,
509 builder: PluginSpecBuilder,
510}
511
512impl PluginSpecFactory {
513 pub fn new(id: &'static str, builder: PluginSpecBuilder) -> Self {
514 Self { id, builder }
515 }
516}
517
518pub struct StaticPluginFactory {
519 id: &'static str,
520 spec: PluginSpec,
521}
522
523impl StaticPluginFactory {
524 pub fn new(id: &'static str, spec: PluginSpec) -> Self {
525 Self { id, spec }
526 }
527}
528
529struct SpecPlugin {
530 id: &'static str,
531 spec: PluginSpec,
532}
533
534impl PluginFactory for PluginSpecFactory {
535 fn id(&self) -> &'static str {
536 self.id
537 }
538
539 fn build(&self, ctx: &PluginSessionContext) -> Result<Arc<dyn SessionPlugin>, PluginError> {
540 Ok(Arc::new(SpecPlugin {
541 id: self.id,
542 spec: (self.builder)(ctx)?,
543 }))
544 }
545}
546
547impl PluginFactory for StaticPluginFactory {
548 fn id(&self) -> &'static str {
549 self.id
550 }
551
552 fn build(&self, _ctx: &PluginSessionContext) -> Result<Arc<dyn SessionPlugin>, PluginError> {
553 Ok(Arc::new(SpecPlugin {
554 id: self.id,
555 spec: self.spec.clone(),
556 }))
557 }
558}
559
560impl SessionPlugin for SpecPlugin {
561 fn id(&self) -> &'static str {
562 self.id
563 }
564
565 fn register(&self, reg: &mut PluginRegistrar) -> Result<(), PluginError> {
566 for provider in &self.spec.tool_providers {
567 reg.tools().provider(Arc::clone(provider))?;
568 }
569 for event in &self.spec.triggers {
570 reg.triggers().declare(event.clone())?;
571 }
572 for contributor in &self.spec.prompt_contributors {
573 reg.prompt().contribute(Arc::clone(contributor));
574 }
575 for contributor in &self.spec.tool_catalog_contributors {
576 reg.tool_catalog().contribute(Arc::clone(contributor));
577 }
578 for hook in &self.spec.before_turn_hooks {
579 reg.turn().before(Arc::clone(hook));
580 }
581 for hook in &self.spec.before_tool_call_hooks {
582 reg.tool_calls().before(Arc::clone(hook));
583 }
584 for hook in &self.spec.after_tool_call_hooks {
585 reg.tool_calls().after(Arc::clone(hook));
586 }
587 for hook in &self.spec.after_turn_hooks {
588 reg.turn().after(Arc::clone(hook));
589 }
590 for hook in &self.spec.checkpoint_hooks {
591 reg.turn().checkpoint(Arc::clone(hook));
592 }
593 for hook in &self.spec.assistant_stream_hooks {
594 reg.output().stream(Arc::clone(hook));
595 }
596 for hook in &self.spec.assistant_response_hooks {
597 reg.output().response(Arc::clone(hook));
598 }
599 if let Some(projector) = &self.spec.tool_result_projector {
600 reg.tool_results().projector(Arc::clone(projector))?;
601 }
602 for hook in &self.spec.runtime_event_hooks {
603 reg.session().on_event(Arc::clone(hook));
604 }
605 for hook in &self.spec.session_config_mutators {
606 reg.session().config_mutator(Arc::clone(hook));
607 }
608 for (def, handler) in &self.spec.plugin_queries {
609 reg.operations().query(def.clone(), Arc::clone(handler))?;
610 }
611 for (def, handler) in &self.spec.plugin_commands {
612 reg.operations().command(def.clone(), Arc::clone(handler))?;
613 }
614 for (def, handler) in &self.spec.plugin_tasks {
615 reg.operations().task(def.clone(), Arc::clone(handler))?;
616 }
617 for (priority, transform) in &self.spec.turn_context_transforms {
618 reg.context().prepare_turn(*priority, Arc::clone(transform));
619 }
620 for (priority, compactor) in &self.spec.context_compactors {
621 reg.context().compact(*priority, Arc::clone(compactor));
622 }
623 Ok(())
624 }
625}