thru_base/
txn_lib.rs

1//! Transaction library: normal Rust struct, signing, serialization, accessors
2//!
3
4pub type TnPubkey = [u8; 32];
5pub type TnHash = [u8; 32];
6pub type TnSignature = [u8; 64];
7
8use crate::{StateProofType, tn_state_proof::StateProof};
9use bytemuck::{Pod, Zeroable, bytes_of, from_bytes};
10use ed25519_dalek::{Signature, Signer, SigningKey, Verifier, VerifyingKey};
11
12pub const TN_TXN_FLAG_HAS_FEE_PAYER_PROOF_BIT: u8 = 0; // Bit position (matching C #define TN_TXN_FLAG_HAS_FEE_PAYER_PROOF (0U))
13pub const TN_TXN_FLAG_MAY_COMPRESS_ACCOUNT_BIT: u8 = 1; // Bit position (matching C #define TN_TXN_FLAG_MAY_COMPRESS_ACCOUNT (1U))
14
15// State proof type constants (matching C implementation)
16pub const TN_STATE_PROOF_TYPE_EXISTING: u64 = 0x0;
17pub const TN_STATE_PROOF_TYPE_UPDATING: u64 = 0x1;
18pub const TN_STATE_PROOF_TYPE_CREATION: u64 = 0x2;
19
20// State proof header size constants
21pub const TN_STATE_PROOF_HDR_SIZE: usize = 40; // 8 bytes type_slot + 32 bytes path_bitset
22pub const TN_ACCOUNT_META_FOOTPRINT: usize = 64; // Size of tn_account_meta_t (matching C sizeof)
23
24// TEMPORARY: Minimal local RpcError for test pass (remove when shared error type is available)
25#[derive(Debug, PartialEq)]
26pub enum RpcError {
27    InvalidTransactionSize { size: usize, max_size: usize },
28    TrailingBytes { expected: usize, found: usize },
29    TooManyAccounts { count: usize, max_count: usize },
30    InvalidTransactionSignature,
31    InvalidParams(&'static str),
32    InvalidFormat,
33    InvalidVersion,
34    InvalidFlags,
35    InvalidFeePayerStateProofType,
36}
37
38impl std::fmt::Display for RpcError {
39    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
40        match self {
41            RpcError::InvalidTransactionSize { size, max_size } => {
42                write!(
43                    f,
44                    "Transaction size {} exceeds maximum allowed size {}",
45                    size, max_size
46                )
47            }
48            RpcError::TrailingBytes { expected, found } => {
49                write!(
50                    f,
51                    "Transaction has trailing bytes: expected {} bytes, found {} bytes",
52                    expected, found
53                )
54            }
55            RpcError::TooManyAccounts { count, max_count } => {
56                write!(
57                    f,
58                    "Too many accounts: {} exceeds maximum {}",
59                    count, max_count
60                )
61            }
62            RpcError::InvalidTransactionSignature => {
63                write!(f, "Invalid transaction signature")
64            }
65            RpcError::InvalidParams(msg) => {
66                write!(f, "Invalid parameters: {}", msg)
67            }
68            RpcError::InvalidFormat => {
69                write!(f, "Invalid transaction format")
70            }
71            RpcError::InvalidVersion => {
72                write!(f, "Invalid transaction version")
73            }
74            RpcError::InvalidFlags => {
75                write!(f, "Invalid transaction flags")
76            }
77            RpcError::InvalidFeePayerStateProofType => {
78                write!(f, "Invalid fee payer state proof type")
79            }
80        }
81    }
82}
83
84impl RpcError {
85    pub fn invalid_transaction_size(size: usize, max_size: usize) -> Self {
86        Self::InvalidTransactionSize { size, max_size }
87    }
88    pub fn trailing_bytes(expected: usize, found: usize) -> Self {
89        Self::TrailingBytes { expected, found }
90    }
91    pub fn too_many_accounts(count: usize, max_count: usize) -> Self {
92        Self::TooManyAccounts { count, max_count }
93    }
94    pub fn invalid_transaction_signature() -> Self {
95        Self::InvalidTransactionSignature
96    }
97    pub fn invalid_params(msg: &'static str) -> Self {
98        Self::InvalidParams(msg)
99    }
100    pub fn invalid_format() -> Self {
101        Self::InvalidFormat
102    }
103    pub fn invalid_version() -> Self {
104        Self::InvalidVersion
105    }
106    pub fn invalid_flags() -> Self {
107        Self::InvalidFlags
108    }
109    pub fn invalid_fee_payer_state_proof_type() -> Self {
110        Self::InvalidFeePayerStateProofType
111    }
112}
113
114/// On-wire transaction header (matches TnTxnHdrV1 layout)
115#[repr(C)]
116#[derive(Clone, Copy, Debug)]
117pub struct WireTxnHdrV1 {
118    pub fee_payer_signature: [u8; 64],
119    pub transaction_version: u8,
120    pub flags: u8,
121    pub readwrite_accounts_cnt: u16,
122    pub readonly_accounts_cnt: u16,
123    pub instr_data_sz: u16,
124    pub req_compute_units: u32,
125    pub req_state_units: u16,
126    pub req_memory_units: u16,
127    pub fee: u64,
128    pub nonce: u64,
129    pub start_slot: u64,
130    pub expiry_after: u32,
131    pub padding_0: [u8; 4],
132    pub fee_payer_pubkey: [u8; 32],
133    pub program_pubkey: [u8; 32],
134}
135
136impl Default for WireTxnHdrV1 {
137    fn default() -> Self {
138        Self {
139            fee_payer_signature: [0u8; 64],
140            transaction_version: 0,
141            flags: 0,
142            readwrite_accounts_cnt: 0,
143            readonly_accounts_cnt: 0,
144            instr_data_sz: 0,
145            req_compute_units: 0,
146            req_state_units: 0,
147            req_memory_units: 0,
148            fee: 0,
149            nonce: 0,
150            start_slot: 0,
151            expiry_after: 0,
152            padding_0: [0u8; 4],
153            fee_payer_pubkey: [0u8; 32],
154            program_pubkey: [0u8; 32],
155        }
156    }
157}
158
159// Manual Pod implementation to avoid derive issues
160unsafe impl Pod for WireTxnHdrV1 {}
161unsafe impl Zeroable for WireTxnHdrV1 {}
162
163/// Normal Rust struct for transaction construction
164#[derive(Clone, Debug, Default)]
165pub struct Transaction {
166    // Core transaction fields
167    pub fee_payer: TnPubkey, // [u8; 32] - who pays the fee
168    pub program: TnPubkey,   // [u8; 32] - target program
169
170    // Account lists (optional)
171    pub rw_accs: Option<Vec<TnPubkey>>, // read-write accounts
172    pub r_accs: Option<Vec<TnPubkey>>,  // read-only accounts
173
174    // Instruction data (optional)
175    pub instructions: Option<Vec<u8>>, // instruction bytes
176
177    // Transaction parameters
178    pub fee: u64,               // transaction fee
179    pub req_compute_units: u32, // requested compute units
180    pub req_state_units: u16,   // requested state units
181    pub req_memory_units: u16,  // requested memory units
182    pub expiry_after: u32,      // expiry time offset
183    pub start_slot: u64,        // starting slot
184    pub nonce: u64,             // transaction nonce
185    pub flags: u8,              // transaction flags
186
187    // Signature (optional until signed)
188    pub signature: Option<TnSignature>, // [u8; 64] - Ed25519 signature
189
190    // Fee payer state proof (optional)
191    pub fee_payer_state_proof: Option<StateProof>, // State proof for fee payer account
192
193    pub fee_payer_account_meta_raw: Option<Vec<u8>>,
194}
195
196impl Transaction {
197    /// Create a new unsigned transaction
198    pub fn new(fee_payer: TnPubkey, program: TnPubkey, fee: u64, nonce: u64) -> Self {
199        Self {
200            fee_payer,
201            program,
202            rw_accs: None,
203            r_accs: None,
204            instructions: None,
205            fee,
206            req_compute_units: 0,
207            req_state_units: 0,
208            req_memory_units: 0,
209            expiry_after: 0,
210            start_slot: 0,
211            nonce,
212            flags: 0,
213            signature: None,
214            fee_payer_state_proof: None,
215            fee_payer_account_meta_raw: None,
216        }
217    }
218
219    pub fn has_fee_payer_state_proof(&self) -> bool {
220        (self.flags & (1 << TN_TXN_FLAG_HAS_FEE_PAYER_PROOF_BIT)) != 0
221    }
222    pub fn may_compress_account(&self) -> bool {
223        (self.flags & (1 << TN_TXN_FLAG_MAY_COMPRESS_ACCOUNT_BIT)) != 0
224    }
225
226    pub fn get_signature(&self) -> Option<crate::Signature> {
227        if let Some(sig) = &self.signature {
228            return Some(crate::Signature::from_bytes(&sig));
229        }
230        None
231    }
232
233    pub fn with_may_compress_account(mut self) -> Self {
234        self.flags |= 1 << TN_TXN_FLAG_MAY_COMPRESS_ACCOUNT_BIT;
235        self
236    }
237
238    /// Builder method: set fee payer state proof
239    pub fn with_fee_payer_state_proof(mut self, state_proof: &StateProof) -> Self {
240        self.fee_payer_state_proof = Some(state_proof.clone());
241        // Set the flag bit to indicate presence of state proof
242        self.flags |= 1 << TN_TXN_FLAG_HAS_FEE_PAYER_PROOF_BIT;
243        self
244    }
245
246    /// Builder method: set fee payer account meta as raw bytes
247    pub fn with_fee_payer_account_meta_raw(mut self, account_meta_raw: Vec<u8>) -> Self {
248        self.fee_payer_account_meta_raw = Some(account_meta_raw);
249        self
250    }
251
252    /// Builder method: remove fee payer state proof
253    pub fn without_fee_payer_state_proof(mut self) -> Self {
254        self.fee_payer_state_proof = None;
255        // Clear the flag bit to indicate absence of state proof
256        self.flags &= !(1 << TN_TXN_FLAG_HAS_FEE_PAYER_PROOF_BIT);
257        self
258    }
259
260    /// Builder method: add read-write accounts
261    pub fn with_rw_accounts(mut self, accounts: Vec<TnPubkey>) -> Self {
262        self.rw_accs = Some(accounts);
263        self
264    }
265
266    /// Builder method: add read-only accounts
267    pub fn with_r_accounts(mut self, accounts: Vec<TnPubkey>) -> Self {
268        self.r_accs = Some(accounts);
269        self
270    }
271
272    /// Builder method: add a single read-write account
273    pub fn add_rw_account(mut self, account: TnPubkey) -> Self {
274        match self.rw_accs {
275            Some(ref mut accounts) => accounts.push(account),
276            None => self.rw_accs = Some(vec![account]),
277        }
278        self
279    }
280
281    /// Builder method: add a single read-only account
282    pub fn add_r_account(mut self, account: TnPubkey) -> Self {
283        match self.r_accs {
284            Some(ref mut accounts) => accounts.push(account),
285            None => self.r_accs = Some(vec![account]),
286        }
287        self
288    }
289
290    /// Builder method: add instruction data
291    pub fn with_instructions(mut self, instructions: Vec<u8>) -> Self {
292        self.instructions = Some(instructions);
293        self
294    }
295
296    /// Builder method: set compute units
297    pub fn with_compute_units(mut self, units: u32) -> Self {
298        self.req_compute_units = units;
299        self
300    }
301
302    /// Builder method: set state units
303    pub fn with_state_units(mut self, units: u16) -> Self {
304        self.req_state_units = units;
305        self
306    }
307
308    /// Builder method: set memory units
309    pub fn with_memory_units(mut self, units: u16) -> Self {
310        self.req_memory_units = units;
311        self
312    }
313
314    /// Builder method: set expiry
315    pub fn with_expiry_after(mut self, expiry: u32) -> Self {
316        self.expiry_after = expiry;
317        self
318    }
319
320    /// Builder method: set nonce
321    pub fn with_nonce(mut self, nonce: u64) -> Self {
322        self.nonce = nonce;
323        self
324    }
325
326    /// Builder method: set start slot
327    pub fn with_start_slot(mut self, slot: u64) -> Self {
328        self.start_slot = slot;
329        self
330    }
331
332    /// Sign the transaction with a 32-byte Ed25519 private key
333    pub fn sign(&mut self, private_key: &[u8; 32]) -> Result<(), Box<dyn std::error::Error>> {
334        let signing_key = SigningKey::from_bytes(private_key);
335        // Sign the wire format bytes (excluding signature field)
336        let wire_bytes = self.to_wire_for_signing();
337        let sig = signing_key.sign(&wire_bytes);
338        self.signature = Some(sig.to_bytes());
339        Ok(())
340    }
341
342    /// Verify the transaction signature
343    pub fn verify(&self) -> bool {
344        if let Some(sig_bytes) = &self.signature {
345            if let Ok(verifying_key) = VerifyingKey::from_bytes(&self.fee_payer) {
346                let sig = Signature::from_bytes(sig_bytes);
347                // Verify against the wire format bytes (excluding signature field)
348                let wire_bytes = self.to_wire_for_signing();
349                return verifying_key.verify(&wire_bytes, &sig).is_ok();
350            }
351        }
352        false
353    }
354
355    /// Create wire format for signing (excluding signature field)
356    fn to_wire_for_signing(&self) -> Vec<u8> {
357        // Zero out all bytes first to ensure deterministic padding
358        let mut wire: WireTxnHdrV1 = unsafe { core::mem::zeroed() };
359        // Don't set fee_payer_signature - it will be excluded from signing
360        wire.transaction_version = 1;
361        wire.flags = self.flags;
362        wire.readwrite_accounts_cnt = self.rw_accs.as_ref().map_or(0, |v| v.len() as u16);
363        wire.readonly_accounts_cnt = self.r_accs.as_ref().map_or(0, |v| v.len() as u16);
364        wire.instr_data_sz = self.instructions.as_ref().map_or(0, |v| v.len() as u16);
365        wire.req_compute_units = self.req_compute_units;
366        wire.req_state_units = self.req_state_units;
367        wire.req_memory_units = self.req_memory_units;
368        wire.expiry_after = self.expiry_after;
369        wire.fee = self.fee;
370        wire.nonce = self.nonce;
371        wire.start_slot = self.start_slot;
372        wire.fee_payer_pubkey = self.fee_payer;
373        wire.program_pubkey = self.program;
374
375        let wire_bytes = bytes_of(&wire);
376        // Skip the first 64 bytes (fee_payer_signature) and include the rest
377        let mut result = wire_bytes[64..].to_vec();
378
379        // Append variable-length data
380        if let Some(ref rw_accs) = self.rw_accs {
381            for acc in rw_accs {
382                result.extend_from_slice(acc);
383            }
384        }
385
386        if let Some(ref r_accs) = self.r_accs {
387            for acc in r_accs {
388                result.extend_from_slice(acc);
389            }
390        }
391
392        if let Some(ref instructions) = self.instructions {
393            result.extend_from_slice(instructions);
394        }
395
396        // Append state proof if present
397        if let Some(ref state_proof) = self.fee_payer_state_proof {
398            result.extend_from_slice(&state_proof.to_wire());
399        }
400
401        // Use raw account meta if available, otherwise use structured account meta
402        if let Some(ref fee_payer_account_meta_raw) = self.fee_payer_account_meta_raw {
403            result.extend_from_slice(fee_payer_account_meta_raw);
404        }
405
406        result
407    }
408
409    /// Serialize to on-wire format (WireTxnHdrV1)
410    pub fn to_wire(&self) -> Vec<u8> {
411        let mut wire = WireTxnHdrV1::default();
412        if let Some(sig) = &self.signature {
413            wire.fee_payer_signature = *sig;
414        }
415        wire.transaction_version = 1;
416        wire.flags = self.flags;
417        wire.readwrite_accounts_cnt = self.rw_accs.as_ref().map_or(0, |v| v.len() as u16);
418        wire.readonly_accounts_cnt = self.r_accs.as_ref().map_or(0, |v| v.len() as u16);
419        wire.instr_data_sz = self.instructions.as_ref().map_or(0, |v| v.len() as u16);
420        wire.req_compute_units = self.req_compute_units;
421        wire.req_state_units = self.req_state_units;
422        wire.req_memory_units = self.req_memory_units;
423        wire.expiry_after = self.expiry_after;
424        wire.fee = self.fee;
425        wire.nonce = self.nonce;
426        wire.start_slot = self.start_slot;
427        wire.fee_payer_pubkey = self.fee_payer;
428        wire.program_pubkey = self.program;
429
430        let mut result = bytes_of(&wire).to_vec();
431
432        // Append variable-length data
433        if let Some(ref rw_accs) = self.rw_accs {
434            for acc in rw_accs {
435                result.extend_from_slice(acc);
436            }
437        }
438
439        if let Some(ref r_accs) = self.r_accs {
440            for acc in r_accs {
441                result.extend_from_slice(acc);
442            }
443        }
444
445        if let Some(ref instructions) = self.instructions {
446            result.extend_from_slice(instructions);
447        }
448
449        // Append state proof if present (after instruction data)
450        if let Some(ref state_proof) = self.fee_payer_state_proof {
451            result.extend_from_slice(&state_proof.to_wire());
452        }
453
454        // Use raw account meta if available, otherwise use structured account meta
455        if let Some(ref fee_payer_account_meta_raw) = self.fee_payer_account_meta_raw {
456            result.extend_from_slice(fee_payer_account_meta_raw);
457        }
458
459        result
460    }
461
462    /// Deserialize from on-wire format (WireTxnHdrV1)
463    pub fn from_wire(bytes: &[u8]) -> Option<Self> {
464        if bytes.len() < core::mem::size_of::<WireTxnHdrV1>() {
465            return None;
466        }
467
468        let wire: &WireTxnHdrV1 = from_bytes(&bytes[0..core::mem::size_of::<WireTxnHdrV1>()]);
469        let mut offset = core::mem::size_of::<WireTxnHdrV1>();
470
471        // Parse read-write accounts
472        let rw_accs = if wire.readwrite_accounts_cnt > 0 {
473            let mut accounts = Vec::new();
474            for _ in 0..wire.readwrite_accounts_cnt {
475                if offset + 32 > bytes.len() {
476                    return None;
477                }
478                let mut acc = [0u8; 32];
479                acc.copy_from_slice(&bytes[offset..offset + 32]);
480                accounts.push(acc);
481                offset += 32;
482            }
483            Some(accounts)
484        } else {
485            None
486        };
487
488        // Parse read-only accounts
489        let r_accs = if wire.readonly_accounts_cnt > 0 {
490            let mut accounts = Vec::new();
491            for _ in 0..wire.readonly_accounts_cnt {
492                if offset + 32 > bytes.len() {
493                    return None;
494                }
495                let mut acc = [0u8; 32];
496                acc.copy_from_slice(&bytes[offset..offset + 32]);
497                accounts.push(acc);
498                offset += 32;
499            }
500            Some(accounts)
501        } else {
502            None
503        };
504
505        // Parse instructions
506        let instructions = if wire.instr_data_sz > 0 {
507            if offset + wire.instr_data_sz as usize > bytes.len() {
508                return None;
509            }
510            let instr = bytes[offset..offset + wire.instr_data_sz as usize].to_vec();
511            offset += wire.instr_data_sz as usize;
512            Some(instr)
513        } else {
514            None
515        };
516
517        let mut fee_payer_account_meta_raw: Option<Vec<u8>> = None;
518        // Parse state proof if present
519        let fee_payer_state_proof = if has_fee_payer_state_proof(wire.flags) {
520            if offset >= bytes.len() {
521                return None;
522            }
523            let state_proof_bytes = &bytes[offset..];
524            if let Some(state_proof) = StateProof::from_wire(state_proof_bytes) {
525                offset += state_proof.footprint();
526                if state_proof.header.proof_type == StateProofType::Existing {
527                    if offset + TN_ACCOUNT_META_FOOTPRINT > bytes.len() {
528                        return None;
529                    }
530                    let account_meta_bytes = &bytes[offset..offset + TN_ACCOUNT_META_FOOTPRINT];
531                    fee_payer_account_meta_raw = Some(account_meta_bytes.to_vec());
532                    offset += TN_ACCOUNT_META_FOOTPRINT;
533                }
534                Some(state_proof)
535            } else {
536                return None;
537            }
538        } else {
539            None
540        };
541
542        // Verify we've consumed all bytes
543        if offset != bytes.len() {
544            log::warn!(
545                "Transaction::from_wire: offset != bytes.len() ({} != {})",
546                offset,
547                bytes.len()
548            );
549            return None;
550        }
551
552        Some(Transaction {
553            fee_payer: wire.fee_payer_pubkey,
554            program: wire.program_pubkey,
555            rw_accs,
556            r_accs,
557            instructions,
558            flags: wire.flags,
559            fee: wire.fee,
560            req_compute_units: wire.req_compute_units,
561            req_state_units: wire.req_state_units,
562            req_memory_units: wire.req_memory_units,
563            expiry_after: wire.expiry_after,
564            start_slot: wire.start_slot,
565            nonce: wire.nonce,
566            signature: Some(wire.fee_payer_signature),
567            fee_payer_state_proof,
568            fee_payer_account_meta_raw,
569        })
570    }
571
572    /// Accessor: read a field from serialized bytes by name
573    pub fn get_field_from_wire(bytes: &[u8], field: &str) -> Option<Vec<u8>> {
574        if bytes.len() < core::mem::size_of::<WireTxnHdrV1>() {
575            return None;
576        }
577        let wire: &WireTxnHdrV1 = from_bytes(&bytes[0..core::mem::size_of::<WireTxnHdrV1>()]);
578        match field {
579            "fee_payer_signature" => Some(wire.fee_payer_signature.to_vec()),
580            "transaction_version" => Some(vec![wire.transaction_version]),
581            "flags" => Some(vec![wire.flags]),
582            "readwrite_accounts_cnt" => Some(wire.readwrite_accounts_cnt.to_le_bytes().to_vec()),
583            "readonly_accounts_cnt" => Some(wire.readonly_accounts_cnt.to_le_bytes().to_vec()),
584            "instr_data_sz" => Some(wire.instr_data_sz.to_le_bytes().to_vec()),
585            "req_compute_units" => Some(wire.req_compute_units.to_le_bytes().to_vec()),
586            "req_state_units" => Some(wire.req_state_units.to_le_bytes().to_vec()),
587            "req_memory_units" => Some(wire.req_memory_units.to_le_bytes().to_vec()),
588            "expiry_after" => Some(wire.expiry_after.to_le_bytes().to_vec()),
589            "fee" => Some(wire.fee.to_le_bytes().to_vec()),
590            "nonce" => Some(wire.nonce.to_le_bytes().to_vec()),
591            "start_slot" => Some(wire.start_slot.to_le_bytes().to_vec()),
592            "fee_payer_pubkey" => Some(wire.fee_payer_pubkey.to_vec()),
593            "program_pubkey" => Some(wire.program_pubkey.to_vec()),
594            _ => None,
595        }
596    }
597}
598
599/// Helper function to check if transaction has fee payer state proof
600fn has_fee_payer_state_proof(flags: u8) -> bool {
601    (flags & (1 << TN_TXN_FLAG_HAS_FEE_PAYER_PROOF_BIT)) != 0
602}
603
604/// Helper function to extract state proof type from header
605fn extract_state_proof_type(type_slot: u64) -> u64 {
606    (type_slot >> 62) & 0x3 // Extract top 2 bits
607}
608
609/// Helper function to calculate state proof footprint from header
610fn calculate_state_proof_footprint(state_proof_data: &[u8]) -> Result<usize, RpcError> {
611    if state_proof_data.len() < TN_STATE_PROOF_HDR_SIZE {
612        return Err(RpcError::invalid_format());
613    }
614
615    // Extract type_slot (first 8 bytes)
616    let type_slot = u64::from_le_bytes([
617        state_proof_data[0],
618        state_proof_data[1],
619        state_proof_data[2],
620        state_proof_data[3],
621        state_proof_data[4],
622        state_proof_data[5],
623        state_proof_data[6],
624        state_proof_data[7],
625    ]);
626
627    // Extract path_bitset (next 32 bytes) and count set bits
628    let mut sibling_hash_cnt = 0u32;
629    for i in 0..4 {
630        let start = 8 + i * 8;
631        let word = u64::from_le_bytes([
632            state_proof_data[start],
633            state_proof_data[start + 1],
634            state_proof_data[start + 2],
635            state_proof_data[start + 3],
636            state_proof_data[start + 4],
637            state_proof_data[start + 5],
638            state_proof_data[start + 6],
639            state_proof_data[start + 7],
640        ]);
641        sibling_hash_cnt += word.count_ones();
642    }
643
644    let proof_type = extract_state_proof_type(type_slot);
645    let body_sz = (proof_type + sibling_hash_cnt as u64) * 32; // Each hash is 32 bytes
646
647    Ok(TN_STATE_PROOF_HDR_SIZE + body_sz as usize)
648}
649
650pub fn tn_txn_size(bytes: &[u8]) -> Result<usize, RpcError> {
651    // Basic size checks
652    if bytes.len() < core::mem::size_of::<WireTxnHdrV1>() {
653        return Err(RpcError::invalid_format());
654    }
655
656    // Parse the header
657    // Use read_unaligned to safely read from potentially unaligned memory
658    let hdr: WireTxnHdrV1 =
659        unsafe { std::ptr::read_unaligned(bytes.as_ptr() as *const WireTxnHdrV1) };
660    let hdr = &hdr;
661    let mut offset = core::mem::size_of::<WireTxnHdrV1>();
662
663    // Calculate accounts size
664    let accs_sz = (hdr.readwrite_accounts_cnt as usize + hdr.readonly_accounts_cnt as usize) * 32;
665    if offset + accs_sz > bytes.len() {
666        return Err(RpcError::invalid_format());
667    }
668    offset += accs_sz;
669
670    // Calculate instruction data size
671    let instr_sz = hdr.instr_data_sz as usize;
672    if offset + instr_sz > bytes.len() {
673        return Err(RpcError::invalid_format());
674    }
675    offset += instr_sz;
676
677    // Handle fee payer state proof if present
678    if has_fee_payer_state_proof(hdr.flags) {
679        // Check state proof header size
680        if offset + TN_STATE_PROOF_HDR_SIZE > bytes.len() {
681            return Err(RpcError::invalid_format());
682        }
683
684        // Calculate state proof footprint
685        let state_proof_data = &bytes[offset..];
686        let state_proof_sz = calculate_state_proof_footprint(state_proof_data)?;
687
688        if offset + state_proof_sz > bytes.len() {
689            return Err(RpcError::invalid_format());
690        }
691        offset += state_proof_sz;
692
693        // Extract proof type for additional validation
694        let type_slot = u64::from_le_bytes([
695            state_proof_data[0],
696            state_proof_data[1],
697            state_proof_data[2],
698            state_proof_data[3],
699            state_proof_data[4],
700            state_proof_data[5],
701            state_proof_data[6],
702            state_proof_data[7],
703        ]);
704        let proof_type = extract_state_proof_type(type_slot);
705
706        // If proof type is EXISTING, account for account meta
707        if proof_type == TN_STATE_PROOF_TYPE_EXISTING {
708            if offset + TN_ACCOUNT_META_FOOTPRINT > bytes.len() {
709                return Err(RpcError::invalid_format());
710            }
711            offset += TN_ACCOUNT_META_FOOTPRINT;
712        }
713    }
714
715    // Verify we don't exceed the provided bytes
716    if offset > bytes.len() {
717        return Err(RpcError::invalid_format());
718    }
719
720    Ok(offset)
721}
722
723/// Validate a wire-format transaction for protocol correctness (matching C tn_txn_parse_core).
724pub fn validate_wire_transaction(bytes: &[u8]) -> Result<(), RpcError> {
725    const TN_TXN_MTU: usize = 32_768;
726    const TN_TXN_VERSION_OFFSET: usize = 64;
727    const TN_TXN_FLAGS_OFFSET: usize = 65;
728
729    use bytemuck::from_bytes;
730
731    // 1. Check payload size
732    if bytes.len() > TN_TXN_MTU {
733        return Err(RpcError::invalid_transaction_size(bytes.len(), TN_TXN_MTU));
734    }
735
736    // 2. Check transaction version
737    if bytes.len() <= TN_TXN_VERSION_OFFSET {
738        return Err(RpcError::invalid_format());
739    }
740    let transaction_version = bytes[TN_TXN_VERSION_OFFSET];
741    if transaction_version != 0x01 {
742        return Err(RpcError::invalid_version());
743    }
744
745    // 3. Check flags
746    if bytes.len() <= TN_TXN_FLAGS_OFFSET {
747        return Err(RpcError::invalid_format());
748    }
749    let flags = bytes[TN_TXN_FLAGS_OFFSET];
750    // Clear the fee payer proof bit and check that all other bits are 0
751    let flags_without_proof_bit = flags & !(1 << TN_TXN_FLAG_HAS_FEE_PAYER_PROOF_BIT);
752    let flags_cleared = flags_without_proof_bit & !(1 << TN_TXN_FLAG_MAY_COMPRESS_ACCOUNT_BIT);
753    if flags_cleared != 0 {
754        return Err(RpcError::invalid_flags());
755    }
756
757    // 4. Check header size and parse header
758    if bytes.len() < core::mem::size_of::<WireTxnHdrV1>() {
759        return Err(RpcError::invalid_format());
760    }
761    let hdr: &WireTxnHdrV1 = from_bytes(&bytes[0..core::mem::size_of::<WireTxnHdrV1>()]);
762    let mut offset = core::mem::size_of::<WireTxnHdrV1>();
763
764    // 5. Parse accounts
765    let accs_sz = (hdr.readwrite_accounts_cnt as usize + hdr.readonly_accounts_cnt as usize) * 32;
766    if offset + accs_sz > bytes.len() {
767        return Err(RpcError::invalid_format());
768    }
769    offset += accs_sz;
770
771    // 6. Parse instruction data
772    let instr_sz = hdr.instr_data_sz as usize;
773    if offset + instr_sz > bytes.len() {
774        return Err(RpcError::invalid_format());
775    }
776    offset += instr_sz;
777
778    // 7. Handle fee payer state proof if present
779    if has_fee_payer_state_proof(flags) {
780        // Check state proof header size
781        if offset + TN_STATE_PROOF_HDR_SIZE > bytes.len() {
782            return Err(RpcError::invalid_format());
783        }
784
785        // Calculate state proof footprint
786        let state_proof_data = &bytes[offset..];
787        let state_proof_sz = calculate_state_proof_footprint(state_proof_data)?;
788
789        if offset + state_proof_sz > bytes.len() {
790            return Err(RpcError::invalid_format());
791        }
792
793        // Extract proof type and validate
794        let type_slot = u64::from_le_bytes([
795            state_proof_data[0],
796            state_proof_data[1],
797            state_proof_data[2],
798            state_proof_data[3],
799            state_proof_data[4],
800            state_proof_data[5],
801            state_proof_data[6],
802            state_proof_data[7],
803        ]);
804        let proof_type = extract_state_proof_type(type_slot);
805
806        // Check that proof type is not UPDATING
807        if proof_type == TN_STATE_PROOF_TYPE_UPDATING {
808            return Err(RpcError::invalid_fee_payer_state_proof_type());
809        }
810
811        offset += state_proof_sz;
812
813        // If proof type is EXISTING, expect account meta
814        if proof_type == TN_STATE_PROOF_TYPE_EXISTING {
815            if offset + TN_ACCOUNT_META_FOOTPRINT > bytes.len() {
816                return Err(RpcError::invalid_format());
817            }
818            offset += TN_ACCOUNT_META_FOOTPRINT;
819        }
820    }
821
822    // 8. Check for exact size match (no trailing bytes)
823    if offset != bytes.len() {
824        // return Err(RpcError::invalid_format());
825        return Err(RpcError::trailing_bytes(offset, bytes.len()));
826    }
827    // 5. Signature check (fee payer signature)
828    if hdr.fee_payer_signature.len() != 64 {
829        return Err(RpcError::invalid_transaction_signature());
830    }
831    let sig = Signature::from_bytes(&hdr.fee_payer_signature);
832    let wire_for_signing = bytes[64..].to_vec(); // Exclude signature field
833    let verifying_key = match VerifyingKey::from_bytes(&hdr.fee_payer_pubkey) {
834        Ok(key) => key,
835        Err(_) => return Err(RpcError::invalid_transaction_signature()),
836    };
837    if verifying_key.verify(&wire_for_signing, &sig).is_err() {
838        return Err(RpcError::invalid_transaction_signature());
839    }
840
841    Ok(())
842}
843
844#[cfg(test)]
845mod tests {
846    use super::*;
847    use ed25519_dalek::SigningKey;
848
849    fn make_valid_txn_bytes_with_flags(flags: u8) -> Vec<u8> {
850        let signing_key = SigningKey::from(&[1u8; 32]);
851        let verifying_key = signing_key.verifying_key();
852        let mut tx = Transaction::new(verifying_key.to_bytes(), [2u8; 32], 100, 42);
853        tx.rw_accs = Some(vec![[3u8; 32], [4u8; 32]]);
854        tx.r_accs = Some(vec![[5u8; 32]]);
855        tx.instructions = Some(vec![1, 2, 3, 4]);
856        tx.flags = flags;
857        tx.sign(&signing_key.to_bytes()).unwrap();
858        tx.to_wire()
859    }
860
861    fn make_valid_txn_bytes() -> Vec<u8> {
862        make_valid_txn_bytes_with_flags(0)
863    }
864
865    #[test]
866    fn test_tn_txn_size_basic_transaction() {
867        let bytes = make_valid_txn_bytes();
868        let calculated_size = tn_txn_size(&bytes).unwrap();
869
870        // The calculated size should match the actual bytes length
871        assert_eq!(calculated_size, bytes.len());
872    }
873
874    #[test]
875    fn test_tn_txn_size_with_state_proof() {
876        use crate::tn_state_proof::StateProof;
877
878        let signing_key = SigningKey::from(&[1u8; 32]);
879        let verifying_key = signing_key.verifying_key();
880
881        // Create a CREATION state proof
882        let path_bitset = [0u8; 32]; // No set bits = no sibling hashes
883        let existing_leaf_pubkey = [7u8; 32];
884        let existing_leaf_hash = [8u8; 32];
885        let state_proof = StateProof::creation(
886            100,
887            path_bitset,
888            existing_leaf_pubkey,
889            existing_leaf_hash,
890            vec![],
891        );
892
893        // Create transaction with state proof
894        let mut tx = Transaction::new(verifying_key.to_bytes(), [2u8; 32], 100, 42)
895            .with_rw_accounts(vec![[3u8; 32]])
896            .with_instructions(vec![1, 2, 3])
897            .with_fee_payer_state_proof(&state_proof);
898
899        tx.sign(&signing_key.to_bytes()).unwrap();
900        let bytes = tx.to_wire();
901
902        let calculated_size = tn_txn_size(&bytes).unwrap();
903
904        // The calculated size should match the actual bytes length
905        assert_eq!(calculated_size, bytes.len());
906    }
907
908    #[test]
909    fn test_tn_txn_size_minimal_transaction() {
910        let signing_key = SigningKey::from(&[1u8; 32]);
911        let verifying_key = signing_key.verifying_key();
912
913        // Create minimal transaction (no accounts, no instructions)
914        let mut tx = Transaction::new(verifying_key.to_bytes(), [2u8; 32], 100, 42);
915        tx.sign(&signing_key.to_bytes()).unwrap();
916        let bytes = tx.to_wire();
917
918        let calculated_size = tn_txn_size(&bytes).unwrap();
919
920        // The calculated size should match the actual bytes length
921        assert_eq!(calculated_size, bytes.len());
922
923        // Should be exactly the header size for minimal transaction
924        assert_eq!(calculated_size, core::mem::size_of::<WireTxnHdrV1>());
925    }
926
927    #[test]
928    fn test_tn_txn_size_invalid_format() {
929        // Test with bytes too short for header
930        let short_bytes = vec![0u8; 50];
931        let result = tn_txn_size(&short_bytes);
932        assert!(matches!(result, Err(RpcError::InvalidFormat)));
933
934        // Test with header but missing account data
935        let mut bytes = make_valid_txn_bytes();
936        bytes.truncate(core::mem::size_of::<WireTxnHdrV1>() + 10); // Truncate to cause missing data
937        let result = tn_txn_size(&bytes);
938        assert!(matches!(result, Err(RpcError::InvalidFormat)));
939    }
940
941    #[test]
942    fn test_tn_txn_size_consistency_with_validation() {
943        let bytes = make_valid_txn_bytes();
944
945        // Both functions should succeed for valid transactions
946        assert!(validate_wire_transaction(&bytes).is_ok());
947        assert!(tn_txn_size(&bytes).is_ok());
948
949        // Size should match actual length
950        let calculated_size = tn_txn_size(&bytes).unwrap();
951        assert_eq!(calculated_size, bytes.len());
952    }
953
954    #[test]
955    fn test_valid_transaction() {
956        let bytes = make_valid_txn_bytes();
957        assert!(validate_wire_transaction(&bytes).is_ok());
958    }
959
960    #[test]
961    fn test_oversize_transaction() {
962        let mut bytes = make_valid_txn_bytes();
963        bytes.resize(32_769, 0);
964        let err = validate_wire_transaction(&bytes).unwrap_err();
965        assert!(matches!(
966            err,
967            RpcError::InvalidTransactionSize {
968                size: 32_769,
969                max_size: 32_768
970            }
971        ));
972    }
973
974    #[test]
975    fn test_trailing_bytes() {
976        let mut bytes = make_valid_txn_bytes();
977        bytes.push(0);
978        let err = validate_wire_transaction(&bytes).unwrap_err();
979        assert!(matches!(
980            err,
981            RpcError::TrailingBytes {
982                expected: 276,
983                found: 277
984            }
985        ));
986    }
987
988    #[test]
989    fn test_invalid_transaction_version() {
990        let mut bytes = make_valid_txn_bytes();
991        // Corrupt the transaction version (at offset 64)
992        bytes[64] = 0x02; // Invalid version
993        let err = validate_wire_transaction(&bytes).unwrap_err();
994        assert!(matches!(err, RpcError::InvalidVersion));
995    }
996
997    #[test]
998    fn test_invalid_flags() {
999        // Set invalid flag bits (keeping fee payer proof bit, but adding others)
1000        let bytes = make_valid_txn_bytes_with_flags(0x07);
1001        let err = validate_wire_transaction(&bytes).unwrap_err();
1002        assert!(matches!(err, RpcError::InvalidFlags));
1003    }
1004
1005    #[test]
1006    fn test_transaction_too_short() {
1007        let bytes = vec![0u8; 50]; // Too short for header
1008        let err = validate_wire_transaction(&bytes).unwrap_err();
1009        assert!(matches!(err, RpcError::InvalidFormat));
1010    }
1011
1012    #[test]
1013    fn test_transaction_with_state_proof() {
1014        use crate::tn_state_proof::{StateProof, StateProofType};
1015
1016        let signing_key = SigningKey::from(&[1u8; 32]);
1017        let verifying_key = signing_key.verifying_key();
1018
1019        // Create a CREATION state proof (doesn't require account meta)
1020        let path_bitset = [0u8; 32]; // No set bits = no sibling hashes
1021        let existing_leaf_pubkey = [7u8; 32];
1022        let existing_leaf_hash = [8u8; 32];
1023        let state_proof = StateProof::creation(
1024            100,
1025            path_bitset,
1026            existing_leaf_pubkey,
1027            existing_leaf_hash,
1028            vec![],
1029        );
1030
1031        // Create transaction with state proof
1032        let mut tx = Transaction::new(verifying_key.to_bytes(), [2u8; 32], 100, 42)
1033            .with_fee_payer_state_proof(&state_proof);
1034
1035        // Verify flag is set
1036        assert!(tx.has_fee_payer_state_proof());
1037        assert_eq!(tx.flags & (1 << TN_TXN_FLAG_HAS_FEE_PAYER_PROOF_BIT), 1);
1038
1039        tx.sign(&signing_key.to_bytes()).unwrap();
1040        let bytes = tx.to_wire();
1041
1042        // Verify state proof is included in wire format
1043        assert!(bytes.len() > 168); // Header + state proof should be larger
1044        assert!(validate_wire_transaction(&bytes).is_ok());
1045
1046        // Test deserialization
1047        let decoded_tx = Transaction::from_wire(&bytes).unwrap();
1048        assert!(decoded_tx.has_fee_payer_state_proof());
1049        assert!(decoded_tx.fee_payer_state_proof.is_some());
1050
1051        let decoded_proof = decoded_tx.fee_payer_state_proof.unwrap();
1052        assert_eq!(decoded_proof.proof_type(), StateProofType::Creation);
1053        assert_eq!(decoded_proof.slot(), 100);
1054    }
1055
1056    #[test]
1057    fn test_transaction_with_state_proof_serialization_round_trip() {
1058        use crate::tn_state_proof::StateProof;
1059
1060        let signing_key = SigningKey::from(&[1u8; 32]);
1061        let verifying_key = signing_key.verifying_key();
1062
1063        // Create a creation state proof with some sibling hashes
1064        let mut path_bitset = [0u8; 32];
1065        path_bitset[0] = 0b11; // Set first 2 bits for 2 sibling hashes
1066        let existing_leaf_pubkey = [7u8; 32];
1067        let existing_leaf_hash = [8u8; 32];
1068        let sibling_hashes = vec![[9u8; 32], [10u8; 32]];
1069
1070        let state_proof = StateProof::creation(
1071            200,
1072            path_bitset,
1073            existing_leaf_pubkey,
1074            existing_leaf_hash,
1075            sibling_hashes.clone(),
1076        );
1077
1078        // Create transaction with complex state proof
1079        let mut tx = Transaction::new(verifying_key.to_bytes(), [2u8; 32], 100, 42)
1080            .with_rw_accounts(vec![[3u8; 32], [4u8; 32]])
1081            .with_r_accounts(vec![[5u8; 32]])
1082            .with_instructions(vec![1, 2, 3, 4])
1083            .with_fee_payer_state_proof(&state_proof);
1084
1085        tx.sign(&signing_key.to_bytes()).unwrap();
1086        let bytes = tx.to_wire();
1087
1088        // Test validation
1089        assert!(validate_wire_transaction(&bytes).is_ok());
1090
1091        // Test round-trip serialization
1092        let decoded_tx = Transaction::from_wire(&bytes).unwrap();
1093        assert_eq!(decoded_tx.fee_payer, tx.fee_payer);
1094        assert_eq!(decoded_tx.program, tx.program);
1095        assert_eq!(decoded_tx.rw_accs, tx.rw_accs);
1096        assert_eq!(decoded_tx.r_accs, tx.r_accs);
1097        assert_eq!(decoded_tx.instructions, tx.instructions);
1098        assert_eq!(decoded_tx.flags, tx.flags);
1099        assert!(decoded_tx.has_fee_payer_state_proof());
1100
1101        let decoded_proof = decoded_tx.fee_payer_state_proof.unwrap();
1102        assert_eq!(decoded_proof.slot(), 200);
1103        assert_eq!(decoded_proof.path_bitset(), &path_bitset);
1104    }
1105
1106    #[test]
1107    fn test_transaction_without_state_proof() {
1108        let signing_key = SigningKey::from(&[1u8; 32]);
1109        let verifying_key = signing_key.verifying_key();
1110
1111        let mut tx = Transaction::new(verifying_key.to_bytes(), [2u8; 32], 100, 42);
1112
1113        // Verify flag is not set
1114        assert!(!tx.has_fee_payer_state_proof());
1115        assert_eq!(tx.flags & (1 << TN_TXN_FLAG_HAS_FEE_PAYER_PROOF_BIT), 0);
1116        assert!(tx.fee_payer_state_proof.is_none());
1117
1118        tx.sign(&signing_key.to_bytes()).unwrap();
1119        let bytes = tx.to_wire();
1120
1121        assert!(validate_wire_transaction(&bytes).is_ok());
1122
1123        // Test deserialization
1124        let decoded_tx = Transaction::from_wire(&bytes).unwrap();
1125        assert!(!decoded_tx.has_fee_payer_state_proof());
1126        assert!(decoded_tx.fee_payer_state_proof.is_none());
1127    }
1128
1129    #[test]
1130    fn test_transaction_remove_state_proof() {
1131        use crate::tn_state_proof::StateProof;
1132
1133        let signing_key = SigningKey::from(&[1u8; 32]);
1134        let verifying_key = signing_key.verifying_key();
1135
1136        // Create a CREATION state proof
1137        let path_bitset = [0u8; 32];
1138        let existing_leaf_pubkey = [7u8; 32];
1139        let existing_leaf_hash = [8u8; 32];
1140        let state_proof = StateProof::creation(
1141            100,
1142            path_bitset,
1143            existing_leaf_pubkey,
1144            existing_leaf_hash,
1145            vec![],
1146        );
1147
1148        // Create transaction with state proof, then remove it
1149        let tx = Transaction::new(verifying_key.to_bytes(), [2u8; 32], 100, 42)
1150            .with_fee_payer_state_proof(&state_proof)
1151            .without_fee_payer_state_proof();
1152
1153        // Verify flag is cleared and state proof is removed
1154        assert!(!tx.has_fee_payer_state_proof());
1155        assert_eq!(tx.flags & (1 << TN_TXN_FLAG_HAS_FEE_PAYER_PROOF_BIT), 0);
1156        assert!(tx.fee_payer_state_proof.is_none());
1157    }
1158}