1use elicitation::emit_code::{CrateDep, EmitCode};
8use elicitation::{ElicitPlugin, elicit_tool};
9use proc_macro2::TokenStream;
10use quote::quote;
11use rmcp::ErrorData;
12use rmcp::model::{CallToolResult, Content};
13use schemars::JsonSchema;
14use serde::{Deserialize, Serialize};
15use tracing::instrument;
16
17#[derive(Debug, Clone, Copy, Default, Serialize, Deserialize, JsonSchema, PartialEq, Eq)]
19#[serde(rename_all = "snake_case")]
20pub enum BevyQueryItemAccess {
21 Value,
23 #[default]
25 Shared,
26 Mutable,
28}
29
30#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
32pub struct BevyQueryItemSpec {
33 pub ty: String,
35 #[serde(default)]
37 pub access: BevyQueryItemAccess,
38}
39
40#[derive(Debug, Clone, Copy, Serialize, Deserialize, JsonSchema, PartialEq, Eq)]
42#[serde(rename_all = "snake_case")]
43pub enum BevyQueryFilterKind {
44 With,
46 Without,
48 Added,
50 Changed,
52}
53
54#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
56pub struct DefineComponentQueryParams {
57 pub binding: String,
59 #[serde(default)]
61 pub mutable_binding: bool,
62 pub items: Vec<BevyQueryItemSpec>,
64 #[serde(default)]
66 pub filters: Vec<String>,
67}
68
69#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
71pub struct DefineResourceParams {
72 pub binding: String,
74 pub resource_type: String,
76 #[serde(default)]
78 pub mutable: bool,
79}
80
81#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
83pub struct DefineEventReaderParams {
84 pub binding: String,
86 pub event_type: String,
88}
89
90#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
92pub struct DefineEventWriterParams {
93 pub binding: String,
95 pub event_type: String,
97}
98
99#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
101pub struct DefineHandleParams {
102 #[serde(default)]
104 pub visibility: Option<String>,
105 pub binding: String,
107 pub asset_type: String,
109}
110
111#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
113pub struct DefineLocalParams {
114 pub binding: String,
116 pub local_type: String,
118 #[serde(default)]
120 pub mutable_binding: bool,
121}
122
123#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
125pub struct DefineTimeParams {
126 pub binding: String,
128 #[serde(default)]
130 pub time_generic: Option<String>,
131 #[serde(default)]
133 pub mutable: bool,
134}
135
136#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
138pub struct FilterParams {
139 pub kind: BevyQueryFilterKind,
141 pub type_name: String,
143}
144
145#[derive(Debug, Clone, Serialize, Deserialize, JsonSchema)]
147pub struct SystemSignatureParams {
148 #[serde(default)]
150 pub visibility: Option<String>,
151 pub function_name: String,
153 #[serde(default)]
155 pub params: Vec<String>,
156 #[serde(default)]
158 pub return_type: Option<String>,
159 #[serde(default)]
161 pub body: Option<String>,
162}
163
164fn tool_err(msg: impl std::fmt::Display) -> ErrorData {
165 ErrorData::invalid_params(msg.to_string(), None)
166}
167
168fn ok_source(source: String) -> Result<CallToolResult, ErrorData> {
169 Ok(CallToolResult::success(vec![Content::text(source)]))
170}
171
172fn parse_ident(src: &str, context: &str) -> Result<syn::Ident, ErrorData> {
173 syn::parse_str::<syn::Ident>(src)
174 .map_err(|error| tool_err(format!("invalid {context} identifier `{src}`: {error}")))
175}
176
177fn parse_type(src: &str, context: &str) -> Result<syn::Type, ErrorData> {
178 syn::parse_str::<syn::Type>(src)
179 .map_err(|error| tool_err(format!("invalid {context} type `{src}`: {error}")))
180}
181
182fn parse_visibility(src: &str) -> Result<syn::Visibility, ErrorData> {
183 syn::parse_str::<syn::Visibility>(src)
184 .map_err(|error| tool_err(format!("invalid visibility `{src}`: {error}")))
185}
186
187fn parse_param_fragment(src: &str) -> Result<syn::FnArg, ErrorData> {
188 syn::parse_str::<syn::FnArg>(src)
189 .map_err(|error| tool_err(format!("invalid system parameter `{src}`: {error}")))
190}
191
192fn parse_body_statements(src: &str) -> Result<Vec<syn::Stmt>, ErrorData> {
193 let wrapped = format!("{{{src}}}");
194 syn::parse_str::<syn::Block>(&wrapped)
195 .map(|block| block.stmts)
196 .map_err(|error| tool_err(format!("invalid function body `{src}`: {error}")))
197}
198
199fn validate_non_empty<T>(values: &[T], context: &str) -> Result<(), ErrorData> {
200 if values.is_empty() {
201 Err(tool_err(format!("{context} must not be empty")))
202 } else {
203 Ok(())
204 }
205}
206
207fn binding_tokens(binding: &str, mutable: bool) -> syn::Pat {
208 let binding = parse_ident(binding, "binding").expect("validated bindings must parse");
209 if mutable {
210 syn::parse_quote!(mut #binding)
211 } else {
212 syn::parse_quote!(#binding)
213 }
214}
215
216fn type_tokens(src: &str, context: &str) -> syn::Type {
217 parse_type(src, context).expect("validated types must parse")
218}
219
220fn visibility_tokens(visibility: &Option<String>) -> syn::Visibility {
221 visibility
222 .as_deref()
223 .map(|src| parse_visibility(src).expect("validated visibility must parse"))
224 .unwrap_or_else(|| syn::parse_quote!())
225}
226
227fn fn_arg_tokens(src: &str) -> syn::FnArg {
228 parse_param_fragment(src).expect("validated function args must parse")
229}
230
231fn filter_kind_ident(kind: BevyQueryFilterKind) -> syn::Ident {
232 let name = match kind {
233 BevyQueryFilterKind::With => "With",
234 BevyQueryFilterKind::Without => "Without",
235 BevyQueryFilterKind::Added => "Added",
236 BevyQueryFilterKind::Changed => "Changed",
237 };
238 syn::Ident::new(name, proc_macro2::Span::call_site())
239}
240
241fn render_query_item(item: &BevyQueryItemSpec) -> TokenStream {
242 let ty = type_tokens(&item.ty, "query item");
243 match item.access {
244 BevyQueryItemAccess::Value => quote! { #ty },
245 BevyQueryItemAccess::Shared => quote! { &#ty },
246 BevyQueryItemAccess::Mutable => quote! { &mut #ty },
247 }
248}
249
250fn render_query_filters(filters: &[String]) -> TokenStream {
251 let filter_tokens: Vec<syn::Type> = filters
252 .iter()
253 .map(|filter| type_tokens(filter, "query filter"))
254 .collect();
255 match filter_tokens.len() {
256 0 => TokenStream::new(),
257 1 => {
258 let filter = &filter_tokens[0];
259 quote! { , #filter }
260 }
261 _ => quote! { , (#(#filter_tokens),*) },
262 }
263}
264
265fn time_type_tokens(time_generic: &Option<String>) -> TokenStream {
266 match time_generic {
267 Some(generic) => {
268 let generic = type_tokens(generic, "time generic");
269 quote! { ::bevy::prelude::Time<#generic> }
270 }
271 None => quote! { ::bevy::prelude::Time },
272 }
273}
274
275fn bevy_dep() -> Vec<CrateDep> {
276 vec![CrateDep::new("bevy", "0.18")]
277}
278
279impl EmitCode for DefineComponentQueryParams {
280 fn emit_code(&self) -> TokenStream {
281 let binding = binding_tokens(&self.binding, self.mutable_binding);
282 let items: Vec<TokenStream> = self.items.iter().map(render_query_item).collect();
283 let data = if items.len() == 1 {
284 let item = &items[0];
285 quote! { #item }
286 } else {
287 quote! { (#(#items),*) }
288 };
289 let filters = render_query_filters(&self.filters);
290 quote! { #binding: ::bevy::ecs::system::Query<#data #filters> }
291 }
292
293 fn crate_deps(&self) -> Vec<CrateDep> {
294 bevy_dep()
295 }
296}
297
298impl EmitCode for DefineResourceParams {
299 fn emit_code(&self) -> TokenStream {
300 let binding = binding_tokens(&self.binding, self.mutable);
301 let resource = type_tokens(&self.resource_type, "resource");
302 if self.mutable {
303 quote! { #binding: ::bevy::ecs::system::ResMut<#resource> }
304 } else {
305 quote! { #binding: ::bevy::ecs::system::Res<#resource> }
306 }
307 }
308
309 fn crate_deps(&self) -> Vec<CrateDep> {
310 bevy_dep()
311 }
312}
313
314impl EmitCode for DefineEventReaderParams {
315 fn emit_code(&self) -> TokenStream {
316 let binding = binding_tokens(&self.binding, false);
317 let event = type_tokens(&self.event_type, "event");
318 quote! { #binding: ::bevy::ecs::event::EventReader<#event> }
319 }
320
321 fn crate_deps(&self) -> Vec<CrateDep> {
322 bevy_dep()
323 }
324}
325
326impl EmitCode for DefineEventWriterParams {
327 fn emit_code(&self) -> TokenStream {
328 let binding = binding_tokens(&self.binding, true);
329 let event = type_tokens(&self.event_type, "event");
330 quote! { #binding: ::bevy::ecs::event::EventWriter<#event> }
331 }
332
333 fn crate_deps(&self) -> Vec<CrateDep> {
334 bevy_dep()
335 }
336}
337
338impl EmitCode for DefineHandleParams {
339 fn emit_code(&self) -> TokenStream {
340 let visibility = visibility_tokens(&self.visibility);
341 let binding = parse_ident(&self.binding, "field").expect("validated field must parse");
342 let asset = type_tokens(&self.asset_type, "asset");
343 quote! { #visibility #binding: ::bevy::asset::Handle<#asset> }
344 }
345
346 fn crate_deps(&self) -> Vec<CrateDep> {
347 bevy_dep()
348 }
349}
350
351impl EmitCode for DefineLocalParams {
352 fn emit_code(&self) -> TokenStream {
353 let binding = binding_tokens(&self.binding, self.mutable_binding);
354 let local = type_tokens(&self.local_type, "local");
355 quote! { #binding: ::bevy::ecs::system::Local<#local> }
356 }
357
358 fn crate_deps(&self) -> Vec<CrateDep> {
359 bevy_dep()
360 }
361}
362
363impl EmitCode for DefineTimeParams {
364 fn emit_code(&self) -> TokenStream {
365 let binding = binding_tokens(&self.binding, self.mutable);
366 let time = time_type_tokens(&self.time_generic);
367 if self.mutable {
368 quote! { #binding: ::bevy::ecs::system::ResMut<#time> }
369 } else {
370 quote! { #binding: ::bevy::ecs::system::Res<#time> }
371 }
372 }
373
374 fn crate_deps(&self) -> Vec<CrateDep> {
375 bevy_dep()
376 }
377}
378
379impl EmitCode for FilterParams {
380 fn emit_code(&self) -> TokenStream {
381 let kind = filter_kind_ident(self.kind);
382 let ty = type_tokens(&self.type_name, "filter");
383 quote! { ::bevy::ecs::query::#kind<#ty> }
384 }
385
386 fn crate_deps(&self) -> Vec<CrateDep> {
387 bevy_dep()
388 }
389}
390
391impl EmitCode for SystemSignatureParams {
392 fn emit_code(&self) -> TokenStream {
393 let visibility = visibility_tokens(&self.visibility);
394 let name = parse_ident(&self.function_name, "function")
395 .expect("validated function name must parse");
396 let params: Vec<syn::FnArg> = self
397 .params
398 .iter()
399 .map(|param| fn_arg_tokens(param))
400 .collect();
401 let output = self
402 .return_type
403 .as_deref()
404 .map(|return_type| {
405 let ty = type_tokens(return_type, "return");
406 quote! { -> #ty }
407 })
408 .unwrap_or_default();
409 let body = self
410 .body
411 .as_deref()
412 .map(|body| parse_body_statements(body).expect("validated body must parse"))
413 .unwrap_or_default();
414 quote! {
415 #visibility fn #name(#(#params),*) #output {
416 #(#body)*
417 }
418 }
419 }
420
421 fn crate_deps(&self) -> Vec<CrateDep> {
422 bevy_dep()
423 }
424}
425
426elicitation::register_emit!("define_component_query", DefineComponentQueryParams);
427elicitation::register_emit!("define_resource", DefineResourceParams);
428elicitation::register_emit!("define_event_reader", DefineEventReaderParams);
429elicitation::register_emit!("define_event_writer", DefineEventWriterParams);
430elicitation::register_emit!("define_handle", DefineHandleParams);
431elicitation::register_emit!("define_local", DefineLocalParams);
432elicitation::register_emit!("define_time", DefineTimeParams);
433elicitation::register_emit!("system_signature", SystemSignatureParams);
434elicitation::register_emit!("filter", FilterParams);
435
436#[derive(Debug, ElicitPlugin)]
438#[plugin(name = "bevy_query")]
439pub struct BevyQueryPlugin;
440
441impl BevyQueryPlugin {
442 #[instrument]
444 pub fn new() -> Self {
445 Self
446 }
447}
448
449impl Default for BevyQueryPlugin {
450 fn default() -> Self {
451 Self::new()
452 }
453}
454
455fn validate_define_component_query(params: &DefineComponentQueryParams) -> Result<(), ErrorData> {
456 let _ = parse_ident(¶ms.binding, "binding")?;
457 validate_non_empty(¶ms.items, "query items")?;
458 for item in ¶ms.items {
459 let _ = parse_type(&item.ty, "query item")?;
460 }
461 for filter in ¶ms.filters {
462 let _ = parse_type(filter, "query filter")?;
463 }
464 Ok(())
465}
466
467fn validate_define_resource(params: &DefineResourceParams) -> Result<(), ErrorData> {
468 let _ = parse_ident(¶ms.binding, "binding")?;
469 let _ = parse_type(¶ms.resource_type, "resource")?;
470 Ok(())
471}
472
473fn validate_define_event_reader(params: &DefineEventReaderParams) -> Result<(), ErrorData> {
474 let _ = parse_ident(¶ms.binding, "binding")?;
475 let _ = parse_type(¶ms.event_type, "event")?;
476 Ok(())
477}
478
479fn validate_define_event_writer(params: &DefineEventWriterParams) -> Result<(), ErrorData> {
480 let _ = parse_ident(¶ms.binding, "binding")?;
481 let _ = parse_type(¶ms.event_type, "event")?;
482 Ok(())
483}
484
485fn validate_define_handle(params: &DefineHandleParams) -> Result<(), ErrorData> {
486 if let Some(visibility) = ¶ms.visibility {
487 let _ = parse_visibility(visibility)?;
488 }
489 let _ = parse_ident(¶ms.binding, "field")?;
490 let _ = parse_type(¶ms.asset_type, "asset")?;
491 Ok(())
492}
493
494fn validate_define_local(params: &DefineLocalParams) -> Result<(), ErrorData> {
495 let _ = parse_ident(¶ms.binding, "binding")?;
496 let _ = parse_type(¶ms.local_type, "local")?;
497 Ok(())
498}
499
500fn validate_define_time(params: &DefineTimeParams) -> Result<(), ErrorData> {
501 let _ = parse_ident(¶ms.binding, "binding")?;
502 if let Some(time_generic) = ¶ms.time_generic {
503 let _ = parse_type(time_generic, "time generic")?;
504 }
505 Ok(())
506}
507
508fn validate_filter(params: &FilterParams) -> Result<(), ErrorData> {
509 let _ = parse_type(¶ms.type_name, "filter")?;
510 Ok(())
511}
512
513fn validate_system_signature(params: &SystemSignatureParams) -> Result<(), ErrorData> {
514 if let Some(visibility) = ¶ms.visibility {
515 let _ = parse_visibility(visibility)?;
516 }
517 let _ = parse_ident(¶ms.function_name, "function")?;
518 for param in ¶ms.params {
519 let _ = parse_param_fragment(param)?;
520 }
521 if let Some(return_type) = ¶ms.return_type {
522 let _ = parse_type(return_type, "return")?;
523 }
524 if let Some(body) = ¶ms.body {
525 let _ = parse_body_statements(body)?;
526 }
527 Ok(())
528}
529
530#[elicit_tool(
531 plugin = "bevy_query",
532 name = "define_component_query",
533 description = "Emit a `Query<...>` system parameter fragment from query items and optional filters.",
534 emit = None
535)]
536#[instrument(skip_all)]
537async fn define_component_query(
538 p: DefineComponentQueryParams,
539) -> Result<CallToolResult, ErrorData> {
540 validate_define_component_query(&p)?;
541 ok_source(p.emit_code().to_string())
542}
543
544#[elicit_tool(
545 plugin = "bevy_query",
546 name = "define_resource",
547 description = "Emit a `Res<T>` or `ResMut<T>` system parameter fragment.",
548 emit = None
549)]
550#[instrument(skip_all)]
551async fn define_resource(p: DefineResourceParams) -> Result<CallToolResult, ErrorData> {
552 validate_define_resource(&p)?;
553 ok_source(p.emit_code().to_string())
554}
555
556#[elicit_tool(
557 plugin = "bevy_query",
558 name = "define_event_reader",
559 description = "Emit an `EventReader<E>` system parameter fragment.",
560 emit = None
561)]
562#[instrument(skip_all)]
563async fn define_event_reader(p: DefineEventReaderParams) -> Result<CallToolResult, ErrorData> {
564 validate_define_event_reader(&p)?;
565 ok_source(p.emit_code().to_string())
566}
567
568#[elicit_tool(
569 plugin = "bevy_query",
570 name = "define_event_writer",
571 description = "Emit an `EventWriter<E>` system parameter fragment with a mutable binding.",
572 emit = None
573)]
574#[instrument(skip_all)]
575async fn define_event_writer(p: DefineEventWriterParams) -> Result<CallToolResult, ErrorData> {
576 validate_define_event_writer(&p)?;
577 ok_source(p.emit_code().to_string())
578}
579
580#[elicit_tool(
581 plugin = "bevy_query",
582 name = "define_handle",
583 description = "Emit a `Handle<A>` field declaration fragment.",
584 emit = None
585)]
586#[instrument(skip_all)]
587async fn define_handle(p: DefineHandleParams) -> Result<CallToolResult, ErrorData> {
588 validate_define_handle(&p)?;
589 ok_source(p.emit_code().to_string())
590}
591
592#[elicit_tool(
593 plugin = "bevy_query",
594 name = "define_local",
595 description = "Emit a `Local<T>` system parameter fragment.",
596 emit = None
597)]
598#[instrument(skip_all)]
599async fn define_local(p: DefineLocalParams) -> Result<CallToolResult, ErrorData> {
600 validate_define_local(&p)?;
601 ok_source(p.emit_code().to_string())
602}
603
604#[elicit_tool(
605 plugin = "bevy_query",
606 name = "define_time",
607 description = "Emit a `Res<Time>` or `Res<Time<Fixed>>`-style system parameter fragment.",
608 emit = None
609)]
610#[instrument(skip_all)]
611async fn define_time(p: DefineTimeParams) -> Result<CallToolResult, ErrorData> {
612 validate_define_time(&p)?;
613 ok_source(p.emit_code().to_string())
614}
615
616#[elicit_tool(
617 plugin = "bevy_query",
618 name = "system_signature",
619 description = "Emit a full Bevy system function signature from previously generated parameter fragments.",
620 emit = None
621)]
622#[instrument(skip_all)]
623async fn system_signature(p: SystemSignatureParams) -> Result<CallToolResult, ErrorData> {
624 validate_system_signature(&p)?;
625 ok_source(p.emit_code().to_string())
626}
627
628#[elicit_tool(
629 plugin = "bevy_query",
630 name = "filter",
631 description = "Emit a Bevy query filter such as `With<T>` or `Changed<T>`.",
632 emit = None
633)]
634#[instrument(skip_all)]
635async fn filter(p: FilterParams) -> Result<CallToolResult, ErrorData> {
636 validate_filter(&p)?;
637 ok_source(p.emit_code().to_string())
638}