Skip to main content

hyperstack_idl/
snapshot.rs

1//! Snapshot type definitions
2
3use serde::{de::Error as DeError, Deserialize, Deserializer, Serialize};
4
5use crate::types::SteelDiscriminant;
6
7#[derive(Debug, Clone, Serialize)]
8pub struct IdlSnapshot {
9    pub name: String,
10    #[serde(default, skip_serializing_if = "Option::is_none", alias = "address")]
11    pub program_id: Option<String>,
12    pub version: String,
13    pub accounts: Vec<IdlAccountSnapshot>,
14    pub instructions: Vec<IdlInstructionSnapshot>,
15    #[serde(default)]
16    pub types: Vec<IdlTypeDefSnapshot>,
17    #[serde(default)]
18    pub events: Vec<IdlEventSnapshot>,
19    #[serde(default)]
20    pub errors: Vec<IdlErrorSnapshot>,
21    pub discriminant_size: usize,
22}
23
24impl<'de> Deserialize<'de> for IdlSnapshot {
25    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
26    where
27        D: Deserializer<'de>,
28    {
29        // First deserialize to a generic Value to inspect instructions
30        let value = serde_json::Value::deserialize(deserializer)?;
31
32        // Check if any instruction has discriminant (Steel-style) vs discriminator (Anchor-style)
33        let discriminant_size = value
34            .get("instructions")
35            .and_then(|instrs| instrs.as_array())
36            .map(|instrs| {
37                if instrs.is_empty() {
38                    return false;
39                }
40                instrs.iter().all(|ix| {
41                    // Steel-style: has discriminant field, no discriminator or empty discriminator
42                    let has_discriminant = ix.get("discriminant").is_some();
43                    let discriminator = ix.get("discriminator");
44                    let has_discriminator = discriminator
45                        .map(|d| {
46                            !d.is_null() && d.as_array().map(|a| !a.is_empty()).unwrap_or(true)
47                        })
48                        .unwrap_or(false);
49                    has_discriminant && !has_discriminator
50                })
51            })
52            .map(|is_steel| if is_steel { 1 } else { 8 })
53            .unwrap_or(8); // Default to 8 if no instructions
54
55        // Now deserialize the full struct
56        let mut intermediate: IdlSnapshotIntermediate = serde_json::from_value(value)
57            .map_err(|e| DeError::custom(format!("Failed to deserialize IDL: {}", e)))?;
58        intermediate.discriminant_size = discriminant_size;
59
60        Ok(IdlSnapshot {
61            name: intermediate.name,
62            program_id: intermediate.program_id,
63            version: intermediate.version,
64            accounts: intermediate.accounts,
65            instructions: intermediate.instructions,
66            types: intermediate.types,
67            events: intermediate.events,
68            errors: intermediate.errors,
69            discriminant_size: intermediate.discriminant_size,
70        })
71    }
72}
73
74// Intermediate struct for deserialization
75#[derive(Debug, Clone, Deserialize)]
76struct IdlSnapshotIntermediate {
77    pub name: String,
78    #[serde(default, alias = "address")]
79    pub program_id: Option<String>,
80    pub version: String,
81    pub accounts: Vec<IdlAccountSnapshot>,
82    pub instructions: Vec<IdlInstructionSnapshot>,
83    #[serde(default)]
84    pub types: Vec<IdlTypeDefSnapshot>,
85    #[serde(default)]
86    pub events: Vec<IdlEventSnapshot>,
87    #[serde(default)]
88    pub errors: Vec<IdlErrorSnapshot>,
89    #[serde(default)]
90    pub discriminant_size: usize,
91}
92
93#[derive(Debug, Clone, Serialize)]
94pub struct IdlAccountSnapshot {
95    pub name: String,
96    pub discriminator: Vec<u8>,
97    pub docs: Vec<String>,
98    pub serialization: Option<IdlSerializationSnapshot>,
99    /// Account fields - populated from inline type definition
100    pub fields: Vec<IdlFieldSnapshot>,
101    /// Inline type definition (for Steel format with type.fields structure)
102    #[serde(skip_serializing_if = "Option::is_none")]
103    pub type_def: Option<IdlInlineTypeDef>,
104}
105
106// Intermediate struct for deserialization
107#[derive(Deserialize)]
108struct IdlAccountSnapshotIntermediate {
109    pub name: String,
110    pub discriminator: Vec<u8>,
111    #[serde(default)]
112    pub docs: Vec<String>,
113    #[serde(default, skip_serializing_if = "Option::is_none")]
114    pub serialization: Option<IdlSerializationSnapshot>,
115    #[serde(default)]
116    pub fields: Vec<IdlFieldSnapshot>,
117    #[serde(rename = "type", default, skip_serializing_if = "Option::is_none")]
118    pub type_def: Option<IdlInlineTypeDef>,
119}
120
121impl<'de> Deserialize<'de> for IdlAccountSnapshot {
122    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
123    where
124        D: Deserializer<'de>,
125    {
126        let intermediate = IdlAccountSnapshotIntermediate::deserialize(deserializer)?;
127
128        // Normalize fields: if empty but type_def has fields, use those
129        let fields = if intermediate.fields.is_empty() {
130            if let Some(type_def) = intermediate.type_def.as_ref() {
131                type_def.fields.clone()
132            } else {
133                intermediate.fields
134            }
135        } else {
136            intermediate.fields
137        };
138
139        Ok(IdlAccountSnapshot {
140            name: intermediate.name,
141            discriminator: intermediate.discriminator,
142            docs: intermediate.docs,
143            serialization: intermediate.serialization,
144            fields,
145            type_def: intermediate.type_def,
146        })
147    }
148}
149
150/// Inline type definition for account fields (Steel format)
151#[derive(Debug, Clone, Serialize, Deserialize)]
152pub struct IdlInlineTypeDef {
153    pub kind: String,
154    pub fields: Vec<IdlFieldSnapshot>,
155}
156
157#[derive(Debug, Clone, Serialize, Deserialize)]
158pub struct IdlInstructionSnapshot {
159    pub name: String,
160    #[serde(default)]
161    pub discriminator: Vec<u8>,
162    #[serde(default)]
163    pub discriminant: Option<SteelDiscriminant>,
164    #[serde(default)]
165    pub docs: Vec<String>,
166    pub accounts: Vec<IdlInstructionAccountSnapshot>,
167    pub args: Vec<IdlFieldSnapshot>,
168}
169
170impl IdlInstructionSnapshot {
171    /// Get the computed 8-byte discriminator.
172    /// Returns the explicit discriminator if present, otherwise computes from discriminant.
173    pub fn get_discriminator(&self) -> Vec<u8> {
174        if !self.discriminator.is_empty() {
175            return self.discriminator.clone();
176        }
177
178        if let Some(disc) = &self.discriminant {
179            match u8::try_from(disc.value) {
180                Ok(value) => return vec![value],
181                Err(_) => {
182                    tracing::warn!(
183                        instruction = %self.name,
184                        value = disc.value,
185                        "Steel discriminant exceeds u8::MAX; falling back to Anchor hash"
186                    );
187                }
188            }
189        }
190
191        crate::discriminator::anchor_discriminator(&format!("global:{}", self.name))
192    }
193}
194
195#[derive(Debug, Clone, Serialize, Deserialize)]
196pub struct IdlInstructionAccountSnapshot {
197    pub name: String,
198    #[serde(default)]
199    pub writable: bool,
200    #[serde(default)]
201    pub signer: bool,
202    #[serde(default)]
203    pub optional: bool,
204    #[serde(default)]
205    pub address: Option<String>,
206    #[serde(default)]
207    pub docs: Vec<String>,
208}
209
210#[derive(Debug, Clone, Serialize, Deserialize)]
211pub struct IdlFieldSnapshot {
212    pub name: String,
213    #[serde(rename = "type")]
214    pub type_: IdlTypeSnapshot,
215}
216
217#[derive(Debug, Clone, Serialize, Deserialize)]
218#[serde(untagged)]
219pub enum IdlTypeSnapshot {
220    Simple(String),
221    Array(IdlArrayTypeSnapshot),
222    Option(IdlOptionTypeSnapshot),
223    Vec(IdlVecTypeSnapshot),
224    HashMap(IdlHashMapTypeSnapshot),
225    Defined(IdlDefinedTypeSnapshot),
226}
227
228#[derive(Debug, Clone, Serialize, Deserialize)]
229pub struct IdlHashMapTypeSnapshot {
230    #[serde(rename = "hashMap", deserialize_with = "deserialize_hash_map")]
231    pub hash_map: (Box<IdlTypeSnapshot>, Box<IdlTypeSnapshot>),
232}
233
234fn deserialize_hash_map<'de, D>(
235    deserializer: D,
236) -> Result<(Box<IdlTypeSnapshot>, Box<IdlTypeSnapshot>), D::Error>
237where
238    D: Deserializer<'de>,
239{
240    use serde::de::Error;
241    let values: Vec<IdlTypeSnapshot> = Vec::deserialize(deserializer)?;
242    if values.len() != 2 {
243        return Err(D::Error::custom("hashMap must have exactly 2 elements"));
244    }
245    let mut iter = values.into_iter();
246    Ok((
247        Box::new(iter.next().expect("length checked")),
248        Box::new(iter.next().expect("length checked")),
249    ))
250}
251
252#[derive(Debug, Clone, Serialize, Deserialize)]
253pub struct IdlArrayTypeSnapshot {
254    pub array: Vec<IdlArrayElementSnapshot>,
255}
256
257#[derive(Debug, Clone, Serialize, Deserialize)]
258#[serde(untagged)]
259pub enum IdlArrayElementSnapshot {
260    Type(IdlTypeSnapshot),
261    TypeName(String),
262    Size(u32),
263}
264
265#[derive(Debug, Clone, Serialize, Deserialize)]
266pub struct IdlOptionTypeSnapshot {
267    pub option: Box<IdlTypeSnapshot>,
268}
269
270#[derive(Debug, Clone, Serialize, Deserialize)]
271pub struct IdlVecTypeSnapshot {
272    pub vec: Box<IdlTypeSnapshot>,
273}
274
275#[derive(Debug, Clone, Serialize, Deserialize)]
276pub struct IdlDefinedTypeSnapshot {
277    pub defined: IdlDefinedInnerSnapshot,
278}
279
280#[derive(Debug, Clone, Serialize, Deserialize)]
281#[serde(untagged)]
282pub enum IdlDefinedInnerSnapshot {
283    Named { name: String },
284    Simple(String),
285}
286
287#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
288#[serde(rename_all = "lowercase")]
289pub enum IdlSerializationSnapshot {
290    Borsh,
291    Bytemuck,
292    #[serde(alias = "bytemuckunsafe")]
293    BytemuckUnsafe,
294}
295
296#[derive(Debug, Clone, Serialize, Deserialize)]
297pub struct IdlTypeDefSnapshot {
298    pub name: String,
299    #[serde(default)]
300    pub docs: Vec<String>,
301    #[serde(default, skip_serializing_if = "Option::is_none")]
302    pub serialization: Option<IdlSerializationSnapshot>,
303    #[serde(rename = "type")]
304    pub type_def: IdlTypeDefKindSnapshot,
305}
306
307#[derive(Debug, Clone, Serialize, Deserialize)]
308#[serde(untagged)]
309pub enum IdlTypeDefKindSnapshot {
310    Struct {
311        kind: String,
312        fields: Vec<IdlFieldSnapshot>,
313    },
314    TupleStruct {
315        kind: String,
316        fields: Vec<IdlTypeSnapshot>,
317    },
318    Enum {
319        kind: String,
320        variants: Vec<IdlEnumVariantSnapshot>,
321    },
322}
323
324#[derive(Debug, Clone, Serialize, Deserialize)]
325pub struct IdlEnumVariantSnapshot {
326    pub name: String,
327}
328
329#[derive(Debug, Clone, Serialize, Deserialize)]
330pub struct IdlEventSnapshot {
331    pub name: String,
332    pub discriminator: Vec<u8>,
333    #[serde(default)]
334    pub docs: Vec<String>,
335}
336
337#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
338pub struct IdlErrorSnapshot {
339    pub code: u32,
340    pub name: String,
341    #[serde(default, skip_serializing_if = "Option::is_none")]
342    pub msg: Option<String>,
343}
344
345#[cfg(test)]
346mod tests {
347    use super::*;
348
349    #[test]
350    fn test_snapshot_serde() {
351        let snapshot = IdlSnapshot {
352            name: "test_program".to_string(),
353            program_id: Some("11111111111111111111111111111111".to_string()),
354            version: "0.1.0".to_string(),
355            accounts: vec![IdlAccountSnapshot {
356                name: "ExampleAccount".to_string(),
357                discriminator: vec![1, 2, 3, 4, 5, 6, 7, 8],
358                docs: vec!["Example account".to_string()],
359                serialization: Some(IdlSerializationSnapshot::Borsh),
360                fields: vec![],
361                type_def: None,
362            }],
363            instructions: vec![IdlInstructionSnapshot {
364                name: "example_instruction".to_string(),
365                discriminator: vec![8, 7, 6, 5, 4, 3, 2, 1],
366                discriminant: None,
367                docs: vec!["Example instruction".to_string()],
368                accounts: vec![IdlInstructionAccountSnapshot {
369                    name: "payer".to_string(),
370                    writable: true,
371                    signer: true,
372                    optional: false,
373                    address: None,
374                    docs: vec![],
375                }],
376                args: vec![IdlFieldSnapshot {
377                    name: "amount".to_string(),
378                    type_: IdlTypeSnapshot::HashMap(IdlHashMapTypeSnapshot {
379                        hash_map: (
380                            Box::new(IdlTypeSnapshot::Simple("u64".to_string())),
381                            Box::new(IdlTypeSnapshot::Simple("string".to_string())),
382                        ),
383                    }),
384                }],
385            }],
386            types: vec![IdlTypeDefSnapshot {
387                name: "ExampleType".to_string(),
388                docs: vec![],
389                serialization: None,
390                type_def: IdlTypeDefKindSnapshot::Struct {
391                    kind: "struct".to_string(),
392                    fields: vec![IdlFieldSnapshot {
393                        name: "value".to_string(),
394                        type_: IdlTypeSnapshot::Simple("u64".to_string()),
395                    }],
396                },
397            }],
398            events: vec![IdlEventSnapshot {
399                name: "ExampleEvent".to_string(),
400                discriminator: vec![0, 0, 0, 0, 0, 0, 0, 1],
401                docs: vec![],
402            }],
403            errors: vec![IdlErrorSnapshot {
404                code: 6000,
405                name: "ExampleError".to_string(),
406                msg: Some("example".to_string()),
407            }],
408            discriminant_size: 8,
409        };
410
411        let serialized = serde_json::to_value(&snapshot).expect("serialize snapshot");
412        let deserialized: IdlSnapshot =
413            serde_json::from_value(serialized.clone()).expect("deserialize snapshot");
414        let round_trip = serde_json::to_value(&deserialized).expect("re-serialize snapshot");
415
416        assert_eq!(serialized, round_trip);
417        assert_eq!(deserialized.name, "test_program");
418    }
419
420    #[test]
421    fn test_hashmap_compat() {
422        let json = r#"{"hashMap":["u64","string"]}"#;
423        let parsed: IdlHashMapTypeSnapshot =
424            serde_json::from_str(json).expect("deserialize hashMap");
425
426        assert!(matches!(
427            parsed.hash_map.0.as_ref(),
428            IdlTypeSnapshot::Simple(value) if value == "u64"
429        ));
430        assert!(matches!(
431            parsed.hash_map.1.as_ref(),
432            IdlTypeSnapshot::Simple(value) if value == "string"
433        ));
434    }
435}