objectid 0.2.0

A Rust implementation of a BSON ObjectId.
Documentation
//! A Rust implementation of a BSON `ObjectId`.
//!
#![cfg_attr(all(feature = "unstable", test), feature(test))]
#![deny(missing_docs)]

#[macro_use]
extern crate lazy_static;
#[macro_use]
extern crate quick_error;
extern crate byteorder;
extern crate crypto;
extern crate libc;
extern crate rand;
extern crate hostname;
extern crate rustc_serialize;
#[cfg(feature = "serde")]
extern crate serde;
#[cfg(all(feature = "serde", test))]
extern crate serde_json;


use std::sync::atomic::{AtomicUsize, Ordering};
use std::io;
use std::fmt;
use byteorder::{ByteOrder, BigEndian, LittleEndian};
use crypto::digest::Digest;
use crypto::md5::Md5;
use rand::{Rng, OsRng};
use rustc_serialize::hex::{FromHex, ToHex, FromHexError};
use rustc_serialize::{Decodable, Decoder, Encodable, Encoder};


const TIMESTAMP_SIZE: usize = 4;
const MACHINE_ID_SIZE: usize = 3;
const PROCESS_ID_SIZE: usize = 2;
const COUNTER_SIZE: usize = 3;
const TIMESTAMP_OFFSET: usize = 0;
const MACHINE_ID_OFFSET: usize = TIMESTAMP_OFFSET + TIMESTAMP_SIZE;
const PROCESS_ID_OFFSET: usize = MACHINE_ID_OFFSET + MACHINE_ID_SIZE;
const COUNTER_OFFSET: usize = PROCESS_ID_OFFSET + PROCESS_ID_SIZE;
const MAX_U24: usize = 0xFFFFFF;

lazy_static! {
    static ref COUNTER: Result<AtomicUsize, Error> = gen_counter();
    static ref MACHINE_BYTES: Result<[u8; 3], Error> = gen_machine_id();
}

quick_error!{
    /// Errors that can occur during OID construction and generation.
    #[derive(Debug, Clone)]
    pub enum Error {
        /// Provided string must be a 12-byte hexadecimal string.
        Argument {
            display("Provided string must be a 12-byte hexadecimal string.")
        }
        /// Io error.
        Io {
            from(io::Error)
        }
        /// Hex encode or decode error.
        Hex {
            from(FromHexError)
        }
        /// Can't get hostname.
        Hostname
    }
}

/// A wrapper around raw 12-byte `ObjectId` representations.
#[derive(Clone, PartialEq, PartialOrd, Eq, Ord, Hash)]
pub struct ObjectId([u8; 12]);

impl ObjectId {
    /// Generates a new ObjectId, represented in bytes.
    /// See the [docs](http://docs.mongodb.org/manual/reference/object-id/)
    /// for more information.
    pub fn new() -> Result<ObjectId, Error> {
        if MACHINE_BYTES.is_err() {
            return Err(MACHINE_BYTES.as_ref().err().cloned().unwrap());
        }

        if COUNTER.is_err() {
            return Err(COUNTER.as_ref().err().cloned().unwrap());
        }

        let timestamp = gen_timestamp();
        let machine_id = MACHINE_BYTES.as_ref().unwrap();
        let process_id = gen_process_id();
        let counter = gen_count();

        let mut buf: [u8; 12] = [0; 12];

        for (i, bit) in timestamp.iter().enumerate().take(TIMESTAMP_SIZE) {
            buf[TIMESTAMP_OFFSET + i] = *bit;
        }

        for (i, bit) in machine_id.iter().enumerate().take(MACHINE_ID_SIZE) {
            buf[MACHINE_ID_OFFSET + i] = *bit
        }

        for (i, bit) in process_id.iter().enumerate().take(MACHINE_ID_SIZE) {
            buf[PROCESS_ID_OFFSET + i] = *bit
        }

        for (i, bit) in counter.iter().enumerate().take(MACHINE_ID_SIZE) {
            buf[COUNTER_OFFSET + i] = *bit
        }

        Ok(Self::with_bytes(buf))
    }

