Skip to main content

hyperstack_interpreter/
rust.rs

1use crate::ast::*;
2use std::collections::HashSet;
3
4#[derive(Debug, Clone)]
5pub struct RustOutput {
6    pub cargo_toml: String,
7    pub lib_rs: String,
8    pub types_rs: String,
9    pub entity_rs: String,
10}
11
12impl RustOutput {
13    pub fn full_lib(&self) -> String {
14        format!(
15            "{}\n\n// types.rs\n{}\n\n// entity.rs\n{}",
16            self.lib_rs, self.types_rs, self.entity_rs
17        )
18    }
19
20    pub fn mod_rs(&self) -> String {
21        self.lib_rs.clone()
22    }
23}
24
25#[derive(Debug, Clone)]
26pub struct RustConfig {
27    pub crate_name: String,
28    pub sdk_version: String,
29    pub module_mode: bool,
30    /// WebSocket URL for the stack. If None, generates a placeholder comment.
31    pub url: Option<String>,
32}
33
34impl Default for RustConfig {
35    fn default() -> Self {
36        Self {
37            crate_name: "generated-stack".to_string(),
38            sdk_version: "0.2".to_string(),
39            module_mode: false,
40            url: None,
41        }
42    }
43}
44
45pub fn compile_serializable_spec(
46    spec: SerializableStreamSpec,
47    entity_name: String,
48    config: Option<RustConfig>,
49) -> Result<RustOutput, String> {
50    let config = config.unwrap_or_default();
51    let compiler = RustCompiler::new(spec, entity_name, config);
52    Ok(compiler.compile())
53}
54
55pub fn write_rust_crate(
56    output: &RustOutput,
57    crate_dir: &std::path::Path,
58) -> Result<(), std::io::Error> {
59    std::fs::create_dir_all(crate_dir.join("src"))?;
60    std::fs::write(crate_dir.join("Cargo.toml"), &output.cargo_toml)?;
61    std::fs::write(crate_dir.join("src/lib.rs"), &output.lib_rs)?;
62    std::fs::write(crate_dir.join("src/types.rs"), &output.types_rs)?;
63    std::fs::write(crate_dir.join("src/entity.rs"), &output.entity_rs)?;
64    Ok(())
65}
66
67pub fn write_rust_module(
68    output: &RustOutput,
69    module_dir: &std::path::Path,
70) -> Result<(), std::io::Error> {
71    std::fs::create_dir_all(module_dir)?;
72    std::fs::write(module_dir.join("mod.rs"), output.mod_rs())?;
73    std::fs::write(module_dir.join("types.rs"), &output.types_rs)?;
74    std::fs::write(module_dir.join("entity.rs"), &output.entity_rs)?;
75    Ok(())
76}
77
78struct RustCompiler {
79    spec: SerializableStreamSpec,
80    entity_name: String,
81    config: RustConfig,
82}
83
84impl RustCompiler {
85    fn new(spec: SerializableStreamSpec, entity_name: String, config: RustConfig) -> Self {
86        Self {
87            spec,
88            entity_name,
89            config,
90        }
91    }
92
93    fn compile(&self) -> RustOutput {
94        RustOutput {
95            cargo_toml: self.generate_cargo_toml(),
96            lib_rs: self.generate_lib_rs(),
97            types_rs: self.generate_types_rs(),
98            entity_rs: self.generate_entity_rs(),
99        }
100    }
101
102    fn generate_cargo_toml(&self) -> String {
103        format!(
104            r#"[package]
105name = "{}"
106version = "0.1.0"
107edition = "2021"
108
109[dependencies]
110hyperstack-sdk = "{}"
111serde = {{ version = "1", features = ["derive"] }}
112serde_json = "1"
113"#,
114            self.config.crate_name, self.config.sdk_version
115        )
116    }
117
118    fn generate_lib_rs(&self) -> String {
119        let stack_name = self.derive_stack_name();
120        let entity_name = &self.entity_name;
121        
122        format!(
123            r#"mod entity;
124mod types;
125
126pub use entity::{{{stack_name}Stack, {stack_name}StackViews, {entity_name}EntityViews}};
127pub use types::*;
128
129pub use hyperstack_sdk::{{ConnectionState, HyperStack, Stack, Update, Views}};
130"#,
131            stack_name = stack_name,
132            entity_name = entity_name
133        )
134    }
135
136    fn generate_types_rs(&self) -> String {
137        let mut output = String::new();
138        output.push_str("use serde::{Deserialize, Serialize};\n\n");
139
140        let mut generated = HashSet::new();
141
142        for section in &self.spec.sections {
143            if !Self::is_root_section(&section.name) && generated.insert(section.name.clone()) {
144                output.push_str(&self.generate_struct_for_section(section));
145                output.push_str("\n\n");
146            }
147        }
148
149        output.push_str(&self.generate_main_entity_struct());
150        output.push_str(&self.generate_resolved_types(&mut generated));
151        output.push_str(&self.generate_event_wrapper());
152
153        output
154    }
155
156    fn generate_struct_for_section(&self, section: &EntitySection) -> String {
157        let struct_name = format!("{}{}", self.entity_name, to_pascal_case(&section.name));
158        let mut fields = Vec::new();
159
160        for field in &section.fields {
161            let field_name = to_snake_case(&field.field_name);
162            let rust_type = self.field_type_to_rust(field);
163
164            fields.push(format!(
165                "    #[serde(default)]\n    pub {}: {},",
166                field_name, rust_type
167            ));
168        }
169
170        format!(
171            "#[derive(Debug, Clone, Serialize, Deserialize, Default)]\npub struct {} {{\n{}\n}}",
172            struct_name,
173            fields.join("\n")
174        )
175    }
176
177    /// Check if a section name is the root section (case-insensitive)
178    fn is_root_section(name: &str) -> bool {
179        name.eq_ignore_ascii_case("root")
180    }
181
182    fn generate_main_entity_struct(&self) -> String {
183        let mut fields = Vec::new();
184
185        for section in &self.spec.sections {
186            if !Self::is_root_section(&section.name) {
187                let field_name = to_snake_case(&section.name);
188                let type_name = format!("{}{}", self.entity_name, to_pascal_case(&section.name));
189                fields.push(format!(
190                    "    #[serde(default)]\n    pub {}: {},",
191                    field_name, type_name
192                ));
193            }
194        }
195
196        for section in &self.spec.sections {
197            if Self::is_root_section(&section.name) {
198                for field in &section.fields {
199                    let field_name = to_snake_case(&field.field_name);
200                    let rust_type = self.field_type_to_rust(field);
201                    fields.push(format!(
202                        "    #[serde(default)]\n    pub {}: {},",
203                        field_name, rust_type
204                    ));
205                }
206            }
207        }
208
209        format!(
210            "#[derive(Debug, Clone, Serialize, Deserialize, Default)]\npub struct {} {{\n{}\n}}",
211            self.entity_name,
212            fields.join("\n")
213        )
214    }
215
216    fn generate_resolved_types(&self, generated: &mut HashSet<String>) -> String {
217        let mut output = String::new();
218
219        for section in &self.spec.sections {
220            for field in &section.fields {
221                if let Some(resolved) = &field.resolved_type {
222                    if generated.insert(resolved.type_name.clone()) {
223                        output.push_str("\n\n");
224                        output.push_str(&self.generate_resolved_struct(resolved));
225                    }
226                }
227            }
228        }
229
230        output
231    }
232
233    fn generate_resolved_struct(&self, resolved: &ResolvedStructType) -> String {
234        if resolved.is_enum {
235            let variants: Vec<String> = resolved
236                .enum_variants
237                .iter()
238                .map(|v| format!("    {},", to_pascal_case(v)))
239                .collect();
240
241            format!(
242                "#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]\npub enum {} {{\n{}\n}}",
243                to_pascal_case(&resolved.type_name),
244                variants.join("\n")
245            )
246        } else {
247            let fields: Vec<String> = resolved
248                .fields
249                .iter()
250                .map(|f| {
251                    let rust_type = self.resolved_field_to_rust(f);
252                    format!(
253                        "    #[serde(default)]\n    pub {}: {},",
254                        to_snake_case(&f.field_name),
255                        rust_type
256                    )
257                })
258                .collect();
259
260            format!(
261                "#[derive(Debug, Clone, Serialize, Deserialize, Default)]\npub struct {} {{\n{}\n}}",
262                to_pascal_case(&resolved.type_name),
263                fields.join("\n")
264            )
265        }
266    }
267
268    fn generate_event_wrapper(&self) -> String {
269        r#"
270
271#[derive(Debug, Clone, Serialize, Deserialize)]
272pub struct EventWrapper<T> {
273    #[serde(default)]
274    pub timestamp: i64,
275    pub data: T,
276    #[serde(default)]
277    pub slot: Option<f64>,
278    #[serde(default)]
279    pub signature: Option<String>,
280}
281
282impl<T: Default> Default for EventWrapper<T> {
283    fn default() -> Self {
284        Self {
285            timestamp: 0,
286            data: T::default(),
287            slot: None,
288            signature: None,
289        }
290    }
291}
292"#
293        .to_string()
294    }
295
296    fn generate_entity_rs(&self) -> String {
297        let entity_name = &self.entity_name;
298        let stack_name = self.derive_stack_name();
299        let stack_name_kebab = to_kebab_case(entity_name);
300        let entity_snake = to_snake_case(entity_name);
301        
302        let types_import = if self.config.module_mode {
303            "super::types"
304        } else {
305            "crate::types"
306        };
307
308        // Generate URL line - either actual URL or placeholder comment
309        let url_impl = match &self.config.url {
310            Some(url) => format!(r#"fn url() -> &'static str {{
311        "{}"
312    }}"#, url),
313            None => r#"fn url() -> &'static str {
314        "" // TODO: Set URL after first deployment in hyperstack.toml
315    }"#.to_string(),
316        };
317
318        let entity_views = self.generate_entity_views_struct();
319
320        format!(
321            r#"use {types_import}::{entity_name};
322use hyperstack_sdk::{{Stack, StateView, ViewBuilder, ViewHandle, Views}};
323
324pub struct {stack_name}Stack;
325
326impl Stack for {stack_name}Stack {{
327    type Views = {stack_name}StackViews;
328
329    fn name() -> &'static str {{
330        "{stack_name_kebab}"
331    }}
332
333    {url_impl}
334}}
335
336pub struct {stack_name}StackViews {{
337    pub {entity_snake}: {entity_name}EntityViews,
338}}
339
340impl Views for {stack_name}StackViews {{
341    fn from_builder(builder: ViewBuilder) -> Self {{
342        Self {{
343            {entity_snake}: {entity_name}EntityViews {{ builder }},
344        }}
345    }}
346}}
347{entity_views}"#,
348            types_import = types_import,
349            entity_name = entity_name,
350            stack_name = stack_name,
351            stack_name_kebab = stack_name_kebab,
352            entity_snake = entity_snake,
353            url_impl = url_impl,
354            entity_views = entity_views
355        )
356    }
357
358    fn generate_entity_views_struct(&self) -> String {
359        let entity_name = &self.entity_name;
360
361        let derived: Vec<_> = self
362            .spec
363            .views
364            .iter()
365            .filter(|v| {
366                !v.id.ends_with("/state")
367                    && !v.id.ends_with("/list")
368                    && v.id.starts_with(entity_name)
369            })
370            .collect();
371
372        let mut derived_methods = String::new();
373        for view in &derived {
374            let view_name = view.id.split('/').nth(1).unwrap_or("unknown");
375            let method_name = to_snake_case(view_name);
376
377            derived_methods.push_str(&format!(
378                r#"
379    pub fn {method_name}(&self) -> ViewHandle<{entity_name}> {{
380        self.builder.view("{view_id}")
381    }}
382"#,
383                method_name = method_name,
384                entity_name = entity_name,
385                view_id = view.id
386            ));
387        }
388
389        format!(
390            r#"
391pub struct {entity_name}EntityViews {{
392    builder: ViewBuilder,
393}}
394
395impl {entity_name}EntityViews {{
396    pub fn state(&self) -> StateView<{entity_name}> {{
397        StateView::new(
398            self.builder.connection().clone(),
399            self.builder.store().clone(),
400            "{entity_name}/state".to_string(),
401            self.builder.initial_data_timeout(),
402        )
403    }}
404
405    pub fn list(&self) -> ViewHandle<{entity_name}> {{
406        self.builder.view("{entity_name}/list")
407    }}
408{derived_methods}}}"#,
409            entity_name = entity_name,
410            derived_methods = derived_methods
411        )
412    }
413
414    /// Derive stack name from entity name.
415    /// E.g., "OreRound" -> "Ore", "PumpfunToken" -> "Pumpfun"
416    fn derive_stack_name(&self) -> String {
417        let entity_name = &self.entity_name;
418        
419        // Common suffixes to strip
420        let suffixes = ["Round", "Token", "Game", "State", "Entity", "Data"];
421        
422        for suffix in suffixes {
423            if entity_name.ends_with(suffix) && entity_name.len() > suffix.len() {
424                return entity_name[..entity_name.len() - suffix.len()].to_string();
425            }
426        }
427        
428        // If no suffix matched, use the full entity name
429        entity_name.clone()
430    }
431
432    /// Generate Rust type for a field.
433    ///
434    /// All fields are wrapped in Option<T> because we receive partial patches,
435    /// so any field may not yet be present.
436    ///
437    /// - Non-optional spec fields become `Option<T>`:
438    ///   - `None` = not yet received in any patch
439    ///   - `Some(value)` = has value
440    ///
441    /// - Optional spec fields become `Option<Option<T>>`:
442    ///   - `None` = not yet received in any patch
443    ///   - `Some(None)` = explicitly set to null
444    ///   - `Some(Some(value))` = has value
445    fn field_type_to_rust(&self, field: &FieldTypeInfo) -> String {
446        let base = self.base_type_to_rust(&field.base_type, &field.rust_type_name);
447
448        let typed = if field.is_array && !matches!(field.base_type, BaseType::Array) {
449            format!("Vec<{}>", base)
450        } else {
451            base
452        };
453
454        // All fields wrapped in Option since we receive patches
455        // Optional spec fields get Option<Option<T>> to distinguish "not received" from "explicitly null"
456        if field.is_optional {
457            format!("Option<Option<{}>>", typed)
458        } else {
459            format!("Option<{}>", typed)
460        }
461    }
462
463    fn base_type_to_rust(&self, base_type: &BaseType, rust_type_name: &str) -> String {
464        match base_type {
465            BaseType::Integer => {
466                if rust_type_name.contains("u64") {
467                    "u64".to_string()
468                } else if rust_type_name.contains("i64") {
469                    "i64".to_string()
470                } else if rust_type_name.contains("u32") {
471                    "u32".to_string()
472                } else if rust_type_name.contains("i32") {
473                    "i32".to_string()
474                } else {
475                    "i64".to_string()
476                }
477            }
478            BaseType::Float => "f64".to_string(),
479            BaseType::String => "String".to_string(),
480            BaseType::Boolean => "bool".to_string(),
481            BaseType::Timestamp => "i64".to_string(),
482            BaseType::Binary => "Vec<u8>".to_string(),
483            BaseType::Pubkey => "String".to_string(),
484            BaseType::Array => "Vec<serde_json::Value>".to_string(),
485            BaseType::Object => "serde_json::Value".to_string(),
486            BaseType::Any => "serde_json::Value".to_string(),
487        }
488    }
489
490    fn resolved_field_to_rust(&self, field: &ResolvedField) -> String {
491        let base = self.base_type_to_rust(&field.base_type, &field.field_type);
492
493        let typed = if field.is_array {
494            format!("Vec<{}>", base)
495        } else {
496            base
497        };
498
499        if field.is_optional {
500            format!("Option<Option<{}>>", typed)
501        } else {
502            format!("Option<{}>", typed)
503        }
504    }
505}
506
507fn to_kebab_case(s: &str) -> String {
508    let mut result = String::new();
509    for (i, c) in s.chars().enumerate() {
510        if c.is_uppercase() {
511            if i > 0 {
512                result.push('-');
513            }
514            result.push(c.to_lowercase().next().unwrap());
515        } else {
516            result.push(c);
517        }
518    }
519    result
520}
521
522fn to_pascal_case(s: &str) -> String {
523    s.split(['_', '-', '.'])
524        .map(|word| {
525            let mut chars = word.chars();
526            match chars.next() {
527                None => String::new(),
528                Some(first) => first.to_uppercase().collect::<String>() + chars.as_str(),
529            }
530        })
531        .collect()
532}
533
534fn to_snake_case(s: &str) -> String {
535    let mut result = String::new();
536    for (i, ch) in s.chars().enumerate() {
537        if ch.is_uppercase() {
538            if i > 0 {
539                result.push('_');
540            }
541            result.push(ch.to_lowercase().next().unwrap());
542        } else {
543            result.push(ch);
544        }
545    }
546    result
547}