Skip to main content

rrq_runner/
registry.rs

1use async_trait::async_trait;
2use std::collections::HashMap;
3use std::future::Future;
4use std::sync::Arc;
5
6use std::time::Instant;
7
8use tracing::Instrument;
9
10use crate::telemetry::{NoopTelemetry, Telemetry};
11use crate::types::{ExecutionOutcome, ExecutionRequest, OutcomeStatus};
12
13#[async_trait]
14pub trait Handler: Send + Sync {
15    async fn handle(&self, request: ExecutionRequest) -> ExecutionOutcome;
16}
17
18struct FnHandler<F>(F);
19
20#[async_trait]
21impl<F, Fut> Handler for FnHandler<F>
22where
23    F: Fn(ExecutionRequest) -> Fut + Send + Sync,
24    Fut: Future<Output = ExecutionOutcome> + Send,
25{
26    async fn handle(&self, request: ExecutionRequest) -> ExecutionOutcome {
27        (self.0)(request).await
28    }
29}
30
31#[derive(Clone, Default)]
32pub struct Registry {
33    handlers: HashMap<String, Arc<dyn Handler>>,
34}
35
36impl Registry {
37    #[must_use]
38    pub fn new() -> Self {
39        Self {
40            handlers: HashMap::new(),
41        }
42    }
43
44    pub fn register<F, Fut>(&mut self, name: impl Into<String>, handler: F)
45    where
46        F: Fn(ExecutionRequest) -> Fut + Send + Sync + 'static,
47        Fut: Future<Output = ExecutionOutcome> + Send + 'static,
48    {
49        let handler = Arc::new(FnHandler(handler)) as Arc<dyn Handler>;
50        self.handlers.insert(name.into(), handler);
51    }
52
53    #[must_use]
54    pub fn get(&self, name: &str) -> Option<Arc<dyn Handler>> {
55        self.handlers.get(name).cloned()
56    }
57
58    pub async fn execute(&self, request: ExecutionRequest) -> ExecutionOutcome {
59        let telemetry = NoopTelemetry;
60        self.execute_with(request, &telemetry).await
61    }
62
63    pub async fn execute_with<T: Telemetry + ?Sized>(
64        &self,
65        request: ExecutionRequest,
66        telemetry: &T,
67    ) -> ExecutionOutcome {
68        let span = telemetry.runner_span(&request);
69        let start = Instant::now();
70        let function_name = request.function_name.clone();
71        let job_id = request.job_id.clone();
72        let request_id = request.request_id.clone();
73        let mut outcome = match self.get(&function_name) {
74            Some(handler) => handler.handle(request).instrument(span.clone()).await,
75            None => ExecutionOutcome::handler_not_found(
76                job_id.clone(),
77                request_id.clone(),
78                format!("No handler registered for function '{function_name}'"),
79            ),
80        };
81        if outcome.job_id.is_none() {
82            outcome.job_id = Some(job_id.clone());
83        }
84        if outcome.request_id.is_none() {
85            outcome.request_id = Some(request_id.clone());
86        }
87        record_outcome(&span, function_name.as_str(), &outcome, start.elapsed());
88        outcome
89    }
90}
91
92fn record_outcome(
93    span: &tracing::Span,
94    function_name: &str,
95    outcome: &ExecutionOutcome,
96    duration: std::time::Duration,
97) {
98    let duration_ms = duration.as_secs_f64() * 1000.0;
99    crate::telemetry::record_job_outcome(function_name, outcome.status.clone(), duration);
100    span.record("rrq.duration_ms", duration_ms);
101    match outcome.status {
102        OutcomeStatus::Success => {
103            span.record("rrq.outcome", "success");
104        }
105        OutcomeStatus::Retry => {
106            span.record("rrq.outcome", "retry");
107            if let Some(delay) = outcome.retry_after_seconds {
108                span.record("rrq.retry_delay_ms", delay * 1000.0);
109            }
110        }
111        OutcomeStatus::Timeout => {
112            span.record("rrq.outcome", "timeout");
113        }
114        OutcomeStatus::Error => {
115            span.record("rrq.outcome", "error");
116        }
117    }
118    if let Some(error) = outcome.error.as_ref() {
119        span.record("rrq.error_message", error.message.as_str());
120        if let Some(error_type) = error.error_type.as_deref() {
121            span.record("rrq.error_type", error_type);
122        }
123    }
124}
125
126#[cfg(test)]
127mod tests {
128    use super::*;
129    use crate::types::{ExecutionOutcome, ExecutionRequest, OutcomeStatus};
130    use serde_json::json;
131
132    #[tokio::test]
133    async fn registry_invokes_handler() {
134        let mut registry = Registry::new();
135        registry.register("echo", |request| async move {
136            ExecutionOutcome::success(
137                request.job_id.clone(),
138                request.request_id.clone(),
139                json!({ "job_id": request.job_id }),
140            )
141        });
142
143        let handler = registry.get("echo").expect("handler not found");
144        let payload = json!({
145            "job_id": "job-1",
146            "request_id": "req-1",
147            "function_name": "echo",
148            "args": [],
149            "kwargs": {},
150            "context": {
151                "job_id": "job-1",
152                "attempt": 1,
153                "enqueue_time": "2024-01-01T00:00:00Z",
154                "queue_name": "default",
155                "deadline": null,
156                "trace_context": null,
157                "worker_id": null
158            }
159        });
160        let request: ExecutionRequest = serde_json::from_value(payload).unwrap();
161
162        let outcome = handler.handle(request).await;
163        assert_eq!(outcome.status, OutcomeStatus::Success);
164        assert_eq!(outcome.result, Some(json!({ "job_id": "job-1" })));
165    }
166
167    #[tokio::test]
168    async fn registry_execute_with_noop_telemetry() {
169        let mut registry = Registry::new();
170        registry.register("echo", |request| async move {
171            ExecutionOutcome::success(
172                request.job_id.clone(),
173                request.request_id.clone(),
174                json!({ "job_id": request.job_id }),
175            )
176        });
177
178        let payload = json!({
179            "job_id": "job-2",
180            "request_id": "req-2",
181            "function_name": "echo",
182            "args": [],
183            "kwargs": {},
184            "context": {
185                "job_id": "job-2",
186                "attempt": 1,
187                "enqueue_time": "2024-01-01T00:00:00Z",
188                "queue_name": "default",
189                "deadline": null,
190                "trace_context": null,
191                "worker_id": null
192            }
193        });
194        let request: ExecutionRequest = serde_json::from_value(payload).unwrap();
195
196        let outcome = registry.execute_with(request, &NoopTelemetry).await;
197        assert_eq!(outcome.status, OutcomeStatus::Success);
198    }
199
200    #[tokio::test]
201    async fn registry_execute_handler_not_found() {
202        let registry = Registry::new();
203        let payload = json!({
204            "job_id": "job-3",
205            "request_id": "req-3",
206            "function_name": "missing",
207            "args": [],
208            "kwargs": {},
209            "context": {
210                "job_id": "job-3",
211                "attempt": 1,
212                "enqueue_time": "2024-01-01T00:00:00Z",
213                "queue_name": "default",
214                "deadline": null,
215                "trace_context": null,
216                "worker_id": null
217            }
218        });
219        let request: ExecutionRequest = serde_json::from_value(payload).unwrap();
220
221        let outcome = registry.execute_with(request, &NoopTelemetry).await;
222        assert_eq!(outcome.status, OutcomeStatus::Error);
223        assert_eq!(
224            outcome
225                .error
226                .as_ref()
227                .and_then(|error| error.error_type.as_deref()),
228            Some("handler_not_found")
229        );
230    }
231}