mlua-socket 0.2.7

A Rust-native implementation of LuaSocket for mlua.
-- https://github.com/lunarmodules/luasocket/blob/v3.1.0/src/socket.lua
-----------------------------------------------------------------------------
-- LuaSocket helper module
-- Author: Diego Nehab
-----------------------------------------------------------------------------

-----------------------------------------------------------------------------
-- Declare module and import dependencies
-----------------------------------------------------------------------------
local base = _G
local string = require("string")
local math = require("math")
local socket = require("socket.core")
local except = require("socket.except")

local _M = socket

-----------------------------------------------------------------------------
-- Exported auxiliar functions
-----------------------------------------------------------------------------
function _M.connect4(address, port, laddress, lport)
    return socket.connect(address, port, laddress, lport, "inet")
end

function _M.connect6(address, port, laddress, lport)
    return socket.connect(address, port, laddress, lport, "inet6")
end

--function _M.bind(host, port, backlog)
--    if host == "*" then
--        host = "0.0.0.0"
--    end
--    local addrinfo, err = socket.dns.getaddrinfo(host);
--    if not addrinfo then
--        return nil, err
--    end
--    local sock, res
--    err = "no info on address"
--    for i, alt in base.ipairs(addrinfo) do
--        if alt.family == "inet" then
--            sock, err = socket.tcp4()
--        else
--            sock, err = socket.tcp6()
--        end
--        if not sock then
--            return nil, err
--        end
--        sock:setoption("reuseaddr", true)
--        res, err = sock:bind(alt.addr, port)
--        if not res then
--            sock:close()
--        else
--            res, err = sock:listen(backlog)
--            if not res then
--                sock:close()
--            else
--                return sock
--            end
--        end
--    end
--    return nil, err
--end

_M.newtry = except.newtry
_M.try = _M.newtry()
_M.protect = except.protect

function _M.choose(table)
    return function(name, opt1, opt2)
        if base.type(name) ~= "string" then
            name, opt1, opt2 = "default", name, opt1
        end
        local f = table[name or "nil"]
        if not f then
            base.error("unknown key (" .. base.tostring(name) .. ")", 3)
        else
            return f(opt1, opt2)
        end
    end
end

-----------------------------------------------------------------------------
-- Socket sources and sinks, conforming to LTN12
-----------------------------------------------------------------------------
-- create namespaces inside LuaSocket namespace
local sourcet, sinkt = {}, {}
_M.sourcet = sourcet
_M.sinkt = sinkt

_M.BLOCKSIZE = 2048

sinkt["close-when-done"] = function(sock)
    return base.setmetatable({
        getfd = function()
            return sock:getfd()
        end,
        dirty = function()
            return sock:dirty()
        end
    }, {
        __call = function(self, chunk, err)
            if not chunk then
                sock:close()
                return 1
            else
                return sock:send(chunk)
            end
        end
    })
end

sinkt["keep-open"] = function(sock)
    return base.setmetatable({
        getfd = function()
            return sock:getfd()
        end,
        dirty = function()
            return sock:dirty()
        end
    }, {
        __call = function(self, chunk, err)
            if chunk then
                return sock:send(chunk)
            else
                return 1
            end
        end
    })
end

sinkt["default"] = sinkt["keep-open"]

_M.sink = _M.choose(sinkt)

sourcet["by-length"] = function(sock, length)
    return base.setmetatable({
        getfd = function()
            return sock:getfd()
        end,
        dirty = function()
            return sock:dirty()
        end
    }, {
        __call = function()
            if length <= 0 then
                return nil
            end
            local size = math.min(socket.BLOCKSIZE, length)
            local chunk, err = sock:receive(size)
            if err then
                return nil, err
            end
            length = length - string.len(chunk)
            return chunk
        end
    })
end

sourcet["until-closed"] = function(sock)
    local done
    return base.setmetatable({
        getfd = function()
            return sock:getfd()
        end,
        dirty = function()
            return sock:dirty()
        end
    }, {
        __call = function()
            if done then
                return nil
            end
            local chunk, err, partial = sock:receive(socket.BLOCKSIZE)
            if not err then
                return chunk
            elseif err == "closed" then
                sock:close()
                done = 1
                return partial
            else
                return nil, err
            end
        end
    })
end

sourcet["default"] = sourcet["until-closed"]

_M.source = _M.choose(sourcet)

return _M