    /// Constructs a new ObjectId wrapper around the raw byte representation.
    pub fn with_bytes(bytes: [u8; 12]) -> ObjectId {
        ObjectId(bytes)
    }

    /// Creates an ObjectId using a 12-byte (24-char) hexadecimal string.
    pub fn with_string(oid: &str) -> Result<ObjectId, Error> {
        let bytes = oid.from_hex()?;
        if bytes.len() != 12 {
            Err(Error::Argument)
        } else {
            let mut byte_array: [u8; 12] = [0; 12];
            for i in 0..12 {
                byte_array[i] = bytes[i];
            }
            Ok(Self::with_bytes(byte_array))
        }
    }

    /// Creates a dummy ObjectId with a specific generation time.
    /// This method should only be used to do range queries on a field
    /// containing ObjectId instances.
    pub fn with_timestamp(time: u32) -> ObjectId {
        let mut buf: [u8; 12] = [0; 12];
        BigEndian::write_u32(&mut buf, time);
        Self::with_bytes(buf)
    }

    /// Returns the raw byte representation of an ObjectId.
    pub fn bytes(&self) -> [u8; 12] {
        self.0
    }

    /// Retrieves the timestamp (seconds since epoch) from an ObjectId.
    pub fn timestamp(&self) -> u32 {
        BigEndian::read_u32(&self.0)
    }

    /// Retrieves the machine id associated with an ObjectId.
    pub fn machine_id(&self) -> u32 {
        let mut buf: [u8; 4] = [0; 4];

        for (i, bit) in buf.iter_mut().enumerate().take(MACHINE_ID_SIZE) {
            *bit = self.0[MACHINE_ID_OFFSET + i]
        }

        LittleEndian::read_u32(&buf)
    }

    /// Retrieves the process id associated with an ObjectId.
    pub fn process_id(&self) -> u16 {
        LittleEndian::read_u16(&self.0[PROCESS_ID_OFFSET..])
    }

    /// Retrieves the increment counter from an ObjectId.
    pub fn counter(&self) -> u32 {
        let mut buf: [u8; 4] = [0; 4];
        for i in 0..COUNTER_SIZE {
            buf[i + 1] = self.0[COUNTER_OFFSET + i];
        }
        BigEndian::read_u32(&buf)
    }
}

impl fmt::Display for ObjectId {
    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
        write!(f, "{}", self.0.to_hex())
    }
}

impl fmt::Debug for ObjectId {
    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
        write!(f, "ObjectId({})", self.0.to_hex())
    }
}

impl ToHex for ObjectId {
    fn to_hex(&self) -> String {
        self.0.to_hex()
    }
}

impl Decodable for ObjectId {
    fn decode<D: Decoder>(d: &mut D) -> Result<Self, D::Error> {
        Ok(ObjectId::with_string(&d.read_str()?).unwrap())
    }
}

impl Encodable for ObjectId {
    fn encode<S: Encoder>(&self, s: &mut S) -> Result<(), S::Error> {
        s.emit_str(&self.to_string())
    }
}


#[cfg(feature = "serde")]
impl serde::Serialize for ObjectId {
    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
        where S: serde::Serializer
    {
        serializer.serialize_str(&self.0.to_hex())
    }
}

#[cfg(feature = "serde")]
struct ObjectIdVisitor;

#[cfg(feature = "serde")]
impl serde::de::Visitor for ObjectIdVisitor {
    type Value = ObjectId;

    fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
        write!(formatter,
               "Provided string must be a 12-byte hexadecimal string")
    }

    fn visit_str<E>(self, value: &str) -> Result<ObjectId, E>
        where E: serde::de::Error
    {
        let id = ObjectId::with_string(value);
        if id.is_ok() {
            Ok(id.unwrap())
        } else {
            Err(E::custom(format!("Provided string must be a 12-byte hexadecimal string.: {}",
                                  value)))
        }
    }
}

