1use serde::{de::Error as DeError, Deserialize, Deserializer, Serialize};
4
5use crate::types::SteelDiscriminant;
6
7#[derive(Debug, Clone, Serialize)]
8pub struct IdlSnapshot {
9 pub name: String,
10 #[serde(default, skip_serializing_if = "Option::is_none", alias = "address")]
11 pub program_id: Option<String>,
12 pub version: String,
13 pub accounts: Vec<IdlAccountSnapshot>,
14 pub instructions: Vec<IdlInstructionSnapshot>,
15 #[serde(default)]
16 pub types: Vec<IdlTypeDefSnapshot>,
17 #[serde(default)]
18 pub events: Vec<IdlEventSnapshot>,
19 #[serde(default)]
20 pub errors: Vec<IdlErrorSnapshot>,
21 pub discriminant_size: usize,
22}
23
24impl<'de> Deserialize<'de> for IdlSnapshot {
25 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
26 where
27 D: Deserializer<'de>,
28 {
29 let value = serde_json::Value::deserialize(deserializer)?;
31
32 let discriminant_size = value
34 .get("instructions")
35 .and_then(|instrs| instrs.as_array())
36 .map(|instrs| {
37 if instrs.is_empty() {
38 return false;
39 }
40 instrs.iter().all(|ix| {
41 let discriminator = ix.get("discriminator");
42 let disc_len = discriminator
43 .and_then(|d| d.as_array())
44 .map(|a| a.len())
45 .unwrap_or(0);
46
47 let has_discriminant = ix
52 .get("discriminant")
53 .map(|v| !v.is_null())
54 .unwrap_or(false);
55 let has_discriminator = discriminator
56 .map(|d| {
57 !d.is_null() && d.as_array().map(|a| !a.is_empty()).unwrap_or(true)
58 })
59 .unwrap_or(false);
60
61 let is_steel_discriminant = has_discriminant && !has_discriminator;
63
64 let is_steel_short_discriminator = !has_discriminant && disc_len == 1;
68
69 is_steel_discriminant || is_steel_short_discriminator
70 })
71 })
72 .map(|is_steel| if is_steel { 1 } else { 8 })
73 .unwrap_or(8); let mut intermediate: IdlSnapshotIntermediate = serde_json::from_value(value)
77 .map_err(|e| DeError::custom(format!("Failed to deserialize IDL: {}", e)))?;
78 if intermediate.discriminant_size == 0 {
81 intermediate.discriminant_size = discriminant_size;
82 }
83
84 Ok(IdlSnapshot {
85 name: intermediate.name,
86 program_id: intermediate.program_id,
87 version: intermediate.version,
88 accounts: intermediate.accounts,
89 instructions: intermediate.instructions,
90 types: intermediate.types,
91 events: intermediate.events,
92 errors: intermediate.errors,
93 discriminant_size: intermediate.discriminant_size,
94 })
95 }
96}
97
98#[derive(Debug, Clone, Deserialize)]
100struct IdlSnapshotIntermediate {
101 pub name: String,
102 #[serde(default, alias = "address")]
103 pub program_id: Option<String>,
104 pub version: String,
105 pub accounts: Vec<IdlAccountSnapshot>,
106 pub instructions: Vec<IdlInstructionSnapshot>,
107 #[serde(default)]
108 pub types: Vec<IdlTypeDefSnapshot>,
109 #[serde(default)]
110 pub events: Vec<IdlEventSnapshot>,
111 #[serde(default)]
112 pub errors: Vec<IdlErrorSnapshot>,
113 #[serde(default)]
114 pub discriminant_size: usize,
115}
116
117#[derive(Debug, Clone, Serialize)]
118pub struct IdlAccountSnapshot {
119 pub name: String,
120 pub discriminator: Vec<u8>,
121 pub docs: Vec<String>,
122 pub serialization: Option<IdlSerializationSnapshot>,
123 pub fields: Vec<IdlFieldSnapshot>,
125 #[serde(skip_serializing_if = "Option::is_none")]
127 pub type_def: Option<IdlInlineTypeDef>,
128}
129
130#[derive(Deserialize)]
132struct IdlAccountSnapshotIntermediate {
133 pub name: String,
134 pub discriminator: Vec<u8>,
135 #[serde(default)]
136 pub docs: Vec<String>,
137 #[serde(default, skip_serializing_if = "Option::is_none")]
138 pub serialization: Option<IdlSerializationSnapshot>,
139 #[serde(default)]
140 pub fields: Vec<IdlFieldSnapshot>,
141 #[serde(rename = "type", default, skip_serializing_if = "Option::is_none")]
142 pub type_def: Option<IdlInlineTypeDef>,
143}
144
145impl<'de> Deserialize<'de> for IdlAccountSnapshot {
146 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
147 where
148 D: Deserializer<'de>,
149 {
150 let intermediate = IdlAccountSnapshotIntermediate::deserialize(deserializer)?;
151
152 let fields = if intermediate.fields.is_empty() {
154 if let Some(type_def) = intermediate.type_def.as_ref() {
155 type_def.fields.clone()
156 } else {
157 intermediate.fields
158 }
159 } else {
160 intermediate.fields
161 };
162
163 Ok(IdlAccountSnapshot {
164 name: intermediate.name,
165 discriminator: intermediate.discriminator,
166 docs: intermediate.docs,
167 serialization: intermediate.serialization,
168 fields,
169 type_def: intermediate.type_def,
170 })
171 }
172}
173
174#[derive(Debug, Clone, Serialize, Deserialize)]
176pub struct IdlInlineTypeDef {
177 pub kind: String,
178 pub fields: Vec<IdlFieldSnapshot>,
179}
180
181#[derive(Debug, Clone, Serialize, Deserialize)]
182pub struct IdlInstructionSnapshot {
183 pub name: String,
184 #[serde(default)]
185 pub discriminator: Vec<u8>,
186 #[serde(default)]
187 pub discriminant: Option<SteelDiscriminant>,
188 #[serde(default)]
189 pub docs: Vec<String>,
190 pub accounts: Vec<IdlInstructionAccountSnapshot>,
191 pub args: Vec<IdlFieldSnapshot>,
192}
193
194impl IdlInstructionSnapshot {
195 pub fn get_discriminator(&self) -> Vec<u8> {
198 if !self.discriminator.is_empty() {
199 return self.discriminator.clone();
200 }
201
202 if let Some(disc) = &self.discriminant {
203 match u8::try_from(disc.value) {
204 Ok(value) => return vec![value],
205 Err(_) => {
206 tracing::warn!(
207 instruction = %self.name,
208 value = disc.value,
209 "Steel discriminant exceeds u8::MAX; falling back to Anchor hash"
210 );
211 }
212 }
213 }
214
215 crate::discriminator::anchor_discriminator(&format!("global:{}", self.name))
216 }
217}
218
219#[derive(Debug, Clone, Serialize, Deserialize)]
220pub struct IdlInstructionAccountSnapshot {
221 pub name: String,
222 #[serde(default)]
223 pub writable: bool,
224 #[serde(default)]
225 pub signer: bool,
226 #[serde(default)]
227 pub optional: bool,
228 #[serde(default)]
229 pub address: Option<String>,
230 #[serde(default)]
231 pub docs: Vec<String>,
232}
233
234#[derive(Debug, Clone, Serialize, Deserialize)]
235pub struct IdlFieldSnapshot {
236 pub name: String,
237 #[serde(rename = "type")]
238 pub type_: IdlTypeSnapshot,
239}
240
241#[derive(Debug, Clone, Serialize, Deserialize)]
242#[serde(untagged)]
243pub enum IdlTypeSnapshot {
244 Simple(String),
245 Array(IdlArrayTypeSnapshot),
246 Option(IdlOptionTypeSnapshot),
247 Vec(IdlVecTypeSnapshot),
248 HashMap(IdlHashMapTypeSnapshot),
249 Defined(IdlDefinedTypeSnapshot),
250}
251
252#[derive(Debug, Clone, Serialize, Deserialize)]
253pub struct IdlHashMapTypeSnapshot {
254 #[serde(rename = "hashMap", deserialize_with = "deserialize_hash_map")]
255 pub hash_map: (Box<IdlTypeSnapshot>, Box<IdlTypeSnapshot>),
256}
257
258fn deserialize_hash_map<'de, D>(
259 deserializer: D,
260) -> Result<(Box<IdlTypeSnapshot>, Box<IdlTypeSnapshot>), D::Error>
261where
262 D: Deserializer<'de>,
263{
264 use serde::de::Error;
265 let values: Vec<IdlTypeSnapshot> = Vec::deserialize(deserializer)?;
266 if values.len() != 2 {
267 return Err(D::Error::custom("hashMap must have exactly 2 elements"));
268 }
269 let mut iter = values.into_iter();
270 Ok((
271 Box::new(iter.next().expect("length checked")),
272 Box::new(iter.next().expect("length checked")),
273 ))
274}
275
276#[derive(Debug, Clone, Serialize, Deserialize)]
277pub struct IdlArrayTypeSnapshot {
278 pub array: Vec<IdlArrayElementSnapshot>,
279}
280
281#[derive(Debug, Clone, Serialize, Deserialize)]
282#[serde(untagged)]
283pub enum IdlArrayElementSnapshot {
284 Type(IdlTypeSnapshot),
285 TypeName(String),
286 Size(u32),
287}
288
289#[derive(Debug, Clone, Serialize, Deserialize)]
290pub struct IdlOptionTypeSnapshot {
291 pub option: Box<IdlTypeSnapshot>,
292}
293
294#[derive(Debug, Clone, Serialize, Deserialize)]
295pub struct IdlVecTypeSnapshot {
296 pub vec: Box<IdlTypeSnapshot>,
297}
298
299#[derive(Debug, Clone, Serialize, Deserialize)]
300pub struct IdlDefinedTypeSnapshot {
301 pub defined: IdlDefinedInnerSnapshot,
302}
303
304#[derive(Debug, Clone, Serialize, Deserialize)]
305#[serde(untagged)]
306pub enum IdlDefinedInnerSnapshot {
307 Named { name: String },
308 Simple(String),
309}
310
311#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
312#[serde(rename_all = "lowercase")]
313pub enum IdlSerializationSnapshot {
314 Borsh,
315 Bytemuck,
316 #[serde(alias = "bytemuckunsafe")]
317 BytemuckUnsafe,
318}
319
320#[derive(Debug, Clone, Serialize, Deserialize)]
321pub struct IdlTypeDefSnapshot {
322 pub name: String,
323 #[serde(default)]
324 pub docs: Vec<String>,
325 #[serde(default, skip_serializing_if = "Option::is_none")]
326 pub serialization: Option<IdlSerializationSnapshot>,
327 #[serde(rename = "type")]
328 pub type_def: IdlTypeDefKindSnapshot,
329}
330
331#[derive(Debug, Clone, Serialize, Deserialize)]
332#[serde(untagged)]
333pub enum IdlTypeDefKindSnapshot {
334 Struct {
335 kind: String,
336 fields: Vec<IdlFieldSnapshot>,
337 },
338 TupleStruct {
339 kind: String,
340 fields: Vec<IdlTypeSnapshot>,
341 },
342 Enum {
343 kind: String,
344 variants: Vec<IdlEnumVariantSnapshot>,
345 },
346}
347
348#[derive(Debug, Clone, Serialize, Deserialize)]
349pub struct IdlEnumVariantSnapshot {
350 pub name: String,
351}
352
353#[derive(Debug, Clone, Serialize, Deserialize)]
354pub struct IdlEventSnapshot {
355 pub name: String,
356 pub discriminator: Vec<u8>,
357 #[serde(default)]
358 pub docs: Vec<String>,
359}
360
361#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
362pub struct IdlErrorSnapshot {
363 pub code: u32,
364 pub name: String,
365 #[serde(default, skip_serializing_if = "Option::is_none")]
366 pub msg: Option<String>,
367}
368
369#[cfg(test)]
370mod tests {
371 use super::*;
372
373 #[test]
374 fn test_snapshot_serde() {
375 let snapshot = IdlSnapshot {
376 name: "test_program".to_string(),
377 program_id: Some("11111111111111111111111111111111".to_string()),
378 version: "0.1.0".to_string(),
379 accounts: vec![IdlAccountSnapshot {
380 name: "ExampleAccount".to_string(),
381 discriminator: vec![1, 2, 3, 4, 5, 6, 7, 8],
382 docs: vec!["Example account".to_string()],
383 serialization: Some(IdlSerializationSnapshot::Borsh),
384 fields: vec![],
385 type_def: None,
386 }],
387 instructions: vec![IdlInstructionSnapshot {
388 name: "example_instruction".to_string(),
389 discriminator: vec![8, 7, 6, 5, 4, 3, 2, 1],
390 discriminant: None,
391 docs: vec!["Example instruction".to_string()],
392 accounts: vec![IdlInstructionAccountSnapshot {
393 name: "payer".to_string(),
394 writable: true,
395 signer: true,
396 optional: false,
397 address: None,
398 docs: vec![],
399 }],
400 args: vec![IdlFieldSnapshot {
401 name: "amount".to_string(),
402 type_: IdlTypeSnapshot::HashMap(IdlHashMapTypeSnapshot {
403 hash_map: (
404 Box::new(IdlTypeSnapshot::Simple("u64".to_string())),
405 Box::new(IdlTypeSnapshot::Simple("string".to_string())),
406 ),
407 }),
408 }],
409 }],
410 types: vec![IdlTypeDefSnapshot {
411 name: "ExampleType".to_string(),
412 docs: vec![],
413 serialization: None,
414 type_def: IdlTypeDefKindSnapshot::Struct {
415 kind: "struct".to_string(),
416 fields: vec![IdlFieldSnapshot {
417 name: "value".to_string(),
418 type_: IdlTypeSnapshot::Simple("u64".to_string()),
419 }],
420 },
421 }],
422 events: vec![IdlEventSnapshot {
423 name: "ExampleEvent".to_string(),
424 discriminator: vec![0, 0, 0, 0, 0, 0, 0, 1],
425 docs: vec![],
426 }],
427 errors: vec![IdlErrorSnapshot {
428 code: 6000,
429 name: "ExampleError".to_string(),
430 msg: Some("example".to_string()),
431 }],
432 discriminant_size: 8,
433 };
434
435 let serialized = serde_json::to_value(&snapshot).expect("serialize snapshot");
436 let deserialized: IdlSnapshot =
437 serde_json::from_value(serialized.clone()).expect("deserialize snapshot");
438 let round_trip = serde_json::to_value(&deserialized).expect("re-serialize snapshot");
439
440 assert_eq!(serialized, round_trip);
441 assert_eq!(deserialized.name, "test_program");
442 }
443
444 #[test]
445 fn test_hashmap_compat() {
446 let json = r#"{"hashMap":["u64","string"]}"#;
447 let parsed: IdlHashMapTypeSnapshot =
448 serde_json::from_str(json).expect("deserialize hashMap");
449
450 assert!(matches!(
451 parsed.hash_map.0.as_ref(),
452 IdlTypeSnapshot::Simple(value) if value == "u64"
453 ));
454 assert!(matches!(
455 parsed.hash_map.1.as_ref(),
456 IdlTypeSnapshot::Simple(value) if value == "string"
457 ));
458 }
459}