1use crate::ast::*;
2use std::collections::HashSet;
3
4#[derive(Debug, Clone)]
5pub struct RustOutput {
6 pub cargo_toml: String,
7 pub lib_rs: String,
8 pub types_rs: String,
9 pub entity_rs: String,
10}
11
12impl RustOutput {
13 pub fn full_lib(&self) -> String {
14 format!(
15 "{}\n\n// types.rs\n{}\n\n// entity.rs\n{}",
16 self.lib_rs, self.types_rs, self.entity_rs
17 )
18 }
19
20 pub fn mod_rs(&self) -> String {
21 self.lib_rs.clone()
22 }
23}
24
25#[derive(Debug, Clone)]
26pub struct RustConfig {
27 pub crate_name: String,
28 pub sdk_version: String,
29 pub module_mode: bool,
30}
31
32impl Default for RustConfig {
33 fn default() -> Self {
34 Self {
35 crate_name: "generated-stack".to_string(),
36 sdk_version: "0.2".to_string(),
37 module_mode: false,
38 }
39 }
40}
41
42pub fn compile_serializable_spec(
43 spec: SerializableStreamSpec,
44 entity_name: String,
45 config: Option<RustConfig>,
46) -> Result<RustOutput, String> {
47 let config = config.unwrap_or_default();
48 let compiler = RustCompiler::new(spec, entity_name, config);
49 Ok(compiler.compile())
50}
51
52pub fn write_rust_crate(
53 output: &RustOutput,
54 crate_dir: &std::path::Path,
55) -> Result<(), std::io::Error> {
56 std::fs::create_dir_all(crate_dir.join("src"))?;
57 std::fs::write(crate_dir.join("Cargo.toml"), &output.cargo_toml)?;
58 std::fs::write(crate_dir.join("src/lib.rs"), &output.lib_rs)?;
59 std::fs::write(crate_dir.join("src/types.rs"), &output.types_rs)?;
60 std::fs::write(crate_dir.join("src/entity.rs"), &output.entity_rs)?;
61 Ok(())
62}
63
64pub fn write_rust_module(
65 output: &RustOutput,
66 module_dir: &std::path::Path,
67) -> Result<(), std::io::Error> {
68 std::fs::create_dir_all(module_dir)?;
69 std::fs::write(module_dir.join("mod.rs"), output.mod_rs())?;
70 std::fs::write(module_dir.join("types.rs"), &output.types_rs)?;
71 std::fs::write(module_dir.join("entity.rs"), &output.entity_rs)?;
72 Ok(())
73}
74
75struct RustCompiler {
76 spec: SerializableStreamSpec,
77 entity_name: String,
78 config: RustConfig,
79}
80
81impl RustCompiler {
82 fn new(spec: SerializableStreamSpec, entity_name: String, config: RustConfig) -> Self {
83 Self {
84 spec,
85 entity_name,
86 config,
87 }
88 }
89
90 fn compile(&self) -> RustOutput {
91 RustOutput {
92 cargo_toml: self.generate_cargo_toml(),
93 lib_rs: self.generate_lib_rs(),
94 types_rs: self.generate_types_rs(),
95 entity_rs: self.generate_entity_rs(),
96 }
97 }
98
99 fn generate_cargo_toml(&self) -> String {
100 format!(
101 r#"[package]
102name = "{}"
103version = "0.1.0"
104edition = "2021"
105
106[dependencies]
107hyperstack-sdk = "{}"
108serde = {{ version = "1", features = ["derive"] }}
109serde_json = "1"
110"#,
111 self.config.crate_name, self.config.sdk_version
112 )
113 }
114
115 fn generate_lib_rs(&self) -> String {
116 format!(
117 r#"mod types;
118mod entity;
119
120pub use types::*;
121pub use entity::{entity_name}Entity;
122
123pub use hyperstack_sdk::{{HyperStack, Entity, Update, ConnectionState}};
124"#,
125 entity_name = self.entity_name
126 )
127 }
128
129 fn generate_types_rs(&self) -> String {
130 let mut output = String::new();
131 output.push_str("use serde::{Deserialize, Serialize};\n\n");
132
133 let mut generated = HashSet::new();
134
135 for section in &self.spec.sections {
136 if !Self::is_root_section(§ion.name) && generated.insert(section.name.clone()) {
137 output.push_str(&self.generate_struct_for_section(section));
138 output.push_str("\n\n");
139 }
140 }
141
142 output.push_str(&self.generate_main_entity_struct());
143 output.push_str(&self.generate_resolved_types(&mut generated));
144 output.push_str(&self.generate_event_wrapper());
145
146 output
147 }
148
149 fn generate_struct_for_section(&self, section: &EntitySection) -> String {
150 let struct_name = format!("{}{}", self.entity_name, to_pascal_case(§ion.name));
151 let mut fields = Vec::new();
152
153 for field in §ion.fields {
154 let field_name = to_snake_case(&field.field_name);
155 let rust_type = self.field_type_to_rust(field);
156
157 let serde_attr = if field_name != to_snake_case(&field.field_name)
158 || field_name != field.field_name
159 {
160 let original = &field.field_name;
161 if to_snake_case(original) != *original {
162 format!(
163 " #[serde(rename = \"{}\", default)]\n",
164 to_camel_case(original)
165 )
166 } else {
167 " #[serde(default)]\n".to_string()
168 }
169 } else {
170 " #[serde(default)]\n".to_string()
171 };
172
173 fields.push(format!(
174 "{} pub {}: {},",
175 serde_attr,
176 to_snake_case(&field.field_name),
177 rust_type
178 ));
179 }
180
181 format!(
182 "#[derive(Debug, Clone, Serialize, Deserialize, Default)]\npub struct {} {{\n{}\n}}",
183 struct_name,
184 fields.join("\n")
185 )
186 }
187
188 fn is_root_section(name: &str) -> bool {
190 name.eq_ignore_ascii_case("root")
191 }
192
193 fn generate_main_entity_struct(&self) -> String {
194 let mut fields = Vec::new();
195
196 for section in &self.spec.sections {
197 if !Self::is_root_section(§ion.name) {
198 let field_name = to_snake_case(§ion.name);
199 let type_name = format!("{}{}", self.entity_name, to_pascal_case(§ion.name));
200 let serde_attr = if field_name != section.name {
201 format!(
202 " #[serde(rename = \"{}\", default)]\n",
203 to_camel_case(§ion.name)
204 )
205 } else {
206 " #[serde(default)]\n".to_string()
207 };
208 fields.push(format!(
209 "{} pub {}: {},",
210 serde_attr, field_name, type_name
211 ));
212 }
213 }
214
215 for section in &self.spec.sections {
216 if Self::is_root_section(§ion.name) {
217 for field in §ion.fields {
218 let field_name = to_snake_case(&field.field_name);
219 let rust_type = self.field_type_to_rust(field);
220 fields.push(format!(
221 " #[serde(default)]\n pub {}: {},",
222 field_name, rust_type
223 ));
224 }
225 }
226 }
227
228 format!(
229 "#[derive(Debug, Clone, Serialize, Deserialize, Default)]\npub struct {} {{\n{}\n}}",
230 self.entity_name,
231 fields.join("\n")
232 )
233 }
234
235 fn generate_resolved_types(&self, generated: &mut HashSet<String>) -> String {
236 let mut output = String::new();
237
238 for section in &self.spec.sections {
239 for field in §ion.fields {
240 if let Some(resolved) = &field.resolved_type {
241 if generated.insert(resolved.type_name.clone()) {
242 output.push_str("\n\n");
243 output.push_str(&self.generate_resolved_struct(resolved));
244 }
245 }
246 }
247 }
248
249 output
250 }
251
252 fn generate_resolved_struct(&self, resolved: &ResolvedStructType) -> String {
253 if resolved.is_enum {
254 let variants: Vec<String> = resolved
255 .enum_variants
256 .iter()
257 .map(|v| format!(" {},", to_pascal_case(v)))
258 .collect();
259
260 format!(
261 "#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]\npub enum {} {{\n{}\n}}",
262 to_pascal_case(&resolved.type_name),
263 variants.join("\n")
264 )
265 } else {
266 let fields: Vec<String> = resolved
267 .fields
268 .iter()
269 .map(|f| {
270 let rust_type = self.resolved_field_to_rust(f);
271 let serde_attr = format!(
272 " #[serde(rename = \"{}\", default)]\n",
273 to_camel_case(&f.field_name)
274 );
275 format!(
276 "{} pub {}: {},",
277 serde_attr,
278 to_snake_case(&f.field_name),
279 rust_type
280 )
281 })
282 .collect();
283
284 format!(
285 "#[derive(Debug, Clone, Serialize, Deserialize, Default)]\npub struct {} {{\n{}\n}}",
286 to_pascal_case(&resolved.type_name),
287 fields.join("\n")
288 )
289 }
290 }
291
292 fn generate_event_wrapper(&self) -> String {
293 r#"
294
295#[derive(Debug, Clone, Serialize, Deserialize)]
296pub struct EventWrapper<T> {
297 #[serde(default)]
298 pub timestamp: i64,
299 pub data: T,
300 #[serde(default)]
301 pub slot: Option<f64>,
302 #[serde(default)]
303 pub signature: Option<String>,
304}
305
306impl<T: Default> Default for EventWrapper<T> {
307 fn default() -> Self {
308 Self {
309 timestamp: 0,
310 data: T::default(),
311 slot: None,
312 signature: None,
313 }
314 }
315}
316"#
317 .to_string()
318 }
319
320 fn generate_entity_rs(&self) -> String {
321 let entity_name = &self.entity_name;
322 let types_import = if self.config.module_mode {
323 "super::types"
324 } else {
325 "crate::types"
326 };
327
328 format!(
329 r#"use hyperstack_sdk::Entity;
330use {types_import}::{entity_name};
331
332pub struct {entity_name}Entity;
333
334impl Entity for {entity_name}Entity {{
335 type Data = {entity_name};
336
337 const NAME: &'static str = "{entity_name}";
338
339 fn state_view() -> &'static str {{
340 "{entity_name}/state"
341 }}
342
343 fn list_view() -> &'static str {{
344 "{entity_name}/list"
345 }}
346}}
347"#,
348 types_import = types_import,
349 entity_name = entity_name
350 )
351 }
352
353 fn field_type_to_rust(&self, field: &FieldTypeInfo) -> String {
367 let base = self.base_type_to_rust(&field.base_type, &field.rust_type_name);
368
369 let typed = if field.is_array && !matches!(field.base_type, BaseType::Array) {
370 format!("Vec<{}>", base)
371 } else {
372 base
373 };
374
375 if field.is_optional {
378 format!("Option<Option<{}>>", typed)
379 } else {
380 format!("Option<{}>", typed)
381 }
382 }
383
384 fn base_type_to_rust(&self, base_type: &BaseType, rust_type_name: &str) -> String {
385 match base_type {
386 BaseType::Integer => {
387 if rust_type_name.contains("u64") {
388 "u64".to_string()
389 } else if rust_type_name.contains("i64") {
390 "i64".to_string()
391 } else if rust_type_name.contains("u32") {
392 "u32".to_string()
393 } else if rust_type_name.contains("i32") {
394 "i32".to_string()
395 } else {
396 "i64".to_string()
397 }
398 }
399 BaseType::Float => "f64".to_string(),
400 BaseType::String => "String".to_string(),
401 BaseType::Boolean => "bool".to_string(),
402 BaseType::Timestamp => "i64".to_string(),
403 BaseType::Binary => "Vec<u8>".to_string(),
404 BaseType::Pubkey => "String".to_string(),
405 BaseType::Array => "Vec<serde_json::Value>".to_string(),
406 BaseType::Object => "serde_json::Value".to_string(),
407 BaseType::Any => "serde_json::Value".to_string(),
408 }
409 }
410
411 fn resolved_field_to_rust(&self, field: &ResolvedField) -> String {
412 let base = self.base_type_to_rust(&field.base_type, &field.field_type);
413
414 let typed = if field.is_array {
415 format!("Vec<{}>", base)
416 } else {
417 base
418 };
419
420 if field.is_optional {
421 format!("Option<Option<{}>>", typed)
422 } else {
423 format!("Option<{}>", typed)
424 }
425 }
426}
427
428fn to_pascal_case(s: &str) -> String {
429 s.split(['_', '-', '.'])
430 .map(|word| {
431 let mut chars = word.chars();
432 match chars.next() {
433 None => String::new(),
434 Some(first) => first.to_uppercase().collect::<String>() + chars.as_str(),
435 }
436 })
437 .collect()
438}
439
440fn to_snake_case(s: &str) -> String {
441 let mut result = String::new();
442 for (i, ch) in s.chars().enumerate() {
443 if ch.is_uppercase() {
444 if i > 0 {
445 result.push('_');
446 }
447 result.push(ch.to_lowercase().next().unwrap());
448 } else {
449 result.push(ch);
450 }
451 }
452 result
453}
454
455fn to_camel_case(s: &str) -> String {
456 let pascal = to_pascal_case(s);
457 let mut chars = pascal.chars();
458 match chars.next() {
459 None => String::new(),
460 Some(first) => first.to_lowercase().collect::<String>() + chars.as_str(),
461 }
462}