#[cfg(test)]
mod tests;
use std::io::{Read, Write};
use std::net::TcpStream;
use std::sync::{Arc, RwLock};
use std::time::Duration;
use crate::application::Application;
use crate::core::New;
use crate::mime_type::MimeType;
use crate::range::Range;
use crate::request::Request;
use crate::response::{Response, STATUS_CODE_REASON_PHRASE};
use crate::server::ConnectionInfo;
#[derive(Debug, Clone, PartialEq)]
pub struct IngressRule {
pub host: String,
pub path: String,
pub service_name: String,
pub service_port: u16,
pub namespace: String,
}
impl IngressRule {
pub fn upstream_addr(&self) -> String {
format!(
"{}.{}.svc.cluster.local:{}",
self.service_name, self.namespace, self.service_port
)
}
pub fn matches(&self, host: &str, uri: &str) -> bool {
if !self.host.is_empty() && !self.host.eq_ignore_ascii_case(host) {
return false;
}
self.path == "/" || uri.starts_with(&self.path)
}
}
fn extract_str_field<'a>(json: &'a str, field: &str) -> Option<&'a str> {
let needle = format!("\"{}\":", field);
let start = json.find(needle.as_str())?;
let after_colon = &json[start + needle.len()..];
let after_colon = after_colon.trim_start_matches(' ');
if !after_colon.starts_with('"') {
return None;
}
let inner = &after_colon[1..];
let end = inner.find('"')?;
Some(&inner[..end])
}
fn extract_u16_field(json: &str, field: &str) -> Option<u16> {
let needle = format!("\"{}\":", field);
let start = json.find(needle.as_str())?;
let after_colon = &json[start + needle.len()..];
let after_colon = after_colon.trim_start_matches(' ');
let end = after_colon.find(|c: char| !c.is_ascii_digit())?;
after_colon[..end].parse().ok()
}
pub fn parse_ingress_list(json: &str) -> Vec<IngressRule> {
let mut rules = Vec::new();
let spec_sections: Vec<&str> = json.split("\"spec\"").collect();
for section in spec_sections.iter().skip(1) {
let namespace = extract_str_field(section, "namespace")
.unwrap_or("default")
.to_string();
let rules_sections: Vec<&str> = section.split("\"rules\"").collect();
for rules_section in rules_sections.iter().skip(1) {
let host = extract_str_field(rules_section, "host").unwrap_or("").to_string();
let paths_sections: Vec<&str> = rules_section.split("\"paths\"").collect();
for paths_section in paths_sections.iter().skip(1) {
let path_entries: Vec<&str> = paths_section.split("\"path\"").collect();
for path_entry in path_entries.iter().skip(1) {
let path = extract_str_field(path_entry, "path")
.or_else(|| {
let after_colon = path_entry.trim_start_matches(':').trim_start_matches(' ');
if after_colon.starts_with('"') {
let inner = &after_colon[1..];
inner.find('"').map(|end| &inner[..end])
} else {
None
}
})
.unwrap_or("/")
.to_string();
let service_name =
extract_str_field(path_entry, "name").unwrap_or("").to_string();
let service_port =
extract_u16_field(path_entry, "number").unwrap_or(80);
if !service_name.is_empty() {
rules.push(IngressRule {
host: host.clone(),
path,
service_name,
service_port,
namespace: namespace.clone(),
});
}
}
}
}
}
rules
}
pub struct KubernetesIngressWatcher {
api_server: String,
token: String,
namespace: String,
poll_interval_secs: u64,
rules: Arc<RwLock<Vec<IngressRule>>>,
}
impl KubernetesIngressWatcher {
pub fn new(api_server: impl Into<String>, token: impl Into<String>) -> Self {
Self {
api_server: api_server.into(),
token: token.into(),
namespace: "default".to_string(),
poll_interval_secs: 30,
rules: Arc::new(RwLock::new(Vec::new())),
}
}
pub fn from_service_account() -> Result<Self, String> {
Err(
"In-cluster TLS (https://kubernetes.default.svc) is not yet supported. \
Use `kubectl proxy` and set RWS_K8S_API_SERVER=http://localhost:8001 \
along with RWS_K8S_TOKEN and RWS_K8S_NAMESPACE, then call \
KubernetesIngressWatcher::from_env()."
.to_string(),
)
}
pub fn from_env() -> Result<Self, String> {
let api_server = std::env::var("RWS_K8S_API_SERVER").map_err(|_| {
"RWS_K8S_API_SERVER environment variable is not set".to_string()
})?;
let token = std::env::var("RWS_K8S_TOKEN").unwrap_or_default();
let namespace = std::env::var("RWS_K8S_NAMESPACE").unwrap_or_else(|_| "default".to_string());
let mut watcher = Self::new(api_server, token);
watcher.namespace = namespace;
Ok(watcher)
}
pub fn namespace(mut self, ns: impl Into<String>) -> Self {
self.namespace = ns.into();
self
}
pub fn poll_interval_secs(mut self, secs: u64) -> Self {
self.poll_interval_secs = secs;
self
}
pub fn start(&self) {
self.clone_inner().poll_loop();
}
fn clone_inner(&self) -> WatcherHandle {
WatcherHandle {
api_server: self.api_server.clone(),
token: self.token.clone(),
namespace: self.namespace.clone(),
poll_interval_secs: self.poll_interval_secs,
rules: Arc::clone(&self.rules),
}
}
pub fn rules(&self) -> Vec<IngressRule> {
self.rules.read().unwrap().clone()
}
pub fn poll(&self) -> Result<(), String> {
let new_rules = self.do_poll()?;
*self.rules.write().unwrap() = new_rules;
Ok(())
}
fn do_poll(&self) -> Result<Vec<IngressRule>, String> {
let path = if self.namespace.is_empty() || self.namespace == "all" {
"/apis/networking.k8s.io/v1/ingresses".to_string()
} else {
format!(
"/apis/networking.k8s.io/v1/namespaces/{}/ingresses",
self.namespace
)
};
let body = http_get_plain(&self.api_server, &path, &self.token)?;
Ok(parse_ingress_list(&body))
}
}
struct WatcherHandle {
api_server: String,
token: String,
namespace: String,
poll_interval_secs: u64,
rules: Arc<RwLock<Vec<IngressRule>>>,
}
impl WatcherHandle {
fn poll_loop(self) {
self.poll_once();
let interval = Duration::from_secs(self.poll_interval_secs);
std::thread::spawn(move || loop {
std::thread::sleep(interval);
self.poll_once();
});
}
fn poll_once(&self) {
let path = if self.namespace.is_empty() || self.namespace == "all" {
"/apis/networking.k8s.io/v1/ingresses".to_string()
} else {
format!(
"/apis/networking.k8s.io/v1/namespaces/{}/ingresses",
self.namespace
)
};
match http_get_plain(&self.api_server, &path, &self.token) {
Ok(body) => {
let new_rules = parse_ingress_list(&body);
*self.rules.write().unwrap() = new_rules;
}
Err(e) => {
eprintln!("ingress watcher: poll failed: {}", e);
}
}
}
}
fn http_get_plain(api_server: &str, path: &str, token: &str) -> Result<String, String> {
let rest = api_server
.strip_prefix("http://")
.ok_or_else(|| format!("ingress watcher: api_server must start with http://, got: {}", api_server))?;
let host_port = rest.split('/').next().unwrap_or(rest);
let (host, port) = if let Some(colon) = host_port.rfind(':') {
let port_str = &host_port[colon + 1..];
if let Ok(p) = port_str.parse::<u16>() {
(&host_port[..colon], p)
} else {
(host_port, 80u16)
}
} else {
(host_port, 80u16)
};
let addr = format!("{}:{}", host, port);
let mut stream = TcpStream::connect(&addr)
.map_err(|e| format!("ingress watcher: connect to {} failed: {}", addr, e))?;
stream.set_read_timeout(Some(Duration::from_secs(10))).map_err(|e| e.to_string())?;
stream.set_write_timeout(Some(Duration::from_secs(5))).map_err(|e| e.to_string())?;
let auth_header = if token.is_empty() {
String::new()
} else {
format!("Authorization: Bearer {}\r\n", token)
};
let request = format!(
"GET {} HTTP/1.1\r\nHost: {}\r\n{}Accept: application/json\r\nConnection: close\r\n\r\n",
path, host, auth_header
);
stream.write_all(request.as_bytes()).map_err(|e| e.to_string())?;
let mut buf = Vec::with_capacity(8192);
let mut tmp = [0u8; 4096];
loop {
match stream.read(&mut tmp) {
Ok(0) => break,
Ok(n) => buf.extend_from_slice(&tmp[..n]),
Err(e) => return Err(format!("ingress watcher: read failed: {}", e)),
}
}
let header_end = buf
.windows(4)
.position(|w| w == b"\r\n\r\n")
.ok_or_else(|| "ingress watcher: incomplete HTTP response (no header end)".to_string())?;
let header_str = std::str::from_utf8(&buf[..header_end]).unwrap_or("");
let status_line = header_str.lines().next().unwrap_or("");
let parts: Vec<&str> = status_line.splitn(3, ' ').collect();
if parts.len() < 2 {
return Err(format!("ingress watcher: malformed status line: {}", status_line));
}
let status: u16 = parts[1].parse().unwrap_or(0);
if status < 200 || status >= 300 {
return Err(format!("ingress watcher: API returned status {}", status));
}
let body_bytes = &buf[header_end + 4..];
std::str::from_utf8(body_bytes)
.map(|s| s.to_string())
.map_err(|e| format!("ingress watcher: non-UTF-8 response body: {}", e))
}
pub struct IngressRouter {
watcher: KubernetesIngressWatcher,
connect_timeout: Duration,
read_timeout: Duration,
}
impl IngressRouter {
pub fn new(watcher: KubernetesIngressWatcher) -> Self {
Self {
watcher,
connect_timeout: Duration::from_secs(5),
read_timeout: Duration::from_secs(30),
}
}
pub fn connect_timeout_ms(mut self, ms: u64) -> Self {
self.connect_timeout = Duration::from_millis(ms);
self
}
pub fn read_timeout_ms(mut self, ms: u64) -> Self {
self.read_timeout = Duration::from_millis(ms);
self
}
}
impl Application for IngressRouter {
fn execute(&self, request: &Request, connection: &ConnectionInfo) -> Result<Response, String> {
let host = request
.get_header("host".to_string())
.map(|h| h.value.as_str())
.unwrap_or("");
let rules = self.watcher.rules();
let matched = rules.iter().find(|r| r.matches(host, &request.request_uri));
match matched {
Some(rule) => {
let upstream_host = format!(
"{}.{}.svc.cluster.local",
rule.service_name, rule.namespace
);
crate::proxy::proxy_http1(
request,
&connection.client.ip,
&upstream_host,
rule.service_port,
self.connect_timeout,
self.read_timeout,
)
.or_else(|_| Ok(bad_gateway()))
}
None => Ok(not_found()),
}
}
}
fn bad_gateway() -> Response {
let cr = Range::get_content_range(
b"502 Bad Gateway".to_vec(),
MimeType::TEXT_PLAIN.to_string(),
);
let mut r = Response::new();
r.status_code = *STATUS_CODE_REASON_PHRASE.n502_bad_gateway.status_code;
r.reason_phrase = STATUS_CODE_REASON_PHRASE.n502_bad_gateway.reason_phrase.to_string();
r.content_range_list = vec![cr];
r
}
fn not_found() -> Response {
let cr = Range::get_content_range(
b"404 No matching ingress rule".to_vec(),
MimeType::TEXT_PLAIN.to_string(),
);
let mut r = Response::new();
r.status_code = *STATUS_CODE_REASON_PHRASE.n404_not_found.status_code;
r.reason_phrase = STATUS_CODE_REASON_PHRASE.n404_not_found.reason_phrase.to_string();
r.content_range_list = vec![cr];
r
}