hyperstack_interpreter/
rust.rs

1use crate::ast::*;
2use std::collections::HashSet;
3
4#[derive(Debug, Clone)]
5pub struct RustOutput {
6    pub cargo_toml: String,
7    pub lib_rs: String,
8    pub types_rs: String,
9    pub entity_rs: String,
10}
11
12impl RustOutput {
13    pub fn full_lib(&self) -> String {
14        format!(
15            "{}\n\n// types.rs\n{}\n\n// entity.rs\n{}",
16            self.lib_rs, self.types_rs, self.entity_rs
17        )
18    }
19}
20
21#[derive(Debug, Clone)]
22pub struct RustConfig {
23    pub crate_name: String,
24    pub sdk_version: String,
25}
26
27impl Default for RustConfig {
28    fn default() -> Self {
29        Self {
30            crate_name: "generated-stack".to_string(),
31            sdk_version: "0.2".to_string(),
32        }
33    }
34}
35
36pub fn compile_serializable_spec(
37    spec: SerializableStreamSpec,
38    entity_name: String,
39    config: Option<RustConfig>,
40) -> Result<RustOutput, String> {
41    let config = config.unwrap_or_default();
42    let compiler = RustCompiler::new(spec, entity_name, config);
43    Ok(compiler.compile())
44}
45
46pub fn write_rust_crate(
47    output: &RustOutput,
48    crate_dir: &std::path::Path,
49) -> Result<(), std::io::Error> {
50    std::fs::create_dir_all(crate_dir.join("src"))?;
51    std::fs::write(crate_dir.join("Cargo.toml"), &output.cargo_toml)?;
52    std::fs::write(crate_dir.join("src/lib.rs"), &output.lib_rs)?;
53    std::fs::write(crate_dir.join("src/types.rs"), &output.types_rs)?;
54    std::fs::write(crate_dir.join("src/entity.rs"), &output.entity_rs)?;
55    Ok(())
56}
57
58struct RustCompiler {
59    spec: SerializableStreamSpec,
60    entity_name: String,
61    config: RustConfig,
62}
63
64impl RustCompiler {
65    fn new(spec: SerializableStreamSpec, entity_name: String, config: RustConfig) -> Self {
66        Self {
67            spec,
68            entity_name,
69            config,
70        }
71    }
72
73    fn compile(&self) -> RustOutput {
74        RustOutput {
75            cargo_toml: self.generate_cargo_toml(),
76            lib_rs: self.generate_lib_rs(),
77            types_rs: self.generate_types_rs(),
78            entity_rs: self.generate_entity_rs(),
79        }
80    }
81
82    fn generate_cargo_toml(&self) -> String {
83        format!(
84            r#"[package]
85name = "{}"
86version = "0.1.0"
87edition = "2021"
88
89[dependencies]
90hyperstack-sdk = "{}"
91serde = {{ version = "1", features = ["derive"] }}
92serde_json = "1"
93"#,
94            self.config.crate_name, self.config.sdk_version
95        )
96    }
97
98    fn generate_lib_rs(&self) -> String {
99        format!(
100            r#"mod types;
101mod entity;
102
103pub use types::*;
104pub use entity::{entity_name}Entity;
105
106pub use hyperstack_sdk::{{HyperStack, Entity, Update, ConnectionState}};
107"#,
108            entity_name = self.entity_name
109        )
110    }
111
112    fn generate_types_rs(&self) -> String {
113        let mut output = String::new();
114        output.push_str("use serde::{Deserialize, Deserializer, Serialize};\n\n");
115        output.push_str(&self.generate_serde_helpers());
116
117        let mut generated = HashSet::new();
118
119        for section in &self.spec.sections {
120            if !Self::is_root_section(&section.name) && generated.insert(section.name.clone()) {
121                output.push_str(&self.generate_struct_for_section(section));
122                output.push_str("\n\n");
123            }
124        }
125
126        output.push_str(&self.generate_main_entity_struct());
127        output.push_str(&self.generate_resolved_types(&mut generated));
128        output.push_str(&self.generate_event_wrapper());
129
130        output
131    }
132
133    fn generate_struct_for_section(&self, section: &EntitySection) -> String {
134        let struct_name = format!("{}{}", self.entity_name, to_pascal_case(&section.name));
135        let mut fields = Vec::new();
136
137        for field in &section.fields {
138            let field_name = to_snake_case(&field.field_name);
139            let rust_type = self.field_type_to_rust(field);
140
141            let serde_attr = if field_name != to_snake_case(&field.field_name)
142                || field_name != field.field_name
143            {
144                let original = &field.field_name;
145                if to_snake_case(original) != *original {
146                    format!(
147                        "    #[serde(rename = \"{}\", default)]\n",
148                        to_camel_case(original)
149                    )
150                } else {
151                    "    #[serde(default)]\n".to_string()
152                }
153            } else {
154                "    #[serde(default)]\n".to_string()
155            };
156
157            fields.push(format!(
158                "{}    pub {}: {},",
159                serde_attr,
160                to_snake_case(&field.field_name),
161                rust_type
162            ));
163        }
164
165        format!(
166            "#[derive(Debug, Clone, Serialize, Deserialize, Default)]\npub struct {} {{\n{}\n}}",
167            struct_name,
168            fields.join("\n")
169        )
170    }
171
172    /// Check if a section name is the root section (case-insensitive)
173    fn is_root_section(name: &str) -> bool {
174        name.eq_ignore_ascii_case("root")
175    }
176
177    fn generate_main_entity_struct(&self) -> String {
178        let mut fields = Vec::new();
179
180        for section in &self.spec.sections {
181            if !Self::is_root_section(&section.name) {
182                let field_name = to_snake_case(&section.name);
183                let type_name = format!("{}{}", self.entity_name, to_pascal_case(&section.name));
184                let serde_attr = if field_name != section.name {
185                    format!(
186                        "    #[serde(rename = \"{}\", default)]\n",
187                        to_camel_case(&section.name)
188                    )
189                } else {
190                    "    #[serde(default)]\n".to_string()
191                };
192                fields.push(format!(
193                    "{}    pub {}: {},",
194                    serde_attr, field_name, type_name
195                ));
196            }
197        }
198
199        for section in &self.spec.sections {
200            if Self::is_root_section(&section.name) {
201                for field in &section.fields {
202                    let field_name = to_snake_case(&field.field_name);
203                    let rust_type = self.field_type_to_rust(field);
204                    fields.push(format!(
205                        "    #[serde(default)]\n    pub {}: {},",
206                        field_name, rust_type
207                    ));
208                }
209            }
210        }
211
212        format!(
213            "#[derive(Debug, Clone, Serialize, Deserialize, Default)]\npub struct {} {{\n{}\n}}",
214            self.entity_name,
215            fields.join("\n")
216        )
217    }
218
219    fn generate_resolved_types(&self, generated: &mut HashSet<String>) -> String {
220        let mut output = String::new();
221
222        for section in &self.spec.sections {
223            for field in &section.fields {
224                if let Some(resolved) = &field.resolved_type {
225                    if generated.insert(resolved.type_name.clone()) {
226                        output.push_str("\n\n");
227                        output.push_str(&self.generate_resolved_struct(resolved));
228                    }
229                }
230            }
231        }
232
233        output
234    }
235
236    fn generate_resolved_struct(&self, resolved: &ResolvedStructType) -> String {
237        if resolved.is_enum {
238            let variants: Vec<String> = resolved
239                .enum_variants
240                .iter()
241                .map(|v| format!("    {},", to_pascal_case(v)))
242                .collect();
243
244            format!(
245                "#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]\npub enum {} {{\n{}\n}}",
246                to_pascal_case(&resolved.type_name),
247                variants.join("\n")
248            )
249        } else {
250            let fields: Vec<String> = resolved
251                .fields
252                .iter()
253                .map(|f| {
254                    let rust_type = self.resolved_field_to_rust(f);
255                    let serde_attr = format!(
256                        "    #[serde(rename = \"{}\", default)]\n",
257                        to_camel_case(&f.field_name)
258                    );
259                    format!(
260                        "{}    pub {}: {},",
261                        serde_attr,
262                        to_snake_case(&f.field_name),
263                        rust_type
264                    )
265                })
266                .collect();
267
268            format!(
269                "#[derive(Debug, Clone, Serialize, Deserialize, Default)]\npub struct {} {{\n{}\n}}",
270                to_pascal_case(&resolved.type_name),
271                fields.join("\n")
272            )
273        }
274    }
275
276    fn generate_event_wrapper(&self) -> String {
277        r#"
278
279#[derive(Debug, Clone, Serialize, Deserialize)]
280pub struct EventWrapper<T> {
281    #[serde(default)]
282    pub timestamp: i64,
283    pub data: T,
284    #[serde(default)]
285    pub slot: Option<f64>,
286    #[serde(default)]
287    pub signature: Option<String>,
288}
289
290impl<T: Default> Default for EventWrapper<T> {
291    fn default() -> Self {
292        Self {
293            timestamp: 0,
294            data: T::default(),
295            slot: None,
296            signature: None,
297        }
298    }
299}
300"#
301        .to_string()
302    }
303
304    fn generate_serde_helpers(&self) -> String {
305        r#"mod serde_helpers {
306    use serde::{Deserialize, Deserializer};
307
308    pub fn deserialize_number_from_any<'de, D>(deserializer: D) -> Result<Option<f64>, D::Error>
309    where
310        D: Deserializer<'de>,
311    {
312        #[derive(Deserialize)]
313        #[serde(untagged)]
314        enum NumOrNull {
315            Num(f64),
316            Null,
317        }
318        match NumOrNull::deserialize(deserializer)? {
319            NumOrNull::Num(n) => Ok(Some(n)),
320            NumOrNull::Null => Ok(None),
321        }
322    }
323}
324
325"#
326        .to_string()
327    }
328
329    fn generate_entity_rs(&self) -> String {
330        let entity_name = &self.entity_name;
331
332        format!(
333            r#"use hyperstack_sdk::Entity;
334use crate::types::{entity_name};
335
336pub struct {entity_name}Entity;
337
338impl Entity for {entity_name}Entity {{
339    type Data = {entity_name};
340    
341    const NAME: &'static str = "{entity_name}";
342    
343    fn state_view() -> &'static str {{
344        "{entity_name}/state"
345    }}
346    
347    fn list_view() -> &'static str {{
348        "{entity_name}/list"
349    }}
350}}
351"#,
352            entity_name = entity_name
353        )
354    }
355
356    /// Generate Rust type for a field.
357    ///
358    /// All fields are wrapped in Option<T> because we receive partial patches,
359    /// so any field may not yet be present.
360    ///
361    /// - Non-optional spec fields become `Option<T>`:
362    ///   - `None` = not yet received in any patch
363    ///   - `Some(value)` = has value
364    ///
365    /// - Optional spec fields become `Option<Option<T>>`:
366    ///   - `None` = not yet received in any patch
367    ///   - `Some(None)` = explicitly set to null
368    ///   - `Some(Some(value))` = has value
369    fn field_type_to_rust(&self, field: &FieldTypeInfo) -> String {
370        let base = self.base_type_to_rust(&field.base_type, &field.rust_type_name);
371
372        let typed = if field.is_array && !matches!(field.base_type, BaseType::Array) {
373            format!("Vec<{}>", base)
374        } else {
375            base
376        };
377
378        // All fields wrapped in Option since we receive patches
379        // Optional spec fields get Option<Option<T>> to distinguish "not received" from "explicitly null"
380        if field.is_optional {
381            format!("Option<Option<{}>>", typed)
382        } else {
383            format!("Option<{}>", typed)
384        }
385    }
386
387    fn base_type_to_rust(&self, base_type: &BaseType, rust_type_name: &str) -> String {
388        match base_type {
389            BaseType::Integer => {
390                if rust_type_name.contains("u64") {
391                    "u64".to_string()
392                } else if rust_type_name.contains("i64") {
393                    "i64".to_string()
394                } else if rust_type_name.contains("u32") {
395                    "u32".to_string()
396                } else if rust_type_name.contains("i32") {
397                    "i32".to_string()
398                } else {
399                    "i64".to_string()
400                }
401            }
402            BaseType::Float => "f64".to_string(),
403            BaseType::String => "String".to_string(),
404            BaseType::Boolean => "bool".to_string(),
405            BaseType::Timestamp => "i64".to_string(),
406            BaseType::Binary => "Vec<u8>".to_string(),
407            BaseType::Pubkey => "String".to_string(),
408            BaseType::Array => "Vec<serde_json::Value>".to_string(),
409            BaseType::Object => "serde_json::Value".to_string(),
410            BaseType::Any => "serde_json::Value".to_string(),
411        }
412    }
413
414    fn resolved_field_to_rust(&self, field: &ResolvedField) -> String {
415        let base = self.base_type_to_rust(&field.base_type, &field.field_type);
416
417        let typed = if field.is_array {
418            format!("Vec<{}>", base)
419        } else {
420            base
421        };
422
423        if field.is_optional {
424            format!("Option<Option<{}>>", typed)
425        } else {
426            format!("Option<{}>", typed)
427        }
428    }
429}
430
431fn to_pascal_case(s: &str) -> String {
432    s.split(['_', '-', '.'])
433        .map(|word| {
434            let mut chars = word.chars();
435            match chars.next() {
436                None => String::new(),
437                Some(first) => first.to_uppercase().collect::<String>() + chars.as_str(),
438            }
439        })
440        .collect()
441}
442
443fn to_snake_case(s: &str) -> String {
444    let mut result = String::new();
445    for (i, ch) in s.chars().enumerate() {
446        if ch.is_uppercase() {
447            if i > 0 {
448                result.push('_');
449            }
450            result.push(ch.to_lowercase().next().unwrap());
451        } else {
452            result.push(ch);
453        }
454    }
455    result
456}
457
458fn to_camel_case(s: &str) -> String {
459    let pascal = to_pascal_case(s);
460    let mut chars = pascal.chars();
461    match chars.next() {
462        None => String::new(),
463        Some(first) => first.to_lowercase().collect::<String>() + chars.as_str(),
464    }
465}