grorm 0.1.0

Goroutine-native ORM for Rust with multi-database support
Documentation
use gorust::net::AsyncTcpStream;
use std::net::SocketAddr;

#[derive(Debug, Clone)]
pub struct MyColumnInfo {
    pub name: String,
    pub data_type: u8,
    pub flags: u16,
    pub decimals: u8,
}

pub enum MyResult {
    Rows(Vec<Vec<String>>, Vec<MyColumnInfo>),
    Ok(u64, Option<i64>),
    Err(String),
}

pub struct MyConnection {
    stream: AsyncTcpStream,
    sequence_id: u8,
}

impl MyConnection {
    pub fn connect(addr: SocketAddr, username: &str, password: &str, database: &str) -> Result<Self, crate::error::Error> {
        let stream = AsyncTcpStream::connect(addr)?;
        let mut conn = MyConnection { stream, sequence_id: 0 };
        conn.read_handshake()?;
        conn.send_handshake_response(username, password, database)?;
        conn.read_auth_result()?;
        Ok(conn)
    }

    fn read_handshake(&mut self) -> Result<(), crate::error::Error> {
        let pkt_len = self.read_u24()?;
        let seq = self.read_u8()?;
        self.sequence_id = seq.wrapping_add(1);

        let _protocol_version = self.read_u8()?;
        let mut version = String::new();
        loop {
            let b = self.read_u8()?;
            if b == 0 { break; }
            version.push(b as char);
        }
        let _thread_id = self.read_u32()?;
        let mut auth_plugin_data = vec![0u8; 8];
        self.stream_read(&mut auth_plugin_data)?;
        self.read_u16()?;
        let _charset = self.read_u8()?;
        let _status = self.read_u16()?;
        let _auth_plugin_len = self.read_u8()?;
        self.skip(10)?;

        let remaining = pkt_len as usize - 36;
        if remaining > 0 {
            self.skip(remaining)?;
        }
        Ok(())
    }

    fn send_handshake_response(&mut self, username: &str, password: &str, database: &str) -> Result<(), crate::error::Error> {
        let mut payload = Vec::new();
        payload.extend_from_slice(&0x0285a2ffu32.to_le_bytes());
        payload.extend_from_slice(&0x21u32.to_le_bytes());
        payload.extend_from_slice(&0x00u32.to_le_bytes());
        payload.extend_from_slice(&0x21u32.to_le_bytes());
        payload.extend_from_slice(&0x00u32.to_le_bytes());
        payload.extend_from_slice(&0x00u32.to_le_bytes());
        payload.extend_from_slice(&0x00u32.to_le_bytes());
        payload.extend_from_slice(&0x00u32.to_le_bytes());
        payload.extend_from_slice(&0x00u32.to_le_bytes());
        payload.extend_from_slice(&0x00u32.to_le_bytes());
        payload.extend_from_slice(&0x00u32.to_le_bytes());
        payload.extend_from_slice(&0x00u32.to_le_bytes());
        payload.extend_from_slice(&0x00u32.to_le_bytes());
        payload.extend_from_slice(&0x00u32.to_le_bytes());
        payload.extend_from_slice(&0x00u32.to_le_bytes());
        payload.extend_from_slice(&0x00u32.to_le_bytes());
        payload.extend_from_slice(&0x00u32.to_le_bytes());
        payload.extend_from_slice(&0x00u32.to_le_bytes());
        payload.extend_from_slice(&0x00u32.to_le_bytes());
        payload.extend_from_slice(&0x00u32.to_le_bytes());
        payload.extend_from_slice(&0x00u32.to_le_bytes());
        payload.extend_from_slice(&0x00u32.to_le_bytes());
        payload.extend_from_slice(&0x00u32.to_le_bytes());

        payload.extend_from_slice(username.as_bytes());
        payload.push(0);

        let auth_response = if password.is_empty() {
            vec![0u8]
        } else {
            let mut resp = vec![0x14u8];
            let mut hash = [0u8; 20];
            for (i, b) in password.bytes().enumerate() {
                if i < 20 { hash[i] = b; }
            }
            resp.extend_from_slice(&hash);
            resp
        };
        payload.push(auth_response.len() as u8);
        payload.extend_from_slice(&auth_response);

        payload.extend_from_slice(database.as_bytes());
        payload.push(0);

        payload.extend_from_slice(b"mysql_native_password\0");

        self.send_packet(1, &payload)?;
        Ok(())
    }

