1
  2
  3
  4
  5
  6
  7
  8
  9
 10
 11
 12
 13
 14
 15
 16
 17
 18
 19
 20
 21
 22
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
//! Utility functions for the UCS-2 character encoding.

#![no_std]

#[deny(missing_docs)]
#[cfg_attr(feature = "cargo-clippy", deny(clippy))]

/// Possible errors returned by the API.
#[derive(Debug, Copy, Clone)]
pub enum Error {
    /// Not enough space left in the output buffer.
    BufferOverflow,
    /// Input contained a character which cannot be represented in UCS-2.
    MultiByte,
}

type Result<T> = core::result::Result<T, Error>;

/// Encodes an input UTF-8 string into a UCS-2 string.
///
/// The returned `usize` represents the length of the returned buffer,
/// measured in 2-byte characters.
pub fn encode(input: &str, buffer: &mut [u16]) -> Result<usize> {
    let buffer_size = buffer.len();
    let mut i = 0;

    encode_with(input, |ch| {
        if i >= buffer_size {
            Err(Error::BufferOverflow)
        } else {
            buffer[i] = ch;
            i += 1;
            Ok(())
        }
    })?;

    Ok(i)
}

/// Encode UTF-8 string to UCS-2 with a custom callback function.
///
/// `output` is a function which receives every encoded character.
pub fn encode_with<F>(input: &str, mut output: F) -> Result<()>
where
    F: FnMut(u16) -> Result<()>,
{
    let bytes = input.as_bytes();
    let len = bytes.len();
    let mut i = 0;

    while i < len {
        let ch;

        if bytes[i] & 0b1000_0000 == 0b0000_0000 {
            ch = u16::from(bytes[i]);
            i += 1;
        } else if bytes[i] & 0b1110_0000 == 0b1100_0000 {
            // 2 byte codepoint
            if i + 1 >= len {
                // safe: len is the length of bytes,
                // and bytes is a direct view into the
                // buffer of input, which in order to be a valid
                // utf-8 string _must_ contain `i + 1`.
                unsafe { core::hint::unreachable_unchecked() }
            }

            let a = u16::from(bytes[i] & 0b0001_1111);
            let b = u16::from(bytes[i + 1] & 0b0011_1111);
            ch = a << 6 | b;
            i += 2;
        } else if bytes[i] & 0b1111_0000 == 0b1110_0000 {
            // 3 byte codepoint
            if i + 2 >= len || i + 1 >= len {
                // safe: impossible utf-8 string.
                unsafe { core::hint::unreachable_unchecked() }
            }

            let a = u16::from(bytes[i] & 0b0000_1111);
            let b = u16::from(bytes[i + 1] & 0b0011_1111);
            let c = u16::from(bytes[i + 2] & 0b0011_1111);
            ch = a << 12 | b << 6 | c;
            i += 3;
        } else if bytes[i] & 0b1111_0000 == 0b1111_0000 {
            return Err(Error::MultiByte); // UTF-16
        } else {
            // safe: impossible utf-8 string.
            unsafe { core::hint::unreachable_unchecked() }
        }
        output(ch)?;
    }
    Ok(())
}

#[cfg(test)]
mod tests {
    use super::*;

    #[test]
    fn encoding() {
        let input = "őэ╋";
        let mut buffer = [0u16; 3];

        let result = encode(input, &mut buffer);
        assert_eq!(result.unwrap(), 3);

        assert_eq!(buffer[..], [0x0151, 0x044D, 0x254B]);
    }
}