1use crate::{ChannelMessage, OutboundEnvelope, TenantCtx};
2use async_trait::async_trait;
3use serde_json::{Map, Value};
4use std::collections::BTreeMap;
5use std::str::FromStr;
6use std::time::Duration;
7use thiserror::Error;
8use time::OffsetDateTime;
9use tracing::{error, warn};
10use uuid::Uuid;
11
12pub const WORKER_ENVELOPE_VERSION: &str = "1.0";
14pub const DEFAULT_WORKER_ID: &str = "greentic-repo-assistant";
16pub const DEFAULT_WORKER_NATS_SUBJECT: &str = "workers.repo-assistant";
18
19pub use greentic_types::{WorkerMessage, WorkerRequest, WorkerResponse};
20
21#[derive(Clone, Debug, PartialEq, Eq)]
23pub enum WorkerTransport {
24 Nats,
25 Http,
26}
27
28impl WorkerTransport {
29 pub fn from_optional(value: Option<&str>) -> Self {
30 match value.unwrap_or("nats").to_ascii_lowercase().as_str() {
31 "http" => WorkerTransport::Http,
32 _ => WorkerTransport::Nats,
33 }
34 }
35}
36
37impl FromStr for WorkerTransport {
38 type Err = ();
39
40 fn from_str(s: &str) -> Result<Self, Self::Err> {
41 Ok(match s.to_ascii_lowercase().as_str() {
42 "http" => WorkerTransport::Http,
43 _ => WorkerTransport::Nats,
44 })
45 }
46}
47
48#[derive(Clone, Debug)]
50pub struct WorkerRoutingConfig {
51 pub transport: WorkerTransport,
52 pub worker_id: String,
53 pub nats_subject: String,
54 pub http_url: Option<String>,
55 pub max_retries: u8,
57}
58
59impl Default for WorkerRoutingConfig {
60 fn default() -> Self {
61 Self {
62 transport: WorkerTransport::Nats,
63 worker_id: DEFAULT_WORKER_ID.to_string(),
64 nats_subject: DEFAULT_WORKER_NATS_SUBJECT.to_string(),
65 http_url: None,
66 max_retries: 2,
67 }
68 }
69}
70
71impl WorkerRoutingConfig {
72 pub fn from_settings(
73 transport: WorkerTransport,
74 worker_id: String,
75 nats_subject: String,
76 http_url: Option<String>,
77 max_retries: u8,
78 ) -> Self {
79 Self {
80 transport,
81 worker_id,
82 nats_subject,
83 http_url,
84 max_retries,
85 }
86 }
87
88 pub fn from_route_spec(worker_id: &str, transport: WorkerTransport, target: &str) -> Self {
89 match transport {
90 WorkerTransport::Nats => WorkerRoutingConfig {
91 transport,
92 worker_id: worker_id.to_string(),
93 nats_subject: target.to_string(),
94 http_url: None,
95 max_retries: 2,
96 },
97 WorkerTransport::Http => WorkerRoutingConfig {
98 transport,
99 worker_id: worker_id.to_string(),
100 nats_subject: DEFAULT_WORKER_NATS_SUBJECT.to_string(),
101 http_url: Some(target.to_string()),
102 max_retries: 2,
103 },
104 }
105 }
106}
107
108pub fn worker_routes_from_specs(raw: &str) -> BTreeMap<String, WorkerRoutingConfig> {
113 let mut map = BTreeMap::new();
114 for entry in raw.split(',').map(str::trim).filter(|s| !s.is_empty()) {
115 if let Some((id, spec)) = entry.split_once('=')
116 && let Some((transport_raw, target)) = spec.split_once(':')
117 {
118 let transport = WorkerTransport::from_optional(Some(transport_raw));
119 let cfg = WorkerRoutingConfig::from_route_spec(id.trim(), transport, target.trim());
120 map.insert(id.trim().to_string(), cfg);
121 }
122 }
123 map
124}
125
126fn now_timestamp_utc() -> String {
127 OffsetDateTime::now_utc()
128 .format(&time::format_description::well_known::Rfc3339)
129 .unwrap_or_else(|_| OffsetDateTime::now_utc().unix_timestamp().to_string())
130}
131
132fn encode_payload(payload: &Value) -> Result<String, WorkerClientError> {
133 serde_json::to_string(payload).map_err(WorkerClientError::PayloadEncode)
134}
135
136fn decode_payload(payload_json: &str) -> Value {
137 serde_json::from_str(payload_json).unwrap_or_else(|_| Value::String(payload_json.to_string()))
138}
139
140fn build_worker_request(
141 tenant: TenantCtx,
142 worker_id: String,
143 payload: Value,
144 session_id: Option<String>,
145 thread_id: Option<String>,
146 correlation_id: Option<String>,
147) -> Result<WorkerRequest, WorkerClientError> {
148 Ok(WorkerRequest {
149 version: WORKER_ENVELOPE_VERSION.to_string(),
150 tenant,
151 worker_id,
152 correlation_id: Some(correlation_id.unwrap_or_else(|| Uuid::new_v4().to_string())),
153 session_id,
154 thread_id,
155 payload_json: encode_payload(&payload)?,
156 timestamp_utc: now_timestamp_utc(),
157 })
158}
159
160fn worker_request_from_channel(
161 channel: &ChannelMessage,
162 payload: Value,
163 config: &WorkerRoutingConfig,
164 correlation_id: Option<String>,
165) -> Result<WorkerRequest, WorkerClientError> {
166 let correlation = correlation_id
167 .or_else(|| {
168 channel
169 .payload
170 .get("correlation_id")
171 .and_then(|v| v.as_str())
172 .map(str::to_string)
173 })
174 .or_else(|| {
175 channel
176 .payload
177 .get("msg_id")
178 .and_then(|v| v.as_str())
179 .map(str::to_string)
180 });
181
182 let thread_id = channel
183 .payload
184 .get("thread_id")
185 .and_then(|v| v.as_str())
186 .map(str::to_string);
187
188 build_worker_request(
189 channel.tenant.clone(),
190 config.worker_id.clone(),
191 payload,
192 Some(channel.session_id.clone()),
193 thread_id,
194 correlation,
195 )
196}
197
198pub fn empty_worker_response_for(request: &WorkerRequest) -> WorkerResponse {
199 WorkerResponse {
200 version: request.version.clone(),
201 tenant: request.tenant.clone(),
202 worker_id: request.worker_id.clone(),
203 correlation_id: request.correlation_id.clone(),
204 session_id: request.session_id.clone(),
205 thread_id: request.thread_id.clone(),
206 messages: Vec::new(),
207 timestamp_utc: now_timestamp_utc(),
208 }
209}
210
211pub fn worker_messages_to_outbound(
213 response: &WorkerResponse,
214 channel: &ChannelMessage,
215) -> Vec<OutboundEnvelope> {
216 response
217 .messages
218 .iter()
219 .map(|msg| {
220 let mut meta = Map::new();
221 meta.insert(
222 "worker_id".into(),
223 Value::String(response.worker_id.clone()),
224 );
225 if let Some(corr) = &response.correlation_id {
226 meta.insert("correlation_id".into(), Value::String(corr.clone()));
227 }
228 meta.insert("kind".into(), Value::String(msg.kind.clone()));
229
230 OutboundEnvelope {
231 tenant: channel.tenant.clone(),
232 channel_id: channel.channel_id.clone(),
233 session_id: channel.session_id.clone(),
234 meta: Value::Object(meta),
235 body: decode_payload(&msg.payload_json),
236 }
237 })
238 .collect()
239}
240
241#[derive(Debug, Error)]
242pub enum WorkerClientError {
243 #[error("failed to encode worker payload: {0}")]
244 PayloadEncode(#[source] serde_json::Error),
245 #[error("failed to serialize worker request: {0}")]
246 Serialize(#[source] serde_json::Error),
247 #[error("failed to deserialize worker response: {0}")]
248 Deserialize(#[source] serde_json::Error),
249 #[error("NATS request failed: {0}")]
250 Nats(#[source] anyhow::Error),
251 #[error("HTTP request failed: {0}")]
252 Http(#[source] anyhow::Error),
253}
254
255#[async_trait]
256pub trait WorkerClient: Send + Sync {
257 async fn send_request(
258 &self,
259 request: WorkerRequest,
260 ) -> Result<WorkerResponse, WorkerClientError>;
261}
262
263pub struct InMemoryWorkerClient {
265 responder: Box<dyn Fn(WorkerRequest) -> WorkerResponse + Send + Sync>,
266}
267
268impl InMemoryWorkerClient {
269 pub fn new<F>(responder: F) -> Self
270 where
271 F: Fn(WorkerRequest) -> WorkerResponse + Send + Sync + 'static,
272 {
273 Self {
274 responder: Box::new(responder),
275 }
276 }
277}
278
279#[async_trait]
280impl WorkerClient for InMemoryWorkerClient {
281 async fn send_request(
282 &self,
283 request: WorkerRequest,
284 ) -> Result<WorkerResponse, WorkerClientError> {
285 Ok((self.responder)(request))
286 }
287}
288
289pub async fn forward_to_worker(
291 client: &dyn WorkerClient,
292 channel: &ChannelMessage,
293 payload: Value,
294 config: &WorkerRoutingConfig,
295 correlation_id: Option<String>,
296) -> Result<Vec<OutboundEnvelope>, WorkerClientError> {
297 let request = worker_request_from_channel(channel, payload, config, correlation_id)?;
298 let response = client.send_request(request).await?;
299 Ok(worker_messages_to_outbound(&response, channel))
300}
301
302#[cfg(feature = "nats")]
303pub struct NatsWorkerClient {
304 client: async_nats::Client,
305 subject: String,
306 max_retries: u8,
307}
308
309#[cfg(feature = "nats")]
310impl NatsWorkerClient {
311 pub fn new(client: async_nats::Client, subject: String, max_retries: u8) -> Self {
312 Self {
313 client,
314 subject,
315 max_retries,
316 }
317 }
318
319 async fn send_once(
320 &self,
321 request: &WorkerRequest,
322 ) -> Result<WorkerResponse, WorkerClientError> {
323 let bytes = serde_json::to_vec(request).map_err(WorkerClientError::Serialize)?;
324 let msg = self
325 .client
326 .request(self.subject.clone(), bytes.into())
327 .await
328 .map_err(|e| WorkerClientError::Nats(anyhow::Error::new(e)))?;
329 serde_json::from_slice(&msg.payload).map_err(WorkerClientError::Deserialize)
330 }
331}
332
333#[cfg(feature = "nats")]
334#[async_trait]
335impl WorkerClient for NatsWorkerClient {
336 async fn send_request(
337 &self,
338 request: WorkerRequest,
339 ) -> Result<WorkerResponse, WorkerClientError> {
340 let mut attempt = 0;
341 loop {
342 attempt += 1;
343 match self.send_once(&request).await {
344 Ok(res) => return Ok(res),
345 Err(err) => {
346 if attempt > self.max_retries {
347 return Err(err);
348 }
349 warn!(attempt, subject = %self.subject, error = %err, "retrying worker request over NATS");
350 tokio::time::sleep(Duration::from_millis(50 * attempt as u64)).await;
351 }
352 }
353 }
354 }
355}
356
357pub struct HttpWorkerClient {
358 client: reqwest::Client,
359 url: String,
360 max_retries: u8,
361}
362
363impl HttpWorkerClient {
364 pub fn new(url: String, max_retries: u8) -> Self {
365 Self {
366 client: reqwest::Client::new(),
367 url,
368 max_retries,
369 }
370 }
371
372 async fn send_once(
373 &self,
374 request: &WorkerRequest,
375 ) -> Result<WorkerResponse, WorkerClientError> {
376 let response = self
377 .client
378 .post(&self.url)
379 .json(request)
380 .send()
381 .await
382 .map_err(|e| WorkerClientError::Http(anyhow::Error::new(e)))?;
383
384 if !response.status().is_success() {
385 let status = response.status();
386 let body = response.text().await.unwrap_or_default();
387 return Err(WorkerClientError::Http(anyhow::anyhow!(
388 "HTTP {} from worker endpoint: {}",
389 status,
390 body
391 )));
392 }
393
394 let body = response
395 .bytes()
396 .await
397 .map_err(|e| WorkerClientError::Http(anyhow::Error::new(e)))?;
398 serde_json::from_slice(&body).map_err(WorkerClientError::Deserialize)
399 }
400}
401
402#[async_trait]
403impl WorkerClient for HttpWorkerClient {
404 async fn send_request(
405 &self,
406 request: WorkerRequest,
407 ) -> Result<WorkerResponse, WorkerClientError> {
408 let mut attempt = 0;
409 loop {
410 attempt += 1;
411 match self.send_once(&request).await {
412 Ok(res) => return Ok(res),
413 Err(err) => {
414 if attempt > self.max_retries {
415 error!(attempt, url = %self.url, error = %err, "worker HTTP request failed");
416 return Err(err);
417 }
418 warn!(attempt, url = %self.url, error = %err, "retrying worker HTTP request");
419 tokio::time::sleep(Duration::from_millis(50 * attempt as u64)).await;
420 }
421 }
422 }
423 }
424}
425
426#[cfg(test)]
427mod tests {
428 use super::*;
429
430 fn sample_channel() -> ChannelMessage {
431 ChannelMessage {
432 tenant: crate::make_tenant_ctx("acme".into(), Some("team".into()), None),
433 channel_id: "webchat".into(),
434 session_id: "sess-1".into(),
435 route: None,
436 payload: serde_json::json!({"text": "hi"}),
437 }
438 }
439
440 #[tokio::test]
441 async fn builds_request_and_maps_response() {
442 let channel = sample_channel();
443 let config = WorkerRoutingConfig::default();
444 let payload = serde_json::json!({"body": "hello"});
445 let corr = Some("corr-1".to_string());
446 let client = InMemoryWorkerClient::new(|req| {
447 assert_eq!(req.version, WORKER_ENVELOPE_VERSION);
448 assert_eq!(req.worker_id, DEFAULT_WORKER_ID);
449 assert_eq!(req.session_id.as_deref(), Some("sess-1"));
450 assert_eq!(req.correlation_id.as_deref(), Some("corr-1"));
451 let decoded: Value = serde_json::from_str(&req.payload_json).unwrap();
452 assert_eq!(decoded["body"], "hello");
453 let mut resp = empty_worker_response_for(&req);
454 resp.messages = vec![WorkerMessage {
455 kind: "text".into(),
456 payload_json: serde_json::to_string(&serde_json::json!({"reply": "pong"})).unwrap(),
457 }];
458 resp
459 });
460
461 let outbound = forward_to_worker(&client, &channel, payload, &config, corr)
462 .await
463 .unwrap();
464
465 assert_eq!(outbound.len(), 1);
466 assert_eq!(outbound[0].channel_id, "webchat");
467 assert_eq!(outbound[0].body["reply"], "pong");
468 assert_eq!(outbound[0].tenant.tenant.as_str(), "acme");
469 assert_eq!(outbound[0].session_id, "sess-1");
470 assert_eq!(outbound[0].meta["kind"], "text");
471 assert_eq!(outbound[0].meta["worker_id"], DEFAULT_WORKER_ID);
472 assert_eq!(outbound[0].meta["correlation_id"], "corr-1");
473 }
474
475 #[tokio::test]
476 async fn populates_thread_and_correlation_defaults() {
477 let mut channel = sample_channel();
478 channel.payload = serde_json::json!({"text": "ping", "thread_id": "thr-1"});
479 let config = WorkerRoutingConfig::default();
480 let payload = serde_json::json!({"body": "hello"});
481
482 let client = InMemoryWorkerClient::new(|req| {
483 assert_eq!(req.thread_id.as_deref(), Some("thr-1"));
484 assert!(req.correlation_id.is_some());
485 empty_worker_response_for(&req)
486 });
487
488 let outbound = forward_to_worker(&client, &channel, payload, &config, None)
489 .await
490 .unwrap();
491
492 assert_eq!(outbound.len(), 0);
493 }
494}