#[cfg(feature = "serde")]
impl serde::Deserialize for ObjectId {
    fn deserialize<D>(deserializer: D) -> Result<ObjectId, D::Error>
        where D: serde::Deserializer
    {
        deserializer.deserialize(ObjectIdVisitor)
    }
}


#[inline]
fn gen_counter() -> Result<AtomicUsize, Error> {
    Ok(AtomicUsize::new(OsRng::new()?.gen_range(0, MAX_U24 + 1)))
}

#[inline]
fn gen_process_id() -> [u8; 2] {
    let pid = unsafe { libc::getpid() as u16 };
    let mut buf: [u8; 2] = [0; 2];
    LittleEndian::write_u16(&mut buf, pid);
    buf
}

#[inline]
fn gen_machine_id() -> Result<[u8; 3], Error> {
    let hostname = hostname::get_hostname().unwrap();
    let mut md5 = Md5::new();
    md5.input_str(hostname.as_str());
    let hash = md5.result_str();

    let mut bytes = hash.bytes();
    let mut vec: [u8; 3] = [0; 3];

    for bit in vec.iter_mut().take(MACHINE_ID_SIZE) {
        match bytes.next() {
            Some(b) => *bit = b,
            None => break,
        }
    }

    Ok(vec)
}

extern "C" {
    fn time(time: *mut libc::time_t) -> libc::time_t;
}

#[inline]
fn get_time() -> i64 {
    unsafe { time(0 as *mut libc::time_t) }
}

#[inline]
fn gen_timestamp() -> [u8; 4] {
    let timestamp = get_time() as u32;
    let mut buf: [u8; 4] = [0; 4];
    BigEndian::write_u32(&mut buf, timestamp);
    buf
}

#[inline]
fn gen_count() -> [u8; 3] {
    let counter = COUNTER.as_ref().unwrap().fetch_add(1, Ordering::SeqCst) % MAX_U24;
    let mut buf: [u8; 8] = [0; 8];
    BigEndian::write_u64(&mut buf, counter as u64);
    [buf[5], buf[6], buf[7]]
}



#[test]
fn test_count_generated_is_big_endian() {
    let start = 1122866;
    COUNTER.as_ref().unwrap().store(start, Ordering::SeqCst);

    let count_bytes = gen_count();
    let mut buf: [u8; 4] = [0; 4];

    for (i, bit) in count_bytes.iter().enumerate() {
        buf[i + 1] = *bit;
    }

    let count = BigEndian::read_u32(&buf);
    assert_eq!(start as u32, count);
}

#[test]
fn test_display() {
    let id = ObjectId::with_string("53e37d08776f724e42000000").unwrap();
    assert_eq!(format!("{}", id), "53e37d08776f724e42000000")
}

#[test]
fn test_debug() {
    let id = ObjectId::with_string("53e37d08776f724e42000000").unwrap();
    assert_eq!(format!("{:?}", id), "ObjectId(53e37d08776f724e42000000)")
}

#[cfg(feature = "serde")]
#[test]
fn test_serde_encode() {
    let id = ObjectId::with_string("53e37d08776f724e42000000").unwrap();
    assert_eq!(serde_json::to_value("53e37d08776f724e42000000").unwrap(),
               serde_json::to_value(id).unwrap());
}

#[cfg(feature = "serde")]
#[test]
fn test_serde_decode() {
    assert_eq!(ObjectId::with_string("53e37d08776f724e42000000").unwrap(),
               serde_json::from_str::<ObjectId>(r#""53e37d08776f724e42000000""#).unwrap());
}

#[test]
fn test_get_time() {
    assert!(get_time() > 1481757541);
}

#[cfg(all(feature = "unstable", test))]
mod benches {
    extern crate test;
    use super::ObjectId;

    #[bench]
    fn bench_create_object_id(b: &mut test::Bencher) {
        b.iter(|| ObjectId::new().unwrap())
    }
}