hyperstack_interpreter/
rust.rs

1use crate::ast::*;
2use std::collections::HashSet;
3
4#[derive(Debug, Clone)]
5pub struct RustOutput {
6    pub cargo_toml: String,
7    pub lib_rs: String,
8    pub types_rs: String,
9    pub entity_rs: String,
10}
11
12impl RustOutput {
13    pub fn full_lib(&self) -> String {
14        format!(
15            "{}\n\n// types.rs\n{}\n\n// entity.rs\n{}",
16            self.lib_rs, self.types_rs, self.entity_rs
17        )
18    }
19
20    pub fn mod_rs(&self) -> String {
21        self.lib_rs.clone()
22    }
23}
24
25#[derive(Debug, Clone)]
26pub struct RustConfig {
27    pub crate_name: String,
28    pub sdk_version: String,
29    pub module_mode: bool,
30}
31
32impl Default for RustConfig {
33    fn default() -> Self {
34        Self {
35            crate_name: "generated-stack".to_string(),
36            sdk_version: "0.2".to_string(),
37            module_mode: false,
38        }
39    }
40}
41
42pub fn compile_serializable_spec(
43    spec: SerializableStreamSpec,
44    entity_name: String,
45    config: Option<RustConfig>,
46) -> Result<RustOutput, String> {
47    let config = config.unwrap_or_default();
48    let compiler = RustCompiler::new(spec, entity_name, config);
49    Ok(compiler.compile())
50}
51
52pub fn write_rust_crate(
53    output: &RustOutput,
54    crate_dir: &std::path::Path,
55) -> Result<(), std::io::Error> {
56    std::fs::create_dir_all(crate_dir.join("src"))?;
57    std::fs::write(crate_dir.join("Cargo.toml"), &output.cargo_toml)?;
58    std::fs::write(crate_dir.join("src/lib.rs"), &output.lib_rs)?;
59    std::fs::write(crate_dir.join("src/types.rs"), &output.types_rs)?;
60    std::fs::write(crate_dir.join("src/entity.rs"), &output.entity_rs)?;
61    Ok(())
62}
63
64pub fn write_rust_module(
65    output: &RustOutput,
66    module_dir: &std::path::Path,
67) -> Result<(), std::io::Error> {
68    std::fs::create_dir_all(module_dir)?;
69    std::fs::write(module_dir.join("mod.rs"), output.mod_rs())?;
70    std::fs::write(module_dir.join("types.rs"), &output.types_rs)?;
71    std::fs::write(module_dir.join("entity.rs"), &output.entity_rs)?;
72    Ok(())
73}
74
75struct RustCompiler {
76    spec: SerializableStreamSpec,
77    entity_name: String,
78    config: RustConfig,
79}
80
81impl RustCompiler {
82    fn new(spec: SerializableStreamSpec, entity_name: String, config: RustConfig) -> Self {
83        Self {
84            spec,
85            entity_name,
86            config,
87        }
88    }
89
90    fn compile(&self) -> RustOutput {
91        RustOutput {
92            cargo_toml: self.generate_cargo_toml(),
93            lib_rs: self.generate_lib_rs(),
94            types_rs: self.generate_types_rs(),
95            entity_rs: self.generate_entity_rs(),
96        }
97    }
98
99    fn generate_cargo_toml(&self) -> String {
100        format!(
101            r#"[package]
102name = "{}"
103version = "0.1.0"
104edition = "2021"
105
106[dependencies]
107hyperstack-sdk = "{}"
108serde = {{ version = "1", features = ["derive"] }}
109serde_json = "1"
110"#,
111            self.config.crate_name, self.config.sdk_version
112        )
113    }
114
115    fn generate_lib_rs(&self) -> String {
116        format!(
117            r#"mod types;
118mod entity;
119
120pub use types::*;
121pub use entity::{entity_name}Entity;
122
123pub use hyperstack_sdk::{{HyperStack, Entity, Update, ConnectionState}};
124"#,
125            entity_name = self.entity_name
126        )
127    }
128
129    fn generate_types_rs(&self) -> String {
130        let mut output = String::new();
131        output.push_str("use serde::{Deserialize, Serialize};\n\n");
132
133        let mut generated = HashSet::new();
134
135        for section in &self.spec.sections {
136            if !Self::is_root_section(&section.name) && generated.insert(section.name.clone()) {
137                output.push_str(&self.generate_struct_for_section(section));
138                output.push_str("\n\n");
139            }
140        }
141
142        output.push_str(&self.generate_main_entity_struct());
143        output.push_str(&self.generate_resolved_types(&mut generated));
144        output.push_str(&self.generate_event_wrapper());
145
146        output
147    }
148
149    fn generate_struct_for_section(&self, section: &EntitySection) -> String {
150        let struct_name = format!("{}{}", self.entity_name, to_pascal_case(&section.name));
151        let mut fields = Vec::new();
152
153        for field in &section.fields {
154            let field_name = to_snake_case(&field.field_name);
155            let rust_type = self.field_type_to_rust(field);
156
157            let serde_attr = if field_name != to_snake_case(&field.field_name)
158                || field_name != field.field_name
159            {
160                let original = &field.field_name;
161                if to_snake_case(original) != *original {
162                    format!(
163                        "    #[serde(rename = \"{}\", default)]\n",
164                        to_camel_case(original)
165                    )
166                } else {
167                    "    #[serde(default)]\n".to_string()
168                }
169            } else {
170                "    #[serde(default)]\n".to_string()
171            };
172
173            fields.push(format!(
174                "{}    pub {}: {},",
175                serde_attr,
176                to_snake_case(&field.field_name),
177                rust_type
178            ));
179        }
180
181        format!(
182            "#[derive(Debug, Clone, Serialize, Deserialize, Default)]\npub struct {} {{\n{}\n}}",
183            struct_name,
184            fields.join("\n")
185        )
186    }
187
188    /// Check if a section name is the root section (case-insensitive)
189    fn is_root_section(name: &str) -> bool {
190        name.eq_ignore_ascii_case("root")
191    }
192
193    fn generate_main_entity_struct(&self) -> String {
194        let mut fields = Vec::new();
195
196        for section in &self.spec.sections {
197            if !Self::is_root_section(&section.name) {
198                let field_name = to_snake_case(&section.name);
199                let type_name = format!("{}{}", self.entity_name, to_pascal_case(&section.name));
200                let serde_attr = if field_name != section.name {
201                    format!(
202                        "    #[serde(rename = \"{}\", default)]\n",
203                        to_camel_case(&section.name)
204                    )
205                } else {
206                    "    #[serde(default)]\n".to_string()
207                };
208                fields.push(format!(
209                    "{}    pub {}: {},",
210                    serde_attr, field_name, type_name
211                ));
212            }
213        }
214
215        for section in &self.spec.sections {
216            if Self::is_root_section(&section.name) {
217                for field in &section.fields {
218                    let field_name = to_snake_case(&field.field_name);
219                    let rust_type = self.field_type_to_rust(field);
220                    fields.push(format!(
221                        "    #[serde(default)]\n    pub {}: {},",
222                        field_name, rust_type
223                    ));
224                }
225            }
226        }
227
228        format!(
229            "#[derive(Debug, Clone, Serialize, Deserialize, Default)]\npub struct {} {{\n{}\n}}",
230            self.entity_name,
231            fields.join("\n")
232        )
233    }
234
235    fn generate_resolved_types(&self, generated: &mut HashSet<String>) -> String {
236        let mut output = String::new();
237
238        for section in &self.spec.sections {
239            for field in &section.fields {
240                if let Some(resolved) = &field.resolved_type {
241                    if generated.insert(resolved.type_name.clone()) {
242                        output.push_str("\n\n");
243                        output.push_str(&self.generate_resolved_struct(resolved));
244                    }
245                }
246            }
247        }
248
249        output
250    }
251
252    fn generate_resolved_struct(&self, resolved: &ResolvedStructType) -> String {
253        if resolved.is_enum {
254            let variants: Vec<String> = resolved
255                .enum_variants
256                .iter()
257                .map(|v| format!("    {},", to_pascal_case(v)))
258                .collect();
259
260            format!(
261                "#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]\npub enum {} {{\n{}\n}}",
262                to_pascal_case(&resolved.type_name),
263                variants.join("\n")
264            )
265        } else {
266            let fields: Vec<String> = resolved
267                .fields
268                .iter()
269                .map(|f| {
270                    let rust_type = self.resolved_field_to_rust(f);
271                    let serde_attr = format!(
272                        "    #[serde(rename = \"{}\", default)]\n",
273                        to_camel_case(&f.field_name)
274                    );
275                    format!(
276                        "{}    pub {}: {},",
277                        serde_attr,
278                        to_snake_case(&f.field_name),
279                        rust_type
280                    )
281                })
282                .collect();
283
284            format!(
285                "#[derive(Debug, Clone, Serialize, Deserialize, Default)]\npub struct {} {{\n{}\n}}",
286                to_pascal_case(&resolved.type_name),
287                fields.join("\n")
288            )
289        }
290    }
291
292    fn generate_event_wrapper(&self) -> String {
293        r#"
294
295#[derive(Debug, Clone, Serialize, Deserialize)]
296pub struct EventWrapper<T> {
297    #[serde(default)]
298    pub timestamp: i64,
299    pub data: T,
300    #[serde(default)]
301    pub slot: Option<f64>,
302    #[serde(default)]
303    pub signature: Option<String>,
304}
305
306impl<T: Default> Default for EventWrapper<T> {
307    fn default() -> Self {
308        Self {
309            timestamp: 0,
310            data: T::default(),
311            slot: None,
312            signature: None,
313        }
314    }
315}
316"#
317        .to_string()
318    }
319
320    fn generate_entity_rs(&self) -> String {
321        let entity_name = &self.entity_name;
322        let types_import = if self.config.module_mode {
323            "super::types"
324        } else {
325            "crate::types"
326        };
327
328        format!(
329            r#"use hyperstack_sdk::Entity;
330use {types_import}::{entity_name};
331
332pub struct {entity_name}Entity;
333
334impl Entity for {entity_name}Entity {{
335    type Data = {entity_name};
336    
337    const NAME: &'static str = "{entity_name}";
338    
339    fn state_view() -> &'static str {{
340        "{entity_name}/state"
341    }}
342    
343    fn list_view() -> &'static str {{
344        "{entity_name}/list"
345    }}
346}}
347"#,
348            types_import = types_import,
349            entity_name = entity_name
350        )
351    }
352
353    /// Generate Rust type for a field.
354    ///
355    /// All fields are wrapped in Option<T> because we receive partial patches,
356    /// so any field may not yet be present.
357    ///
358    /// - Non-optional spec fields become `Option<T>`:
359    ///   - `None` = not yet received in any patch
360    ///   - `Some(value)` = has value
361    ///
362    /// - Optional spec fields become `Option<Option<T>>`:
363    ///   - `None` = not yet received in any patch
364    ///   - `Some(None)` = explicitly set to null
365    ///   - `Some(Some(value))` = has value
366    fn field_type_to_rust(&self, field: &FieldTypeInfo) -> String {
367        let base = self.base_type_to_rust(&field.base_type, &field.rust_type_name);
368
369        let typed = if field.is_array && !matches!(field.base_type, BaseType::Array) {
370            format!("Vec<{}>", base)
371        } else {
372            base
373        };
374
375        // All fields wrapped in Option since we receive patches
376        // Optional spec fields get Option<Option<T>> to distinguish "not received" from "explicitly null"
377        if field.is_optional {
378            format!("Option<Option<{}>>", typed)
379        } else {
380            format!("Option<{}>", typed)
381        }
382    }
383
384    fn base_type_to_rust(&self, base_type: &BaseType, rust_type_name: &str) -> String {
385        match base_type {
386            BaseType::Integer => {
387                if rust_type_name.contains("u64") {
388                    "u64".to_string()
389                } else if rust_type_name.contains("i64") {
390                    "i64".to_string()
391                } else if rust_type_name.contains("u32") {
392                    "u32".to_string()
393                } else if rust_type_name.contains("i32") {
394                    "i32".to_string()
395                } else {
396                    "i64".to_string()
397                }
398            }
399            BaseType::Float => "f64".to_string(),
400            BaseType::String => "String".to_string(),
401            BaseType::Boolean => "bool".to_string(),
402            BaseType::Timestamp => "i64".to_string(),
403            BaseType::Binary => "Vec<u8>".to_string(),
404            BaseType::Pubkey => "String".to_string(),
405            BaseType::Array => "Vec<serde_json::Value>".to_string(),
406            BaseType::Object => "serde_json::Value".to_string(),
407            BaseType::Any => "serde_json::Value".to_string(),
408        }
409    }
410
411    fn resolved_field_to_rust(&self, field: &ResolvedField) -> String {
412        let base = self.base_type_to_rust(&field.base_type, &field.field_type);
413
414        let typed = if field.is_array {
415            format!("Vec<{}>", base)
416        } else {
417            base
418        };
419
420        if field.is_optional {
421            format!("Option<Option<{}>>", typed)
422        } else {
423            format!("Option<{}>", typed)
424        }
425    }
426}
427
428fn to_pascal_case(s: &str) -> String {
429    s.split(['_', '-', '.'])
430        .map(|word| {
431            let mut chars = word.chars();
432            match chars.next() {
433                None => String::new(),
434                Some(first) => first.to_uppercase().collect::<String>() + chars.as_str(),
435            }
436        })
437        .collect()
438}
439
440fn to_snake_case(s: &str) -> String {
441    let mut result = String::new();
442    for (i, ch) in s.chars().enumerate() {
443        if ch.is_uppercase() {
444            if i > 0 {
445                result.push('_');
446            }
447            result.push(ch.to_lowercase().next().unwrap());
448        } else {
449            result.push(ch);
450        }
451    }
452    result
453}
454
455fn to_camel_case(s: &str) -> String {
456    let pascal = to_pascal_case(s);
457    let mut chars = pascal.chars();
458    match chars.next() {
459        None => String::new(),
460        Some(first) => first.to_lowercase().collect::<String>() + chars.as_str(),
461    }
462}