1use crate::ast::*;
2use std::collections::HashSet;
3
4#[derive(Debug, Clone)]
5pub struct RustOutput {
6 pub cargo_toml: String,
7 pub lib_rs: String,
8 pub types_rs: String,
9 pub entity_rs: String,
10}
11
12impl RustOutput {
13 pub fn full_lib(&self) -> String {
14 format!(
15 "{}\n\n// types.rs\n{}\n\n// entity.rs\n{}",
16 self.lib_rs, self.types_rs, self.entity_rs
17 )
18 }
19}
20
21#[derive(Debug, Clone)]
22pub struct RustConfig {
23 pub crate_name: String,
24 pub sdk_version: String,
25}
26
27impl Default for RustConfig {
28 fn default() -> Self {
29 Self {
30 crate_name: "generated-stack".to_string(),
31 sdk_version: "0.2".to_string(),
32 }
33 }
34}
35
36pub fn compile_serializable_spec(
37 spec: SerializableStreamSpec,
38 entity_name: String,
39 config: Option<RustConfig>,
40) -> Result<RustOutput, String> {
41 let config = config.unwrap_or_default();
42 let compiler = RustCompiler::new(spec, entity_name, config);
43 Ok(compiler.compile())
44}
45
46pub fn write_rust_crate(
47 output: &RustOutput,
48 crate_dir: &std::path::Path,
49) -> Result<(), std::io::Error> {
50 std::fs::create_dir_all(crate_dir.join("src"))?;
51 std::fs::write(crate_dir.join("Cargo.toml"), &output.cargo_toml)?;
52 std::fs::write(crate_dir.join("src/lib.rs"), &output.lib_rs)?;
53 std::fs::write(crate_dir.join("src/types.rs"), &output.types_rs)?;
54 std::fs::write(crate_dir.join("src/entity.rs"), &output.entity_rs)?;
55 Ok(())
56}
57
58struct RustCompiler {
59 spec: SerializableStreamSpec,
60 entity_name: String,
61 config: RustConfig,
62}
63
64impl RustCompiler {
65 fn new(spec: SerializableStreamSpec, entity_name: String, config: RustConfig) -> Self {
66 Self {
67 spec,
68 entity_name,
69 config,
70 }
71 }
72
73 fn compile(&self) -> RustOutput {
74 RustOutput {
75 cargo_toml: self.generate_cargo_toml(),
76 lib_rs: self.generate_lib_rs(),
77 types_rs: self.generate_types_rs(),
78 entity_rs: self.generate_entity_rs(),
79 }
80 }
81
82 fn generate_cargo_toml(&self) -> String {
83 format!(
84 r#"[package]
85name = "{}"
86version = "0.1.0"
87edition = "2021"
88
89[dependencies]
90hyperstack-sdk = "{}"
91serde = {{ version = "1", features = ["derive"] }}
92serde_json = "1"
93"#,
94 self.config.crate_name, self.config.sdk_version
95 )
96 }
97
98 fn generate_lib_rs(&self) -> String {
99 format!(
100 r#"mod types;
101mod entity;
102
103pub use types::*;
104pub use entity::{entity_name}Entity;
105
106pub use hyperstack_sdk::{{HyperStack, Entity, Update, ConnectionState}};
107"#,
108 entity_name = self.entity_name
109 )
110 }
111
112 fn generate_types_rs(&self) -> String {
113 let mut output = String::new();
114 output.push_str("use serde::{Deserialize, Deserializer, Serialize};\n\n");
115 output.push_str(&self.generate_serde_helpers());
116
117 let mut generated = HashSet::new();
118
119 for section in &self.spec.sections {
120 if !Self::is_root_section(§ion.name) && generated.insert(section.name.clone()) {
121 output.push_str(&self.generate_struct_for_section(section));
122 output.push_str("\n\n");
123 }
124 }
125
126 output.push_str(&self.generate_main_entity_struct());
127 output.push_str(&self.generate_resolved_types(&mut generated));
128 output.push_str(&self.generate_event_wrapper());
129
130 output
131 }
132
133 fn generate_struct_for_section(&self, section: &EntitySection) -> String {
134 let struct_name = format!("{}{}", self.entity_name, to_pascal_case(§ion.name));
135 let mut fields = Vec::new();
136
137 for field in §ion.fields {
138 let field_name = to_snake_case(&field.field_name);
139 let rust_type = self.field_type_to_rust(field);
140
141 let serde_attr = if field_name != to_snake_case(&field.field_name)
142 || field_name != field.field_name
143 {
144 let original = &field.field_name;
145 if to_snake_case(original) != *original {
146 format!(
147 " #[serde(rename = \"{}\", default)]\n",
148 to_camel_case(original)
149 )
150 } else {
151 " #[serde(default)]\n".to_string()
152 }
153 } else {
154 " #[serde(default)]\n".to_string()
155 };
156
157 fields.push(format!(
158 "{} pub {}: {},",
159 serde_attr,
160 to_snake_case(&field.field_name),
161 rust_type
162 ));
163 }
164
165 format!(
166 "#[derive(Debug, Clone, Serialize, Deserialize, Default)]\npub struct {} {{\n{}\n}}",
167 struct_name,
168 fields.join("\n")
169 )
170 }
171
172 fn is_root_section(name: &str) -> bool {
174 name.eq_ignore_ascii_case("root")
175 }
176
177 fn generate_main_entity_struct(&self) -> String {
178 let mut fields = Vec::new();
179
180 for section in &self.spec.sections {
181 if !Self::is_root_section(§ion.name) {
182 let field_name = to_snake_case(§ion.name);
183 let type_name = format!("{}{}", self.entity_name, to_pascal_case(§ion.name));
184 let serde_attr = if field_name != section.name {
185 format!(
186 " #[serde(rename = \"{}\", default)]\n",
187 to_camel_case(§ion.name)
188 )
189 } else {
190 " #[serde(default)]\n".to_string()
191 };
192 fields.push(format!(
193 "{} pub {}: {},",
194 serde_attr, field_name, type_name
195 ));
196 }
197 }
198
199 for section in &self.spec.sections {
200 if Self::is_root_section(§ion.name) {
201 for field in §ion.fields {
202 let field_name = to_snake_case(&field.field_name);
203 let rust_type = self.field_type_to_rust(field);
204 fields.push(format!(
205 " #[serde(default)]\n pub {}: {},",
206 field_name, rust_type
207 ));
208 }
209 }
210 }
211
212 format!(
213 "#[derive(Debug, Clone, Serialize, Deserialize, Default)]\npub struct {} {{\n{}\n}}",
214 self.entity_name,
215 fields.join("\n")
216 )
217 }
218
219 fn generate_resolved_types(&self, generated: &mut HashSet<String>) -> String {
220 let mut output = String::new();
221
222 for section in &self.spec.sections {
223 for field in §ion.fields {
224 if let Some(resolved) = &field.resolved_type {
225 if generated.insert(resolved.type_name.clone()) {
226 output.push_str("\n\n");
227 output.push_str(&self.generate_resolved_struct(resolved));
228 }
229 }
230 }
231 }
232
233 output
234 }
235
236 fn generate_resolved_struct(&self, resolved: &ResolvedStructType) -> String {
237 if resolved.is_enum {
238 let variants: Vec<String> = resolved
239 .enum_variants
240 .iter()
241 .map(|v| format!(" {},", to_pascal_case(v)))
242 .collect();
243
244 format!(
245 "#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]\npub enum {} {{\n{}\n}}",
246 to_pascal_case(&resolved.type_name),
247 variants.join("\n")
248 )
249 } else {
250 let fields: Vec<String> = resolved
251 .fields
252 .iter()
253 .map(|f| {
254 let rust_type = self.resolved_field_to_rust(f);
255 let serde_attr = format!(
256 " #[serde(rename = \"{}\", default)]\n",
257 to_camel_case(&f.field_name)
258 );
259 format!(
260 "{} pub {}: {},",
261 serde_attr,
262 to_snake_case(&f.field_name),
263 rust_type
264 )
265 })
266 .collect();
267
268 format!(
269 "#[derive(Debug, Clone, Serialize, Deserialize, Default)]\npub struct {} {{\n{}\n}}",
270 to_pascal_case(&resolved.type_name),
271 fields.join("\n")
272 )
273 }
274 }
275
276 fn generate_event_wrapper(&self) -> String {
277 r#"
278
279#[derive(Debug, Clone, Serialize, Deserialize)]
280pub struct EventWrapper<T> {
281 #[serde(default)]
282 pub timestamp: i64,
283 pub data: T,
284 #[serde(default)]
285 pub slot: Option<f64>,
286 #[serde(default)]
287 pub signature: Option<String>,
288}
289
290impl<T: Default> Default for EventWrapper<T> {
291 fn default() -> Self {
292 Self {
293 timestamp: 0,
294 data: T::default(),
295 slot: None,
296 signature: None,
297 }
298 }
299}
300"#
301 .to_string()
302 }
303
304 fn generate_serde_helpers(&self) -> String {
305 r#"mod serde_helpers {
306 use serde::{Deserialize, Deserializer};
307
308 pub fn deserialize_number_from_any<'de, D>(deserializer: D) -> Result<Option<f64>, D::Error>
309 where
310 D: Deserializer<'de>,
311 {
312 #[derive(Deserialize)]
313 #[serde(untagged)]
314 enum NumOrNull {
315 Num(f64),
316 Null,
317 }
318 match NumOrNull::deserialize(deserializer)? {
319 NumOrNull::Num(n) => Ok(Some(n)),
320 NumOrNull::Null => Ok(None),
321 }
322 }
323}
324
325"#
326 .to_string()
327 }
328
329 fn generate_entity_rs(&self) -> String {
330 let entity_name = &self.entity_name;
331
332 format!(
333 r#"use hyperstack_sdk::Entity;
334use crate::types::{entity_name};
335
336pub struct {entity_name}Entity;
337
338impl Entity for {entity_name}Entity {{
339 type Data = {entity_name};
340
341 const NAME: &'static str = "{entity_name}";
342
343 fn state_view() -> &'static str {{
344 "{entity_name}/state"
345 }}
346
347 fn list_view() -> &'static str {{
348 "{entity_name}/list"
349 }}
350}}
351"#,
352 entity_name = entity_name
353 )
354 }
355
356 fn field_type_to_rust(&self, field: &FieldTypeInfo) -> String {
370 let base = self.base_type_to_rust(&field.base_type, &field.rust_type_name);
371
372 let typed = if field.is_array && !matches!(field.base_type, BaseType::Array) {
373 format!("Vec<{}>", base)
374 } else {
375 base
376 };
377
378 if field.is_optional {
381 format!("Option<Option<{}>>", typed)
382 } else {
383 format!("Option<{}>", typed)
384 }
385 }
386
387 fn base_type_to_rust(&self, base_type: &BaseType, rust_type_name: &str) -> String {
388 match base_type {
389 BaseType::Integer => {
390 if rust_type_name.contains("u64") {
391 "u64".to_string()
392 } else if rust_type_name.contains("i64") {
393 "i64".to_string()
394 } else if rust_type_name.contains("u32") {
395 "u32".to_string()
396 } else if rust_type_name.contains("i32") {
397 "i32".to_string()
398 } else {
399 "i64".to_string()
400 }
401 }
402 BaseType::Float => "f64".to_string(),
403 BaseType::String => "String".to_string(),
404 BaseType::Boolean => "bool".to_string(),
405 BaseType::Timestamp => "i64".to_string(),
406 BaseType::Binary => "Vec<u8>".to_string(),
407 BaseType::Pubkey => "String".to_string(),
408 BaseType::Array => "Vec<serde_json::Value>".to_string(),
409 BaseType::Object => "serde_json::Value".to_string(),
410 BaseType::Any => "serde_json::Value".to_string(),
411 }
412 }
413
414 fn resolved_field_to_rust(&self, field: &ResolvedField) -> String {
415 let base = self.base_type_to_rust(&field.base_type, &field.field_type);
416
417 let typed = if field.is_array {
418 format!("Vec<{}>", base)
419 } else {
420 base
421 };
422
423 if field.is_optional {
424 format!("Option<Option<{}>>", typed)
425 } else {
426 format!("Option<{}>", typed)
427 }
428 }
429}
430
431fn to_pascal_case(s: &str) -> String {
432 s.split(['_', '-', '.'])
433 .map(|word| {
434 let mut chars = word.chars();
435 match chars.next() {
436 None => String::new(),
437 Some(first) => first.to_uppercase().collect::<String>() + chars.as_str(),
438 }
439 })
440 .collect()
441}
442
443fn to_snake_case(s: &str) -> String {
444 let mut result = String::new();
445 for (i, ch) in s.chars().enumerate() {
446 if ch.is_uppercase() {
447 if i > 0 {
448 result.push('_');
449 }
450 result.push(ch.to_lowercase().next().unwrap());
451 } else {
452 result.push(ch);
453 }
454 }
455 result
456}
457
458fn to_camel_case(s: &str) -> String {
459 let pascal = to_pascal_case(s);
460 let mut chars = pascal.chars();
461 match chars.next() {
462 None => String::new(),
463 Some(first) => first.to_lowercase().collect::<String>() + chars.as_str(),
464 }
465}