Skip to main content

elicit_bevy/
query_plugin.rs

1//! `BevyQueryPlugin` — generic ECS parameter and query codegen tools.
2//!
3//! This plugin covers the generic Bevy ECS surfaces that cannot be represented
4//! as concrete MCP values: `Query`, `Res`, `EventReader`, `Handle`, `Local`,
5//! and complete system signatures assembled from those parameter fragments.
6
7use elicitation::emit_code::{CrateDep, EmitCode};
8use elicitation::{ElicitPlugin, elicit_tool};
9use proc_macro2::TokenStream;
10use quote::quote;
11use rmcp::ErrorData;
12use rmcp::model::{CallToolResult, Content};
13use schemars::JsonSchema;
14use serde::{Deserialize, Serialize};
15use tracing::instrument;
16
17/// How a query item should appear inside `Query<...>`.
18#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize, JsonSchema, PartialEq, Eq)]
19#[serde(rename_all = "snake_case")]
20pub enum BevyQueryItemAccess {
21    /// Emit the type by value, e.g. `Entity`.
22    Value,
23    /// Emit a shared reference, e.g. `&Transform`.
24    #[default]
25    Shared,
26    /// Emit a mutable reference, e.g. `&mut Transform`.
27    Mutable,
28}
29
30/// A single item inside a Bevy query.
31#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
32pub struct BevyQueryItemSpec {
33    /// Rust type path for the query item.
34    pub ty: String,
35    /// Access mode used for this item.
36    #[serde(default)]
37    pub access: BevyQueryItemAccess,
38}
39
40/// Supported Bevy query filter wrappers.
41#[derive(Debug, Clone, Copy, Serialize, Deserialize, JsonSchema, PartialEq, Eq)]
42#[serde(rename_all = "snake_case")]
43pub enum BevyQueryFilterKind {
44    /// Emit `With<T>`.
45    With,
46    /// Emit `Without<T>`.
47    Without,
48    /// Emit `Added<T>`.
49    Added,
50    /// Emit `Changed<T>`.
51    Changed,
52}
53
54/// Parameters for `bevy_query__define_component_query`.
55#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
56pub struct DefineComponentQueryParams {
57    /// Parameter binding name.
58    pub binding: String,
59    /// Whether to emit `mut binding`.
60    #[serde(default)]
61    pub mutable_binding: bool,
62    /// Query items to include in the data position.
63    pub items: Vec<BevyQueryItemSpec>,
64    /// Optional query filter type fragments.
65    #[serde(default)]
66    pub filters: Vec<String>,
67}
68
69/// Parameters for `bevy_query__define_resource`.
70#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
71pub struct DefineResourceParams {
72    /// Parameter binding name.
73    pub binding: String,
74    /// Resource type used inside `Res<T>` / `ResMut<T>`.
75    pub resource_type: String,
76    /// Whether to emit `ResMut<T>` and a mutable binding.
77    #[serde(default)]
78    pub mutable: bool,
79}
80
81/// Parameters for `bevy_query__define_event_reader`.
82#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
83pub struct DefineEventReaderParams {
84    /// Parameter binding name.
85    pub binding: String,
86    /// Event type read by the parameter.
87    pub event_type: String,
88}
89
90/// Parameters for `bevy_query__define_event_writer`.
91#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
92pub struct DefineEventWriterParams {
93    /// Parameter binding name.
94    pub binding: String,
95    /// Event type written by the parameter.
96    pub event_type: String,
97}
98
99/// Parameters for `bevy_query__define_handle`.
100#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
101pub struct DefineHandleParams {
102    /// Optional field visibility such as `pub`.
103    #[serde(default)]
104    pub visibility: Option<String>,
105    /// Field name.
106    pub binding: String,
107    /// Asset type used in `Handle<T>`.
108    pub asset_type: String,
109}
110
111/// Parameters for `bevy_query__define_local`.
112#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
113pub struct DefineLocalParams {
114    /// Parameter binding name.
115    pub binding: String,
116    /// Local state type.
117    pub local_type: String,
118    /// Whether to emit `mut binding`.
119    #[serde(default)]
120    pub mutable_binding: bool,
121}
122
123/// Parameters for `bevy_query__define_time`.
124#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
125pub struct DefineTimeParams {
126    /// Parameter binding name.
127    pub binding: String,
128    /// Optional time generic such as `Fixed`, `Virtual`, or `Real`.
129    #[serde(default)]
130    pub time_generic: Option<String>,
131    /// Whether to emit `ResMut<...>` and a mutable binding.
132    #[serde(default)]
133    pub mutable: bool,
134}
135
136/// Parameters for `bevy_query__filter`.
137#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
138pub struct FilterParams {
139    /// Filter wrapper to emit.
140    pub kind: BevyQueryFilterKind,
141    /// Inner type for the filter.
142    pub type_name: String,
143}
144
145/// Parameters for `bevy_query__system_signature`.
146#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
147pub struct SystemSignatureParams {
148    /// Optional function visibility such as `pub`.
149    #[serde(default)]
150    pub visibility: Option<String>,
151    /// System function name.
152    pub function_name: String,
153    /// Full parameter fragments such as `query: Query<&Transform>`.
154    #[serde(default)]
155    pub params: Vec<String>,
156    /// Optional return type.
157    #[serde(default)]
158    pub return_type: Option<String>,
159    /// Optional function body statements.
160    #[serde(default)]
161    pub body: Option<String>,
162}
163
164fn tool_err(msg: impl std::fmt::Display) -> ErrorData {
165    ErrorData::invalid_params(msg.to_string(), None)
166}
167
168fn ok_source(source: String) -> Result<CallToolResult, ErrorData> {
169    Ok(CallToolResult::success(vec![Content::text(source)]))
170}
171
172fn parse_ident(src: &str, context: &str) -> Result<syn::Ident, ErrorData> {
173    syn::parse_str::<syn::Ident>(src)
174        .map_err(|error| tool_err(format!("invalid {context} identifier `{src}`: {error}")))
175}
176
177fn parse_type(src: &str, context: &str) -> Result<syn::Type, ErrorData> {
178    syn::parse_str::<syn::Type>(src)
179        .map_err(|error| tool_err(format!("invalid {context} type `{src}`: {error}")))
180}
181
182fn parse_visibility(src: &str) -> Result<syn::Visibility, ErrorData> {
183    syn::parse_str::<syn::Visibility>(src)
184        .map_err(|error| tool_err(format!("invalid visibility `{src}`: {error}")))
185}
186
187fn parse_param_fragment(src: &str) -> Result<syn::FnArg, ErrorData> {
188    syn::parse_str::<syn::FnArg>(src)
189        .map_err(|error| tool_err(format!("invalid system parameter `{src}`: {error}")))
190}
191
192fn parse_body_statements(src: &str) -> Result<Vec<syn::Stmt>, ErrorData> {
193    let wrapped = format!("{{{src}}}");
194    syn::parse_str::<syn::Block>(&wrapped)
195        .map(|block| block.stmts)
196        .map_err(|error| tool_err(format!("invalid function body `{src}`: {error}")))
197}
198
199fn validate_non_empty<T>(values: &[T], context: &str) -> Result<(), ErrorData> {
200    if values.is_empty() {
201        Err(tool_err(format!("{context} must not be empty")))
202    } else {
203        Ok(())
204    }
205}
206
207fn binding_tokens(binding: &str, mutable: bool) -> syn::Pat {
208    let binding = parse_ident(binding, "binding").expect("validated bindings must parse");
209    if mutable {
210        syn::parse_quote!(mut #binding)
211    } else {
212        syn::parse_quote!(#binding)
213    }
214}
215
216fn type_tokens(src: &str, context: &str) -> syn::Type {
217    parse_type(src, context).expect("validated types must parse")
218}
219
220fn visibility_tokens(visibility: &Option<String>) -> syn::Visibility {
221    visibility
222        .as_deref()
223        .map(|src| parse_visibility(src).expect("validated visibility must parse"))
224        .unwrap_or_else(|| syn::parse_quote!())
225}
226
227fn fn_arg_tokens(src: &str) -> syn::FnArg {
228    parse_param_fragment(src).expect("validated function args must parse")
229}
230
231fn filter_kind_ident(kind: BevyQueryFilterKind) -> syn::Ident {
232    let name = match kind {
233        BevyQueryFilterKind::With => "With",
234        BevyQueryFilterKind::Without => "Without",
235        BevyQueryFilterKind::Added => "Added",
236        BevyQueryFilterKind::Changed => "Changed",
237    };
238    syn::Ident::new(name, proc_macro2::Span::call_site())
239}
240
241fn render_query_item(item: &BevyQueryItemSpec) -> TokenStream {
242    let ty = type_tokens(&item.ty, "query item");
243    match item.access {
244        BevyQueryItemAccess::Value => quote! { #ty },
245        BevyQueryItemAccess::Shared => quote! { &#ty },
246        BevyQueryItemAccess::Mutable => quote! { &mut #ty },
247    }
248}
249
250fn render_query_filters(filters: &[String]) -> TokenStream {
251    let filter_tokens: Vec<syn::Type> = filters
252        .iter()
253        .map(|filter| type_tokens(filter, "query filter"))
254        .collect();
255    match filter_tokens.len() {
256        0 => TokenStream::new(),
257        1 => {
258            let filter = &filter_tokens[0];
259            quote! { , #filter }
260        }
261        _ => quote! { , (#(#filter_tokens),*) },
262    }
263}
264
265fn time_type_tokens(time_generic: &Option<String>) -> TokenStream {
266    match time_generic {
267        Some(generic) => {
268            let generic = type_tokens(generic, "time generic");
269            quote! { ::bevy::prelude::Time<#generic> }
270        }
271        None => quote! { ::bevy::prelude::Time },
272    }
273}
274
275fn bevy_dep() -> Vec<CrateDep> {
276    vec![CrateDep::new("bevy", "0.18")]
277}
278
279impl EmitCode for DefineComponentQueryParams {
280    fn emit_code(&self) -> TokenStream {
281        let binding = binding_tokens(&self.binding, self.mutable_binding);
282        let items: Vec<TokenStream> = self.items.iter().map(render_query_item).collect();
283        let data = if items.len() == 1 {
284            let item = &items[0];
285            quote! { #item }
286        } else {
287            quote! { (#(#items),*) }
288        };
289        let filters = render_query_filters(&self.filters);
290        quote! { #binding: ::bevy::ecs::system::Query<#data #filters> }
291    }
292
293    fn crate_deps(&self) -> Vec<CrateDep> {
294        bevy_dep()
295    }
296}
297
298impl EmitCode for DefineResourceParams {
299    fn emit_code(&self) -> TokenStream {
300        let binding = binding_tokens(&self.binding, self.mutable);
301        let resource = type_tokens(&self.resource_type, "resource");
302        if self.mutable {
303            quote! { #binding: ::bevy::ecs::system::ResMut<#resource> }
304        } else {
305            quote! { #binding: ::bevy::ecs::system::Res<#resource> }
306        }
307    }
308
309    fn crate_deps(&self) -> Vec<CrateDep> {
310        bevy_dep()
311    }
312}
313
314impl EmitCode for DefineEventReaderParams {
315    fn emit_code(&self) -> TokenStream {
316        let binding = binding_tokens(&self.binding, false);
317        let event = type_tokens(&self.event_type, "event");
318        quote! { #binding: ::bevy::ecs::event::EventReader<#event> }
319    }
320
321    fn crate_deps(&self) -> Vec<CrateDep> {
322        bevy_dep()
323    }
324}
325
326impl EmitCode for DefineEventWriterParams {
327    fn emit_code(&self) -> TokenStream {
328        let binding = binding_tokens(&self.binding, true);
329        let event = type_tokens(&self.event_type, "event");
330        quote! { #binding: ::bevy::ecs::event::EventWriter<#event> }
331    }
332
333    fn crate_deps(&self) -> Vec<CrateDep> {
334        bevy_dep()
335    }
336}
337
338impl EmitCode for DefineHandleParams {
339    fn emit_code(&self) -> TokenStream {
340        let visibility = visibility_tokens(&self.visibility);
341        let binding = parse_ident(&self.binding, "field").expect("validated field must parse");
342        let asset = type_tokens(&self.asset_type, "asset");
343        quote! { #visibility #binding: ::bevy::asset::Handle<#asset> }
344    }
345
346    fn crate_deps(&self) -> Vec<CrateDep> {
347        bevy_dep()
348    }
349}
350
351impl EmitCode for DefineLocalParams {
352    fn emit_code(&self) -> TokenStream {
353        let binding = binding_tokens(&self.binding, self.mutable_binding);
354        let local = type_tokens(&self.local_type, "local");
355        quote! { #binding: ::bevy::ecs::system::Local<#local> }
356    }
357
358    fn crate_deps(&self) -> Vec<CrateDep> {
359        bevy_dep()
360    }
361}
362
363impl EmitCode for DefineTimeParams {
364    fn emit_code(&self) -> TokenStream {
365        let binding = binding_tokens(&self.binding, self.mutable);
366        let time = time_type_tokens(&self.time_generic);
367        if self.mutable {
368            quote! { #binding: ::bevy::ecs::system::ResMut<#time> }
369        } else {
370            quote! { #binding: ::bevy::ecs::system::Res<#time> }
371        }
372    }
373
374    fn crate_deps(&self) -> Vec<CrateDep> {
375        bevy_dep()
376    }
377}
378
379impl EmitCode for FilterParams {
380    fn emit_code(&self) -> TokenStream {
381        let kind = filter_kind_ident(self.kind);
382        let ty = type_tokens(&self.type_name, "filter");
383        quote! { ::bevy::ecs::query::#kind<#ty> }
384    }
385
386    fn crate_deps(&self) -> Vec<CrateDep> {
387        bevy_dep()
388    }
389}
390
391impl EmitCode for SystemSignatureParams {
392    fn emit_code(&self) -> TokenStream {
393        let visibility = visibility_tokens(&self.visibility);
394        let name = parse_ident(&self.function_name, "function")
395            .expect("validated function name must parse");
396        let params: Vec<syn::FnArg> = self
397            .params
398            .iter()
399            .map(|param| fn_arg_tokens(param))
400            .collect();
401        let output = self
402            .return_type
403            .as_deref()
404            .map(|return_type| {
405                let ty = type_tokens(return_type, "return");
406                quote! { -> #ty }
407            })
408            .unwrap_or_default();
409        let body = self
410            .body
411            .as_deref()
412            .map(|body| parse_body_statements(body).expect("validated body must parse"))
413            .unwrap_or_default();
414        quote! {
415            #visibility fn #name(#(#params),*) #output {
416                #(#body)*
417            }
418        }
419    }
420
421    fn crate_deps(&self) -> Vec<CrateDep> {
422        bevy_dep()
423    }
424}
425
426elicitation::register_emit!("define_component_query", DefineComponentQueryParams);
427elicitation::register_emit!("define_resource", DefineResourceParams);
428elicitation::register_emit!("define_event_reader", DefineEventReaderParams);
429elicitation::register_emit!("define_event_writer", DefineEventWriterParams);
430elicitation::register_emit!("define_handle", DefineHandleParams);
431elicitation::register_emit!("define_local", DefineLocalParams);
432elicitation::register_emit!("define_time", DefineTimeParams);
433elicitation::register_emit!("system_signature", SystemSignatureParams);
434elicitation::register_emit!("filter", FilterParams);
435
436/// MCP plugin exposing generic Bevy ECS parameter fragment tools.
437#[derive(Debug, ElicitPlugin)]
438#[plugin(name = "bevy_query")]
439pub struct BevyQueryPlugin;
440
441impl BevyQueryPlugin {
442    /// Creates a new Bevy query fragment plugin.
443    #[instrument]
444    pub fn new() -> Self {
445        Self
446    }
447}
448
449impl Default for BevyQueryPlugin {
450    fn default() -> Self {
451        Self::new()
452    }
453}
454
455fn validate_define_component_query(params: &DefineComponentQueryParams) -> Result<(), ErrorData> {
456    let _ = parse_ident(&params.binding, "binding")?;
457    validate_non_empty(&params.items, "query items")?;
458    for item in &params.items {
459        let _ = parse_type(&item.ty, "query item")?;
460    }
461    for filter in &params.filters {
462        let _ = parse_type(filter, "query filter")?;
463    }
464    Ok(())
465}
466
467fn validate_define_resource(params: &DefineResourceParams) -> Result<(), ErrorData> {
468    let _ = parse_ident(&params.binding, "binding")?;
469    let _ = parse_type(&params.resource_type, "resource")?;
470    Ok(())
471}
472
473fn validate_define_event_reader(params: &DefineEventReaderParams) -> Result<(), ErrorData> {
474    let _ = parse_ident(&params.binding, "binding")?;
475    let _ = parse_type(&params.event_type, "event")?;
476    Ok(())
477}
478
479fn validate_define_event_writer(params: &DefineEventWriterParams) -> Result<(), ErrorData> {
480    let _ = parse_ident(&params.binding, "binding")?;
481    let _ = parse_type(&params.event_type, "event")?;
482    Ok(())
483}
484
485fn validate_define_handle(params: &DefineHandleParams) -> Result<(), ErrorData> {
486    if let Some(visibility) = &params.visibility {
487        let _ = parse_visibility(visibility)?;
488    }
489    let _ = parse_ident(&params.binding, "field")?;
490    let _ = parse_type(&params.asset_type, "asset")?;
491    Ok(())
492}
493
494fn validate_define_local(params: &DefineLocalParams) -> Result<(), ErrorData> {
495    let _ = parse_ident(&params.binding, "binding")?;
496    let _ = parse_type(&params.local_type, "local")?;
497    Ok(())
498}
499
500fn validate_define_time(params: &DefineTimeParams) -> Result<(), ErrorData> {
501    let _ = parse_ident(&params.binding, "binding")?;
502    if let Some(time_generic) = &params.time_generic {
503        let _ = parse_type(time_generic, "time generic")?;
504    }
505    Ok(())
506}
507
508fn validate_filter(params: &FilterParams) -> Result<(), ErrorData> {
509    let _ = parse_type(&params.type_name, "filter")?;
510    Ok(())
511}
512
513fn validate_system_signature(params: &SystemSignatureParams) -> Result<(), ErrorData> {
514    if let Some(visibility) = &params.visibility {
515        let _ = parse_visibility(visibility)?;
516    }
517    let _ = parse_ident(&params.function_name, "function")?;
518    for param in &params.params {
519        let _ = parse_param_fragment(param)?;
520    }
521    if let Some(return_type) = &params.return_type {
522        let _ = parse_type(return_type, "return")?;
523    }
524    if let Some(body) = &params.body {
525        let _ = parse_body_statements(body)?;
526    }
527    Ok(())
528}
529
530#[elicit_tool(
531    plugin = "bevy_query",
532    name = "define_component_query",
533    description = "Emit a `Query<...>` system parameter fragment from query items and optional filters.",
534    emit = None
535)]
536#[instrument(skip_all)]
537async fn define_component_query(
538    p: DefineComponentQueryParams,
539) -> Result<CallToolResult, ErrorData> {
540    validate_define_component_query(&p)?;
541    ok_source(p.emit_code().to_string())
542}
543
544#[elicit_tool(
545    plugin = "bevy_query",
546    name = "define_resource",
547    description = "Emit a `Res<T>` or `ResMut<T>` system parameter fragment.",
548    emit = None
549)]
550#[instrument(skip_all)]
551async fn define_resource(p: DefineResourceParams) -> Result<CallToolResult, ErrorData> {
552    validate_define_resource(&p)?;
553    ok_source(p.emit_code().to_string())
554}
555
556#[elicit_tool(
557    plugin = "bevy_query",
558    name = "define_event_reader",
559    description = "Emit an `EventReader<E>` system parameter fragment.",
560    emit = None
561)]
562#[instrument(skip_all)]
563async fn define_event_reader(p: DefineEventReaderParams) -> Result<CallToolResult, ErrorData> {
564    validate_define_event_reader(&p)?;
565    ok_source(p.emit_code().to_string())
566}
567
568#[elicit_tool(
569    plugin = "bevy_query",
570    name = "define_event_writer",
571    description = "Emit an `EventWriter<E>` system parameter fragment with a mutable binding.",
572    emit = None
573)]
574#[instrument(skip_all)]
575async fn define_event_writer(p: DefineEventWriterParams) -> Result<CallToolResult, ErrorData> {
576    validate_define_event_writer(&p)?;
577    ok_source(p.emit_code().to_string())
578}
579
580#[elicit_tool(
581    plugin = "bevy_query",
582    name = "define_handle",
583    description = "Emit a `Handle<A>` field declaration fragment.",
584    emit = None
585)]
586#[instrument(skip_all)]
587async fn define_handle(p: DefineHandleParams) -> Result<CallToolResult, ErrorData> {
588    validate_define_handle(&p)?;
589    ok_source(p.emit_code().to_string())
590}
591
592#[elicit_tool(
593    plugin = "bevy_query",
594    name = "define_local",
595    description = "Emit a `Local<T>` system parameter fragment.",
596    emit = None
597)]
598#[instrument(skip_all)]
599async fn define_local(p: DefineLocalParams) -> Result<CallToolResult, ErrorData> {
600    validate_define_local(&p)?;
601    ok_source(p.emit_code().to_string())
602}
603
604#[elicit_tool(
605    plugin = "bevy_query",
606    name = "define_time",
607    description = "Emit a `Res<Time>` or `Res<Time<Fixed>>`-style system parameter fragment.",
608    emit = None
609)]
610#[instrument(skip_all)]
611async fn define_time(p: DefineTimeParams) -> Result<CallToolResult, ErrorData> {
612    validate_define_time(&p)?;
613    ok_source(p.emit_code().to_string())
614}
615
616#[elicit_tool(
617    plugin = "bevy_query",
618    name = "system_signature",
619    description = "Emit a full Bevy system function signature from previously generated parameter fragments.",
620    emit = None
621)]
622#[instrument(skip_all)]
623async fn system_signature(p: SystemSignatureParams) -> Result<CallToolResult, ErrorData> {
624    validate_system_signature(&p)?;
625    ok_source(p.emit_code().to_string())
626}
627
628#[elicit_tool(
629    plugin = "bevy_query",
630    name = "filter",
631    description = "Emit a Bevy query filter such as `With<T>` or `Changed<T>`.",
632    emit = None
633)]
634#[instrument(skip_all)]
635async fn filter(p: FilterParams) -> Result<CallToolResult, ErrorData> {
636    validate_filter(&p)?;
637    ok_source(p.emit_code().to_string())
638}