1use crate::{ChannelMessage, OutboundEnvelope, TenantCtx};
2use async_trait::async_trait;
3use serde_json::{Map, Value};
4use std::collections::BTreeMap;
5use std::time::Duration;
6use thiserror::Error;
7use time::OffsetDateTime;
8use tracing::{error, warn};
9use uuid::Uuid;
10
11pub const WORKER_ENVELOPE_VERSION: &str = "1.0";
13pub const DEFAULT_WORKER_ID: &str = "greentic-repo-assistant";
15pub const DEFAULT_WORKER_NATS_SUBJECT: &str = "workers.repo-assistant";
17
18pub use greentic_types::{WorkerMessage, WorkerRequest, WorkerResponse};
19
20#[derive(Clone, Debug, PartialEq, Eq)]
22pub enum WorkerTransport {
23 Nats,
24 Http,
25}
26
27impl WorkerTransport {
28 pub fn from_env(value: Option<String>) -> Self {
29 match value
30 .unwrap_or_else(|| "nats".to_string())
31 .to_ascii_lowercase()
32 .as_str()
33 {
34 "http" => WorkerTransport::Http,
35 _ => WorkerTransport::Nats,
36 }
37 }
38}
39
40#[derive(Clone, Debug)]
42pub struct WorkerRoutingConfig {
43 pub transport: WorkerTransport,
44 pub worker_id: String,
45 pub nats_subject: String,
46 pub http_url: Option<String>,
47 pub max_retries: u8,
49}
50
51impl Default for WorkerRoutingConfig {
52 fn default() -> Self {
53 Self {
54 transport: WorkerTransport::Nats,
55 worker_id: DEFAULT_WORKER_ID.to_string(),
56 nats_subject: DEFAULT_WORKER_NATS_SUBJECT.to_string(),
57 http_url: None,
58 max_retries: 2,
59 }
60 }
61}
62
63impl WorkerRoutingConfig {
64 pub fn from_env() -> Self {
65 let transport = WorkerTransport::from_env(std::env::var("REPO_WORKER_TRANSPORT").ok());
66 let worker_id = std::env::var("REPO_WORKER_ID")
67 .ok()
68 .filter(|v| !v.is_empty())
69 .unwrap_or_else(|| DEFAULT_WORKER_ID.to_string());
70 let nats_subject = std::env::var("REPO_WORKER_NATS_SUBJECT")
71 .ok()
72 .filter(|v| !v.is_empty())
73 .unwrap_or_else(|| DEFAULT_WORKER_NATS_SUBJECT.to_string());
74 let http_url = std::env::var("REPO_WORKER_HTTP_URL").ok();
75 let max_retries = std::env::var("REPO_WORKER_RETRIES")
76 .ok()
77 .and_then(|v| v.parse::<u8>().ok())
78 .unwrap_or(2);
79
80 Self {
81 transport,
82 worker_id,
83 nats_subject,
84 http_url,
85 max_retries,
86 }
87 }
88
89 pub fn from_route_spec(worker_id: &str, transport: WorkerTransport, target: &str) -> Self {
90 match transport {
91 WorkerTransport::Nats => WorkerRoutingConfig {
92 transport,
93 worker_id: worker_id.to_string(),
94 nats_subject: target.to_string(),
95 http_url: None,
96 max_retries: 2,
97 },
98 WorkerTransport::Http => WorkerRoutingConfig {
99 transport,
100 worker_id: worker_id.to_string(),
101 nats_subject: DEFAULT_WORKER_NATS_SUBJECT.to_string(),
102 http_url: Some(target.to_string()),
103 max_retries: 2,
104 },
105 }
106 }
107}
108
109pub fn worker_routes_from_env() -> BTreeMap<String, WorkerRoutingConfig> {
114 let raw = match std::env::var("WORKER_ROUTES") {
115 Ok(v) => v,
116 Err(_) => return BTreeMap::new(),
117 };
118 let mut map = BTreeMap::new();
119 for entry in raw.split(',').map(str::trim).filter(|s| !s.is_empty()) {
120 if let Some((id, spec)) = entry.split_once('=')
121 && let Some((transport_raw, target)) = spec.split_once(':')
122 {
123 let transport = WorkerTransport::from_env(Some(transport_raw.to_string()));
124 let cfg = WorkerRoutingConfig::from_route_spec(id.trim(), transport, target.trim());
125 map.insert(id.trim().to_string(), cfg);
126 }
127 }
128 map
129}
130
131fn now_timestamp_utc() -> String {
132 OffsetDateTime::now_utc()
133 .format(&time::format_description::well_known::Rfc3339)
134 .unwrap_or_else(|_| OffsetDateTime::now_utc().unix_timestamp().to_string())
135}
136
137fn encode_payload(payload: &Value) -> Result<String, WorkerClientError> {
138 serde_json::to_string(payload).map_err(WorkerClientError::PayloadEncode)
139}
140
141fn decode_payload(payload_json: &str) -> Value {
142 serde_json::from_str(payload_json).unwrap_or_else(|_| Value::String(payload_json.to_string()))
143}
144
145fn build_worker_request(
146 tenant: TenantCtx,
147 worker_id: String,
148 payload: Value,
149 session_id: Option<String>,
150 thread_id: Option<String>,
151 correlation_id: Option<String>,
152) -> Result<WorkerRequest, WorkerClientError> {
153 Ok(WorkerRequest {
154 version: WORKER_ENVELOPE_VERSION.to_string(),
155 tenant,
156 worker_id,
157 correlation_id: Some(correlation_id.unwrap_or_else(|| Uuid::new_v4().to_string())),
158 session_id,
159 thread_id,
160 payload_json: encode_payload(&payload)?,
161 timestamp_utc: now_timestamp_utc(),
162 })
163}
164
165fn worker_request_from_channel(
166 channel: &ChannelMessage,
167 payload: Value,
168 config: &WorkerRoutingConfig,
169 correlation_id: Option<String>,
170) -> Result<WorkerRequest, WorkerClientError> {
171 let correlation = correlation_id
172 .or_else(|| {
173 channel
174 .payload
175 .get("correlation_id")
176 .and_then(|v| v.as_str())
177 .map(str::to_string)
178 })
179 .or_else(|| {
180 channel
181 .payload
182 .get("msg_id")
183 .and_then(|v| v.as_str())
184 .map(str::to_string)
185 });
186
187 let thread_id = channel
188 .payload
189 .get("thread_id")
190 .and_then(|v| v.as_str())
191 .map(str::to_string);
192
193 build_worker_request(
194 channel.tenant.clone(),
195 config.worker_id.clone(),
196 payload,
197 Some(channel.session_id.clone()),
198 thread_id,
199 correlation,
200 )
201}
202
203pub fn empty_worker_response_for(request: &WorkerRequest) -> WorkerResponse {
204 WorkerResponse {
205 version: request.version.clone(),
206 tenant: request.tenant.clone(),
207 worker_id: request.worker_id.clone(),
208 correlation_id: request.correlation_id.clone(),
209 session_id: request.session_id.clone(),
210 thread_id: request.thread_id.clone(),
211 messages: Vec::new(),
212 timestamp_utc: now_timestamp_utc(),
213 }
214}
215
216pub fn worker_messages_to_outbound(
218 response: &WorkerResponse,
219 channel: &ChannelMessage,
220) -> Vec<OutboundEnvelope> {
221 response
222 .messages
223 .iter()
224 .map(|msg| {
225 let mut meta = Map::new();
226 meta.insert(
227 "worker_id".into(),
228 Value::String(response.worker_id.clone()),
229 );
230 if let Some(corr) = &response.correlation_id {
231 meta.insert("correlation_id".into(), Value::String(corr.clone()));
232 }
233 meta.insert("kind".into(), Value::String(msg.kind.clone()));
234
235 OutboundEnvelope {
236 tenant: channel.tenant.clone(),
237 channel_id: channel.channel_id.clone(),
238 session_id: channel.session_id.clone(),
239 meta: Value::Object(meta),
240 body: decode_payload(&msg.payload_json),
241 }
242 })
243 .collect()
244}
245
246#[derive(Debug, Error)]
247pub enum WorkerClientError {
248 #[error("failed to encode worker payload: {0}")]
249 PayloadEncode(#[source] serde_json::Error),
250 #[error("failed to serialize worker request: {0}")]
251 Serialize(#[source] serde_json::Error),
252 #[error("failed to deserialize worker response: {0}")]
253 Deserialize(#[source] serde_json::Error),
254 #[error("NATS request failed: {0}")]
255 Nats(#[source] anyhow::Error),
256 #[error("HTTP request failed: {0}")]
257 Http(#[source] anyhow::Error),
258}
259
260#[async_trait]
261pub trait WorkerClient: Send + Sync {
262 async fn send_request(
263 &self,
264 request: WorkerRequest,
265 ) -> Result<WorkerResponse, WorkerClientError>;
266}
267
268pub struct InMemoryWorkerClient {
270 responder: Box<dyn Fn(WorkerRequest) -> WorkerResponse + Send + Sync>,
271}
272
273impl InMemoryWorkerClient {
274 pub fn new<F>(responder: F) -> Self
275 where
276 F: Fn(WorkerRequest) -> WorkerResponse + Send + Sync + 'static,
277 {
278 Self {
279 responder: Box::new(responder),
280 }
281 }
282}
283
284#[async_trait]
285impl WorkerClient for InMemoryWorkerClient {
286 async fn send_request(
287 &self,
288 request: WorkerRequest,
289 ) -> Result<WorkerResponse, WorkerClientError> {
290 Ok((self.responder)(request))
291 }
292}
293
294pub async fn forward_to_worker(
296 client: &dyn WorkerClient,
297 channel: &ChannelMessage,
298 payload: Value,
299 config: &WorkerRoutingConfig,
300 correlation_id: Option<String>,
301) -> Result<Vec<OutboundEnvelope>, WorkerClientError> {
302 let request = worker_request_from_channel(channel, payload, config, correlation_id)?;
303 let response = client.send_request(request).await?;
304 Ok(worker_messages_to_outbound(&response, channel))
305}
306
307#[cfg(feature = "nats")]
308pub struct NatsWorkerClient {
309 client: async_nats::Client,
310 subject: String,
311 max_retries: u8,
312}
313
314#[cfg(feature = "nats")]
315impl NatsWorkerClient {
316 pub fn new(client: async_nats::Client, subject: String, max_retries: u8) -> Self {
317 Self {
318 client,
319 subject,
320 max_retries,
321 }
322 }
323
324 async fn send_once(
325 &self,
326 request: &WorkerRequest,
327 ) -> Result<WorkerResponse, WorkerClientError> {
328 let bytes = serde_json::to_vec(request).map_err(WorkerClientError::Serialize)?;
329 let msg = self
330 .client
331 .request(self.subject.clone(), bytes.into())
332 .await
333 .map_err(|e| WorkerClientError::Nats(anyhow::Error::new(e)))?;
334 serde_json::from_slice(&msg.payload).map_err(WorkerClientError::Deserialize)
335 }
336}
337
338#[cfg(feature = "nats")]
339#[async_trait]
340impl WorkerClient for NatsWorkerClient {
341 async fn send_request(
342 &self,
343 request: WorkerRequest,
344 ) -> Result<WorkerResponse, WorkerClientError> {
345 let mut attempt = 0;
346 loop {
347 attempt += 1;
348 match self.send_once(&request).await {
349 Ok(res) => return Ok(res),
350 Err(err) => {
351 if attempt > self.max_retries {
352 return Err(err);
353 }
354 warn!(attempt, subject = %self.subject, error = %err, "retrying worker request over NATS");
355 tokio::time::sleep(Duration::from_millis(50 * attempt as u64)).await;
356 }
357 }
358 }
359 }
360}
361
362pub struct HttpWorkerClient {
363 client: reqwest::Client,
364 url: String,
365 max_retries: u8,
366}
367
368impl HttpWorkerClient {
369 pub fn new(url: String, max_retries: u8) -> Self {
370 Self {
371 client: reqwest::Client::new(),
372 url,
373 max_retries,
374 }
375 }
376
377 async fn send_once(
378 &self,
379 request: &WorkerRequest,
380 ) -> Result<WorkerResponse, WorkerClientError> {
381 let response = self
382 .client
383 .post(&self.url)
384 .json(request)
385 .send()
386 .await
387 .map_err(|e| WorkerClientError::Http(anyhow::Error::new(e)))?;
388
389 if !response.status().is_success() {
390 let status = response.status();
391 let body = response.text().await.unwrap_or_default();
392 return Err(WorkerClientError::Http(anyhow::anyhow!(
393 "HTTP {} from worker endpoint: {}",
394 status,
395 body
396 )));
397 }
398
399 let body = response
400 .bytes()
401 .await
402 .map_err(|e| WorkerClientError::Http(anyhow::Error::new(e)))?;
403 serde_json::from_slice(&body).map_err(WorkerClientError::Deserialize)
404 }
405}
406
407#[async_trait]
408impl WorkerClient for HttpWorkerClient {
409 async fn send_request(
410 &self,
411 request: WorkerRequest,
412 ) -> Result<WorkerResponse, WorkerClientError> {
413 let mut attempt = 0;
414 loop {
415 attempt += 1;
416 match self.send_once(&request).await {
417 Ok(res) => return Ok(res),
418 Err(err) => {
419 if attempt > self.max_retries {
420 error!(attempt, url = %self.url, error = %err, "worker HTTP request failed");
421 return Err(err);
422 }
423 warn!(attempt, url = %self.url, error = %err, "retrying worker HTTP request");
424 tokio::time::sleep(Duration::from_millis(50 * attempt as u64)).await;
425 }
426 }
427 }
428 }
429}
430
431#[cfg(test)]
432mod tests {
433 use super::*;
434
435 fn sample_channel() -> ChannelMessage {
436 ChannelMessage {
437 tenant: crate::make_tenant_ctx("acme".into(), Some("team".into()), None),
438 channel_id: "webchat".into(),
439 session_id: "sess-1".into(),
440 route: None,
441 payload: serde_json::json!({"text": "hi"}),
442 }
443 }
444
445 #[tokio::test]
446 async fn builds_request_and_maps_response() {
447 let channel = sample_channel();
448 let config = WorkerRoutingConfig::default();
449 let payload = serde_json::json!({"body": "hello"});
450 let corr = Some("corr-1".to_string());
451 let client = InMemoryWorkerClient::new(|req| {
452 assert_eq!(req.version, WORKER_ENVELOPE_VERSION);
453 assert_eq!(req.worker_id, DEFAULT_WORKER_ID);
454 assert_eq!(req.session_id.as_deref(), Some("sess-1"));
455 assert_eq!(req.correlation_id.as_deref(), Some("corr-1"));
456 let decoded: Value = serde_json::from_str(&req.payload_json).unwrap();
457 assert_eq!(decoded["body"], "hello");
458 let mut resp = empty_worker_response_for(&req);
459 resp.messages = vec![WorkerMessage {
460 kind: "text".into(),
461 payload_json: serde_json::to_string(&serde_json::json!({"reply": "pong"})).unwrap(),
462 }];
463 resp
464 });
465
466 let outbound = forward_to_worker(&client, &channel, payload, &config, corr)
467 .await
468 .unwrap();
469
470 assert_eq!(outbound.len(), 1);
471 assert_eq!(outbound[0].channel_id, "webchat");
472 assert_eq!(outbound[0].body["reply"], "pong");
473 assert_eq!(outbound[0].tenant.tenant.as_str(), "acme");
474 assert_eq!(outbound[0].session_id, "sess-1");
475 assert_eq!(outbound[0].meta["kind"], "text");
476 assert_eq!(outbound[0].meta["worker_id"], DEFAULT_WORKER_ID);
477 assert_eq!(outbound[0].meta["correlation_id"], "corr-1");
478 }
479
480 #[tokio::test]
481 async fn populates_thread_and_correlation_defaults() {
482 let mut channel = sample_channel();
483 channel.payload = serde_json::json!({"text": "ping", "thread_id": "thr-1"});
484 let config = WorkerRoutingConfig::default();
485 let payload = serde_json::json!({"body": "hello"});
486
487 let client = InMemoryWorkerClient::new(|req| {
488 assert_eq!(req.thread_id.as_deref(), Some("thr-1"));
489 assert!(req.correlation_id.is_some());
490 empty_worker_response_for(&req)
491 });
492
493 let outbound = forward_to_worker(&client, &channel, payload, &config, None)
494 .await
495 .unwrap();
496
497 assert_eq!(outbound.len(), 0);
498 }
499}