Skip to main content

gsm_core/
worker.rs

1use crate::{ChannelMessage, OutboundEnvelope, TenantCtx};
2use async_trait::async_trait;
3use serde_json::{Map, Value};
4use std::collections::BTreeMap;
5use std::str::FromStr;
6use std::time::Duration;
7use thiserror::Error;
8use time::OffsetDateTime;
9use tracing::{error, warn};
10use uuid::Uuid;
11
12/// Worker envelope schema version.
13pub const WORKER_ENVELOPE_VERSION: &str = "1.0";
14/// Canonical default worker identifier for the repo assistant.
15pub const DEFAULT_WORKER_ID: &str = "greentic-repo-assistant";
16/// Default NATS subject used when no override is provided.
17pub const DEFAULT_WORKER_NATS_SUBJECT: &str = "workers.repo-assistant";
18
19pub use greentic_types::{WorkerMessage, WorkerRequest, WorkerResponse};
20
21/// Which transport to use for talking to the worker endpoint.
22#[derive(Clone, Debug, PartialEq, Eq)]
23pub enum WorkerTransport {
24    Nats,
25    Http,
26}
27
28impl WorkerTransport {
29    pub fn from_optional(value: Option<&str>) -> Self {
30        match value.unwrap_or("nats").to_ascii_lowercase().as_str() {
31            "http" => WorkerTransport::Http,
32            _ => WorkerTransport::Nats,
33        }
34    }
35}
36
37impl FromStr for WorkerTransport {
38    type Err = ();
39
40    fn from_str(s: &str) -> Result<Self, Self::Err> {
41        Ok(match s.to_ascii_lowercase().as_str() {
42            "http" => WorkerTransport::Http,
43            _ => WorkerTransport::Nats,
44        })
45    }
46}
47
48/// Routing configuration for the repo worker.
49#[derive(Clone, Debug)]
50pub struct WorkerRoutingConfig {
51    pub transport: WorkerTransport,
52    pub worker_id: String,
53    pub nats_subject: String,
54    pub http_url: Option<String>,
55    /// How many transient retries to attempt locally before surfacing an error.
56    pub max_retries: u8,
57}
58
59impl Default for WorkerRoutingConfig {
60    fn default() -> Self {
61        Self {
62            transport: WorkerTransport::Nats,
63            worker_id: DEFAULT_WORKER_ID.to_string(),
64            nats_subject: DEFAULT_WORKER_NATS_SUBJECT.to_string(),
65            http_url: None,
66            max_retries: 2,
67        }
68    }
69}
70
71impl WorkerRoutingConfig {
72    pub fn from_settings(
73        transport: WorkerTransport,
74        worker_id: String,
75        nats_subject: String,
76        http_url: Option<String>,
77        max_retries: u8,
78    ) -> Self {
79        Self {
80            transport,
81            worker_id,
82            nats_subject,
83            http_url,
84            max_retries,
85        }
86    }
87
88    pub fn from_route_spec(worker_id: &str, transport: WorkerTransport, target: &str) -> Self {
89        match transport {
90            WorkerTransport::Nats => WorkerRoutingConfig {
91                transport,
92                worker_id: worker_id.to_string(),
93                nats_subject: target.to_string(),
94                http_url: None,
95                max_retries: 2,
96            },
97            WorkerTransport::Http => WorkerRoutingConfig {
98                transport,
99                worker_id: worker_id.to_string(),
100                nats_subject: DEFAULT_WORKER_NATS_SUBJECT.to_string(),
101                http_url: Some(target.to_string()),
102                max_retries: 2,
103            },
104        }
105    }
106}
107
108/// Parse a simple worker route map from the `WORKER_ROUTES` env var.
109///
110/// Format: `worker_id=transport:target,worker_id2=http:https://example`
111/// transport: `nats` uses `target` as subject; `http` uses `target` as URL.
112pub fn worker_routes_from_specs(raw: &str) -> BTreeMap<String, WorkerRoutingConfig> {
113    let mut map = BTreeMap::new();
114    for entry in raw.split(',').map(str::trim).filter(|s| !s.is_empty()) {
115        if let Some((id, spec)) = entry.split_once('=')
116            && let Some((transport_raw, target)) = spec.split_once(':')
117        {
118            let transport = WorkerTransport::from_optional(Some(transport_raw));
119            let cfg = WorkerRoutingConfig::from_route_spec(id.trim(), transport, target.trim());
120            map.insert(id.trim().to_string(), cfg);
121        }
122    }
123    map
124}
125
126fn now_timestamp_utc() -> String {
127    OffsetDateTime::now_utc()
128        .format(&time::format_description::well_known::Rfc3339)
129        .unwrap_or_else(|_| OffsetDateTime::now_utc().unix_timestamp().to_string())
130}
131
132fn encode_payload(payload: &Value) -> Result<String, WorkerClientError> {
133    serde_json::to_string(payload).map_err(WorkerClientError::PayloadEncode)
134}
135
136fn decode_payload(payload_json: &str) -> Value {
137    serde_json::from_str(payload_json).unwrap_or_else(|_| Value::String(payload_json.to_string()))
138}
139
140fn build_worker_request(
141    tenant: TenantCtx,
142    worker_id: String,
143    payload: Value,
144    session_id: Option<String>,
145    thread_id: Option<String>,
146    correlation_id: Option<String>,
147) -> Result<WorkerRequest, WorkerClientError> {
148    Ok(WorkerRequest {
149        version: WORKER_ENVELOPE_VERSION.to_string(),
150        tenant,
151        worker_id,
152        correlation_id: Some(correlation_id.unwrap_or_else(|| Uuid::new_v4().to_string())),
153        session_id,
154        thread_id,
155        payload_json: encode_payload(&payload)?,
156        timestamp_utc: now_timestamp_utc(),
157    })
158}
159
160fn worker_request_from_channel(
161    channel: &ChannelMessage,
162    payload: Value,
163    config: &WorkerRoutingConfig,
164    correlation_id: Option<String>,
165) -> Result<WorkerRequest, WorkerClientError> {
166    let correlation = correlation_id
167        .or_else(|| {
168            channel
169                .payload
170                .get("correlation_id")
171                .and_then(|v| v.as_str())
172                .map(str::to_string)
173        })
174        .or_else(|| {
175            channel
176                .payload
177                .get("msg_id")
178                .and_then(|v| v.as_str())
179                .map(str::to_string)
180        });
181
182    let thread_id = channel
183        .payload
184        .get("thread_id")
185        .and_then(|v| v.as_str())
186        .map(str::to_string);
187
188    build_worker_request(
189        channel.tenant.clone(),
190        config.worker_id.clone(),
191        payload,
192        Some(channel.session_id.clone()),
193        thread_id,
194        correlation,
195    )
196}
197
198pub fn empty_worker_response_for(request: &WorkerRequest) -> WorkerResponse {
199    WorkerResponse {
200        version: request.version.clone(),
201        tenant: request.tenant.clone(),
202        worker_id: request.worker_id.clone(),
203        correlation_id: request.correlation_id.clone(),
204        session_id: request.session_id.clone(),
205        thread_id: request.thread_id.clone(),
206        messages: Vec::new(),
207        timestamp_utc: now_timestamp_utc(),
208    }
209}
210
211/// Converts a worker response into outbound envelopes targeting the same channel context.
212pub fn worker_messages_to_outbound(
213    response: &WorkerResponse,
214    channel: &ChannelMessage,
215) -> Vec<OutboundEnvelope> {
216    response
217        .messages
218        .iter()
219        .map(|msg| {
220            let mut meta = Map::new();
221            meta.insert(
222                "worker_id".into(),
223                Value::String(response.worker_id.clone()),
224            );
225            if let Some(corr) = &response.correlation_id {
226                meta.insert("correlation_id".into(), Value::String(corr.clone()));
227            }
228            meta.insert("kind".into(), Value::String(msg.kind.clone()));
229
230            OutboundEnvelope {
231                tenant: channel.tenant.clone(),
232                channel_id: channel.channel_id.clone(),
233                session_id: channel.session_id.clone(),
234                meta: Value::Object(meta),
235                body: decode_payload(&msg.payload_json),
236            }
237        })
238        .collect()
239}
240
241#[derive(Debug, Error)]
242pub enum WorkerClientError {
243    #[error("failed to encode worker payload: {0}")]
244    PayloadEncode(#[source] serde_json::Error),
245    #[error("failed to serialize worker request: {0}")]
246    Serialize(#[source] serde_json::Error),
247    #[error("failed to deserialize worker response: {0}")]
248    Deserialize(#[source] serde_json::Error),
249    #[error("NATS request failed: {0}")]
250    Nats(#[source] anyhow::Error),
251    #[error("HTTP request failed: {0}")]
252    Http(#[source] anyhow::Error),
253}
254
255#[async_trait]
256pub trait WorkerClient: Send + Sync {
257    async fn send_request(
258        &self,
259        request: WorkerRequest,
260    ) -> Result<WorkerResponse, WorkerClientError>;
261}
262
263/// In-memory client used in tests.
264pub struct InMemoryWorkerClient {
265    responder: Box<dyn Fn(WorkerRequest) -> WorkerResponse + Send + Sync>,
266}
267
268impl InMemoryWorkerClient {
269    pub fn new<F>(responder: F) -> Self
270    where
271        F: Fn(WorkerRequest) -> WorkerResponse + Send + Sync + 'static,
272    {
273        Self {
274            responder: Box::new(responder),
275        }
276    }
277}
278
279#[async_trait]
280impl WorkerClient for InMemoryWorkerClient {
281    async fn send_request(
282        &self,
283        request: WorkerRequest,
284    ) -> Result<WorkerResponse, WorkerClientError> {
285        Ok((self.responder)(request))
286    }
287}
288
289/// Sends a worker request via the provided client and maps the response back to outbound envelopes.
290pub async fn forward_to_worker(
291    client: &dyn WorkerClient,
292    channel: &ChannelMessage,
293    payload: Value,
294    config: &WorkerRoutingConfig,
295    correlation_id: Option<String>,
296) -> Result<Vec<OutboundEnvelope>, WorkerClientError> {
297    let request = worker_request_from_channel(channel, payload, config, correlation_id)?;
298    let response = client.send_request(request).await?;
299    Ok(worker_messages_to_outbound(&response, channel))
300}
301
302#[cfg(feature = "nats")]
303pub struct NatsWorkerClient {
304    client: async_nats::Client,
305    subject: String,
306    max_retries: u8,
307}
308
309#[cfg(feature = "nats")]
310impl NatsWorkerClient {
311    pub fn new(client: async_nats::Client, subject: String, max_retries: u8) -> Self {
312        Self {
313            client,
314            subject,
315            max_retries,
316        }
317    }
318
319    async fn send_once(
320        &self,
321        request: &WorkerRequest,
322    ) -> Result<WorkerResponse, WorkerClientError> {
323        let bytes = serde_json::to_vec(request).map_err(WorkerClientError::Serialize)?;
324        let msg = self
325            .client
326            .request(self.subject.clone(), bytes.into())
327            .await
328            .map_err(|e| WorkerClientError::Nats(anyhow::Error::new(e)))?;
329        serde_json::from_slice(&msg.payload).map_err(WorkerClientError::Deserialize)
330    }
331}
332
333#[cfg(feature = "nats")]
334#[async_trait]
335impl WorkerClient for NatsWorkerClient {
336    async fn send_request(
337        &self,
338        request: WorkerRequest,
339    ) -> Result<WorkerResponse, WorkerClientError> {
340        let mut attempt = 0;
341        loop {
342            attempt += 1;
343            match self.send_once(&request).await {
344                Ok(res) => return Ok(res),
345                Err(err) => {
346                    if attempt > self.max_retries {
347                        return Err(err);
348                    }
349                    warn!(attempt, subject = %self.subject, error = %err, "retrying worker request over NATS");
350                    tokio::time::sleep(Duration::from_millis(50 * attempt as u64)).await;
351                }
352            }
353        }
354    }
355}
356
357pub struct HttpWorkerClient {
358    client: reqwest::Client,
359    url: String,
360    max_retries: u8,
361}
362
363impl HttpWorkerClient {
364    pub fn new(url: String, max_retries: u8) -> Self {
365        Self {
366            client: reqwest::Client::new(),
367            url,
368            max_retries,
369        }
370    }
371
372    async fn send_once(
373        &self,
374        request: &WorkerRequest,
375    ) -> Result<WorkerResponse, WorkerClientError> {
376        let response = self
377            .client
378            .post(&self.url)
379            .json(request)
380            .send()
381            .await
382            .map_err(|e| WorkerClientError::Http(anyhow::Error::new(e)))?;
383
384        if !response.status().is_success() {
385            let status = response.status();
386            let body = response.text().await.unwrap_or_default();
387            return Err(WorkerClientError::Http(anyhow::anyhow!(
388                "HTTP {} from worker endpoint: {}",
389                status,
390                body
391            )));
392        }
393
394        let body = response
395            .bytes()
396            .await
397            .map_err(|e| WorkerClientError::Http(anyhow::Error::new(e)))?;
398        serde_json::from_slice(&body).map_err(WorkerClientError::Deserialize)
399    }
400}
401
402#[async_trait]
403impl WorkerClient for HttpWorkerClient {
404    async fn send_request(
405        &self,
406        request: WorkerRequest,
407    ) -> Result<WorkerResponse, WorkerClientError> {
408        let mut attempt = 0;
409        loop {
410            attempt += 1;
411            match self.send_once(&request).await {
412                Ok(res) => return Ok(res),
413                Err(err) => {
414                    if attempt > self.max_retries {
415                        error!(attempt, url = %self.url, error = %err, "worker HTTP request failed");
416                        return Err(err);
417                    }
418                    warn!(attempt, url = %self.url, error = %err, "retrying worker HTTP request");
419                    tokio::time::sleep(Duration::from_millis(50 * attempt as u64)).await;
420                }
421            }
422        }
423    }
424}
425
426#[cfg(test)]
427mod tests {
428    use super::*;
429
430    fn sample_channel() -> ChannelMessage {
431        ChannelMessage {
432            tenant: crate::make_tenant_ctx("acme".into(), Some("team".into()), None),
433            channel_id: "webchat".into(),
434            session_id: "sess-1".into(),
435            route: None,
436            payload: serde_json::json!({"text": "hi"}),
437        }
438    }
439
440    #[tokio::test]
441    async fn builds_request_and_maps_response() {
442        let channel = sample_channel();
443        let config = WorkerRoutingConfig::default();
444        let payload = serde_json::json!({"body": "hello"});
445        let corr = Some("corr-1".to_string());
446        let client = InMemoryWorkerClient::new(|req| {
447            assert_eq!(req.version, WORKER_ENVELOPE_VERSION);
448            assert_eq!(req.worker_id, DEFAULT_WORKER_ID);
449            assert_eq!(req.session_id.as_deref(), Some("sess-1"));
450            assert_eq!(req.correlation_id.as_deref(), Some("corr-1"));
451            let decoded: Value = serde_json::from_str(&req.payload_json).unwrap();
452            assert_eq!(decoded["body"], "hello");
453            let mut resp = empty_worker_response_for(&req);
454            resp.messages = vec![WorkerMessage {
455                kind: "text".into(),
456                payload_json: serde_json::to_string(&serde_json::json!({"reply": "pong"})).unwrap(),
457            }];
458            resp
459        });
460
461        let outbound = forward_to_worker(&client, &channel, payload, &config, corr)
462            .await
463            .unwrap();
464
465        assert_eq!(outbound.len(), 1);
466        assert_eq!(outbound[0].channel_id, "webchat");
467        assert_eq!(outbound[0].body["reply"], "pong");
468        assert_eq!(outbound[0].tenant.tenant.as_str(), "acme");
469        assert_eq!(outbound[0].session_id, "sess-1");
470        assert_eq!(outbound[0].meta["kind"], "text");
471        assert_eq!(outbound[0].meta["worker_id"], DEFAULT_WORKER_ID);
472        assert_eq!(outbound[0].meta["correlation_id"], "corr-1");
473    }
474
475    #[tokio::test]
476    async fn populates_thread_and_correlation_defaults() {
477        let mut channel = sample_channel();
478        channel.payload = serde_json::json!({"text": "ping", "thread_id": "thr-1"});
479        let config = WorkerRoutingConfig::default();
480        let payload = serde_json::json!({"body": "hello"});
481
482        let client = InMemoryWorkerClient::new(|req| {
483            assert_eq!(req.thread_id.as_deref(), Some("thr-1"));
484            assert!(req.correlation_id.is_some());
485            empty_worker_response_for(&req)
486        });
487
488        let outbound = forward_to_worker(&client, &channel, payload, &config, None)
489            .await
490            .unwrap();
491
492        assert_eq!(outbound.len(), 0);
493    }
494}