use std::collections::{HashMap, HashSet};
use std::path::Path;
use memmap2::Mmap;
use prost::Message;
use prost_reflect::{DescriptorPool, EnumDescriptor, MessageDescriptor};
use prost_types::{FileDescriptorProto, FileDescriptorSet};
use prototext_graph::fds_index::ArchivedFdsIndex;
use rkyv::api::access_unchecked;
use crate::EMBEDDED_DESCRIPTOR;
const MAGIC: &[u8; 8] = b"PTSGRAPH";
const VERSION: u32 = 3;
pub struct LazyPool {
_raw_mmap: Mmap,
_idx_mmap: Mmap,
index: &'static ArchivedFdsIndex,
raw_bytes: &'static [u8],
pub pool: DescriptorPool,
wkt_fdps: HashMap<String, FileDescriptorProto>,
loaded: HashSet<String>,
in_progress: HashSet<String>,
}
fn check_header(bytes: &[u8], label: &str) -> Result<usize, Box<dyn std::error::Error>> {
if bytes.len() < 24 {
return Err(format!("{label}: file too short for PTSGRAPH header").into());
}
if &bytes[0..8] != MAGIC {
return Err(format!("{label}: bad magic (expected PTSGRAPH)").into());
}
let version = u32::from_le_bytes(bytes[8..12].try_into()?);
if version != VERSION {
return Err(format!("{label}: unsupported version {version} (expected {VERSION})").into());
}
let root_offset = u64::from_le_bytes(bytes[16..24].try_into()?) as usize;
Ok(root_offset)
}
impl LazyPool {
pub fn open(pb_path: &Path, idx_path: &Path) -> Result<Self, Box<dyn std::error::Error>> {
let raw_file =
std::fs::File::open(pb_path).map_err(|e| format!("{}: {e}", pb_path.display()))?;
let idx_file =
std::fs::File::open(idx_path).map_err(|e| format!("{}: {e}", idx_path.display()))?;
let raw_mmap = unsafe { Mmap::map(&raw_file) }
.map_err(|e| format!("{}: mmap: {e}", pb_path.display()))?;
let idx_mmap = unsafe { Mmap::map(&idx_file) }
.map_err(|e| format!("{}: mmap: {e}", idx_path.display()))?;
let root_offset = check_header(&idx_mmap, &idx_path.display().to_string())?;
let index: &'static ArchivedFdsIndex = unsafe {
let slice: &[u8] = &*(&idx_mmap[root_offset..] as *const [u8]);
access_unchecked::<ArchivedFdsIndex>(slice)
};
let raw_bytes: &'static [u8] = unsafe { &*(&raw_mmap[..] as *const [u8]) };
let wkt_fds = FileDescriptorSet::decode(EMBEDDED_DESCRIPTOR)
.map_err(|e| format!("decoding embedded WKT descriptor: {e}"))?;
let wkt_fdps: HashMap<String, FileDescriptorProto> = wkt_fds
.file
.into_iter()
.map(|f| (f.name().to_owned(), f))
.collect();
Ok(LazyPool {
_raw_mmap: raw_mmap,
_idx_mmap: idx_mmap,
index,
raw_bytes,
pool: DescriptorPool::new(),
wkt_fdps,
loaded: HashSet::new(),
in_progress: HashSet::new(),
})
}
}
impl LazyPool {
fn ensure_loaded(&mut self, file: &str) -> Result<(), Box<dyn std::error::Error>> {
if self.loaded.contains(file) {
return Ok(());
}
if self.in_progress.contains(file) {
return Err(format!("cycle detected in FDS dependency graph: '{file}'").into());
}
self.in_progress.insert(file.to_owned());
let deps: Vec<String> = if let Some(v) = self.index.dep_graph.get(file) {
v.iter().map(|s| s.as_str().to_owned()).collect()
} else if let Some(wkt_fdp) = self.wkt_fdps.get(file) {
wkt_fdp
.dependency
.iter()
.map(|s| s.as_str().to_owned())
.collect()
} else {
vec![]
};
for dep in &deps {
self.ensure_loaded(dep)?;
}
let fdp = if let Some(span) = self.index.file_to_span.get(file) {
let (start, end) = (span.0.to_native() as usize, span.1.to_native() as usize);
FileDescriptorProto::decode(&self.raw_bytes[start..end])
.map_err(|e| format!("decoding FDP for '{file}': {e}"))?
} else if let Some(wkt_fdp) = self.wkt_fdps.get(file) {
wkt_fdp.clone()
} else {
return Err(format!("'{file}' not found in FDS index or embedded WKT fallback").into());
};
self.pool
.add_file_descriptor_proto(fdp)
.map_err(|e| format!("adding FDP '{file}' to pool: {e}"))?;
self.in_progress.remove(file);
self.loaded.insert(file.to_owned());
Ok(())
}
fn resolve_file(&self, fqdn: &str) -> Option<String> {
if let Some(f) = self.index.type_to_file.get(fqdn) {
return Some(f.as_str().to_owned());
}
self.wkt_fdps
.iter()
.find(|(_, fdp)| {
let pkg = fdp.package();
let prefixed = |name: &str| {
if pkg.is_empty() {
name.to_owned()
} else {
format!("{pkg}.{name}")
}
};
fdp.message_type.iter().any(|m| prefixed(m.name()) == fqdn)
|| fdp.enum_type.iter().any(|e| prefixed(e.name()) == fqdn)
})
.map(|(fname, _)| fname.clone())
}
pub fn get_message(
&mut self,
fqdn: &str,
) -> Result<Option<MessageDescriptor>, Box<dyn std::error::Error>> {
let fqdn = fqdn.trim_start_matches('.');
let Some(file) = self.resolve_file(fqdn) else {
return Ok(None);
};
self.ensure_loaded(&file)?;
Ok(self.pool.get_message_by_name(fqdn))
}
pub fn get_enum(
&mut self,
fqdn: &str,
) -> Result<Option<EnumDescriptor>, Box<dyn std::error::Error>> {
let fqdn = fqdn.trim_start_matches('.');
let Some(file) = self.resolve_file(fqdn) else {
return Ok(None);
};
self.ensure_loaded(&file)?;
Ok(self.pool.get_enum_by_name(fqdn))
}
pub fn load_all(&mut self) -> Result<(), Box<dyn std::error::Error>> {
let files: Vec<String> = self
.index
.file_to_span
.keys()
.map(|s| s.as_str().to_owned())
.collect();
for file in files {
self.ensure_loaded(&file)?;
}
Ok(())
}
}