1use crate::ast::*;
2use std::collections::HashSet;
3
4#[derive(Debug, Clone)]
5pub struct RustOutput {
6 pub cargo_toml: String,
7 pub lib_rs: String,
8 pub types_rs: String,
9 pub entity_rs: String,
10}
11
12impl RustOutput {
13 pub fn full_lib(&self) -> String {
14 format!(
15 "{}\n\n// types.rs\n{}\n\n// entity.rs\n{}",
16 self.lib_rs, self.types_rs, self.entity_rs
17 )
18 }
19
20 pub fn mod_rs(&self) -> String {
21 self.lib_rs.clone()
22 }
23}
24
25#[derive(Debug, Clone)]
26pub struct RustConfig {
27 pub crate_name: String,
28 pub sdk_version: String,
29 pub module_mode: bool,
30}
31
32impl Default for RustConfig {
33 fn default() -> Self {
34 Self {
35 crate_name: "generated-stack".to_string(),
36 sdk_version: "0.2".to_string(),
37 module_mode: false,
38 }
39 }
40}
41
42pub fn compile_serializable_spec(
43 spec: SerializableStreamSpec,
44 entity_name: String,
45 config: Option<RustConfig>,
46) -> Result<RustOutput, String> {
47 let config = config.unwrap_or_default();
48 let compiler = RustCompiler::new(spec, entity_name, config);
49 Ok(compiler.compile())
50}
51
52pub fn write_rust_crate(
53 output: &RustOutput,
54 crate_dir: &std::path::Path,
55) -> Result<(), std::io::Error> {
56 std::fs::create_dir_all(crate_dir.join("src"))?;
57 std::fs::write(crate_dir.join("Cargo.toml"), &output.cargo_toml)?;
58 std::fs::write(crate_dir.join("src/lib.rs"), &output.lib_rs)?;
59 std::fs::write(crate_dir.join("src/types.rs"), &output.types_rs)?;
60 std::fs::write(crate_dir.join("src/entity.rs"), &output.entity_rs)?;
61 Ok(())
62}
63
64pub fn write_rust_module(
65 output: &RustOutput,
66 module_dir: &std::path::Path,
67) -> Result<(), std::io::Error> {
68 std::fs::create_dir_all(module_dir)?;
69 std::fs::write(module_dir.join("mod.rs"), output.mod_rs())?;
70 std::fs::write(module_dir.join("types.rs"), &output.types_rs)?;
71 std::fs::write(module_dir.join("entity.rs"), &output.entity_rs)?;
72 Ok(())
73}
74
75struct RustCompiler {
76 spec: SerializableStreamSpec,
77 entity_name: String,
78 config: RustConfig,
79}
80
81impl RustCompiler {
82 fn new(spec: SerializableStreamSpec, entity_name: String, config: RustConfig) -> Self {
83 Self {
84 spec,
85 entity_name,
86 config,
87 }
88 }
89
90 fn compile(&self) -> RustOutput {
91 RustOutput {
92 cargo_toml: self.generate_cargo_toml(),
93 lib_rs: self.generate_lib_rs(),
94 types_rs: self.generate_types_rs(),
95 entity_rs: self.generate_entity_rs(),
96 }
97 }
98
99 fn generate_cargo_toml(&self) -> String {
100 format!(
101 r#"[package]
102name = "{}"
103version = "0.1.0"
104edition = "2021"
105
106[dependencies]
107hyperstack-sdk = "{}"
108serde = {{ version = "1", features = ["derive"] }}
109serde_json = "1"
110"#,
111 self.config.crate_name, self.config.sdk_version
112 )
113 }
114
115 fn generate_lib_rs(&self) -> String {
116 format!(
117 r#"mod types;
118mod entity;
119
120pub use types::*;
121pub use entity::{{{}Entity, {}Views}};
122
123pub use hyperstack_sdk::{{HyperStack, Entity, Update, ConnectionState, Views}};
124"#,
125 self.entity_name, self.entity_name
126 )
127 }
128
129 fn generate_types_rs(&self) -> String {
130 let mut output = String::new();
131 output.push_str("use serde::{Deserialize, Serialize};\n\n");
132
133 let mut generated = HashSet::new();
134
135 for section in &self.spec.sections {
136 if !Self::is_root_section(§ion.name) && generated.insert(section.name.clone()) {
137 output.push_str(&self.generate_struct_for_section(section));
138 output.push_str("\n\n");
139 }
140 }
141
142 output.push_str(&self.generate_main_entity_struct());
143 output.push_str(&self.generate_resolved_types(&mut generated));
144 output.push_str(&self.generate_event_wrapper());
145
146 output
147 }
148
149 fn generate_struct_for_section(&self, section: &EntitySection) -> String {
150 let struct_name = format!("{}{}", self.entity_name, to_pascal_case(§ion.name));
151 let mut fields = Vec::new();
152
153 for field in §ion.fields {
154 let field_name = to_snake_case(&field.field_name);
155 let rust_type = self.field_type_to_rust(field);
156
157 fields.push(format!(
158 " #[serde(default)]\n pub {}: {},",
159 field_name, rust_type
160 ));
161 }
162
163 format!(
164 "#[derive(Debug, Clone, Serialize, Deserialize, Default)]\npub struct {} {{\n{}\n}}",
165 struct_name,
166 fields.join("\n")
167 )
168 }
169
170 fn is_root_section(name: &str) -> bool {
172 name.eq_ignore_ascii_case("root")
173 }
174
175 fn generate_main_entity_struct(&self) -> String {
176 let mut fields = Vec::new();
177
178 for section in &self.spec.sections {
179 if !Self::is_root_section(§ion.name) {
180 let field_name = to_snake_case(§ion.name);
181 let type_name = format!("{}{}", self.entity_name, to_pascal_case(§ion.name));
182 fields.push(format!(
183 " #[serde(default)]\n pub {}: {},",
184 field_name, type_name
185 ));
186 }
187 }
188
189 for section in &self.spec.sections {
190 if Self::is_root_section(§ion.name) {
191 for field in §ion.fields {
192 let field_name = to_snake_case(&field.field_name);
193 let rust_type = self.field_type_to_rust(field);
194 fields.push(format!(
195 " #[serde(default)]\n pub {}: {},",
196 field_name, rust_type
197 ));
198 }
199 }
200 }
201
202 format!(
203 "#[derive(Debug, Clone, Serialize, Deserialize, Default)]\npub struct {} {{\n{}\n}}",
204 self.entity_name,
205 fields.join("\n")
206 )
207 }
208
209 fn generate_resolved_types(&self, generated: &mut HashSet<String>) -> String {
210 let mut output = String::new();
211
212 for section in &self.spec.sections {
213 for field in §ion.fields {
214 if let Some(resolved) = &field.resolved_type {
215 if generated.insert(resolved.type_name.clone()) {
216 output.push_str("\n\n");
217 output.push_str(&self.generate_resolved_struct(resolved));
218 }
219 }
220 }
221 }
222
223 output
224 }
225
226 fn generate_resolved_struct(&self, resolved: &ResolvedStructType) -> String {
227 if resolved.is_enum {
228 let variants: Vec<String> = resolved
229 .enum_variants
230 .iter()
231 .map(|v| format!(" {},", to_pascal_case(v)))
232 .collect();
233
234 format!(
235 "#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]\npub enum {} {{\n{}\n}}",
236 to_pascal_case(&resolved.type_name),
237 variants.join("\n")
238 )
239 } else {
240 let fields: Vec<String> = resolved
241 .fields
242 .iter()
243 .map(|f| {
244 let rust_type = self.resolved_field_to_rust(f);
245 format!(
246 " #[serde(default)]\n pub {}: {},",
247 to_snake_case(&f.field_name),
248 rust_type
249 )
250 })
251 .collect();
252
253 format!(
254 "#[derive(Debug, Clone, Serialize, Deserialize, Default)]\npub struct {} {{\n{}\n}}",
255 to_pascal_case(&resolved.type_name),
256 fields.join("\n")
257 )
258 }
259 }
260
261 fn generate_event_wrapper(&self) -> String {
262 r#"
263
264#[derive(Debug, Clone, Serialize, Deserialize)]
265pub struct EventWrapper<T> {
266 #[serde(default)]
267 pub timestamp: i64,
268 pub data: T,
269 #[serde(default)]
270 pub slot: Option<f64>,
271 #[serde(default)]
272 pub signature: Option<String>,
273}
274
275impl<T: Default> Default for EventWrapper<T> {
276 fn default() -> Self {
277 Self {
278 timestamp: 0,
279 data: T::default(),
280 slot: None,
281 signature: None,
282 }
283 }
284}
285"#
286 .to_string()
287 }
288
289 fn generate_entity_rs(&self) -> String {
290 let entity_name = &self.entity_name;
291 let types_import = if self.config.module_mode {
292 "super::types"
293 } else {
294 "crate::types"
295 };
296
297 let views_struct = self.generate_views_struct();
298
299 format!(
300 r#"use hyperstack_sdk::{{Entity, StateView, ViewBuilder, ViewHandle, Views}};
301use {types_import}::{entity_name};
302
303pub struct {entity_name}Entity;
304
305impl Entity for {entity_name}Entity {{
306 type Data = {entity_name};
307
308 const NAME: &'static str = "{entity_name}";
309
310 fn state_view() -> &'static str {{
311 "{entity_name}/state"
312 }}
313
314 fn list_view() -> &'static str {{
315 "{entity_name}/list"
316 }}
317}}
318{views_struct}"#,
319 types_import = types_import,
320 entity_name = entity_name,
321 views_struct = views_struct
322 )
323 }
324
325 fn generate_views_struct(&self) -> String {
326 let entity_name = &self.entity_name;
327
328 let derived: Vec<_> = self
329 .spec
330 .views
331 .iter()
332 .filter(|v| {
333 !v.id.ends_with("/state")
334 && !v.id.ends_with("/list")
335 && v.id.starts_with(entity_name)
336 })
337 .collect();
338
339 let mut derived_methods = String::new();
340 for view in &derived {
341 let view_name = view.id.split('/').nth(1).unwrap_or("unknown");
342 let method_name = to_snake_case(view_name);
343
344 derived_methods.push_str(&format!(
345 r#"
346 pub fn {method_name}(&self) -> ViewHandle<{entity_name}> {{
347 self.builder.view("{view_id}")
348 }}
349"#,
350 method_name = method_name,
351 entity_name = entity_name,
352 view_id = view.id
353 ));
354 }
355
356 format!(
357 r#"
358
359pub struct {entity_name}Views {{
360 builder: ViewBuilder,
361}}
362
363impl Views for {entity_name}Views {{
364 type Entity = {entity_name}Entity;
365
366 fn from_builder(builder: ViewBuilder) -> Self {{
367 Self {{ builder }}
368 }}
369}}
370
371impl {entity_name}Views {{
372 pub fn state(&self) -> StateView<{entity_name}> {{
373 StateView::new(
374 self.builder.connection().clone(),
375 self.builder.store().clone(),
376 "{entity_name}/state".to_string(),
377 self.builder.initial_data_timeout(),
378 )
379 }}
380
381 pub fn list(&self) -> ViewHandle<{entity_name}> {{
382 self.builder.view("{entity_name}/list")
383 }}
384{derived_methods}}}
385"#,
386 entity_name = entity_name,
387 derived_methods = derived_methods
388 )
389 }
390
391 fn field_type_to_rust(&self, field: &FieldTypeInfo) -> String {
405 let base = self.base_type_to_rust(&field.base_type, &field.rust_type_name);
406
407 let typed = if field.is_array && !matches!(field.base_type, BaseType::Array) {
408 format!("Vec<{}>", base)
409 } else {
410 base
411 };
412
413 if field.is_optional {
416 format!("Option<Option<{}>>", typed)
417 } else {
418 format!("Option<{}>", typed)
419 }
420 }
421
422 fn base_type_to_rust(&self, base_type: &BaseType, rust_type_name: &str) -> String {
423 match base_type {
424 BaseType::Integer => {
425 if rust_type_name.contains("u64") {
426 "u64".to_string()
427 } else if rust_type_name.contains("i64") {
428 "i64".to_string()
429 } else if rust_type_name.contains("u32") {
430 "u32".to_string()
431 } else if rust_type_name.contains("i32") {
432 "i32".to_string()
433 } else {
434 "i64".to_string()
435 }
436 }
437 BaseType::Float => "f64".to_string(),
438 BaseType::String => "String".to_string(),
439 BaseType::Boolean => "bool".to_string(),
440 BaseType::Timestamp => "i64".to_string(),
441 BaseType::Binary => "Vec<u8>".to_string(),
442 BaseType::Pubkey => "String".to_string(),
443 BaseType::Array => "Vec<serde_json::Value>".to_string(),
444 BaseType::Object => "serde_json::Value".to_string(),
445 BaseType::Any => "serde_json::Value".to_string(),
446 }
447 }
448
449 fn resolved_field_to_rust(&self, field: &ResolvedField) -> String {
450 let base = self.base_type_to_rust(&field.base_type, &field.field_type);
451
452 let typed = if field.is_array {
453 format!("Vec<{}>", base)
454 } else {
455 base
456 };
457
458 if field.is_optional {
459 format!("Option<Option<{}>>", typed)
460 } else {
461 format!("Option<{}>", typed)
462 }
463 }
464}
465
466fn to_pascal_case(s: &str) -> String {
467 s.split(['_', '-', '.'])
468 .map(|word| {
469 let mut chars = word.chars();
470 match chars.next() {
471 None => String::new(),
472 Some(first) => first.to_uppercase().collect::<String>() + chars.as_str(),
473 }
474 })
475 .collect()
476}
477
478fn to_snake_case(s: &str) -> String {
479 let mut result = String::new();
480 for (i, ch) in s.chars().enumerate() {
481 if ch.is_uppercase() {
482 if i > 0 {
483 result.push('_');
484 }
485 result.push(ch.to_lowercase().next().unwrap());
486 } else {
487 result.push(ch);
488 }
489 }
490 result
491}