1use std::{
2 collections::HashMap,
3 ffi::{CStr, CString},
4 ops::Index,
5};
6
7use crate::{Error, Result, proc::IsaOp, tomasulo::registers::RegisterName};
8use thiserror::Error;
9
10use elf::{ElfBytes, abi::STT_FUNC, endian::LittleEndian};
11
12pub(crate) const SUPPORTED_ABI_VERSION: u8 = 1;
19
20enum OpCode {
21 BindReadOnly,
23 BindReadWrite,
24 Load,
25 LoadI,
26 Store,
27 Zext,
28 Trunc,
29
30 Add,
32 AddC,
33 Sub,
34 SubB,
35 Mul,
36
37 Shl,
39 Rotl,
40 Shr,
41 Rotr,
42
43 And,
45 Or,
46 Xor,
47 Not,
48 Neg,
49
50 Gt,
52 Ge,
53 Lt,
54 Le,
55 Eq,
56 Cmux,
57
58 Ret,
60 Unknown,
61}
62
63#[derive(Debug, Clone, Hash, PartialEq, Eq)]
64pub struct Symbol {
66 name: CString,
69}
70
71impl Symbol {
72 pub fn new(name: &CStr) -> Self {
74 Self {
75 name: name.to_owned(),
76 }
77 }
78}
79
80impl From<&CStr> for Symbol {
81 fn from(name: &CStr) -> Self {
82 Self::new(name)
83 }
84}
85
86impl From<&str> for Symbol {
87 fn from(name: &str) -> Self {
88 Self::new(&CString::new(name).unwrap())
89 }
90}
91
92pub struct FheApplication {
94 programs: HashMap<Symbol, FheProgram>,
95}
96
97impl FheApplication {
98 pub fn get_program(&self, name: &Symbol) -> Option<&FheProgram> {
100 self.programs.get(name)
101 }
102
103 pub fn parse_elf(binary: &[u8]) -> Result<Self> {
105 let elf = ElfBytes::<LittleEndian>::minimal_parse(binary)?;
106
107 let abi_version = elf.ehdr.abiversion;
108
109 if abi_version != SUPPORTED_ABI_VERSION {
110 return Err(Error::ElfUnsupportedAbiVersion(abi_version));
111 }
112
113 let get_name = |name: u32| -> Result<&CStr> {
114 let shstrn = elf
115 .section_headers()
116 .ok_or(Error::ElfNoSectionHeaders)?
117 .get(elf.ehdr.e_shstrndx as usize)?;
118
119 let str_offset = shstrn.sh_offset as usize + name as usize;
120 Ok(CStr::from_bytes_until_nul(&binary[str_offset..])?)
121 };
122
123 let (sym, _) = elf.symbol_table()?.ok_or(Error::ElfNoSymbolTable)?;
124
125 let mut programs = HashMap::new();
126
127 for s in sym {
128 if s.st_symtype() == STT_FUNC {
129 let header = elf
130 .section_headers()
131 .ok_or(Error::ElfNoSectionHeaders)?
132 .get(s.st_shndx as usize)?;
133 let (data, _) = elf.section_data(&header)?;
134 let data = &data[s.st_value as usize..s.st_value as usize + s.st_size as usize];
135
136 let mut code = vec![];
137
138 for inst in data.chunks(8) {
139 let inst: [u8; 8] = inst.try_into().unwrap();
141 let inst = u64::from_le_bytes(inst);
142 let inst = Self::parse_instruction(inst);
143
144 code.push(inst);
145 }
146
147 let symbol_name = get_name(s.st_name)?;
148
149 programs.insert(
150 Symbol::new(symbol_name),
151 FheProgram::from_instructions(code),
152 );
153 }
154 }
155
156 Ok(Self { programs })
157 }
158
159 fn parse_instruction(encoded: u64) -> IsaOp {
160 match Self::get_opcode(encoded) {
161 OpCode::BindReadOnly => {
163 let dst = RegisterName::named(Self::get_dst(encoded));
164 let buffer_id = Self::get_bind_buffer_id(encoded);
165 let is_encrypted = Self::get_bind_is_encrypted(encoded);
166
167 IsaOp::BindReadOnly(dst, buffer_id, is_encrypted)
168 }
169 OpCode::BindReadWrite => {
170 let dst = RegisterName::named(Self::get_dst(encoded));
171 let buffer_id = Self::get_bind_buffer_id(encoded);
172 let is_encrypted = Self::get_bind_is_encrypted(encoded);
173
174 IsaOp::BindReadWrite(dst, buffer_id, is_encrypted)
175 }
176 OpCode::Load => {
177 let register = RegisterName::named(Self::get_dst(encoded));
178 let memory_pointer = RegisterName::named(Self::get_src1(encoded));
179 let width = Self::get_casting_width(encoded);
180
181 IsaOp::Load(register, memory_pointer, width)
182 }
183 OpCode::LoadI => {
184 let dst = RegisterName::named(Self::get_dst(encoded));
185 let imm = Self::get_immediate(encoded);
186 let width = Self::get_immediate_width(encoded);
187
188 let mask = (1u128 << width) - 1;
191 let imm = imm & mask;
192
193 IsaOp::LoadI(dst, imm, width)
194 }
195 OpCode::Store => {
196 let dst = RegisterName::named(Self::get_dst(encoded));
197 let src = RegisterName::named(Self::get_src1(encoded));
198 let width = Self::get_casting_width(encoded);
199
200 IsaOp::Store(dst, src, width)
201 }
202 OpCode::Zext => {
203 let dst = RegisterName::named(Self::get_dst(encoded));
204 let src = RegisterName::named(Self::get_src1(encoded));
205 let width = Self::get_casting_width(encoded);
206
207 IsaOp::Zext(dst, src, width)
208 }
209 OpCode::Trunc => {
210 let dst = RegisterName::named(Self::get_dst(encoded));
211 let src = RegisterName::named(Self::get_src1(encoded));
212 let width = Self::get_casting_width(encoded);
213
214 IsaOp::Trunc(dst, src, width)
215 }
216
217 OpCode::Add => {
219 let dst = RegisterName::named(Self::get_dst(encoded));
220 let src1 = RegisterName::named(Self::get_src1(encoded));
221 let src2 = RegisterName::named(Self::get_src2(encoded));
222
223 IsaOp::Add(dst, src1, src2)
224 }
225 OpCode::AddC => {
226 unimplemented!(
227 "Not implemented in the Parasol compiler. This operation should never be generated."
228 );
229 }
230 OpCode::Sub => {
231 let dst = RegisterName::named(Self::get_dst(encoded));
232 let src1 = RegisterName::named(Self::get_src1(encoded));
233 let src2 = RegisterName::named(Self::get_src2(encoded));
234
235 IsaOp::Sub(dst, src1, src2)
236 }
237 OpCode::SubB => {
238 unimplemented!(
239 "Not implemented in the Parasol compiler. This operation should never be generated."
240 );
241 }
242 OpCode::Mul => {
243 let dst = RegisterName::named(Self::get_dst(encoded));
244 let src1 = RegisterName::named(Self::get_src1(encoded));
245 let src2 = RegisterName::named(Self::get_src2(encoded));
246
247 IsaOp::Mul(dst, src1, src2)
248 }
249
250 OpCode::Shl => {
252 let dst = RegisterName::named(Self::get_dst(encoded));
253 let src1 = RegisterName::named(Self::get_src1(encoded));
254 let src2 = RegisterName::named(Self::get_src2(encoded));
255
256 IsaOp::Shl(dst, src1, src2)
257 }
258 OpCode::Rotl => {
259 let dst = RegisterName::named(Self::get_dst(encoded));
260 let src1 = RegisterName::named(Self::get_src1(encoded));
261 let src2 = RegisterName::named(Self::get_src2(encoded));
262
263 IsaOp::Rotl(dst, src1, src2)
264 }
265 OpCode::Shr => {
266 let dst = RegisterName::named(Self::get_dst(encoded));
267 let src1 = RegisterName::named(Self::get_src1(encoded));
268 let src2 = RegisterName::named(Self::get_src2(encoded));
269
270 IsaOp::Shr(dst, src1, src2)
271 }
272 OpCode::Rotr => {
273 let dst = RegisterName::named(Self::get_dst(encoded));
274 let src1 = RegisterName::named(Self::get_src1(encoded));
275 let src2 = RegisterName::named(Self::get_src2(encoded));
276
277 IsaOp::Rotr(dst, src1, src2)
278 }
279
280 OpCode::And => {
282 let dst = RegisterName::named(Self::get_dst(encoded));
283 let src1 = RegisterName::named(Self::get_src1(encoded));
284 let src2 = RegisterName::named(Self::get_src2(encoded));
285
286 IsaOp::And(dst, src1, src2)
287 }
288 OpCode::Or => {
289 let dst = RegisterName::named(Self::get_dst(encoded));
290 let src1 = RegisterName::named(Self::get_src1(encoded));
291 let src2 = RegisterName::named(Self::get_src2(encoded));
292
293 IsaOp::Or(dst, src1, src2)
294 }
295 OpCode::Xor => {
296 let dst = RegisterName::named(Self::get_dst(encoded));
297 let src1 = RegisterName::named(Self::get_src1(encoded));
298 let src2 = RegisterName::named(Self::get_src2(encoded));
299
300 IsaOp::Xor(dst, src1, src2)
301 }
302 OpCode::Not => {
303 let dst = RegisterName::named(Self::get_dst(encoded));
304 let src = RegisterName::named(Self::get_src1(encoded));
305
306 IsaOp::Not(dst, src)
307 }
308 OpCode::Neg => {
309 let dst = RegisterName::named(Self::get_dst(encoded));
310 let src = RegisterName::named(Self::get_src1(encoded));
311
312 IsaOp::Neg(dst, src)
313 }
314
315 OpCode::Gt => {
317 let dst = RegisterName::named(Self::get_dst(encoded));
318 let src1 = RegisterName::named(Self::get_src1(encoded));
319 let src2 = RegisterName::named(Self::get_src2(encoded));
320
321 IsaOp::CmpGt(dst, src1, src2)
322 }
323 OpCode::Ge => {
324 let dst = RegisterName::named(Self::get_dst(encoded));
325 let src1 = RegisterName::named(Self::get_src1(encoded));
326 let src2 = RegisterName::named(Self::get_src2(encoded));
327
328 IsaOp::CmpGe(dst, src1, src2)
329 }
330 OpCode::Lt => {
331 let dst = RegisterName::named(Self::get_dst(encoded));
332 let src1 = RegisterName::named(Self::get_src1(encoded));
333 let src2 = RegisterName::named(Self::get_src2(encoded));
334
335 IsaOp::CmpLt(dst, src1, src2)
336 }
337 OpCode::Le => {
338 let dst = RegisterName::named(Self::get_dst(encoded));
339 let src1 = RegisterName::named(Self::get_src1(encoded));
340 let src2 = RegisterName::named(Self::get_src2(encoded));
341
342 IsaOp::CmpLe(dst, src1, src2)
343 }
344 OpCode::Eq => {
345 let dst = RegisterName::named(Self::get_dst(encoded));
346 let src1 = RegisterName::named(Self::get_src1(encoded));
347 let src2 = RegisterName::named(Self::get_src2(encoded));
348
349 IsaOp::CmpEq(dst, src1, src2)
350 }
351 OpCode::Cmux => {
352 let dst = RegisterName::named(Self::get_dst(encoded));
353 let select = RegisterName::named(Self::get_src1(encoded));
354 let a = RegisterName::named(Self::get_src2(encoded));
355 let b = RegisterName::named(Self::get_src3(encoded));
356
357 IsaOp::Cmux(dst, select, a, b)
358 }
359
360 OpCode::Ret => IsaOp::Ret(),
362 _ => {
363 unimplemented!("Unknown opcode {:x}", encoded & 0xFF);
364 }
365 }
366 }
367
368 fn get_opcode(encoded: u64) -> OpCode {
369 match encoded & 0xFF {
370 0x00 => OpCode::BindReadOnly,
372 0x01 => OpCode::BindReadWrite,
373 0x02 => OpCode::Load,
374 0x03 => OpCode::LoadI,
375 0x04 => OpCode::Store,
376 0x05 => OpCode::Zext,
377 0x06 => OpCode::Trunc,
378
379 0x10 => OpCode::Add,
381 0x11 => OpCode::AddC,
382 0x12 => OpCode::Sub,
383 0x13 => OpCode::SubB,
384 0x14 => OpCode::Mul,
385
386 0x20 => OpCode::Shl,
388 0x21 => OpCode::Rotl,
389 0x22 => OpCode::Shr,
390 0x23 => OpCode::Rotr,
391
392 0x30 => OpCode::And,
394 0x31 => OpCode::Or,
395 0x32 => OpCode::Xor,
396 0x33 => OpCode::Not,
397 0x34 => OpCode::Neg,
398
399 0x40 => OpCode::Gt,
401 0x41 => OpCode::Ge,
402 0x42 => OpCode::Lt,
403 0x43 => OpCode::Le,
404 0x44 => OpCode::Eq,
405 0x45 => OpCode::Cmux,
406
407 0xFE => OpCode::Ret,
409 _ => OpCode::Unknown,
410 }
411 }
412
413 fn get_dst(encoded: u64) -> usize {
414 ((encoded >> 8) & 0x3F) as usize
415 }
416
417 fn get_src1(encoded: u64) -> usize {
418 ((encoded >> 14) & 0x3F) as usize
419 }
420
421 fn get_src2(encoded: u64) -> usize {
422 ((encoded >> 20) & 0x3F) as usize
423 }
424
425 fn get_src3(encoded: u64) -> usize {
426 ((encoded >> 26) & 0x3F) as usize
427 }
428
429 fn get_bind_buffer_id(encoded: u64) -> usize {
430 ((encoded >> 14) & 0x3FF) as usize
431 }
432
433 fn get_bind_is_encrypted(encoded: u64) -> bool {
434 (encoded >> 24) & 0x1 == 1
435 }
436
437 fn get_immediate_width(encoded: u64) -> u32 {
438 let exponent = ((encoded >> 14) & 0x7) as u32;
439 2u32.pow(exponent)
440 }
441
442 fn get_immediate(encoded: u64) -> u128 {
443 (encoded >> 17) as u128
444 }
445
446 fn get_casting_width(encoded: u64) -> u32 {
447 let exponent = ((encoded >> 20) & 0x7) as u32;
448 2u32.pow(exponent)
449 }
450}
451
452#[derive(Debug, Clone, PartialEq, Eq)]
454pub enum BufferType {
455 Read,
457
458 ReadWrite,
460}
461
462#[derive(Debug, Clone, PartialEq, Eq)]
464pub struct BufferInfo {
465 pub register: usize,
467
468 pub buffer_type: BufferType,
470
471 pub is_encrypted: bool,
473
474 pub buffer_id: usize,
476
477 pub width: u32,
479}
480
481#[derive(Debug, Clone, PartialEq, Eq)]
482pub struct ProgramBufferInfo {
484 buffers: Vec<BufferInfo>,
485}
486
487impl ProgramBufferInfo {
488 pub fn len(&self) -> usize {
490 self.buffers.len()
491 }
492
493 pub fn is_empty(&self) -> bool {
495 self.len() == 0
496 }
497
498 pub fn num_read_buffers(&self) -> usize {
500 self.buffers
501 .iter()
502 .filter(|info| info.buffer_type == BufferType::Read)
503 .count()
504 }
505
506 pub fn num_read_write_buffers(&self) -> usize {
508 self.buffers
509 .iter()
510 .filter(|info| info.buffer_type == BufferType::ReadWrite)
511 .count()
512 }
513
514 pub fn read_buffers(&self) -> impl Iterator<Item = &BufferInfo> {
516 self.buffers
517 .iter()
518 .filter(|info| info.buffer_type == BufferType::Read)
519 }
520
521 pub fn read_write_buffers(&self) -> impl Iterator<Item = &BufferInfo> {
523 self.buffers
524 .iter()
525 .filter(|info| info.buffer_type == BufferType::ReadWrite)
526 }
527
528 pub fn iter(&self) -> std::slice::Iter<BufferInfo> {
530 self.buffers.iter()
531 }
532}
533
534impl IntoIterator for ProgramBufferInfo {
535 type Item = BufferInfo;
536 type IntoIter = std::vec::IntoIter<Self::Item>;
537
538 fn into_iter(self) -> Self::IntoIter {
539 self.buffers.into_iter()
540 }
541}
542
543impl Index<usize> for ProgramBufferInfo {
544 type Output = BufferInfo;
545
546 fn index(&self, index: usize) -> &Self::Output {
547 &self.buffers[index]
548 }
549}
550
551pub struct FheProgram {
553 pub(crate) instructions: Vec<IsaOp>,
554}
555
556#[derive(Debug, Clone, Error, PartialEq, Eq)]
557pub enum BufferInfoError {
559 #[error("Register bind pointer ID {} and meta ID {} don't match", .0, .1)]
561 MismatchedBinding(usize, usize),
562
563 #[error("Non-sequential buffer IDs")]
565 NonSequentialBufferIds,
566
567 #[error("Duplicate binding for register ID {}", .0)]
569 DuplicateBinding(usize),
570}
571
572impl FheProgram {
573 pub(crate) fn from_instructions(inst: Vec<IsaOp>) -> Self {
574 Self { instructions: inst }
575 }
576
577 pub fn get_buffer_info(&self) -> std::result::Result<ProgramBufferInfo, BufferInfoError> {
579 let mut bindings = HashMap::new();
580 let mut verified_buffers = HashMap::new();
581
582 for op in &self.instructions {
584 match op {
585 IsaOp::BindReadOnly(RegisterName::Named(reg_num, _), id, encrypted) => {
586 if reg_num != id {
587 return Err(BufferInfoError::MismatchedBinding(*reg_num, *id));
588 }
589 if bindings.contains_key(reg_num) {
590 return Err(BufferInfoError::DuplicateBinding(*reg_num));
591 }
592 bindings.insert(*reg_num, (*id, true, *encrypted));
593 }
594 IsaOp::BindReadWrite(RegisterName::Named(reg_num, _), id, encrypted) => {
595 if reg_num != id {
596 return Err(BufferInfoError::MismatchedBinding(*reg_num, *id));
597 }
598 if bindings.contains_key(reg_num) {
599 return Err(BufferInfoError::DuplicateBinding(*reg_num));
600 }
601 bindings.insert(*reg_num, (*id, false, *encrypted));
602 }
603 IsaOp::Load(_, RegisterName::Named(reg_num, _), width) => {
604 if let Some(&(id, is_read_only, encrypted)) = bindings.get(reg_num) {
605 if is_read_only {
606 verified_buffers.insert(
607 id,
608 BufferInfo {
609 register: *reg_num,
610 buffer_type: BufferType::Read,
611 is_encrypted: encrypted,
612 buffer_id: id,
613 width: *width,
614 },
615 );
616 }
617 }
618 }
619 IsaOp::Store(RegisterName::Named(reg_num, _), _, width) => {
620 if let Some(&(id, is_read_only, encrypted)) = bindings.get(reg_num) {
621 if !is_read_only {
622 verified_buffers.insert(
623 id,
624 BufferInfo {
625 register: *reg_num,
626 buffer_type: BufferType::ReadWrite,
627 is_encrypted: encrypted,
628 buffer_id: id,
629 width: *width,
630 },
631 );
632 }
633 }
634 }
635 _ => {}
636 }
637 }
638
639 let mut result: Vec<_> = verified_buffers.into_values().collect();
641 result.sort_by_key(|info| info.buffer_id);
642
643 for (i, info) in result.iter().enumerate() {
645 if info.buffer_id != i {
646 return Err(BufferInfoError::NonSequentialBufferIds);
647 }
648 }
649
650 Ok(ProgramBufferInfo { buffers: result })
651 }
652}
653
654impl From<Vec<IsaOp>> for FheProgram {
655 fn from(inst: Vec<IsaOp>) -> Self {
656 Self { instructions: inst }
657 }
658}
659
660impl From<&[IsaOp]> for FheProgram {
661 fn from(inst: &[IsaOp]) -> Self {
662 Self {
663 instructions: inst.to_vec(),
664 }
665 }
666}
667
668#[cfg(test)]
669mod tests {
670
671 use super::*;
672
673 const ELF: &[u8] = include_bytes!("../../tests/test_data/chi_squared.o");
674
675 #[test]
676 fn can_parse_elf() {
677 let result = FheApplication::parse_elf(ELF).unwrap();
678
679 assert_eq!(result.programs.len(), 1);
680 result
681 .get_program(&Symbol::new(
682 &CString::new("chi_squared_optimized").unwrap(),
683 ))
684 .unwrap();
685 }
686
687 #[test]
688 fn test_chi_squared_buffer_info() {
689 let result = FheApplication::parse_elf(ELF).unwrap();
690 let program = result
691 .get_program(&Symbol::from("chi_squared_optimized"))
692 .unwrap();
693 let buffer_info = program.get_buffer_info().unwrap();
694
695 assert_eq!(buffer_info.len(), 7);
696
697 assert_eq!(
699 buffer_info[0],
700 BufferInfo {
701 register: 0,
702 buffer_type: BufferType::Read,
703 is_encrypted: true,
704 buffer_id: 0,
705 width: 16,
706 }
707 );
708
709 assert_eq!(
710 buffer_info[1],
711 BufferInfo {
712 register: 1,
713 buffer_type: BufferType::Read,
714 is_encrypted: true,
715 buffer_id: 1,
716 width: 16,
717 }
718 );
719
720 assert_eq!(
721 buffer_info[2],
722 BufferInfo {
723 register: 2,
724 buffer_type: BufferType::Read,
725 is_encrypted: true,
726 buffer_id: 2,
727 width: 16,
728 }
729 );
730
731 assert_eq!(
733 buffer_info[3],
734 BufferInfo {
735 register: 3,
736 buffer_type: BufferType::ReadWrite,
737 is_encrypted: true,
738 buffer_id: 3,
739 width: 16,
740 }
741 );
742
743 assert_eq!(
744 buffer_info[4],
745 BufferInfo {
746 register: 4,
747 buffer_type: BufferType::ReadWrite,
748 is_encrypted: true,
749 buffer_id: 4,
750 width: 16,
751 }
752 );
753
754 assert_eq!(
755 buffer_info[5],
756 BufferInfo {
757 register: 5,
758 buffer_type: BufferType::ReadWrite,
759 is_encrypted: true,
760 buffer_id: 5,
761 width: 16,
762 }
763 );
764
765 assert_eq!(
766 buffer_info[6],
767 BufferInfo {
768 register: 6,
769 buffer_type: BufferType::ReadWrite,
770 is_encrypted: true,
771 buffer_id: 6,
772 width: 16,
773 }
774 );
775 }
776
777 #[test]
778 fn test_program_buffer_info_methods() {
779 let buffers = vec![
780 BufferInfo {
781 register: 0,
782 buffer_type: BufferType::Read,
783 is_encrypted: true,
784 buffer_id: 0,
785 width: 16,
786 },
787 BufferInfo {
788 register: 1,
789 buffer_type: BufferType::ReadWrite,
790 is_encrypted: true,
791 buffer_id: 1,
792 width: 16,
793 },
794 BufferInfo {
795 register: 2,
796 buffer_type: BufferType::Read,
797 is_encrypted: false,
798 buffer_id: 2,
799 width: 32,
800 },
801 ];
802
803 let program_info = ProgramBufferInfo { buffers };
804
805 assert_eq!(program_info.len(), 3);
807
808 assert_eq!(program_info.num_read_buffers(), 2);
810
811 assert_eq!(program_info.num_read_write_buffers(), 1);
813
814 let read_buffers: Vec<_> = program_info.read_buffers().collect();
816 assert_eq!(read_buffers.len(), 2);
817 assert_eq!(read_buffers[0].register, 0);
818 assert_eq!(read_buffers[1].register, 2);
819
820 let read_write_buffers: Vec<_> = program_info.read_write_buffers().collect();
822 assert_eq!(read_write_buffers.len(), 1);
823 assert_eq!(read_write_buffers[0].register, 1);
824
825 let all_buffers: Vec<_> = program_info.iter().collect();
827 assert_eq!(all_buffers.len(), 3);
828
829 assert_eq!(program_info[0].register, 0);
831 assert_eq!(program_info[1].register, 1);
832 assert_eq!(program_info[2].register, 2);
833
834 let into_iter_buffers: Vec<_> = program_info.into_iter().collect();
836 assert_eq!(into_iter_buffers.len(), 3);
837 assert_eq!(into_iter_buffers[0].register, 0);
838 assert_eq!(into_iter_buffers[1].register, 1);
839 assert_eq!(into_iter_buffers[2].register, 2);
840 }
841}