gsm_core/
worker.rs

1use crate::{ChannelMessage, OutboundEnvelope, TenantCtx};
2use async_trait::async_trait;
3use serde_json::{Map, Value};
4use std::collections::BTreeMap;
5use std::time::Duration;
6use thiserror::Error;
7use time::OffsetDateTime;
8use tracing::{error, warn};
9use uuid::Uuid;
10
11/// Worker envelope schema version.
12pub const WORKER_ENVELOPE_VERSION: &str = "1.0";
13/// Canonical default worker identifier for the repo assistant.
14pub const DEFAULT_WORKER_ID: &str = "greentic-repo-assistant";
15/// Default NATS subject used when no override is provided.
16pub const DEFAULT_WORKER_NATS_SUBJECT: &str = "workers.repo-assistant";
17
18pub use greentic_types::{WorkerMessage, WorkerRequest, WorkerResponse};
19
20/// Which transport to use for talking to the worker endpoint.
21#[derive(Clone, Debug, PartialEq, Eq)]
22pub enum WorkerTransport {
23    Nats,
24    Http,
25}
26
27impl WorkerTransport {
28    pub fn from_env(value: Option<String>) -> Self {
29        match value
30            .unwrap_or_else(|| "nats".to_string())
31            .to_ascii_lowercase()
32            .as_str()
33        {
34            "http" => WorkerTransport::Http,
35            _ => WorkerTransport::Nats,
36        }
37    }
38}
39
40/// Routing configuration for the repo worker.
41#[derive(Clone, Debug)]
42pub struct WorkerRoutingConfig {
43    pub transport: WorkerTransport,
44    pub worker_id: String,
45    pub nats_subject: String,
46    pub http_url: Option<String>,
47    /// How many transient retries to attempt locally before surfacing an error.
48    pub max_retries: u8,
49}
50
51impl Default for WorkerRoutingConfig {
52    fn default() -> Self {
53        Self {
54            transport: WorkerTransport::Nats,
55            worker_id: DEFAULT_WORKER_ID.to_string(),
56            nats_subject: DEFAULT_WORKER_NATS_SUBJECT.to_string(),
57            http_url: None,
58            max_retries: 2,
59        }
60    }
61}
62
63impl WorkerRoutingConfig {
64    pub fn from_env() -> Self {
65        let transport = WorkerTransport::from_env(std::env::var("REPO_WORKER_TRANSPORT").ok());
66        let worker_id = std::env::var("REPO_WORKER_ID")
67            .ok()
68            .filter(|v| !v.is_empty())
69            .unwrap_or_else(|| DEFAULT_WORKER_ID.to_string());
70        let nats_subject = std::env::var("REPO_WORKER_NATS_SUBJECT")
71            .ok()
72            .filter(|v| !v.is_empty())
73            .unwrap_or_else(|| DEFAULT_WORKER_NATS_SUBJECT.to_string());
74        let http_url = std::env::var("REPO_WORKER_HTTP_URL").ok();
75        let max_retries = std::env::var("REPO_WORKER_RETRIES")
76            .ok()
77            .and_then(|v| v.parse::<u8>().ok())
78            .unwrap_or(2);
79
80        Self {
81            transport,
82            worker_id,
83            nats_subject,
84            http_url,
85            max_retries,
86        }
87    }
88
89    pub fn from_route_spec(worker_id: &str, transport: WorkerTransport, target: &str) -> Self {
90        match transport {
91            WorkerTransport::Nats => WorkerRoutingConfig {
92                transport,
93                worker_id: worker_id.to_string(),
94                nats_subject: target.to_string(),
95                http_url: None,
96                max_retries: 2,
97            },
98            WorkerTransport::Http => WorkerRoutingConfig {
99                transport,
100                worker_id: worker_id.to_string(),
101                nats_subject: DEFAULT_WORKER_NATS_SUBJECT.to_string(),
102                http_url: Some(target.to_string()),
103                max_retries: 2,
104            },
105        }
106    }
107}
108
109/// Parse a simple worker route map from the `WORKER_ROUTES` env var.
110///
111/// Format: `worker_id=transport:target,worker_id2=http:https://example`
112/// transport: `nats` uses `target` as subject; `http` uses `target` as URL.
113pub fn worker_routes_from_env() -> BTreeMap<String, WorkerRoutingConfig> {
114    let raw = match std::env::var("WORKER_ROUTES") {
115        Ok(v) => v,
116        Err(_) => return BTreeMap::new(),
117    };
118    let mut map = BTreeMap::new();
119    for entry in raw.split(',').map(str::trim).filter(|s| !s.is_empty()) {
120        if let Some((id, spec)) = entry.split_once('=')
121            && let Some((transport_raw, target)) = spec.split_once(':')
122        {
123            let transport = WorkerTransport::from_env(Some(transport_raw.to_string()));
124            let cfg = WorkerRoutingConfig::from_route_spec(id.trim(), transport, target.trim());
125            map.insert(id.trim().to_string(), cfg);
126        }
127    }
128    map
129}
130
131fn now_timestamp_utc() -> String {
132    OffsetDateTime::now_utc()
133        .format(&time::format_description::well_known::Rfc3339)
134        .unwrap_or_else(|_| OffsetDateTime::now_utc().unix_timestamp().to_string())
135}
136
137fn encode_payload(payload: &Value) -> Result<String, WorkerClientError> {
138    serde_json::to_string(payload).map_err(WorkerClientError::PayloadEncode)
139}
140
141fn decode_payload(payload_json: &str) -> Value {
142    serde_json::from_str(payload_json).unwrap_or_else(|_| Value::String(payload_json.to_string()))
143}
144
145fn build_worker_request(
146    tenant: TenantCtx,
147    worker_id: String,
148    payload: Value,
149    session_id: Option<String>,
150    thread_id: Option<String>,
151    correlation_id: Option<String>,
152) -> Result<WorkerRequest, WorkerClientError> {
153    Ok(WorkerRequest {
154        version: WORKER_ENVELOPE_VERSION.to_string(),
155        tenant,
156        worker_id,
157        correlation_id: Some(correlation_id.unwrap_or_else(|| Uuid::new_v4().to_string())),
158        session_id,
159        thread_id,
160        payload_json: encode_payload(&payload)?,
161        timestamp_utc: now_timestamp_utc(),
162    })
163}
164
165fn worker_request_from_channel(
166    channel: &ChannelMessage,
167    payload: Value,
168    config: &WorkerRoutingConfig,
169    correlation_id: Option<String>,
170) -> Result<WorkerRequest, WorkerClientError> {
171    let correlation = correlation_id
172        .or_else(|| {
173            channel
174                .payload
175                .get("correlation_id")
176                .and_then(|v| v.as_str())
177                .map(str::to_string)
178        })
179        .or_else(|| {
180            channel
181                .payload
182                .get("msg_id")
183                .and_then(|v| v.as_str())
184                .map(str::to_string)
185        });
186
187    let thread_id = channel
188        .payload
189        .get("thread_id")
190        .and_then(|v| v.as_str())
191        .map(str::to_string);
192
193    build_worker_request(
194        channel.tenant.clone(),
195        config.worker_id.clone(),
196        payload,
197        Some(channel.session_id.clone()),
198        thread_id,
199        correlation,
200    )
201}
202
203pub fn empty_worker_response_for(request: &WorkerRequest) -> WorkerResponse {
204    WorkerResponse {
205        version: request.version.clone(),
206        tenant: request.tenant.clone(),
207        worker_id: request.worker_id.clone(),
208        correlation_id: request.correlation_id.clone(),
209        session_id: request.session_id.clone(),
210        thread_id: request.thread_id.clone(),
211        messages: Vec::new(),
212        timestamp_utc: now_timestamp_utc(),
213    }
214}
215
216/// Converts a worker response into outbound envelopes targeting the same channel context.
217pub fn worker_messages_to_outbound(
218    response: &WorkerResponse,
219    channel: &ChannelMessage,
220) -> Vec<OutboundEnvelope> {
221    response
222        .messages
223        .iter()
224        .map(|msg| {
225            let mut meta = Map::new();
226            meta.insert(
227                "worker_id".into(),
228                Value::String(response.worker_id.clone()),
229            );
230            if let Some(corr) = &response.correlation_id {
231                meta.insert("correlation_id".into(), Value::String(corr.clone()));
232            }
233            meta.insert("kind".into(), Value::String(msg.kind.clone()));
234
235            OutboundEnvelope {
236                tenant: channel.tenant.clone(),
237                channel_id: channel.channel_id.clone(),
238                session_id: channel.session_id.clone(),
239                meta: Value::Object(meta),
240                body: decode_payload(&msg.payload_json),
241            }
242        })
243        .collect()
244}
245
246#[derive(Debug, Error)]
247pub enum WorkerClientError {
248    #[error("failed to encode worker payload: {0}")]
249    PayloadEncode(#[source] serde_json::Error),
250    #[error("failed to serialize worker request: {0}")]
251    Serialize(#[source] serde_json::Error),
252    #[error("failed to deserialize worker response: {0}")]
253    Deserialize(#[source] serde_json::Error),
254    #[error("NATS request failed: {0}")]
255    Nats(#[source] anyhow::Error),
256    #[error("HTTP request failed: {0}")]
257    Http(#[source] anyhow::Error),
258}
259
260#[async_trait]
261pub trait WorkerClient: Send + Sync {
262    async fn send_request(
263        &self,
264        request: WorkerRequest,
265    ) -> Result<WorkerResponse, WorkerClientError>;
266}
267
268/// In-memory client used in tests.
269pub struct InMemoryWorkerClient {
270    responder: Box<dyn Fn(WorkerRequest) -> WorkerResponse + Send + Sync>,
271}
272
273impl InMemoryWorkerClient {
274    pub fn new<F>(responder: F) -> Self
275    where
276        F: Fn(WorkerRequest) -> WorkerResponse + Send + Sync + 'static,
277    {
278        Self {
279            responder: Box::new(responder),
280        }
281    }
282}
283
284#[async_trait]
285impl WorkerClient for InMemoryWorkerClient {
286    async fn send_request(
287        &self,
288        request: WorkerRequest,
289    ) -> Result<WorkerResponse, WorkerClientError> {
290        Ok((self.responder)(request))
291    }
292}
293
294/// Sends a worker request via the provided client and maps the response back to outbound envelopes.
295pub async fn forward_to_worker(
296    client: &dyn WorkerClient,
297    channel: &ChannelMessage,
298    payload: Value,
299    config: &WorkerRoutingConfig,
300    correlation_id: Option<String>,
301) -> Result<Vec<OutboundEnvelope>, WorkerClientError> {
302    let request = worker_request_from_channel(channel, payload, config, correlation_id)?;
303    let response = client.send_request(request).await?;
304    Ok(worker_messages_to_outbound(&response, channel))
305}
306
307#[cfg(feature = "nats")]
308pub struct NatsWorkerClient {
309    client: async_nats::Client,
310    subject: String,
311    max_retries: u8,
312}
313
314#[cfg(feature = "nats")]
315impl NatsWorkerClient {
316    pub fn new(client: async_nats::Client, subject: String, max_retries: u8) -> Self {
317        Self {
318            client,
319            subject,
320            max_retries,
321        }
322    }
323
324    async fn send_once(
325        &self,
326        request: &WorkerRequest,
327    ) -> Result<WorkerResponse, WorkerClientError> {
328        let bytes = serde_json::to_vec(request).map_err(WorkerClientError::Serialize)?;
329        let msg = self
330            .client
331            .request(self.subject.clone(), bytes.into())
332            .await
333            .map_err(|e| WorkerClientError::Nats(anyhow::Error::new(e)))?;
334        serde_json::from_slice(&msg.payload).map_err(WorkerClientError::Deserialize)
335    }
336}
337
338#[cfg(feature = "nats")]
339#[async_trait]
340impl WorkerClient for NatsWorkerClient {
341    async fn send_request(
342        &self,
343        request: WorkerRequest,
344    ) -> Result<WorkerResponse, WorkerClientError> {
345        let mut attempt = 0;
346        loop {
347            attempt += 1;
348            match self.send_once(&request).await {
349                Ok(res) => return Ok(res),
350                Err(err) => {
351                    if attempt > self.max_retries {
352                        return Err(err);
353                    }
354                    warn!(attempt, subject = %self.subject, error = %err, "retrying worker request over NATS");
355                    tokio::time::sleep(Duration::from_millis(50 * attempt as u64)).await;
356                }
357            }
358        }
359    }
360}
361
362pub struct HttpWorkerClient {
363    client: reqwest::Client,
364    url: String,
365    max_retries: u8,
366}
367
368impl HttpWorkerClient {
369    pub fn new(url: String, max_retries: u8) -> Self {
370        Self {
371            client: reqwest::Client::new(),
372            url,
373            max_retries,
374        }
375    }
376
377    async fn send_once(
378        &self,
379        request: &WorkerRequest,
380    ) -> Result<WorkerResponse, WorkerClientError> {
381        let response = self
382            .client
383            .post(&self.url)
384            .json(request)
385            .send()
386            .await
387            .map_err(|e| WorkerClientError::Http(anyhow::Error::new(e)))?;
388
389        if !response.status().is_success() {
390            let status = response.status();
391            let body = response.text().await.unwrap_or_default();
392            return Err(WorkerClientError::Http(anyhow::anyhow!(
393                "HTTP {} from worker endpoint: {}",
394                status,
395                body
396            )));
397        }
398
399        let body = response
400            .bytes()
401            .await
402            .map_err(|e| WorkerClientError::Http(anyhow::Error::new(e)))?;
403        serde_json::from_slice(&body).map_err(WorkerClientError::Deserialize)
404    }
405}
406
407#[async_trait]
408impl WorkerClient for HttpWorkerClient {
409    async fn send_request(
410        &self,
411        request: WorkerRequest,
412    ) -> Result<WorkerResponse, WorkerClientError> {
413        let mut attempt = 0;
414        loop {
415            attempt += 1;
416            match self.send_once(&request).await {
417                Ok(res) => return Ok(res),
418                Err(err) => {
419                    if attempt > self.max_retries {
420                        error!(attempt, url = %self.url, error = %err, "worker HTTP request failed");
421                        return Err(err);
422                    }
423                    warn!(attempt, url = %self.url, error = %err, "retrying worker HTTP request");
424                    tokio::time::sleep(Duration::from_millis(50 * attempt as u64)).await;
425                }
426            }
427        }
428    }
429}
430
431#[cfg(test)]
432mod tests {
433    use super::*;
434
435    fn sample_channel() -> ChannelMessage {
436        ChannelMessage {
437            tenant: crate::make_tenant_ctx("acme".into(), Some("team".into()), None),
438            channel_id: "webchat".into(),
439            session_id: "sess-1".into(),
440            route: None,
441            payload: serde_json::json!({"text": "hi"}),
442        }
443    }
444
445    #[tokio::test]
446    async fn builds_request_and_maps_response() {
447        let channel = sample_channel();
448        let config = WorkerRoutingConfig::default();
449        let payload = serde_json::json!({"body": "hello"});
450        let corr = Some("corr-1".to_string());
451        let client = InMemoryWorkerClient::new(|req| {
452            assert_eq!(req.version, WORKER_ENVELOPE_VERSION);
453            assert_eq!(req.worker_id, DEFAULT_WORKER_ID);
454            assert_eq!(req.session_id.as_deref(), Some("sess-1"));
455            assert_eq!(req.correlation_id.as_deref(), Some("corr-1"));
456            let decoded: Value = serde_json::from_str(&req.payload_json).unwrap();
457            assert_eq!(decoded["body"], "hello");
458            let mut resp = empty_worker_response_for(&req);
459            resp.messages = vec![WorkerMessage {
460                kind: "text".into(),
461                payload_json: serde_json::to_string(&serde_json::json!({"reply": "pong"})).unwrap(),
462            }];
463            resp
464        });
465
466        let outbound = forward_to_worker(&client, &channel, payload, &config, corr)
467            .await
468            .unwrap();
469
470        assert_eq!(outbound.len(), 1);
471        assert_eq!(outbound[0].channel_id, "webchat");
472        assert_eq!(outbound[0].body["reply"], "pong");
473        assert_eq!(outbound[0].tenant.tenant.as_str(), "acme");
474        assert_eq!(outbound[0].session_id, "sess-1");
475        assert_eq!(outbound[0].meta["kind"], "text");
476        assert_eq!(outbound[0].meta["worker_id"], DEFAULT_WORKER_ID);
477        assert_eq!(outbound[0].meta["correlation_id"], "corr-1");
478    }
479
480    #[tokio::test]
481    async fn populates_thread_and_correlation_defaults() {
482        let mut channel = sample_channel();
483        channel.payload = serde_json::json!({"text": "ping", "thread_id": "thr-1"});
484        let config = WorkerRoutingConfig::default();
485        let payload = serde_json::json!({"body": "hello"});
486
487        let client = InMemoryWorkerClient::new(|req| {
488            assert_eq!(req.thread_id.as_deref(), Some("thr-1"));
489            assert!(req.correlation_id.is_some());
490            empty_worker_response_for(&req)
491        });
492
493        let outbound = forward_to_worker(&client, &channel, payload, &config, None)
494            .await
495            .unwrap();
496
497        assert_eq!(outbound.len(), 0);
498    }
499}