parasol_cpu/proc/
program.rs

1use std::{
2    collections::HashMap,
3    ffi::{CStr, CString},
4    ops::Index,
5};
6
7use crate::{Error, Result, proc::IsaOp, tomasulo::registers::RegisterName};
8use thiserror::Error;
9
10use elf::{ElfBytes, abi::STT_FUNC, endian::LittleEndian};
11
12// ABI version changes:
13// 1:
14//   - Organized instructions into groupings
15//   - Added rotl, rotr, neg, xor, addc, subb
16//   - Note that addc and subb are not currently implemented in the backend, but
17//     they do have defined opcodes.
18pub(crate) const SUPPORTED_ABI_VERSION: u8 = 1;
19
20enum OpCode {
21    // Types and loading
22    BindReadOnly,
23    BindReadWrite,
24    Load,
25    LoadI,
26    Store,
27    Zext,
28    Trunc,
29
30    // Arithmetic
31    Add,
32    AddC,
33    Sub,
34    SubB,
35    Mul,
36
37    // Shifts
38    Shl,
39    Rotl,
40    Shr,
41    Rotr,
42
43    // Logic
44    And,
45    Or,
46    Xor,
47    Not,
48    Neg,
49
50    // Comparison
51    Gt,
52    Ge,
53    Lt,
54    Le,
55    Eq,
56    Cmux,
57
58    // Control flow
59    Ret,
60    Unknown,
61}
62
63#[derive(Debug, Clone, Hash, PartialEq, Eq)]
64/// The name of an FHE program as it appears in an ELF file.
65pub struct Symbol {
66    /// ELF symbols aren't required to be UTF-8, so we must use a
67    /// CString to represent them.
68    name: CString,
69}
70
71impl Symbol {
72    /// Create a [`Symbol`] from a [`CStr`].
73    pub fn new(name: &CStr) -> Self {
74        Self {
75            name: name.to_owned(),
76        }
77    }
78}
79
80impl From<&CStr> for Symbol {
81    fn from(name: &CStr) -> Self {
82        Self::new(name)
83    }
84}
85
86impl From<&str> for Symbol {
87    fn from(name: &str) -> Self {
88        Self::new(&CString::new(name).unwrap())
89    }
90}
91
92/// A collection of [`FheProgram`]s parsed out of an ELF file.
93pub struct FheApplication {
94    programs: HashMap<Symbol, FheProgram>,
95}
96
97impl FheApplication {
98    /// Retreive an FHE program by its Symbol
99    pub fn get_program(&self, name: &Symbol) -> Option<&FheProgram> {
100        self.programs.get(name)
101    }
102
103    /// Attempt to parse the given bytes as an ELF file and return the resulting [`FheApplication`].
104    pub fn parse_elf(binary: &[u8]) -> Result<Self> {
105        let elf = ElfBytes::<LittleEndian>::minimal_parse(binary)?;
106
107        let abi_version = elf.ehdr.abiversion;
108
109        if abi_version != SUPPORTED_ABI_VERSION {
110            return Err(Error::ElfUnsupportedAbiVersion(abi_version));
111        }
112
113        let get_name = |name: u32| -> Result<&CStr> {
114            let shstrn = elf
115                .section_headers()
116                .ok_or(Error::ElfNoSectionHeaders)?
117                .get(elf.ehdr.e_shstrndx as usize)?;
118
119            let str_offset = shstrn.sh_offset as usize + name as usize;
120            Ok(CStr::from_bytes_until_nul(&binary[str_offset..])?)
121        };
122
123        let (sym, _) = elf.symbol_table()?.ok_or(Error::ElfNoSymbolTable)?;
124
125        let mut programs = HashMap::new();
126
127        for s in sym {
128            if s.st_symtype() == STT_FUNC {
129                let header = elf
130                    .section_headers()
131                    .ok_or(Error::ElfNoSectionHeaders)?
132                    .get(s.st_shndx as usize)?;
133                let (data, _) = elf.section_data(&header)?;
134                let data = &data[s.st_value as usize..s.st_value as usize + s.st_size as usize];
135
136                let mut code = vec![];
137
138                for inst in data.chunks(8) {
139                    // Infallible
140                    let inst: [u8; 8] = inst.try_into().unwrap();
141                    let inst = u64::from_le_bytes(inst);
142                    let inst = Self::parse_instruction(inst);
143
144                    code.push(inst);
145                }
146
147                let symbol_name = get_name(s.st_name)?;
148
149                programs.insert(
150                    Symbol::new(symbol_name),
151                    FheProgram::from_instructions(code),
152                );
153            }
154        }
155
156        Ok(Self { programs })
157    }
158
159    fn parse_instruction(encoded: u64) -> IsaOp {
160        match Self::get_opcode(encoded) {
161            // Types and loading
162            OpCode::BindReadOnly => {
163                let dst = RegisterName::named(Self::get_dst(encoded));
164                let buffer_id = Self::get_bind_buffer_id(encoded);
165                let is_encrypted = Self::get_bind_is_encrypted(encoded);
166
167                IsaOp::BindReadOnly(dst, buffer_id, is_encrypted)
168            }
169            OpCode::BindReadWrite => {
170                let dst = RegisterName::named(Self::get_dst(encoded));
171                let buffer_id = Self::get_bind_buffer_id(encoded);
172                let is_encrypted = Self::get_bind_is_encrypted(encoded);
173
174                IsaOp::BindReadWrite(dst, buffer_id, is_encrypted)
175            }
176            OpCode::Load => {
177                let register = RegisterName::named(Self::get_dst(encoded));
178                let memory_pointer = RegisterName::named(Self::get_src1(encoded));
179                let width = Self::get_casting_width(encoded);
180
181                IsaOp::Load(register, memory_pointer, width)
182            }
183            OpCode::LoadI => {
184                let dst = RegisterName::named(Self::get_dst(encoded));
185                let imm = Self::get_immediate(encoded);
186                let width = Self::get_immediate_width(encoded);
187
188                // All immediates are currently 32 bits, so we need to mask off the
189                // unused bits. Negative numbers are wrapped.
190                let mask = (1u128 << width) - 1;
191                let imm = imm & mask;
192
193                IsaOp::LoadI(dst, imm, width)
194            }
195            OpCode::Store => {
196                let dst = RegisterName::named(Self::get_dst(encoded));
197                let src = RegisterName::named(Self::get_src1(encoded));
198                let width = Self::get_casting_width(encoded);
199
200                IsaOp::Store(dst, src, width)
201            }
202            OpCode::Zext => {
203                let dst = RegisterName::named(Self::get_dst(encoded));
204                let src = RegisterName::named(Self::get_src1(encoded));
205                let width = Self::get_casting_width(encoded);
206
207                IsaOp::Zext(dst, src, width)
208            }
209            OpCode::Trunc => {
210                let dst = RegisterName::named(Self::get_dst(encoded));
211                let src = RegisterName::named(Self::get_src1(encoded));
212                let width = Self::get_casting_width(encoded);
213
214                IsaOp::Trunc(dst, src, width)
215            }
216
217            // Arithmetic
218            OpCode::Add => {
219                let dst = RegisterName::named(Self::get_dst(encoded));
220                let src1 = RegisterName::named(Self::get_src1(encoded));
221                let src2 = RegisterName::named(Self::get_src2(encoded));
222
223                IsaOp::Add(dst, src1, src2)
224            }
225            OpCode::AddC => {
226                unimplemented!(
227                    "Not implemented in the Parasol compiler. This operation should never be generated."
228                );
229            }
230            OpCode::Sub => {
231                let dst = RegisterName::named(Self::get_dst(encoded));
232                let src1 = RegisterName::named(Self::get_src1(encoded));
233                let src2 = RegisterName::named(Self::get_src2(encoded));
234
235                IsaOp::Sub(dst, src1, src2)
236            }
237            OpCode::SubB => {
238                unimplemented!(
239                    "Not implemented in the Parasol compiler. This operation should never be generated."
240                );
241            }
242            OpCode::Mul => {
243                let dst = RegisterName::named(Self::get_dst(encoded));
244                let src1 = RegisterName::named(Self::get_src1(encoded));
245                let src2 = RegisterName::named(Self::get_src2(encoded));
246
247                IsaOp::Mul(dst, src1, src2)
248            }
249
250            // Shifts
251            OpCode::Shl => {
252                let dst = RegisterName::named(Self::get_dst(encoded));
253                let src1 = RegisterName::named(Self::get_src1(encoded));
254                let src2 = RegisterName::named(Self::get_src2(encoded));
255
256                IsaOp::Shl(dst, src1, src2)
257            }
258            OpCode::Rotl => {
259                let dst = RegisterName::named(Self::get_dst(encoded));
260                let src1 = RegisterName::named(Self::get_src1(encoded));
261                let src2 = RegisterName::named(Self::get_src2(encoded));
262
263                IsaOp::Rotl(dst, src1, src2)
264            }
265            OpCode::Shr => {
266                let dst = RegisterName::named(Self::get_dst(encoded));
267                let src1 = RegisterName::named(Self::get_src1(encoded));
268                let src2 = RegisterName::named(Self::get_src2(encoded));
269
270                IsaOp::Shr(dst, src1, src2)
271            }
272            OpCode::Rotr => {
273                let dst = RegisterName::named(Self::get_dst(encoded));
274                let src1 = RegisterName::named(Self::get_src1(encoded));
275                let src2 = RegisterName::named(Self::get_src2(encoded));
276
277                IsaOp::Rotr(dst, src1, src2)
278            }
279
280            // Logic
281            OpCode::And => {
282                let dst = RegisterName::named(Self::get_dst(encoded));
283                let src1 = RegisterName::named(Self::get_src1(encoded));
284                let src2 = RegisterName::named(Self::get_src2(encoded));
285
286                IsaOp::And(dst, src1, src2)
287            }
288            OpCode::Or => {
289                let dst = RegisterName::named(Self::get_dst(encoded));
290                let src1 = RegisterName::named(Self::get_src1(encoded));
291                let src2 = RegisterName::named(Self::get_src2(encoded));
292
293                IsaOp::Or(dst, src1, src2)
294            }
295            OpCode::Xor => {
296                let dst = RegisterName::named(Self::get_dst(encoded));
297                let src1 = RegisterName::named(Self::get_src1(encoded));
298                let src2 = RegisterName::named(Self::get_src2(encoded));
299
300                IsaOp::Xor(dst, src1, src2)
301            }
302            OpCode::Not => {
303                let dst = RegisterName::named(Self::get_dst(encoded));
304                let src = RegisterName::named(Self::get_src1(encoded));
305
306                IsaOp::Not(dst, src)
307            }
308            OpCode::Neg => {
309                let dst = RegisterName::named(Self::get_dst(encoded));
310                let src = RegisterName::named(Self::get_src1(encoded));
311
312                IsaOp::Neg(dst, src)
313            }
314
315            // Comparison
316            OpCode::Gt => {
317                let dst = RegisterName::named(Self::get_dst(encoded));
318                let src1 = RegisterName::named(Self::get_src1(encoded));
319                let src2 = RegisterName::named(Self::get_src2(encoded));
320
321                IsaOp::CmpGt(dst, src1, src2)
322            }
323            OpCode::Ge => {
324                let dst = RegisterName::named(Self::get_dst(encoded));
325                let src1 = RegisterName::named(Self::get_src1(encoded));
326                let src2 = RegisterName::named(Self::get_src2(encoded));
327
328                IsaOp::CmpGe(dst, src1, src2)
329            }
330            OpCode::Lt => {
331                let dst = RegisterName::named(Self::get_dst(encoded));
332                let src1 = RegisterName::named(Self::get_src1(encoded));
333                let src2 = RegisterName::named(Self::get_src2(encoded));
334
335                IsaOp::CmpLt(dst, src1, src2)
336            }
337            OpCode::Le => {
338                let dst = RegisterName::named(Self::get_dst(encoded));
339                let src1 = RegisterName::named(Self::get_src1(encoded));
340                let src2 = RegisterName::named(Self::get_src2(encoded));
341
342                IsaOp::CmpLe(dst, src1, src2)
343            }
344            OpCode::Eq => {
345                let dst = RegisterName::named(Self::get_dst(encoded));
346                let src1 = RegisterName::named(Self::get_src1(encoded));
347                let src2 = RegisterName::named(Self::get_src2(encoded));
348
349                IsaOp::CmpEq(dst, src1, src2)
350            }
351            OpCode::Cmux => {
352                let dst = RegisterName::named(Self::get_dst(encoded));
353                let select = RegisterName::named(Self::get_src1(encoded));
354                let a = RegisterName::named(Self::get_src2(encoded));
355                let b = RegisterName::named(Self::get_src3(encoded));
356
357                IsaOp::Cmux(dst, select, a, b)
358            }
359
360            // Control flow
361            OpCode::Ret => IsaOp::Ret(),
362            _ => {
363                unimplemented!("Unknown opcode {:x}", encoded & 0xFF);
364            }
365        }
366    }
367
368    fn get_opcode(encoded: u64) -> OpCode {
369        match encoded & 0xFF {
370            // Types and loading
371            0x00 => OpCode::BindReadOnly,
372            0x01 => OpCode::BindReadWrite,
373            0x02 => OpCode::Load,
374            0x03 => OpCode::LoadI,
375            0x04 => OpCode::Store,
376            0x05 => OpCode::Zext,
377            0x06 => OpCode::Trunc,
378
379            // Arithmetic
380            0x10 => OpCode::Add,
381            0x11 => OpCode::AddC,
382            0x12 => OpCode::Sub,
383            0x13 => OpCode::SubB,
384            0x14 => OpCode::Mul,
385
386            // Shifts
387            0x20 => OpCode::Shl,
388            0x21 => OpCode::Rotl,
389            0x22 => OpCode::Shr,
390            0x23 => OpCode::Rotr,
391
392            // Logic
393            0x30 => OpCode::And,
394            0x31 => OpCode::Or,
395            0x32 => OpCode::Xor,
396            0x33 => OpCode::Not,
397            0x34 => OpCode::Neg,
398
399            // Comparison
400            0x40 => OpCode::Gt,
401            0x41 => OpCode::Ge,
402            0x42 => OpCode::Lt,
403            0x43 => OpCode::Le,
404            0x44 => OpCode::Eq,
405            0x45 => OpCode::Cmux,
406
407            // Control flow
408            0xFE => OpCode::Ret,
409            _ => OpCode::Unknown,
410        }
411    }
412
413    fn get_dst(encoded: u64) -> usize {
414        ((encoded >> 8) & 0x3F) as usize
415    }
416
417    fn get_src1(encoded: u64) -> usize {
418        ((encoded >> 14) & 0x3F) as usize
419    }
420
421    fn get_src2(encoded: u64) -> usize {
422        ((encoded >> 20) & 0x3F) as usize
423    }
424
425    fn get_src3(encoded: u64) -> usize {
426        ((encoded >> 26) & 0x3F) as usize
427    }
428
429    fn get_bind_buffer_id(encoded: u64) -> usize {
430        ((encoded >> 14) & 0x3FF) as usize
431    }
432
433    fn get_bind_is_encrypted(encoded: u64) -> bool {
434        (encoded >> 24) & 0x1 == 1
435    }
436
437    fn get_immediate_width(encoded: u64) -> u32 {
438        let exponent = ((encoded >> 14) & 0x7) as u32;
439        2u32.pow(exponent)
440    }
441
442    fn get_immediate(encoded: u64) -> u128 {
443        (encoded >> 17) as u128
444    }
445
446    fn get_casting_width(encoded: u64) -> u32 {
447        let exponent = ((encoded >> 20) & 0x7) as u32;
448        2u32.pow(exponent)
449    }
450}
451
452/// Whether a buffer is bound as read-only or read/write.
453#[derive(Debug, Clone, PartialEq, Eq)]
454pub enum BufferType {
455    /// Buffer is read-only.
456    Read,
457
458    /// Buffer is read/write.
459    ReadWrite,
460}
461
462/// Information about an FHE program's bindings.
463#[derive(Debug, Clone, PartialEq, Eq)]
464pub struct BufferInfo {
465    /// The assigned data register for the first load/store on the binding.
466    pub register: usize,
467
468    /// Whether the buffer is writable or not.
469    pub buffer_type: BufferType,
470
471    /// Whether the buffer is encrypted of not.
472    pub is_encrypted: bool,
473
474    /// The buffer's ID.
475    pub buffer_id: usize,
476
477    /// The width of the first load/store on the bound buffer.
478    pub width: u32,
479}
480
481#[derive(Debug, Clone, PartialEq, Eq)]
482/// Information about buffer bindings during a call to [`crate::FheComputer::run_program`].
483pub struct ProgramBufferInfo {
484    buffers: Vec<BufferInfo>,
485}
486
487impl ProgramBufferInfo {
488    /// The number of bindings in an [`FheProgram`].
489    pub fn len(&self) -> usize {
490        self.buffers.len()
491    }
492
493    /// Whether the [`FheProgram`] lacks bindings or not.
494    pub fn is_empty(&self) -> bool {
495        self.len() == 0
496    }
497
498    /// Get the number of read-only buffers.
499    pub fn num_read_buffers(&self) -> usize {
500        self.buffers
501            .iter()
502            .filter(|info| info.buffer_type == BufferType::Read)
503            .count()
504    }
505
506    /// Get the number of writable buffers.
507    pub fn num_read_write_buffers(&self) -> usize {
508        self.buffers
509            .iter()
510            .filter(|info| info.buffer_type == BufferType::ReadWrite)
511            .count()
512    }
513
514    /// Return an iterator over the read-only [`BufferInfo`]s.
515    pub fn read_buffers(&self) -> impl Iterator<Item = &BufferInfo> {
516        self.buffers
517            .iter()
518            .filter(|info| info.buffer_type == BufferType::Read)
519    }
520
521    /// Return an iterator over the read/write [`BufferInfo`]s.
522    pub fn read_write_buffers(&self) -> impl Iterator<Item = &BufferInfo> {
523        self.buffers
524            .iter()
525            .filter(|info| info.buffer_type == BufferType::ReadWrite)
526    }
527
528    /// Return an iterator over all [`BufferInfo`]s.
529    pub fn iter(&self) -> std::slice::Iter<BufferInfo> {
530        self.buffers.iter()
531    }
532}
533
534impl IntoIterator for ProgramBufferInfo {
535    type Item = BufferInfo;
536    type IntoIter = std::vec::IntoIter<Self::Item>;
537
538    fn into_iter(self) -> Self::IntoIter {
539        self.buffers.into_iter()
540    }
541}
542
543impl Index<usize> for ProgramBufferInfo {
544    type Output = BufferInfo;
545
546    fn index(&self, index: usize) -> &Self::Output {
547        &self.buffers[index]
548    }
549}
550
551/// An executable Parasol program located in an ELF file.
552pub struct FheProgram {
553    pub(crate) instructions: Vec<IsaOp>,
554}
555
556#[derive(Debug, Clone, Error, PartialEq, Eq)]
557/// An error when attempting to gather buffer information in an [`FheProgram`]
558pub enum BufferInfoError {
559    /// Bindings were mismatched.
560    #[error("Register bind pointer ID {} and meta ID {} don't match", .0, .1)]
561    MismatchedBinding(usize, usize),
562
563    /// Bind instruction's buffer IDs were not sequential.
564    #[error("Non-sequential buffer IDs")]
565    NonSequentialBufferIds,
566
567    /// Program contained multiple bind instructions binding to the same value.
568    #[error("Duplicate binding for register ID {}", .0)]
569    DuplicateBinding(usize),
570}
571
572impl FheProgram {
573    pub(crate) fn from_instructions(inst: Vec<IsaOp>) -> Self {
574        Self { instructions: inst }
575    }
576
577    /// Get information about a program's bound buffers.
578    pub fn get_buffer_info(&self) -> std::result::Result<ProgramBufferInfo, BufferInfoError> {
579        let mut bindings = HashMap::new();
580        let mut verified_buffers = HashMap::new();
581
582        // Single pass through instructions
583        for op in &self.instructions {
584            match op {
585                IsaOp::BindReadOnly(RegisterName::Named(reg_num, _), id, encrypted) => {
586                    if reg_num != id {
587                        return Err(BufferInfoError::MismatchedBinding(*reg_num, *id));
588                    }
589                    if bindings.contains_key(reg_num) {
590                        return Err(BufferInfoError::DuplicateBinding(*reg_num));
591                    }
592                    bindings.insert(*reg_num, (*id, true, *encrypted));
593                }
594                IsaOp::BindReadWrite(RegisterName::Named(reg_num, _), id, encrypted) => {
595                    if reg_num != id {
596                        return Err(BufferInfoError::MismatchedBinding(*reg_num, *id));
597                    }
598                    if bindings.contains_key(reg_num) {
599                        return Err(BufferInfoError::DuplicateBinding(*reg_num));
600                    }
601                    bindings.insert(*reg_num, (*id, false, *encrypted));
602                }
603                IsaOp::Load(_, RegisterName::Named(reg_num, _), width) => {
604                    if let Some(&(id, is_read_only, encrypted)) = bindings.get(reg_num) {
605                        if is_read_only {
606                            verified_buffers.insert(
607                                id,
608                                BufferInfo {
609                                    register: *reg_num,
610                                    buffer_type: BufferType::Read,
611                                    is_encrypted: encrypted,
612                                    buffer_id: id,
613                                    width: *width,
614                                },
615                            );
616                        }
617                    }
618                }
619                IsaOp::Store(RegisterName::Named(reg_num, _), _, width) => {
620                    if let Some(&(id, is_read_only, encrypted)) = bindings.get(reg_num) {
621                        if !is_read_only {
622                            verified_buffers.insert(
623                                id,
624                                BufferInfo {
625                                    register: *reg_num,
626                                    buffer_type: BufferType::ReadWrite,
627                                    is_encrypted: encrypted,
628                                    buffer_id: id,
629                                    width: *width,
630                                },
631                            );
632                        }
633                    }
634                }
635                _ => {}
636            }
637        }
638
639        // Convert to sorted vec and verify sequential IDs
640        let mut result: Vec<_> = verified_buffers.into_values().collect();
641        result.sort_by_key(|info| info.buffer_id);
642
643        // Verify sequential IDs
644        for (i, info) in result.iter().enumerate() {
645            if info.buffer_id != i {
646                return Err(BufferInfoError::NonSequentialBufferIds);
647            }
648        }
649
650        Ok(ProgramBufferInfo { buffers: result })
651    }
652}
653
654impl From<Vec<IsaOp>> for FheProgram {
655    fn from(inst: Vec<IsaOp>) -> Self {
656        Self { instructions: inst }
657    }
658}
659
660impl From<&[IsaOp]> for FheProgram {
661    fn from(inst: &[IsaOp]) -> Self {
662        Self {
663            instructions: inst.to_vec(),
664        }
665    }
666}
667
668#[cfg(test)]
669mod tests {
670
671    use super::*;
672
673    const ELF: &[u8] = include_bytes!("../../tests/test_data/chi_squared.o");
674
675    #[test]
676    fn can_parse_elf() {
677        let result = FheApplication::parse_elf(ELF).unwrap();
678
679        assert_eq!(result.programs.len(), 1);
680        result
681            .get_program(&Symbol::new(
682                &CString::new("chi_squared_optimized").unwrap(),
683            ))
684            .unwrap();
685    }
686
687    #[test]
688    fn test_chi_squared_buffer_info() {
689        let result = FheApplication::parse_elf(ELF).unwrap();
690        let program = result
691            .get_program(&Symbol::from("chi_squared_optimized"))
692            .unwrap();
693        let buffer_info = program.get_buffer_info().unwrap();
694
695        assert_eq!(buffer_info.len(), 7);
696
697        // Inputs
698        assert_eq!(
699            buffer_info[0],
700            BufferInfo {
701                register: 0,
702                buffer_type: BufferType::Read,
703                is_encrypted: true,
704                buffer_id: 0,
705                width: 16,
706            }
707        );
708
709        assert_eq!(
710            buffer_info[1],
711            BufferInfo {
712                register: 1,
713                buffer_type: BufferType::Read,
714                is_encrypted: true,
715                buffer_id: 1,
716                width: 16,
717            }
718        );
719
720        assert_eq!(
721            buffer_info[2],
722            BufferInfo {
723                register: 2,
724                buffer_type: BufferType::Read,
725                is_encrypted: true,
726                buffer_id: 2,
727                width: 16,
728            }
729        );
730
731        // ReadWrites
732        assert_eq!(
733            buffer_info[3],
734            BufferInfo {
735                register: 3,
736                buffer_type: BufferType::ReadWrite,
737                is_encrypted: true,
738                buffer_id: 3,
739                width: 16,
740            }
741        );
742
743        assert_eq!(
744            buffer_info[4],
745            BufferInfo {
746                register: 4,
747                buffer_type: BufferType::ReadWrite,
748                is_encrypted: true,
749                buffer_id: 4,
750                width: 16,
751            }
752        );
753
754        assert_eq!(
755            buffer_info[5],
756            BufferInfo {
757                register: 5,
758                buffer_type: BufferType::ReadWrite,
759                is_encrypted: true,
760                buffer_id: 5,
761                width: 16,
762            }
763        );
764
765        assert_eq!(
766            buffer_info[6],
767            BufferInfo {
768                register: 6,
769                buffer_type: BufferType::ReadWrite,
770                is_encrypted: true,
771                buffer_id: 6,
772                width: 16,
773            }
774        );
775    }
776
777    #[test]
778    fn test_program_buffer_info_methods() {
779        let buffers = vec![
780            BufferInfo {
781                register: 0,
782                buffer_type: BufferType::Read,
783                is_encrypted: true,
784                buffer_id: 0,
785                width: 16,
786            },
787            BufferInfo {
788                register: 1,
789                buffer_type: BufferType::ReadWrite,
790                is_encrypted: true,
791                buffer_id: 1,
792                width: 16,
793            },
794            BufferInfo {
795                register: 2,
796                buffer_type: BufferType::Read,
797                is_encrypted: false,
798                buffer_id: 2,
799                width: 32,
800            },
801        ];
802
803        let program_info = ProgramBufferInfo { buffers };
804
805        // Test len()
806        assert_eq!(program_info.len(), 3);
807
808        // Test num_read_buffers()
809        assert_eq!(program_info.num_read_buffers(), 2);
810
811        // Test num_read_write_buffers()
812        assert_eq!(program_info.num_read_write_buffers(), 1);
813
814        // Test read_buffers()
815        let read_buffers: Vec<_> = program_info.read_buffers().collect();
816        assert_eq!(read_buffers.len(), 2);
817        assert_eq!(read_buffers[0].register, 0);
818        assert_eq!(read_buffers[1].register, 2);
819
820        // Test read_write_buffers()
821        let read_write_buffers: Vec<_> = program_info.read_write_buffers().collect();
822        assert_eq!(read_write_buffers.len(), 1);
823        assert_eq!(read_write_buffers[0].register, 1);
824
825        // Test iter()
826        let all_buffers: Vec<_> = program_info.iter().collect();
827        assert_eq!(all_buffers.len(), 3);
828
829        // Test Index implementation
830        assert_eq!(program_info[0].register, 0);
831        assert_eq!(program_info[1].register, 1);
832        assert_eq!(program_info[2].register, 2);
833
834        // Test IntoIterator implementation
835        let into_iter_buffers: Vec<_> = program_info.into_iter().collect();
836        assert_eq!(into_iter_buffers.len(), 3);
837        assert_eq!(into_iter_buffers[0].register, 0);
838        assert_eq!(into_iter_buffers[1].register, 1);
839        assert_eq!(into_iter_buffers[2].register, 2);
840    }
841}