Skip to main content

rrq_runner/
registry.rs

1use async_trait::async_trait;
2use std::collections::HashMap;
3use std::future::Future;
4use std::sync::Arc;
5
6use std::time::Instant;
7
8use tracing::Instrument;
9
10use crate::telemetry::{NoopTelemetry, Telemetry};
11use crate::types::{ExecutionOutcome, ExecutionRequest, OutcomeStatus};
12
13#[async_trait]
14pub trait Handler: Send + Sync {
15    async fn handle(&self, request: ExecutionRequest) -> ExecutionOutcome;
16}
17
18struct FnHandler<F>(F);
19
20#[async_trait]
21impl<F, Fut> Handler for FnHandler<F>
22where
23    F: Fn(ExecutionRequest) -> Fut + Send + Sync,
24    Fut: Future<Output = ExecutionOutcome> + Send,
25{
26    async fn handle(&self, request: ExecutionRequest) -> ExecutionOutcome {
27        (self.0)(request).await
28    }
29}
30
31#[derive(Clone, Default)]
32pub struct Registry {
33    handlers: HashMap<String, Arc<dyn Handler>>,
34}
35
36impl Registry {
37    pub fn new() -> Self {
38        Self {
39            handlers: HashMap::new(),
40        }
41    }
42
43    pub fn register<F, Fut>(&mut self, name: impl Into<String>, handler: F)
44    where
45        F: Fn(ExecutionRequest) -> Fut + Send + Sync + 'static,
46        Fut: Future<Output = ExecutionOutcome> + Send + 'static,
47    {
48        let handler = Arc::new(FnHandler(handler)) as Arc<dyn Handler>;
49        self.handlers.insert(name.into(), handler);
50    }
51
52    pub fn get(&self, name: &str) -> Option<Arc<dyn Handler>> {
53        self.handlers.get(name).cloned()
54    }
55
56    pub async fn execute(&self, request: ExecutionRequest) -> ExecutionOutcome {
57        let telemetry = NoopTelemetry;
58        self.execute_with(request, &telemetry).await
59    }
60
61    pub async fn execute_with<T: Telemetry + ?Sized>(
62        &self,
63        request: ExecutionRequest,
64        telemetry: &T,
65    ) -> ExecutionOutcome {
66        let span = telemetry.runner_span(&request);
67        let start = Instant::now();
68        let function_name = request.function_name.clone();
69        let job_id = request.job_id.clone();
70        let request_id = request.request_id.clone();
71        let mut outcome = match self.get(&function_name) {
72            Some(handler) => handler.handle(request).instrument(span.clone()).await,
73            None => ExecutionOutcome::handler_not_found(
74                job_id.clone(),
75                request_id.clone(),
76                format!("No handler registered for function '{}'", function_name),
77            ),
78        };
79        if outcome.job_id.is_none() {
80            outcome.job_id = Some(job_id.clone());
81        }
82        if outcome.request_id.is_none() {
83            outcome.request_id = Some(request_id.clone());
84        }
85        record_outcome(&span, &outcome, start.elapsed());
86        outcome
87    }
88}
89
90fn record_outcome(span: &tracing::Span, outcome: &ExecutionOutcome, duration: std::time::Duration) {
91    let duration_ms = duration.as_secs_f64() * 1000.0;
92    span.record("rrq.duration_ms", duration_ms);
93    match outcome.status {
94        OutcomeStatus::Success => {
95            span.record("rrq.outcome", "success");
96        }
97        OutcomeStatus::Retry => {
98            span.record("rrq.outcome", "retry");
99            if let Some(delay) = outcome.retry_after_seconds {
100                span.record("rrq.retry_delay_ms", delay * 1000.0);
101            }
102        }
103        OutcomeStatus::Timeout => {
104            span.record("rrq.outcome", "timeout");
105        }
106        OutcomeStatus::Error => {
107            span.record("rrq.outcome", "error");
108        }
109    }
110    if let Some(error) = outcome.error.as_ref() {
111        span.record("rrq.error_message", error.message.as_str());
112        if let Some(error_type) = error.error_type.as_deref() {
113            span.record("rrq.error_type", error_type);
114        }
115    }
116}
117
118#[cfg(test)]
119mod tests {
120    use super::*;
121    use crate::types::{ExecutionOutcome, ExecutionRequest, OutcomeStatus};
122    use serde_json::json;
123
124    #[tokio::test]
125    async fn registry_invokes_handler() {
126        let mut registry = Registry::new();
127        registry.register("echo", |request| async move {
128            ExecutionOutcome::success(
129                request.job_id.clone(),
130                request.request_id.clone(),
131                json!({ "job_id": request.job_id }),
132            )
133        });
134
135        let handler = registry.get("echo").expect("handler not found");
136        let payload = json!({
137            "job_id": "job-1",
138            "request_id": "req-1",
139            "function_name": "echo",
140            "args": [],
141            "kwargs": {},
142            "context": {
143                "job_id": "job-1",
144                "attempt": 1,
145                "enqueue_time": "2024-01-01T00:00:00Z",
146                "queue_name": "default",
147                "deadline": null,
148                "trace_context": null,
149                "worker_id": null
150            }
151        });
152        let request: ExecutionRequest = serde_json::from_value(payload).unwrap();
153
154        let outcome = handler.handle(request).await;
155        assert_eq!(outcome.status, OutcomeStatus::Success);
156        assert_eq!(outcome.result, Some(json!({ "job_id": "job-1" })));
157    }
158
159    #[tokio::test]
160    async fn registry_execute_with_noop_telemetry() {
161        let mut registry = Registry::new();
162        registry.register("echo", |request| async move {
163            ExecutionOutcome::success(
164                request.job_id.clone(),
165                request.request_id.clone(),
166                json!({ "job_id": request.job_id }),
167            )
168        });
169
170        let payload = json!({
171            "job_id": "job-2",
172            "request_id": "req-2",
173            "function_name": "echo",
174            "args": [],
175            "kwargs": {},
176            "context": {
177                "job_id": "job-2",
178                "attempt": 1,
179                "enqueue_time": "2024-01-01T00:00:00Z",
180                "queue_name": "default",
181                "deadline": null,
182                "trace_context": null,
183                "worker_id": null
184            }
185        });
186        let request: ExecutionRequest = serde_json::from_value(payload).unwrap();
187
188        let outcome = registry.execute_with(request, &NoopTelemetry).await;
189        assert_eq!(outcome.status, OutcomeStatus::Success);
190    }
191
192    #[tokio::test]
193    async fn registry_execute_handler_not_found() {
194        let registry = Registry::new();
195        let payload = json!({
196            "job_id": "job-3",
197            "request_id": "req-3",
198            "function_name": "missing",
199            "args": [],
200            "kwargs": {},
201            "context": {
202                "job_id": "job-3",
203                "attempt": 1,
204                "enqueue_time": "2024-01-01T00:00:00Z",
205                "queue_name": "default",
206                "deadline": null,
207                "trace_context": null,
208                "worker_id": null
209            }
210        });
211        let request: ExecutionRequest = serde_json::from_value(payload).unwrap();
212
213        let outcome = registry.execute_with(request, &NoopTelemetry).await;
214        assert_eq!(outcome.status, OutcomeStatus::Error);
215        assert_eq!(
216            outcome
217                .error
218                .as_ref()
219                .and_then(|error| error.error_type.as_deref()),
220            Some("handler_not_found")
221        );
222    }
223}