    fn read_auth_result(&mut self) -> Result<(), crate::error::Error> {
        let pkt_len = self.read_u24()?;
        let _seq = self.read_u8()?;
        self.sequence_id = 2;

        let header = self.read_u8()?;
        match header {
            0x00 => {
                self.skip((pkt_len - 1) as usize)?;
                Ok(())
            }
            0xFF => {
                let _code = self.read_u16()?;
                let mut msg = String::new();
                for _ in 0..(pkt_len - 3) {
                    let b = self.read_u8()?;
                    msg.push(b as char);
                }
                Err(format!("MySQL error: {}", msg).into())
            }
            _ => {
                self.skip((pkt_len - 1) as usize)?;
                self.read_auth_result()
            }
        }
    }

    pub fn execute_query(&mut self, sql: &str) -> Result<MyResult, crate::error::Error> {
        self.sequence_id = 0;
        self.send_command(0x03, sql)?;

        let pkt_len = self.read_u24()?;
        let _seq = self.read_u8()?;
        self.sequence_id = 1;

        let first_byte = self.read_u8()?;
        match first_byte {
            0x00 => {
                let affected_rows = self.read_lenenc()?;
                let last_insert_id = self.read_lenenc()?;
                let _status = self.read_u16()?;
                let _warnings = self.read_u16()?;
                Ok(MyResult::Ok(affected_rows, Some(last_insert_id as i64)))
            }
            0xFF => {
                let _code = self.read_u16()?;
                let mut msg = String::new();
                for _ in 0..(pkt_len - 3) {
                    let b = self.read_u8()?;
                    msg.push(b as char);
                }
                Err(format!("MySQL error: {}", msg).into())
            }
            _ => {
                let num_cols = self.read_lenenc_from_byte(first_byte, &mut vec![0u8; 0])?;
                let mut columns = Vec::with_capacity(num_cols as usize);
                for _ in 0..num_cols {
                    let col = self.read_column_definition()?;
                    columns.push(col);
                }
                self.read_eof()?;

                let mut rows = Vec::new();
                loop {
                    let pkt_len = self.read_u24()?;
                    let _seq = self.read_u8()?;
                    if pkt_len < 9 {
                        let first = self.read_u8()?;
                        if first == 0xFE {
                            self.skip((pkt_len - 1) as usize)?;
                            break;
                        }
                        let mut buf = vec![first];
                        for _ in 0..(pkt_len - 1) {
                            buf.push(self.read_u8()?);
                        }
                        let mut row = Vec::with_capacity(columns.len());
                        let mut pos = 0;
                        for _ in 0..columns.len() {
                            if pos < buf.len() {
                                let len = buf[pos] as usize;
                                pos += 1;
                                if len < 251 {
                                    let val = String::from_utf8_lossy(&buf[pos..pos + len]).to_string();
                                    row.push(val);
                                    pos += len;
                                } else {
                                    row.push("NULL".to_string());
                                }
                            } else {
                                row.push("NULL".to_string());
                            }
                        }
                        rows.push(row);
                    } else {
                        let mut buf = vec![0u8; pkt_len as usize];
                        self.stream_read(&mut buf)?;
                        let mut row = Vec::with_capacity(columns.len());
                        let mut pos = 0;
                        for _ in 0..columns.len() {
                            if pos < buf.len() {
                                let len = buf[pos] as usize;
                                pos += 1;
                                if len < 251 {
                                    let val = String::from_utf8_lossy(&buf[pos..pos + len]).to_string();
                                    row.push(val);
                                    pos += len;
                                } else {
                                    row.push("NULL".to_string());
                                }
                            } else {
                                row.push("NULL".to_string());
                            }
                        }
                        rows.push(row);
                    }
                }
                Ok(MyResult::Rows(rows, columns))
            }
        }
    }

