iocaine 3.4.0

The deadliest poison known to AI
// SPDX-FileCopyrightText: famfo
// SPDX-FileContributor: famfo
//
// SPDX-License-Identifier: MIT

use futures_util::sink::SinkExt;
use serde::Serialize;
use spop::Version;
use spop::frame::Message;
use spop::frames::{Ack, AgentDisconnect, AgentHello, FrameCapabilities, HaproxyHello};
use spop::{FramePayload, FrameType, SpopCodec, TypedData, VarScope};
use std::collections::BTreeMap;
use std::io;
use std::path::Path;
use std::sync::Arc;
use tokio::sync::watch;
use tokio::task::JoinSet;
use tokio_listener::{Connection, Listener};
use tokio_stream::StreamExt;
use tokio_util::codec::Framed;

use crate::{
    morgue::{Decay, StateOfDecay, shutdown_signal},
    tenx_programmer::TenXProgrammer,
};

use iocaine_powder::{
    acab::State,
    http::{HeaderMap, HeaderName, HeaderValue},
    sex_dungeon::{DungeonMaster, Language, Request},
};

pub struct CheckpointCharlie {
    state: StateOfDecay,
}

impl CheckpointCharlie {
    pub fn new(
        language: Language,
        compiler: Option<&impl AsRef<Path>>,
        path: Option<&impl AsRef<Path>>,
        initial_seed: &str,
        metrics: &TenXProgrammer,
        state: &State,
        config: Option<impl Serialize>,
    ) -> anyhow::Result<Self> {
        let request_handler = DungeonMaster::new(initial_seed)
            .language(language)
            .compiler(compiler.as_ref())
            .path(path.as_ref())
            .config(config)
            .build(&metrics.metrics, state)
            .map_err(|e| anyhow::anyhow!("{e:#?}"))?;
        let request_handler = Arc::new(request_handler);

        if !request_handler.can_decide() {
            let path = path
                .as_ref()
                .map_or_else(|| "(default)".into(), |p| p.as_ref().display().to_string());
            tracing::error!({ path }, "A decide() function is required");
            anyhow::bail!("Requires decide()");
        }

        let state = Decay {
            metrics: metrics.clone(),
            request_handler,
        }
        .into();

        Ok(Self { state })
    }

    fn haproxy_hello(hello: &HaproxyHello) -> anyhow::Result<(Box<AgentHello>, bool)> {
        let healthcheck = hello.healthcheck.unwrap_or(false);
        let max_frame_size = hello.max_frame_size;

        let version = Version::parse("2.0.0")?;
        let agent_hello = AgentHello {
            version,
            max_frame_size,
            capabilities: vec![FrameCapabilities::Pipelining],
        };

        Ok((agent_hello.into(), healthcheck))
    }

    fn check_request(
        msg: &Message,
        state_snapshot: &Decay,
        vars: &mut Vec<(VarScope, &str, TypedData)>,
    ) -> anyhow::Result<()> {
        let Some(TypedData::String(method)) = msg.args.get("req_method") else {
            anyhow::bail!("Message request is missing method");
        };

        let Some(TypedData::String(path)) = msg.args.get("req_path") else {
            anyhow::bail!("Message request is missing path");
        };

        let mut headers = HeaderMap::new();
        if let Some(headers_raw) = msg.args.get("req_hdrs") {
            let TypedData::String(headers_raw) = headers_raw else {
                anyhow::bail!("Message contains malformed headers");
            };

            for pair in headers_raw.split("\r\n") {
                if pair.is_empty() {
                    // End of header map
                    continue;
                }

                let Some((key, value)) = pair.split_once(':') else {
                    anyhow::bail!("Message contains malformed header");
                };

                let hdr_name = HeaderName::from_bytes(key.trim().as_bytes())?;
                let hdr_value = HeaderValue::from_str(value.trim())?;

                headers.insert(hdr_name, hdr_value);
            }
        }

        let mut params = BTreeMap::new();
        if let Some(TypedData::String(params_raw)) = msg.args.get("req_query") {
            for pair in params_raw.split('&') {
                if pair.is_empty() {
                    continue;
                }

                let (key, value) = pair.split_once('=').unwrap_or((pair, ""));
                params.insert(key.to_owned(), value.to_owned());
            }
        }

        let request = Request {
            method: method.to_owned(),
            headers,
            path: path.to_owned(),
            params,
        };

        let response = state_snapshot
            .request_handler
            .decide(request.into())
            .map_err(|e| anyhow::anyhow!("{e:#?}"))?;

        vars.push((
            VarScope::Transaction,
            "response",
            TypedData::String(response),
        ));

        Ok(())
    }

