use futures_util::sink::SinkExt;
use serde::Serialize;
use spop::Version;
use spop::frame::Message;
use spop::frames::{Ack, AgentDisconnect, AgentHello, FrameCapabilities, HaproxyHello};
use spop::{FramePayload, FrameType, SpopCodec, TypedData, VarScope};
use std::collections::BTreeMap;
use std::io;
use std::path::Path;
use std::sync::Arc;
use tokio::sync::watch;
use tokio::task::JoinSet;
use tokio_listener::{Connection, Listener};
use tokio_stream::StreamExt;
use tokio_util::codec::Framed;
use crate::{
morgue::{Decay, StateOfDecay, shutdown_signal},
tenx_programmer::TenXProgrammer,
};
use iocaine_powder::{
acab::State,
http::{HeaderMap, HeaderName, HeaderValue},
sex_dungeon::{DungeonMaster, Language, Request},
};
pub struct CheckpointCharlie {
state: StateOfDecay,
}
impl CheckpointCharlie {
pub fn new(
language: Language,
compiler: Option<&impl AsRef<Path>>,
path: Option<&impl AsRef<Path>>,
initial_seed: &str,
metrics: &TenXProgrammer,
state: &State,
config: Option<impl Serialize>,
) -> anyhow::Result<Self> {
let request_handler = DungeonMaster::new(initial_seed)
.language(language)
.compiler(compiler.as_ref())
.path(path.as_ref())
.config(config)
.build(&metrics.metrics, state)
.map_err(|e| anyhow::anyhow!("{e:#?}"))?;
let request_handler = Arc::new(request_handler);
if !request_handler.can_decide() {
let path = path
.as_ref()
.map_or_else(|| "(default)".into(), |p| p.as_ref().display().to_string());
tracing::error!({ path }, "A decide() function is required");
anyhow::bail!("Requires decide()");
}
let state = Decay {
metrics: metrics.clone(),
request_handler,
}
.into();
Ok(Self { state })
}
fn haproxy_hello(hello: &HaproxyHello) -> anyhow::Result<(Box<AgentHello>, bool)> {
let healthcheck = hello.healthcheck.unwrap_or(false);
let max_frame_size = hello.max_frame_size;
let version = Version::parse("2.0.0")?;
let agent_hello = AgentHello {
version,
max_frame_size,
capabilities: vec![FrameCapabilities::Pipelining],
};
Ok((agent_hello.into(), healthcheck))
}
fn check_request(
msg: &Message,
state_snapshot: &Decay,
vars: &mut Vec<(VarScope, &str, TypedData)>,
) -> anyhow::Result<()> {
let Some(TypedData::String(method)) = msg.args.get("req_method") else {
anyhow::bail!("Message request is missing method");
};
let Some(TypedData::String(path)) = msg.args.get("req_path") else {
anyhow::bail!("Message request is missing path");
};
let mut headers = HeaderMap::new();
if let Some(headers_raw) = msg.args.get("req_hdrs") {
let TypedData::String(headers_raw) = headers_raw else {
anyhow::bail!("Message contains malformed headers");
};
for pair in headers_raw.split("\r\n") {
if pair.is_empty() {
continue;
}
let Some((key, value)) = pair.split_once(':') else {
anyhow::bail!("Message contains malformed header");
};
let hdr_name = HeaderName::from_bytes(key.trim().as_bytes())?;
let hdr_value = HeaderValue::from_str(value.trim())?;
headers.insert(hdr_name, hdr_value);
}
}
let mut params = BTreeMap::new();
if let Some(TypedData::String(params_raw)) = msg.args.get("req_query") {
for pair in params_raw.split('&') {
if pair.is_empty() {
continue;
}
let (key, value) = pair.split_once('=').unwrap_or((pair, ""));
params.insert(key.to_owned(), value.to_owned());
}
}
let request = Request {
method: method.to_owned(),
headers,
path: path.to_owned(),
params,
};
let response = state_snapshot
.request_handler
.decide(request.into())
.map_err(|e| anyhow::anyhow!("{e:#?}"))?;
vars.push((
VarScope::Transaction,
"response",
TypedData::String(response),
));
Ok(())
}
fn haproxy_notify<'a>(
message: &FramePayload,
state: &StateOfDecay,
) -> anyhow::Result<Vec<(VarScope, &'a str, TypedData)>> {
let state_snapshot = state
.read()
.map_err(|e| anyhow::anyhow!("Unable to lock state for reading: {e}"))?;
let mut vars = Vec::new();
if let FramePayload::ListOfMessages(messages) = message {
for msg in *messages {
if msg.name.as_str() == "check-request" {
Self::check_request(msg, &state_snapshot, &mut vars)?;
tracing::trace!({ vars = format!("{vars:?}") }, "handling check-request");
} else {
tracing::warn!("Unhandled message: {msg:?}");
}
}
} else {
tracing::warn!("Unhandled message: {message:?}");
}
Ok(vars)
}
async fn agent_loop(
connection: Connection,
state: StateOfDecay,
mut shutdown: watch::Receiver<bool>,
) -> anyhow::Result<()> {
let mut socket = Framed::new(connection, SpopCodec);
loop {
tokio::select! {
socket_res = socket.next() => {
let Some(frame_res) = socket_res else {
break;
};
let frame = frame_res?;
tracing::trace!("Received {:?} HAProxy frame", frame.frame_type());
match frame.frame_type() {
FrameType::HaproxyHello => {
let hello = HaproxyHello::try_from(frame.payload())
.map_err(|e| anyhow::anyhow!("Failed to parse HAProxy hello {e}"))?;
let (hello, healthcheck) = Self::haproxy_hello(&hello)?;
socket.send(hello).await?;
if healthcheck {
tracing::warn!("HAProxy SPOE exited");
break;
}
}
FrameType::HaproxyDisconnect => {
let disconnect = AgentDisconnect {
status_code: 0,
message: "Iocaine disconnecting".to_owned(),
};
tracing::warn!("HAProxy SPOE disconnected");
socket.send(disconnect.into()).await?;
socket.close().await?;
break;
}
FrameType::Notify => {
let actions = Self::haproxy_notify(&frame.payload(), &state)?;
let ack = actions.into_iter().fold(
Ack::new(frame.metadata().stream_id, frame.metadata().frame_id),
|ack, (scope, name, val)| ack.set_var(scope, name, val),
);
socket.send(ack.into()).await?;
}
_ => {
tracing::warn!("Unhandled HAProxy SPOE frame: {:?}", frame.frame_type());
}
}
}
_ = shutdown.changed() => {
let disconnect = AgentDisconnect {
status_code: 0,
message: "Iocaine shutting down".to_owned(),
};
socket.send(disconnect.into()).await?;
socket.close().await?;
break;
}
}
}
Ok(())
}
pub async fn serve(&self, mut listener: Listener) -> anyhow::Result<()> {
let (tx, mut rx) = watch::channel(false);
let signal_handler = tokio::spawn(async move {
shutdown_signal(None).await;
tracing::info!("Signalling HAProxy SPOA shutdown");
let _ = tx.send(true);
});
let mut workers = JoinSet::new();
loop {
tokio::select! {
res = listener.accept() => {
match res {
Ok((stream, _)) => {
let shutdown = rx.clone();
let state_ = self.state.clone();
workers.spawn(async move {
if let Err(e) = Self::agent_loop(stream, state_, shutdown).await {
if let Some(err) = e.downcast_ref::<io::Error>()
&& err.kind() == io::ErrorKind::ConnectionReset {
tracing::debug!("Connection reset in HAProxy agent loop: {err}");
return
}
tracing::error!("Error in HAProxy agent loop: {e}");
}
});
}
Err(e) => tracing::error!("Error accepting HAProxy SPOE connection: {e}"),
}
},
_ = rx.changed() => break,
}
}
workers.join_all().await;
Ok(signal_handler.await?)
}
}