use clap::{command, value_parser, Arg};
use portman::responder::responder;
use std::io::BufRead;
use std::io::BufReader;
use std::io::Write;
use std::net;
use std::net::SocketAddr;
use std::net::TcpListener;
use std::net::TcpStream;
use std::process;
use std::sync::mpsc;
use std::sync::{Arc, Mutex};
use std::thread;
type RequestChannel = Arc<Mutex<mpsc::Sender<responder::RequestMessage>>>;
type Socket = Arc<Mutex<TcpStream>>;
#[derive(Debug, Clone, Copy)]
struct Arguments {
listen_port: u16,
port_base: u16,
num_ports: u16,
}
#[derive(Debug)]
enum ClientRequest {
Gimme {
service_name: String,
user_name: String,
},
List,
Terminate,
Invalid,
}
fn parse_arguments() -> Arguments {
let parser =command!()
.version("1.0")
.author("Ron Fox")
.about("Rust replacement for NSCLDAQ port manager - does not need container")
.arg(Arg::new("listen-port").short('l').long("listen-port")
.default_value("30000")
.value_parser(value_parser!(u16))
)
.arg(
Arg::new("port-base").short('p').long("port-base")
.default_value("31000")
.value_parser(value_parser!(u16))
)
.arg(
Arg::new("num-ports").short('n').long("num-ports")
.default_value("1000")
.value_parser(value_parser!(u16))
)
.get_matches();
let mut result = Arguments {
listen_port: 30000,
port_base: 31000,
num_ports: 1000,
};
if let Some(listen_value) = parser.get_one::<u16>("listen-port") {
result.listen_port = *listen_value;
} else {
eprintln!("The listen port value must be a 16 bit unsigned integer");
process::exit(-1);
}
if let Some(base_value) = parser.get_one::<u16>("port-base") {
result.port_base = *base_value;
} else {
eprintln!("The port-base value must be a 16 bit unsigned integer");
process::exit(-1);
};
if let Some(num_value) = parser.get_one::<u16>("num-ports") {
result.num_ports = *num_value;
} else {
eprintln!("The num-ports value must be a 16 bit unsigned integer");
process::exit(-1);
}
result
}
fn main() {
let args = parse_arguments();
println!("{:#?}", args);
let (request_send, request_receive) = mpsc::channel();
let safe_req = Arc::new(Mutex::new(request_send));
let _service_handle = thread::spawn(move || {
responder::responder(args.port_base, args.num_ports, request_receive)
});
let server = TcpListener::bind(SocketAddr::from(([0, 0, 0, 0], args.listen_port))).unwrap();
for request in server.incoming() {
if let Ok(socket) = request {
let safe_socket = Arc::new(Mutex::new(socket));
let myreq = Arc::clone(&safe_req);
thread::spawn(move || process_request(Arc::clone(&myreq), Arc::clone(&safe_socket)));
} else {
}
}
}
fn read_request_line(socket: &Socket) -> String {
let mut line: Vec<u8> = vec![];
let so = socket.lock().unwrap();
let mut reader = BufReader::new(so.try_clone().unwrap());
if let Ok(_) = reader.read_until(b'\n', &mut line) {
String::from_utf8_lossy(&line).trim_end().to_string()
} else {
String::from("") }
}
fn decode_request(request_line: &str) -> ClientRequest {
let request_words: Vec<&str> = request_line.split_ascii_whitespace().collect::<Vec<&str>>();
if request_words.len() >= 1 {
match request_words[0] {
"GIMME" => {
if request_words.len() == 3 {
ClientRequest::Gimme {
service_name: request_words[1].to_string(),
user_name: request_words[2].to_string(),
}
} else {
ClientRequest::Invalid
}
}
"LIST" => ClientRequest::List,
"TERMINATE" => ClientRequest::Terminate,
_ => ClientRequest::Invalid,
}
} else {
ClientRequest::Invalid
}
}
fn release_ports(req_chan: &RequestChannel, ports: Vec<u16>) {
for port in ports {
responder::release_port(port, &req_chan.lock().unwrap()).unwrap();
}
}
fn process_request(req_chan: RequestChannel, so: Socket) {
let mut allocated_ports = Vec::<u16>::new();
println!("Connected from {:#?}", so.lock().unwrap().peer_addr());
loop {
let request_line = read_request_line(&so);
if request_line.len() == 0 {
break;
}
println!("Request: {}", request_line);
let request = decode_request(&request_line);
match request {
ClientRequest::Gimme {
service_name,
user_name,
} => {
match create_allocation(
Arc::clone(&req_chan),
Arc::clone(&so),
&service_name,
&user_name,
) {
Ok(port) => {
allocated_ports.push(port);
if so
.lock()
.unwrap()
.write_all(format!("OK {}\n", port).as_bytes())
.is_err()
{
break;
}
}
Err(msg) => {
if so
.lock()
.unwrap()
.write_all(format!("FAIL - {}\n", msg).as_bytes())
.is_err()
{
break;
}
break; }
};
}
ClientRequest::List => {
list_allocations(&req_chan, &so);
}
ClientRequest::Terminate => {
println!("Client requesting shutdown");
process::exit(0);
}
ClientRequest::Invalid => {
invalid_request(&so);
break; }
}
}
release_ports(&req_chan, allocated_ports);
let _ = so.lock().unwrap().shutdown(net::Shutdown::Both);
}
fn is_local(so: &Socket) -> bool {
let socket = so.lock().unwrap();
if let Ok(peer) = socket.peer_addr() {
if peer.is_ipv4() {
peer.ip() == net::Ipv4Addr::new(127, 0, 0, 1)
} else if peer.is_ipv6() {
peer.ip() == net::Ipv6Addr::new(0, 0, 0, 0, 0, 0, 0, 1)
} else {
false
}
} else {
false
}
}
fn invalid_request(sock: &Socket) {
sock.lock()
.unwrap()
.write_all(String::from("FAIL - invalid request\n").as_bytes())
.unwrap();
sock.lock().unwrap().flush().unwrap();
}
fn list_allocations(req_chan: &RequestChannel, so: &Socket) {
let allocations = responder::get_allocations(&req_chan.lock().unwrap()).unwrap();
let mut sock = so.lock().unwrap();
let result = sock.write_all(format!("OK {}\n", allocations.len()).as_bytes());
if result.is_err() {
return;
}
for aloc in allocations {
let result = sock.write_all(format!("{}\n", aloc).as_bytes());
if result.is_err() {
return;
}
}
let _ = sock.flush();
return;
}
fn create_allocation(
req_chan: RequestChannel,
so: Socket,
service: &str,
user: &str,
) -> Result<u16, String> {
if !is_local(&so) {
Err(String::from("FAIL can only allocate to local senders\n"))
} else {
responder::request_port(service, user, &req_chan.lock().unwrap())
}
}