use super::tcp::Tcp;
use mlua::{Error, Lua, MultiValue};
use std::net::Shutdown;
pub(super) fn handle(_lua: &Lua, tcp: &Tcp, args: MultiValue) -> Result<(), Error> {
let mode: Shutdown = {
if args.is_empty() {
Shutdown::Both
} else {
let arg0 = args[0].to_string()?;
match arg0.as_str() {
"send" => Shutdown::Write,
"receive" => Shutdown::Read,
_ => Shutdown::Both,
}
}
};
tcp.socket
.shutdown(mode)
.map_err(|err| Error::RuntimeError(err.to_string()))?;
Ok(())
}
#[cfg(test)]
mod tests {
extern crate tokio;
use mlua::Lua;
use std::error::Error;
use std::net::SocketAddr;
use tokio::io::AsyncReadExt;
use tokio::net::TcpListener;
#[tokio::test]
async fn shutdown() -> Result<(), Box<dyn Error>> {
let addr: SocketAddr = "127.0.0.1:0".parse()?;
let listener = TcpListener::bind(&addr).await?;
let port = listener.local_addr()?.port();
tokio::spawn(async move {
let (mut socket, _) = listener.accept().await.unwrap();
let mut buf = vec![0; 3];
let _ = socket.read_exact(&mut buf).await;
});
let lua = Lua::new();
crate::preload(&lua)?;
let script = r#"
local socket = require('socket')
local client = socket.connect('127.0.0.1', _port_)
client:shutdown()
return client:send('abc')
"#
.replace("_port_", format!("{port}").as_str());
let (_bytes_sent, err): (Option<u16>, Option<String>) = lua.load(script).eval()?;
assert_ne!(err, None);
Ok(())
}
}