Skip to main content

thru_base/
txn_lib.rs

1//! Transaction library: normal Rust struct, signing, serialization, accessors
2//!
3
4pub type TnPubkey = [u8; 32];
5pub type TnHash = [u8; 32];
6pub type TnSignature = [u8; 64];
7
8use crate::{
9    tn_signature::{sign_transaction, verify_transaction},
10    tn_state_proof::StateProof,
11    StateProofType,
12};
13use bytemuck::{Pod, Zeroable, bytes_of, from_bytes};
14
15pub const TN_TXN_FLAG_HAS_FEE_PAYER_PROOF_BIT: u8 = 0; // Bit position (matching C #define TN_TXN_FLAG_HAS_FEE_PAYER_PROOF (0U))
16pub const TN_TXN_FLAG_MAY_COMPRESS_ACCOUNT_BIT: u8 = 1; // Bit position (matching C #define TN_TXN_FLAG_MAY_COMPRESS_ACCOUNT (1U))
17
18// State proof type constants (matching C implementation)
19pub const TN_STATE_PROOF_TYPE_EXISTING: u64 = 0x0;
20pub const TN_STATE_PROOF_TYPE_UPDATING: u64 = 0x1;
21pub const TN_STATE_PROOF_TYPE_CREATION: u64 = 0x2;
22
23// State proof header size constants
24pub const TN_STATE_PROOF_HDR_SIZE: usize = 40; // 8 bytes type_slot + 32 bytes path_bitset
25pub const TN_ACCOUNT_META_FOOTPRINT: usize = 64; // Size of tn_account_meta_t (matching C sizeof)
26
27// TEMPORARY: Minimal local RpcError for test pass (remove when shared error type is available)
28#[derive(Debug, PartialEq)]
29pub enum RpcError {
30    InvalidTransactionSize { size: usize, max_size: usize },
31    TrailingBytes { expected: usize, found: usize },
32    TooManyAccounts { count: usize, max_count: usize },
33    InvalidTransactionSignature,
34    InvalidParams(&'static str),
35    InvalidFormat,
36    InvalidVersion,
37    InvalidFlags,
38    InvalidFeePayerStateProofType,
39    InvalidChainId,
40}
41
42impl std::fmt::Display for RpcError {
43    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
44        match self {
45            RpcError::InvalidTransactionSize { size, max_size } => {
46                write!(
47                    f,
48                    "Transaction size {} exceeds maximum allowed size {}",
49                    size, max_size
50                )
51            }
52            RpcError::TrailingBytes { expected, found } => {
53                write!(
54                    f,
55                    "Transaction has trailing bytes: expected {} bytes, found {} bytes",
56                    expected, found
57                )
58            }
59            RpcError::TooManyAccounts { count, max_count } => {
60                write!(
61                    f,
62                    "Too many accounts: {} exceeds maximum {}",
63                    count, max_count
64                )
65            }
66            RpcError::InvalidTransactionSignature => {
67                write!(f, "Invalid transaction signature")
68            }
69            RpcError::InvalidParams(msg) => {
70                write!(f, "Invalid parameters: {}", msg)
71            }
72            RpcError::InvalidFormat => {
73                write!(f, "Invalid transaction format")
74            }
75            RpcError::InvalidVersion => {
76                write!(f, "Invalid transaction version")
77            }
78            RpcError::InvalidFlags => {
79                write!(f, "Invalid transaction flags")
80            }
81            RpcError::InvalidFeePayerStateProofType => {
82                write!(f, "Invalid fee payer state proof type")
83            }
84            RpcError::InvalidChainId => {
85                write!(f, "Invalid chain ID: chain_id cannot be zero")
86            }
87        }
88    }
89}
90
91impl RpcError {
92    pub fn invalid_transaction_size(size: usize, max_size: usize) -> Self {
93        Self::InvalidTransactionSize { size, max_size }
94    }
95    pub fn trailing_bytes(expected: usize, found: usize) -> Self {
96        Self::TrailingBytes { expected, found }
97    }
98    pub fn too_many_accounts(count: usize, max_count: usize) -> Self {
99        Self::TooManyAccounts { count, max_count }
100    }
101    pub fn invalid_transaction_signature() -> Self {
102        Self::InvalidTransactionSignature
103    }
104    pub fn invalid_params(msg: &'static str) -> Self {
105        Self::InvalidParams(msg)
106    }
107    pub fn invalid_format() -> Self {
108        Self::InvalidFormat
109    }
110    pub fn invalid_version() -> Self {
111        Self::InvalidVersion
112    }
113    pub fn invalid_flags() -> Self {
114        Self::InvalidFlags
115    }
116    pub fn invalid_fee_payer_state_proof_type() -> Self {
117        Self::InvalidFeePayerStateProofType
118    }
119    pub fn invalid_chain_id() -> Self {
120        Self::InvalidChainId
121    }
122}
123
124/// On-wire transaction header (matches TnTxnHdrV1 layout)
125///
126/// Transaction wire format:
127///   [header (112 bytes)]
128///   [input_pubkeys (variable)]
129///   [instr_data (variable)]
130///   [state_proof (optional)]
131///   [account_meta (optional)]
132///   [fee_payer_signature (64 bytes)]
133#[repr(C)]
134#[derive(Clone, Copy, Debug)]
135pub struct WireTxnHdrV1 {
136    pub transaction_version: u8,
137    pub flags: u8,
138    pub readwrite_accounts_cnt: u16,
139    pub readonly_accounts_cnt: u16,
140    pub instr_data_sz: u16,
141    pub req_compute_units: u32,
142    pub req_state_units: u16,
143    pub req_memory_units: u16,
144    pub fee: u64,
145    pub nonce: u64,
146    pub start_slot: u64,
147    pub expiry_after: u32,
148    pub chain_id: u16,
149    pub padding_0: [u8; 2],
150    pub fee_payer_pubkey: [u8; 32],
151    pub program_pubkey: [u8; 32],
152}
153
154/// Size of the signature in bytes (always at end of transaction)
155pub const TN_TXN_SIGNATURE_SZ: usize = 64;
156
157impl Default for WireTxnHdrV1 {
158    fn default() -> Self {
159        Self {
160            transaction_version: 0,
161            flags: 0,
162            readwrite_accounts_cnt: 0,
163            readonly_accounts_cnt: 0,
164            instr_data_sz: 0,
165            req_compute_units: 0,
166            req_state_units: 0,
167            req_memory_units: 0,
168            fee: 0,
169            nonce: 0,
170            start_slot: 0,
171            expiry_after: 0,
172            chain_id: 0,
173            padding_0: [0u8; 2],
174            fee_payer_pubkey: [0u8; 32],
175            program_pubkey: [0u8; 32],
176        }
177    }
178}
179
180// Manual Pod implementation to avoid derive issues
181unsafe impl Pod for WireTxnHdrV1 {}
182unsafe impl Zeroable for WireTxnHdrV1 {}
183
184/// Normal Rust struct for transaction construction
185#[derive(Clone, Debug, Default)]
186pub struct Transaction {
187    // Core transaction fields
188    pub fee_payer: TnPubkey, // [u8; 32] - who pays the fee
189    pub program: TnPubkey,   // [u8; 32] - target program
190
191    // Account lists (optional)
192    pub rw_accs: Option<Vec<TnPubkey>>, // read-write accounts
193    pub r_accs: Option<Vec<TnPubkey>>,  // read-only accounts
194
195    // Instruction data (optional)
196    pub instructions: Option<Vec<u8>>, // instruction bytes
197
198    // Transaction parameters
199    pub fee: u64,               // transaction fee
200    pub req_compute_units: u32, // requested compute units
201    pub req_state_units: u16,   // requested state units
202    pub req_memory_units: u16,  // requested memory units
203    pub expiry_after: u32,      // expiry time offset
204    pub start_slot: u64,        // starting slot
205    pub nonce: u64,             // transaction nonce
206    pub flags: u8,              // transaction flags
207    pub chain_id: u16,          // chain identifier (must be non-zero)
208
209    // Signature (optional until signed)
210    pub signature: Option<TnSignature>, // [u8; 64] - Ed25519 signature
211
212    // Fee payer state proof (optional)
213    pub fee_payer_state_proof: Option<StateProof>, // State proof for fee payer account
214
215    pub fee_payer_account_meta_raw: Option<Vec<u8>>,
216}
217
218impl Transaction {
219    /// Create a new unsigned transaction
220    pub fn new(fee_payer: TnPubkey, program: TnPubkey, fee: u64, nonce: u64) -> Self {
221        Self {
222            fee_payer,
223            program,
224            rw_accs: None,
225            r_accs: None,
226            instructions: None,
227            fee,
228            req_compute_units: 0,
229            req_state_units: 0,
230            req_memory_units: 0,
231            expiry_after: 0,
232            start_slot: 0,
233            nonce,
234            flags: 0,
235            chain_id: 1,
236            signature: None,
237            fee_payer_state_proof: None,
238            fee_payer_account_meta_raw: None,
239        }
240    }
241
242    /// Create a minimal transaction with just a program ID and instruction data.
243    /// Used for fake event transactions that don't need accounts or fees.
244    pub fn new_raw_instruction(
245        fee_payer: &TnPubkey,
246        program: &TnPubkey,
247        instruction_data: &[u8],
248    ) -> Result<Self, Box<dyn std::error::Error>> {
249        Ok(Self {
250            fee_payer: *fee_payer,
251            program: *program,
252            rw_accs: None,
253            r_accs: None,
254            instructions: Some(instruction_data.to_vec()),
255            fee: 0,
256            req_compute_units: 0,
257            req_state_units: 0,
258            req_memory_units: 0,
259            expiry_after: 0,
260            start_slot: 0,
261            nonce: 0,
262            flags: 0,
263            chain_id: 1,
264            signature: None,
265            fee_payer_state_proof: None,
266            fee_payer_account_meta_raw: None,
267        })
268    }
269
270    pub fn has_fee_payer_state_proof(&self) -> bool {
271        (self.flags & (1 << TN_TXN_FLAG_HAS_FEE_PAYER_PROOF_BIT)) != 0
272    }
273    pub fn may_compress_account(&self) -> bool {
274        (self.flags & (1 << TN_TXN_FLAG_MAY_COMPRESS_ACCOUNT_BIT)) != 0
275    }
276
277    pub fn get_signature(&self) -> Option<crate::Signature> {
278        if let Some(sig) = &self.signature {
279            return Some(crate::Signature::from_bytes(&sig));
280        }
281        None
282    }
283
284    pub fn with_may_compress_account(mut self) -> Self {
285        self.flags |= 1 << TN_TXN_FLAG_MAY_COMPRESS_ACCOUNT_BIT;
286        self
287    }
288
289    /// Builder method: set fee payer state proof
290    pub fn with_fee_payer_state_proof(mut self, state_proof: &StateProof) -> Self {
291        self.fee_payer_state_proof = Some(state_proof.clone());
292        // Set the flag bit to indicate presence of state proof
293        self.flags |= 1 << TN_TXN_FLAG_HAS_FEE_PAYER_PROOF_BIT;
294        self
295    }
296
297    /// Builder method: set fee payer account meta as raw bytes
298    pub fn with_fee_payer_account_meta_raw(mut self, account_meta_raw: Vec<u8>) -> Self {
299        self.fee_payer_account_meta_raw = Some(account_meta_raw);
300        self
301    }
302
303    /// Builder method: remove fee payer state proof
304    pub fn without_fee_payer_state_proof(mut self) -> Self {
305        self.fee_payer_state_proof = None;
306        // Clear the flag bit to indicate absence of state proof
307        self.flags &= !(1 << TN_TXN_FLAG_HAS_FEE_PAYER_PROOF_BIT);
308        self
309    }
310
311    /// Builder method: add read-write accounts
312    pub fn with_rw_accounts(mut self, accounts: Vec<TnPubkey>) -> Self {
313        self.rw_accs = Some(accounts);
314        self
315    }
316
317    /// Builder method: add read-only accounts
318    pub fn with_r_accounts(mut self, accounts: Vec<TnPubkey>) -> Self {
319        self.r_accs = Some(accounts);
320        self
321    }
322
323    /// Builder method: add a single read-write account
324    pub fn add_rw_account(mut self, account: TnPubkey) -> Self {
325        match self.rw_accs {
326            Some(ref mut accounts) => accounts.push(account),
327            None => self.rw_accs = Some(vec![account]),
328        }
329        self
330    }
331
332    /// Builder method: add a single read-only account
333    pub fn add_r_account(mut self, account: TnPubkey) -> Self {
334        match self.r_accs {
335            Some(ref mut accounts) => accounts.push(account),
336            None => self.r_accs = Some(vec![account]),
337        }
338        self
339    }
340
341    /// Builder method: add instruction data
342    pub fn with_instructions(mut self, instructions: Vec<u8>) -> Self {
343        self.instructions = Some(instructions);
344        self
345    }
346
347    /// Builder method: set compute units
348    pub fn with_compute_units(mut self, units: u32) -> Self {
349        self.req_compute_units = units;
350        self
351    }
352
353    /// Builder method: set state units
354    pub fn with_state_units(mut self, units: u16) -> Self {
355        self.req_state_units = units;
356        self
357    }
358
359    /// Builder method: set memory units
360    pub fn with_memory_units(mut self, units: u16) -> Self {
361        self.req_memory_units = units;
362        self
363    }
364
365    /// Builder method: set expiry
366    pub fn with_expiry_after(mut self, expiry: u32) -> Self {
367        self.expiry_after = expiry;
368        self
369    }
370
371    /// Builder method: set nonce
372    pub fn with_nonce(mut self, nonce: u64) -> Self {
373        self.nonce = nonce;
374        self
375    }
376
377    /// Builder method: set start slot
378    pub fn with_start_slot(mut self, slot: u64) -> Self {
379        self.start_slot = slot;
380        self
381    }
382
383    /// Builder method: set chain ID
384    pub fn with_chain_id(mut self, chain_id: u16) -> Self {
385        self.chain_id = chain_id;
386        self
387    }
388
389    /// Sign the transaction with a 32-byte Ed25519 private key
390    pub fn sign(&mut self, private_key: &[u8; 32]) -> Result<(), Box<dyn std::error::Error>> {
391        let wire_bytes = self.to_wire_for_signing();
392        let sig = sign_transaction(&wire_bytes, &self.fee_payer, private_key)
393            .map_err(|e| Box::<dyn std::error::Error>::from(e))?;
394        self.signature = Some(sig);
395        Ok(())
396    }
397
398    /// Verify the transaction signature
399    pub fn verify(&self) -> bool {
400        if let Some(sig_bytes) = &self.signature {
401            let wire_bytes = self.to_wire_for_signing();
402            return verify_transaction(&wire_bytes, sig_bytes, &self.fee_payer).is_ok();
403        }
404        false
405    }
406
407    /// Create wire format for signing (excluding trailing signature)
408    /// The message is: header + accounts + instr_data + state_proof + account_meta
409    fn to_wire_for_signing(&self) -> Vec<u8> {
410        // Zero out all bytes first to ensure deterministic padding
411        let mut wire: WireTxnHdrV1 = unsafe { core::mem::zeroed() };
412        wire.transaction_version = 1;
413        wire.flags = self.flags;
414        wire.readwrite_accounts_cnt = self.rw_accs.as_ref().map_or(0, |v| v.len() as u16);
415        wire.readonly_accounts_cnt = self.r_accs.as_ref().map_or(0, |v| v.len() as u16);
416        wire.instr_data_sz = self.instructions.as_ref().map_or(0, |v| v.len() as u16);
417        wire.req_compute_units = self.req_compute_units;
418        wire.req_state_units = self.req_state_units;
419        wire.req_memory_units = self.req_memory_units;
420        wire.expiry_after = self.expiry_after;
421        wire.chain_id = self.chain_id;
422        wire.fee = self.fee;
423        wire.nonce = self.nonce;
424        wire.start_slot = self.start_slot;
425        wire.fee_payer_pubkey = self.fee_payer;
426        wire.program_pubkey = self.program;
427
428        let mut result = bytes_of(&wire).to_vec();
429
430        // Append variable-length data
431        if let Some(ref rw_accs) = self.rw_accs {
432            for acc in rw_accs {
433                result.extend_from_slice(acc);
434            }
435        }
436
437        if let Some(ref r_accs) = self.r_accs {
438            for acc in r_accs {
439                result.extend_from_slice(acc);
440            }
441        }
442
443        if let Some(ref instructions) = self.instructions {
444            result.extend_from_slice(instructions);
445        }
446
447        // Append state proof if present
448        if let Some(ref state_proof) = self.fee_payer_state_proof {
449            result.extend_from_slice(&state_proof.to_wire());
450        }
451
452        // Use raw account meta if available, otherwise use structured account meta
453        if let Some(ref fee_payer_account_meta_raw) = self.fee_payer_account_meta_raw {
454            result.extend_from_slice(fee_payer_account_meta_raw);
455        }
456
457        result
458    }
459
460    /// Serialize to on-wire format (WireTxnHdrV1)
461    /// Wire format: header + accounts + instr_data + state_proof + account_meta + signature
462    pub fn to_wire(&self) -> Vec<u8> {
463        let mut wire = WireTxnHdrV1::default();
464        wire.transaction_version = 1;
465        wire.flags = self.flags;
466        wire.readwrite_accounts_cnt = self.rw_accs.as_ref().map_or(0, |v| v.len() as u16);
467        wire.readonly_accounts_cnt = self.r_accs.as_ref().map_or(0, |v| v.len() as u16);
468        wire.instr_data_sz = self.instructions.as_ref().map_or(0, |v| v.len() as u16);
469        wire.req_compute_units = self.req_compute_units;
470        wire.req_state_units = self.req_state_units;
471        wire.req_memory_units = self.req_memory_units;
472        wire.expiry_after = self.expiry_after;
473        wire.chain_id = self.chain_id;
474        wire.fee = self.fee;
475        wire.nonce = self.nonce;
476        wire.start_slot = self.start_slot;
477        wire.fee_payer_pubkey = self.fee_payer;
478        wire.program_pubkey = self.program;
479
480        let mut result = bytes_of(&wire).to_vec();
481
482        // Append variable-length data
483        if let Some(ref rw_accs) = self.rw_accs {
484            for acc in rw_accs {
485                result.extend_from_slice(acc);
486            }
487        }
488
489        if let Some(ref r_accs) = self.r_accs {
490            for acc in r_accs {
491                result.extend_from_slice(acc);
492            }
493        }
494
495        if let Some(ref instructions) = self.instructions {
496            result.extend_from_slice(instructions);
497        }
498
499        // Append state proof if present (after instruction data)
500        if let Some(ref state_proof) = self.fee_payer_state_proof {
501            result.extend_from_slice(&state_proof.to_wire());
502        }
503
504        // Use raw account meta if available, otherwise use structured account meta
505        if let Some(ref fee_payer_account_meta_raw) = self.fee_payer_account_meta_raw {
506            result.extend_from_slice(fee_payer_account_meta_raw);
507        }
508
509        // Append signature at the END (last 64 bytes)
510        if let Some(sig) = &self.signature {
511            result.extend_from_slice(sig);
512        } else {
513            // Zero signature if not set
514            result.extend_from_slice(&[0u8; TN_TXN_SIGNATURE_SZ]);
515        }
516
517        result
518    }
519
520    /// Deserialize from on-wire format (WireTxnHdrV1)
521    /// Wire format: header + accounts + instr_data + state_proof + account_meta + signature (at end)
522    pub fn from_wire(bytes: &[u8]) -> Option<Self> {
523        // Minimum size: header + signature at end
524        if bytes.len() < core::mem::size_of::<WireTxnHdrV1>() + TN_TXN_SIGNATURE_SZ {
525            return None;
526        }
527
528        let wire: &WireTxnHdrV1 = from_bytes(&bytes[0..core::mem::size_of::<WireTxnHdrV1>()]);
529        let mut offset = core::mem::size_of::<WireTxnHdrV1>();
530
531        let sig_start = bytes.len() - TN_TXN_SIGNATURE_SZ;
532        let mut signature = [0u8; TN_TXN_SIGNATURE_SZ];
533        signature.copy_from_slice(&bytes[sig_start..]);
534
535        // Parse read-write accounts
536        let rw_accs = if wire.readwrite_accounts_cnt > 0 {
537            let mut accounts = Vec::new();
538            for _ in 0..wire.readwrite_accounts_cnt {
539                if offset + 32 > sig_start {
540                    return None;
541                }
542                let mut acc = [0u8; 32];
543                acc.copy_from_slice(&bytes[offset..offset + 32]);
544                accounts.push(acc);
545                offset += 32;
546            }
547            Some(accounts)
548        } else {
549            None
550        };
551
552        // Parse read-only accounts
553        let r_accs = if wire.readonly_accounts_cnt > 0 {
554            let mut accounts = Vec::new();
555            for _ in 0..wire.readonly_accounts_cnt {
556                if offset + 32 > sig_start {
557                    return None;
558                }
559                let mut acc = [0u8; 32];
560                acc.copy_from_slice(&bytes[offset..offset + 32]);
561                accounts.push(acc);
562                offset += 32;
563            }
564            Some(accounts)
565        } else {
566            None
567        };
568
569        // Parse instructions
570        let instructions = if wire.instr_data_sz > 0 {
571            if offset + wire.instr_data_sz as usize > sig_start {
572                return None;
573            }
574            let instr = bytes[offset..offset + wire.instr_data_sz as usize].to_vec();
575            offset += wire.instr_data_sz as usize;
576            Some(instr)
577        } else {
578            None
579        };
580
581        let mut fee_payer_account_meta_raw: Option<Vec<u8>> = None;
582        // Parse state proof if present
583        let fee_payer_state_proof = if has_fee_payer_state_proof(wire.flags) {
584            if offset >= sig_start {
585                return None;
586            }
587            let state_proof_bytes = &bytes[offset..sig_start];
588            if let Some(state_proof) = StateProof::from_wire(state_proof_bytes) {
589                offset += state_proof.footprint();
590                if state_proof.header.proof_type == StateProofType::Existing {
591                    if offset + TN_ACCOUNT_META_FOOTPRINT > sig_start {
592                        return None;
593                    }
594                    let account_meta_bytes = &bytes[offset..offset + TN_ACCOUNT_META_FOOTPRINT];
595                    fee_payer_account_meta_raw = Some(account_meta_bytes.to_vec());
596                    offset += TN_ACCOUNT_META_FOOTPRINT;
597                }
598                Some(state_proof)
599            } else {
600                return None;
601            }
602        } else {
603            None
604        };
605
606        // Verify we've consumed all bytes before signature
607        if offset != sig_start {
608            log::warn!(
609                "Transaction::from_wire: offset != sig_start ({} != {})",
610                offset,
611                sig_start
612            );
613            return None;
614        }
615
616        Some(Transaction {
617            fee_payer: wire.fee_payer_pubkey,
618            program: wire.program_pubkey,
619            rw_accs,
620            r_accs,
621            instructions,
622            flags: wire.flags,
623            chain_id: wire.chain_id,
624            fee: wire.fee,
625            req_compute_units: wire.req_compute_units,
626            req_state_units: wire.req_state_units,
627            req_memory_units: wire.req_memory_units,
628            expiry_after: wire.expiry_after,
629            start_slot: wire.start_slot,
630            nonce: wire.nonce,
631            signature: Some(signature),
632            fee_payer_state_proof,
633            fee_payer_account_meta_raw,
634        })
635    }
636
637    /// Accessor: read a field from serialized bytes by name
638    pub fn get_field_from_wire(bytes: &[u8], field: &str) -> Option<Vec<u8>> {
639        if bytes.len() < core::mem::size_of::<WireTxnHdrV1>() + TN_TXN_SIGNATURE_SZ {
640            return None;
641        }
642        let wire: &WireTxnHdrV1 = from_bytes(&bytes[0..core::mem::size_of::<WireTxnHdrV1>()]);
643        match field {
644            "fee_payer_signature" => {
645                let sig_start = bytes.len() - TN_TXN_SIGNATURE_SZ;
646                Some(bytes[sig_start..].to_vec())
647            }
648            "transaction_version" => Some(vec![wire.transaction_version]),
649            "flags" => Some(vec![wire.flags]),
650            "readwrite_accounts_cnt" => Some(wire.readwrite_accounts_cnt.to_le_bytes().to_vec()),
651            "readonly_accounts_cnt" => Some(wire.readonly_accounts_cnt.to_le_bytes().to_vec()),
652            "instr_data_sz" => Some(wire.instr_data_sz.to_le_bytes().to_vec()),
653            "req_compute_units" => Some(wire.req_compute_units.to_le_bytes().to_vec()),
654            "req_state_units" => Some(wire.req_state_units.to_le_bytes().to_vec()),
655            "req_memory_units" => Some(wire.req_memory_units.to_le_bytes().to_vec()),
656            "expiry_after" => Some(wire.expiry_after.to_le_bytes().to_vec()),
657            "chain_id" => Some(wire.chain_id.to_le_bytes().to_vec()),
658            "fee" => Some(wire.fee.to_le_bytes().to_vec()),
659            "nonce" => Some(wire.nonce.to_le_bytes().to_vec()),
660            "start_slot" => Some(wire.start_slot.to_le_bytes().to_vec()),
661            "fee_payer_pubkey" => Some(wire.fee_payer_pubkey.to_vec()),
662            "program_pubkey" => Some(wire.program_pubkey.to_vec()),
663            _ => None,
664        }
665    }
666}
667
668/// Helper function to check if transaction has fee payer state proof
669fn has_fee_payer_state_proof(flags: u8) -> bool {
670    (flags & (1 << TN_TXN_FLAG_HAS_FEE_PAYER_PROOF_BIT)) != 0
671}
672
673/// Helper function to extract state proof type from header
674fn extract_state_proof_type(type_slot: u64) -> u64 {
675    (type_slot >> 62) & 0x3 // Extract top 2 bits
676}
677
678/// Helper function to calculate state proof footprint from header
679fn calculate_state_proof_footprint(state_proof_data: &[u8]) -> Result<usize, RpcError> {
680    if state_proof_data.len() < TN_STATE_PROOF_HDR_SIZE {
681        return Err(RpcError::invalid_format());
682    }
683
684    // Extract type_slot (first 8 bytes)
685    let type_slot = u64::from_le_bytes([
686        state_proof_data[0],
687        state_proof_data[1],
688        state_proof_data[2],
689        state_proof_data[3],
690        state_proof_data[4],
691        state_proof_data[5],
692        state_proof_data[6],
693        state_proof_data[7],
694    ]);
695
696    // Extract path_bitset (next 32 bytes) and count set bits
697    let mut sibling_hash_cnt = 0u32;
698    for i in 0..4 {
699        let start = 8 + i * 8;
700        let word = u64::from_le_bytes([
701            state_proof_data[start],
702            state_proof_data[start + 1],
703            state_proof_data[start + 2],
704            state_proof_data[start + 3],
705            state_proof_data[start + 4],
706            state_proof_data[start + 5],
707            state_proof_data[start + 6],
708            state_proof_data[start + 7],
709        ]);
710        sibling_hash_cnt += word.count_ones();
711    }
712
713    let proof_type = extract_state_proof_type(type_slot);
714    let body_sz = (proof_type + sibling_hash_cnt as u64) * 32; // Each hash is 32 bytes
715
716    Ok(TN_STATE_PROOF_HDR_SIZE + body_sz as usize)
717}
718
719pub fn tn_txn_size(bytes: &[u8]) -> Result<usize, RpcError> {
720    // Basic size checks
721    if bytes.len() < core::mem::size_of::<WireTxnHdrV1>() + TN_TXN_SIGNATURE_SZ {
722        return Err(RpcError::invalid_format());
723    }
724
725    // Parse the header
726    // Use read_unaligned to safely read from potentially unaligned memory
727    let hdr: WireTxnHdrV1 =
728        unsafe { std::ptr::read_unaligned(bytes.as_ptr() as *const WireTxnHdrV1) };
729    let hdr = &hdr;
730    let mut offset = core::mem::size_of::<WireTxnHdrV1>();
731
732    // Calculate accounts size
733    let accs_sz = (hdr.readwrite_accounts_cnt as usize + hdr.readonly_accounts_cnt as usize) * 32;
734    if offset + accs_sz > bytes.len() {
735        return Err(RpcError::invalid_format());
736    }
737    offset += accs_sz;
738
739    // Calculate instruction data size
740    let instr_sz = hdr.instr_data_sz as usize;
741    if offset + instr_sz > bytes.len() {
742        return Err(RpcError::invalid_format());
743    }
744    offset += instr_sz;
745
746    // Handle fee payer state proof if present
747    if has_fee_payer_state_proof(hdr.flags) {
748        // Check state proof header size
749        if offset + TN_STATE_PROOF_HDR_SIZE > bytes.len() {
750            return Err(RpcError::invalid_format());
751        }
752
753        // Calculate state proof footprint
754        let state_proof_data = &bytes[offset..];
755        let state_proof_sz = calculate_state_proof_footprint(state_proof_data)?;
756
757        if offset + state_proof_sz > bytes.len() {
758            return Err(RpcError::invalid_format());
759        }
760        offset += state_proof_sz;
761
762        // Extract proof type for additional validation
763        let type_slot = u64::from_le_bytes([
764            state_proof_data[0],
765            state_proof_data[1],
766            state_proof_data[2],
767            state_proof_data[3],
768            state_proof_data[4],
769            state_proof_data[5],
770            state_proof_data[6],
771            state_proof_data[7],
772        ]);
773        let proof_type = extract_state_proof_type(type_slot);
774
775        // If proof type is EXISTING, account for account meta
776        if proof_type == TN_STATE_PROOF_TYPE_EXISTING {
777            if offset + TN_ACCOUNT_META_FOOTPRINT > bytes.len() {
778                return Err(RpcError::invalid_format());
779            }
780            offset += TN_ACCOUNT_META_FOOTPRINT;
781        }
782    }
783
784    // Add trailing signature size
785    offset += TN_TXN_SIGNATURE_SZ;
786
787    // Verify we don't exceed the provided bytes
788    if offset > bytes.len() {
789        return Err(RpcError::invalid_format());
790    }
791
792    Ok(offset)
793}
794
795/// Validate a wire-format transaction for protocol correctness (matching C tn_txn_parse_core).
796pub fn validate_wire_transaction(bytes: &[u8]) -> Result<(), RpcError> {
797    const TN_TXN_MTU: usize = 32_768;
798    // Version and flags are now at offset 0 and 1
799    const TN_TXN_VERSION_OFFSET: usize = 0;
800    const TN_TXN_FLAGS_OFFSET: usize = 1;
801
802    use bytemuck::from_bytes;
803
804    // 1. Check payload size
805    if bytes.len() > TN_TXN_MTU {
806        return Err(RpcError::invalid_transaction_size(bytes.len(), TN_TXN_MTU));
807    }
808
809    // 2. Check minimum size
810    if bytes.len() < core::mem::size_of::<WireTxnHdrV1>() + TN_TXN_SIGNATURE_SZ {
811        return Err(RpcError::invalid_format());
812    }
813
814    // 3. Check transaction version
815    let transaction_version = bytes[TN_TXN_VERSION_OFFSET];
816    if transaction_version != 0x01 {
817        return Err(RpcError::invalid_version());
818    }
819
820    // 4. Check flags
821    let flags = bytes[TN_TXN_FLAGS_OFFSET];
822    // Clear the fee payer proof bit and check that all other bits are 0
823    let flags_without_proof_bit = flags & !(1 << TN_TXN_FLAG_HAS_FEE_PAYER_PROOF_BIT);
824    let flags_cleared = flags_without_proof_bit & !(1 << TN_TXN_FLAG_MAY_COMPRESS_ACCOUNT_BIT);
825    if flags_cleared != 0 {
826        return Err(RpcError::invalid_flags());
827    }
828
829    // 5. Parse header
830    let hdr: &WireTxnHdrV1 = from_bytes(&bytes[0..core::mem::size_of::<WireTxnHdrV1>()]);
831    let mut offset = core::mem::size_of::<WireTxnHdrV1>();
832
833    let sig_start = bytes.len() - TN_TXN_SIGNATURE_SZ;
834
835    // 6. Validate chain_id is non-zero (matches C tn_txn_parse.c:37)
836    if hdr.chain_id == 0 {
837        return Err(RpcError::invalid_chain_id());
838    }
839
840    // 6. Parse accounts
841    let accs_sz = (hdr.readwrite_accounts_cnt as usize + hdr.readonly_accounts_cnt as usize) * 32;
842    if offset + accs_sz > sig_start {
843        return Err(RpcError::invalid_format());
844    }
845    offset += accs_sz;
846
847    // 7. Parse instruction data
848    let instr_sz = hdr.instr_data_sz as usize;
849    if offset + instr_sz > sig_start {
850        return Err(RpcError::invalid_format());
851    }
852    offset += instr_sz;
853
854    // 8. Handle fee payer state proof if present
855    if has_fee_payer_state_proof(flags) {
856        // Check state proof header size
857        if offset + TN_STATE_PROOF_HDR_SIZE > sig_start {
858            return Err(RpcError::invalid_format());
859        }
860
861        // Calculate state proof footprint
862        let state_proof_data = &bytes[offset..sig_start];
863        let state_proof_sz = calculate_state_proof_footprint(state_proof_data)?;
864
865        if offset + state_proof_sz > sig_start {
866            return Err(RpcError::invalid_format());
867        }
868
869        // Extract proof type and validate
870        let type_slot = u64::from_le_bytes([
871            state_proof_data[0],
872            state_proof_data[1],
873            state_proof_data[2],
874            state_proof_data[3],
875            state_proof_data[4],
876            state_proof_data[5],
877            state_proof_data[6],
878            state_proof_data[7],
879        ]);
880        let proof_type = extract_state_proof_type(type_slot);
881
882        // Check that proof type is not UPDATING
883        if proof_type == TN_STATE_PROOF_TYPE_UPDATING {
884            return Err(RpcError::invalid_fee_payer_state_proof_type());
885        }
886
887        offset += state_proof_sz;
888
889        // If proof type is EXISTING, expect account meta
890        if proof_type == TN_STATE_PROOF_TYPE_EXISTING {
891            if offset + TN_ACCOUNT_META_FOOTPRINT > sig_start {
892                return Err(RpcError::invalid_format());
893            }
894            offset += TN_ACCOUNT_META_FOOTPRINT;
895        }
896    }
897
898    // 9. Check for exact size match (offset should be at signature start)
899    if offset != sig_start {
900        return Err(RpcError::trailing_bytes(offset + TN_TXN_SIGNATURE_SZ, bytes.len()));
901    }
902
903    // 10. Signature check
904    let wire_for_signing = &bytes[..sig_start];
905    let signature = &bytes[sig_start..];
906    if verify_transaction(
907        wire_for_signing,
908        signature.try_into().expect("signature should be 64 bytes"),
909        &hdr.fee_payer_pubkey,
910    )
911    .is_err()
912    {
913        return Err(RpcError::invalid_transaction_signature());
914    }
915
916    Ok(())
917}
918
919#[cfg(test)]
920mod tests {
921    use super::*;
922    use ed25519_dalek::SigningKey;
923
924    fn make_valid_txn_bytes_with_flags(flags: u8) -> Vec<u8> {
925        let signing_key = SigningKey::from(&[1u8; 32]);
926        let verifying_key = signing_key.verifying_key();
927        let mut tx = Transaction::new(verifying_key.to_bytes(), [2u8; 32], 100, 42);
928        tx.rw_accs = Some(vec![[3u8; 32], [4u8; 32]]);
929        tx.r_accs = Some(vec![[5u8; 32]]);
930        tx.instructions = Some(vec![1, 2, 3, 4]);
931        tx.flags = flags;
932        tx.sign(&signing_key.to_bytes()).unwrap();
933        tx.to_wire()
934    }
935
936    fn make_valid_txn_bytes() -> Vec<u8> {
937        make_valid_txn_bytes_with_flags(0)
938    }
939
940    #[test]
941    fn test_tn_txn_size_basic_transaction() {
942        let bytes = make_valid_txn_bytes();
943        let calculated_size = tn_txn_size(&bytes).unwrap();
944
945        // The calculated size should match the actual bytes length
946        assert_eq!(calculated_size, bytes.len());
947    }
948
949    #[test]
950    fn test_tn_txn_size_with_state_proof() {
951        use crate::tn_state_proof::StateProof;
952
953        let signing_key = SigningKey::from(&[1u8; 32]);
954        let verifying_key = signing_key.verifying_key();
955
956        // Create a CREATION state proof
957        let path_bitset = [0u8; 32]; // No set bits = no sibling hashes
958        let existing_leaf_pubkey = [7u8; 32];
959        let existing_leaf_hash = [8u8; 32];
960        let state_proof = StateProof::creation(
961            100,
962            path_bitset,
963            existing_leaf_pubkey,
964            existing_leaf_hash,
965            vec![],
966        );
967
968        // Create transaction with state proof
969        let mut tx = Transaction::new(verifying_key.to_bytes(), [2u8; 32], 100, 42)
970            .with_rw_accounts(vec![[3u8; 32]])
971            .with_instructions(vec![1, 2, 3])
972            .with_fee_payer_state_proof(&state_proof);
973
974        tx.sign(&signing_key.to_bytes()).unwrap();
975        let bytes = tx.to_wire();
976
977        let calculated_size = tn_txn_size(&bytes).unwrap();
978
979        // The calculated size should match the actual bytes length
980        assert_eq!(calculated_size, bytes.len());
981    }
982
983    #[test]
984    fn test_tn_txn_size_minimal_transaction() {
985        let signing_key = SigningKey::from(&[1u8; 32]);
986        let verifying_key = signing_key.verifying_key();
987
988        // Create minimal transaction (no accounts, no instructions)
989        let mut tx = Transaction::new(verifying_key.to_bytes(), [2u8; 32], 100, 42);
990        tx.sign(&signing_key.to_bytes()).unwrap();
991        let bytes = tx.to_wire();
992
993        let calculated_size = tn_txn_size(&bytes).unwrap();
994
995        // The calculated size should match the actual bytes length
996        assert_eq!(calculated_size, bytes.len());
997
998        // Should be header size + trailing signature for minimal transaction
999        let expected_min_size = core::mem::size_of::<WireTxnHdrV1>() + TN_TXN_SIGNATURE_SZ;
1000        assert_eq!(calculated_size, expected_min_size);
1001    }
1002
1003    #[test]
1004    fn test_tn_txn_size_invalid_format() {
1005        // Test with bytes too short for header
1006        let short_bytes = vec![0u8; 50];
1007        let result = tn_txn_size(&short_bytes);
1008        assert!(matches!(result, Err(RpcError::InvalidFormat)));
1009
1010        // Test with header but missing account data
1011        let mut bytes = make_valid_txn_bytes();
1012        bytes.truncate(core::mem::size_of::<WireTxnHdrV1>() + 10); // Truncate to cause missing data
1013        let result = tn_txn_size(&bytes);
1014        assert!(matches!(result, Err(RpcError::InvalidFormat)));
1015    }
1016
1017    #[test]
1018    fn test_tn_txn_size_consistency_with_validation() {
1019        let bytes = make_valid_txn_bytes();
1020
1021        // Both functions should succeed for valid transactions
1022        assert!(validate_wire_transaction(&bytes).is_ok());
1023        assert!(tn_txn_size(&bytes).is_ok());
1024
1025        // Size should match actual length
1026        let calculated_size = tn_txn_size(&bytes).unwrap();
1027        assert_eq!(calculated_size, bytes.len());
1028    }
1029
1030    #[test]
1031    fn test_valid_transaction() {
1032        let bytes = make_valid_txn_bytes();
1033        assert!(validate_wire_transaction(&bytes).is_ok());
1034    }
1035
1036    #[test]
1037    fn test_oversize_transaction() {
1038        let mut bytes = make_valid_txn_bytes();
1039        bytes.resize(32_769, 0);
1040        let err = validate_wire_transaction(&bytes).unwrap_err();
1041        assert!(matches!(
1042            err,
1043            RpcError::InvalidTransactionSize {
1044                size: 32_769,
1045                max_size: 32_768
1046            }
1047        ));
1048    }
1049
1050    #[test]
1051    fn test_trailing_bytes() {
1052        let mut bytes = make_valid_txn_bytes();
1053        let orig_len = bytes.len();
1054        // Insert a byte before the signature (trailing bytes between body and signature)
1055        bytes.insert(orig_len - TN_TXN_SIGNATURE_SZ, 0);
1056        let err = validate_wire_transaction(&bytes).unwrap_err();
1057        // Expected size changes due to new wire format (112-byte header instead of 176)
1058        assert!(matches!(err, RpcError::TrailingBytes { .. }));
1059    }
1060
1061    #[test]
1062    fn test_invalid_transaction_version() {
1063        let mut bytes = make_valid_txn_bytes();
1064        // Corrupt the transaction version
1065        bytes[0] = 0x00; // Invalid version
1066        let err = validate_wire_transaction(&bytes).unwrap_err();
1067        assert!(matches!(err, RpcError::InvalidVersion));
1068    }
1069
1070    #[test]
1071    fn test_invalid_flags() {
1072        // Set invalid flag bits (keeping fee payer proof bit, but adding others)
1073        let bytes = make_valid_txn_bytes_with_flags(0x07);
1074        let err = validate_wire_transaction(&bytes).unwrap_err();
1075        assert!(matches!(err, RpcError::InvalidFlags));
1076    }
1077
1078    #[test]
1079    fn test_invalid_chain_id_zero() {
1080        let mut bytes = make_valid_txn_bytes();
1081
1082        /* Modify chain_id field to 0 (chain_id is at offset 108-109 in WireTxnHdrV1) */
1083        let hdr: &mut WireTxnHdrV1 = bytemuck::from_bytes_mut(&mut bytes[0..core::mem::size_of::<WireTxnHdrV1>()]);
1084        hdr.chain_id = 0;
1085
1086        let err = validate_wire_transaction(&bytes).unwrap_err();
1087        assert!(matches!(err, RpcError::InvalidChainId));
1088    }
1089
1090    #[test]
1091    fn test_transaction_too_short() {
1092        let bytes = vec![0u8; 50]; // Too short for header
1093        let err = validate_wire_transaction(&bytes).unwrap_err();
1094        assert!(matches!(err, RpcError::InvalidFormat));
1095    }
1096
1097    #[test]
1098    fn test_transaction_with_state_proof() {
1099        use crate::tn_state_proof::{StateProof, StateProofType};
1100
1101        let signing_key = SigningKey::from(&[1u8; 32]);
1102        let verifying_key = signing_key.verifying_key();
1103
1104        // Create a CREATION state proof (doesn't require account meta)
1105        let path_bitset = [0u8; 32]; // No set bits = no sibling hashes
1106        let existing_leaf_pubkey = [7u8; 32];
1107        let existing_leaf_hash = [8u8; 32];
1108        let state_proof = StateProof::creation(
1109            100,
1110            path_bitset,
1111            existing_leaf_pubkey,
1112            existing_leaf_hash,
1113            vec![],
1114        );
1115
1116        // Create transaction with state proof
1117        let mut tx = Transaction::new(verifying_key.to_bytes(), [2u8; 32], 100, 42)
1118            .with_fee_payer_state_proof(&state_proof);
1119
1120        // Verify flag is set
1121        assert!(tx.has_fee_payer_state_proof());
1122        assert_eq!(tx.flags & (1 << TN_TXN_FLAG_HAS_FEE_PAYER_PROOF_BIT), 1);
1123
1124        tx.sign(&signing_key.to_bytes()).unwrap();
1125        let bytes = tx.to_wire();
1126
1127        // Verify state proof is included in wire format
1128        assert!(bytes.len() > core::mem::size_of::<WireTxnHdrV1>() + TN_TXN_SIGNATURE_SZ);
1129        assert!(validate_wire_transaction(&bytes).is_ok());
1130
1131        // Test deserialization
1132        let decoded_tx = Transaction::from_wire(&bytes).unwrap();
1133        assert!(decoded_tx.has_fee_payer_state_proof());
1134        assert!(decoded_tx.fee_payer_state_proof.is_some());
1135
1136        let decoded_proof = decoded_tx.fee_payer_state_proof.unwrap();
1137        assert_eq!(decoded_proof.proof_type(), StateProofType::Creation);
1138        assert_eq!(decoded_proof.slot(), 100);
1139    }
1140
1141    #[test]
1142    fn test_transaction_with_state_proof_serialization_round_trip() {
1143        use crate::tn_state_proof::StateProof;
1144
1145        let signing_key = SigningKey::from(&[1u8; 32]);
1146        let verifying_key = signing_key.verifying_key();
1147
1148        // Create a creation state proof with some sibling hashes
1149        let mut path_bitset = [0u8; 32];
1150        path_bitset[0] = 0b11; // Set first 2 bits for 2 sibling hashes
1151        let existing_leaf_pubkey = [7u8; 32];
1152        let existing_leaf_hash = [8u8; 32];
1153        let sibling_hashes = vec![[9u8; 32], [10u8; 32]];
1154
1155        let state_proof = StateProof::creation(
1156            200,
1157            path_bitset,
1158            existing_leaf_pubkey,
1159            existing_leaf_hash,
1160            sibling_hashes.clone(),
1161        );
1162
1163        // Create transaction with complex state proof
1164        let mut tx = Transaction::new(verifying_key.to_bytes(), [2u8; 32], 100, 42)
1165            .with_rw_accounts(vec![[3u8; 32], [4u8; 32]])
1166            .with_r_accounts(vec![[5u8; 32]])
1167            .with_instructions(vec![1, 2, 3, 4])
1168            .with_fee_payer_state_proof(&state_proof);
1169
1170        tx.sign(&signing_key.to_bytes()).unwrap();
1171        let bytes = tx.to_wire();
1172
1173        // Test validation
1174        assert!(validate_wire_transaction(&bytes).is_ok());
1175
1176        // Test round-trip serialization
1177        let decoded_tx = Transaction::from_wire(&bytes).unwrap();
1178        assert_eq!(decoded_tx.fee_payer, tx.fee_payer);
1179        assert_eq!(decoded_tx.program, tx.program);
1180        assert_eq!(decoded_tx.rw_accs, tx.rw_accs);
1181        assert_eq!(decoded_tx.r_accs, tx.r_accs);
1182        assert_eq!(decoded_tx.instructions, tx.instructions);
1183        assert_eq!(decoded_tx.flags, tx.flags);
1184        assert!(decoded_tx.has_fee_payer_state_proof());
1185
1186        let decoded_proof = decoded_tx.fee_payer_state_proof.unwrap();
1187        assert_eq!(decoded_proof.slot(), 200);
1188        assert_eq!(decoded_proof.path_bitset(), &path_bitset);
1189    }
1190
1191    #[test]
1192    fn test_transaction_without_state_proof() {
1193        let signing_key = SigningKey::from(&[1u8; 32]);
1194        let verifying_key = signing_key.verifying_key();
1195
1196        let mut tx = Transaction::new(verifying_key.to_bytes(), [2u8; 32], 100, 42);
1197
1198        // Verify flag is not set
1199        assert!(!tx.has_fee_payer_state_proof());
1200        assert_eq!(tx.flags & (1 << TN_TXN_FLAG_HAS_FEE_PAYER_PROOF_BIT), 0);
1201        assert!(tx.fee_payer_state_proof.is_none());
1202
1203        tx.sign(&signing_key.to_bytes()).unwrap();
1204        let bytes = tx.to_wire();
1205
1206        assert!(validate_wire_transaction(&bytes).is_ok());
1207
1208        // Test deserialization
1209        let decoded_tx = Transaction::from_wire(&bytes).unwrap();
1210        assert!(!decoded_tx.has_fee_payer_state_proof());
1211        assert!(decoded_tx.fee_payer_state_proof.is_none());
1212    }
1213
1214    #[test]
1215    fn test_transaction_remove_state_proof() {
1216        use crate::tn_state_proof::StateProof;
1217
1218        let signing_key = SigningKey::from(&[1u8; 32]);
1219        let verifying_key = signing_key.verifying_key();
1220
1221        // Create a CREATION state proof
1222        let path_bitset = [0u8; 32];
1223        let existing_leaf_pubkey = [7u8; 32];
1224        let existing_leaf_hash = [8u8; 32];
1225        let state_proof = StateProof::creation(
1226            100,
1227            path_bitset,
1228            existing_leaf_pubkey,
1229            existing_leaf_hash,
1230            vec![],
1231        );
1232
1233        // Create transaction with state proof, then remove it
1234        let tx = Transaction::new(verifying_key.to_bytes(), [2u8; 32], 100, 42)
1235            .with_fee_payer_state_proof(&state_proof)
1236            .without_fee_payer_state_proof();
1237
1238        // Verify flag is cleared and state proof is removed
1239        assert!(!tx.has_fee_payer_state_proof());
1240        assert_eq!(tx.flags & (1 << TN_TXN_FLAG_HAS_FEE_PAYER_PROOF_BIT), 0);
1241        assert!(tx.fee_payer_state_proof.is_none());
1242    }
1243}