use crate::{ChannelMessage, OutboundEnvelope, TenantCtx};
use async_trait::async_trait;
use serde_json::{Map, Value};
use std::collections::BTreeMap;
use std::str::FromStr;
use std::time::Duration;
use thiserror::Error;
use time::OffsetDateTime;
use tracing::{error, warn};
use uuid::Uuid;
pub const WORKER_ENVELOPE_VERSION: &str = "1.0";
pub const DEFAULT_WORKER_ID: &str = "greentic-repo-assistant";
pub const DEFAULT_WORKER_NATS_SUBJECT: &str = "workers.repo-assistant";
pub use greentic_types::{WorkerMessage, WorkerRequest, WorkerResponse};
#[derive(Clone, Debug, PartialEq, Eq)]
pub enum WorkerTransport {
Nats,
Http,
}
impl WorkerTransport {
pub fn from_optional(value: Option<&str>) -> Self {
match value.unwrap_or("nats").to_ascii_lowercase().as_str() {
"http" => WorkerTransport::Http,
_ => WorkerTransport::Nats,
}
}
}
impl FromStr for WorkerTransport {
type Err = ();
fn from_str(s: &str) -> Result<Self, Self::Err> {
Ok(match s.to_ascii_lowercase().as_str() {
"http" => WorkerTransport::Http,
_ => WorkerTransport::Nats,
})
}
}
#[derive(Clone, Debug)]
pub struct WorkerRoutingConfig {
pub transport: WorkerTransport,
pub worker_id: String,
pub nats_subject: String,
pub http_url: Option<String>,
pub max_retries: u8,
}
impl Default for WorkerRoutingConfig {
fn default() -> Self {
Self {
transport: WorkerTransport::Nats,
worker_id: DEFAULT_WORKER_ID.to_string(),
nats_subject: DEFAULT_WORKER_NATS_SUBJECT.to_string(),
http_url: None,
max_retries: 2,
}
}
}
impl WorkerRoutingConfig {
pub fn from_settings(
transport: WorkerTransport,
worker_id: String,
nats_subject: String,
http_url: Option<String>,
max_retries: u8,
) -> Self {
Self {
transport,
worker_id,
nats_subject,
http_url,
max_retries,
}
}
pub fn from_route_spec(worker_id: &str, transport: WorkerTransport, target: &str) -> Self {
match transport {
WorkerTransport::Nats => WorkerRoutingConfig {
transport,
worker_id: worker_id.to_string(),
nats_subject: target.to_string(),
http_url: None,
max_retries: 2,
},
WorkerTransport::Http => WorkerRoutingConfig {
transport,
worker_id: worker_id.to_string(),
nats_subject: DEFAULT_WORKER_NATS_SUBJECT.to_string(),
http_url: Some(target.to_string()),
max_retries: 2,
},
}
}
}
pub fn worker_routes_from_specs(raw: &str) -> BTreeMap<String, WorkerRoutingConfig> {
let mut map = BTreeMap::new();
for entry in raw.split(',').map(str::trim).filter(|s| !s.is_empty()) {
if let Some((id, spec)) = entry.split_once('=')
&& let Some((transport_raw, target)) = spec.split_once(':')
{
let transport = WorkerTransport::from_optional(Some(transport_raw));
let cfg = WorkerRoutingConfig::from_route_spec(id.trim(), transport, target.trim());
map.insert(id.trim().to_string(), cfg);
}
}
map
}
fn now_timestamp_utc() -> String {
OffsetDateTime::now_utc()
.format(&time::format_description::well_known::Rfc3339)
.unwrap_or_else(|_| OffsetDateTime::now_utc().unix_timestamp().to_string())
}
fn encode_payload(payload: &Value) -> Result<String, WorkerClientError> {
serde_json::to_string(payload).map_err(WorkerClientError::PayloadEncode)
}
fn decode_payload(payload_json: &str) -> Value {
serde_json::from_str(payload_json).unwrap_or_else(|_| Value::String(payload_json.to_string()))
}
fn build_worker_request(
tenant: TenantCtx,
worker_id: String,
payload: Value,
session_id: Option<String>,
thread_id: Option<String>,
correlation_id: Option<String>,
) -> Result<WorkerRequest, WorkerClientError> {
Ok(WorkerRequest {
version: WORKER_ENVELOPE_VERSION.to_string(),
tenant,
worker_id,
correlation_id: Some(correlation_id.unwrap_or_else(|| Uuid::new_v4().to_string())),
session_id,
thread_id,
payload_json: encode_payload(&payload)?,
timestamp_utc: now_timestamp_utc(),
})
}
fn worker_request_from_channel(
channel: &ChannelMessage,
payload: Value,
config: &WorkerRoutingConfig,
correlation_id: Option<String>,
) -> Result<WorkerRequest, WorkerClientError> {
let correlation = correlation_id
.or_else(|| {
channel
.payload
.get("correlation_id")
.and_then(|v| v.as_str())
.map(str::to_string)
})
.or_else(|| {
channel
.payload
.get("msg_id")
.and_then(|v| v.as_str())
.map(str::to_string)
});
let thread_id = channel
.payload
.get("thread_id")
.and_then(|v| v.as_str())
.map(str::to_string);
build_worker_request(
channel.tenant.clone(),
config.worker_id.clone(),
payload,
Some(channel.session_id.clone()),
thread_id,
correlation,
)
}
pub fn empty_worker_response_for(request: &WorkerRequest) -> WorkerResponse {
WorkerResponse {
version: request.version.clone(),
tenant: request.tenant.clone(),
worker_id: request.worker_id.clone(),
correlation_id: request.correlation_id.clone(),
session_id: request.session_id.clone(),
thread_id: request.thread_id.clone(),
messages: Vec::new(),
timestamp_utc: now_timestamp_utc(),
}
}
pub fn worker_messages_to_outbound(
response: &WorkerResponse,
channel: &ChannelMessage,
) -> Vec<OutboundEnvelope> {
response
.messages
.iter()
.map(|msg| {
let mut meta = Map::new();
meta.insert(
"worker_id".into(),
Value::String(response.worker_id.clone()),
);
if let Some(corr) = &response.correlation_id {
meta.insert("correlation_id".into(), Value::String(corr.clone()));
}
meta.insert("kind".into(), Value::String(msg.kind.clone()));
OutboundEnvelope {
tenant: channel.tenant.clone(),
channel_id: channel.channel_id.clone(),
session_id: channel.session_id.clone(),
meta: Value::Object(meta),
body: decode_payload(&msg.payload_json),
}
})
.collect()
}
#[derive(Debug, Error)]
pub enum WorkerClientError {
#[error("failed to encode worker payload: {0}")]
PayloadEncode(#[source] serde_json::Error),
#[error("failed to serialize worker request: {0}")]
Serialize(#[source] serde_json::Error),
#[error("failed to deserialize worker response: {0}")]
Deserialize(#[source] serde_json::Error),
#[error("NATS request failed: {0}")]
Nats(#[source] anyhow::Error),
#[error("HTTP request failed: {0}")]
Http(#[source] anyhow::Error),
}
#[async_trait]
pub trait WorkerClient: Send + Sync {
async fn send_request(
&self,
request: WorkerRequest,
) -> Result<WorkerResponse, WorkerClientError>;
}
pub struct InMemoryWorkerClient {
responder: Box<dyn Fn(WorkerRequest) -> WorkerResponse + Send + Sync>,
}
impl InMemoryWorkerClient {
pub fn new<F>(responder: F) -> Self
where
F: Fn(WorkerRequest) -> WorkerResponse + Send + Sync + 'static,
{
Self {
responder: Box::new(responder),
}
}
}
#[async_trait]
impl WorkerClient for InMemoryWorkerClient {
async fn send_request(
&self,
request: WorkerRequest,
) -> Result<WorkerResponse, WorkerClientError> {
Ok((self.responder)(request))
}
}
pub async fn forward_to_worker(
client: &dyn WorkerClient,
channel: &ChannelMessage,
payload: Value,
config: &WorkerRoutingConfig,
correlation_id: Option<String>,
) -> Result<Vec<OutboundEnvelope>, WorkerClientError> {
let request = worker_request_from_channel(channel, payload, config, correlation_id)?;
let response = client.send_request(request).await?;
Ok(worker_messages_to_outbound(&response, channel))
}
#[cfg(feature = "nats")]
pub struct NatsWorkerClient {
client: async_nats::Client,
subject: String,
max_retries: u8,
}
#[cfg(feature = "nats")]
impl NatsWorkerClient {
pub fn new(client: async_nats::Client, subject: String, max_retries: u8) -> Self {
Self {
client,
subject,
max_retries,
}
}
async fn send_once(
&self,
request: &WorkerRequest,
) -> Result<WorkerResponse, WorkerClientError> {
let bytes = serde_json::to_vec(request).map_err(WorkerClientError::Serialize)?;
let msg = self
.client
.request(self.subject.clone(), bytes.into())
.await
.map_err(|e| WorkerClientError::Nats(anyhow::Error::new(e)))?;
serde_json::from_slice(&msg.payload).map_err(WorkerClientError::Deserialize)
}
}
#[cfg(feature = "nats")]
#[async_trait]
impl WorkerClient for NatsWorkerClient {
async fn send_request(
&self,
request: WorkerRequest,
) -> Result<WorkerResponse, WorkerClientError> {
let mut attempt = 0;
loop {
attempt += 1;
match self.send_once(&request).await {
Ok(res) => return Ok(res),
Err(err) => {
if attempt > self.max_retries {
return Err(err);
}
warn!(attempt, subject = %self.subject, error = %err, "retrying worker request over NATS");
tokio::time::sleep(Duration::from_millis(50 * attempt as u64)).await;
}
}
}
}
}
pub struct HttpWorkerClient {
client: reqwest::Client,
url: String,
max_retries: u8,
}
impl HttpWorkerClient {
pub fn new(url: String, max_retries: u8) -> Self {
Self {
client: reqwest::Client::new(),
url,
max_retries,
}
}
async fn send_once(
&self,
request: &WorkerRequest,
) -> Result<WorkerResponse, WorkerClientError> {
let response = self
.client
.post(&self.url)
.json(request)
.send()
.await
.map_err(|e| WorkerClientError::Http(anyhow::Error::new(e)))?;
if !response.status().is_success() {
let status = response.status();
let body = response.text().await.unwrap_or_default();
return Err(WorkerClientError::Http(anyhow::anyhow!(
"HTTP {} from worker endpoint: {}",
status,
body
)));
}
let body = response
.bytes()
.await
.map_err(|e| WorkerClientError::Http(anyhow::Error::new(e)))?;
serde_json::from_slice(&body).map_err(WorkerClientError::Deserialize)
}
}
#[async_trait]
impl WorkerClient for HttpWorkerClient {
async fn send_request(
&self,
request: WorkerRequest,
) -> Result<WorkerResponse, WorkerClientError> {
let mut attempt = 0;
loop {
attempt += 1;
match self.send_once(&request).await {
Ok(res) => return Ok(res),
Err(err) => {
if attempt > self.max_retries {
error!(attempt, url = %self.url, error = %err, "worker HTTP request failed");
return Err(err);
}
warn!(attempt, url = %self.url, error = %err, "retrying worker HTTP request");
tokio::time::sleep(Duration::from_millis(50 * attempt as u64)).await;
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn sample_channel() -> ChannelMessage {
ChannelMessage {
tenant: crate::make_tenant_ctx("acme".into(), Some("team".into()), None),
channel_id: "webchat".into(),
session_id: "sess-1".into(),
route: None,
payload: serde_json::json!({"text": "hi"}),
}
}
#[tokio::test]
async fn builds_request_and_maps_response() {
let channel = sample_channel();
let config = WorkerRoutingConfig::default();
let payload = serde_json::json!({"body": "hello"});
let corr = Some("corr-1".to_string());
let client = InMemoryWorkerClient::new(|req| {
assert_eq!(req.version, WORKER_ENVELOPE_VERSION);
assert_eq!(req.worker_id, DEFAULT_WORKER_ID);
assert_eq!(req.session_id.as_deref(), Some("sess-1"));
assert_eq!(req.correlation_id.as_deref(), Some("corr-1"));
let decoded: Value = serde_json::from_str(&req.payload_json).unwrap();
assert_eq!(decoded["body"], "hello");
let mut resp = empty_worker_response_for(&req);
resp.messages = vec![WorkerMessage {
kind: "text".into(),
payload_json: serde_json::to_string(&serde_json::json!({"reply": "pong"})).unwrap(),
}];
resp
});
let outbound = forward_to_worker(&client, &channel, payload, &config, corr)
.await
.unwrap();
assert_eq!(outbound.len(), 1);
assert_eq!(outbound[0].channel_id, "webchat");
assert_eq!(outbound[0].body["reply"], "pong");
assert_eq!(outbound[0].tenant.tenant.as_str(), "acme");
assert_eq!(outbound[0].session_id, "sess-1");
assert_eq!(outbound[0].meta["kind"], "text");
assert_eq!(outbound[0].meta["worker_id"], DEFAULT_WORKER_ID);
assert_eq!(outbound[0].meta["correlation_id"], "corr-1");
}
#[tokio::test]
async fn populates_thread_and_correlation_defaults() {
let mut channel = sample_channel();
channel.payload = serde_json::json!({"text": "ping", "thread_id": "thr-1"});
let config = WorkerRoutingConfig::default();
let payload = serde_json::json!({"body": "hello"});
let client = InMemoryWorkerClient::new(|req| {
assert_eq!(req.thread_id.as_deref(), Some("thr-1"));
assert!(req.correlation_id.is_some());
empty_worker_response_for(&req)
});
let outbound = forward_to_worker(&client, &channel, payload, &config, None)
.await
.unwrap();
assert_eq!(outbound.len(), 0);
}
}