squishy 0.5.0

A convenient high level library for reading SquashFS files
Documentation
//! fbthrift's Frozen2 format deserializer.

use serde::{de, forward_to_deserialize_any};

use super::{Schema, SchemaLayout};

type Result<T, E = Error> = std::result::Result<T, E>;
type Error = de::value::Error;

/// The offset type we use to index into metadata bytes.
pub(crate) type Offset = u32;

fn to_usize(offset: Offset) -> usize {
    const _: () = assert!(size_of::<Offset>() <= size_of::<usize>());
    offset as usize
}

pub(crate) fn deserialize<T: de::DeserializeOwned>(schema: &Schema, bytes: &[u8]) -> Result<T> {
    let root_layout = schema.layouts.get(schema.root_layout).expect("validated");
    let de = Deserializer {
        src: &Source { schema, bytes },
        layout: Some(root_layout),
        bit_offset: 0,
        storage_start: 0,
    };
    T::deserialize(de)
}

#[derive(Clone, Copy)]
struct Source<'a> {
    schema: &'a Schema,
    bytes: &'a [u8],
}

impl Source<'_> {
    fn load_bit(&self, base_bit: Offset) -> Result<bool> {
        let (byte_idx, bit_idx) = (to_usize(base_bit) / 8, base_bit % 8);
        let b = *self
            .bytes
            .get(byte_idx)
            .ok_or_else(|| de::Error::custom("bit location overflow"))?;
        Ok((b >> bit_idx) & 1 != 0)
    }

    fn load_bits(&self, base_bit: Offset, bits: u16) -> Result<u64> {
        debug_assert!(bits > 0);
        debug_assert!(bits <= 64);
        let (byte_idx, bit_start) = (to_usize(base_bit) / 8, base_bit as u16 % 8);
        let last_byte_idx = (base_bit + Offset::from(bits) - 1) / 8;
        if to_usize(last_byte_idx) >= self.bytes.len() {
            return Err(de::Error::custom("bits location overflow"));
        }

        let rest = &self.bytes[byte_idx..];
        let x = if rest.len() >= 8 {
            u64::from_le_bytes(rest[..8].try_into().unwrap())
        } else {
            let mut buf = [0u8; 8];
            buf[..rest.len()].copy_from_slice(rest);
            u64::from_le_bytes(buf)
        };

        let start_and_bits = bit_start + bits;
        Ok(if start_and_bits <= 64 {
            x << (64 - start_and_bits) >> (64 - bits)
        } else {
            let overshooting_bits = start_and_bits & 63;
            let hi = u64::from(rest[8]);
            x >> bit_start | hi << (64 - overshooting_bits) >> (64 - bits)
        })
    }
}

#[derive(Clone, Copy)]
struct Deserializer<'a, 'de> {
    src: &'a Source<'de>,
    layout: Option<&'de SchemaLayout>,
    bit_offset: Offset,
    storage_start: Offset,
}

impl<'de> Deserializer<'_, 'de> {
    fn field_deserializer(&self, i: i16) -> Self {
        let (layout, offset_bits) = if let Some(field) = self.layout.and_then(|l| l.fields.get(i)) {
            (
                self.src.schema.layouts.get(field.layout_id),
                field.offset_bits(),
            )
        } else {
            (None, 0)
        };
        Self {
            src: self.src,
            layout,
            bit_offset: self.bit_offset + Offset::from(offset_bits),
            storage_start: self.storage_start,
        }
    }

    fn deserialize_field<T: de::Deserialize<'de>>(&self, i: i16) -> Result<T> {
        de::Deserialize::deserialize(self.field_deserializer(i))
    }
}

impl<'de> de::Deserializer<'de> for Deserializer<'_, 'de> {
    type Error = Error;

    fn is_human_readable(&self) -> bool {
        false
    }

    fn deserialize_any<V>(self, _visitor: V) -> Result<V::Value>
    where
        V: de::Visitor<'de>,
    {
        unimplemented!()
    }

    fn deserialize_bool<V>(self, visitor: V) -> Result<V::Value>
    where
        V: de::Visitor<'de>,
    {
        let b = self.layout.is_some()
            && self
                .src
                .load_bit(self.storage_start * 8 + self.bit_offset)?;
        visitor.visit_bool(b)
    }

    fn deserialize_u32<V>(self, visitor: V) -> Result<V::Value>
    where
        V: de::Visitor<'de>,
    {
        self.deserialize_u64(visitor)
    }

    fn deserialize_u64<V>(self, visitor: V) -> Result<V::Value>
    where
        V: de::Visitor<'de>,
    {
        let Some(layout) = self.layout else {
            return visitor.visit_u64(0);
        };
        if !layout.fields.is_empty() {
            return Err(de::Error::invalid_type(
                de::Unexpected::Other("a schema layout with some fields"),
                &"an unsigned integer",
            ));
        }
        let bits = layout.bits;
        if !(0..=64).contains(&bits) {
            return Err(de::Error::custom("too many bits for an unsigned int"));
        }
        visitor.visit_u64(
            self.src
                .load_bits(self.storage_start * 8 + self.bit_offset, bits as u16)?,
        )
    }

    fn deserialize_byte_buf<V>(self, visitor: V) -> Result<V::Value>
    where
        V: de::Visitor<'de>,
    {
        self.deserialize_bytes(visitor)
    }

