1use serde::{de::Error as DeError, Deserialize, Deserializer, Serialize};
4
5use crate::types::SteelDiscriminant;
6
7#[derive(Debug, Clone, Serialize)]
8pub struct IdlSnapshot {
9 pub name: String,
10 #[serde(default, skip_serializing_if = "Option::is_none", alias = "address")]
11 pub program_id: Option<String>,
12 pub version: String,
13 pub accounts: Vec<IdlAccountSnapshot>,
14 pub instructions: Vec<IdlInstructionSnapshot>,
15 #[serde(default)]
16 pub types: Vec<IdlTypeDefSnapshot>,
17 #[serde(default)]
18 pub events: Vec<IdlEventSnapshot>,
19 #[serde(default)]
20 pub errors: Vec<IdlErrorSnapshot>,
21 pub discriminant_size: usize,
22}
23
24impl<'de> Deserialize<'de> for IdlSnapshot {
25 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
26 where
27 D: Deserializer<'de>,
28 {
29 let value = serde_json::Value::deserialize(deserializer)?;
31
32 let discriminant_size = value
34 .get("instructions")
35 .and_then(|instrs| instrs.as_array())
36 .map(|instrs| {
37 if instrs.is_empty() {
38 return false;
39 }
40 instrs.iter().all(|ix| {
41 let has_discriminant = ix.get("discriminant").is_some();
43 let discriminator = ix.get("discriminator");
44 let has_discriminator = discriminator
45 .map(|d| {
46 !d.is_null() && d.as_array().map(|a| !a.is_empty()).unwrap_or(true)
47 })
48 .unwrap_or(false);
49 has_discriminant && !has_discriminator
50 })
51 })
52 .map(|is_steel| if is_steel { 1 } else { 8 })
53 .unwrap_or(8); let mut intermediate: IdlSnapshotIntermediate = serde_json::from_value(value)
57 .map_err(|e| DeError::custom(format!("Failed to deserialize IDL: {}", e)))?;
58 intermediate.discriminant_size = discriminant_size;
59
60 Ok(IdlSnapshot {
61 name: intermediate.name,
62 program_id: intermediate.program_id,
63 version: intermediate.version,
64 accounts: intermediate.accounts,
65 instructions: intermediate.instructions,
66 types: intermediate.types,
67 events: intermediate.events,
68 errors: intermediate.errors,
69 discriminant_size: intermediate.discriminant_size,
70 })
71 }
72}
73
74#[derive(Debug, Clone, Deserialize)]
76struct IdlSnapshotIntermediate {
77 pub name: String,
78 #[serde(default, alias = "address")]
79 pub program_id: Option<String>,
80 pub version: String,
81 pub accounts: Vec<IdlAccountSnapshot>,
82 pub instructions: Vec<IdlInstructionSnapshot>,
83 #[serde(default)]
84 pub types: Vec<IdlTypeDefSnapshot>,
85 #[serde(default)]
86 pub events: Vec<IdlEventSnapshot>,
87 #[serde(default)]
88 pub errors: Vec<IdlErrorSnapshot>,
89 #[serde(default)]
90 pub discriminant_size: usize,
91}
92
93#[derive(Debug, Clone, Serialize)]
94pub struct IdlAccountSnapshot {
95 pub name: String,
96 pub discriminator: Vec<u8>,
97 pub docs: Vec<String>,
98 pub serialization: Option<IdlSerializationSnapshot>,
99 pub fields: Vec<IdlFieldSnapshot>,
101 #[serde(skip_serializing_if = "Option::is_none")]
103 pub type_def: Option<IdlInlineTypeDef>,
104}
105
106#[derive(Deserialize)]
108struct IdlAccountSnapshotIntermediate {
109 pub name: String,
110 pub discriminator: Vec<u8>,
111 #[serde(default)]
112 pub docs: Vec<String>,
113 #[serde(default, skip_serializing_if = "Option::is_none")]
114 pub serialization: Option<IdlSerializationSnapshot>,
115 #[serde(default)]
116 pub fields: Vec<IdlFieldSnapshot>,
117 #[serde(rename = "type", default, skip_serializing_if = "Option::is_none")]
118 pub type_def: Option<IdlInlineTypeDef>,
119}
120
121impl<'de> Deserialize<'de> for IdlAccountSnapshot {
122 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
123 where
124 D: Deserializer<'de>,
125 {
126 let intermediate = IdlAccountSnapshotIntermediate::deserialize(deserializer)?;
127
128 let fields = if intermediate.fields.is_empty() {
130 if let Some(type_def) = intermediate.type_def.as_ref() {
131 type_def.fields.clone()
132 } else {
133 intermediate.fields
134 }
135 } else {
136 intermediate.fields
137 };
138
139 Ok(IdlAccountSnapshot {
140 name: intermediate.name,
141 discriminator: intermediate.discriminator,
142 docs: intermediate.docs,
143 serialization: intermediate.serialization,
144 fields,
145 type_def: intermediate.type_def,
146 })
147 }
148}
149
150#[derive(Debug, Clone, Serialize, Deserialize)]
152pub struct IdlInlineTypeDef {
153 pub kind: String,
154 pub fields: Vec<IdlFieldSnapshot>,
155}
156
157#[derive(Debug, Clone, Serialize, Deserialize)]
158pub struct IdlInstructionSnapshot {
159 pub name: String,
160 #[serde(default)]
161 pub discriminator: Vec<u8>,
162 #[serde(default)]
163 pub discriminant: Option<SteelDiscriminant>,
164 #[serde(default)]
165 pub docs: Vec<String>,
166 pub accounts: Vec<IdlInstructionAccountSnapshot>,
167 pub args: Vec<IdlFieldSnapshot>,
168}
169
170impl IdlInstructionSnapshot {
171 pub fn get_discriminator(&self) -> Vec<u8> {
174 if !self.discriminator.is_empty() {
175 return self.discriminator.clone();
176 }
177
178 if let Some(disc) = &self.discriminant {
179 match u8::try_from(disc.value) {
180 Ok(value) => return vec![value],
181 Err(_) => {
182 tracing::warn!(
183 instruction = %self.name,
184 value = disc.value,
185 "Steel discriminant exceeds u8::MAX; falling back to Anchor hash"
186 );
187 }
188 }
189 }
190
191 crate::discriminator::anchor_discriminator(&format!("global:{}", self.name))
192 }
193}
194
195#[derive(Debug, Clone, Serialize, Deserialize)]
196pub struct IdlInstructionAccountSnapshot {
197 pub name: String,
198 #[serde(default)]
199 pub writable: bool,
200 #[serde(default)]
201 pub signer: bool,
202 #[serde(default)]
203 pub optional: bool,
204 #[serde(default)]
205 pub address: Option<String>,
206 #[serde(default)]
207 pub docs: Vec<String>,
208}
209
210#[derive(Debug, Clone, Serialize, Deserialize)]
211pub struct IdlFieldSnapshot {
212 pub name: String,
213 #[serde(rename = "type")]
214 pub type_: IdlTypeSnapshot,
215}
216
217#[derive(Debug, Clone, Serialize, Deserialize)]
218#[serde(untagged)]
219pub enum IdlTypeSnapshot {
220 Simple(String),
221 Array(IdlArrayTypeSnapshot),
222 Option(IdlOptionTypeSnapshot),
223 Vec(IdlVecTypeSnapshot),
224 HashMap(IdlHashMapTypeSnapshot),
225 Defined(IdlDefinedTypeSnapshot),
226}
227
228#[derive(Debug, Clone, Serialize, Deserialize)]
229pub struct IdlHashMapTypeSnapshot {
230 #[serde(rename = "hashMap", deserialize_with = "deserialize_hash_map")]
231 pub hash_map: (Box<IdlTypeSnapshot>, Box<IdlTypeSnapshot>),
232}
233
234fn deserialize_hash_map<'de, D>(
235 deserializer: D,
236) -> Result<(Box<IdlTypeSnapshot>, Box<IdlTypeSnapshot>), D::Error>
237where
238 D: Deserializer<'de>,
239{
240 use serde::de::Error;
241 let values: Vec<IdlTypeSnapshot> = Vec::deserialize(deserializer)?;
242 if values.len() != 2 {
243 return Err(D::Error::custom("hashMap must have exactly 2 elements"));
244 }
245 let mut iter = values.into_iter();
246 Ok((
247 Box::new(iter.next().expect("length checked")),
248 Box::new(iter.next().expect("length checked")),
249 ))
250}
251
252#[derive(Debug, Clone, Serialize, Deserialize)]
253pub struct IdlArrayTypeSnapshot {
254 pub array: Vec<IdlArrayElementSnapshot>,
255}
256
257#[derive(Debug, Clone, Serialize, Deserialize)]
258#[serde(untagged)]
259pub enum IdlArrayElementSnapshot {
260 Type(IdlTypeSnapshot),
261 TypeName(String),
262 Size(u32),
263}
264
265#[derive(Debug, Clone, Serialize, Deserialize)]
266pub struct IdlOptionTypeSnapshot {
267 pub option: Box<IdlTypeSnapshot>,
268}
269
270#[derive(Debug, Clone, Serialize, Deserialize)]
271pub struct IdlVecTypeSnapshot {
272 pub vec: Box<IdlTypeSnapshot>,
273}
274
275#[derive(Debug, Clone, Serialize, Deserialize)]
276pub struct IdlDefinedTypeSnapshot {
277 pub defined: IdlDefinedInnerSnapshot,
278}
279
280#[derive(Debug, Clone, Serialize, Deserialize)]
281#[serde(untagged)]
282pub enum IdlDefinedInnerSnapshot {
283 Named { name: String },
284 Simple(String),
285}
286
287#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
288#[serde(rename_all = "lowercase")]
289pub enum IdlSerializationSnapshot {
290 Borsh,
291 Bytemuck,
292 #[serde(alias = "bytemuckunsafe")]
293 BytemuckUnsafe,
294}
295
296#[derive(Debug, Clone, Serialize, Deserialize)]
297pub struct IdlTypeDefSnapshot {
298 pub name: String,
299 #[serde(default)]
300 pub docs: Vec<String>,
301 #[serde(default, skip_serializing_if = "Option::is_none")]
302 pub serialization: Option<IdlSerializationSnapshot>,
303 #[serde(rename = "type")]
304 pub type_def: IdlTypeDefKindSnapshot,
305}
306
307#[derive(Debug, Clone, Serialize, Deserialize)]
308#[serde(untagged)]
309pub enum IdlTypeDefKindSnapshot {
310 Struct {
311 kind: String,
312 fields: Vec<IdlFieldSnapshot>,
313 },
314 TupleStruct {
315 kind: String,
316 fields: Vec<IdlTypeSnapshot>,
317 },
318 Enum {
319 kind: String,
320 variants: Vec<IdlEnumVariantSnapshot>,
321 },
322}
323
324#[derive(Debug, Clone, Serialize, Deserialize)]
325pub struct IdlEnumVariantSnapshot {
326 pub name: String,
327}
328
329#[derive(Debug, Clone, Serialize, Deserialize)]
330pub struct IdlEventSnapshot {
331 pub name: String,
332 pub discriminator: Vec<u8>,
333 #[serde(default)]
334 pub docs: Vec<String>,
335}
336
337#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
338pub struct IdlErrorSnapshot {
339 pub code: u32,
340 pub name: String,
341 #[serde(default, skip_serializing_if = "Option::is_none")]
342 pub msg: Option<String>,
343}
344
345#[cfg(test)]
346mod tests {
347 use super::*;
348
349 #[test]
350 fn test_snapshot_serde() {
351 let snapshot = IdlSnapshot {
352 name: "test_program".to_string(),
353 program_id: Some("11111111111111111111111111111111".to_string()),
354 version: "0.1.0".to_string(),
355 accounts: vec![IdlAccountSnapshot {
356 name: "ExampleAccount".to_string(),
357 discriminator: vec![1, 2, 3, 4, 5, 6, 7, 8],
358 docs: vec!["Example account".to_string()],
359 serialization: Some(IdlSerializationSnapshot::Borsh),
360 fields: vec![],
361 type_def: None,
362 }],
363 instructions: vec![IdlInstructionSnapshot {
364 name: "example_instruction".to_string(),
365 discriminator: vec![8, 7, 6, 5, 4, 3, 2, 1],
366 discriminant: None,
367 docs: vec!["Example instruction".to_string()],
368 accounts: vec![IdlInstructionAccountSnapshot {
369 name: "payer".to_string(),
370 writable: true,
371 signer: true,
372 optional: false,
373 address: None,
374 docs: vec![],
375 }],
376 args: vec![IdlFieldSnapshot {
377 name: "amount".to_string(),
378 type_: IdlTypeSnapshot::HashMap(IdlHashMapTypeSnapshot {
379 hash_map: (
380 Box::new(IdlTypeSnapshot::Simple("u64".to_string())),
381 Box::new(IdlTypeSnapshot::Simple("string".to_string())),
382 ),
383 }),
384 }],
385 }],
386 types: vec![IdlTypeDefSnapshot {
387 name: "ExampleType".to_string(),
388 docs: vec![],
389 serialization: None,
390 type_def: IdlTypeDefKindSnapshot::Struct {
391 kind: "struct".to_string(),
392 fields: vec![IdlFieldSnapshot {
393 name: "value".to_string(),
394 type_: IdlTypeSnapshot::Simple("u64".to_string()),
395 }],
396 },
397 }],
398 events: vec![IdlEventSnapshot {
399 name: "ExampleEvent".to_string(),
400 discriminator: vec![0, 0, 0, 0, 0, 0, 0, 1],
401 docs: vec![],
402 }],
403 errors: vec![IdlErrorSnapshot {
404 code: 6000,
405 name: "ExampleError".to_string(),
406 msg: Some("example".to_string()),
407 }],
408 discriminant_size: 8,
409 };
410
411 let serialized = serde_json::to_value(&snapshot).expect("serialize snapshot");
412 let deserialized: IdlSnapshot =
413 serde_json::from_value(serialized.clone()).expect("deserialize snapshot");
414 let round_trip = serde_json::to_value(&deserialized).expect("re-serialize snapshot");
415
416 assert_eq!(serialized, round_trip);
417 assert_eq!(deserialized.name, "test_program");
418 }
419
420 #[test]
421 fn test_hashmap_compat() {
422 let json = r#"{"hashMap":["u64","string"]}"#;
423 let parsed: IdlHashMapTypeSnapshot =
424 serde_json::from_str(json).expect("deserialize hashMap");
425
426 assert!(matches!(
427 parsed.hash_map.0.as_ref(),
428 IdlTypeSnapshot::Simple(value) if value == "u64"
429 ));
430 assert!(matches!(
431 parsed.hash_map.1.as_ref(),
432 IdlTypeSnapshot::Simple(value) if value == "string"
433 ));
434 }
435}