Skip to main content

hyperstack_idl/
snapshot.rs

1//! Snapshot type definitions
2
3use serde::{de::Error as DeError, Deserialize, Deserializer, Serialize};
4
5use crate::types::SteelDiscriminant;
6
7#[derive(Debug, Clone, Serialize)]
8pub struct IdlSnapshot {
9    pub name: String,
10    #[serde(default, skip_serializing_if = "Option::is_none", alias = "address")]
11    pub program_id: Option<String>,
12    pub version: String,
13    pub accounts: Vec<IdlAccountSnapshot>,
14    pub instructions: Vec<IdlInstructionSnapshot>,
15    #[serde(default)]
16    pub types: Vec<IdlTypeDefSnapshot>,
17    #[serde(default)]
18    pub events: Vec<IdlEventSnapshot>,
19    #[serde(default)]
20    pub errors: Vec<IdlErrorSnapshot>,
21    pub discriminant_size: usize,
22}
23
24impl<'de> Deserialize<'de> for IdlSnapshot {
25    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
26    where
27        D: Deserializer<'de>,
28    {
29        // First deserialize to a generic Value to inspect instructions
30        let value = serde_json::Value::deserialize(deserializer)?;
31
32        // Check if any instruction has discriminant (Steel-style) vs discriminator (Anchor-style)
33        let discriminant_size = value
34            .get("instructions")
35            .and_then(|instrs| instrs.as_array())
36            .map(|instrs| {
37                if instrs.is_empty() {
38                    return false;
39                }
40                instrs.iter().all(|ix| {
41                    let discriminator = ix.get("discriminator");
42                    let disc_len = discriminator
43                        .and_then(|d| d.as_array())
44                        .map(|a| a.len())
45                        .unwrap_or(0);
46
47                    // Treat discriminant as present only if the value is non-null.
48                    // ix.get("discriminant").is_some() returns true even for `null`,
49                    // which causes misclassification when the AST serializer writes
50                    // `discriminant: null` explicitly (as the ore AST does).
51                    let has_discriminant = ix
52                        .get("discriminant")
53                        .map(|v| !v.is_null())
54                        .unwrap_or(false);
55                    let has_discriminator = discriminator
56                        .map(|d| {
57                            !d.is_null() && d.as_array().map(|a| !a.is_empty()).unwrap_or(true)
58                        })
59                        .unwrap_or(false);
60
61                    // Steel-style variant 1: explicit discriminant object, no discriminator array
62                    let is_steel_discriminant = has_discriminant && !has_discriminator;
63
64                    // Steel-style variant 2: discriminator is stored as a 1-byte array with no
65                    // discriminant value. This happens when the AST serializer flattens the
66                    // Steel u8 discriminant directly into the discriminator field.
67                    let is_steel_short_discriminator = !has_discriminant && disc_len == 1;
68
69                    is_steel_discriminant || is_steel_short_discriminator
70                })
71            })
72            .map(|is_steel| if is_steel { 1 } else { 8 })
73            .unwrap_or(8); // Default to 8 if no instructions
74
75        // Now deserialize the full struct
76        let mut intermediate: IdlSnapshotIntermediate = serde_json::from_value(value)
77            .map_err(|e| DeError::custom(format!("Failed to deserialize IDL: {}", e)))?;
78        // Only use the heuristic if discriminant_size wasn't already present in the JSON
79        // (discriminant_size = 0 means it was absent / defaulted).
80        if intermediate.discriminant_size == 0 {
81            intermediate.discriminant_size = discriminant_size;
82        }
83
84        Ok(IdlSnapshot {
85            name: intermediate.name,
86            program_id: intermediate.program_id,
87            version: intermediate.version,
88            accounts: intermediate.accounts,
89            instructions: intermediate.instructions,
90            types: intermediate.types,
91            events: intermediate.events,
92            errors: intermediate.errors,
93            discriminant_size: intermediate.discriminant_size,
94        })
95    }
96}
97
98// Intermediate struct for deserialization
99#[derive(Debug, Clone, Deserialize)]
100struct IdlSnapshotIntermediate {
101    pub name: String,
102    #[serde(default, alias = "address")]
103    pub program_id: Option<String>,
104    pub version: String,
105    pub accounts: Vec<IdlAccountSnapshot>,
106    pub instructions: Vec<IdlInstructionSnapshot>,
107    #[serde(default)]
108    pub types: Vec<IdlTypeDefSnapshot>,
109    #[serde(default)]
110    pub events: Vec<IdlEventSnapshot>,
111    #[serde(default)]
112    pub errors: Vec<IdlErrorSnapshot>,
113    #[serde(default)]
114    pub discriminant_size: usize,
115}
116
117#[derive(Debug, Clone, Serialize)]
118pub struct IdlAccountSnapshot {
119    pub name: String,
120    pub discriminator: Vec<u8>,
121    pub docs: Vec<String>,
122    pub serialization: Option<IdlSerializationSnapshot>,
123    /// Account fields - populated from inline type definition
124    pub fields: Vec<IdlFieldSnapshot>,
125    /// Inline type definition (for Steel format with type.fields structure)
126    #[serde(skip_serializing_if = "Option::is_none")]
127    pub type_def: Option<IdlInlineTypeDef>,
128}
129
130// Intermediate struct for deserialization
131#[derive(Deserialize)]
132struct IdlAccountSnapshotIntermediate {
133    pub name: String,
134    pub discriminator: Vec<u8>,
135    #[serde(default)]
136    pub docs: Vec<String>,
137    #[serde(default, skip_serializing_if = "Option::is_none")]
138    pub serialization: Option<IdlSerializationSnapshot>,
139    #[serde(default)]
140    pub fields: Vec<IdlFieldSnapshot>,
141    #[serde(rename = "type", default, skip_serializing_if = "Option::is_none")]
142    pub type_def: Option<IdlInlineTypeDef>,
143}
144
145impl<'de> Deserialize<'de> for IdlAccountSnapshot {
146    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
147    where
148        D: Deserializer<'de>,
149    {
150        let intermediate = IdlAccountSnapshotIntermediate::deserialize(deserializer)?;
151
152        // Normalize fields: if empty but type_def has fields, use those
153        let fields = if intermediate.fields.is_empty() {
154            if let Some(type_def) = intermediate.type_def.as_ref() {
155                type_def.fields.clone()
156            } else {
157                intermediate.fields
158            }
159        } else {
160            intermediate.fields
161        };
162
163        Ok(IdlAccountSnapshot {
164            name: intermediate.name,
165            discriminator: intermediate.discriminator,
166            docs: intermediate.docs,
167            serialization: intermediate.serialization,
168            fields,
169            type_def: intermediate.type_def,
170        })
171    }
172}
173
174/// Inline type definition for account fields (Steel format)
175#[derive(Debug, Clone, Serialize, Deserialize)]
176pub struct IdlInlineTypeDef {
177    pub kind: String,
178    pub fields: Vec<IdlFieldSnapshot>,
179}
180
181#[derive(Debug, Clone, Serialize, Deserialize)]
182pub struct IdlInstructionSnapshot {
183    pub name: String,
184    #[serde(default)]
185    pub discriminator: Vec<u8>,
186    #[serde(default)]
187    pub discriminant: Option<SteelDiscriminant>,
188    #[serde(default)]
189    pub docs: Vec<String>,
190    pub accounts: Vec<IdlInstructionAccountSnapshot>,
191    pub args: Vec<IdlFieldSnapshot>,
192}
193
194impl IdlInstructionSnapshot {
195    /// Get the computed 8-byte discriminator.
196    /// Returns the explicit discriminator if present, otherwise computes from discriminant.
197    pub fn get_discriminator(&self) -> Vec<u8> {
198        if !self.discriminator.is_empty() {
199            return self.discriminator.clone();
200        }
201
202        if let Some(disc) = &self.discriminant {
203            match u8::try_from(disc.value) {
204                Ok(value) => return vec![value],
205                Err(_) => {
206                    tracing::warn!(
207                        instruction = %self.name,
208                        value = disc.value,
209                        "Steel discriminant exceeds u8::MAX; falling back to Anchor hash"
210                    );
211                }
212            }
213        }
214
215        crate::discriminator::anchor_discriminator(&format!("global:{}", self.name))
216    }
217}
218
219#[derive(Debug, Clone, Serialize, Deserialize)]
220pub struct IdlInstructionAccountSnapshot {
221    pub name: String,
222    #[serde(default)]
223    pub writable: bool,
224    #[serde(default)]
225    pub signer: bool,
226    #[serde(default)]
227    pub optional: bool,
228    #[serde(default)]
229    pub address: Option<String>,
230    #[serde(default)]
231    pub docs: Vec<String>,
232}
233
234#[derive(Debug, Clone, Serialize, Deserialize)]
235pub struct IdlFieldSnapshot {
236    pub name: String,
237    #[serde(rename = "type")]
238    pub type_: IdlTypeSnapshot,
239}
240
241#[derive(Debug, Clone, Serialize, Deserialize)]
242#[serde(untagged)]
243pub enum IdlTypeSnapshot {
244    Simple(String),
245    Array(IdlArrayTypeSnapshot),
246    Option(IdlOptionTypeSnapshot),
247    Vec(IdlVecTypeSnapshot),
248    HashMap(IdlHashMapTypeSnapshot),
249    Defined(IdlDefinedTypeSnapshot),
250}
251
252#[derive(Debug, Clone, Serialize, Deserialize)]
253pub struct IdlHashMapTypeSnapshot {
254    #[serde(rename = "hashMap", deserialize_with = "deserialize_hash_map")]
255    pub hash_map: (Box<IdlTypeSnapshot>, Box<IdlTypeSnapshot>),
256}
257
258fn deserialize_hash_map<'de, D>(
259    deserializer: D,
260) -> Result<(Box<IdlTypeSnapshot>, Box<IdlTypeSnapshot>), D::Error>
261where
262    D: Deserializer<'de>,
263{
264    use serde::de::Error;
265    let values: Vec<IdlTypeSnapshot> = Vec::deserialize(deserializer)?;
266    if values.len() != 2 {
267        return Err(D::Error::custom("hashMap must have exactly 2 elements"));
268    }
269    let mut iter = values.into_iter();
270    Ok((
271        Box::new(iter.next().expect("length checked")),
272        Box::new(iter.next().expect("length checked")),
273    ))
274}
275
276#[derive(Debug, Clone, Serialize, Deserialize)]
277pub struct IdlArrayTypeSnapshot {
278    pub array: Vec<IdlArrayElementSnapshot>,
279}
280
281#[derive(Debug, Clone, Serialize, Deserialize)]
282#[serde(untagged)]
283pub enum IdlArrayElementSnapshot {
284    Type(IdlTypeSnapshot),
285    TypeName(String),
286    Size(u32),
287}
288
289#[derive(Debug, Clone, Serialize, Deserialize)]
290pub struct IdlOptionTypeSnapshot {
291    pub option: Box<IdlTypeSnapshot>,
292}
293
294#[derive(Debug, Clone, Serialize, Deserialize)]
295pub struct IdlVecTypeSnapshot {
296    pub vec: Box<IdlTypeSnapshot>,
297}
298
299#[derive(Debug, Clone, Serialize, Deserialize)]
300pub struct IdlDefinedTypeSnapshot {
301    pub defined: IdlDefinedInnerSnapshot,
302}
303
304#[derive(Debug, Clone, Serialize, Deserialize)]
305#[serde(untagged)]
306pub enum IdlDefinedInnerSnapshot {
307    Named { name: String },
308    Simple(String),
309}
310
311#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
312#[serde(rename_all = "lowercase")]
313pub enum IdlSerializationSnapshot {
314    Borsh,
315    Bytemuck,
316    #[serde(alias = "bytemuckunsafe")]
317    BytemuckUnsafe,
318}
319
320#[derive(Debug, Clone, Serialize, Deserialize)]
321pub struct IdlTypeDefSnapshot {
322    pub name: String,
323    #[serde(default)]
324    pub docs: Vec<String>,
325    #[serde(default, skip_serializing_if = "Option::is_none")]
326    pub serialization: Option<IdlSerializationSnapshot>,
327    #[serde(rename = "type")]
328    pub type_def: IdlTypeDefKindSnapshot,
329}
330
331#[derive(Debug, Clone, Serialize, Deserialize)]
332#[serde(untagged)]
333pub enum IdlTypeDefKindSnapshot {
334    Struct {
335        kind: String,
336        fields: Vec<IdlFieldSnapshot>,
337    },
338    TupleStruct {
339        kind: String,
340        fields: Vec<IdlTypeSnapshot>,
341    },
342    Enum {
343        kind: String,
344        variants: Vec<IdlEnumVariantSnapshot>,
345    },
346}
347
348#[derive(Debug, Clone, Serialize, Deserialize)]
349pub struct IdlEnumVariantSnapshot {
350    pub name: String,
351}
352
353#[derive(Debug, Clone, Serialize, Deserialize)]
354pub struct IdlEventSnapshot {
355    pub name: String,
356    pub discriminator: Vec<u8>,
357    #[serde(default)]
358    pub docs: Vec<String>,
359}
360
361#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
362pub struct IdlErrorSnapshot {
363    pub code: u32,
364    pub name: String,
365    #[serde(default, skip_serializing_if = "Option::is_none")]
366    pub msg: Option<String>,
367}
368
369#[cfg(test)]
370mod tests {
371    use super::*;
372
373    #[test]
374    fn test_snapshot_serde() {
375        let snapshot = IdlSnapshot {
376            name: "test_program".to_string(),
377            program_id: Some("11111111111111111111111111111111".to_string()),
378            version: "0.1.0".to_string(),
379            accounts: vec![IdlAccountSnapshot {
380                name: "ExampleAccount".to_string(),
381                discriminator: vec![1, 2, 3, 4, 5, 6, 7, 8],
382                docs: vec!["Example account".to_string()],
383                serialization: Some(IdlSerializationSnapshot::Borsh),
384                fields: vec![],
385                type_def: None,
386            }],
387            instructions: vec![IdlInstructionSnapshot {
388                name: "example_instruction".to_string(),
389                discriminator: vec![8, 7, 6, 5, 4, 3, 2, 1],
390                discriminant: None,
391                docs: vec!["Example instruction".to_string()],
392                accounts: vec![IdlInstructionAccountSnapshot {
393                    name: "payer".to_string(),
394                    writable: true,
395                    signer: true,
396                    optional: false,
397                    address: None,
398                    docs: vec![],
399                }],
400                args: vec![IdlFieldSnapshot {
401                    name: "amount".to_string(),
402                    type_: IdlTypeSnapshot::HashMap(IdlHashMapTypeSnapshot {
403                        hash_map: (
404                            Box::new(IdlTypeSnapshot::Simple("u64".to_string())),
405                            Box::new(IdlTypeSnapshot::Simple("string".to_string())),
406                        ),
407                    }),
408                }],
409            }],
410            types: vec![IdlTypeDefSnapshot {
411                name: "ExampleType".to_string(),
412                docs: vec![],
413                serialization: None,
414                type_def: IdlTypeDefKindSnapshot::Struct {
415                    kind: "struct".to_string(),
416                    fields: vec![IdlFieldSnapshot {
417                        name: "value".to_string(),
418                        type_: IdlTypeSnapshot::Simple("u64".to_string()),
419                    }],
420                },
421            }],
422            events: vec![IdlEventSnapshot {
423                name: "ExampleEvent".to_string(),
424                discriminator: vec![0, 0, 0, 0, 0, 0, 0, 1],
425                docs: vec![],
426            }],
427            errors: vec![IdlErrorSnapshot {
428                code: 6000,
429                name: "ExampleError".to_string(),
430                msg: Some("example".to_string()),
431            }],
432            discriminant_size: 8,
433        };
434
435        let serialized = serde_json::to_value(&snapshot).expect("serialize snapshot");
436        let deserialized: IdlSnapshot =
437            serde_json::from_value(serialized.clone()).expect("deserialize snapshot");
438        let round_trip = serde_json::to_value(&deserialized).expect("re-serialize snapshot");
439
440        assert_eq!(serialized, round_trip);
441        assert_eq!(deserialized.name, "test_program");
442    }
443
444    #[test]
445    fn test_hashmap_compat() {
446        let json = r#"{"hashMap":["u64","string"]}"#;
447        let parsed: IdlHashMapTypeSnapshot =
448            serde_json::from_str(json).expect("deserialize hashMap");
449
450        assert!(matches!(
451            parsed.hash_map.0.as_ref(),
452            IdlTypeSnapshot::Simple(value) if value == "u64"
453        ));
454        assert!(matches!(
455            parsed.hash_map.1.as_ref(),
456            IdlTypeSnapshot::Simple(value) if value == "string"
457        ));
458    }
459}