hyperstack_interpreter/
rust.rs

1use crate::ast::*;
2use std::collections::HashSet;
3
4#[derive(Debug, Clone)]
5pub struct RustOutput {
6    pub cargo_toml: String,
7    pub lib_rs: String,
8    pub types_rs: String,
9    pub entity_rs: String,
10}
11
12impl RustOutput {
13    pub fn full_lib(&self) -> String {
14        format!(
15            "{}\n\n// types.rs\n{}\n\n// entity.rs\n{}",
16            self.lib_rs, self.types_rs, self.entity_rs
17        )
18    }
19
20    pub fn mod_rs(&self) -> String {
21        self.lib_rs.clone()
22    }
23}
24
25#[derive(Debug, Clone)]
26pub struct RustConfig {
27    pub crate_name: String,
28    pub sdk_version: String,
29    pub module_mode: bool,
30}
31
32impl Default for RustConfig {
33    fn default() -> Self {
34        Self {
35            crate_name: "generated-stack".to_string(),
36            sdk_version: "0.2".to_string(),
37            module_mode: false,
38        }
39    }
40}
41
42pub fn compile_serializable_spec(
43    spec: SerializableStreamSpec,
44    entity_name: String,
45    config: Option<RustConfig>,
46) -> Result<RustOutput, String> {
47    let config = config.unwrap_or_default();
48    let compiler = RustCompiler::new(spec, entity_name, config);
49    Ok(compiler.compile())
50}
51
52pub fn write_rust_crate(
53    output: &RustOutput,
54    crate_dir: &std::path::Path,
55) -> Result<(), std::io::Error> {
56    std::fs::create_dir_all(crate_dir.join("src"))?;
57    std::fs::write(crate_dir.join("Cargo.toml"), &output.cargo_toml)?;
58    std::fs::write(crate_dir.join("src/lib.rs"), &output.lib_rs)?;
59    std::fs::write(crate_dir.join("src/types.rs"), &output.types_rs)?;
60    std::fs::write(crate_dir.join("src/entity.rs"), &output.entity_rs)?;
61    Ok(())
62}
63
64pub fn write_rust_module(
65    output: &RustOutput,
66    module_dir: &std::path::Path,
67) -> Result<(), std::io::Error> {
68    std::fs::create_dir_all(module_dir)?;
69    std::fs::write(module_dir.join("mod.rs"), output.mod_rs())?;
70    std::fs::write(module_dir.join("types.rs"), &output.types_rs)?;
71    std::fs::write(module_dir.join("entity.rs"), &output.entity_rs)?;
72    Ok(())
73}
74
75struct RustCompiler {
76    spec: SerializableStreamSpec,
77    entity_name: String,
78    config: RustConfig,
79}
80
81impl RustCompiler {
82    fn new(spec: SerializableStreamSpec, entity_name: String, config: RustConfig) -> Self {
83        Self {
84            spec,
85            entity_name,
86            config,
87        }
88    }
89
90    fn compile(&self) -> RustOutput {
91        RustOutput {
92            cargo_toml: self.generate_cargo_toml(),
93            lib_rs: self.generate_lib_rs(),
94            types_rs: self.generate_types_rs(),
95            entity_rs: self.generate_entity_rs(),
96        }
97    }
98
99    fn generate_cargo_toml(&self) -> String {
100        format!(
101            r#"[package]
102name = "{}"
103version = "0.1.0"
104edition = "2021"
105
106[dependencies]
107hyperstack-sdk = "{}"
108serde = {{ version = "1", features = ["derive"] }}
109serde_json = "1"
110"#,
111            self.config.crate_name, self.config.sdk_version
112        )
113    }
114
115    fn generate_lib_rs(&self) -> String {
116        format!(
117            r#"mod types;
118mod entity;
119
120pub use types::*;
121pub use entity::{{{}Entity, {}Views}};
122
123pub use hyperstack_sdk::{{HyperStack, Entity, Update, ConnectionState, Views}};
124"#,
125            self.entity_name, self.entity_name
126        )
127    }
128
129    fn generate_types_rs(&self) -> String {
130        let mut output = String::new();
131        output.push_str("use serde::{Deserialize, Serialize};\n\n");
132
133        let mut generated = HashSet::new();
134
135        for section in &self.spec.sections {
136            if !Self::is_root_section(&section.name) && generated.insert(section.name.clone()) {
137                output.push_str(&self.generate_struct_for_section(section));
138                output.push_str("\n\n");
139            }
140        }
141
142        output.push_str(&self.generate_main_entity_struct());
143        output.push_str(&self.generate_resolved_types(&mut generated));
144        output.push_str(&self.generate_event_wrapper());
145
146        output
147    }
148
149    fn generate_struct_for_section(&self, section: &EntitySection) -> String {
150        let struct_name = format!("{}{}", self.entity_name, to_pascal_case(&section.name));
151        let mut fields = Vec::new();
152
153        for field in &section.fields {
154            let field_name = to_snake_case(&field.field_name);
155            let rust_type = self.field_type_to_rust(field);
156
157            fields.push(format!(
158                "    #[serde(default)]\n    pub {}: {},",
159                field_name, rust_type
160            ));
161        }
162
163        format!(
164            "#[derive(Debug, Clone, Serialize, Deserialize, Default)]\npub struct {} {{\n{}\n}}",
165            struct_name,
166            fields.join("\n")
167        )
168    }
169
170    /// Check if a section name is the root section (case-insensitive)
171    fn is_root_section(name: &str) -> bool {
172        name.eq_ignore_ascii_case("root")
173    }
174
175    fn generate_main_entity_struct(&self) -> String {
176        let mut fields = Vec::new();
177
178        for section in &self.spec.sections {
179            if !Self::is_root_section(&section.name) {
180                let field_name = to_snake_case(&section.name);
181                let type_name = format!("{}{}", self.entity_name, to_pascal_case(&section.name));
182                fields.push(format!(
183                    "    #[serde(default)]\n    pub {}: {},",
184                    field_name, type_name
185                ));
186            }
187        }
188
189        for section in &self.spec.sections {
190            if Self::is_root_section(&section.name) {
191                for field in &section.fields {
192                    let field_name = to_snake_case(&field.field_name);
193                    let rust_type = self.field_type_to_rust(field);
194                    fields.push(format!(
195                        "    #[serde(default)]\n    pub {}: {},",
196                        field_name, rust_type
197                    ));
198                }
199            }
200        }
201
202        format!(
203            "#[derive(Debug, Clone, Serialize, Deserialize, Default)]\npub struct {} {{\n{}\n}}",
204            self.entity_name,
205            fields.join("\n")
206        )
207    }
208
209    fn generate_resolved_types(&self, generated: &mut HashSet<String>) -> String {
210        let mut output = String::new();
211
212        for section in &self.spec.sections {
213            for field in &section.fields {
214                if let Some(resolved) = &field.resolved_type {
215                    if generated.insert(resolved.type_name.clone()) {
216                        output.push_str("\n\n");
217                        output.push_str(&self.generate_resolved_struct(resolved));
218                    }
219                }
220            }
221        }
222
223        output
224    }
225
226    fn generate_resolved_struct(&self, resolved: &ResolvedStructType) -> String {
227        if resolved.is_enum {
228            let variants: Vec<String> = resolved
229                .enum_variants
230                .iter()
231                .map(|v| format!("    {},", to_pascal_case(v)))
232                .collect();
233
234            format!(
235                "#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]\npub enum {} {{\n{}\n}}",
236                to_pascal_case(&resolved.type_name),
237                variants.join("\n")
238            )
239        } else {
240            let fields: Vec<String> = resolved
241                .fields
242                .iter()
243                .map(|f| {
244                    let rust_type = self.resolved_field_to_rust(f);
245                    format!(
246                        "    #[serde(default)]\n    pub {}: {},",
247                        to_snake_case(&f.field_name),
248                        rust_type
249                    )
250                })
251                .collect();
252
253            format!(
254                "#[derive(Debug, Clone, Serialize, Deserialize, Default)]\npub struct {} {{\n{}\n}}",
255                to_pascal_case(&resolved.type_name),
256                fields.join("\n")
257            )
258        }
259    }
260
261    fn generate_event_wrapper(&self) -> String {
262        r#"
263
264#[derive(Debug, Clone, Serialize, Deserialize)]
265pub struct EventWrapper<T> {
266    #[serde(default)]
267    pub timestamp: i64,
268    pub data: T,
269    #[serde(default)]
270    pub slot: Option<f64>,
271    #[serde(default)]
272    pub signature: Option<String>,
273}
274
275impl<T: Default> Default for EventWrapper<T> {
276    fn default() -> Self {
277        Self {
278            timestamp: 0,
279            data: T::default(),
280            slot: None,
281            signature: None,
282        }
283    }
284}
285"#
286        .to_string()
287    }
288
289    fn generate_entity_rs(&self) -> String {
290        let entity_name = &self.entity_name;
291        let types_import = if self.config.module_mode {
292            "super::types"
293        } else {
294            "crate::types"
295        };
296
297        let views_struct = self.generate_views_struct();
298
299        format!(
300            r#"use hyperstack_sdk::{{Entity, StateView, ViewBuilder, ViewHandle, Views}};
301use {types_import}::{entity_name};
302
303pub struct {entity_name}Entity;
304
305impl Entity for {entity_name}Entity {{
306    type Data = {entity_name};
307    
308    const NAME: &'static str = "{entity_name}";
309    
310    fn state_view() -> &'static str {{
311        "{entity_name}/state"
312    }}
313    
314    fn list_view() -> &'static str {{
315        "{entity_name}/list"
316    }}
317}}
318{views_struct}"#,
319            types_import = types_import,
320            entity_name = entity_name,
321            views_struct = views_struct
322        )
323    }
324
325    fn generate_views_struct(&self) -> String {
326        let entity_name = &self.entity_name;
327
328        let derived: Vec<_> = self
329            .spec
330            .views
331            .iter()
332            .filter(|v| {
333                !v.id.ends_with("/state")
334                    && !v.id.ends_with("/list")
335                    && v.id.starts_with(entity_name)
336            })
337            .collect();
338
339        let mut derived_methods = String::new();
340        for view in &derived {
341            let view_name = view.id.split('/').nth(1).unwrap_or("unknown");
342            let method_name = to_snake_case(view_name);
343
344            derived_methods.push_str(&format!(
345                r#"
346    pub fn {method_name}(&self) -> ViewHandle<{entity_name}> {{
347        self.builder.view("{view_id}")
348    }}
349"#,
350                method_name = method_name,
351                entity_name = entity_name,
352                view_id = view.id
353            ));
354        }
355
356        format!(
357            r#"
358
359pub struct {entity_name}Views {{
360    builder: ViewBuilder,
361}}
362
363impl Views for {entity_name}Views {{
364    type Entity = {entity_name}Entity;
365
366    fn from_builder(builder: ViewBuilder) -> Self {{
367        Self {{ builder }}
368    }}
369}}
370
371impl {entity_name}Views {{
372    pub fn state(&self) -> StateView<{entity_name}> {{
373        StateView::new(
374            self.builder.connection().clone(),
375            self.builder.store().clone(),
376            "{entity_name}/state".to_string(),
377            self.builder.initial_data_timeout(),
378        )
379    }}
380
381    pub fn list(&self) -> ViewHandle<{entity_name}> {{
382        self.builder.view("{entity_name}/list")
383    }}
384{derived_methods}}}
385"#,
386            entity_name = entity_name,
387            derived_methods = derived_methods
388        )
389    }
390
391    /// Generate Rust type for a field.
392    ///
393    /// All fields are wrapped in Option<T> because we receive partial patches,
394    /// so any field may not yet be present.
395    ///
396    /// - Non-optional spec fields become `Option<T>`:
397    ///   - `None` = not yet received in any patch
398    ///   - `Some(value)` = has value
399    ///
400    /// - Optional spec fields become `Option<Option<T>>`:
401    ///   - `None` = not yet received in any patch
402    ///   - `Some(None)` = explicitly set to null
403    ///   - `Some(Some(value))` = has value
404    fn field_type_to_rust(&self, field: &FieldTypeInfo) -> String {
405        let base = self.base_type_to_rust(&field.base_type, &field.rust_type_name);
406
407        let typed = if field.is_array && !matches!(field.base_type, BaseType::Array) {
408            format!("Vec<{}>", base)
409        } else {
410            base
411        };
412
413        // All fields wrapped in Option since we receive patches
414        // Optional spec fields get Option<Option<T>> to distinguish "not received" from "explicitly null"
415        if field.is_optional {
416            format!("Option<Option<{}>>", typed)
417        } else {
418            format!("Option<{}>", typed)
419        }
420    }
421
422    fn base_type_to_rust(&self, base_type: &BaseType, rust_type_name: &str) -> String {
423        match base_type {
424            BaseType::Integer => {
425                if rust_type_name.contains("u64") {
426                    "u64".to_string()
427                } else if rust_type_name.contains("i64") {
428                    "i64".to_string()
429                } else if rust_type_name.contains("u32") {
430                    "u32".to_string()
431                } else if rust_type_name.contains("i32") {
432                    "i32".to_string()
433                } else {
434                    "i64".to_string()
435                }
436            }
437            BaseType::Float => "f64".to_string(),
438            BaseType::String => "String".to_string(),
439            BaseType::Boolean => "bool".to_string(),
440            BaseType::Timestamp => "i64".to_string(),
441            BaseType::Binary => "Vec<u8>".to_string(),
442            BaseType::Pubkey => "String".to_string(),
443            BaseType::Array => "Vec<serde_json::Value>".to_string(),
444            BaseType::Object => "serde_json::Value".to_string(),
445            BaseType::Any => "serde_json::Value".to_string(),
446        }
447    }
448
449    fn resolved_field_to_rust(&self, field: &ResolvedField) -> String {
450        let base = self.base_type_to_rust(&field.base_type, &field.field_type);
451
452        let typed = if field.is_array {
453            format!("Vec<{}>", base)
454        } else {
455            base
456        };
457
458        if field.is_optional {
459            format!("Option<Option<{}>>", typed)
460        } else {
461            format!("Option<{}>", typed)
462        }
463    }
464}
465
466fn to_pascal_case(s: &str) -> String {
467    s.split(['_', '-', '.'])
468        .map(|word| {
469            let mut chars = word.chars();
470            match chars.next() {
471                None => String::new(),
472                Some(first) => first.to_uppercase().collect::<String>() + chars.as_str(),
473            }
474        })
475        .collect()
476}
477
478fn to_snake_case(s: &str) -> String {
479    let mut result = String::new();
480    for (i, ch) in s.chars().enumerate() {
481        if ch.is_uppercase() {
482            if i > 0 {
483                result.push('_');
484            }
485            result.push(ch.to_lowercase().next().unwrap());
486        } else {
487            result.push(ch);
488        }
489    }
490    result
491}