Skip to main content

hyperstack_idl/
snapshot.rs

1//! Snapshot type definitions
2
3use serde::{de::Error as DeError, Deserialize, Deserializer, Serialize};
4
5use crate::types::SteelDiscriminant;
6
7#[derive(Debug, Clone, Serialize)]
8pub struct IdlSnapshot {
9    pub name: String,
10    #[serde(default, skip_serializing_if = "Option::is_none", alias = "address")]
11    pub program_id: Option<String>,
12    pub version: String,
13    pub accounts: Vec<IdlAccountSnapshot>,
14    pub instructions: Vec<IdlInstructionSnapshot>,
15    #[serde(default)]
16    pub types: Vec<IdlTypeDefSnapshot>,
17    #[serde(default)]
18    pub events: Vec<IdlEventSnapshot>,
19    #[serde(default)]
20    pub errors: Vec<IdlErrorSnapshot>,
21    pub discriminant_size: usize,
22}
23
24impl<'de> Deserialize<'de> for IdlSnapshot {
25    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
26    where
27        D: Deserializer<'de>,
28    {
29        // First deserialize to a generic Value to inspect instructions
30        let value = serde_json::Value::deserialize(deserializer)?;
31
32        // Check if any instruction has discriminant (Steel-style) vs discriminator (Anchor-style)
33        let discriminant_size = value
34            .get("instructions")
35            .and_then(|instrs| instrs.as_array())
36            .map(|instrs| {
37                if instrs.is_empty() {
38                    return false;
39                }
40                instrs.iter().all(|ix| {
41                    let discriminator = ix.get("discriminator");
42                    let disc_len = discriminator
43                        .and_then(|d| d.as_array())
44                        .map(|a| a.len())
45                        .unwrap_or(0);
46
47                    // Steel-style variant 1: has explicit discriminant field, no discriminator
48                    let has_discriminant = ix.get("discriminant").is_some();
49                    let has_discriminator = discriminator
50                        .map(|d| {
51                            !d.is_null() && d.as_array().map(|a| !a.is_empty()).unwrap_or(true)
52                        })
53                        .unwrap_or(false);
54                    let is_steel_discriminant = has_discriminant && !has_discriminator;
55
56                    // Steel-style variant 2: discriminator is stored as a 1-byte array (no
57                    // discriminant field). This happens when the AST was generated from a
58                    // Steel IDL that serialised its u8 discriminant directly into the
59                    // discriminator field.
60                    let is_steel_short_discriminator = !has_discriminant && disc_len == 1;
61
62                    is_steel_discriminant || is_steel_short_discriminator
63                })
64            })
65            .map(|is_steel| if is_steel { 1 } else { 8 })
66            .unwrap_or(8); // Default to 8 if no instructions
67
68        // Now deserialize the full struct
69        let mut intermediate: IdlSnapshotIntermediate = serde_json::from_value(value)
70            .map_err(|e| DeError::custom(format!("Failed to deserialize IDL: {}", e)))?;
71        intermediate.discriminant_size = discriminant_size;
72
73        Ok(IdlSnapshot {
74            name: intermediate.name,
75            program_id: intermediate.program_id,
76            version: intermediate.version,
77            accounts: intermediate.accounts,
78            instructions: intermediate.instructions,
79            types: intermediate.types,
80            events: intermediate.events,
81            errors: intermediate.errors,
82            discriminant_size: intermediate.discriminant_size,
83        })
84    }
85}
86
87// Intermediate struct for deserialization
88#[derive(Debug, Clone, Deserialize)]
89struct IdlSnapshotIntermediate {
90    pub name: String,
91    #[serde(default, alias = "address")]
92    pub program_id: Option<String>,
93    pub version: String,
94    pub accounts: Vec<IdlAccountSnapshot>,
95    pub instructions: Vec<IdlInstructionSnapshot>,
96    #[serde(default)]
97    pub types: Vec<IdlTypeDefSnapshot>,
98    #[serde(default)]
99    pub events: Vec<IdlEventSnapshot>,
100    #[serde(default)]
101    pub errors: Vec<IdlErrorSnapshot>,
102    #[serde(default)]
103    pub discriminant_size: usize,
104}
105
106#[derive(Debug, Clone, Serialize)]
107pub struct IdlAccountSnapshot {
108    pub name: String,
109    pub discriminator: Vec<u8>,
110    pub docs: Vec<String>,
111    pub serialization: Option<IdlSerializationSnapshot>,
112    /// Account fields - populated from inline type definition
113    pub fields: Vec<IdlFieldSnapshot>,
114    /// Inline type definition (for Steel format with type.fields structure)
115    #[serde(skip_serializing_if = "Option::is_none")]
116    pub type_def: Option<IdlInlineTypeDef>,
117}
118
119// Intermediate struct for deserialization
120#[derive(Deserialize)]
121struct IdlAccountSnapshotIntermediate {
122    pub name: String,
123    pub discriminator: Vec<u8>,
124    #[serde(default)]
125    pub docs: Vec<String>,
126    #[serde(default, skip_serializing_if = "Option::is_none")]
127    pub serialization: Option<IdlSerializationSnapshot>,
128    #[serde(default)]
129    pub fields: Vec<IdlFieldSnapshot>,
130    #[serde(rename = "type", default, skip_serializing_if = "Option::is_none")]
131    pub type_def: Option<IdlInlineTypeDef>,
132}
133
134impl<'de> Deserialize<'de> for IdlAccountSnapshot {
135    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
136    where
137        D: Deserializer<'de>,
138    {
139        let intermediate = IdlAccountSnapshotIntermediate::deserialize(deserializer)?;
140
141        // Normalize fields: if empty but type_def has fields, use those
142        let fields = if intermediate.fields.is_empty() {
143            if let Some(type_def) = intermediate.type_def.as_ref() {
144                type_def.fields.clone()
145            } else {
146                intermediate.fields
147            }
148        } else {
149            intermediate.fields
150        };
151
152        Ok(IdlAccountSnapshot {
153            name: intermediate.name,
154            discriminator: intermediate.discriminator,
155            docs: intermediate.docs,
156            serialization: intermediate.serialization,
157            fields,
158            type_def: intermediate.type_def,
159        })
160    }
161}
162
163/// Inline type definition for account fields (Steel format)
164#[derive(Debug, Clone, Serialize, Deserialize)]
165pub struct IdlInlineTypeDef {
166    pub kind: String,
167    pub fields: Vec<IdlFieldSnapshot>,
168}
169
170#[derive(Debug, Clone, Serialize, Deserialize)]
171pub struct IdlInstructionSnapshot {
172    pub name: String,
173    #[serde(default)]
174    pub discriminator: Vec<u8>,
175    #[serde(default)]
176    pub discriminant: Option<SteelDiscriminant>,
177    #[serde(default)]
178    pub docs: Vec<String>,
179    pub accounts: Vec<IdlInstructionAccountSnapshot>,
180    pub args: Vec<IdlFieldSnapshot>,
181}
182
183impl IdlInstructionSnapshot {
184    /// Get the computed 8-byte discriminator.
185    /// Returns the explicit discriminator if present, otherwise computes from discriminant.
186    pub fn get_discriminator(&self) -> Vec<u8> {
187        if !self.discriminator.is_empty() {
188            return self.discriminator.clone();
189        }
190
191        if let Some(disc) = &self.discriminant {
192            match u8::try_from(disc.value) {
193                Ok(value) => return vec![value],
194                Err(_) => {
195                    tracing::warn!(
196                        instruction = %self.name,
197                        value = disc.value,
198                        "Steel discriminant exceeds u8::MAX; falling back to Anchor hash"
199                    );
200                }
201            }
202        }
203
204        crate::discriminator::anchor_discriminator(&format!("global:{}", self.name))
205    }
206}
207
208#[derive(Debug, Clone, Serialize, Deserialize)]
209pub struct IdlInstructionAccountSnapshot {
210    pub name: String,
211    #[serde(default)]
212    pub writable: bool,
213    #[serde(default)]
214    pub signer: bool,
215    #[serde(default)]
216    pub optional: bool,
217    #[serde(default)]
218    pub address: Option<String>,
219    #[serde(default)]
220    pub docs: Vec<String>,
221}
222
223#[derive(Debug, Clone, Serialize, Deserialize)]
224pub struct IdlFieldSnapshot {
225    pub name: String,
226    #[serde(rename = "type")]
227    pub type_: IdlTypeSnapshot,
228}
229
230#[derive(Debug, Clone, Serialize, Deserialize)]
231#[serde(untagged)]
232pub enum IdlTypeSnapshot {
233    Simple(String),
234    Array(IdlArrayTypeSnapshot),
235    Option(IdlOptionTypeSnapshot),
236    Vec(IdlVecTypeSnapshot),
237    HashMap(IdlHashMapTypeSnapshot),
238    Defined(IdlDefinedTypeSnapshot),
239}
240
241#[derive(Debug, Clone, Serialize, Deserialize)]
242pub struct IdlHashMapTypeSnapshot {
243    #[serde(rename = "hashMap", deserialize_with = "deserialize_hash_map")]
244    pub hash_map: (Box<IdlTypeSnapshot>, Box<IdlTypeSnapshot>),
245}
246
247fn deserialize_hash_map<'de, D>(
248    deserializer: D,
249) -> Result<(Box<IdlTypeSnapshot>, Box<IdlTypeSnapshot>), D::Error>
250where
251    D: Deserializer<'de>,
252{
253    use serde::de::Error;
254    let values: Vec<IdlTypeSnapshot> = Vec::deserialize(deserializer)?;
255    if values.len() != 2 {
256        return Err(D::Error::custom("hashMap must have exactly 2 elements"));
257    }
258    let mut iter = values.into_iter();
259    Ok((
260        Box::new(iter.next().expect("length checked")),
261        Box::new(iter.next().expect("length checked")),
262    ))
263}
264
265#[derive(Debug, Clone, Serialize, Deserialize)]
266pub struct IdlArrayTypeSnapshot {
267    pub array: Vec<IdlArrayElementSnapshot>,
268}
269
270#[derive(Debug, Clone, Serialize, Deserialize)]
271#[serde(untagged)]
272pub enum IdlArrayElementSnapshot {
273    Type(IdlTypeSnapshot),
274    TypeName(String),
275    Size(u32),
276}
277
278#[derive(Debug, Clone, Serialize, Deserialize)]
279pub struct IdlOptionTypeSnapshot {
280    pub option: Box<IdlTypeSnapshot>,
281}
282
283#[derive(Debug, Clone, Serialize, Deserialize)]
284pub struct IdlVecTypeSnapshot {
285    pub vec: Box<IdlTypeSnapshot>,
286}
287
288#[derive(Debug, Clone, Serialize, Deserialize)]
289pub struct IdlDefinedTypeSnapshot {
290    pub defined: IdlDefinedInnerSnapshot,
291}
292
293#[derive(Debug, Clone, Serialize, Deserialize)]
294#[serde(untagged)]
295pub enum IdlDefinedInnerSnapshot {
296    Named { name: String },
297    Simple(String),
298}
299
300#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
301#[serde(rename_all = "lowercase")]
302pub enum IdlSerializationSnapshot {
303    Borsh,
304    Bytemuck,
305    #[serde(alias = "bytemuckunsafe")]
306    BytemuckUnsafe,
307}
308
309#[derive(Debug, Clone, Serialize, Deserialize)]
310pub struct IdlTypeDefSnapshot {
311    pub name: String,
312    #[serde(default)]
313    pub docs: Vec<String>,
314    #[serde(default, skip_serializing_if = "Option::is_none")]
315    pub serialization: Option<IdlSerializationSnapshot>,
316    #[serde(rename = "type")]
317    pub type_def: IdlTypeDefKindSnapshot,
318}
319
320#[derive(Debug, Clone, Serialize, Deserialize)]
321#[serde(untagged)]
322pub enum IdlTypeDefKindSnapshot {
323    Struct {
324        kind: String,
325        fields: Vec<IdlFieldSnapshot>,
326    },
327    TupleStruct {
328        kind: String,
329        fields: Vec<IdlTypeSnapshot>,
330    },
331    Enum {
332        kind: String,
333        variants: Vec<IdlEnumVariantSnapshot>,
334    },
335}
336
337#[derive(Debug, Clone, Serialize, Deserialize)]
338pub struct IdlEnumVariantSnapshot {
339    pub name: String,
340}
341
342#[derive(Debug, Clone, Serialize, Deserialize)]
343pub struct IdlEventSnapshot {
344    pub name: String,
345    pub discriminator: Vec<u8>,
346    #[serde(default)]
347    pub docs: Vec<String>,
348}
349
350#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
351pub struct IdlErrorSnapshot {
352    pub code: u32,
353    pub name: String,
354    #[serde(default, skip_serializing_if = "Option::is_none")]
355    pub msg: Option<String>,
356}
357
358#[cfg(test)]
359mod tests {
360    use super::*;
361
362    #[test]
363    fn test_snapshot_serde() {
364        let snapshot = IdlSnapshot {
365            name: "test_program".to_string(),
366            program_id: Some("11111111111111111111111111111111".to_string()),
367            version: "0.1.0".to_string(),
368            accounts: vec![IdlAccountSnapshot {
369                name: "ExampleAccount".to_string(),
370                discriminator: vec![1, 2, 3, 4, 5, 6, 7, 8],
371                docs: vec!["Example account".to_string()],
372                serialization: Some(IdlSerializationSnapshot::Borsh),
373                fields: vec![],
374                type_def: None,
375            }],
376            instructions: vec![IdlInstructionSnapshot {
377                name: "example_instruction".to_string(),
378                discriminator: vec![8, 7, 6, 5, 4, 3, 2, 1],
379                discriminant: None,
380                docs: vec!["Example instruction".to_string()],
381                accounts: vec![IdlInstructionAccountSnapshot {
382                    name: "payer".to_string(),
383                    writable: true,
384                    signer: true,
385                    optional: false,
386                    address: None,
387                    docs: vec![],
388                }],
389                args: vec![IdlFieldSnapshot {
390                    name: "amount".to_string(),
391                    type_: IdlTypeSnapshot::HashMap(IdlHashMapTypeSnapshot {
392                        hash_map: (
393                            Box::new(IdlTypeSnapshot::Simple("u64".to_string())),
394                            Box::new(IdlTypeSnapshot::Simple("string".to_string())),
395                        ),
396                    }),
397                }],
398            }],
399            types: vec![IdlTypeDefSnapshot {
400                name: "ExampleType".to_string(),
401                docs: vec![],
402                serialization: None,
403                type_def: IdlTypeDefKindSnapshot::Struct {
404                    kind: "struct".to_string(),
405                    fields: vec![IdlFieldSnapshot {
406                        name: "value".to_string(),
407                        type_: IdlTypeSnapshot::Simple("u64".to_string()),
408                    }],
409                },
410            }],
411            events: vec![IdlEventSnapshot {
412                name: "ExampleEvent".to_string(),
413                discriminator: vec![0, 0, 0, 0, 0, 0, 0, 1],
414                docs: vec![],
415            }],
416            errors: vec![IdlErrorSnapshot {
417                code: 6000,
418                name: "ExampleError".to_string(),
419                msg: Some("example".to_string()),
420            }],
421            discriminant_size: 8,
422        };
423
424        let serialized = serde_json::to_value(&snapshot).expect("serialize snapshot");
425        let deserialized: IdlSnapshot =
426            serde_json::from_value(serialized.clone()).expect("deserialize snapshot");
427        let round_trip = serde_json::to_value(&deserialized).expect("re-serialize snapshot");
428
429        assert_eq!(serialized, round_trip);
430        assert_eq!(deserialized.name, "test_program");
431    }
432
433    #[test]
434    fn test_hashmap_compat() {
435        let json = r#"{"hashMap":["u64","string"]}"#;
436        let parsed: IdlHashMapTypeSnapshot =
437            serde_json::from_str(json).expect("deserialize hashMap");
438
439        assert!(matches!(
440            parsed.hash_map.0.as_ref(),
441            IdlTypeSnapshot::Simple(value) if value == "u64"
442        ));
443        assert!(matches!(
444            parsed.hash_map.1.as_ref(),
445            IdlTypeSnapshot::Simple(value) if value == "string"
446        ));
447    }
448}