    fn read_column_definition(&mut self) -> Result<MyColumnInfo, crate::error::Error> {
        let pkt_len = self.read_u24()?;
        let _seq = self.read_u8()?;

        let mut buf = vec![0u8; pkt_len as usize];
        self.stream_read(&mut buf)?;

        let mut pos = 0;
        while pos < buf.len() && buf[pos] != b'c' { pos += 1; }
        pos += 1;
        let catalog_len = buf[pos] as usize;
        pos += 1 + catalog_len;
        let schema_len = buf[pos] as usize;
        pos += 1 + schema_len;
        let table_len = buf[pos] as usize;
        pos += 1 + table_len;
        let org_table_len = buf[pos] as usize;
        pos += 1 + org_table_len;
        let name_len = buf[pos] as usize;
        pos += 1;
        let name = String::from_utf8_lossy(&buf[pos..pos + name_len]).to_string();
        pos += name_len;
        let org_name_len = buf[pos] as usize;
        pos += 1 + org_name_len;
        pos += 1;
        let _charset = u16::from_le_bytes([buf[pos], buf[pos + 1]]);
        pos += 2;
        let _col_len = u32::from_le_bytes([buf[pos], buf[pos + 1], buf[pos + 2], buf[pos + 3]]);
        pos += 4;
        let data_type = buf[pos];
        pos += 1;
        let flags = u16::from_le_bytes([buf[pos], buf[pos + 1]]);
        pos += 2;
        let decimals = buf[pos];

        Ok(MyColumnInfo { name, data_type, flags, decimals })
    }

    fn read_eof(&mut self) -> Result<(), crate::error::Error> {
        let pkt_len = self.read_u24()?;
        let _seq = self.read_u8()?;
        self.skip(pkt_len as usize)?;
        Ok(())
    }

    fn send_command(&mut self, cmd: u8, arg: &str) -> Result<(), crate::error::Error> {
        let mut payload = vec![cmd];
        payload.extend_from_slice(arg.as_bytes());
        self.send_packet(0, &payload)?;
        Ok(())
    }

    fn send_packet(&mut self, seq: u8, payload: &[u8]) -> Result<(), crate::error::Error> {
        let len = payload.len() as u32;
        let mut header = Vec::new();
        header.extend_from_slice(&len.to_le_bytes()[..3]);
        header.push(seq);
        self.stream.write_all(&header)?;
        self.stream.write_all(payload)?;
        Ok(())
    }

    fn read_u8(&mut self) -> Result<u8, crate::error::Error> {
        let mut buf = [0u8; 1];
        self.stream.read(&mut buf)?;
        Ok(buf[0])
    }

    fn read_u16(&mut self) -> Result<u16, crate::error::Error> {
        let mut buf = [0u8; 2];
        self.stream.read(&mut buf)?;
        Ok(u16::from_le_bytes(buf))
    }

    fn read_u24(&mut self) -> Result<u32, crate::error::Error> {
        let mut buf = [0u8; 3];
        self.stream.read(&mut buf)?;
        Ok(u32::from_le_bytes([buf[0], buf[1], buf[2], 0]))
    }

    fn read_u32(&mut self) -> Result<u32, crate::error::Error> {
        let mut buf = [0u8; 4];
        self.stream.read(&mut buf)?;
        Ok(u32::from_le_bytes(buf))
    }

    fn read_lenenc(&mut self) -> Result<u64, crate::error::Error> {
        let first = self.read_u8()?;
        self.read_lenenc_from_byte(first, &mut vec![0u8; 0])
    }

    fn read_lenenc_from_byte(&mut self, first: u8, _buf: &mut Vec<u8>) -> Result<u64, crate::error::Error> {
        match first {
            0xFB => Ok(0),
            0xFC => {
                let mut b = [0u8; 2];
                self.stream.read(&mut b)?;
                Ok(u16::from_le_bytes(b) as u64)
            }
            0xFD => {
                let mut b = [0u8; 3];
                self.stream.read(&mut b)?;
                Ok(u32::from_le_bytes([b[0], b[1], b[2], 0]) as u64)
            }
            0xFE => {
                let mut b = [0u8; 8];
                self.stream.read(&mut b)?;
                Ok(u64::from_le_bytes(b))
            }
            v => Ok(v as u64),
        }
    }

    fn stream_read(&mut self, buf: &mut [u8]) -> Result<(), crate::error::Error> {
        self.stream.read(buf)?;
        Ok(())
    }

    fn skip(&mut self, count: usize) -> Result<(), crate::error::Error> {
        let mut buf = vec![0u8; count];
        self.stream.read(&mut buf)?;
        Ok(())
    }
}