fn0-worker 0.3.31

Worker binary for the fn0 FaaS platform
use anyhow::{Result, anyhow};
use base64::Engine;
use bytes::Bytes;
use http_body_util::combinators::UnsyncBoxBody;
use http_body_util::{BodyExt, Full};
use oci_rust_sdk::auth::{RequestSigner, SimpleAuthProvider, SimpleAuthProviderRequiredFields};
use serde::Deserialize;
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::{mpsc, oneshot};

use crate::worker_pool::{self, DispatchError, RequestEnvelope};

const PUT_MESSAGES_API_VERSION: &str = "20210201";
const LONG_POLL_TIMEOUT_SECS: u32 = 30;
const VISIBILITY_SECS: u32 = 60;
const MAX_MESSAGES_PER_FETCH: u32 = 8;

pub struct QueueConsumerConfig {
    pub queue_ocid: String,
    pub messages_endpoint: String,
    pub tenancy: String,
    pub user: String,
    pub fingerprint: String,
    pub private_key_pem: String,
}

impl QueueConsumerConfig {
    pub fn from_env() -> Result<Self> {
        let messages_endpoint = std::env::var("FN0_QUEUE_MESSAGES_ENDPOINT")
            .map_err(|_| anyhow!("FN0_QUEUE_MESSAGES_ENDPOINT is required"))?;
        let queue_ocid =
            std::env::var("FN0_QUEUE_OCID").map_err(|_| anyhow!("FN0_QUEUE_OCID is required"))?;
        let tenancy = std::env::var("FN0_QUEUE_OCI_TENANCY_ID")
            .map_err(|_| anyhow!("FN0_QUEUE_OCI_TENANCY_ID is required"))?;
        let user = std::env::var("FN0_QUEUE_OCI_USER_ID")
            .map_err(|_| anyhow!("FN0_QUEUE_OCI_USER_ID is required"))?;
        let fingerprint = std::env::var("FN0_QUEUE_OCI_FINGERPRINT")
            .map_err(|_| anyhow!("FN0_QUEUE_OCI_FINGERPRINT is required"))?;
        let private_key_b64 = std::env::var("FN0_QUEUE_OCI_PRIVATE_KEY_BASE64")
            .map_err(|_| anyhow!("FN0_QUEUE_OCI_PRIVATE_KEY_BASE64 is required"))?;
        let private_key_pem = base64::engine::general_purpose::STANDARD
            .decode(private_key_b64.as_bytes())
            .map_err(|e| anyhow!("queue private key base64 decode: {e}"))
            .and_then(|b| {
                String::from_utf8(b).map_err(|e| anyhow!("queue private key utf8: {e}"))
            })?;
        Ok(Self {
            queue_ocid,
            messages_endpoint,
            tenancy,
            user,
            fingerprint,
            private_key_pem,
        })
    }
}

struct Consumer {
    queue_ocid: String,
    messages_host: String,
    signer: Arc<RequestSigner>,
    http: reqwest::Client,
}

#[derive(Deserialize)]
struct GetMessagesResponse {
    #[serde(default)]
    messages: Vec<IncomingMessage>,
}

#[derive(Deserialize)]
struct IncomingMessage {
    receipt: String,
    content: String,
}

#[derive(Deserialize)]
struct WrappedMessage {
    project_id: String,
    task_name: String,
    payload: serde_json::Value,
}

pub async fn run(
    config: QueueConsumerConfig,
    worker_senders: Arc<Vec<mpsc::Sender<RequestEnvelope>>>,
) {
    let consumer = match Consumer::new(config) {
        Ok(c) => c,
        Err(err) => {
            tracing::error!(?err, "queue consumer init failed; consumer disabled");
            return;
        }
    };

    loop {
        match consumer.fetch_messages().await {
            Ok(messages) if messages.is_empty() => {}
            Ok(messages) => {
                for msg in messages {
                    let consumer = &consumer;
                    let worker_senders = worker_senders.clone();
                    if let Err(err) = consumer.dispatch_one(msg, worker_senders).await {
                        tracing::warn!(?err, "queue dispatch failed; leaving for redelivery");
                    }
                }
            }
            Err(err) => {
                tracing::warn!(?err, "queue fetch failed");
                tokio::time::sleep(Duration::from_secs(2)).await;
            }
        }
    }
}

impl Consumer {
    fn new(config: QueueConsumerConfig) -> Result<Self> {
        let url = url::Url::parse(&config.messages_endpoint)
            .map_err(|e| anyhow!("queue endpoint parse: {e}"))?;
        let host = url
            .host_str()
            .ok_or_else(|| anyhow!("queue endpoint missing host"))?;
        let messages_host = if let Some(port) = url.port() {
            format!("{host}:{port}")
        } else {
            host.to_string()
        };

        let provider: Arc<SimpleAuthProvider> = Arc::new(
            SimpleAuthProvider::builder(SimpleAuthProviderRequiredFields {
                tenancy: config.tenancy,
                user: config.user,
                fingerprint: config.fingerprint,
                private_key: config.private_key_pem,
            })
            .build(),
        );
        let signer = Arc::new(
            RequestSigner::new(provider as Arc<dyn oci_rust_sdk::auth::AuthProvider>)
                .map_err(|e| anyhow!("queue signer init failed: {e:?}"))?,
        );

        Ok(Self {
            queue_ocid: config.queue_ocid,
            messages_host,
            signer,
            http: reqwest::Client::new(),
        })
    }

