1use serde::{de::Error as DeError, Deserialize, Deserializer, Serialize};
4
5use crate::types::SteelDiscriminant;
6
7#[derive(Debug, Clone, Serialize)]
8pub struct IdlSnapshot {
9 pub name: String,
10 #[serde(default, skip_serializing_if = "Option::is_none", alias = "address")]
11 pub program_id: Option<String>,
12 pub version: String,
13 pub accounts: Vec<IdlAccountSnapshot>,
14 pub instructions: Vec<IdlInstructionSnapshot>,
15 #[serde(default)]
16 pub types: Vec<IdlTypeDefSnapshot>,
17 #[serde(default)]
18 pub events: Vec<IdlEventSnapshot>,
19 #[serde(default)]
20 pub errors: Vec<IdlErrorSnapshot>,
21 pub discriminant_size: usize,
22}
23
24impl<'de> Deserialize<'de> for IdlSnapshot {
25 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
26 where
27 D: Deserializer<'de>,
28 {
29 let value = serde_json::Value::deserialize(deserializer)?;
31
32 let discriminant_size = value
34 .get("instructions")
35 .and_then(|instrs| instrs.as_array())
36 .map(|instrs| {
37 if instrs.is_empty() {
38 return false;
39 }
40 instrs.iter().all(|ix| {
41 let discriminator = ix.get("discriminator");
42 let disc_len = discriminator
43 .and_then(|d| d.as_array())
44 .map(|a| a.len())
45 .unwrap_or(0);
46
47 let has_discriminant = ix.get("discriminant").is_some();
49 let has_discriminator = discriminator
50 .map(|d| {
51 !d.is_null() && d.as_array().map(|a| !a.is_empty()).unwrap_or(true)
52 })
53 .unwrap_or(false);
54 let is_steel_discriminant = has_discriminant && !has_discriminator;
55
56 let is_steel_short_discriminator = !has_discriminant && disc_len == 1;
61
62 is_steel_discriminant || is_steel_short_discriminator
63 })
64 })
65 .map(|is_steel| if is_steel { 1 } else { 8 })
66 .unwrap_or(8); let mut intermediate: IdlSnapshotIntermediate = serde_json::from_value(value)
70 .map_err(|e| DeError::custom(format!("Failed to deserialize IDL: {}", e)))?;
71 intermediate.discriminant_size = discriminant_size;
72
73 Ok(IdlSnapshot {
74 name: intermediate.name,
75 program_id: intermediate.program_id,
76 version: intermediate.version,
77 accounts: intermediate.accounts,
78 instructions: intermediate.instructions,
79 types: intermediate.types,
80 events: intermediate.events,
81 errors: intermediate.errors,
82 discriminant_size: intermediate.discriminant_size,
83 })
84 }
85}
86
87#[derive(Debug, Clone, Deserialize)]
89struct IdlSnapshotIntermediate {
90 pub name: String,
91 #[serde(default, alias = "address")]
92 pub program_id: Option<String>,
93 pub version: String,
94 pub accounts: Vec<IdlAccountSnapshot>,
95 pub instructions: Vec<IdlInstructionSnapshot>,
96 #[serde(default)]
97 pub types: Vec<IdlTypeDefSnapshot>,
98 #[serde(default)]
99 pub events: Vec<IdlEventSnapshot>,
100 #[serde(default)]
101 pub errors: Vec<IdlErrorSnapshot>,
102 #[serde(default)]
103 pub discriminant_size: usize,
104}
105
106#[derive(Debug, Clone, Serialize)]
107pub struct IdlAccountSnapshot {
108 pub name: String,
109 pub discriminator: Vec<u8>,
110 pub docs: Vec<String>,
111 pub serialization: Option<IdlSerializationSnapshot>,
112 pub fields: Vec<IdlFieldSnapshot>,
114 #[serde(skip_serializing_if = "Option::is_none")]
116 pub type_def: Option<IdlInlineTypeDef>,
117}
118
119#[derive(Deserialize)]
121struct IdlAccountSnapshotIntermediate {
122 pub name: String,
123 pub discriminator: Vec<u8>,
124 #[serde(default)]
125 pub docs: Vec<String>,
126 #[serde(default, skip_serializing_if = "Option::is_none")]
127 pub serialization: Option<IdlSerializationSnapshot>,
128 #[serde(default)]
129 pub fields: Vec<IdlFieldSnapshot>,
130 #[serde(rename = "type", default, skip_serializing_if = "Option::is_none")]
131 pub type_def: Option<IdlInlineTypeDef>,
132}
133
134impl<'de> Deserialize<'de> for IdlAccountSnapshot {
135 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
136 where
137 D: Deserializer<'de>,
138 {
139 let intermediate = IdlAccountSnapshotIntermediate::deserialize(deserializer)?;
140
141 let fields = if intermediate.fields.is_empty() {
143 if let Some(type_def) = intermediate.type_def.as_ref() {
144 type_def.fields.clone()
145 } else {
146 intermediate.fields
147 }
148 } else {
149 intermediate.fields
150 };
151
152 Ok(IdlAccountSnapshot {
153 name: intermediate.name,
154 discriminator: intermediate.discriminator,
155 docs: intermediate.docs,
156 serialization: intermediate.serialization,
157 fields,
158 type_def: intermediate.type_def,
159 })
160 }
161}
162
163#[derive(Debug, Clone, Serialize, Deserialize)]
165pub struct IdlInlineTypeDef {
166 pub kind: String,
167 pub fields: Vec<IdlFieldSnapshot>,
168}
169
170#[derive(Debug, Clone, Serialize, Deserialize)]
171pub struct IdlInstructionSnapshot {
172 pub name: String,
173 #[serde(default)]
174 pub discriminator: Vec<u8>,
175 #[serde(default)]
176 pub discriminant: Option<SteelDiscriminant>,
177 #[serde(default)]
178 pub docs: Vec<String>,
179 pub accounts: Vec<IdlInstructionAccountSnapshot>,
180 pub args: Vec<IdlFieldSnapshot>,
181}
182
183impl IdlInstructionSnapshot {
184 pub fn get_discriminator(&self) -> Vec<u8> {
187 if !self.discriminator.is_empty() {
188 return self.discriminator.clone();
189 }
190
191 if let Some(disc) = &self.discriminant {
192 match u8::try_from(disc.value) {
193 Ok(value) => return vec![value],
194 Err(_) => {
195 tracing::warn!(
196 instruction = %self.name,
197 value = disc.value,
198 "Steel discriminant exceeds u8::MAX; falling back to Anchor hash"
199 );
200 }
201 }
202 }
203
204 crate::discriminator::anchor_discriminator(&format!("global:{}", self.name))
205 }
206}
207
208#[derive(Debug, Clone, Serialize, Deserialize)]
209pub struct IdlInstructionAccountSnapshot {
210 pub name: String,
211 #[serde(default)]
212 pub writable: bool,
213 #[serde(default)]
214 pub signer: bool,
215 #[serde(default)]
216 pub optional: bool,
217 #[serde(default)]
218 pub address: Option<String>,
219 #[serde(default)]
220 pub docs: Vec<String>,
221}
222
223#[derive(Debug, Clone, Serialize, Deserialize)]
224pub struct IdlFieldSnapshot {
225 pub name: String,
226 #[serde(rename = "type")]
227 pub type_: IdlTypeSnapshot,
228}
229
230#[derive(Debug, Clone, Serialize, Deserialize)]
231#[serde(untagged)]
232pub enum IdlTypeSnapshot {
233 Simple(String),
234 Array(IdlArrayTypeSnapshot),
235 Option(IdlOptionTypeSnapshot),
236 Vec(IdlVecTypeSnapshot),
237 HashMap(IdlHashMapTypeSnapshot),
238 Defined(IdlDefinedTypeSnapshot),
239}
240
241#[derive(Debug, Clone, Serialize, Deserialize)]
242pub struct IdlHashMapTypeSnapshot {
243 #[serde(rename = "hashMap", deserialize_with = "deserialize_hash_map")]
244 pub hash_map: (Box<IdlTypeSnapshot>, Box<IdlTypeSnapshot>),
245}
246
247fn deserialize_hash_map<'de, D>(
248 deserializer: D,
249) -> Result<(Box<IdlTypeSnapshot>, Box<IdlTypeSnapshot>), D::Error>
250where
251 D: Deserializer<'de>,
252{
253 use serde::de::Error;
254 let values: Vec<IdlTypeSnapshot> = Vec::deserialize(deserializer)?;
255 if values.len() != 2 {
256 return Err(D::Error::custom("hashMap must have exactly 2 elements"));
257 }
258 let mut iter = values.into_iter();
259 Ok((
260 Box::new(iter.next().expect("length checked")),
261 Box::new(iter.next().expect("length checked")),
262 ))
263}
264
265#[derive(Debug, Clone, Serialize, Deserialize)]
266pub struct IdlArrayTypeSnapshot {
267 pub array: Vec<IdlArrayElementSnapshot>,
268}
269
270#[derive(Debug, Clone, Serialize, Deserialize)]
271#[serde(untagged)]
272pub enum IdlArrayElementSnapshot {
273 Type(IdlTypeSnapshot),
274 TypeName(String),
275 Size(u32),
276}
277
278#[derive(Debug, Clone, Serialize, Deserialize)]
279pub struct IdlOptionTypeSnapshot {
280 pub option: Box<IdlTypeSnapshot>,
281}
282
283#[derive(Debug, Clone, Serialize, Deserialize)]
284pub struct IdlVecTypeSnapshot {
285 pub vec: Box<IdlTypeSnapshot>,
286}
287
288#[derive(Debug, Clone, Serialize, Deserialize)]
289pub struct IdlDefinedTypeSnapshot {
290 pub defined: IdlDefinedInnerSnapshot,
291}
292
293#[derive(Debug, Clone, Serialize, Deserialize)]
294#[serde(untagged)]
295pub enum IdlDefinedInnerSnapshot {
296 Named { name: String },
297 Simple(String),
298}
299
300#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
301#[serde(rename_all = "lowercase")]
302pub enum IdlSerializationSnapshot {
303 Borsh,
304 Bytemuck,
305 #[serde(alias = "bytemuckunsafe")]
306 BytemuckUnsafe,
307}
308
309#[derive(Debug, Clone, Serialize, Deserialize)]
310pub struct IdlTypeDefSnapshot {
311 pub name: String,
312 #[serde(default)]
313 pub docs: Vec<String>,
314 #[serde(default, skip_serializing_if = "Option::is_none")]
315 pub serialization: Option<IdlSerializationSnapshot>,
316 #[serde(rename = "type")]
317 pub type_def: IdlTypeDefKindSnapshot,
318}
319
320#[derive(Debug, Clone, Serialize, Deserialize)]
321#[serde(untagged)]
322pub enum IdlTypeDefKindSnapshot {
323 Struct {
324 kind: String,
325 fields: Vec<IdlFieldSnapshot>,
326 },
327 TupleStruct {
328 kind: String,
329 fields: Vec<IdlTypeSnapshot>,
330 },
331 Enum {
332 kind: String,
333 variants: Vec<IdlEnumVariantSnapshot>,
334 },
335}
336
337#[derive(Debug, Clone, Serialize, Deserialize)]
338pub struct IdlEnumVariantSnapshot {
339 pub name: String,
340}
341
342#[derive(Debug, Clone, Serialize, Deserialize)]
343pub struct IdlEventSnapshot {
344 pub name: String,
345 pub discriminator: Vec<u8>,
346 #[serde(default)]
347 pub docs: Vec<String>,
348}
349
350#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
351pub struct IdlErrorSnapshot {
352 pub code: u32,
353 pub name: String,
354 #[serde(default, skip_serializing_if = "Option::is_none")]
355 pub msg: Option<String>,
356}
357
358#[cfg(test)]
359mod tests {
360 use super::*;
361
362 #[test]
363 fn test_snapshot_serde() {
364 let snapshot = IdlSnapshot {
365 name: "test_program".to_string(),
366 program_id: Some("11111111111111111111111111111111".to_string()),
367 version: "0.1.0".to_string(),
368 accounts: vec![IdlAccountSnapshot {
369 name: "ExampleAccount".to_string(),
370 discriminator: vec![1, 2, 3, 4, 5, 6, 7, 8],
371 docs: vec!["Example account".to_string()],
372 serialization: Some(IdlSerializationSnapshot::Borsh),
373 fields: vec![],
374 type_def: None,
375 }],
376 instructions: vec![IdlInstructionSnapshot {
377 name: "example_instruction".to_string(),
378 discriminator: vec![8, 7, 6, 5, 4, 3, 2, 1],
379 discriminant: None,
380 docs: vec!["Example instruction".to_string()],
381 accounts: vec![IdlInstructionAccountSnapshot {
382 name: "payer".to_string(),
383 writable: true,
384 signer: true,
385 optional: false,
386 address: None,
387 docs: vec![],
388 }],
389 args: vec![IdlFieldSnapshot {
390 name: "amount".to_string(),
391 type_: IdlTypeSnapshot::HashMap(IdlHashMapTypeSnapshot {
392 hash_map: (
393 Box::new(IdlTypeSnapshot::Simple("u64".to_string())),
394 Box::new(IdlTypeSnapshot::Simple("string".to_string())),
395 ),
396 }),
397 }],
398 }],
399 types: vec![IdlTypeDefSnapshot {
400 name: "ExampleType".to_string(),
401 docs: vec![],
402 serialization: None,
403 type_def: IdlTypeDefKindSnapshot::Struct {
404 kind: "struct".to_string(),
405 fields: vec![IdlFieldSnapshot {
406 name: "value".to_string(),
407 type_: IdlTypeSnapshot::Simple("u64".to_string()),
408 }],
409 },
410 }],
411 events: vec![IdlEventSnapshot {
412 name: "ExampleEvent".to_string(),
413 discriminator: vec![0, 0, 0, 0, 0, 0, 0, 1],
414 docs: vec![],
415 }],
416 errors: vec![IdlErrorSnapshot {
417 code: 6000,
418 name: "ExampleError".to_string(),
419 msg: Some("example".to_string()),
420 }],
421 discriminant_size: 8,
422 };
423
424 let serialized = serde_json::to_value(&snapshot).expect("serialize snapshot");
425 let deserialized: IdlSnapshot =
426 serde_json::from_value(serialized.clone()).expect("deserialize snapshot");
427 let round_trip = serde_json::to_value(&deserialized).expect("re-serialize snapshot");
428
429 assert_eq!(serialized, round_trip);
430 assert_eq!(deserialized.name, "test_program");
431 }
432
433 #[test]
434 fn test_hashmap_compat() {
435 let json = r#"{"hashMap":["u64","string"]}"#;
436 let parsed: IdlHashMapTypeSnapshot =
437 serde_json::from_str(json).expect("deserialize hashMap");
438
439 assert!(matches!(
440 parsed.hash_map.0.as_ref(),
441 IdlTypeSnapshot::Simple(value) if value == "u64"
442 ));
443 assert!(matches!(
444 parsed.hash_map.1.as_ref(),
445 IdlTypeSnapshot::Simple(value) if value == "string"
446 ));
447 }
448}