yamlbase 0.7.2

A lightweight SQL server that serves YAML-defined tables over standard SQL protocols
Documentation
use bytes::{BufMut, BytesMut};
use tokio::io::AsyncWriteExt;
use tokio::net::TcpStream;

use crate::YamlBaseError;
use crate::database::Value;

/// PostgreSQL function call protocol implementation
/// Supports calling PostgreSQL functions via the Function Call protocol
pub struct PostgresFunctionProtocol;

impl PostgresFunctionProtocol {
    /// Handle function call message (F message)
    pub async fn handle_function_call(stream: &mut TcpStream, data: &[u8]) -> crate::Result<()> {
        if data.len() < 4 {
            return Err(YamlBaseError::Protocol(
                "Incomplete function call message".to_string(),
            ));
        }

        let mut pos = 0;

        // Read function OID
        let function_oid =
            u32::from_be_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
        pos += 4;

        // Read argument format codes count
        if pos + 2 > data.len() {
            return Err(YamlBaseError::Protocol(
                "Incomplete argument format codes".to_string(),
            ));
        }
        let arg_format_count = u16::from_be_bytes([data[pos], data[pos + 1]]) as usize;
        pos += 2;

        // Read argument format codes
        let mut _arg_formats = Vec::new();
        for _ in 0..arg_format_count {
            if pos + 2 > data.len() {
                return Err(YamlBaseError::Protocol(
                    "Incomplete argument format".to_string(),
                ));
            }
            let format = u16::from_be_bytes([data[pos], data[pos + 1]]);
            _arg_formats.push(format);
            pos += 2;
        }

        // Read argument count
        if pos + 2 > data.len() {
            return Err(YamlBaseError::Protocol(
                "Incomplete argument count".to_string(),
            ));
        }
        let arg_count = u16::from_be_bytes([data[pos], data[pos + 1]]) as usize;
        pos += 2;

        // Read arguments
        let mut arguments = Vec::new();
        for _ in 0..arg_count {
            if pos + 4 > data.len() {
                return Err(YamlBaseError::Protocol(
                    "Incomplete argument length".to_string(),
                ));
            }
            let arg_len =
                i32::from_be_bytes([data[pos], data[pos + 1], data[pos + 2], data[pos + 3]]);
            pos += 4;

            if arg_len == -1 {
                arguments.push(Value::Null);
            } else {
                let arg_len = arg_len as usize;
                if pos + arg_len > data.len() {
                    return Err(YamlBaseError::Protocol(
                        "Incomplete argument data".to_string(),
                    ));
                }
                let arg_data = &data[pos..pos + arg_len];
                pos += arg_len;

                // For now, treat all arguments as text
                let arg_str = std::str::from_utf8(arg_data).map_err(|_| {
                    YamlBaseError::Protocol("Invalid UTF-8 in function argument".to_string())
                })?;
                arguments.push(Value::Text(arg_str.to_string()));
            }
        }

        // Read result format code
        if pos + 2 > data.len() {
            return Err(YamlBaseError::Protocol(
                "Incomplete result format".to_string(),
            ));
        }
        let _result_format = u16::from_be_bytes([data[pos], data[pos + 1]]);

        // Call the function
        let result = Self::call_function(function_oid, &arguments)?;

        // Send function result
        Self::send_function_result(stream, &result).await?;

        Ok(())
    }