    fn deserialize_bytes<V>(self, visitor: V) -> Result<V::Value>
    where
        V: de::Visitor<'de>,
    {
        let distance = self.deserialize_field::<u32>(1)?;
        let len = self.deserialize_field::<u32>(2)?;

        let content = (|| {
            let start = self.storage_start.checked_add(distance)?;
            let end = start.checked_add(len)?;
            self.src
                .bytes
                .get(usize::try_from(start).ok()?..usize::try_from(end).ok()?)
        })()
        .ok_or_else(|| <Error as de::Error>::custom("string offset or length overflow"))?;
        visitor.visit_borrowed_bytes(content)
    }

    fn deserialize_seq<V>(self, visitor: V) -> Result<V::Value>
    where
        V: de::Visitor<'de>,
    {
        let distance = self.deserialize_field::<u32>(1)?;
        let len = self.deserialize_field::<u32>(2)?;
        let elem_layout = self.layout.and_then(|l| {
            let id = l.fields.get(3)?.layout_id;
            Some(self.src.schema.layouts.get(id).expect("validated"))
        });
        visitor.visit_seq(CollectionDeserializer {
            elem_de: Self {
                src: self.src,
                layout: elem_layout,
                bit_offset: 0,
                storage_start: self.storage_start + distance,
            },
            len,
        })
    }

    fn deserialize_map<V>(self, visitor: V) -> Result<V::Value>
    where
        V: de::Visitor<'de>,
    {
        let distance = self.deserialize_field::<u32>(1)?;
        let len = self.deserialize_field::<u32>(2)?;
        let elem_layout = self.layout.and_then(|l| {
            let id = l.fields.get(3)?.layout_id;
            Some(self.src.schema.layouts.get(id).expect("validated"))
        });
        visitor.visit_map(CollectionDeserializer {
            elem_de: Self {
                src: self.src,
                layout: elem_layout,
                bit_offset: 0,
                storage_start: self.storage_start + distance,
            },
            len,
        })
    }

    fn deserialize_option<V>(self, visitor: V) -> Result<V::Value>
    where
        V: de::Visitor<'de>,
    {
        if !self.deserialize_field::<bool>(1)? {
            return visitor.visit_none();
        }
        visitor.visit_some(self.field_deserializer(2))
    }

    fn deserialize_struct<V>(
        self,
        _name: &'static str,
        _fields: &'static [&'static str],
        visitor: V,
    ) -> Result<V::Value>
    where
        V: de::Visitor<'de>,
    {
        visitor.visit_map(StructDeserializer {
            de: self,
            field_id: 0,
        })
    }

    fn deserialize_ignored_any<V>(self, visitor: V) -> Result<V::Value>
    where
        V: de::Visitor<'de>,
    {
        visitor.visit_unit()
    }

    forward_to_deserialize_any! {
        i8 i16 i32 i64 i128 u8 u16 u128 f32 f64 char str string
        unit unit_struct newtype_struct tuple
        tuple_struct enum identifier
    }
}

struct StructDeserializer<'i, 'de> {
    de: Deserializer<'i, 'de>,
    field_id: usize,
}

impl<'de> de::MapAccess<'de> for StructDeserializer<'_, 'de> {
    type Error = Error;

    fn next_key_seed<K>(&mut self, seed: K) -> Result<Option<K::Value>>
    where
        K: de::DeserializeSeed<'de>,
    {
        let Some(layout) = self.de.layout else {
            return Ok(None);
        };

        let fields = &layout.fields.0;
        while self.field_id < fields.len() {
            if fields[self.field_id].is_some() {
                let serde_field_id = self.field_id as u64 - 1;
                return seed
                    .deserialize(de::value::U64Deserializer::new(serde_field_id))
                    .map(Some);
            }
            self.field_id += 1;
        }
        Ok(None)
    }

    fn next_value_seed<V>(&mut self, seed: V) -> Result<V::Value>
    where
        V: de::DeserializeSeed<'de>,
    {
        self.field_id += 1;
        seed.deserialize(self.de.field_deserializer(self.field_id as i16 - 1))
    }
}

struct CollectionDeserializer<'a, 'de> {
    elem_de: Deserializer<'a, 'de>,
    len: u32,
}

impl<'de> de::SeqAccess<'de> for CollectionDeserializer<'_, 'de> {
    type Error = Error;

    fn size_hint(&self) -> Option<usize> {
        self.len.try_into().ok()
    }

    fn next_element_seed<T>(&mut self, seed: T) -> Result<Option<T::Value>>
    where
        T: de::DeserializeSeed<'de>,
    {
        if self.len == 0 {
            return Ok(None);
        }

        let ret = seed.deserialize(self.elem_de);
        self.len -= 1;
        if let Some(layout) = self.elem_de.layout {
            self.elem_de.bit_offset += layout.bits as Offset;
        }
        ret.map(Some)
    }
}

impl<'de> de::MapAccess<'de> for CollectionDeserializer<'_, 'de> {
    type Error = Error;

    fn size_hint(&self) -> Option<usize> {
        self.len.try_into().ok()
    }

    fn next_key_seed<K>(&mut self, seed: K) -> Result<Option<K::Value>>
    where
        K: de::DeserializeSeed<'de>,
    {
        if self.len == 0 {
            return Ok(None);
        }
        self.len -= 1;

        seed.deserialize(self.elem_de.field_deserializer(1))
            .map(Some)
    }

    fn next_value_seed<V>(&mut self, seed: V) -> Result<V::Value>
    where
        V: de::DeserializeSeed<'de>,
    {
        let ret = seed.deserialize(self.elem_de.field_deserializer(2));
        if let Some(layout) = self.elem_de.layout {
            self.elem_de.bit_offset += layout.bits as Offset;
        }
        ret
    }
}