Skip to main content

awsim_xray/
lib.rs

1//! AWS X-Ray emulator. Accepts trace segments via PutTraceSegments, persists
2//! them in memory, and serves the listing/aggregation operations the SDK and
3//! the X-Ray daemon hit when populating the console.
4
5mod operations;
6pub mod state;
7
8pub use state::XrayState;
9
10use std::sync::Arc;
11
12use async_trait::async_trait;
13use awsim_core::{
14    AccountRegionStore, AwsError, Protocol, RequestContext, RouteDefinition, ServiceHandler,
15};
16use serde_json::Value;
17use tracing::debug;
18
19pub struct XrayService {
20    store: AccountRegionStore<XrayState>,
21}
22
23impl XrayService {
24    pub fn new() -> Self {
25        Self {
26            store: AccountRegionStore::new(),
27        }
28    }
29
30    pub fn store(&self) -> AccountRegionStore<XrayState> {
31        self.store.clone()
32    }
33
34    fn get_state(&self, ctx: &RequestContext) -> Arc<XrayState> {
35        self.store.get(&ctx.account_id, &ctx.region)
36    }
37}
38
39impl Default for XrayService {
40    fn default() -> Self {
41        Self::new()
42    }
43}
44
45#[async_trait]
46impl ServiceHandler for XrayService {
47    fn service_name(&self) -> &str {
48        "xray"
49    }
50
51    fn signing_name(&self) -> &str {
52        "xray"
53    }
54
55    fn protocol(&self) -> Protocol {
56        Protocol::RestJson1
57    }
58
59    fn routes(&self) -> Vec<RouteDefinition> {
60        vec![
61            RouteDefinition {
62                method: "POST",
63                path_pattern: "/TraceSegments",
64                operation: "PutTraceSegments",
65                required_query_param: None,
66            },
67            RouteDefinition {
68                method: "POST",
69                path_pattern: "/Traces",
70                operation: "BatchGetTraces",
71                required_query_param: None,
72            },
73            RouteDefinition {
74                method: "POST",
75                path_pattern: "/TraceSummaries",
76                operation: "GetTraceSummaries",
77                required_query_param: None,
78            },
79            RouteDefinition {
80                method: "POST",
81                path_pattern: "/ServiceGraph",
82                operation: "GetServiceGraph",
83                required_query_param: None,
84            },
85            RouteDefinition {
86                method: "POST",
87                path_pattern: "/GetSamplingRules",
88                operation: "GetSamplingRules",
89                required_query_param: None,
90            },
91            RouteDefinition {
92                method: "POST",
93                path_pattern: "/CreateSamplingRule",
94                operation: "CreateSamplingRule",
95                required_query_param: None,
96            },
97            RouteDefinition {
98                method: "POST",
99                path_pattern: "/DeleteSamplingRule",
100                operation: "DeleteSamplingRule",
101                required_query_param: None,
102            },
103            RouteDefinition {
104                method: "POST",
105                path_pattern: "/SamplingTargets",
106                operation: "GetSamplingTargets",
107                required_query_param: None,
108            },
109            RouteDefinition {
110                method: "POST",
111                path_pattern: "/CreateGroup",
112                operation: "CreateGroup",
113                required_query_param: None,
114            },
115            RouteDefinition {
116                method: "POST",
117                path_pattern: "/DeleteGroup",
118                operation: "DeleteGroup",
119                required_query_param: None,
120            },
121            RouteDefinition {
122                method: "POST",
123                path_pattern: "/Groups",
124                operation: "GetGroups",
125                required_query_param: None,
126            },
127        ]
128    }
129
130    async fn handle(
131        &self,
132        operation: &str,
133        input: Value,
134        ctx: &RequestContext,
135    ) -> Result<Value, AwsError> {
136        debug!(operation, "X-Ray request");
137        let state = self.get_state(ctx);
138        match operation {
139            "PutTraceSegments" => operations::put_trace_segments(&state, &input, ctx),
140            "BatchGetTraces" => operations::batch_get_traces(&state, &input, ctx),
141            "GetTraceSummaries" => operations::get_trace_summaries(&state, &input, ctx),
142            "GetServiceGraph" => operations::get_service_graph(&state, &input, ctx),
143            "GetSamplingRules" => operations::get_sampling_rules(&state, &input, ctx),
144            "CreateSamplingRule" => operations::create_sampling_rule(&state, &input, ctx),
145            "DeleteSamplingRule" => operations::delete_sampling_rule(&state, &input, ctx),
146            "GetSamplingTargets" => operations::get_sampling_targets(&state, &input, ctx),
147            "CreateGroup" => operations::create_group(&state, &input, ctx),
148            "DeleteGroup" => operations::delete_group(&state, &input, ctx),
149            "GetGroups" => operations::get_groups(&state, &input, ctx),
150            _ => Err(AwsError::unknown_operation(operation)),
151        }
152    }
153
154    fn snapshot(&self) -> Option<Vec<u8>> {
155        let mut all = state::XrayStateSnapshot {
156            traces: vec![],
157            sampling_rules: Default::default(),
158            groups: Default::default(),
159        };
160        for (_, st) in self.store.iter_all() {
161            let s = st.to_snapshot();
162            all.traces.extend(s.traces);
163            all.sampling_rules.extend(s.sampling_rules);
164            all.groups.extend(s.groups);
165        }
166        serde_json::to_vec(&all).ok()
167    }
168
169    fn restore(&self, data: &[u8]) -> Result<(), String> {
170        let snap: state::XrayStateSnapshot =
171            serde_json::from_slice(data).map_err(|e| e.to_string())?;
172        let st = self.store.get("000000000000", "us-east-1");
173        st.restore_from_snapshot(snap);
174        Ok(())
175    }
176}
177
178#[cfg(test)]
179mod tests {
180    use super::*;
181    use serde_json::json;
182
183    fn ctx() -> RequestContext {
184        RequestContext::new("xray", "us-east-1")
185    }
186
187    fn block_on<F: std::future::Future>(f: F) -> F::Output {
188        use std::task::{Context, Poll, RawWaker, RawWakerVTable, Waker};
189        fn noop_clone(_: *const ()) -> RawWaker {
190            noop_raw_waker()
191        }
192        fn noop(_: *const ()) {}
193        fn noop_raw_waker() -> RawWaker {
194            static VTABLE: RawWakerVTable = RawWakerVTable::new(noop_clone, noop, noop, noop);
195            RawWaker::new(std::ptr::null(), &VTABLE)
196        }
197        let waker = unsafe { Waker::from_raw(noop_raw_waker()) };
198        let mut cx = Context::from_waker(&waker);
199        let mut fut = std::pin::pin!(f);
200        loop {
201            if let Poll::Ready(v) = fut.as_mut().poll(&mut cx) {
202                return v;
203            }
204        }
205    }
206
207    fn segment_doc(trace_id: &str, name: &str, start: f64, end: f64, fault: bool) -> String {
208        serde_json::json!({
209            "trace_id": trace_id,
210            "id": "1234567890123456",
211            "name": name,
212            "start_time": start,
213            "end_time": end,
214            "fault": fault,
215        })
216        .to_string()
217    }
218
219    #[test]
220    fn put_then_summarize_and_graph() {
221        let svc = XrayService::new();
222        let ctx = ctx();
223
224        let put = block_on(svc.handle(
225            "PutTraceSegments",
226            json!({
227                "TraceSegmentDocuments": [
228                    segment_doc("1-65f5a8a0-1234567890abcdef12345678", "checkout-svc", 100.0, 100.5, false),
229                    segment_doc("1-65f5a8a0-1234567890abcdef12345678", "payment-svc", 100.1, 100.4, true),
230                ]
231            }),
232            &ctx,
233        ))
234        .unwrap();
235        assert!(
236            put["UnprocessedTraceSegments"]
237                .as_array()
238                .unwrap()
239                .is_empty()
240        );
241
242        let summaries = block_on(svc.handle(
243            "GetTraceSummaries",
244            json!({ "StartTime": 0, "EndTime": 9_999_999_999.0 }),
245            &ctx,
246        ))
247        .unwrap();
248        assert_eq!(summaries["TraceSummaries"].as_array().unwrap().len(), 1);
249        let s = &summaries["TraceSummaries"][0];
250        assert_eq!(s["HasFault"], true);
251        assert!(s["Duration"].as_f64().unwrap() >= 0.4);
252
253        let graph = block_on(svc.handle("GetServiceGraph", json!({}), &ctx)).unwrap();
254        let svcs = graph["Services"].as_array().unwrap();
255        assert_eq!(svcs.len(), 2);
256
257        let traces = block_on(svc.handle(
258            "BatchGetTraces",
259            json!({ "TraceIds": ["1-65f5a8a0-1234567890abcdef12345678"] }),
260            &ctx,
261        ))
262        .unwrap();
263        assert_eq!(traces["Traces"].as_array().unwrap().len(), 1);
264        assert_eq!(traces["Traces"][0]["Segments"].as_array().unwrap().len(), 2);
265    }
266
267    #[test]
268    fn unprocessed_segments_returned_for_invalid_input() {
269        let svc = XrayService::new();
270        let ctx = ctx();
271        let r = block_on(svc.handle(
272            "PutTraceSegments",
273            json!({ "TraceSegmentDocuments": ["{not json", json!({"id": "x"}).to_string()] }),
274            &ctx,
275        ))
276        .unwrap();
277        let unp = r["UnprocessedTraceSegments"].as_array().unwrap();
278        assert_eq!(unp.len(), 2);
279    }
280}