    fn call_function(function_oid: u32, arguments: &[Value]) -> crate::Result<Value> {
        match function_oid {
            // version() - OID 89
            89 => Ok(Value::Text(
                "PostgreSQL 14.0 (YamlBase Mock Server)".to_string(),
            )),

            // current_database() - OID 861
            861 => Ok(Value::Text("postgres".to_string())),

            // current_schema() - OID 843
            843 => Ok(Value::Text("public".to_string())),

            // current_user - OID 745
            745 => Ok(Value::Text("postgres".to_string())),

            // session_user - OID 746
            746 => Ok(Value::Text("postgres".to_string())),

            // user - OID 747
            747 => Ok(Value::Text("postgres".to_string())),

            // pg_backend_pid() - OID 2026
            2026 => Ok(Value::Integer(12345)),

            // inet_client_addr() - OID 2197
            2197 => Ok(Value::Text("127.0.0.1".to_string())),

            // inet_client_port() - OID 2198
            2198 => Ok(Value::Integer(0)),

            // pg_conf_load_time() - OID 2226
            2226 => Ok(Value::Timestamp(chrono::Utc::now().naive_utc())),

            // pg_postmaster_start_time() - OID 2560
            2560 => Ok(Value::Timestamp(chrono::Utc::now().naive_utc())),

            // has_table_privilege() - OID 1922
            1922 => {
                // For a read-only database, all tables are readable but not writable
                if arguments.len() >= 2 {
                    if let Value::Text(privilege) = &arguments[1] {
                        match privilege.to_uppercase().as_str() {
                            "SELECT" => Ok(Value::Boolean(true)),
                            "INSERT" | "UPDATE" | "DELETE" | "TRUNCATE" | "REFERENCES"
                            | "TRIGGER" => Ok(Value::Boolean(false)),
                            _ => Ok(Value::Boolean(false)),
                        }
                    } else {
                        Ok(Value::Boolean(false))
                    }
                } else {
                    Ok(Value::Boolean(false))
                }
            }

            // has_schema_privilege() - OID 2305
            2305 => {
                // Similar to table privileges
                if arguments.len() >= 2 {
                    if let Value::Text(privilege) = &arguments[1] {
                        match privilege.to_uppercase().as_str() {
                            "USAGE" => Ok(Value::Boolean(true)),
                            "CREATE" => Ok(Value::Boolean(false)),
                            _ => Ok(Value::Boolean(false)),
                        }
                    } else {
                        Ok(Value::Boolean(false))
                    }
                } else {
                    Ok(Value::Boolean(false))
                }
            }

            // format_type() - OID 1081
            1081 => {
                if arguments.len() >= 2 {
                    if let Value::Integer(type_oid) = &arguments[0] {
                        let type_name = Self::oid_to_type_name(*type_oid as u32);
                        Ok(Value::Text(type_name))
                    } else {
                        Ok(Value::Text("unknown".to_string()))
                    }
                } else {
                    Ok(Value::Text("unknown".to_string()))
                }
            }

            // pg_get_expr() - OID 1716
            1716 => {
                // Return the expression as-is (simplified)
                if !arguments.is_empty() {
                    Ok(arguments[0].clone())
                } else {
                    Ok(Value::Null)
                }
            }

            // col_description() - OID 1216
            1216 => Ok(Value::Null), // No column descriptions in our mock

            // obj_description() - OID 1215
            1215 => Ok(Value::Null), // No object descriptions in our mock

            // pg_get_indexdef() - OID 1643
            1643 => Ok(Value::Null), // No indexes in our mock

            // pg_get_constraintdef() - OID 1387
            1387 => Ok(Value::Null), // No constraints in our mock

            // Unknown function
            _ => Err(YamlBaseError::Protocol(format!(
                "Unknown function OID: {}",
                function_oid
            ))),
        }
    }

    async fn send_function_result(stream: &mut TcpStream, result: &Value) -> crate::Result<()> {
        let mut buf = BytesMut::new();
        buf.put_u8(b'V'); // FunctionCallResponse

        match result {
            Value::Null => {
                buf.put_u32(8); // length: 4 (length field) + 4 (null indicator)
                buf.put_i32(-1); // NULL
            }
            _ => {
                let result_str = result.to_string();
                let result_bytes = result_str.as_bytes();
                buf.put_u32(4 + 4 + result_bytes.len() as u32); // length
                buf.put_i32(result_bytes.len() as i32); // result length
                buf.put_slice(result_bytes); // result data
            }
        }

        stream.write_all(&buf).await?;
        Ok(())
    }

    fn oid_to_type_name(oid: u32) -> String {
        match oid {
            16 => "boolean".to_string(),
            17 => "bytea".to_string(),
            18 => "\"char\"".to_string(),
            19 => "name".to_string(),
            20 => "bigint".to_string(),
            21 => "smallint".to_string(),
            23 => "integer".to_string(),
            25 => "text".to_string(),
            26 => "oid".to_string(),
            114 => "json".to_string(),
            700 => "real".to_string(),
            701 => "double precision".to_string(),
            1042 => "character".to_string(),
            1043 => "character varying".to_string(),
            1082 => "date".to_string(),
            1083 => "time without time zone".to_string(),
            1114 => "timestamp without time zone".to_string(),
            1184 => "timestamp with time zone".to_string(),
            1700 => "numeric".to_string(),
            2950 => "uuid".to_string(),
            3802 => "jsonb".to_string(),
            _ => format!("unknown({})", oid),
        }
    }
}