1mod operations;
6pub mod state;
7
8pub use state::XrayState;
9
10use std::sync::Arc;
11
12use async_trait::async_trait;
13use awsim_core::{
14 AccountRegionStore, AwsError, Protocol, RequestContext, RouteDefinition, ServiceHandler,
15};
16use serde_json::Value;
17use tracing::debug;
18
19pub struct XrayService {
20 store: AccountRegionStore<XrayState>,
21}
22
23impl XrayService {
24 pub fn new() -> Self {
25 Self {
26 store: AccountRegionStore::new(),
27 }
28 }
29
30 pub fn store(&self) -> AccountRegionStore<XrayState> {
31 self.store.clone()
32 }
33
34 fn get_state(&self, ctx: &RequestContext) -> Arc<XrayState> {
35 self.store.get(&ctx.account_id, &ctx.region)
36 }
37}
38
39impl Default for XrayService {
40 fn default() -> Self {
41 Self::new()
42 }
43}
44
45#[async_trait]
46impl ServiceHandler for XrayService {
47 fn service_name(&self) -> &str {
48 "xray"
49 }
50
51 fn signing_name(&self) -> &str {
52 "xray"
53 }
54
55 fn protocol(&self) -> Protocol {
56 Protocol::RestJson1
57 }
58
59 fn routes(&self) -> Vec<RouteDefinition> {
60 vec![
61 RouteDefinition {
62 method: "POST",
63 path_pattern: "/TraceSegments",
64 operation: "PutTraceSegments",
65 required_query_param: None,
66 },
67 RouteDefinition {
68 method: "POST",
69 path_pattern: "/Traces",
70 operation: "BatchGetTraces",
71 required_query_param: None,
72 },
73 RouteDefinition {
74 method: "POST",
75 path_pattern: "/TraceSummaries",
76 operation: "GetTraceSummaries",
77 required_query_param: None,
78 },
79 RouteDefinition {
80 method: "POST",
81 path_pattern: "/ServiceGraph",
82 operation: "GetServiceGraph",
83 required_query_param: None,
84 },
85 RouteDefinition {
86 method: "POST",
87 path_pattern: "/GetSamplingRules",
88 operation: "GetSamplingRules",
89 required_query_param: None,
90 },
91 RouteDefinition {
92 method: "POST",
93 path_pattern: "/CreateSamplingRule",
94 operation: "CreateSamplingRule",
95 required_query_param: None,
96 },
97 RouteDefinition {
98 method: "POST",
99 path_pattern: "/DeleteSamplingRule",
100 operation: "DeleteSamplingRule",
101 required_query_param: None,
102 },
103 RouteDefinition {
104 method: "POST",
105 path_pattern: "/SamplingTargets",
106 operation: "GetSamplingTargets",
107 required_query_param: None,
108 },
109 RouteDefinition {
110 method: "POST",
111 path_pattern: "/CreateGroup",
112 operation: "CreateGroup",
113 required_query_param: None,
114 },
115 RouteDefinition {
116 method: "POST",
117 path_pattern: "/DeleteGroup",
118 operation: "DeleteGroup",
119 required_query_param: None,
120 },
121 RouteDefinition {
122 method: "POST",
123 path_pattern: "/Groups",
124 operation: "GetGroups",
125 required_query_param: None,
126 },
127 ]
128 }
129
130 async fn handle(
131 &self,
132 operation: &str,
133 input: Value,
134 ctx: &RequestContext,
135 ) -> Result<Value, AwsError> {
136 debug!(operation, "X-Ray request");
137 let state = self.get_state(ctx);
138 match operation {
139 "PutTraceSegments" => operations::put_trace_segments(&state, &input, ctx),
140 "BatchGetTraces" => operations::batch_get_traces(&state, &input, ctx),
141 "GetTraceSummaries" => operations::get_trace_summaries(&state, &input, ctx),
142 "GetServiceGraph" => operations::get_service_graph(&state, &input, ctx),
143 "GetSamplingRules" => operations::get_sampling_rules(&state, &input, ctx),
144 "CreateSamplingRule" => operations::create_sampling_rule(&state, &input, ctx),
145 "DeleteSamplingRule" => operations::delete_sampling_rule(&state, &input, ctx),
146 "GetSamplingTargets" => operations::get_sampling_targets(&state, &input, ctx),
147 "CreateGroup" => operations::create_group(&state, &input, ctx),
148 "DeleteGroup" => operations::delete_group(&state, &input, ctx),
149 "GetGroups" => operations::get_groups(&state, &input, ctx),
150 _ => Err(AwsError::unknown_operation(operation)),
151 }
152 }
153
154 fn snapshot(&self) -> Option<Vec<u8>> {
155 let mut all = state::XrayStateSnapshot {
156 traces: vec![],
157 sampling_rules: Default::default(),
158 groups: Default::default(),
159 };
160 for (_, st) in self.store.iter_all() {
161 let s = st.to_snapshot();
162 all.traces.extend(s.traces);
163 all.sampling_rules.extend(s.sampling_rules);
164 all.groups.extend(s.groups);
165 }
166 serde_json::to_vec(&all).ok()
167 }
168
169 fn restore(&self, data: &[u8]) -> Result<(), String> {
170 let snap: state::XrayStateSnapshot =
171 serde_json::from_slice(data).map_err(|e| e.to_string())?;
172 let st = self.store.get("000000000000", "us-east-1");
173 st.restore_from_snapshot(snap);
174 Ok(())
175 }
176}
177
178#[cfg(test)]
179mod tests {
180 use super::*;
181 use serde_json::json;
182
183 fn ctx() -> RequestContext {
184 RequestContext::new("xray", "us-east-1")
185 }
186
187 fn block_on<F: std::future::Future>(f: F) -> F::Output {
188 use std::task::{Context, Poll, RawWaker, RawWakerVTable, Waker};
189 fn noop_clone(_: *const ()) -> RawWaker {
190 noop_raw_waker()
191 }
192 fn noop(_: *const ()) {}
193 fn noop_raw_waker() -> RawWaker {
194 static VTABLE: RawWakerVTable = RawWakerVTable::new(noop_clone, noop, noop, noop);
195 RawWaker::new(std::ptr::null(), &VTABLE)
196 }
197 let waker = unsafe { Waker::from_raw(noop_raw_waker()) };
198 let mut cx = Context::from_waker(&waker);
199 let mut fut = std::pin::pin!(f);
200 loop {
201 if let Poll::Ready(v) = fut.as_mut().poll(&mut cx) {
202 return v;
203 }
204 }
205 }
206
207 fn segment_doc(trace_id: &str, name: &str, start: f64, end: f64, fault: bool) -> String {
208 serde_json::json!({
209 "trace_id": trace_id,
210 "id": "1234567890123456",
211 "name": name,
212 "start_time": start,
213 "end_time": end,
214 "fault": fault,
215 })
216 .to_string()
217 }
218
219 #[test]
220 fn put_then_summarize_and_graph() {
221 let svc = XrayService::new();
222 let ctx = ctx();
223
224 let put = block_on(svc.handle(
225 "PutTraceSegments",
226 json!({
227 "TraceSegmentDocuments": [
228 segment_doc("1-65f5a8a0-1234567890abcdef12345678", "checkout-svc", 100.0, 100.5, false),
229 segment_doc("1-65f5a8a0-1234567890abcdef12345678", "payment-svc", 100.1, 100.4, true),
230 ]
231 }),
232 &ctx,
233 ))
234 .unwrap();
235 assert!(
236 put["UnprocessedTraceSegments"]
237 .as_array()
238 .unwrap()
239 .is_empty()
240 );
241
242 let summaries = block_on(svc.handle(
243 "GetTraceSummaries",
244 json!({ "StartTime": 0, "EndTime": 9_999_999_999.0 }),
245 &ctx,
246 ))
247 .unwrap();
248 assert_eq!(summaries["TraceSummaries"].as_array().unwrap().len(), 1);
249 let s = &summaries["TraceSummaries"][0];
250 assert_eq!(s["HasFault"], true);
251 assert!(s["Duration"].as_f64().unwrap() >= 0.4);
252
253 let graph = block_on(svc.handle("GetServiceGraph", json!({}), &ctx)).unwrap();
254 let svcs = graph["Services"].as_array().unwrap();
255 assert_eq!(svcs.len(), 2);
256
257 let traces = block_on(svc.handle(
258 "BatchGetTraces",
259 json!({ "TraceIds": ["1-65f5a8a0-1234567890abcdef12345678"] }),
260 &ctx,
261 ))
262 .unwrap();
263 assert_eq!(traces["Traces"].as_array().unwrap().len(), 1);
264 assert_eq!(traces["Traces"][0]["Segments"].as_array().unwrap().len(), 2);
265 }
266
267 #[test]
268 fn unprocessed_segments_returned_for_invalid_input() {
269 let svc = XrayService::new();
270 let ctx = ctx();
271 let r = block_on(svc.handle(
272 "PutTraceSegments",
273 json!({ "TraceSegmentDocuments": ["{not json", json!({"id": "x"}).to_string()] }),
274 &ctx,
275 ))
276 .unwrap();
277 let unp = r["UnprocessedTraceSegments"].as_array().unwrap();
278 assert_eq!(unp.len(), 2);
279 }
280}