Skip to main content

lightcone_sdk/program/
utils.rs

1//! Utility functions for the Lightcone Pinocchio SDK.
2//!
3//! This module provides helper functions for ATA derivation, validation, and string serialization.
4
5use solana_sdk::pubkey::Pubkey;
6
7use crate::program::constants::{MAX_OUTCOMES, MIN_OUTCOMES};
8use crate::program::error::{SdkError, SdkResult};
9
10// ============================================================================
11// Associated Token Account Helpers
12// ============================================================================
13
14/// Get the Associated Token Address for a wallet and mint.
15///
16/// Uses the standard Solana ATA derivation.
17pub fn get_associated_token_address(
18    wallet: &Pubkey,
19    mint: &Pubkey,
20    token_program_id: &Pubkey,
21) -> Pubkey {
22    let ata_program_id = spl_associated_token_account::id();
23
24    Pubkey::find_program_address(
25        &[
26            wallet.as_ref(),
27            token_program_id.as_ref(),
28            mint.as_ref(),
29        ],
30        &ata_program_id,
31    )
32    .0
33}
34
35/// Get the ATA for a conditional token (using Token-2022).
36pub fn get_conditional_token_ata(wallet: &Pubkey, mint: &Pubkey) -> Pubkey {
37    get_associated_token_address(wallet, mint, &spl_token_2022::id())
38}
39
40/// Get the ATA for a deposit token (using SPL Token).
41pub fn get_deposit_token_ata(wallet: &Pubkey, mint: &Pubkey) -> Pubkey {
42    get_associated_token_address(wallet, mint, &spl_token::id())
43}
44
45// ============================================================================
46// Validation Helpers
47// ============================================================================
48
49/// Validate that the number of outcomes is within the allowed range.
50pub fn validate_outcome_count(num_outcomes: u8) -> SdkResult<()> {
51    if !(MIN_OUTCOMES..=MAX_OUTCOMES).contains(&num_outcomes) {
52        return Err(SdkError::InvalidOutcomeCount { count: num_outcomes });
53    }
54    Ok(())
55}
56
57/// Validate that an outcome index is valid for the given number of outcomes.
58pub fn validate_outcome_index(outcome_index: u8, num_outcomes: u8) -> SdkResult<()> {
59    if outcome_index >= num_outcomes {
60        return Err(SdkError::InvalidOutcomeIndex {
61            index: outcome_index,
62            max: num_outcomes.saturating_sub(1),
63        });
64    }
65    Ok(())
66}
67
68/// Validate that a buffer is exactly 32 bytes.
69pub fn validate_32_bytes(buffer: &[u8]) -> SdkResult<()> {
70    if buffer.len() != 32 {
71        return Err(SdkError::InvalidDataLength {
72            expected: 32,
73            actual: buffer.len(),
74        });
75    }
76    Ok(())
77}
78
79// ============================================================================
80// String Serialization
81// ============================================================================
82
83/// Serialize a string with a u16 length prefix.
84///
85/// Format: [length (2 bytes LE)][utf-8 bytes]
86pub fn serialize_string(s: &str) -> Vec<u8> {
87    let bytes = s.as_bytes();
88    let len = bytes.len() as u16;
89    let mut result = Vec::with_capacity(2 + bytes.len());
90    result.extend_from_slice(&len.to_le_bytes());
91    result.extend_from_slice(bytes);
92    result
93}
94
95/// Deserialize a string with a u16 length prefix.
96///
97/// Returns the string and the number of bytes consumed.
98pub fn deserialize_string(data: &[u8]) -> SdkResult<(String, usize)> {
99    if data.len() < 2 {
100        return Err(SdkError::InvalidDataLength {
101            expected: 2,
102            actual: data.len(),
103        });
104    }
105
106    let len = u16::from_le_bytes([data[0], data[1]]) as usize;
107
108    if data.len() < 2 + len {
109        return Err(SdkError::InvalidDataLength {
110            expected: 2 + len,
111            actual: data.len(),
112        });
113    }
114
115    let s = String::from_utf8(data[2..2 + len].to_vec())
116        .map_err(|e| SdkError::Serialization(e.to_string()))?;
117
118    Ok((s, 2 + len))
119}
120
121// ============================================================================
122// Metadata Serialization
123// ============================================================================
124
125/// Outcome metadata for conditional token creation.
126#[derive(Debug, Clone)]
127pub struct OutcomeMetadataInput {
128    /// Token name (max 32 chars)
129    pub name: String,
130    /// Token symbol (max 10 chars)
131    pub symbol: String,
132    /// Token URI (max 200 chars)
133    pub uri: String,
134}
135
136/// Serialize outcome metadata for the add_deposit_mint instruction.
137pub fn serialize_outcome_metadata(metadata: &[OutcomeMetadataInput]) -> Vec<u8> {
138    let mut result = Vec::new();
139
140    for m in metadata {
141        result.extend(serialize_string(&m.name));
142        result.extend(serialize_string(&m.symbol));
143        result.extend(serialize_string(&m.uri));
144    }
145
146    result
147}
148
149// ============================================================================
150// Checked Arithmetic
151// ============================================================================
152
153/// Multiply two u64 values and check for overflow.
154pub fn checked_mul_u64(a: u64, b: u64) -> SdkResult<u64> {
155    a.checked_mul(b).ok_or(SdkError::Overflow)
156}
157
158/// Divide two u64 values and check for division by zero.
159pub fn checked_div_u64(a: u64, b: u64) -> SdkResult<u64> {
160    if b == 0 {
161        return Err(SdkError::Overflow);
162    }
163    Ok(a / b)
164}
165
166/// Add two u64 values and check for overflow.
167pub fn checked_add_u64(a: u64, b: u64) -> SdkResult<u64> {
168    a.checked_add(b).ok_or(SdkError::Overflow)
169}
170
171/// Subtract two u64 values and check for underflow.
172pub fn checked_sub_u64(a: u64, b: u64) -> SdkResult<u64> {
173    a.checked_sub(b).ok_or(SdkError::Overflow)
174}
175
176#[cfg(test)]
177mod tests {
178    use super::*;
179
180    #[test]
181    fn test_validate_outcome_count() {
182        assert!(validate_outcome_count(2).is_ok());
183        assert!(validate_outcome_count(3).is_ok());
184        assert!(validate_outcome_count(6).is_ok());
185        assert!(validate_outcome_count(1).is_err());
186        assert!(validate_outcome_count(7).is_err());
187        assert!(validate_outcome_count(0).is_err());
188    }
189
190    #[test]
191    fn test_validate_outcome_index() {
192        assert!(validate_outcome_index(0, 3).is_ok());
193        assert!(validate_outcome_index(1, 3).is_ok());
194        assert!(validate_outcome_index(2, 3).is_ok());
195        assert!(validate_outcome_index(3, 3).is_err());
196        assert!(validate_outcome_index(4, 3).is_err());
197    }
198
199    #[test]
200    fn test_string_serialization_roundtrip() {
201        let original = "Hello, World!";
202        let serialized = serialize_string(original);
203        let (deserialized, consumed) = deserialize_string(&serialized).unwrap();
204
205        assert_eq!(original, deserialized);
206        assert_eq!(consumed, serialized.len());
207    }
208
209    #[test]
210    fn test_string_serialization_empty() {
211        let original = "";
212        let serialized = serialize_string(original);
213        let (deserialized, consumed) = deserialize_string(&serialized).unwrap();
214
215        assert_eq!(original, deserialized);
216        assert_eq!(consumed, 2); // Just the length prefix
217    }
218
219    #[test]
220    fn test_string_serialization_unicode() {
221        let original = "Hello, δΈ–η•Œ! 🌍";
222        let serialized = serialize_string(original);
223        let (deserialized, consumed) = deserialize_string(&serialized).unwrap();
224
225        assert_eq!(original, deserialized);
226        assert_eq!(consumed, serialized.len());
227    }
228
229    #[test]
230    fn test_outcome_metadata_serialization() {
231        let metadata = vec![
232            OutcomeMetadataInput {
233                name: "Yes".to_string(),
234                symbol: "YES".to_string(),
235                uri: "https://example.com/yes".to_string(),
236            },
237            OutcomeMetadataInput {
238                name: "No".to_string(),
239                symbol: "NO".to_string(),
240                uri: "https://example.com/no".to_string(),
241            },
242        ];
243
244        let serialized = serialize_outcome_metadata(&metadata);
245
246        // Verify it's not empty and has reasonable length
247        assert!(!serialized.is_empty());
248
249        // First string should be "Yes" (len=3)
250        assert_eq!(u16::from_le_bytes([serialized[0], serialized[1]]), 3);
251    }
252
253    #[test]
254    fn test_checked_arithmetic() {
255        assert_eq!(checked_mul_u64(100, 200).unwrap(), 20000);
256        assert_eq!(checked_div_u64(200, 100).unwrap(), 2);
257        assert_eq!(checked_add_u64(100, 200).unwrap(), 300);
258        assert_eq!(checked_sub_u64(200, 100).unwrap(), 100);
259
260        // Overflow cases
261        assert!(checked_mul_u64(u64::MAX, 2).is_err());
262        assert!(checked_div_u64(100, 0).is_err());
263        assert!(checked_add_u64(u64::MAX, 1).is_err());
264        assert!(checked_sub_u64(0, 1).is_err());
265    }
266
267    #[test]
268    fn test_ata_derivation() {
269        let wallet = Pubkey::new_unique();
270        let mint = Pubkey::new_unique();
271
272        // Should not panic and should return a valid pubkey
273        let ata = get_conditional_token_ata(&wallet, &mint);
274        assert_ne!(ata, Pubkey::default());
275
276        let ata2 = get_deposit_token_ata(&wallet, &mint);
277        assert_ne!(ata2, Pubkey::default());
278
279        // Different token programs should produce different ATAs
280        assert_ne!(ata, ata2);
281    }
282}