    fn messages_url(&self, suffix: &str) -> String {
        format!(
            "https://{}/{}/queues/{}/messages{}",
            self.messages_host, PUT_MESSAGES_API_VERSION, self.queue_ocid, suffix
        )
    }

    async fn fetch_messages(&self) -> Result<Vec<IncomingMessage>> {
        let url_str = self.messages_url(&format!(
            "?limit={MAX_MESSAGES_PER_FETCH}&timeoutInSeconds={LONG_POLL_TIMEOUT_SECS}&visibilityInSeconds={VISIBILITY_SECS}"
        ));
        let url = url::Url::parse(&url_str).map_err(|e| anyhow!("queue url parse: {e}"))?;

        let mut headers = reqwest::header::HeaderMap::new();
        self.signer
            .sign_request("GET", &url, &mut headers, None)
            .map_err(|e| anyhow!("queue sign: {e:?}"))?;

        let resp = self
            .http
            .get(url_str)
            .headers(headers)
            .send()
            .await
            .map_err(|e| anyhow!("queue GetMessages: {e}"))?;

        let status = resp.status();
        let body_bytes = resp
            .bytes()
            .await
            .map_err(|e| anyhow!("queue GetMessages body: {e}"))?;

        if !status.is_success() {
            return Err(anyhow!(
                "queue GetMessages failed status={status} body={}",
                String::from_utf8_lossy(&body_bytes)
            ));
        }

        let parsed: GetMessagesResponse =
            serde_json::from_slice(&body_bytes).map_err(|e| anyhow!("queue resp parse: {e}"))?;
        Ok(parsed.messages)
    }

    async fn delete_message(&self, receipt: &str) -> Result<()> {
        let url_str = self.messages_url(&format!("/{receipt}"));
        let url = url::Url::parse(&url_str).map_err(|e| anyhow!("queue url parse: {e}"))?;

        let mut headers = reqwest::header::HeaderMap::new();
        self.signer
            .sign_request("DELETE", &url, &mut headers, None)
            .map_err(|e| anyhow!("queue sign: {e:?}"))?;

        let resp = self
            .http
            .delete(url_str)
            .headers(headers)
            .send()
            .await
            .map_err(|e| anyhow!("queue DeleteMessage: {e}"))?;

        if !resp.status().is_success() {
            let status = resp.status();
            let body = resp.text().await.unwrap_or_default();
            return Err(anyhow!(
                "queue DeleteMessage failed status={status} body={body}"
            ));
        }
        Ok(())
    }

    async fn dispatch_one(
        &self,
        msg: IncomingMessage,
        worker_senders: Arc<Vec<mpsc::Sender<RequestEnvelope>>>,
    ) -> Result<()> {
        let content_bytes = match base64::engine::general_purpose::STANDARD
            .decode(msg.content.as_bytes())
        {
            Ok(bytes) => bytes,
            Err(e) => {
                tracing::warn!(
                    receipt = %msg.receipt,
                    error = %e,
                    "queue content base64 unrecoverable; acking malformed message"
                );
                return self.delete_message(&msg.receipt).await;
            }
        };
        let wrapped: WrappedMessage = match serde_json::from_slice(&content_bytes) {
            Ok(w) => w,
            Err(e) => {
                tracing::warn!(
                    receipt = %msg.receipt,
                    error = %e,
                    "queue content parse unrecoverable; acking malformed message"
                );
                return self.delete_message(&msg.receipt).await;
            }
        };

        let inner_body = serde_json::to_vec(&serde_json::json!({
            "task_name": wrapped.task_name,
            "payload": wrapped.payload,
        }))
        .map_err(|e| anyhow!("queue inner body: {e}"))?;

        let req: hyper::Request<UnsyncBoxBody<Bytes, anyhow::Error>> = hyper::Request::builder()
            .method("POST")
            .uri("/__fn0_queue_task/execute")
            .header("host", format!("{}.fn0.dev", wrapped.project_id))
            .header("content-type", "application/json")
            .body(
                Full::new(Bytes::from(inner_body))
                    .map_err(|never: std::convert::Infallible| match never {})
                    .boxed_unsync(),
            )
            .map_err(|e| anyhow!("queue req build: {e}"))?;

        let (resp_tx, resp_rx) = oneshot::channel();
        let envelope = RequestEnvelope {
            code_id: wrapped.project_id.clone(),
            req,
            resp_tx,
        };

        if let Err(err) = worker_pool::dispatch(&worker_senders, envelope) {
            return match err {
                DispatchError::Full => {
                    Err(anyhow!("worker queue full for {}", wrapped.project_id))
                }
                DispatchError::Closed => Err(anyhow!("worker pool closed")),
            };
        }

        let result = match resp_rx.await {
            Ok(r) => r,
            Err(_) => return Err(anyhow!("worker dropped response channel")),
        };

        let resp = result.map_err(|e| anyhow!("queue task exec: {e:?}"))?;
        if !resp.status().is_success() {
            return Err(anyhow!("queue task returned status {}", resp.status()));
        }

        self.delete_message(&msg.receipt).await
    }
}