pub type AprModel = AprV2Model;
pub type AprModelType = ();
#[allow(clippy::implicit_hasher)]
pub fn extract_special_tokens_from_vocab(
token_to_id: &HashMap<String, u32>,
) -> HashMap<String, u32> {
let mut special_tokens = HashMap::new();
for (token, &id) in token_to_id {
if is_special_token(token) {
special_tokens.insert(token.clone(), id);
}
}
special_tokens
}
#[must_use]
fn is_special_token(token: &str) -> bool {
if token.starts_with("<|") && token.ends_with("|>") && token.len() > 4 {
return true;
}
if token.starts_with('<') && token.ends_with('>') && token.len() <= 20 {
let inner = &token[1..token.len() - 1];
let inner = inner.strip_prefix('/').unwrap_or(inner);
if !inner.is_empty() && inner.chars().all(|c| c.is_ascii_lowercase() || c == '_') {
return true;
}
}
false
}
use memmap2::Mmap;
#[derive(Debug)]
pub struct MappedAprModel {
pub header: AprHeader,
pub metadata: AprMetadata,
pub tensors: Vec<TensorEntry>,
mmap: Mmap,
}
impl MappedAprModel {
pub fn from_path<P: AsRef<Path>>(path: P) -> Result<Self> {
let file = File::open(path.as_ref()).map_err(|e| RealizarError::IoError {
message: format!("Failed to open .apr file: {e}"),
})?;
let mmap = unsafe {
Mmap::map(&file).map_err(|e| RealizarError::IoError {
message: format!("Failed to mmap .apr file: {e}"),
})?
};
Self::from_mmap(mmap)
}
fn from_mmap(mmap: Mmap) -> Result<Self> {
let data = mmap.get(..).expect("mmap deref to full slice");
let header = AprHeader::from_bytes(data)?;
if header.magic != MAGIC {
return Err(RealizarError::FormatError {
reason: "Invalid APR magic bytes".to_string(),
});
}
let metadata_start = header.metadata_offset as usize;
let metadata_end = metadata_start + header.metadata_size as usize;
if data.len() < metadata_end {
return Err(RealizarError::FormatError {
reason: "APR file truncated: metadata extends past EOF".to_string(),
});
}
let metadata: AprMetadata = if header.metadata_size > 0 {
serde_json::from_slice(&data[metadata_start..metadata_end]).unwrap_or_default()
} else {
AprMetadata::default()
};
let index_start = header.tensor_index_offset as usize;
let index_end = header.data_offset as usize;
let mut tensors = Vec::with_capacity(header.tensor_count as usize);
if index_start < index_end && index_end <= data.len() {
let index_data = &data[index_start..index_end];
let mut pos = 0;
while pos < index_data.len() && tensors.len() < header.tensor_count as usize {
match TensorEntry::from_binary(&index_data[pos..]) {
Ok((entry, consumed)) => {
tensors.push(entry);
pos += consumed;
},
Err(_) => break,
}
}
}
Ok(Self {
header,
metadata,
tensors,
mmap,
})
}
#[must_use]
pub fn data(&self) -> &[u8] {
&self.mmap[..]
}
#[must_use]
pub fn file_size(&self) -> usize {
self.mmap.len()
}
#[must_use]
pub fn tensor_count(&self) -> usize {
self.tensors.len()
}
#[must_use]
pub fn data_offset(&self) -> u64 {
self.header.data_offset
}
#[must_use]
pub fn find_tensor(&self, name: &str) -> Option<&TensorEntry> {
self.tensors.iter().find(|t| t.name == name)
}
pub fn get_tensor_data(&self, name: &str) -> Result<&[u8]> {
let tensor = self
.find_tensor(name)
.ok_or_else(|| RealizarError::FormatError {
reason: format!("Tensor not found: {name}"),
})?;
let start = self.header.data_offset as usize + tensor.offset as usize;
let end = start + tensor.size as usize;
if end > self.mmap.len() {
return Err(RealizarError::FormatError {
reason: format!("Tensor {name} extends past EOF"),
});
}
Ok(&self.mmap[start..end])
}
#[must_use]
pub fn dtype_to_qtype(dtype: &str) -> u32 {
crate::gguf::GgmlQuantType::from_str_lossy(dtype)
.map_or(0, crate::gguf::GgmlQuantType::as_id)
}
}
#[cfg(test)]
#[path = "tests.rs"]
mod apr_tests;
#[cfg(test)]
#[path = "tests_apr_flags.rs"]
mod apr_tests_part_02;
#[cfg(test)]
#[path = "tests_decode_tokens.rs"]
mod apr_tests_part_03;
#[cfg(test)]
#[path = "tests_pygmy_apr.rs"]
mod apr_tests_pygmy_apr;
#[cfg(test)]
#[path = "tests_apr_flags_02.rs"]
mod apr_tests_part_05;
#[cfg(test)]
#[path = "tests_apr_v2_bytes.rs"]
mod apr_tests_v2_bytes;