    fn haproxy_notify<'a>(
        message: &FramePayload,
        state: &StateOfDecay,
    ) -> anyhow::Result<Vec<(VarScope, &'a str, TypedData)>> {
        let state_snapshot = state
            .read()
            .map_err(|e| anyhow::anyhow!("Unable to lock state for reading: {e}"))?;
        let mut vars = Vec::new();

        if let FramePayload::ListOfMessages(messages) = message {
            for msg in *messages {
                if msg.name.as_str() == "check-request" {
                    Self::check_request(msg, &state_snapshot, &mut vars)?;
                    tracing::trace!({ vars = format!("{vars:?}") }, "handling check-request");
                } else {
                    tracing::warn!("Unhandled message: {msg:?}");
                }
            }
        } else {
            tracing::warn!("Unhandled message: {message:?}");
        }

        Ok(vars)
    }

    async fn agent_loop(
        connection: Connection,
        state: StateOfDecay,
        mut shutdown: watch::Receiver<bool>,
    ) -> anyhow::Result<()> {
        let mut socket = Framed::new(connection, SpopCodec);
        loop {
            tokio::select! {
                socket_res = socket.next() => {
                    let Some(frame_res) = socket_res else {
                        break;
                    };

                    let frame = frame_res?;
                    tracing::trace!("Received {:?} HAProxy frame", frame.frame_type());

                    match frame.frame_type() {
                        FrameType::HaproxyHello => {
                            let hello = HaproxyHello::try_from(frame.payload())
                                .map_err(|e| anyhow::anyhow!("Failed to parse HAProxy hello {e}"))?;

                            let (hello, healthcheck) = Self::haproxy_hello(&hello)?;
                            socket.send(hello).await?;

                            if healthcheck {
                                tracing::warn!("HAProxy SPOE exited");
                                break;
                            }
                        }
                        FrameType::HaproxyDisconnect => {
                            let disconnect = AgentDisconnect {
                                status_code: 0,
                                message: "Iocaine disconnecting".to_owned(),
                            };

                            tracing::warn!("HAProxy SPOE disconnected");

                            socket.send(disconnect.into()).await?;
                            socket.close().await?;

                            break;
                        }
                        FrameType::Notify => {
                            let actions = Self::haproxy_notify(&frame.payload(), &state)?;
                            let ack = actions.into_iter().fold(
                                Ack::new(frame.metadata().stream_id, frame.metadata().frame_id),
                                |ack, (scope, name, val)| ack.set_var(scope, name, val),
                            );

                            socket.send(ack.into()).await?;
                        }
                        _ => {
                            tracing::warn!("Unhandled HAProxy SPOE frame: {:?}", frame.frame_type());
                        }
                    }
                }
                _ = shutdown.changed() => {
                    let disconnect = AgentDisconnect {
                        status_code: 0,
                        message: "Iocaine shutting down".to_owned(),
                    };

                    socket.send(disconnect.into()).await?;
                    socket.close().await?;

                    break;
                }
            }
        }

        Ok(())
    }

    pub async fn serve(&self, mut listener: Listener) -> anyhow::Result<()> {
        let (tx, mut rx) = watch::channel(false);
        let signal_handler = tokio::spawn(async move {
            shutdown_signal(None).await;
            tracing::info!("Signalling HAProxy SPOA shutdown");

            // Realistically, this should never error
            let _ = tx.send(true);
        });

        let mut workers = JoinSet::new();

        loop {
            tokio::select! {
                res = listener.accept() => {
                    match  res {
                        Ok((stream, _)) => {
                            let shutdown = rx.clone();
                            let state_ = self.state.clone();

                            workers.spawn(async move {
                                if let Err(e) = Self::agent_loop(stream, state_, shutdown).await {
                                    // Ignore ConnectionReset, that one is normal.
                                    if let Some(err) = e.downcast_ref::<io::Error>()
                                        && err.kind() == io::ErrorKind::ConnectionReset {
                                            tracing::debug!("Connection reset in HAProxy agent loop: {err}");
                                            return
                                        }
                                    tracing::error!("Error in HAProxy agent loop: {e}");
                                }
                            });
                        }
                        Err(e) => tracing::error!("Error accepting HAProxy SPOE connection: {e}"),
                    }
                },
                _ = rx.changed() => break,
            }
        }

        // Technically this is ignoring any error value
        workers.join_all().await;

        Ok(signal_handler.await?)
    }
}