1use crate::ast::*;
2use std::collections::HashSet;
3
4#[derive(Debug, Clone)]
5pub struct RustOutput {
6 pub cargo_toml: String,
7 pub lib_rs: String,
8 pub types_rs: String,
9 pub entity_rs: String,
10}
11
12impl RustOutput {
13 pub fn full_lib(&self) -> String {
14 format!(
15 "{}\n\n// types.rs\n{}\n\n// entity.rs\n{}",
16 self.lib_rs, self.types_rs, self.entity_rs
17 )
18 }
19
20 pub fn mod_rs(&self) -> String {
21 self.lib_rs.clone()
22 }
23}
24
25#[derive(Debug, Clone)]
26pub struct RustConfig {
27 pub crate_name: String,
28 pub sdk_version: String,
29 pub module_mode: bool,
30 pub url: Option<String>,
32}
33
34impl Default for RustConfig {
35 fn default() -> Self {
36 Self {
37 crate_name: "generated-stack".to_string(),
38 sdk_version: "0.2".to_string(),
39 module_mode: false,
40 url: None,
41 }
42 }
43}
44
45pub fn compile_serializable_spec(
46 spec: SerializableStreamSpec,
47 entity_name: String,
48 config: Option<RustConfig>,
49) -> Result<RustOutput, String> {
50 let config = config.unwrap_or_default();
51 let compiler = RustCompiler::new(spec, entity_name, config);
52 Ok(compiler.compile())
53}
54
55pub fn write_rust_crate(
56 output: &RustOutput,
57 crate_dir: &std::path::Path,
58) -> Result<(), std::io::Error> {
59 std::fs::create_dir_all(crate_dir.join("src"))?;
60 std::fs::write(crate_dir.join("Cargo.toml"), &output.cargo_toml)?;
61 std::fs::write(crate_dir.join("src/lib.rs"), &output.lib_rs)?;
62 std::fs::write(crate_dir.join("src/types.rs"), &output.types_rs)?;
63 std::fs::write(crate_dir.join("src/entity.rs"), &output.entity_rs)?;
64 Ok(())
65}
66
67pub fn write_rust_module(
68 output: &RustOutput,
69 module_dir: &std::path::Path,
70) -> Result<(), std::io::Error> {
71 std::fs::create_dir_all(module_dir)?;
72 std::fs::write(module_dir.join("mod.rs"), output.mod_rs())?;
73 std::fs::write(module_dir.join("types.rs"), &output.types_rs)?;
74 std::fs::write(module_dir.join("entity.rs"), &output.entity_rs)?;
75 Ok(())
76}
77
78struct RustCompiler {
79 spec: SerializableStreamSpec,
80 entity_name: String,
81 config: RustConfig,
82}
83
84impl RustCompiler {
85 fn new(spec: SerializableStreamSpec, entity_name: String, config: RustConfig) -> Self {
86 Self {
87 spec,
88 entity_name,
89 config,
90 }
91 }
92
93 fn compile(&self) -> RustOutput {
94 RustOutput {
95 cargo_toml: self.generate_cargo_toml(),
96 lib_rs: self.generate_lib_rs(),
97 types_rs: self.generate_types_rs(),
98 entity_rs: self.generate_entity_rs(),
99 }
100 }
101
102 fn generate_cargo_toml(&self) -> String {
103 format!(
104 r#"[package]
105name = "{}"
106version = "0.1.0"
107edition = "2021"
108
109[dependencies]
110hyperstack-sdk = "{}"
111serde = {{ version = "1", features = ["derive"] }}
112serde_json = "1"
113"#,
114 self.config.crate_name, self.config.sdk_version
115 )
116 }
117
118 fn generate_lib_rs(&self) -> String {
119 let stack_name = self.derive_stack_name();
120 let entity_name = &self.entity_name;
121
122 format!(
123 r#"mod entity;
124mod types;
125
126pub use entity::{{{stack_name}Stack, {stack_name}StackViews, {entity_name}EntityViews}};
127pub use types::*;
128
129pub use hyperstack_sdk::{{ConnectionState, HyperStack, Stack, Update, Views}};
130"#,
131 stack_name = stack_name,
132 entity_name = entity_name
133 )
134 }
135
136 fn generate_types_rs(&self) -> String {
137 let mut output = String::new();
138 output.push_str("use serde::{Deserialize, Serialize};\n\n");
139
140 let mut generated = HashSet::new();
141
142 for section in &self.spec.sections {
143 if !Self::is_root_section(§ion.name) && generated.insert(section.name.clone()) {
144 output.push_str(&self.generate_struct_for_section(section));
145 output.push_str("\n\n");
146 }
147 }
148
149 output.push_str(&self.generate_main_entity_struct());
150 output.push_str(&self.generate_resolved_types(&mut generated));
151 output.push_str(&self.generate_event_wrapper());
152
153 output
154 }
155
156 fn generate_struct_for_section(&self, section: &EntitySection) -> String {
157 let struct_name = format!("{}{}", self.entity_name, to_pascal_case(§ion.name));
158 let mut fields = Vec::new();
159
160 for field in §ion.fields {
161 let field_name = to_snake_case(&field.field_name);
162 let rust_type = self.field_type_to_rust(field);
163
164 fields.push(format!(
165 " #[serde(default)]\n pub {}: {},",
166 field_name, rust_type
167 ));
168 }
169
170 format!(
171 "#[derive(Debug, Clone, Serialize, Deserialize, Default)]\npub struct {} {{\n{}\n}}",
172 struct_name,
173 fields.join("\n")
174 )
175 }
176
177 fn is_root_section(name: &str) -> bool {
179 name.eq_ignore_ascii_case("root")
180 }
181
182 fn generate_main_entity_struct(&self) -> String {
183 let mut fields = Vec::new();
184
185 for section in &self.spec.sections {
186 if !Self::is_root_section(§ion.name) {
187 let field_name = to_snake_case(§ion.name);
188 let type_name = format!("{}{}", self.entity_name, to_pascal_case(§ion.name));
189 fields.push(format!(
190 " #[serde(default)]\n pub {}: {},",
191 field_name, type_name
192 ));
193 }
194 }
195
196 for section in &self.spec.sections {
197 if Self::is_root_section(§ion.name) {
198 for field in §ion.fields {
199 let field_name = to_snake_case(&field.field_name);
200 let rust_type = self.field_type_to_rust(field);
201 fields.push(format!(
202 " #[serde(default)]\n pub {}: {},",
203 field_name, rust_type
204 ));
205 }
206 }
207 }
208
209 format!(
210 "#[derive(Debug, Clone, Serialize, Deserialize, Default)]\npub struct {} {{\n{}\n}}",
211 self.entity_name,
212 fields.join("\n")
213 )
214 }
215
216 fn generate_resolved_types(&self, generated: &mut HashSet<String>) -> String {
217 let mut output = String::new();
218
219 for section in &self.spec.sections {
220 for field in §ion.fields {
221 if let Some(resolved) = &field.resolved_type {
222 if generated.insert(resolved.type_name.clone()) {
223 output.push_str("\n\n");
224 output.push_str(&self.generate_resolved_struct(resolved));
225 }
226 }
227 }
228 }
229
230 output
231 }
232
233 fn generate_resolved_struct(&self, resolved: &ResolvedStructType) -> String {
234 if resolved.is_enum {
235 let variants: Vec<String> = resolved
236 .enum_variants
237 .iter()
238 .map(|v| format!(" {},", to_pascal_case(v)))
239 .collect();
240
241 format!(
242 "#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]\npub enum {} {{\n{}\n}}",
243 to_pascal_case(&resolved.type_name),
244 variants.join("\n")
245 )
246 } else {
247 let fields: Vec<String> = resolved
248 .fields
249 .iter()
250 .map(|f| {
251 let rust_type = self.resolved_field_to_rust(f);
252 format!(
253 " #[serde(default)]\n pub {}: {},",
254 to_snake_case(&f.field_name),
255 rust_type
256 )
257 })
258 .collect();
259
260 format!(
261 "#[derive(Debug, Clone, Serialize, Deserialize, Default)]\npub struct {} {{\n{}\n}}",
262 to_pascal_case(&resolved.type_name),
263 fields.join("\n")
264 )
265 }
266 }
267
268 fn generate_event_wrapper(&self) -> String {
269 r#"
270
271#[derive(Debug, Clone, Serialize, Deserialize)]
272pub struct EventWrapper<T> {
273 #[serde(default)]
274 pub timestamp: i64,
275 pub data: T,
276 #[serde(default)]
277 pub slot: Option<f64>,
278 #[serde(default)]
279 pub signature: Option<String>,
280}
281
282impl<T: Default> Default for EventWrapper<T> {
283 fn default() -> Self {
284 Self {
285 timestamp: 0,
286 data: T::default(),
287 slot: None,
288 signature: None,
289 }
290 }
291}
292"#
293 .to_string()
294 }
295
296 fn generate_entity_rs(&self) -> String {
297 let entity_name = &self.entity_name;
298 let stack_name = self.derive_stack_name();
299 let stack_name_kebab = to_kebab_case(entity_name);
300 let entity_snake = to_snake_case(entity_name);
301
302 let types_import = if self.config.module_mode {
303 "super::types"
304 } else {
305 "crate::types"
306 };
307
308 let url_impl = match &self.config.url {
310 Some(url) => format!(r#"fn url() -> &'static str {{
311 "{}"
312 }}"#, url),
313 None => r#"fn url() -> &'static str {
314 "" // TODO: Set URL after first deployment in hyperstack.toml
315 }"#.to_string(),
316 };
317
318 let entity_views = self.generate_entity_views_struct();
319
320 format!(
321 r#"use {types_import}::{entity_name};
322use hyperstack_sdk::{{Stack, StateView, ViewBuilder, ViewHandle, Views}};
323
324pub struct {stack_name}Stack;
325
326impl Stack for {stack_name}Stack {{
327 type Views = {stack_name}StackViews;
328
329 fn name() -> &'static str {{
330 "{stack_name_kebab}"
331 }}
332
333 {url_impl}
334}}
335
336pub struct {stack_name}StackViews {{
337 pub {entity_snake}: {entity_name}EntityViews,
338}}
339
340impl Views for {stack_name}StackViews {{
341 fn from_builder(builder: ViewBuilder) -> Self {{
342 Self {{
343 {entity_snake}: {entity_name}EntityViews {{ builder }},
344 }}
345 }}
346}}
347{entity_views}"#,
348 types_import = types_import,
349 entity_name = entity_name,
350 stack_name = stack_name,
351 stack_name_kebab = stack_name_kebab,
352 entity_snake = entity_snake,
353 url_impl = url_impl,
354 entity_views = entity_views
355 )
356 }
357
358 fn generate_entity_views_struct(&self) -> String {
359 let entity_name = &self.entity_name;
360
361 let derived: Vec<_> = self
362 .spec
363 .views
364 .iter()
365 .filter(|v| {
366 !v.id.ends_with("/state")
367 && !v.id.ends_with("/list")
368 && v.id.starts_with(entity_name)
369 })
370 .collect();
371
372 let mut derived_methods = String::new();
373 for view in &derived {
374 let view_name = view.id.split('/').nth(1).unwrap_or("unknown");
375 let method_name = to_snake_case(view_name);
376
377 derived_methods.push_str(&format!(
378 r#"
379 pub fn {method_name}(&self) -> ViewHandle<{entity_name}> {{
380 self.builder.view("{view_id}")
381 }}
382"#,
383 method_name = method_name,
384 entity_name = entity_name,
385 view_id = view.id
386 ));
387 }
388
389 format!(
390 r#"
391pub struct {entity_name}EntityViews {{
392 builder: ViewBuilder,
393}}
394
395impl {entity_name}EntityViews {{
396 pub fn state(&self) -> StateView<{entity_name}> {{
397 StateView::new(
398 self.builder.connection().clone(),
399 self.builder.store().clone(),
400 "{entity_name}/state".to_string(),
401 self.builder.initial_data_timeout(),
402 )
403 }}
404
405 pub fn list(&self) -> ViewHandle<{entity_name}> {{
406 self.builder.view("{entity_name}/list")
407 }}
408{derived_methods}}}"#,
409 entity_name = entity_name,
410 derived_methods = derived_methods
411 )
412 }
413
414 fn derive_stack_name(&self) -> String {
417 let entity_name = &self.entity_name;
418
419 let suffixes = ["Round", "Token", "Game", "State", "Entity", "Data"];
421
422 for suffix in suffixes {
423 if entity_name.ends_with(suffix) && entity_name.len() > suffix.len() {
424 return entity_name[..entity_name.len() - suffix.len()].to_string();
425 }
426 }
427
428 entity_name.clone()
430 }
431
432 fn field_type_to_rust(&self, field: &FieldTypeInfo) -> String {
446 let base = self.base_type_to_rust(&field.base_type, &field.rust_type_name);
447
448 let typed = if field.is_array && !matches!(field.base_type, BaseType::Array) {
449 format!("Vec<{}>", base)
450 } else {
451 base
452 };
453
454 if field.is_optional {
457 format!("Option<Option<{}>>", typed)
458 } else {
459 format!("Option<{}>", typed)
460 }
461 }
462
463 fn base_type_to_rust(&self, base_type: &BaseType, rust_type_name: &str) -> String {
464 match base_type {
465 BaseType::Integer => {
466 if rust_type_name.contains("u64") {
467 "u64".to_string()
468 } else if rust_type_name.contains("i64") {
469 "i64".to_string()
470 } else if rust_type_name.contains("u32") {
471 "u32".to_string()
472 } else if rust_type_name.contains("i32") {
473 "i32".to_string()
474 } else {
475 "i64".to_string()
476 }
477 }
478 BaseType::Float => "f64".to_string(),
479 BaseType::String => "String".to_string(),
480 BaseType::Boolean => "bool".to_string(),
481 BaseType::Timestamp => "i64".to_string(),
482 BaseType::Binary => "Vec<u8>".to_string(),
483 BaseType::Pubkey => "String".to_string(),
484 BaseType::Array => "Vec<serde_json::Value>".to_string(),
485 BaseType::Object => "serde_json::Value".to_string(),
486 BaseType::Any => "serde_json::Value".to_string(),
487 }
488 }
489
490 fn resolved_field_to_rust(&self, field: &ResolvedField) -> String {
491 let base = self.base_type_to_rust(&field.base_type, &field.field_type);
492
493 let typed = if field.is_array {
494 format!("Vec<{}>", base)
495 } else {
496 base
497 };
498
499 if field.is_optional {
500 format!("Option<Option<{}>>", typed)
501 } else {
502 format!("Option<{}>", typed)
503 }
504 }
505}
506
507fn to_kebab_case(s: &str) -> String {
508 let mut result = String::new();
509 for (i, c) in s.chars().enumerate() {
510 if c.is_uppercase() {
511 if i > 0 {
512 result.push('-');
513 }
514 result.push(c.to_lowercase().next().unwrap());
515 } else {
516 result.push(c);
517 }
518 }
519 result
520}
521
522fn to_pascal_case(s: &str) -> String {
523 s.split(['_', '-', '.'])
524 .map(|word| {
525 let mut chars = word.chars();
526 match chars.next() {
527 None => String::new(),
528 Some(first) => first.to_uppercase().collect::<String>() + chars.as_str(),
529 }
530 })
531 .collect()
532}
533
534fn to_snake_case(s: &str) -> String {
535 let mut result = String::new();
536 for (i, ch) in s.chars().enumerate() {
537 if ch.is_uppercase() {
538 if i > 0 {
539 result.push('_');
540 }
541 result.push(ch.to_lowercase().next().unwrap());
542 } else {
543 result.push(ch);
544 }
545 }
546 result
547}